Scaling Vision Transformers to 22 Billion Parameters
Mostafa Dehghani, Josip Djolonga, Basil Mustafa, Piotr Padlewski, Jonathan Heek, Justin Gilmer, Andreas Steiner, Mathilde Caron, Robert Geirhos, Ibrahim Alabdulmohsin, Rodolphe Jenatton, Lucas Beyer, Michael Tschannen, Anurag Arnab, Xiao Wang, Carlos Riquelme, Matthias Minderer, Joan Puigcerver, Utku Evci, Manoj Kumar, Sjoerd van Steenkiste, Gamaleldin F. Elsayed, Aravindh Mahendran, Fisher Yu, Avital Oliver, Fantine Huot, Jasmijn Bastings, Mark Patrick Collier, Alexey Gritsenko, Vighnesh Birodkar, Cristina Vasconcelos, Yi Tay, Thomas Mensink, Alexander Kolesnikov, Filip Pavetić, Dustin Tran, Thomas Kipf, Mario Lučić, Xiaohua Zhai, Daniel Keysers, Jeremiah Harmsen, Neil Houlsby
Introduction
Similar to natural language processing, transfer of pre-trained vision backbones has improved performance on a wide variety of vision tasks (Pan and Yang, 2010; Zhai et al., 2019; Kolesnikov et al., 2020). Larger datasets, scalable architectures, and new training methods (Mahajan et al., 2018; Dosovitskiy et al., 2021; Radford et al., 2021; Zhai et al., 2022a) have accelerated this growth. Despite this, vision models have trailed far behind language models, which have demonstrated emergent capabilities at massive scales (Chowdhery et al., 2022; Wei et al., 2022). Specifically, the largest dense vision model to date is a mere B parameter ViT (Chen et al., 2022), while a modestly parameterized model for an entry-level competitive language model typically contains over 10B parameters (Raffel et al., 2019; Tay et al., 2022; Chung et al., 2022), and the largest dense language model has 540B parameters (Chowdhery et al., 2022). Sparse models demonstrate the same trend, where language models go beyond a trillion parameters (Fedus et al., 2021) but the largest reported sparse vision models are only 15B (Riquelme et al., 2021).
This paper presents ViT-22B, the largest dense ViT model to date. En route to 22B parameters, we uncover pathological training instabilities which prevent scaling the default recipe, and demonstrate architectural changes which make it possible. Further, we carefully engineer the model to enable model-parallel training at unprecedented efficiency. ViT-22B’s quality is assessed via a comprehensive evaluation suite of tasks, ranging from (few-shot) classification to dense output tasks, where it reaches or advances the current state-of-the-art. For example, even when used as a frozen visual feature extractor, ViT-22B achieves an accuracy of 89.5% on ImageNet. With a text tower trained to match these visual features (Zhai et al., 2022b), it achieves 85.9% accuracy on ImageNet in the zero-shot setting. The model is furthermore a great teacher — used as a distillation target, we train a ViT-B student that achieves 88.6% on ImageNet, state-of-the-art at this scale.
This performance comes with improved out of distribution behaviour, reliability, uncertainty estimation and fairness tradeoffs. Finally, the model’s features are better aligned with humans perception, achieving previously unseen shape bias of 87%.
Model Architecture
ViT-22B is a Transformer-based encoder model that resembles the architecture of the original Vision Transformer (Dosovitskiy et al., 2021) but incorporates the following three main modifications to improve efficiency and training stability at scale: parallel layers, query/key (QK) normalization, and omitted biases.
As in Wang and Komatsuzaki (2021), ViT-22B applies the Attention and MLP blocks in parallel, instead of sequentially as in the standard Transformer:
This enables additional parallelization via combination of linear projections from the MLP and attention blocks. In particular, the matrix multiplication for query/key/value-projections and the first linear layer of the MLP are fused into a single operation, and the same is done for the attention out-projection and second linear layer of the MLP. This approach is also used by PaLM (Chowdhery et al., 2022), where this technique sped up the largest model’s training by 15% without performance degradation.
In scaling ViT beyond prior works, we observed divergent training loss after a few thousand steps. In particular, this instability was observed for models with around 8B parameters (see Appendix B). It was caused by extremely large values in attention logits, which lead to (almost one-hot) attention weights with near-zero entropy. To solve this, we adopt the approach of Gilmer et al. (2023), which applies LayerNorm (Ba et al., 2016) to the queries and keys before the dot-product attention computation. Specifically, the attention weights are computed as
where is query/key dimension, is the input, LN stands for layer normalization, and is the query weight matrix, and is the key weight matrix. The effect on an 8B parameter model is shown in Figure 1, where normalization prevents divergence due to uncontrolled attention logit growth.
Following PaLM (Chowdhery et al., 2022), the bias terms were removed from the QKV projections and all LayerNorms were applied without bias and centering (Zhang and Sennrich, 2019). This improved accelerator utilization (by 3%), without quality degradation. However, unlike PaLM, we use bias terms for the (in- and out-) MLP dense layers as we have observed improved quality and no speed reduction.
Figure 2 illustrates a ViT-22B encoder block. The embedding layer, which includes extracting patches, linear projection, and the addition of position embedding follow those used in the original ViT. We use multi-head attention pooling (Cordonnier et al., 2019; Zhai et al., 2022a) to aggregate the per-token representations in the head.
ViT-22B is uses patch size of with images at resolution (pre-processed by inception crop followed by random horizontal flip). Similar to the original ViT (Dosovitskiy et al., 2021), ViT-22B employs a learned 1D positional embedding. During fine-tuning on high-resolution images (different number of visual tokens), we perform a 2D interpolation of the pre-trained position embeddings, according to their location in the original image.
Other hyperparameters for the ViT-22B model architecture are presented in Table 1, compared to the previously reported largest ViT models, ViT-G (Zhai et al., 2022a) and ViT-e (Chen et al., 2022).
Following the template in Mitchell et al. (2019), we provide the model card in LABEL:tab:model_card (Appendix C).
Training Infrastructure and Efficiency
ViT-22B is implemented in JAX (Bradbury et al., 2018) using the FLAX library (Heek et al., 2020) and built within Scenic (Dehghani et al., 2022). It leverages both model and data parallelism. In particular, we used the jax.xmap API, which provides explicit control over both the sharding of all intermediates (e.g. weights and activations) as well as inter-chip communication. We organized the chips into a 2D logical mesh of size , where is the size of the data-parallel axis and is the size of the model axis. Then, for each of the groups, devices get the same batch of images, each device keeps only of the activations and is responsible for computing of the output of all linear layers (detailed below).
As we use explicit sharding, we built a wrapper around the dense layers in FLAX that adapts them to the setting where their inputs are split across devices. To maximize throughput, two aspects have to be considered — computation and communication. Namely, we want the operations to be analytically equivalent to the unsharded case, to communicate as little as possible, and ideally to have them overlap (Wang et al., 2022a) so that we can keep the matrix multiply unit, where most of the FLOP capacity is, busy at all times.
Furthermore, matrix multiplications are overlapped with the communication with the neighbours. This asynchronous approach allows for high matrix core utilization and increased device efficiency, while minimizing waiting on incoming communication. Figure 3 presents the overlapping communication and computation across devices with the parallel linear operation in row-sharding and column-sharding modes. The general case of this technique is presented in Wang et al. (2022a), who also introduce the XLA operations we leverage here.
The model is data-parallel on the first axis. Each parameter can be either fully replicated over this axis, or have each device hold a chunk of it. We opted to shard some large tensors from the model parameters to be able to fit larger models and batch sizes. This means that the device would have to gather the parameters before computing of the forward and scatter on the backward pass, but again, note that this happens asynchronous with computation. In particular, while computing one layer the device can start communicating the weights of the next one, thus minimizing the communication overhead.
Using these techniques, ViT-22B processes k tokens per second per core during training (forward and backward pass) on TPUv4 (Jouppi et al., 2020). ViT-22B’s model flops utilization (MFU) (Chowdhery et al., 2022; Dehghani et al., 2021a) is 54.9%, indicating a very efficient use of the hardware. Note that PaLM reports 46.2% MFU (Chowdhery et al., 2022; Pope et al., 2022) and we measured 44.0% MFU for ViT-e (data-parallel only) on the same hardware.
Experiments
ViT-22B is trained on a version of JFT (Sun et al., 2017), extended to around 4B images (Zhai et al., 2022a). These images have been semi-automatically annotated with a class-hierarchy of 30k labels. Following the original Vision Transformer, we flatten the hierarchical label structure and use all the assigned labels in a multi-label classification fashion employing the sigmoid cross-entropy loss.
ViT-22B was trained using 256 visual tokens per image, where each token represents a patch extracted from sized images. ViT-22B is trained for 177k steps with batch size of 65k: approximately 3 epochs. We use a reciprocal square-root learning rate schedule with a peak of , and linear warmup (first 10k steps) and cooldown (last 30k steps) phases. For better few-shot adaptation, we use a higher weight decay on the head () than body () for upstream training (Zhai et al., 2022a; Abnar et al., 2021).
2 Transfer to image classification
Efficient transfer learning with large scale backbones is often achieved by using them as frozen feature extractors. This section presents the evaluation results of ViT-22B for image classification using linear probing and locked-image tuning as well as out-of-distribution transfer. Additional results for Head2Toe transfer, few-shot transfer, and linear probing with L-BFGS can be found in Table 10.
We explored various ways of training a linear probe, our final setup on ImageNet uses SGD with momentum for 10 epochs at 224px resolution, with mild random cropping and horizontal flipping as the only data augmentations, and no further regularizations.
The results presented in Table 2 show that while the returns are diminishing, there is still a notable improvement at this scaleWe repeated a subset of the experiments multiple times and the results are almost identical.. Furthermore, we show that linear probing of larger models like ViT-22B can approach or exceed performance of full fine-tuning of smaller models with high-resolution, which can be often cheaper or easier to do.
We further test linear separability on the fine-grained classification dataset, iNaturalist 2017 (Cui et al., 2018). It has 5,089 find-grained categories, belonging to 13 super-categories. Unlike ImageNet, the image numbers in different categories are not balanced. The long-tail distribution of concepts is more challenging for classification. We compare ViT-22B with the other ViT variants. Similar to the linear probing on ImageNet, we use SGD with 0.001 starting learning rate and no weight decay to optimize the models and train for 30 epochs with cosine learning rate schedule with 3 epochs of linear warm-up. We test both px and px input resolutions. Figure 4 shows the results. We observe that ViT-22B significantly improves over the other ViT variants, especially with the standard px input resolution. This suggests the large number of parameters in ViT-22B are useful for extracting detailed information from the images.
2.2 Zero-shot via locked-image tuning
Following the Locked-image Tuning (LiT) (Zhai et al., 2022b) protocol, we train a text tower contrastively to match the embeddings produced by the frozen ViT-22B model. With this text tower, we can easily perform zero-shot classification and zero-shot retrieval tasks. We train a text Transformer with the same size as ViT-g (Zhai et al., 2022a) on the English subset of the WebLI dataset (Chen et al., 2022) for 1M steps with a 32K batch size. The images are resized to px, and the text is tokenized to 16 tokens using a SentencePiece (Kudo and Richardson, 2018) tokenizer trained on the English C4 dataset.
Table 3 shows the zero-shot transfer results of ViT-22B against CLIP (Radford et al., 2021), ALIGN (Jia et al., 2021), BASIC (Pham et al., 2021), CoCa (Yu et al., 2022a), LiT (Zhai et al., 2022b) with ViT-g (Zhai et al., 2022a) and ViT-e (Chen et al., 2022) models. The bottom part of Table 3 compares three ViT models using the LiT recipe. On all the ImageNet test sets, ViT-22B achieves either comparable or better results. Notably, zero-shot results on the ObjectNet test set is highly correlated with the ViT model size. The largest ViT-22B sets the new SOTA on the challenging ObjectNet test set. Appendix A shows zero-shot classification examples on OOD images.
2.3 Out-of-distribution
We construct a label-map from JFT to ImageNet, and label-maps from ImageNet to different out-of-distribution datasets, namely ObjectNet (Barbu et al., 2019), ImageNet-v2 (Recht et al., 2019) ImageNet-R (Hendrycks et al., 2020), and ImageNet-A (Hendrycks et al., 2021). ImageNet-R and ImageNet-A use the same 200 label subspace of ImageNet (constructed in such a way that misclassifications would be considered egregious (Hendrycks et al., 2021)), while ObjectNet has 313 categories, of which we only consider the 113 ones overlapping with the ImageNet label space. For ObjectNet and ImageNet-A we do an aspect-preserving crop of the central 75% of the image, for the other datasets we first resize them to a square format and then take a 87.5% central crop. Image input resolution is 224px for pre-trained checkpoints and 384px, 518px, 560px for models fine-tuned on ImageNet.
We can confirm results from (Taori et al., 2020; Djolonga et al., 2021; Kolesnikov et al., 2020) that scaling the model increases out-of-distribution performance in line with the improvements on ImageNet. This holds true for models that have only seen JFT images, and for models fine-tuned on ImageNet. In both cases, ViT-22B continues the trend of better OOD performance with larger models (Figure 5, Table 11). While fine-tuning boosts accuracy on both ImageNet and out-of-distribution datasets, the effective robustness (Andreassen et al., 2021) decreases (Figure 5). Even though ImageNet accuracy saturates, we see a significant increase on ObjectNet from ViT-e/14 to ViT-22B.
3 Transfer to dense prediction
Transfer learning for dense prediction is critical especially since obtaining pixel-level labels can be costly. In this section, we investigate the quality of captured geometric and spatial information by the ViT-22B model (trained using image-level classification objective) on semantic segmentation and monocular depth estimation tasks.
We evaluate ViT-22B as a backbone in semantic segmentation on three benchmarks: ADE20K (Zhou et al., 2017b), Pascal Context (Mottaghi et al., 2014) and Pascal VOC (Everingham et al., 2010). We analyze the performance in two scenarios: first, using a limited amount of data for transfer; second (in Section E.1), comparing end-to-end fine-tuning versus a frozen backbone with either a linear decoder (Strudel et al., 2021) or UperNet (Xiao et al., 2018). The number of additional parameters (M for linear and M for UperNet) is negligible compared to the size of the backbone. We use a fixed resolution (px) and report single scale evaluation.
We compare ViT-22B to the ViT-L of DeiT-III (Touvron et al., 2022) and ViT-G of Zhai et al. (2022a), when only a fraction of the ADE20k semantic segmentation data is available. We use the linear decoder and end-to-end fine-tuning. From Table 4, we observe that our ViT-22B backbone transfers better when seeing only few segmentation masks. For example, when fine-tuning with only 1200 images (i.e. ) of ADE20k training data, we reach a performance of 44.7 mIoU, an improvement of mIoU over DeiT-III Large (Touvron et al., 2022) and mIoU over ViT-G (Zhai et al., 2022a). When transferring with more data, the performance of ViT-G and ViT-22B converge.
3.2 Monocular depth estimation
We largely mirror the set-up explored in Ranftl et al. (2021) and train their Dense Prediction Transformer (DPT) on top of frozen ViT-22B backbone features obtained from the Waymo Open real-world driving dataset (Sun et al., 2020). Here we use only a single feature map (of the last layer) to better manage the high-dimensional ViT features. We also explore a much simpler “linear” decoder as a lightweight readout. In both cases we predict obtained from sparse LiDAR as the target and use Mean Squared Error (MSE) as the decoder training loss. We quantify performance using standard depth estimation metrics from the literature (Hermann et al., 2020; Eigen et al., 2014) and also report MSE. We use a resolution of . Remaining details are deferred to Section E.2.
Table 5 summarizes our main findings. From the top rows (DPT decoder), we observe that using ViT-22B features yields the best performance (across all metrics) compared to different backbones. By comparing the ViT-22B backbone to ViT-e (a smaller model but trained on the same data as ViT-22B) we find that scaling the architecture improves performance. Further, comparing the ViT-e backbone to ViT-L (a similar architecture to ViT-e but trained on less data) we find that these improvements also come from scaling the pre-training data. These findings demonstrate that both the greater model size and the greater dataset size contribute substantially to the improved performance. Using the linear decoder, it can be observed again that using ViT-22B features yields the best performance. The gap between DPT and linear decoding suggests that while enough geometric information is retained in the ViT features, only some of it is available for a trivial readout. We report qualitative results in Figure 6 and Figures 13 and 14 in Section E.2.
4 Transfer to video classification
We evaluate the quality of the representations learned by ViT-22B by adapting the model pretrained on images for video classification. We follow the “factorised encoder” architecture of Arnab et al. (2021): Our video model consists of an initial “spatial transformer”, which encodes each frame of the video independently of each other. Thereafter, the representation from each frame is pooled into a single token, which is then fed to a subsequent “temporal transformer” that models the temporal relations between the representations of each frame.
Here, we initialize the “spatial transformer” with the pretrained weights from ViT-22B and freeze them, as this represents a computationally efficient method of adapting large-scale models for video, and also because it allows us to effectively evaluate the representations learned by pretraining ViT-22B. Exhaustive experimental details are included in Appendix F. The temporal transformer is lightweight both in terms of parameters (only 63.7M parameters compared to the 22B frozen parameters in the spatial transformer), and FLOPs as it operates on a single token per frame.
Table 6 presents our results on video classification on the Kinetics 400 (Kay et al., 2017) and Moments in Time (Monfort et al., 2019) datasets, showing that we can achieve competitive results with a frozen backbone. We first compare to ViT-e (Chen et al., 2022), which has the largest previous vision backbone model consisting of 4 billion parameters, and was also trained on the JFT dataset. We observe that our larger ViT-22B model improves by 1.5 points on Kinetics 400, and 1.3 points on Moments in Time. Our results with a frozen backbone are also competitive with CoCA (Yu et al., 2022a), which performs a combination of contrastive and generative caption pretraining in comparison to our supervised pretraining, and uses many tokens per frame (vs. a single one produced by the pretrained frozen pooling) as well as a higher testing resolution.
Finally, we note that there is headroom for further improvement by full end-to-end fine-tuning. This is evidenced by the current state-of-the-art on Kinetics 400 (Wang et al., 2022b) and Moments in Time (Yu et al., 2022a) which leverage a combination of large-scale video pretraining and full end-to-end fine-tuning on the target dataset.
5 Beyond accuracy on downstream tasks
When studying the impact of scaling, there are important aspects to consider beyond downstream task performance. In this section, we probe ViT-22B’s fairness, alignment with human perception, robustness, reliability, and calibration. We find that favorable characteristics emerge when increasing model size. Additional analysis on perceptual similarity and feature attribution can be found in Appendix K and Appendix L.
Machine learning models are susceptible to unintended bias. For example, they can amplify spurious correlations in the training data (Hendricks et al., 2018; Caliskan et al., 2017; Zhao et al., 2017; Wang et al., 2020) and result in error disparities (Zhao et al., 2017; Buolamwini and Gebru, 2018; Deuschel et al., 2020). Here, we identify how scaling the model size can help mitigate such issues, by evaluating the bias of ViT-22B and ViT-{L, g, G, e} (Zhai et al., 2022a; Chen et al., 2022) using demographic parity (DP) as a measure of fairness (Dwork et al., 2012; Zafar et al., 2017).
We use CelebA (Liu et al., 2015) with binary gender as a sensitive attribute while the target is “attractive” or “smiling”. We emphasize that such experiments are carried out only to verify technical claims and shall by no means be interpreted as an endorsement of such vision-related tasks. We choose the latter attributes because they exhibit gender related bias as shown in Figure 15.
We train a logistic regression classifier on top of the ViT-22B pretrained features for a total of epochs and batch size , with a learning rate schedule of (first 25 epochs) and 0.001 (last 25 epochs). After that, we debias using the randomized threshold optimizer (RTO) algorithm of Alabdulmohsin and Lucic (2021), which was shown to be near-optimal and competitive with in-processing methods.
We observe that scale by itself does not impact DP, c.f. Figure 15. This is perhaps not surprising, as the model is trained to reconstruct a chosen target so the level of DP in accurate models is similar to that of the data itself.
However, scaling to ViT-22B offers benefits for fairness in other aspects. First, scale offers a more favorable tradeoff — performance improves with scale subject to any prescribed level of bias constraint. This is consistent with earlier observations reported in the literature (Alabdulmohsin and Lucic, 2021). Second, all subgroups tend to benefit from the improvement in scale. Third, ViT-22B reduces disparities in performance across subgroups. Figure 7 summarizes results for classification accuracy and Appendix G for expected calibration error (ECE) (Naeini et al., 2015; Guo et al., 2017) and OC-AUC (Kivlichan et al., 2021).
5.2 Human Alignment
How well do ViT-22B classification decisions align with human classification decisions? Using the model-vs-human toolbox (Geirhos et al., 2021), we evaluate three ViT-22B models fine-tuned on ImageNet with different resolutions (224, 384, 560). Accross all toolbox metrics, ViT-22B is SOTA: ViT-22B-224 for highest OOD robustness (Figure 19(a)), ViT-22B-384 for the closest alignment with human classification accuracies (Figure 19(b)), and ViT-22B-560 for the largest error consistency (i.e. most human-like error patterns, Figure 19(d)). The ViT-22B models have the highest ever recorded shape bias in vision models: while most models have a strong texture bias (approx. 20–30% shape bias / 70–80% texture bias) (Geirhos et al., 2019); humans are at 96% shape / 4% texture bias and ViT-22B-384 achieves a previously unseen 87% shape bias / 13% texture bias (Figure 8). Overall, ViT-22B measurably improves alignment to human visual object recognition.
5.3 Plex - pretrained large model extensions
Tran et al. (2022) comprehensively evaluate the reliability of models through the lens of uncertainty, robustnes (see Section 4.2.3) and adaptation (see Section 4.2.2). We focus here on the first aspect of that benchmark. To this end, we consider (1) the OOD robustness under covariate shift with ImageNet-C (Hendrycks and Dietterich, 2019), which we evaluate not only with the accuracy but also uncertainty metrics measuring the calibration (NLL, ECE) and the selective prediction (El-Yaniv and Wiener, 2010) (OC-AUC, see Section 4.5.1), and (2) open-set recognition—also known as OOD detection (Fort et al., 2021), which we evaluate via the AUROC and AUPRC, with Places365 as the OOD dataset (Hendrycks et al., 2019); for more details, see Appendix I.
In Table 7, we report the performance of ViT-L and ViT-22B (both with resolution 384) fine-tuned on ImageNet. To put in perspective the strong gains of ViT-22B, we also show Plex-L, a ViT-L equipped with the two components advocated by Tran et al. (2022), viz, efficient-ensemble (Wen et al., 2019) and heteroscedastic layers (Collier et al., 2021). We discuss the challenges and the results of the usage of those components at the 22B scale (Plex-22B) in Appendix I.
5.4 Calibration
Along with the robustness of Section 4.2.3, it is also natural to wonder how the calibration property of ViT evolves as the scale increases. To this end, we focus on the study of Minderer et al. (2021) that we extend with ViT-22B.
In Figure 9, we consider ViT-22B fine-tuned on ImageNet (resolution 384) and report the error (i.e., one minus accuracy) versus the calibration, as measured by the expected calibration error (ECE) (Naeini et al., 2015; Guo et al., 2017). We see how ViT-22B remarkably improves the tradeoff between accuracy and calibration. The conclusion holds both without (left) and with (right) a temperature-scaling of the logits that was observed to better capture the calibration trends across model families (Minderer et al., 2021). More details can be found in Appendix H.
5.5 Distillation
We perform model distillation (Hinton et al., 2015) to compress the ViT-22B into smaller, more widely usable ViTs. We distill ViT-22B into ViT-B/16 and ViT-L/16 by following the procedure of Beyer et al. (2022b). Using ImageNet-finetuned (at 384px) ViT-22B, we annotated 500 random augmentations and mixup transforms of each ImageNet image with ViT-22B logits. Then, we minimize the KL divergence between the student and the teacher predictive distributions. We train for 1000 epochs after initializing the student architecture from checkpoints pre-trained on JFT. The results are shown in Table 8, and we see that we achieve new SOTA on both the ViT-B and ViT-L sizes.
Conclusion
We presented ViT-22B, the currently largest vision transformer model at 22 billion parameters. We show that with small, but critical changes to the original architecture, we can achieve both excellent hardware utilization and training stability, yielding a model that advances the SOTA on several benchmarks. In particular, great performance can be achieved using the frozen model to produce embeddings, and then training thin layers on top. Our evaluations further show that ViT-22B is more aligned with humans when it comes to shape and texture bias, and offers benefits in fairness and robustness, when compared to existing models.
Acknowledgment
We would like to thank Jasper Uijlings, Jeremy Cohen, Arushi Goel, Radu Soricut, Xingyi Zhou, Lluis Castrejon, Adam Paszke, Joelle Barral, Federico Lebron, Blake Hechtman, and Peter Hawkins. Their expertise and unwavering support played a crucial role in the completion of this paper. We also acknowledge the collaboration and dedication of the talented researchers and engineers at Google Research.
References
Appendix A Zero-shot Classification Examples
Figure 10 contains example zero-shot classifications of generated images. These images were provided by the Parti (Yu et al., 2022b) and Imagen (Saharia et al., 2022) models. The training data for the ViT-22B vision backbone and the LiT text backbone was created before these models were trained, therefore these images are not present in the training data. Further, the objects and scenes contained in these images are highly out-of-distribution relative to the distribution of natural images on the web.
Appendix B Scalability
When scaling up the default ViT architecture, we encountered training instability in ViT at Adam . Initially, the loss would decrease as normal, but within 2000 steps the loss steadily increased. Figure 1 shows the behavior of attention logits during training for an 8B parameter model. Without normalization, attention logits quickly grow to over in magnitude, resulting in one-hot attention weights after the softmax, and subsequently unstable training losses and gradients.
To avoid instability, the learning rate of ViT was originally reduced with increasing model scale, from 1e-3 down to 4e-4 for ViT-H (Dosovitskiy et al., 2021). We retrain models up to ViT-L, comparing models trained similar to ViT, to models which have the normalization/reduced precision. For the latter, the learning rate is kept at 1e-3 and not reduced for larger models. With the QK-normalization, the higher 1e-3 learning rate remains stable. The results, shown in Figure 11, demonstrate increasing benefits with scale, likely due to enabling the larger learning rate.
Appendix C Model Card
LABEL:tab:model_card presents the model card (Mitchell et al., 2019) of the ViT-22B model.
Appendix D Transfer to image classification: More results and addition details
An alternative to doing linear probing with SGD is to use the convex optimization technique, L-BFGS (Byrd et al., 1995). It is very effective and has strict convergence guarantees. We compare SGD and L-BFGS for a variety of ViT models using the ImageNet-1k datasset. Specifically, we precompute image embeddings by resizing input images to 224px resolution and then solve the multiclass logistic regression problem with L-BFGS. We also sweep the L2 regularization parameter and select the optimal one using 20000 holdout images from the training data (approximately 2% of the training data). In Table 10 we compare the resulting model with the SGD baseline from the main text. It demonstrates that L-BFGS matches or lags behind SGD approach, so we selected the latter technique for our core experiments.
D.2 Out of distribution classification
D.3 Head2Toe
The cost of fine-tuning a model during transfer learning goes up with increased model size and often requires the same level of resources as training the model itself. Linear probing on the other hand is much cheaper to run, however it often performs worse than fine-tuning. Recent work showed that training a linear classifier on top of the intermediate features can provide significant gains compared to using the last-layer only, especially for target tasks that are significantly different from the original pre-training task (Evci et al., 2022; Adler et al., 2020; Khalifa et al., 2022).
In Table 12 we compare Head2Toe (Evci et al., 2022) with Linear probe on common vision benchmarks and VTAB-1k (Zhai et al., 2019). We include Finetuning results as a comparison point. We use a simplified version of Head2Toe with no feature selection. Experimental details are shared below. Head2Toe achieves 7% better results on VTAB-1k, however fails to match the full finetuning performance (-6%). On other benchmarks (CIFARs, Flowers and Pets), all methods perform similarly potentially. Head2Toe improves over Linear only for the Cifar-100 task. For the remaining tasks it either achieves the same performance or worse (Pets).
All experiments presented here use images with the default resolution of 224. Head2Toe uses the following intermediate features: (1) output of each of the 48 blocks, (2) features after the positional embedding, (3) features after the pooling head (4) pre-logits and logits. We average each of these features among the token dimension and concatenate them; resulting in a 349081 dimensional feature vector. In contrast, linear probe uses the 6144 dimensional prelogit features, which makes Head2Toe training roughly 50 times more expensive. However, given the extraordinary size of the original model, Head2Toe requires significantly less FLOPs and memoryon the order of 1000x, the exact value depends on number of classes compared to fine-tuning. For all tasks (4 standard and 19 VTAB-1k), we search over 2 learning rates (0.01, 0.001) and 2 training lengths (500 and 10000 (2500 for VTAB-1k) steps) using the validation set.
D.4 Few-shot
We replicate the experimental setup of (Abnar et al., 2021) to evaluate the ViT-22B model and baselines on 25 tasks (Table 13) using few-shot transfer setups. The results of few-shot transfer of different models using 1, 5, 10, and 25 shots are presented in Figure 12. Scaling up can improve performance in many tasks, but in some cases, downstream accuracy does not improve with increased scale. This may be due to the higher dimension of the representation from the ViT-22B model, which may require more regularization as the size of the head grows to prevent overfitting. Further study is needed to investigate this.
Appendix E Transfer to dense prediction: More results and addition details.
In this experiment, we evaluate the effect of fine-tuning versus freezing the ViT-22B backbone when transferring to semantic segmentation. The results are shown in Table 14. We observe that for the linear decoder fine-tuning results in much better performance than using frozen features. For the UperNet decoder, however, the gap between fine-tuning and freezing the backbone is much smaller. This can be explained by the fact that UperNet has times more parameters than the linear model. Figure 6 shows qualitative results using Upernet.
E.2 Monocular Depth Estimation
We sub-sample videos to 5 fps, and crop and resize frames to resolution (both RGB inputs and depth targets). The LiDAR projection is done after cropping and resizing, to retain a high-quality signal. For ViT-L, we upscale the RGB input frames to resolution to account for the larger patch size, while keeping the same information content as for ViT-e and ViT-22B, which both use a patch size of 14. For evaluation frames, we use a simple center-crop. For training, we use Inception-style (Szegedy et al., 2015) random-resized crops as our only form of data augmentation. We ensure that at least of the original frame is retained after cropping.
For efficiency reasons, we pre-compute ViT-22B feature maps for 1,024,000 randomly sampled and augmented frames from the training set, which amounts to approx. 6.4 epochs of training data. When training the decoder, we iterate over these pre-computed feature maps in random order, irrespective of the number of training steps used. We follow the same protocol for all compared models.
E.2.2 Decoder Architectures
We largely follow the design of (Ranftl et al., 2021), using four reassemble and fusion blocks that processes the ViT feature map at , , , and spatial resolutions. We use features at each stage and thus can omit the projection convolution in the fusion block. The final fusion stage feeds into a monocular depth estimation head, where we use the default features and adjust the final re-sampling stage to yield the desired resolution of . Similar to (Ranftl et al., 2021), we do not consider dropout or batchnorm for depth estimation.
For efficiency purposes we reuse the same ViT feature map at each stage. We empirically verified that this did not significantly impact results and our implementation of DPT using four ViT-22B feature maps (from layers 12, 24, 36, and 48) normalized using LayerNorm obtained similar scores to what was reported in Table 5: MSE, AbsRel, , 0.906 , 0.979 . Directly feeding pre-norm feature maps led to instabilities.
For the ViT-e and ViT-L baselines, the linear decoder is exactly the same except for a much smaller input feature dimension ( for ViT-e and for ViT-L). Thus the linear decoder on top of ViT-22B has more capacity than the same on top of ViT-e or ViT-l. We controlled for this in two ways: (a) using a convolution on the ViT-22B features, we down-project them to dimensions to match the feature map size of ViT-e, or (b) using a large hidden dimension ( in ViT-e’s decoder and in ViT-L’s decoder) after the first convolution transpose layer, we approximately matched the number of parameters across the three models. In control (a), performance stayed roughly the same at relative absolute error (AbsRel) for ViT-22B. In control (b) performance for baselines did not change substantially in terms of relative absolute error, for ViT-e and for ViT-L. We therefore report results without these controls in Table 5.
E.2.3 Training Details
We train the decoder for 300k steps with a batch size of 64 using Adam (Kingma and Ba, 2015) and clip the gradients to a global norm value of 0.05 to stabilize training. We linearly increase the learning rate for 2500 steps to 0.0002 (starting from 0) and then decay the learning rate with a cosine schedule (Loshchilov and Hutter, 2017) back to 0.
E.2.4 Metrics
We quantify performance using the following standard depth estimation metrics from the literature (Hermann et al., 2020; Eigen et al., 2014), and also report the MSE loss on the validation set: AbsRel measures the mean absolute error between the ground truth and predicted depth relative to the ground truth, while the inlier fraction metrics () measure the fraction of valid pixels within a certain percentage from ground truth. All metrics were measured after undoing the log transformation.
E.2.5 Qualitative Results
We report qualitative depth predictions by DPT from different ViT backbones in Figure 13, and absolute prediction errors in Figure 14.
Appendix F Video Classification
We sample 128 and 32 frames with a stride of 2 frames from Kinetics 400 videos (Kay et al., 2017) and Moments in Time (Monfort et al., 2019) videos, respectively. For both ViT-22B and ViT-e we rely in the frozen, pre-trained models and use the pre-logit feature representation to extract a single embedding per frame, resulting in a token sequences of length 128 and 32, respectively, which are then processed by a shallow transformer model equipped with a class-token classifier.
This is in contrast to CoCa (Yu et al., 2022a), which uses one token per image patch for their video classification experiments and a resolution of 576px (compared 224px in our experiments), resulting in much longer token sequences. We explored using one token per image patch (i.e. unpooled features) in preliminary experiments, but found that this leads to inferior performance. One potential reason for this could be that CoCa applies a contrastive loss to a pooled feature representation, and additionally feeds the unpooled token sequences to a generative decoder, which might lead to a different structure in the unpooled representation than the supervised classification loss used to pretrain ViT-22B and ViT-e.
To facilitate experimentation, we pre-compute frozen features for the two ViT variants we consider, using the same augmentations as (Arnab et al., 2021). To improve the robustness of our model and prevent overfitting we feed the entire training set ten times, with different data augmentations for every pass. We train for 30 epochs on these precomputed features with a batch size of 256 using SGD with momentum and with a cosine schedule including a linear warmup of 2.5 epochs. We sweep the following hyperparameters and corresponding value ranges to train our video model: transformer layers of width , using a learning rate in and a weight decay in .
Appendix G Fairness
We report the full experimental results described in Section 4.5.1 for all three evaluation (1) classification accuracy (denoted ACC), (2) expected calibration error (ECE) (Naeini et al., 2015; Guo et al., 2017), and (3) Oracle Collaborative AUC (OC-AUC) (Kivlichan et al., 2021). ECE is used to measure calibration, while OC-AUC computes the four variables: binned true/false positives/negatives, as a function of a linearly spaced set of thresholds and score bins. The full results are presented in Figure 16, Figure 17, and Figure 18.
Appendix H Calibration
We precisely follow the setup of Minderer et al. (2021): Since temperature scaling (Guo et al., 2017) requires some held-out data, we use 20% of the ImageNet validation set to learn the temperature parameter while we report the accuracy and expected calibration error on the remaining 80%.
Moreover, since the expected calibration error is defined with respect to a probability distribution normalised over the classes, we use a softmax loss function during fine tuning. The sigmoid loss function is defined independently across the classes and does not yield the required normalisation. We use 20k steps together with a learning rate of 0.03.
We reuse the plotting tools provided at https://github.com/google-research/robustness_metrics/tree/master/robustness_metrics/projects/revisiting_calibration.
Appendix I Plex
We start by providing some details about the datasets and the different evaluation protocols based on Djolonga et al. (2020).
This variant of the ImageNet dataset contains algorithmically generated corruptions (e.g., blur and noise) applied to the ImageNet test-set. The results that we report in the paper are averaged over the 16 corruptions and over their 5 different intensity levels.
In this task, we try to classify whether a given test point belongs to the in-distribution dataset (in our case, ImageNet) or an out-of-distribution dataset (following Hendrycks et al. (2019); Tran et al. (2022), we take Places365 which consists of about 1.8 million images from 365 scene categories, where there are at most 5000 images per category (Zhou et al., 2017a)).
To perform the detection, we use the maximum softmax probability (MSP) (Hendrycks et al., 2019; Tran et al., 2022). We evaluate the performance of the resulting binary classification task thanks to the AUROC and AUPRC.
In this task, a model may defer its predictions to human experts when it is not confident enough. In particular, this task jointly assesses a model’s predictive performance and quality of uncertainty estimates (El-Yaniv and Wiener, 2010). Following Tran et al. (2022), we measure the performance with the oracle collaborative AUC (Kivlichan et al., 2021), with a review fraction of 0.5% of all predictions.
For this evaluation, we aim at demonstrating the ability of the model to capture the inherent ambiguity of image labels assigned by humans. Following Tran et al. (2022), we focus on the ImageNet ReaL-H dataset that exploits the human ratings from Beyer et al. (2020) to construct a label distribution representing rater uncertainty for each image. The performance is measured by the negative log likelihood computed with respect to the soft labels (i.e., vectors in the simplex as opposed to the usual one-hot vectors).
I.2 Details about the Plex architecture
Plex (Tran et al., 2022) calls for BatchEnsemble layers (Wen et al., 2019) to be added in the model architecture during both pre-training and fine-tuning.In Tran et al. (2022), the BatchEnsemble layers are added only to a few of the last layers of the encoder in order to reduce the computational and memory cost. The efficient implementation of ViT-22B constrains us to apply BatchEnsemble layers throughout the network. Due to the high cost of training ViT-22B, we add the BatchEnsemble layers during the fine-tuning stage only. We replace all Dense layers in the ViT-22B, except for the Dense layer in the MLP layer for the pooling head with BatchEnsemble layers. Tran et al. (2022) further suggest to replace the final Dense layer of the network with a heteroscedastic output head (Collier et al., 2021). We thus follow this approach and evaluate both a heteroscedastic and BatchEnsemble final layer.
I.3 Details about the hyperparameters
All models were fine-tuned on ImageNet with a batch size 512. We swept over fine-tuning for 20k or 40k steps and learning rates of 0.01 and 0.03, with Plex models performing better at 40k fine-tuning steps—as already observed by Tran et al. (2022)—and learning rate of 0.03. Two BatchEnsemble members were used, with a random sign initialization in the BatchEnsemble layer of -0.5.
For the experiments with a heteroscedastic output layer, 1k MC samples were used and the low-rank component of the covariance matrix employed 15 factors. Furthermore, we report results for a temperature parameter of 5 (after a hyperparameter search over the [0.5, 10] range).
Unlike most of the models in the rest of the paper, the models of this section are fine tuned with a softmax loss function. We do so to be consistent with the design choices of Tran et al. (2022) and because a distribution normalised across the classes is required by several of the metrics employed (e.g., ECE).
I.4 Results of Plex-22B and challenges
In Table 15, we report the results of ViT-L/32, Plex-L/32, ViT-22B and the extensions of Plex to the 22B scale, Plex-22B, with the BatchEnsemble (BE) and heteroscedastic (HET) heads. All the models are fine tuned with a resolution of 384.
The main observation is that the increased scale of ViT-22B comes with substantial improvements across all metrics, except for the label uncertainty over ImageNet-ReaL-H.
More surprisingly, we can see that across all metrics (except for the label uncertainty over ImageNet-ReaL-H), the Plex-22B variants perform worse than the vanilla ViT-22B model. This observation does not extend the findings from Tran et al. (2022) where Plex consistently leads to improvement at the S, B and L scales.
We believe that this surprising observation may be related to specific challenges faced at the 22B scale:
Pre-training vs. fine-tuning: While Tran et al. (2022) introduce BatchEnsemble layers already at pre-training time, the high training cost of ViT-22B forces us to only operate at fine-tuning time. In this regime, it may not be possible to properly learn the BatchEnsemble and heteroscedastic layers. Moreover, while fine-tuning with standard ViT backbones enjoys a well-performing and robust recipe, namely initializing the final Dense layer kernel to all zeros, we do not have an equivalent approach when adding the Plex components.
Hyperparameter tuning: Even though we already covered a reasonable combination of hyperparameters (fine-tuning duration, learning rate and temperature), it is possible that a finer-grained search is required to close the performance gap.
Numerical stability: As discussed in Section 2, it was required to use particular techniques to stabilize the training of ViT-22B. We hypothesise that similar techniques may have to be developed specifically for the Plex components (BatchEnsemble and heteroscedastic layers) to keep their efficiency at this scale.
Appendix J Error Consistency & Human Alignment
In Section 4.5.2, we described results for testing ViT-22B fine-tuned on ImageNet on the model-vs-human benchmark. In Figure 19(a), Figure 19(b), Figure 19(c), Figure 19(d), we provide additional benchmarking results.
Appendix K Perceptual similarity
Kumar et al. (2022) show a trade-off between the accuracy of latest ImageNet classifiers and their inherent ability to capture perceptual similarity. Here, we explore if large-scale classification on a more diverse training dataset than ImageNet can break the observed trade-off. To compare the perceptual similarity of ViT-22B with prior ImageNet-trained models, we make minor changes to adapt ViT-22B on low resolution ImageNet. ViT-22B fine-tuned on ImageNet achieves 84.2 accuracy on ImageNet which is better than the best models trained directly on ImageNet. As done in (Zhang et al., 2018), we assess the ability of ViT-22B to capture perceptual similarity using 48 intermediate representations. The perceptual score of ViT-22B (64.9) is much lower than all other models, indicating that models trained on large-scale classification also lie on the observed accuracy-perceptual similarity Pareto Frontier.
To make a fair comparison with the models in (Kumar et al., 2022), we make minor changes to adapt ViT-22B on low resolution ImageNet. Directly finetuneing ViT-22B on images with the default patch-size of 14 leads to two undesirable consequences a) A low sequence length of 16 and b) Cropping of 8 pixels on the right borders. So, as proposed in (Beyer et al., 2022a), we resize the trained embedding layer from the default patch-size of 14 to a patch-size of 8 that leads to a longer sequence length of 64. Then, we adapt standard finetuneing protocols.
We make three more observations: 1) Untrained ViT-22B gets a even lower Perceptual Score of 62.3, thus some amount of training is desirable 2) ViT-e lies in the same ballpark as ViT-22B with slightly lower accuracy and Perceptual Scores 3) ViT-22B with the newly proposed Mean Pool distance function (Kumar et al., 2022) can improve its Perceptual Score up to 66.2.
Appendix L Feature attribution analysis
To get a better understanding on how ViT-22B arrives at its predictions we make use of gradient-based feature attribution methods (a.k.a. saliency maps). Figure 21 shows the result of applying Integrated Gradients (Sundararajan et al., 2017, IG) to three example datapoints before and after ViT-22B cooldown. We find that using a gray (0.5) baseline and 1024 steps yields qualitatively the best results. The images show a subtle difference in how the two model checkpoints process the example inputs, where more yellow indicates a higher sensitivity. We can also clearly see the patches in which ViT-22B processes input images. This means that the model is less sensitive around the edges of each patch, and suggests a path for future work to improve the model to better deal with patch edges.