Context Autoencoder for Self-Supervised Representation Learning

Xiaokang Chen, Mingyu Ding, Xiaodi Wang, Ying Xin, Shentong Mo, Yunhao Wang, Shumin Han, Ping Luo, Gang Zeng, Jingdong Wang

Introduction

We study the masked image modeling (MIM) task for self-supervised representation learning. It aims to learn an encoder through masking some patches of the input image and making predictions for the masked patches from the visible patches. It is expected that the resulting encoder pretrained through solving the MIM task is able to extract the patch representations taking on semantics that are transferred to solving downstream tasks.

The typical MIM methods, such as BEiT bao2021beit , the method studied in the ViT paper DosovitskiyB0WZ21 , and iBoT zhou2021ibot , use a single ViT architecture to solve the pretraining task i.e., reconstructing the patch tokens or the pixel colors. These methods mix the two tasks: learning the encoder (representation) and reconstructing the masked patch. The subsequent method, masked autoencoder (MAE) he2021masked adopts an encoder-decoder architecture, partially decoupling the two tasks. As a result, the representation quality is limited. Most previous methods, except iBoT zhou2021ibot , lack an explicit modeling between encoded representations of visible patches and masked patches.

We present a context autoencoder (CAE) approach, illustrated in Figure 1, for improving the encoding quality. We pretrain the encoder through making predictions for the masked patches in the encoded representation space. The pretraining task is a combination of masked representation prediction and masked patch reconstruction. the pretraining network is an encoder-regressor-decoder architecture. The encoder takes only the visible patches as input and learns the representations only for the visible patches. The regressor predicts the masked patch representations, which is expected to be aligned with the representations of the masked patches computed from the encoder, from the visible patch representations. The decoder reconstructs the masked patches from the predicted masked patch representations without receiving the representations of the visible patches.

The prediction in the encoded representation space from the visible patches to the masked patches generates a plausible semantic guess for the masked patches, which lies in the same semantic space for the visible patches. We assume that the prediction is easier if the encoded representations take higher semantics and that the accurate prediction encourages that the encoded representations take on a larger extent of semantics, empirically validated by the experiments.

The CAE design also encourages the separation of learning the encoder and completing the pretraining tasks: the responsibility of representation learning is mainly taken by the encoder and the encoder is only for representation learning. The reasons include: the encoder in the top stream in Figure 1 operates only on visible patches, only focusing on learning semantic representations; the regression is done on the encoded representation space, as a mapping between the representations of the visible patches and the masked patches; the decoder operates only on the predicted representations of the masked patches.

We present the empirical performance of our approach on downstream tasks, semantic segmentation, object detection and instance segmentation, and classification. The results show that our approach outperforms supervised pretraining, contrastive self-supervised pretraining, and other MIM methods.

Related Work

Self-supervised representation learning has been widely studied in computer vision , including: context prediction CarlDoersch2015UnsupervisedVR ; tian2021semantic , clustering-based methods xie2016unsupervised ; yang2016joint ; caron2018deep ; asano2019self ; zhuang2019local ; huang2019unsupervised ; caron2019unsupervised ; PriyaGoyal2021SelfsupervisedPO , contrastive self-supervised learning li2020prototypical ; AaronvandenOord2018RepresentationLW ; henaff2020data ; wang2022repre , instance discrimination dosovitskiy2014discriminative ; dosovitskiy2015discriminative , image discretization gidaris2020learning ; gidaris2020online , masked image modeling li2021mst ; fang2022corrupted ; tian2022beyond , and information maximization ermolov2021whitening ; zbontar2021barlow ; bardes2021vicreg . The following mainly reviews closely-related methods.

Autoencoding. Traditionally, autoencoders were used for dimensionality reduction or feature learning phdthesis_LeCun ; gallinari1987memoires ; hinton1994autoencoders ; hinton2006reducing ; ranzato2007efficient ; vincent2008extracting ; kingma2013auto . The denoising autoencoder (DAE) is an autoencoder that receives a corrupted data point as input and is trained to estimate the original, uncorrupted data point as its output. The variants or modifications of DAE were adopted for self-supervised representation learning, e.g., corruption by masking pixels VincentLLBM10 ; pathak2016context ; chen2020generative , removing color channels zhang2016colorful , shuffling image patches noroozi2016unsupervised , denoising pixel-level noise atito2021sit and so on.

Contrastive self-supervised learning. Contrastive self-supervised learning, referring in this paper to the self-supervised approaches comparing random views with contrastive loss or simply MSE loss that are related as shown in GarridoCBNL22 , has been popular for self-supervised representation learning ChenK0H20 ; He0WXG20 ; YonglongTian2020WhatMF ; ChenXH21 ; grill2020bootstrap ; CaronTMJMBJ21 ; chen2021exploring ; caron2020unsupervised_swav ; wu2018unsupervised ; XiangyuPeng2022CraftingBC . The basic idea is to maximize the similarity between the views augmented from the same image and optionally minimize the similarity between the views augmented from different images. Random cropping is an important augmentation scheme, and thus typical contrastive self-supervised learning methods (e.g., MoCo v3) tend to learn knowledge mainly from the central regions of the original images. Some dense variants wang2021dense ; xie2021propagate eliminate the tendency in a limited degree by considering an extra contrastive loss with dense patches.

