Understanding Masked Image Modeling via Learning Occlusion Invariant Feature
Xiangwen Kong, Xiangyu Zhang
Introduction
Invariance matters in science . In self-supervised learning, invariance is particularly important: since ground truth labels are not provided, one could expect the favored learned feature to be invariant (or more generally, equivariant ) to a certain group of transformations on the inputs. Recent years, in visual recognition one of the most successful self-supervised frameworks – contrastive learning – benefits a lot from learning invariance. The key insight of contrastive learning is, because recognition results are typically insensitive to the deformations (e.g. cropping, resizing, color jittering) on the input images, a good feature should also be invariant to the transformations. Therefore, contrastive learning suggests minimizing the distance between two (or more ) feature maps from the augmented copies of the same data, which is formulated as follows:
where is the data distribution; means the encoder network parameterized by ; and are two transformations on the input data, which defines what invariance to learn; is the distance functionFollowing the viewpoint in , we suppose distance functions could contain parameters which are jointly optimized with Eq. 1. For example, weights in project head or predict head are regarded as a part of distance function . (or similarity measurement) to measure the similarity between two feature maps and . Clearly, the choices of and are essential in contrastive learning algorithms. Researchers have come up with a variety of alternatives. For example, for the transformation , popular methods include random cropping , color jittering , rotation , jigsaw puzzle , colorization and etc. For the similarity measurement , InfoMax principle (which can be implemented with MINE or InfoNCE loss ), feature de-correlation , asymmetric teacher , triplet loss and etc., are proposed.
Apart from contrastive learning, very recently Masked Image Modeling (MIM, e.g. ) quickly becomes a new trend in visual self-supervised learning. Inspired by Masked Language Modeling () in Natural Language Processing, MIM learns feature via a form of denoising autoencoder : images which are occluded with random patch masks are fed into the encoder, then the decoder predicts the original embeddings of the masked patches:
where “” means element-wise product; is patch mask So “” represents “unmasked patches” and vice versa.; and are encoder and decoder respectively; is the learned representation; is the similarity measurement, which varies in different works, e.g. -distance , cross-entropy or perceptual loss in codebook space. Compared with conventional contrastive methods, MIM requires fewer effort on tuning the augmentations, furthermore, achieves outstanding performances especially in combination with vision transformers , which is also demonstrated to be scalable into large vision models .
In this paper, we aim to build up a unified understanding framework for MIM and contrastive learning. Our motivation is, even though MIM obtains great success, it is still an open question how it works. Several works try to interpret MIM from different views, for example, suggests MIM model learns "rich hidden representation" via reconstruction from masked images; afterwards, gives a mathematical understanding for MAE . However, what the model learns is still not obvious. The difficulty lies in that MIM is essentially reconstructive (Eq. 2), hence the supervision on the learned feature () is implicit. In contrast, contrastive learning acts as a siamese nature (Eq. 1), which involves explicit supervision on the representation. If we manage to formulate MIM into an equivalent siamese form like Eq. 1, MIM can be explicitly interpreted as learning a certain invariance according to some distance measurement. We hope the framework may inspire more powerful self-supervised methods in the community.
In the next sections, we introduce our methodology. Notice that we do not aim to set up a new state-of-the-art MIM method, but to improve the understanding of MIM frameworks. Our findings are concluded as follows:
We propose RelaxMIM, a new siamese framework to approximate the original reconstructive MIM method. In the view of RelaxMIM, MIM can be interpreted as a special case of contrastive learning: the data transformation is random patch masking and the similarity measurement relates to the decoder. In other words, MIM models intrinsically learn occlusion invariant features.
Based on RelaxMIM, we replace the similarity measurement with simpler InfoNCE loss. Surprisingly, the performance maintains the same as the original model. It suggests that the reconstructive decoder in MIM framework does not matter much; other measurements could also work fine. Instead, patch masking may be the key to success.
To understand why patch masking is important, we perform MIM pretraining on very few images (e.g. only 1 image), then finetune the encoder with supervised training on full ImageNet. Though the learned representations lack of semantic information after pretraining, the finetuned model still significantly outperforms those training from scratch. We hypothesize that the encoder learns data-agnostic occlusion invariant features during pretraining, which could be a favored initialization for finetuning.
MIM intrinsically learns occlusion invariant feature
In this section, we mainly introduce how to approximate MIM formulation (Eq. 2) with a siamese model. For simplicity, we take MAE as an representative example of MIM, in which the similarity measurement is simply distance on the masked patches. Other MIM methods can be analyzed in a similar way. Following the notations in Eq. 2, the loss function for MAE training isIn original MAE , the encoder network only generates tokens of unmasked patches and the decoder only predict the masked patches during training. In our formulations, for simplicity we suppose both networks predict the whole feature map; we equivalently extract the desired part via proper masking if necessary. :
Let us focus on the second term. Typically, the dimension of feature embedding is much larger than dimension of input image, thus the encoder (at least) has a chance to be lossless . That means for the encoder function , there exists a network parameterized by that satisfying . Then, we rewrite Eq. 3 in the following equivalent form:
Eq. 4 can be further simplified. Notice that just approximates the “inverse” (if exists) of , there is no reason to use a different architecture from . So we let . Then we define a new similarity measurement:
Eq. 7 helps us to understand MIM from a explicit view. Compared Eq. 7 with Eq. 1, the formulation can be viewed as a special case of contrastive learning: the loss aims to minimize the differences between the representations derived from two masking transformations. Therefore, we conclude that MIM pretraining encourage occlusion invariant features. The decoder joints as a part of the similarity measurement (see Eq. 5), which is reasonable: since it is difficult to define a proper distance function directly in the latent space, a feasible solution is to project the representation back into the image space, because similarities like -distance in image space are usually explainable (analogous to PSNR). In addition, the constraint term in Eq. 7 can be viewed as standard AutoEncoder defined on the space of , which guarantees the projection to be informative, avoiding collapse of the similarity measurement.
Although Eq. 7 explicitly uncovers the invariant properties of MIM in theory, it is a drawback that Eq. 7 involves a nested optimization, which is difficult to compute. We thus propose a relaxed form of Eq. 7, named R-MAE (or RelaxMIM in general):
Eq. 8 jointly optimizes the distance term and the constraint term in Eq. 7. controls the balance of the two terms. In practice, we let to save computational cost, as we empirically find the optimization targets of and in Eq. 8 do not diverge very much.
Empirical evaluation.
First, we verify our claim that MIM representation is robust to image occlusion, as suggested by Eq. 7. We compute the CKA similarity between the learned features from full images and images with different mask ratios respectively, at each block in the encoder. Figure 1 shows the CKA similarities of different models. The numbers (0.1 to 0.9) indicate the mask ratios (i.e. percentages of image patches to be dropped) of the test images respectively. As shown in Figure 1, both original MAE and our relaxed R-MAE (as well as another variant C-MAE, see the next section) obtain high CKA scores, suggesting those methods learn occlusion invariant features. In contrast, other methods such as supervised training or MoCo v3 do not share the property, especially if the drop ratio is large. After finetuning, the CKA similarities drop, but are still larger than those training from scratch.
Next, we verify how well R-MAE (Eq. 8) approximates the original MAE. We pretrain the original MAE and R-MAE on ImageNet using the same settings: the mask ratio is 0.75 and training epoch is 100 ( is set to 1 for ours). Then we finetune the models on labeled ImageNet data for another 100 epochs. Results are shown in Table 1. Our finetuning accuracy is slightly lower than MAE by 0.4%, which may be caused by the relaxation. Nevertheless, R-MAE roughly maintains the benefit of MAE, which is still much better than supervised training from scratch and competitive among other self-supervised methods with longer pretraining. Another interesting observation is that, the reconstruction quality of R-MAE is even better than the original MAE (see PSNR column in Table 1), which we think may imply the trade-off by the choice of in Eq. 8. We will investigate the topic in the future.
Similarity measurement in MIM is replaceable
Eq. 7 bridges MIM and contrastive learning with a unified siamese framework. Compared with conventional contrastive learning methods (e.g. ), in MIM two things are special: 1) data transformations : previous contrastive learning methods usually employ random crop or other image jittering, while MIM methods adopt patch masking; 2) similarity measurement , contrastive learning often uses InfoNCE or other losses, while MIM implies a relatively complexNotice that the constraint term in Eq. 7 also belongs to the similarity measurement. formulation as Eq. 5. To understand whether the two differences are important, in this section we study how the choice of affects the performance.
We aim to replace the measurement with a much simpler InfoNCE loss . We name the new method contrastive MAE (C-MAE). Inspired by , we transform the representations with asymmetric MLPs before applying the loss. The new distance measurement is defined as follows:
where and are project head and predict head respectively following the name in BYOL , which are implemented with MLPs; is the temperature of the softmax. Readers can refer to for details. Hence the objective function of C-MAE is:
Unlike Eq. 7, C-MAE does not include nested optimization, thus can be directly optimized without relaxing.
The design of transformation 𝒯𝒯\mathcal{T}.
We intend to use the same transformation as we used in MAE and R-MAE (Eq. 6). However, we find directly using Eq. 6 in C-MAE leads to convergence problem. We conjecture that even though the two transformations derive different patches from the same image, they may share the same color distribution, which may lead to information leakage. Inspired by SimCLR , we introduce additional color augmentation after the transformation to cancel out the leakage. The detailed color jittering strategy follows SimSiam .
Token-wise vs. instance-wise loss.
We mainly evaluate our method on ViT-B model. By default, the model generates a latent representation composed of 14x14 patch tokens and one class token, where each patch relates to one image patch while the class token relates to the whole instance. It is worth discussing how the loss in Eq. 11 applies to the tokens. We come up with four alternatives: apply the loss in Eq. 11 1) only to the class token; 2) on the average of all patch tokens; 3) to each patch token respectively; 4) to each patch token as well as the class token respectively. If multiple tokens are assigned to the loss, we gather all loss terms by averaging them up. Table 3 shows the ablation study results. It is clear that token-wise loss on the patch tokens achieves the best finetuning accuracy on ImageNet. In comparison, adding the class token does not lead to improvement, which may imply that class token in self-supervised learning is not as semantic as in supervised learning. Therefore, we use a token-wise-only strategy for C-MAE by default.
Implementation details.
Following , we use a siamese network, which contains an online model and a target model whose parameters are EMA updated by the online model. We use 2-layer projector (i.e. in Eq. 11) and 2-layer predictor (), and use GELU as activation layer. To represent the masked patches into the encoder network, we adopt learnable mask tokens as does rather than directly discard the tokens within the masked region as the original MAE, because unlike MAE, our C-MAE does not include a heavy transformer-based decoder to predict the embeddings for the masked region.
Result and discussion.
Table 2 shows the finetuning results of C-MAE and a few other self-supervised methods. C-MAE achieves comparable results with the counterpart MAE baselines, suggesting that in MIM framework the reconstructive decoder, or equivalently the measurement in siamese form (Eq. 5), does not matter much. A simple InfoNCE loss works fine. We also notice that our findings agree with recent advances in siamese MIMs, e.g. iBOT , MSN and data2vec , whose frameworks involve various distance measurements between the siamese branches instead of reconstructing the unmasked parts, however, achieve comparable or even better results than the original reconstruction-based MIMs like . In addition to those empirical observations, our work uncovers the underlying reason: both reconstructive and siamese methods target learning occlusion invariant features, thereby it is reasonable to obtain similar performances.
Table 2 also indicates that, as siamese frameworks, C-MAE achieves comparable or even better results than previous counterparts such as DINO , even though the former mainly adopts random patch masking while the latter involves complex strategies in data transformation. also reports a similar phenomenon that data augmentation is less important in MIM. The observation further supports the viewpoint that learning occlusion invariant feature is the key to MIM, rather than the loss. Intuitively, to encourage occlusion invariance, patch masking is a simple but strong approach. For example, compared with random crop strategy, patch masking is more general – cropping can be viewed as a special mask pattern on the whole image, however, according to the experiments in , it is good enough or even better to leave patch masking fully randomizedAlthough very recent studies suggest more sophisticated masking strategies can still help..
Additional ablations.
Table 4 presents additional results on MAE and C-MAE. First, Although C-MAE shows comparable fine-tuning results with MAE, we find under linear probing and few-shot (i.e. fine-tuning on 10% ImageNet training data) protocols, C-MAE models lead to inferior results. Further study shows the degradation is mainly caused by the usage of mask tokens in C-MAE, which is absent in the original MAE – if we remove the mask tokens as done in MAE’s encoder, linear probing and few-shot accuracy largely recover (however fine-tuning accuracy slightly drops), which we think is because mask tokens enlarge the structural gap between pretraining and linear/few-shot probing, since the network is not fully fine-tuned under those settings.
Second, we further try replacing the InfoNCE loss 9 with BYOL loss in C-MAE. Following the ablations in Table 3, we still make the BYOL loss in token-wise manner. Compared with InfoNCE, BYOL loss does not have explicit negative pairs. Results imply that BYOL loss shows similar trend as InfoNCE loss, which supports our viewpoint “similarity measurement in MIM is replaceable”. However, we also find BYOL loss is less stable, resulting in slightly lower accuracy than that of InfoNCE.
Last, since our C-MAE involves color jittering , one may argue that color transformation invariance could be another key factor other than occlusion invariance. Unfortunately, the ablation study is nontrivial because we find the contrastive loss quickly collapses without color jittering. So instead, we study the original MAE with additional color jittering. We compare two configurations: a) augmenting the whole image before applying MAE; b) only augmenting the unmasked patches (i.e. the reconstruction targets keep the same). Results show that neither setting boosts MAE further, which implies the invariance of color jittering does not matter much.
MIM can learn a favored, (almost) data-agnostic initialization
As discussed in the above sections, learning occlusion invariant features is the key “philosophy” of MIM methods. Hence an interesting question comes up: how do the learned networks model the invariance? One possible hypothesis is that occlusion invariance is represented in an data-agnostic way, just analogous to the structure of max pooling – the output feature is robust only if the most significant input part is not masked out, thereby the invariance is obtained by design rather than data. Another reasonable hypothesis is, in contrast, the invariance requires knowledge from a lot of data. In this section we investigate the question.
Inspired by , to verify our hypotheses we try to significantly reduce the number of images for MAE pretraining, i.e. ranging from 1 for 1000 randomly sampled from ImageNet training set, hence the semantic information from training data should be very limited in the pretraining phase. Notice that MAE training tends to suffer from over-fitting on very small training set, as the network may easily “remember” the training images. Therefore, we adopt stronger data augmentation and early-stop trick to avoid over-fitting. Table 5 presents the result. Very surprisingly, we find pretraining with only one image with 5 epochs already leads to improved finetuning score – much better than 100-epoch training from scratch and on par with training for 300 epochs. The fine-tuning results do not improve when the number of pretrain images increases to 1000. Since it is not likely for only one image to contain much of the semantic information of the whole dataset, the experiment provides strong evidence that MIM can learn a favored initialization, more importantly, which is (almost) data-agnostic. Table 6 also indicates the choice of sampling strategy does not affect the fine-tuning accuracy, further suggesting that such benefit from MIM pre-training might be free of category information.
Moreover, in Table 7 we benchmark various pretraining methods on a 1000-image subset from ImageNet training data, which provides more insights on MIM training. We find the linear probing accuracy of MAE is very low, which is only slightly better than random feature (first row), suggesting that the feature learned from 1000 images is less semantic; however, the finetuning result as well as few-shot fine-tuning is fine. Our proposed R-MAE and C-MAE share similar properties as the original MAE – relatively low linear probing scores but high fine-tuning performance. The observation strongly supports our first hypothesis at the beginning of Sec. 4: the occlusion invariance learned by MIM could be data-agnostic, which also serves as a good initialization for the network. In comparison, supervised training and MoCo v3 on 1000 images fail to obtain high fine-tuning scores, even though their linear probing accuracy is higher, which may be because those methods cannot learn occlusion-invariant features from small dataset effectively. In the appendix, we will discuss more on the topic.
Experimental Details
We use ViT-B/16 as the default backbone. For MAE pretraining, we use the same settings as , and use the patch normalization when computing loss. We use the mask ratio of 0.75, which is the most effective one in . We use AdamW optimizer with cosine decay scheduler and the batch size is set to 1024. We set the base learning rate (learning rate for batch size of 256) as 1.5e-4 with a 20-epoch linear warm-up and scale up the learning rate linearly when batch size increases . For R-MAE, we search the learning rate and finally set the base learning rate as 3.0e-4. Other training settings are the same as . For C-MAE, the momentum to update the teacher model is set to 0.996, and the temperature to compute contrastive loss is set to 0.2. For projector and predictor heads, we set 2048-d for hidden layers. We search the learning rate and finally set the base learning rate as 1.5e-4. Other parameters are the same as C-MAE. We train the model for 100 epochs on the ImageNet dataset as default. Due to the computational resource constraints, we report the results of 400 epochs to prove that our method gains better results with longer training.
Finetuning.
We follow the training settings in . We use the average pooling feature of the encoded patch tokens as the input of classifier, and train the model end-to-end. Following, we reset the parameters of the final normalization layer. We use AdamW optimizer with cosine decay scheduler and set the batch size to 1024. We set the base learning rate as 1.0e-3 with 5-epoch linearly warm-up and train the model for 100 epochs. Note that the supervised trained ViT in our paper uses the same settings as finetuning and the model is trained for 100 epochs.
Related Work
As the ViT models achieve breakthrough results in computer vision, self-supervised pretraining for ViTs becomes an intense scholarly domain. In addition to siamese frameworks such as , MIM is an efficient and popular way of self-supervised modeling. The model learns rich hidden information by optimizing the reconstruction model . Following BERT , compress the image to a few pixels, and then directly learn the masked pixel color. maps all image patches to 8192 embeddings by training d-VAE , and then learns the correct embedding correspondence for mask patches. optimizes the masking process based on BEiT. combines MIM with siamese frameworks and improves the performance of linear probing. use a simple method to reconstruct the original image, and also learn rich features effectively. gives a mathematical understanding of MAE. MSN , which is a concurrent work of ours, also discusses the invariance to mask.
Siamese approaches in SSL.
Self-supervised pretraining achieves great success in classification , detection and segmentation. One of the promising methods is based on siamese frameworks , which learns representations by minimizing the distance of positive samples with siamese networks. In practice, uses the same parameters in the online and target model, while updates online parameters to target using exponential moving average. Only minimizing the distance of positive samples will cause the model to fall into trivial solutions, so a critical problem in SSL is how to prevent such a model from collapsing. use negative samples from different images, then computes contrastive loss. add an extra predictor on the top of the online model then stop the gradient of the target model. Instead of optimizing the loss per instance, optimize the variance, covariance or cross-covariance on the channel dimension. optimize the distributions of the two features, and avoid trivial solutions by centering and sharpening.
Limitation
When implementing RelaxMIM, we only used MAE as the backbone and do not try other MIM methods. When implementing discriminative MIM, we simply use InfoNCE, which we believe can be replaced by other contrastive learning methods. Due to the lack of computational resources, we only use ViT-B as the backbone and train all models on ImageNet-1k, and our models are all trained for much fewer epochs than commonly used in self-supervised methods. In our future work, we plan to train the models longer and use larger scale models (such as ViT-L) to get better results.
Conclusion
In this paper, we propose a new viewpoint: MIM implicitly learns occlusion-invariant features, and build up a unified understanding framework RelaxMIM for MIM and contrastive learning. In the view of RelaxMIM, MIM models intrinsically learn occlusion invariant features. Then we verify that the representation of RelaxMIM is robust to image occlusion. Based on RelaxMIM, we replace the similarity measurement with simpler InfoNCE loss and achieve comparable results with the original MIM framework. It suggests that patch masking may be the critical component of the framework. To understand why patch masking is important, we perform MIM pretraining on very few images and finetune the encoder with supervised training on full ImageNet. We find that the encoder learns almost data-agnostic occlusion invariant features during pretraining, which could be a favored initialization for finetuning. To measure whether the MIM method has learned human recognition patterns, we compare the shape bias of different self-supervised models and conclude that, MIM could improve the recognition ability of ViT to make it closer to human recognition, but the improvement may be limited. We hope the RelaxMIM framework may inspire more powerful self-supervised methods in the community.
References
Appendix A More Visualization Experiments
Here we discuss the occlusion invariance of a few images pretrained MAE models. We use CKA similarities between the representations generated by the masked image and the full image under different mask ratios as protocol. The numbers (0.1 to 0.9) indicate the mask ratios (i.e. percentages of image patches to be dropped) of the test images respectively. The higher CKA similarity with a large mask ratio means the model learns better occlusion invariance.
Figure 2 shows the CKA similarities of MAE pretrained with different amounts of data. As the figures show, the model learns occlusion invariance even pretrained with one image. Unfortunately, the model does not keep the occlusion invariance after finetuning. When the mask ratio increase to 0.7, the CKA similarities drop significantly below 0.5. In Comparison, full-set pretrained MAE is not so sensitive to the change of mask ratio (after 0.7) after finetuning.
Furthermore, we discuss the relationship between occlusion invariance with overfitting. We train the MAE with different training epochs on 10 images and plot the CKA similarities. Results in Figure 3 show that, overfitting affects the learning of occlusion invariance, and causes the performance to drop. We further explore the way to prevent overfitting, using stronger data augmentation, whether beneficial to maintain occlusion invariance. As shown in Figure 3, even the finetuning results increase a little when using stronger augmentations, the occlusion invariance does not been improved.
A.2 Comparison with Human Recognition
shows that, ViT behaves more like humans in classification, and we wonder whether our proposed siamese framework learns more high-level perception. Following the method in , we plot the shape bias of MIM models in Figure 4.
Figure 4 shows the shape bias of MAE, MoCov3, R-MAE and C-MAE. As shown in the figure, the grey line represents the supervised trained model, which has the lowest shape bias. That means fully supervised learning prefers to learn texture information rather than self-supervised pretrained models. Both MAE (blue line) and R-MAE (green line) learn less shape bias than MoCo (yellow line) and C-MAE (orange line). We speculate that it is because the target of the pretext task of MIM is closer to the original images (or exactly the origin images), which makes the model learn more texture features. Additionally, C-MAE learns a similar shape-bias compared with MoCo v3. The results indicate that instance-wise learning is not necessary for models to learn as human does, learning occlusion invariance could also improve the ability of the model to learn shape-bias. When training longer, all masked-based models are biased to learn texture features. We conclude that the masked-based models could learn the ability to complete object shape quickly in a few epochs, and then learn to reconstruct the texture of images.