Visual Transformers: Token-based Image Representation and Processing for Computer Vision
Bichen Wu, Chenfeng Xu, Xiaoliang Dai, Alvin Wan, Peizhao Zhang, Zhicheng Yan, Masayoshi Tomizuka, Joseph Gonzalez, Kurt Keutzer, Peter Vajda
Introduction
In computer vision, visual information is captured as arrays of pixels. These pixel arrays are then processed by convolutions, the de facto deep learning operator for computer vision. Although this convention has produced highly successful vision models, there are critical challenges:
1) Not all pixels are created equal: Image classification models should prioritize foreground objects over the background. Segmentation models should prioritize pedestrians over disproportionately large swaths of sky, road, vegetation etc. Nevertheless, convolutions uniformly process all image patches regardless of importance. This leads to spatial inefficiency in both computation and representation.
2) Not all images have all concepts: Low-level features such as corners and edges exist in all natural images, so applying low-level convolutional filters to all images is appropriate. However, high-level features such as ear shape exist in specific images, so applying high-level filters to all images is computationally inefficient. For example, dog features may not appear in images of flowers, vehicles, aquatic animals etc. This results in rarely-used, inapplicable filters expending a significant amount of compute.
3) Convolutions struggle to relate spatially-distant concepts: Each convolutional filter is constrained to operate on a small region, but long-range interactions between semantic concepts is vital. To relate spatially-distant concepts, previous approaches increase kernel sizes, increase model depth, or adopt new operations like dilated convolutions, global pooling, and non-local attention layers. However, by working within the pixel-convolution paradigm, these approaches at best mitigate the problem, compensating for the convolution’s weaknesses by adding model and computational complexity.
To overcome the above challenges, we address the root cause, the pixel-convolution paradigm, and introduce the Visual Transformer (VT) (Figure 1), a new paradigm to represent and process high-level concepts in images. Our intuition is that a sentence with a few words (or visual tokens) suffices to describe high-level concepts in an image. This motivates a departure from the fixed pixel-array representation later in the network; instead, we use spatial attention to convert the feature map into a compact set of semantic tokens. We then feed these tokens to a transformer, a self-attention module widely used in natural language processing to capture token interactions. The resulting visual tokens computed can be directly used for image-level prediction tasks (e.g., classification) or be spatially re-projected to the feature map for pixel-level prediction tasks (e.g., segmentation). Unlike convolutions, our VT can better handle the three challenges above: 1) judiciously allocating computation by attending to important regions, instead of treating all pixels equally; 2) encoding semantic concepts in a few visual tokens relevant to the image, instead of modeling all concepts across all images; and 3) relating spatially-distant concepts through self-attention in token-space.
To validate the effectiveness of VT and understanding its key components, we run controlled experiments by using VTs to replace convolutions in ResNet, a common test bed for new building blocks for image classification. We also use VTs to re-design feature-pyramid networks (FPN), a strong baseline for semantic segmentation. Our expeirments show that VTs achieve higher accuracy with lower computational cost in both tasks. For the ImageNet benchmark, we replace the last stage of ResNet with VTs, reducing FLOPs of the stage by 6.9x and improving top-1 accuracy by 4.6 to 7 points. For semantic segmentation on COCO-Stuff and Look-Into-Person , VT-based FPN achieves 0.35 points higher mIOU while reducing regular FPN module’s FLOPs by 6.4x.
Relationship to previous work
Transformers in vision models: A notable recent and relevant trend is the adoption of transformers in vision models. Dosovitskiy et al. propose a Vision Transformer (ViT) , dividing an image into patches and feeding these patches (i.e., tokens) into a standard transformer. Although simple, this requires transformers to learn dense, repeatable patterns (e.g., textures), which convolutions are drastically more efficient at learning. The simplicity incurs an extremely high computational price: ViT requires up to 7 GPU years and 300M JFT dataset images to outperform competing convolutional variants. By contrast, we leverage the respective strengths of each operation, using convolutions for extracting low-level features and transformers for relating high-level concepts. We further use spatial attention to focus on important regions, instead of treating each image patch equally. This yields strong performance despite orders-of-magnitude less data and training time.
Another relevant work, DETR, adopts transformers to simplify the hand-crafted anchor matching procedure in object detection training. Although both adopt transformers, DETR is not directly comparable to our VT given their orthogonal use cases, i.e., insights from both works could be used together in one model for compounded benefit.
Graph convolutions in vision models: Our work is also related to previous efforts such as GloRe , LatentGNN , and that densely relate concepts in latent space using graph convolutions. To augment convolutions, adopt a procedure similar to ours: (1) extracting latent variables to represent in graph nodes (analogous to our visual tokens) (2) applying graph convolution to capture node interactions (analogous to our transformer), and (3) projecting the nodes back to the feature map. Although these approaches avoid spatial redundancy, they are susceptible to concept redundancy: the second limitation listed in the introduction. In particular, by using fixed weights that are not content-aware, the graph convolution expects a fixed semantic concept in each node, regardless of whether the concept exists in the image. By contrast, a transformer uses content-aware weights, allowing visual tokens to represent varying concepts. As a result, while graph convolutions require hundreds of nodes (128 nodes in , 340 in , 150 in ) to encode potential semantic concepts, our VT uses just 16 visual tokens and attains higher accuracy. Furthermore, while modules from can only be added to a pretrained network to augment convolutions, VTs can replace convolution layers to save FLOPs and parameters, and support training from scratch.
Attention in vision models: In addition to being used in transformers, attention is also widely used in different forms in computer vision models . Attention was first used to modulate the feature map: attention values are computed from the input and multiplied to the feature map as in . Later work interpret this “modulation” as a way to make convolution spatially adaptive and content-aware. In , Wang et al. introduced non-local operators, equivalent to self-attention, to video understanding to capture the long-range interactions. However, the computational complexity of self-attention grows quadratically with the number of pixels. use self-attention to augment convolutions and reduce the compute cost by using small channel sizes for attention. on the other hand restrict receptive field of self-attention and use it in a convolutional manner. Starting from , self-attentions are used as a stand-alone building block for vision models. Our work is different from all above since we propose a novel token-transformer paradigm to replace the inefficient pixel-convolution paradigm and achieve superior performance.
Efficient vision models: Many recent research efforts have been focusing on building vision models to achieve better performance with lower computational cost. Early work in this direction includes . Recently, people use neural architecture search to optimize network’s performance within a search space that consists of existing convolution operators. The efforts above all seek to make the common convolutional-neural net more computationally efficient. In contrast, we propose a new building block that naturally eliminates the redundant computations in the pixel-convolution paradigm.
Visual Transformer
We illustrate the overall diagram of a Visual Transformer (VT) based model in Figure 1. First, process the input image with several convolution blocks, then feed the output feature map to VTs. Our insight is to leverage the strengths of both convolutions and VTs: (1) early in the network, use convolutions to learn densely-distributed, low-level patterns and (2) later in the network, use VTs to learn and relate more sparsely-distributed, higher-order semantic concepts. At the end of the network, use visual tokens for image-level prediction tasks and use the augmented feature map for pixel-level prediction tasks.
A VT module involves three steps: First, group pixels into semantic concepts, to produce a compact set of visual tokens. Second, to model relationships between semantic concepts, apply a transformer to these visual tokens. Third, project these visual tokens back to pixel-space to obtain an augmented feature map. Similar paradigms can be found in but with one critical difference: Previous methods use hundreds of semantic concepts (termed, “nodes”), whereas our VT uses as few as 16 visual tokens to achieve superior performance.
However, many high-level semantic concepts are sparse and may each appear in only a few images. As a result, the fixed set of learned weights potentially wastes computation by modeling all such high-level concepts at once. We call this a “filter-based” tokenizer, since it uses convolutional filters to extract visual tokens.
1.2 Recurrent Tokenizer
To remedy the limitation of filter-based tokenizers, we propose a recurrent tokenizer with weights that are dependent on previous layer’s visual tokens. The intuition is to let the previous layer’s tokens guide the extraction of new tokens for the current layer. The name of “recurrent tokenizer” comes from that current tokens are computed dependent on previous ones. Formally, we define
2 Transformer
After tokenization, we then need to model interactions between these visual tokens. Previous works use graph convolutions to relate concepts. However, these operations use fixed weights during inference, meaning each token (or “node”) is bound to a specific concept, therefore graph convolutions waste computation by modeling all high-level concepts, even those that only appear in few images. To address this, we adopt transformers , which use input-dependent weights by design. Due to this, transformers support visual tokens with variable meaning, covering more possible concepts with fewer tokens.
We employ a standard transformer with minor changes:
3 Projector
Many vision tasks require pixel-level details, but such details are not preserved in visual tokens. Therefore, we fuse the transformer’s output with the feature map to refine the feature map’s pixel-array representation as
Using Visual Transformers in vision models
In this section, we discuss how to use VTs as building blocks in vision models. We define three hyper-parameters for each VT: channel size of the feature map; channel size of the visual tokens; and the number of visual tokens.
Image classification model: For image classification, following the convention of previous work, we build our networks with backbones inherited from ResNet . Based on ResNet-{18, 34, 50, 101}, we build corresponding visual-transformer-ResNets (VT-ResNets) by replacing the last stage of convolutions with VT modules. The last stage of ResNet-{18, 34, 50, 101} contains 2 basic blocks, 3 basic blocks, 3 bottleneck blocks, and 3 bottleneck blocks, respectively. We replace them with the same number (2, 3, 3, 3) of VT modules. At the end of stage-4 (before stage-5 max pooling), ResNet-{18, 34} generate feature maps with the shape of , and ResNet-{50, 101} generate feature maps with the shape of . We set VT’s feature map channel size to be 256, 256, 1024, 1024 for ResNet-{18, 34, 50, 101}. We adopt 16 visual tokens with a channel size of 1024 for all the models. At the end of the network, we output 16 visual tokens to the classification head, which applies an average pooling over the tokens and use a fully-connected layer to predict the probabilities. A table summarizing the stage-wise description of the model is provided in Appendix A. Since VTs only operate on 16 visual tokens, we can reduce the last stage’s FLOPs by up to 6.9x, as shown in Table 1.
Semantic segmentation: We show that using VTs for semantic segmentation can tackle several challenges with the pixel-convolution paradigm. First, the computational complexity of convolution grows with the image resolution. Second, convolutions struggles to capture long-term interactions between pixels. VTs, on the other hand, operate on a small number of visual tokens regardless of the image resolution, and since it models concept interactions in the token-space, it bypasses the “long-range” challenge with pixel-arrays.
To validate our hypothesis, we use panoptic feature pyramid networks (FPN) as a baseline and use VTs to improve the network. Panoptic FPNs use ResNet as backbone to extract feature maps from different stages with various resolutions. These feature maps are then fused by a feature pyramid network in a top-down manner to generate a multi-scale and detail preserving feature map with rich semantics for segmentation (Figure 4 left). FPN is computationally expensive since it heavily relies on spatial convolutions operating on high resolution feature maps with large channel sizes. We use VTs to replace convolutions in FPN. We name the new module as VT-FPN (Figure 4 right). From each resolution’s feature map, VT-FPN extract 8 visual tokens with a channel size of 1024. The visual tokens are combined and fed into one transformer to compute interactions between visual tokens across resolutions. The output tokens are then projected back to the original feature maps, which are then used to perform pixel-level prediction. Compared with the original FPN, the computational cost for VT-FPN is much smaller since we only operate on a very small number of visual tokens rather than all the pixels. Our experiment shows VT-FPN uses 6.4x fewer FLOPs than FPN while preserving or surpassing its performance (Table 9 & 10).
Experiments
We conduct experiments with VTs on image classification and semantic segmentation to (a) understand the key components of VTs and (b) validate their effectiveness.
We conduct experiments on the ImageNet dataset with around 1.3 million images in the training set and 50 thousand images in the validation set. We implement VT models in PyTorch . We use stochastic gradient descent (SGD) optimizer with Nesterov momentum . We use an initial learning rate of , a momentum of , and a weight decay of 4e-5. We train the model for 90 epochs, and decay the learning rate by 10x every 30 epochs. We use a batch size of 256 and 8 V100 GPUs for training.
VT vs. ResNet with default training recipe: In Table 2, we first compare VT-ResNets and vanilla ResNets under the same training recipe. VT-ResNets in this experiment use a filter-based tokenizer for the first VT module and recurrent tokenizers in later modules. We can see that after replacing the last stage of convolutions in ResNet18 and ResNet34, VT-based ResNets use many fewer FLOPs: 244M fewer FLOPs for ResNet18 and 384M fewer for ResNet34. Meanwhile, VT-ResNets achieve much higher top-1 validation accuracy than the corresponding ResNets by up to 2.2 points. This confirms effectiveness of VTs. Also note that the training accuracy achieved by VT-ResNets are much higher than that of baseline ResNets: VT-R18 is 7.9 points higher and VT-R34 is 6.9 points higher. This indicates that VT-ResNets are overfitting more heavily than regular ResNets. We hypothesize this is because VT-ResNets have much larger capacity and we need stronger regularization (e.g., data augmentation) to fully utilize the model capacity. We address this in Section 5.2 and Table 8.
Tokenizer ablation studies: In Table 3, we compare different types of tokenizers used by VTs. We consider a pooling-based tokenizer, a clustering-based tokenizer, and a filter-based tokenizer (Section 3.1.1). We use the candidate tokenizer in the first VT module and use recurrent tokenizers in later modules. As a baseline, we implement a pooling-based tokenizer, which spatially downsamples a feature map to reduce its spatial dimensions from to , instead of grouping pixels by their semantics. As a more advanced baseline, we consider a clustering-based tokenizer, which is described in Appendix C. It applies K-Means clustering in the semantic space to group pixels to visual tokens. As can be seen from Table 3, filter-based and clustering-based tokenizers perform significantly better than the pooling-based baseline, validating our hypothesis that feature maps contain redundancies, and this can be addressed by grouping pixels in semantic space. The difference between filter-based and clustering-based tokenizers is small and vary between R18 and R34. We hypothesize this is because both tokenizers have their own drawbacks. The filter-based tokenizer relies on fixed convolution filters to detect and assign pixels to semantic groups, and is limited by the capacity of the convolution filters to deal with diverse and sparse high-level semantic concepts. On the other hand, the clustering-based tokenizer extracts semantic concepts that exist in the image, but it is not designed to capture the essential semantic concepts.
In Table 4, we validate the recurrent tokenizer’s effectiveness. We use a filter-based tokenizer in the first VT module and use recurrent tokenizers in subsequent modules. Experiments show that using recurrent tokenizers leads to higher accuracy.
Modeling token relationships: In Table 5, we compare different methods of capturing token relationships. As a baseline, we do not compute the interactions between tokens. This leads to the worst performance among all variations. This validates the necessities of capturing the relationship between different semantic concepts. Another alternative is to use graph-convolutions similar to , but its performance is worse than that of VTs. This is likely due to the fact that graph-convolutions bind each visual token to a fixed semantic concept. In comparison, using transformers allows each visual token to encode any semantic concepts as long as the concept appears in the image. This allows the models to fully utilize its capacity.
Token efficiency ablation: In Table 6, we test varying numbers of visual tokens, only to find negligible or no increase in accuracy. This agrees with our hypothesis that VTs are already capturing a wide variety of concepts with just a few handfuls of tokens–additional tokens are not needed, as the space of possible, high-level concepts is already covered.
Pojection ablation: In Table 7, we study whether we need to project visual tokens back to feature maps. We hypothesize that projecting the visual tokens is an important step since in vision understanding, pixel-level semantics are very important, and visual tokens are representations in the semantic space that do not encode any spatial information. As validated by Table 7, projecting visual tokens back to the feature map leads to higher performance, even for image classification tasks.
2 Training with Advanced Recipe
In Table 2, we show that under the regular training recipe, the VT-ResNets experience serious overfitting. Despite their accuracy improvement on the validation set, its training accuracy improves by a significantly larger margin. We hypothesize that this is because VT-based models have much higher model capacity. To fully unleash the potential of VT, we used a more advanced training recipe to train VT models. To prevent overfitting, we train with more training epochs, stronger data augmentation, stronger regularization, and distillation. Specifically, we train VT-ResNet models for 400 epochs with RMSProp. We use an initial learning rate of 0.01 and increase to 0.16 in 5 warmup epochs, then reduce the learning rate by a factor of 0.9875 per epoch. We use synchronized batch normalization and distributed training with a batch size of 2048. We use label smoothing and AutoAugment and we set the stochastic depth survival probability and dropout ratio to be 0.9 and 0.2, respectively. We use exponential moving average (EMA) on the model weights with 0.99985 decay. We use knowledge distillation in the training recipe, where the teacher model is FBNetV3-G . The total loss is a weighted sum of distillation loss (0.8) and cross entropy loss (0.2).
Our results are reported in Table 8. Compared with the baseline ResNet models, VT-ResNet models achieve 4.6 to 7 points higher accuracy and surpass all other related work that adopt attention of different forms based on ResNets . This validates that our advanced training recipe better utilizes the model capacity of VT-ResNet models to outperform all previous baselines.
Note that in addition to the architecture differences, previous works also used their own training recipes and it is infeasible for us to try these recipes one by one. So to understand the source of the accuracy gain, we use the same training recipe to train baseline ResNet18 and ResNet34 and also observe significant accuracy improvement on baseline ResNets. But note that under the advanced training recipe, the accuracy gap between VT-ResNet and baselines increases from 1.7 and 2.2 points to 2.2 and 3.0 points. This further validates that a stronger training recipe can better utilize the model capacity of VTs. While achieving higher accuracy than previous work, VT-ResNets also use much fewer FLOPs and parameters, even we only replace the last stage of the baseline model. If we consider the reduction over the original stage, we observe FLOP reductions of up to 6.9x, as shown in Table 1.
3 Visual Transformer for Semantic Segmentation
We conduct experiments to test the effectiveness of VT for semantic segmentation on the COCO-stuff dataset and the LIP dataset . The COCO-stuff dataset contains annotations for 91 stuff classes with 118K training images and 5K validation images. LIP dataset is a collection of images containing humans with challenging poses and views. For the COCO-stuff dataset, we train a VT-FPN model with ResNet-{50, 101} backbones. Our implementation is based on Detectron2 . Our training recipe is based on the semantic segmentation FPN recipe with 1x training steps, except that we use synchronized batch normalization in the VT-FPN, change the batch size to 32, and use a base learning rate of 0.04. For the LIP dataset, we also use synchronized batch normalization with a batch size of 96. We optimize the model via SGD with weight decay of 0.0005 and learning rate of 0.01.
As we can see in Table 9 and 10, after replacing FPN with VT-FPN, both ResNet-50 and ResNet-101 based models achieve slightly higher mIoU, but VT-FPN requires 6.5x fewer FLOPs than a FPN module.
4 Visualizing Visual Tokens
Conclusion
The convention in computer vision is to represent images as pixel arrays and to apply the de facto deep learning operator – the convolution. In lieu of this, we propose Visual Transformers (VTs), as hallmarks of a new computer vision paradigm, learning and relating sparsely-distributed, high-level concepts far more efficiently: Instead of pixel arrays, VTs represent just the high-level concepts in an image using visual tokens. Instead of convolutions, VTs apply transformers to directly relate semantic concepts in token-space. To evaluate this idea, we replace convolutional modules with VTs, obtaining significant accuracy improvements across tasks and datasets. Using an advanced training recipe, our VT improves ResNet accuracy on ImageNet by 4.6 to 7 points. For semantic segmentation on LIP and COCO-stuff, VT-based feature pyramid networks (FPN) achieve 0.35 points higher mIoU despite 6.5x fewer FLOPs than convolutional FPN modules. This paradigm can furthermore be compounded with other contemporaneous tricks beyond the scope of this paper, including extra training data and neural architecture search. However, instead of presenting a mosh pit of deep learning tricks, our goal is to show that the pixel-convolution paradigm is fraught with redundancies. To compensate, modern methods add exceptional amounts of computational complexity. However, as model designers and practitioners, we can tackle the root cause instead of exacerbating compute demands, addressing redundancy in the pixel-convolution convention by adopting the newfound token-transformer paradigm moving forward.
References
Appendix A Stage-wise model description of VT-ResNet
In this section, we provide a stage-wise description of the model configurations for VT-based ResNet (VT-ResNet). We use three hyper-parameters to control a VT module: channel size of the output feature map C, channel size of visual token CT, and the number of visual tokens L. The model configurations are described in Table 11.
Appendix B More visualization results
We provide more visualization of the spatial attention on images from the LIP dataset in Figure 7.
Appendix C Clustering-based tokenizer
Pseudo-code for our K-means implementation is provided in Listing 1, and can be summarized as: Normalize all pixels to unit vectors, initialize centroids with a spatially-downsampled feature map, and run Lloyd’s algorithm to produce centroids.
Although this tokenizer efficiently models only concepts in the current image, the drawback is that it is not designed to choose the most discriminative concepts.