Masked image modeling. Motivated by BERT for masked language modeling DevlinCLT19 , the method studied in DosovitskiyB0WZ21 and BEiT bao2021beit use the ViT structure to solve the masked image modeling task, e.g., estimate the pixels or the discrete tokens. The follow-up work, iBOT zhou2021ibot , combines the MIM method (BEiT) and a contrastive self-supervised approach (DINO CaronTMJMBJ21 ). But they do not have explicitly an encoder for representation learning or a decoder for pretraining task completion, and the ViT structure is essentially a mixture of encoder and decoder, limiting the representation learning quality.

Several subsequent MIM methods are developed to improve the encoder quality, such as designing pretraining architectures: Masked Autoencoder (MAE) he2021masked , SplitMask el2021large , and Simple MIM (SimMIM) xie2021simmim ; adopting new reconstruction targets: Masked Feature Prediction (MaskFeat) wei2021masked , Perceptual Codebook for BEiT (PeCo) dong2021peco , and data2vec BaevskiHXBGA22 . The technical report https://arxiv.org/abs/2202.03026 of our approach was initially published as an arXiv paper CAE2022 , and was concurrent to data2vec BaevskiHXBGA22 , MAE he2021masked , and other methods, such as el2021large ; xie2021simmim . After that, MIM methods have developed rapidly, e.g., extended to frequency/semantic domain xie2022masked ; liu2022devil ; wei2022mvp ; li2022mc , combined with contrastive self-superivsed learning tao2022siamese ; jing2022masked ; yi2022masked ; huang2022contrastive , efficient pretraining zhang2022hivit ; huang2022green ; chen2022efficient , mask strategy design kakogeorgiou2022hide ; li2022semmae ; li2022uniform , scalability of MIM xie2022data , and interpretation of MIM xie2022revealing ; li2022architecture ; kong2022understanding .

The core idea of our approach is making predictions in the encoded representation space. We jointly solve two pretraining tasks: masked representation prediction - predict the representations for the masked patches, where the representations lie in the representation space output from the encoder, and masked patch reconstruction - reconstruct the masked patches.

Our approach is clearly different from MAE he2021masked (Figure 2 (top)). Our approach introduces an extra pretraining task, masked representation prediction, and encourages the separation of two roles: learning the encoder and completing pretraining tasks; in contrast, MAE partially mixes the two roles, and has no explicit prediction of masked patch representations.

On the other hand, our approach differs from data2vec BaevskiHXBGA22 and iBoT zhou2021ibot (Figure 2 (bottom)). Similar to BEiT, in data2vec and iBoT, there is no explicit module separation of learning the encoder and estimating the mask patch representations, and the target representations are formed from the full view (as the teacher) with the same network as the student network for processing the masked view and predicting the masked patch representations (except a centering process in iBoT for the teacher following DINO). In contrast, our approach is simple: form the target representations merely from the output of the encoder, and the encoder-regressor design is straightforward and explainable: the regressor predicts the representations of masked patches to match the representations computed directly from the encoder.

Approach

Our context autoencoder (CAE) is a masked image modeling approach. The network shown in Figure 1 is an encoder-regressor-decoder architecture. The key is to make predictions from visible patches to masked patches in the encoded representation space. The pretraining tasks include: masked representation prediction and masked patch reconstruction.

We randomly split an image into two sets of patches: visible patches Xv\mathbf{X}_{v} and masked patches Xm\mathbf{X}_{m}. The encoder takes the visible patches as input; the regressor predicts the representations of the masked patches, which are expected to be aligned with the representations computed from the encoder, from the representations of the visible patches conditioned on the positions of masked patches; the decoder reconstructs the masked patches from the predicted encoded representations.

Encoder. The encoder F\mathcal{F} maps the visible patches Xv\mathbf{X}_{v} to the latent representations Zv\mathbf{Z}_{v}. It only handles the visible patches. We use the ViT to form our encoder. It first embeds the visible patches by linear projection as patch embeddings, and adds the positional embeddings Pv\mathbf{P}_{v}. Then it sends the combined embeddings into a sequence of transformer blocks that are based on self-attention, generating Zv\mathbf{Z}_{v}.

Regressor. The latent contextual regressor H\mathcal{H} predicts the latent representations Zm\mathbf{Z}_{m} for the masked patches from the latent representations Zv\mathbf{Z}_{v} of the visible patches output from the encoder conditioned on the positions of the masked patches. We form the latent contextual regressor H\mathcal{H} using a series of transformer blocks that are based on cross-attention.

The initial queries Qm\mathbf{Q}_{m}, called mask queries, are mask tokens that are learned as model parameters and are the same for all the masked patches. The keys and the values are the same before linear projection and consist of the visible patch representations Zv\mathbf{Z}_{v} and the output of the previous cross-attention layer (mask queries for the first cross-attention layer). The corresponding positional embeddings of the masked patches are considered when computing the cross-attention weights between the queries and the keys. In this process, the latent representations Zv\mathbf{Z}_{v} of the visible patches are not updated.

