Vision Transformer with Deformable Attention
Zhuofan Xia, Xuran Pan, Shiji Song, Li Erran Li, Gao Huang
Introduction
Transformer is originally introduced to solve natural language processing tasks. It has recently shown great potential in the field of computer vision . The pioneer work, Vision Transformer (ViT), stacks multiple Transformer blocks to process non-overlapping image patch (i.e. visual token) sequences, leading to a convolution-free model for image classification. Compared to their CNN counterparts , Transformer-based models have larger receptive fields and excel at modeling long-range dependencies, which are proved to achieve superior performance in the regime of a large amount of training data and model parameters. However, the superfluous attention in visual recognition is a double-edged sword, and has multiple drawbacks. Specifically, the excessive number of keys to attend per query patch yields high computational cost and slow convergence, and increases the risk of overfitting.
In order to avoid excessive attention computation, existing works have leveraged carefully designed efficient attention patterns to reduce the computation complexity. As two representative approaches among them, Swin Transformer adopts window-based local attention to restrict attention in local windows, while Pyramid Vision Transformer (PVT) downsamples the key and value feature maps to save computation. Though effective, the hand-crafted attention patterns are data-agnostic and may not be optimal. It is likely that relevant keys/values are dropped, while less important ones are still kept.
Ideally, one would expect that the candidate key/value set for a given query is flexible and has the ability to adapt to each individual input, such that the issues in hand-crafted sparse attention patterns can be alleviated. In fact, in the literature of CNNs, learning a deformable receptive field for the convolution filters has been shown effective in selectively attending to more informative regions on a data-dependent basis . The most notable work, Deformable Convolution Networks , has yielded impressive results on many challenging vision tasks. This motivates us to explore a deformable attention pattern in Vision Transformers. However, a naive implementation of this idea leads to an unreasonably high memory/computation complexity: the overhead introduced by the deformable offsets is quadratic w.r.t the number of patches. As a consequence, although some recent work have investigated the idea of deformable mechanism in Transformers , none of them have treated it as a basic building block for constructing a powerful backbone network like the DCN, due to the high computational cost. Instead, their deformable mechanism is either adopted in the detection head , or used as a pre-processing layer to sample patches for the subsequent backbone network .
In this paper, we present a simple and efficient deformable self-attention module, equipped with which a powerful pyramid backbone, named Deformable Attention Transformer (DAT), is constructed for image classification and various dense prediction tasks. Different from DCN that learns different offsets for different pixels in the whole feature map, we propose to learn a few groups of query agnostic offsets to shift keys and values to important regions (as illustrated in Figure 1(d)), based on the observation in that global attention usually results in the almost same attention patterns for different queries. This design both holds a linear space complexity and introduces a deformable attention pattern to Transformer backbones. Specifically, for each attention module, reference points are first generated as uniform grids, which are the same across the input data. Then, an offset network takes as input the query features and generates the corresponding offsets for all the reference points. In this way, the candidate keys/values are shifted towards important regions, thus augmenting the original self-attention module with higher flexibility and efficiency to capture more informative features.
To summarize, our contributions are as follows: we propose the first deformable self-attention backbone for visual recognition, where the data-dependent attention pattern endows higher flexibility and efficiency. Extensive experiments on ImageNet , ADE20K and COCO demonstrate that our model outperforms competitive baselines including Swin Transformer consistently, by a margin of 0.7 on the top-1 accuracy of image classification, 1.2 on the mIoU of semantic segmentation, 1.1 on object detection for both box AP and mask AP. The advantages on small and large objects are more distinct with a margin of 2.1.
Related Work
Transformer vision backbone. Since the introduction of ViT , improvements have focused on learning multi-scale features for dense prediction tasks and efficient attention mechanisms. These attention mechanisms include windowed attention , global tokens , focal attention and dynamic token sizes . More recently, convolution-based approaches have been introduced into Vision Transformer models. Among which exist researches focusing on complementing transformer models with convolution operations to introduce additional inductive biases. CvT adopts convolution in the tokenization process and utilizes stride convolution to reduce the computation complexity of self-attention. ViT with convolutional stem proposes to add convolutions at the early stage to achieve stabler training. CSwin Transformer adopts a convolution-based positional encoding technique and shows improvements on downstream tasks. Many of these convolution-based techniques can potentially be applied on top of DAT for further performance improvements.
Deformable CNN and attention. Deformable convolution is a powerful mechanism to attend to flexible spatial locations conditioned on input data. Recently it has been applied to Vision Transformers . Deformable DETR improves the convergence of DETR by selecting a small number of keys for each query on the top of a CNN backbone. Its deformable attention is not suited to a visual backbone for feature extraction as the lack of keys restricts representation power. Furthermore, the attention in Deformable DETR comes from simply learned linear projections and keys are not shared among query tokens. DPT and PS-ViT builds deformable modules to refine visual tokens. Specifically, DPT proposes a deformable patch embedding to refine patches across stages and PS-ViT introduces a spatial sampling module before a ViT backbone to improve visual tokens. None of them incorporate deformable attention into vision backbones. In contrast, our deformable attention takes a powerful and yet simple design to learn a set of global keys shared among visual tokens, and can be adopted as a general backbone for various vision tasks. Our method can also be viewed as a spatial adaptive mechanism, which has been proved effective in various works .
Deformable Attention Transformer
With normalization layers and identity shortcuts, the -th Transformer block is formulated as
2 Deformable Attention
Existing hierarchical Vision Transformers, notably PVT and Swin Transformer try to address the challenge of excessive attention. The downsampling technique of the former results in severe information loss, and the shift-window attention of the latter leads to a much slower growth of receptive fields, which limits the potential of modeling large objects. Thus a data-dependent sparse attention is required to flexibly model relevant features, leading to deformable mechanism firstly proposed in DCN . However, simply implementing DCN in Transformer models is a non-trivial problem. In DCN, each element on the feature map learns its offsets individually, of which a deformable convolution on an feature map has the space complexity of . If we directly apply the same mechanism in the attention module, the space complexity will drastically rise to , where are the number of queries and keys and usually have the same scale as the feature map size , bringing approximately a biquadratic complexity. Although Deformable DETR has managed to reduce this overhead by setting a lower number of keys with at each scale and works well as a detection head, it is inferior to attend to such few keys in a backbone network because of the unacceptable loss of information (see detailed comparison in Appendix). In the meantime, the observations in have revealed that different queries have similar attention maps in visual attention models. Therefore, we opt for a simpler solution with shared shifted keys and values for each query to achieve an efficient trade-off.
Specifically, we propose deformable attention to model the relations among tokens effectively under the guidance of the important regions in the feature maps. These focused regions are determined by multiple groups of deformed sampling points which are learned from the queries by an offset network. We adopt bilinear interpolation to sample features from the feature maps, and then the sampled features are fed to the key and value projections to get the deformed keys and values. Finally, standard multi-head attention is applied to attend queries to the sampled keys and aggregate features from the deformed values. Additionally, the locations of deformed points provide a more powerful relative position bias to facilitate the learning of the deformable attention, which will be discussed in the following sections.
Offset generation. As we have stated, a sub-network is adopted for offset generation which consumes the query features and outputs the offset values for reference points respectively. Considering that each reference point covers a local region ( is the largest value for offset), the generation network should also have the perception of the local features to learn reasonable offsets. Therefore, we implement the sub-network as two convolution modules with a nonlinear activation, as depicted in Figure 2(b). The input features are first passed through a depthwise convolution to capture local features. Then, GELU activation and a convolution is adopted to get the 2D offsets. It is also worth noticing that the bias in convolution is dropped to alleviate the compulsive shift for all locations.
Offset groups. To promote the diversity of the deformed points, we follow a similar paradigm in MHSA, and split the feature channel into groups. Features from each group use the shared sub-network to generate the corresponding offsets respectively. In practice, the head number for the attention module is set to be multiple times of the size of offset groups , ensuring that multiple attention heads are assigned to one group of deformed keys and values.
Computational complexity. Deformable multi-head attention (DMHA) has a similar computation cost as the counterpart in PVT or Swin Transformer. The only additional overhead comes from the sub-network that is used to generate offsets. The complexity of the whole module can be summarized as:
where is the number of sampled points. It can be immediately seen that the computational cost of the offset network has linear complexity w.r.t. the channel size, which is comparably minor to the cost for attention computation. Typically, consider the third stage of a Swin-T model for image classification where , , , the computational cost for the attention module in a single block is 79.63M FLOPs. If equipped with our deformable module (with ), the additional overhead is 5.08M Flops, which is only of the whole module. Additionally, by choosing a large downsample factor , the complexity will be further reduced, which makes it friendly to the tasks with much higher resolution inputs such as object detection and instance segmentation.
3 Model Architectures
We replace the vanilla MHSA with our deformable attention in the Transformer (Eq.(4)), and combine it with an MLP (Eq.(5)) to build a deformable vision transformer block. In terms of the network architecture, our model, Deformable Attention Transformer, shares a similar pyramid structure with , which is broadly applicable to various vision tasks requiring multiscale feature maps. As illustrated in Figure 3, an input image with shape is firstly embedded by a 44 non-overlapped convolution with stride 4, followed by a normalization layer to get the patch embeddings. Aiming to build a hierarchical feature pyramid, the backbone includes 4 stages with a progressively increasing stride. Between two consecutive stages, there is a non-overlapped 22 convolution with stride 2 to downsample the feature map to halve the spatial size and double the feature dimensions. In classification task, we first normalize the feature maps output from the last stage and then adopt a linear classifier with pooled features to predict the logits. In object detection, instance segmentation and semantic segmentation tasks, DAT plays the role of a backbone in an integrated vision model to extract multiscale features. We add a normalization layer to the features from each stage before feeding them into the following modules such as FPN in object detection or decoders in semantic segmentation.
We introduce successive local attention and deformable attention blocks in the third and the fourth stage of DAT. The feature maps are firstly processed by a window-based local attention to aggregate information locally, and then passed through the deformable attention block to model the global relations between the local augmented tokens. This alternate design of attention blocks with local and global receptive fields helps the model learn strong representations, sharing a similar pattern in GLiT , TNT and Pointformer . Since the first two stages mainly learn local features, deformable attention in these early stages is less preferred. In addition, the keys and values in the first two stages have a rather large spatial size, which greatly increase the computational overhead in the dot products and bilinear interpolations in deformable attention. Therefore, to achieve a trade-off between model capacity and computational burden, we only place deformable attention in the third and the fourth stage and adopt the shift-window attention in Swin Transformer to have a better representation in the early stages. We build three variants of DAT in different parameters and FLOPs for a fair comparison with other Vision Transformer models. We change the model size by stacking more blocks in the third stage and increasing the hidden dimensions. The detailed architectures are reported in Table 1. Note that there are other design choices for the first two stages of DAT, e.g., the SRA module in PVT. We show the comparison results in Table 7.
Experiments
We conduct experiments on several datasets to verify the effectiveness of our proposed DAT. We show our results on ImageNet-1K classification, COCO object detection and ADE20K semantic segmentation tasks. In addition, we provide ablation studies and visualizations to further show the effectiveness of our method.
ImageNet-1K dataset has 1.28M images for training and 50K images for validation. We train three variants of DAT on the training split and report the Top-1 accuracy on the validation split to compare with other Vision Transformer models.
We use AdamW optimizer to train our models for 300 epochs with a cosine learning rate decay. The basic learning rate for a batch size of 1024 is set to , and then linearly scaled w.r.t. the batch size. To stabilize training procedures, we schedule a linear warm-up of learning rate from to the basic learning rate, and for a better convergence the cosine decay rule is applied to gradually decrease the learning rate to during training. We follow DeiT to set the advanced data augmentation, including RandAugment , Mixup and CutMix to avoid overfitting. In addition, stochastic depth and weight decay of 0.05 are also applied, in which the stochastic depth degree is chosen 0.2, 0.3 and 0.5 for the tiny, small and base model, respectively. We do not adopt EMA , random erasing and the vanilla drop out, which does not improve the training of Vision Transformers, as verified in . In terms of larger resolution finetuning, we finetune our DAT-B using AdamW optimizer with a cosine scheduled learning rate for 30 epochs. We set the stochastic depth rate to 0.5 and lower the weight decay to to keep the regularization.
We report our results in Table 2, with 300 training epochs. Compared with other state-of-the-art Vision Transformers, our DATs achieve significant improvements on the Top-1 accuracy with similar computational complexities. Our method DAT outperforms Swin Transformer , PVT , DPT and DeiT in all three scales. Without inserting convolutions in Transformer blocks , or using overlapped convolutions in patch embeddings , DATs achieve gains of +0.7, +0.7 and +0.5 over Swin Transformer counterparts. When finetuning at resolution, our model continues performing better than Swin Transformer by 0.3.
2 COCO Object Detection
COCO object detection and instance segmentation dataset has 118K training images and 5K validation images. We use our DAT as the backbone in RetinaNet , Mask R-CNN and Cascade Mask R-CNN frameworks to evaluate the effectiveness of our method. We pretrain our models on ImageNet-1K dataset for 300 epochs and follow the similar training strategies in Swin Transformer to compare our methods fairly.
We report our DAT on RetinaNet model in 1x and 3x training schedules. As shown in Table 3, DAT outperforms Swin Transformer by 1.1 and 1.2 mAP among tiny and small models. When implemented in two-stage detectors, e.g., Mask R-CNN and Cascade Mask R-CNN, our model achieves consistent improvements over Swin Transformer models in different sizes, as shown in Table 4. We can see that DAT achieves most improvements on large objects (up to +2.1) due to the flexibility in modeling long-range dependencies. The gaps for small objects detection and instance segmentation are also pronounced (up to +2.1), which shows that DATs also have the capacity of modeling relations in the local region.
3 ADE20K Semantic Segmentation
ADE20K is a popular dataset for semantic segmentation with 20K training images and 2K validation images. We employ our DAT on two widely adopted segmentation models, SemanticFPN and UperNet . To make a fair comparison to PVT and Swin Transformer , we follow the learning rate schedules and training epochs, except for the degree of stochastic depth, which is a key hyper-parameter affecting the final performance. We set it for 0.3, 0.3 and 0.5 for tiny, small and base variants of our DAT respectively for both two models. With the pretraining models on ImageNet-1K, we train SemanticFPN for 40k steps and UperNet for 160k steps. In Table 5, we report the results on the validation set with the highest mIoU score of all methods. In comparison with PVT , our tiny model outperforms PVT-S by +0.5 mIoU even with less FLOPs and achieves a sharp boost with +3.1 and +2.5 in mIoU with a slightly larger model size. Our DAT has a significant improvement over the Swin Transformer at each of three model scales, with +1.0, +0.7 and +1.2 in mIoU respectively, showing our method’s effectiveness.
4 Ablation Study
In this section, we ablate the key components in our DAT to verify the effectiveness of these designs. We report the results on ImageNet-1K classification based on DAT-T.
Geometric information exploitation. We first evaluate the effectiveness of our proposed deformable offsets and deformable relative position embeddings, as shown in Table 6. Either adopting offsets in feature sampling or using deformable relative position embedding provides +0.3 improvement. We also try other types of position embeddings, including a fixed learnable position bias and a depth-wise convolution in . But none of them is effective with only +0.1 gain over that without position embedding, which shows our deformable relative position bias is more compatible with deformable attention. There is also an observation from rows 6 and 7 in Table 6 that our model can adapt to different attention modules at the first two stages and achieve competitive results. Our model with SRA at the first two stages outperforms PVT-M by 0.5 with 65 FLOPs.
Deformable attention at different stages. We replace the shift-window attention of Swin Transfomer with our deformable attention at different stages. As shown in Table 7, only replacing the attention in the last stage improves by 0.1 and replacing the last two stages leads to a performance gain of 0.7 (achieving an overall accuracy of 82.0). However, replacing with more deformable attention at the early stages slightly decreases the accuracy.
Ablation on different . We go on the further study of the impact of different maximum offsets, i.e., the offset range scale factor in the paper. We conduct an ablation experiment of ranging from 0 to 16 where 14 corresponds to the largest reasonable offset given the size of the feature map ( at stage 3). As shown in Figure 4, the wide selection range of shows the robustness of DAT to this hyper-parameter. Practically, we choose a small for all models in the paper without additional tuning.
5 Visualization
To verify the effectiveness of deformable attention, we use a similar mechanism to DCNs to visualize the most important keys across multiple deformable attention layers by propagating their attention weights. As shown in Figure 5, our deformable attention learns to place the keys mostly in the foreground, indicating that it focuses on the important regions of the objects, which supports our hypothesis shown in Figure 1 of the paper. More visualizations can be found in appendix (Figure 6,7).
Conclusion
This paper presents Deformable Attention Transformer, a novel hierarchical Vision Transformer that can be adapted to both image classification and dense prediction tasks. With deformable attention module, our model is capable of learning sparse-attention patterns in a data-dependent way and modeling geometric transformations. Extensive experiments demonstrate the effectiveness of our model over competitive baselines. We hope our work can inspire insights towards designing flexible attention techniques.
Acknowledgments
This work is supported in part by the National Science and Technology Major Project of the Ministry of Science and Technology of China under Grants 2018AAA0100701, the National Natural Science Foundation of China under Grants 61906106 and 62022048. The computational resources supporting this work are provided by Hangzhou High-Flyer AI Fundamental Research Co.,Ltd.
Appendix
A. DAT and Deformable DETR
In this section, we provide a detailed comparision between our proposed deformable attention and the direct adaptation from the deformable convolution , which is also known as the multiscale deformable attention in Deformable DETR .
First, our deformable attention serves as a feature extractor in the vision backbones while the one in Deformable DETR which replaces the vanilla attention in DETR with a linear deformable attention, plays the role of the detection head. Second, the -th head of query in the attention in Deformable DETR with single scale is formulated as
Third, the deformable attention in Deformable DETR is not compatible to the dot-product attention for its enormous memory consumption mentioned in Sec.3.2 in the paper. Therefore, the linear predicted attention is used to avoid computing dot products and a smaller number of keys is also adopted to reduce the memory cost.
To experimentally validate our claim, we replace our deformable attention modules in DAT with the modules in to verify that the naive adaptation is inferior for vision backbone. The comparison results are shown in Table 8. Comparing the first and last row, we can see that under smaller memory budget, the number of keys for the deformable DETR model are set as 16 to reduce memory usage, and achieves lower performance. By comparing the third and last row, we can see that the D-DETR attention with the same number of keys as DAT consumes 2.6 memory and 1.3 FLOPs, while the performances are still lower than DAT.
B. Adding Convolutions to DAT
Recent works have proved that adopting convolution layers in the Vision Transformer architecture can further improve model performances. For example, using convolutional patch embedding can generally boost model performances by on ImageNet classification tasks. It is worth noticing that our proposed DAT can readily combine with these techniques, while we maintain the convolution-free architecture in the main paper to perform fair comparison with baselines.
To fully explore the capacity of DAT, we substitute the patch embedding layers in the original model with strided and overlapped convolutions. The comparison results are shown in Table 9, where baseline models have similar modifications. It is shown that our model with additional convolution modules achieve improvement comparing to the original version, and consistently outperform other baselines.
C. More Visualizations
We visualize examples of learned deformed locations in our DAT to verify the effectiveness of our method. As illustrated in Figure 6, the sampling points are depicted on the top of the object detection boxes and instance segmentation masks, from which we can see that the points are shifted to the target objects. In the left column, the deformed points are contracted to two target giraffes, while other points are keeping a nearly uniform grid with small offsets. In the middle column, the deformed points distribute densely among the person’s body and the surfing board both in the two stages. The right column shows the deformed points focus well to each of the six donuts, which shows our model has the ability to better model geometric shapes even with multiple targets. The above visualizations demonstrate that DAT learns meaningful offsets to sample better keys for attention to improve the performances on various vision tasks.
We also provide visualization results of the attention map given specific query tokens, and compare with Swin Transformer in Figure 7. We show key tokens with the highest attention values. It can be observed that our model focus on the more relevent part. As a showcase, our model allocates most attention to foreground objects, e.g., both gireffas in the first row. On the other hand, the region of interests in Swin Transformer is comparably local and fail to distinguish foreground from background, which is depicted in the last surfboard.