Decoder. The decoder G\mathcal{G} maps the latent representations Zm\mathbf{Z}_{m} of the masked patches to some forms of masked patches, Ym\mathbf{Y}_{m}. The decoder, similar to the encoder, is a stack of transformer blocks that are based on self-attention, followed by a linear layer predicting the targets. The decoder only receives the latent representations of the masked patches (the output of the latent contextual regressor), and the positional embeddings of the masked patches as input without directly using the information of the visible patches.

2 Objective Function

Masking. Following BEiT bao2021beit , we adopt the random block-wise masking strategy (illustrated in Figure 3) to split the input image into two sets of patches, visible and masked patches. For each image, 9898 of 196196 (14×1414\times 14) patches are masked.

Targets. The targets Zˉm\bar{\mathbf{Z}}_{m} for the representations of the masked patches are formed as follows. We feed the masked patches Xm\mathbf{X}_{m} into the encoder, which is the same as the one for encoding visible patches, and generate the representations Zˉm\bar{\mathbf{Z}}_{m} of the masked patches as the representation targets.

The targets Yˉm\bar{\mathbf{Y}}_{m} for the patch reconstruction are formed by the discrete tokenizer, e.g., the tokenizer trained with d-VAE on ImageNet-11K without using the labels or the DALL-E tokenizer (trained with d-VAE on 400400M images) RameshPGGVRCS21 used in BEiT bao2021beit . The input image is fed into the tokenizer, assigning a discrete token to each patch for forming the reconstruction targets Yˉm\bar{\mathbf{Y}}_{m}.

Discussions

Predictions are made in the encoded representation space. Our CAE attempts to make predictions in the encoded representation space: predict the representations for the masked patches from the encoded representations of the visible patches. In other words, it is expected that the output representations of the latent contextual regressor also lie in the encoded representation space, which is ensured by prediction alignment. This encourages the learned representation to take on a large extent of semantics for prediction from visible patches to masked patches, benefiting the representation learning of the encoder.

We empirically verify that the predicted representations lie in the encoded representation space through image reconstruction. We train the CAE using the pixel colors as the prediction targets, for two cases: with and without the alignment, i.e., masked representation prediction. For reconstruction, we feed all the patches (without masking, all the image patches are visible) of an image (from the ImageNet validation set) into the pretrained encoder, then skip the latent contextual regressor and directly send all the encoded patch representations to the pretrained decoder for reconstructing the whole image.

Figure 4 provides reconstruction results for several examples randomly sampled from the ImageNet-11K validation set. One can see that our approach can successfully reconstruct the images, implying that the input and output representations of latent contextual regressor are in the same space. In contrast, without the alignment, the reconstructed images are noisy, indicating the input and output representations of latent contextual regressor are in different spaces. The results suggest that the explicit prediction alignment is critical for ensuring that predictions are made in the encoded representation space.

Representation alignment in CAE and contrastive self-supervised learning. Representation alignment is also used in contrastive self-supervised learning methods, such as MoCo, BYOL, SimCLR, and methods mixing contrastive self-supervised learning and masked image modeling, such as iBOT, and MST. The alignment loss could be the MSE loss or the contrastive loss that CAE may also take advantage of.

In the CAE, the alignment is imposed over the representations Zm=H(F(Xv))\mathbf{Z}_{m}=\mathcal{H}(\mathcal{F}(\mathbf{X}_{v})) - predicted from the representations F(Xv)\mathcal{F}(\mathbf{X}_{v}) of visible patches through the regressor H\mathcal{H}, and the representations Zˉm=F(Xm)\bar{\mathbf{Z}}_{m}=\mathcal{F}(\mathbf{X}_{m}) - computed from the encoder F\mathcal{F}. Both Zm\mathbf{Z}_{m} and Zˉm\bar{\mathbf{Z}}_{m} are about the masked patches, and lie in the representation space output from the encoder.

Differently, the alignment in the most contrastive self-supervised learning methods is imposed over the representations {P(F(V1)),P(F(V2)),,P(F((VN))}\{\mathcal{P}(\mathcal{F}(\mathbf{V}_{1})),\mathcal{P}(\mathcal{F}(\mathbf{V}_{2})),\cdots,\mathcal{P}(\mathcal{F}((\mathbf{V}_{N}))\}, where P\mathcal{P} is a projector, and some views may be processed with the EMA version of the encoder and the projector. The NN representations to be aligned are about different views {V1,V2,,VN}\{\mathbf{V}_{1},\mathbf{V}_{2},\cdots,\mathbf{V}_{N}\} (in iBoT and MST, the views are masked views and full views), and are not directly output from the encoder. It is not quite clear how the projector works, and it is reported in MIMPart2022 that the projector is a part-to-whole process mapping the object part representation to the whole object representation for contrastive self-supervised learning.

2 Connection

Relation to autoencoder. The original autoencoder phdthesis_LeCun ; gallinari1987memoires ; hinton1994autoencoders consists of an encoder and a decoder. The encoder maps the input into a latent representation, and the decoder reconstructs the input from the latent representation. The denoising autoencoder (DAE) VincentLLBM10 , a variant of autoencoder, corrupts the input by adding noises and still reconstructs the non-corrupted input.

Our CAE encoder is similar to the original autoencoder and also contains an encoder and a decoder. Different from the autoencoder where the encoder and the decoder process the whole image, our encoder takes a portion of patches as input and our decoder takes the estimated latent representations of the other portion of patches as input. Importantly, the CAE makes predictions in the latent space from the visible patches to the masked patches.

Relation to BEiT, iBoT and MAE. The CAE encoder processes the visible patches, to extract their representations, without making predictions for masked patches. Masked representation prediction is made through the regressor and the prediction alignment, ensuring that the output of the regressor lies in the representation space same with the encoder output. The decoder only processes the predicted representations of masked patches. Our approach encourages that the encoder takes the responsibility of and is only for representation learning.

In contrast, BEiT bao2021beit and the MIM part of iBOT do not separate the representation extraction role and the task completion role and uses a single network, with both the visible and masked patches as the input, simultaneously for the two roles. In MAE he2021masked , the so-called decoder may play a partial role for representation learning as the representations of the visible patches are also updated in the MAE decoder. Unlike CAE, MAE, iBoT, BEiT do not explicitly predict the representations of masked patches from the representations of visible patches (that lie in the encoded representation space) for masked patches.

When the pretrained encoder is applied to downstream tasks, one often replaces the pretext task completion part using the downstream task layer, e.g., segmentation layer or detection layer. The separation of representation learning (encoding) and pretext task completion helps that downstream task applications take good advantage of representation pretraining.

We provide the computational graph for CAE, BEiT bao2021beit , denoising autoencoder, Masked Autoencoder he2021masked and SplitMask el2021large (one stream) in Figure 5. Compared to our CAE, the main issue of MAE is that the so-called decoder R\mathcal{R} might have also the encoding role, i.e., learning semantic representations of the visible patches.

Comparison to contrastive self-supervised learning. Typical contrastive self-supervised learning methods, e.g., SimCLR ChenK0H20 and MoCo He0WXG20 ; ChenXH21 , pretrain the networks by solving the pretext task, maximizing the similarities between augmented views (e.g., random crops) from the same image and minimizing the similarities between augmented views from different images.

It is shown in ChenK0H20 that random cropping plays an important role in view augmentation for contrastive self-supervised learning. Through analyzing random crops (illustrated in Figure 3), we observe that the center pixels in the original image space have large chances to belong to random crops. We suspect that the global representation, learned by contrastive self-supervised learning for a random crop possibly with other augmentation schemes, tends to focus mainly on the center pixels in the original image, so that the representations of different crops from the same image can be possibly similar. Figure 6 (the second row) shows that the center region of the original image for the typical contrastive self-supervised learning approach, MoCo v3, is highly attended. The part in random crops corresponding to the center of the original image is still attended as shown in Figure 8.

In contrast, our CAE method (and other MIM methods) randomly samples the patches from the augmented views to form the visible and masked patches. All the patches are possible to be randomly masked for the augmented views and accordingly the original image. Thus, the CAE encoder needs to learn good representations for all the patches, to make good predictions for the masked patches from the visible patches. Figure 6 (the third row) illustrates that almost all the patches in the original images are considered in our CAE encoder.

Considering that the instances of the 10001000 categories in ImageNet-11K locate mainly around the center of the original images russakovsky2015imagenet , typical contrastive self-supervised learning methods, e.g., MoCo v3, learn the knowledge mainly about the 10001000 categories, which is similar to supervised pretraining. But our CAE and other MIM methods are able to learn more knowledge beyond the 10001000 categories from the non-center image regions. This indicates that the CAE has the potential to perform better for downstream tasks.

3 Interpretation

Intuitive Interpretation for CAE. Humans are able to hallucinate what appears in the masked regions and how they appear according to the visible regions. We speculate that humans do this possibly in a way similar as the following example: given that only the region of the dog’s head is visible and the remaining parts are missing, one can (a) recognize the visible region to be about a dog, (b) predict the regions where the other parts of the dog appear, and (c) guess what the other parts look like.

Our CAE encoder is in some sense like the human recognition step (a). It understands the content by mapping the visual patches into latent representations that lie in the subspace that corresponds to the category dogOur encoder does not know that the subspace is about a dog, and just separates it from the subspaces of other categories.. The latent contextual regressor is like step (b). It produces a plausible hypothesis for the masked patches, and describes the regions corresponding to the other parts of the dog using latent representations. The CAE decoder is like step (c), mapping the latent representations to the targets. It should be noted that the latent representations might contain other information besides the semantic information, e.g., the part information and the information for making predictions.

We adopt t-SNE van2008visualizing to visualize the high-dimensional patch representations output from our CAE encoder on ADE2020K zhou2017scene in Figure 7. ADE2020K has a total of 150150 categories. For each patch in the image, we set its label to be the category that more than half of the pixels belong to. We collect up to 10001000 patches for each category from sampled 500500 images. As shown in the figure, the latent representations of CAE are clustered to some degree for different categories (though not perfect as our CAE is pretrained on ImageNet-1K). Similar observations could be found for other MIM methods.

Probabilistic interpretation for CAE. The MIM problem can be formulated in the probabilistic form, maximizing the probability of the predictions Ym\mathbf{Y}_{m} of the masked patches given the conditions, the visible patches Xv\mathbf{X}_{v}, the positions Pv\mathbf{P}_{v} of the visible patches, and the positions Pm\mathbf{P}_{m} of the masked patches: P(YmXv,Pv,Pm)P(\mathbf{Y}_{m}\mid\mathbf{X}_{v},\mathbf{P}_{v},\mathbf{P}_{m}). It can be solved by introducing latent representations Zm\mathbf{Z}_{m} and Zv\mathbf{Z}_{v}, with the assumption that Zv\mathbf{Z}_{v} and Pm\mathbf{P}_{m} (Ym\mathbf{Y}_{m} and Pv\mathbf{P}_{v}) are conditionally independent (the probabilistic graphical model is given in Figure 9):

Here, the equation from (2) to (3) is obtained from the probabilistic graphical model of CAE shown in Figure 9, and the removal of the condition Pm\mathbf{P}_{m} (from p(ZvXv,Pv,Pm)p(\mathbf{Z}_{v}\mid\mathbf{X}_{v},\mathbf{P}_{v},\mathbf{P}_{m}) to p(ZvXv,Pv)p(\mathbf{Z}_{v}\mid\mathbf{X}_{v},\mathbf{P}_{v})), and the condition Pv\mathbf{P}_{v} (from p(YmZm,Pv,Pm)p(\mathbf{Y}_{m}\mid\mathbf{Z}_{m},\mathbf{P}_{v},\mathbf{P}_{m}) to p(YmZm,Pm)p(\mathbf{Y}_{m}\mid\mathbf{Z}_{m},\mathbf{P}_{m})) from (3) to (4) is based on the conditional independence assumption. The three terms in (4) correspond to three parts of our CAE: the encoder, the latent contextual regressor, and the decoder, respectively.

Similarly, the latent representation alignment constraint can be written as a conditional probability, P(ZmZˉm)P(\mathbf{Z}_{m}\mid\bar{\mathbf{Z}}_{m}), where Zˉm\bar{\mathbf{Z}}_{m} is the masked patch representations computed from the encoder.

Intuitive interpretation for the contrastive self-supervised learning. We consider the case in ImageNet-11K that the object mainly lies in the center of an imageThere are a few images in which the object does not lie in the center in ImageNet-11K. The images are actually viewed as noises and have little influence for contrastive self-supervised learning. . There are NN randomly sampled crops from an image, and each crop In\mathbf{I}_{n} contains a part of the center object, On\mathbf{O}_{n}. To maximize the similarity between two crops Im\mathbf{I}_{m} and In\mathbf{I}_{n}, the pretraining might contain the processes: select the regions Om\mathbf{O}_{m} and On\mathbf{O}_{n} from the two crops Im\mathbf{I}_{m} and In\mathbf{I}_{n}, extract their features fom\mathbf{f}_{om} and fon\mathbf{f}_{on}, and predict the feature of the object, fo\mathbf{f}_{o}, from the part features fom\mathbf{f}_{om} and fon\mathbf{f}_{on}. In this way, the features of the crops from the same image could be similar. Among the NN random crops, most crops contain a part of the object in the center, and a few crops that do not contain a part of the center object could be viewed as noises when optimizing the contrastive loss.

After pretrained on ImageNet-11K (where the object mainly lies in the center) the encoder is able to learn the knowledge of the 10001000 classes and localize the region containing the object belonging to the 10001000 classes. It is not necessary that the object lies in the center for the testing image, which is verified in Figure 8. This further verifies that MoCo v3 (contrastive self-supervised pretraining) pretrained on ImageNet-11K tends to attend to the object region, corresponding to the center region of the original image as shown in Figure 6.

Experiments

We study the standard ViT small, base and large architectures, ViT-S (1212 transformer blocks with dimension 384384), ViT-B (1212 transformer blocks with dimension 768768) and ViT-L (2424 transformer blocks with dimension 10241024). The latent contextual regressor consists of 44 transformer blocks based on cross-attention in which self-attention over masked tokens and encoded visible patch representations is a choice but with slightly higher computation cost and a little lower performance, and the decoder consists of 44 transformer blocks based on self-attention, and an extra linear projection for making predictions.

2 Training Details

Pretraining. The pretraining settings are almost the same as BEiT bao2021beit . We train the CAE on ImageNet-11K. We partition the image of 224×224224\times 224 into 14×1414\times 14 patches with the patch size being 16×1616\times 16. We use standard random cropping and horizontal flipping for data augmentation. We use AdamW loshchilov2017adamw for optimization and train the CAE for 300300/800800/16001600 epochs with the batch size being 20482048. We set the learning rate as 1.51.5e-33 with cosine learning rate decay. The weight decay is set as 0.050.05. The warmup epochs for 300300/800800/16001600 epochs pre-training are 1010/2020/4040, respectively. We employ drop path huang2016stochastic_depth rate 0.10.1 and dropout rate .

Linear probing. We use the LARS you2017large optimizer with momentum 0.90.9. The model is trained for 9090 epochs. The batch size is 1638416384, the warmup epoch is 1010 and the learning rate is 6.46.4. Following he2021masked , we adopt an extra BatchNorm layer SergeyIoffe2015BatchNA without affine transformation (affine=False) before the linear classifier. We do not use mixup HongyiZhang2017mixupBE , cutmix SangdooYun2019CutMixRS , drop path huang2016stochastic_depth , or color jittering, and we set weight decay as zero.

Attentive probing. The parameters of the encoder are fixed during attentive probing. A cross-attention module, a BatchNorm layer (affine=False), and a linear classifier are appended after the encoder. The extra class token representation in cross-attention is learned as model parameters. The keys and the values are the patch representations output from the encoder. There is no MLP or skip connection operation in the extra cross-attention module. We use the SGD optimizer with momentum 0.90.9 and train the model for 9090 epochs. The batch size is 81928192, the warmup epoch is 1010 and the learning rate is 0.40.4. Same as linear probing, we do not use mixup HongyiZhang2017mixupBE , cutmix SangdooYun2019CutMixRS , drop path, or color jittering, and we set weight decay as zero.

Fine-tuning on ImageNet. We follow the fine-tuning protocol in BEiT to use layer-wise learning rate decay, weight decay and AdamW. The batch size is 40964096, the warmup epoch is 55 and the weight decay is 0.050.05. For ViT-S, we train 200200 epochs with learning rate 1.61.6e-22 and layer-wise decay rate 0.750.75. For ViT-B, we train 100100 epochs with learning rate 88e-33 and layer-wise decay rate 0.650.65. For ViT-L, we train 5050 epochs with learning rate 22e-33 and layer-wise decay rate 0.750.75.

Semantic segmentation on ADE2020K. We use AdamW as the optimizer. The input resolution is 512×512512\times 512. The batch size is 1616. For the ViT-B, the layer-wise decay rate is 0.650.65 and the drop path rate is 0.10.1. We search from four learning rates, 11e-44, 22e-44, 33e-44 and 44e-44, for all the results in Table 2. For the ViT-L, the layer-wise decay rate is 0.950.95 and the drop path rate is 0.150.15. We search from three learning rates for all the methods, 33e-55, 44e-55, and 55e-55, We conduct fine-tuning for 160160K steps. We do not use multi-scale testing.

Object detection and instance segmentation on COCO. We utilize multi-scale training and resize the image with the size of the short side between 480480 and 800800 and the long side no larger than 13331333. The batch size is 3232. For the ViT-S, the learning rate is 33e-44, the layer-wise decay rate is 0.750.75, and the drop path rate is 0.10.1. For the ViT-B, the learning rate is 33e-44, the layer-wise decay rate is 0.750.75, and the drop path rate is 0.20.2. For the ViT-L, the learning rate is 22e-44, the layer-wise decay rate is 0.80.8, and the drop path rate is 0.20.2. We train the network with the 1×1\times schedule: 1212 epochs with the learning rate decayed by 10×10\times at epochs 99 and 1111. We do not use multi-scale testing. The Mask R-CNN implementation follows MMDetection mmdetection .

3 Pretraining Evaluation

Linear probing. Linear probing is widely used as a proxy of pretraining quality evaluation for self-supervised representation learning. It learns a linear classifier over the image-level representation output from the pretrained encoder by using the labels of the images, and then tests the performance on the validation set.

Attentive probing. The output of the encoder pretrained with MIM methods are representations for all the patches. It is not suitable to linearly probe the representation, averagely-pooled from patch representations, because the image label in ImageNet-11K only corresponds to a portion of patches. It is also not suitable to use the default class token within the encoder because the default class token serves as a role of aggregating the patch representations for better patch representation extraction and is not merely for the portion of patches corresponding to the image label.

To use the image-level label as a proxy of evaluating the pretraining quality for the encoder pretrained with MIM methods, we need to attend the patches that are related to the label. We introduce a simple modification by using a cross-attention unit with an extra class token (that is different from the class token in the encoder) as the query and the output patch representations of the encoder as the keys and the values, followed by a linear classifier. The introduced cross-attention unit is able to care mainly about the patches belonging to the 10001000 classes in ImageNet-11K and remove the interference of other patches. Figure 10 illustrates the effect of the cross-attention unit, showing that the extra cross-attention unit can to some degree attend the regions that are related to the 10001000 ImageNet-11K classes.

Results. Table 1 shows the results with three schemes, linear probing (LIN), attentive probing (ATT), and fine-tuning (FT) for representative contrastive self-supervised pretraining (MoCo v3 and DINO) and MIM (BEiT and MAE) methods, as well as our approach with the targets formed with the DALL-E tokenizer (trained on 400400M images) and the d-VAE tokenizer (trained on ImageNet-11K without using the labels), denoted as CAE* and CAE, respectively. The models of MAE with 300300 epochs and BEiT are pretrained by us using the official implementations, and other models are officially released models.

We highlight a few observations. The fine-tuning performance for these methods are very similar and there is only a minor difference similar to the observation zhou2021ibot . We think that the reason is that self-supervised pretraining and fine-tuning are conducted on the same dataset and no extra knowledge is introduced for image classification. The minor difference might come from the optimization aspect: different initialization (provided by pretrained models) for fine-tuning.

In terms of linear probing, the scores of the contrastive self-supervised learning methods, MoCo v3 and DINO, are higher than the MIM methods. This is as expected because contrastive self-supervised learning focuses mainly on learning the representations for 10001000 classes (See discussion in Section 4). The pretraining is relatively easier than existing MIM methods as contrastive self-supervised learning mainly cares about the 10001000 classes and MIM methods may care about the classes beyond the 10001000 classes.

For the MIM methods, the scores of attentive probing are much larger than linear probing. This validates our analysis: the MIM methods extract the representations for all the patches, and the classification task needs to attend the corresponding portion of patches.

The LIN and ATT scores are similar for contrastive self-supervised pretraining on ViT-B, e.g., (76.2 vs 77.0)(76.2~{}\text{vs}~{}77.0) for MoCo v3 and (77.3 vs 77.8)(77.3~{}\text{vs}~{}77.8) for DINO. This means that the extra cross-attention in attentive probing does not make a big difference, which is one more evidence for our analysis in Section 4 that they already focus mainly on the region where the instance in the 10001000 categories lies.

4 Downstream Tasks

Semantic segmentation on ADE20\mathbf{20}K zhou2017scene . We follow the implementation bao2021beit to use UperNet xiao2018unified . The CAE with the tokenizers learned over ImageNet-11K performs almost the same as the tokenizers learned over 400400M images provided by DALL-E (CAE*), implying that the tokenizer trained on ImageNet-11K (without using the labels) or a larger dataset does not affect the pretraining quality and accordingly the downstream task performance.

Table 2 shows that using the ViT-B, our CAE* with 300300 training epochs performs better than DeiT, MoCo v3, DINO, MAE (16001600 epochs) and BEiT. Our CAE* (16001600 epochs) further improves the segmentation scores and outperforms MAE (16001600 epochs), MoCo v3 and DeiT by 2.12.1, 3.03.0 and 3.23.2, respectively. Using ViT-L, our CAE* (16001600 epochs) outperforms BEiT (16001600 epochs) and MAE (16001600 epochs) by 1.41.4 and 1.11.1, respectively.

The superior results over supervised and contrastive self-supervised pretraining methods, DeiT, MoCo v3 and DINO, stem from that our approach captures the knowledge beyond the 10001000 classes in ImageNet-11K. The superior results over BEiT and MAE stems from that our CAE makes predictions in the encoded representation space and that representation learning and pretext task completion are separated.

Object detection and instance segmentation on COCO lin2014microsoft . We adopt the Mask R-CNN approach he2017mask that produces bounding boxes and instance masks simultaneously, with the ViT as the backbone. The results are given in Table 3. We report the box AP for object detection and the mask AP for instance segmentation. The observations are consistent with those for semantic segmentation in Table 2. Our CAE* (300300 epochs, ViT-B) is superior to all the other models except that a little lower than MAE (16001600 epochs). Our approach (16001600 epochs) outperforms MAE (16001600 epochs), MoCo v3 and DeiT by 1.61.6, 4.54.5 and 3.13.1, respectively. Using ViT-L, our CAE achieves 54.654.6 box AP and outperforms MAE by 0.60.6.

We also report the results of object detection and instance segmentation on COCO with the Cascaded Mask R-CNN framework ZhaoweiCai2021CascadeRH in Table 6. Results show that our CAE performs better than other methods.

In addition, we conduct experiments on the scaling ability of CAE on the detection task. The detection model is built upon ViT-Huge DosovitskiyB0WZ21 , DINO HaoZhang2023DINODW , and Group DETR QiangChen2022GroupDF (see groupdetrv2 for more details). The ViT-Huge is pretrained on ImageNet-2222K deng2009imagenet using CAE. We are the first to obtain 64.664.6 mAP on COCO test-dev, which outperforms previous methods with larger models and more training data (e.g., BEIT-3 WenhuiWang2023ImageAA (63.763.7 mAP) and SwinV2-G ZeLiu2021SwinTV (63.163.1 mAP)).

Classification. We conduct fine-tuning experiments on three datasets: Food-101101 bossard14 , Clipart castrejon2016learning , and Sketch castrejon2016learning . Results in Table 4 show that the proposed method outperforms the previous supervised method (DeiT) and self-supervised methods (DINO, MAE).

5 Ablation Studies

Decoder and alignment. The CAE architecture contains several components for pretraining the encoder: regressor and alignment for masked representation prediction, decoder with a linear layer for masked patch reconstruction. We observe that if the pretraining task, masked patch reconstruction, is not included, the training collapses, leading to a trivial solution. We thus study the effect of the decoder (when the decoder is removed, we use a linear layer to predict the targets), which is helpful for target reconstruction, and the alignment, which is helpful for representation prediction.

Table 5 shows the ablation results. We report the scores for linear probing, attentive probing, fine-tuning and downstream tasks: semantic segmentation on the ADE2020K dataset and object detection on COCO with the DALL-E tokenizer as the target. One can see that the downstream task performance is almost the same when only the decoder is added and that the performance increases when the decoder and the alignment are both added. This also verifies that the alignment is important for ensuring that the predicted representations of masked patches lie in the encoded representation space and thus the predictions are made in the encoded representation space, and accordingly improving the representation quality. Without the decoder, the performance drops. This is because the reconstruction from the semantic representation to the low-level targets cannot be done through a single linear layer, and no decoder will deteriorate the semantic quality of the encoder. The additional computational cost, i.e. the number of parameters and training time, brought by the decoder and alignment is relatively small, e.g., increasing the number of parameters to 1.23×1.23\times and training time to 1.24×1.24\times.

Mask ratio. We also conduct experiments with different mask ratios including 40%40\%, 50%50\%, and 60%60\%. Results are listed in Table 7. We find that ratio 50%50\% gets better results than ratio 40%40\%. Adopting a higher mask ratio (60%60\%) could further improve the performance of linear probing and attentive probing, while the semantic segmentation performance is reduced by 0.20.2%. We choose 50%50\% in our work unless specified.

#layers in the regressor and decoder. For the number of layers in the latent contextual regressor and decoder, we tried four choices: 11-layer, 22-layers, 44-layer, and 55-layer. The results for linear probing are 58.758.7, 62.162.1, 64.164.1, and 64.264.2. The results for attentive probing are 67.567.5, 71.171.1, 73.873.8, and 73.773.7. We empirically observed that 44-layer outperforms other choices overall.

Loss tradeoff parameter. There is a tradeoff variable λ\lambda in the loss function given in Equation 1. We did not do an extensive study and only tried three choices, λ=1\lambda=1, λ=1.5\lambda=1.5 and λ=2\lambda=2. The linear probing results are 63.463.4, 63.763.7 and 64.164.1, respectively. The choice λ=1\lambda=1 works also well, slightly worse than λ=2\lambda=2 that is adopted in our experiment.

Reconstruction targets. To study the impact of different pretraining targets on model performance, we conduct additional experiments on the RGB pixel value target. Comparing the results with DALL-E tokenizer and d-VAE tokenizer trained on ImageNet-1K, the model shows better linear probe and segmentation results but inferior in attentive probe, as shown in Table 8. Pretraining with these three targets obtains similar performance, illustrating that CAE does not rely on specific pretraining targets.

Conclusion

The core design of our CAE architecture for masked image modeling is that predictions are made from visible patches to masked patches in the encoded representation space. We adopt two pretraining tasks: masked representation prediction and masked patch reconstruction. Experiments demonstrate the effectiveness of the CAE design. In addition, we also point out that the advantage of MIM methods over typical contrastive self-supervised pretraining and supervised pretraining on ImageNet-11K is that MIM learns the representations for all the patches, while typical contrastive self-supervised pretraining (e.g., MoCo and SimCLR) and supervised pretraining tend to learn semantics mainly from center patches of the original images and little from non-center patches.

Possible extensions, as mentioned in the arXiv version CAE2022 , include: investigating the possibility only considering the pretraining task, masked representation prediction, without masked patch reconstruction, pretraining a depth-wise convolution network with masked convolution, and pretraining with the CLIP targets CAEv22022 .

Potential limitations. The proposed method may face challenges when dealing with large and contiguous masked regions in an image, e.g., the whole object region is almost masked. Obtaining plausible and high-quality reconstruction for large areas can be particularly difficult, as the model has to infer the missing information based on limited available context. This is a common limitation of Masked Image Modeling methods, and our proposed method is not exempt from it.

Acknowledgments

We would like to acknowledge Hangbo Bao, Xinlei Chen, Li Dong, Qi Han, Zhuowen Tu, Saining Xie, and Furu Wei for the helpful discussions.

Declarations

This work is partially supported by the National Key Research and Development Program of China (2020YFB1708002), National Natural Science Foundation of China (61632003, 61375022, 61403005), Grant SCITLAB-20017 of Intelligent Terminal Key Laboratory of SiChuan Province, Beijing Advanced Innovation Center for Intelligent Robots and Systems (2018IRS11), and PEK-SenseTime Joint Laboratory of Machine Vision. Ping Luo is supported by the General Research Fund of HK No.27208720, No.17212120, and No.17200622.

Our code will be available at https://github.com/Atten4Vis/CAE.

The datasets used in this paper are publicly available. ImageNet: https://www.image-net.org/, ADE2020K: https://groups.csail.mit.edu/vision/datasets/ADE20K/, COCO: https://cocodataset.org/, Food-101: https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/, Clipart: http://projects.csail.mit.edu/cmplaces/download.html, Sketch: http://projects.csail.mit.edu/cmplaces/download.html.

References