Transformer in Transformer

Kai Han, An Xiao, Enhua Wu, Jianyuan Guo, Chunjing Xu, Yunhe Wang

Introduction

In the past decade, the mainstream deep neural architectures used in the computer vision (CV) are mainly established on convolutional neural networks (CNNs) . Differently, transformer is a type of neural network mainly based on self-attention mechanism , which can provide the relationships between different features. Transformer is widely used in the field of natural language processing (NLP), e.g., the famous BERT and GPT-3 models. The power of these transformer models inspires the whole community to investigate the use of transformer for visual tasks.

To utilize the transformer architectures for conducting visual tasks, a number of researchers have explored for representing the sequence information from different data. For example, Wang et al. explore self-attention mechanism in non-local networks for capturing long-range dependencies in video and image recognition. Carion et al. present DETR , which treats object detection as a direct set prediction problem and solve it using a transformer encoder-decoder architecture. Chen et al. propose the iGPT , which is the pioneering work applying pure transformer model (i.e., without convolution) on image recognition by self-supervised pre-training.

Different from the data in NLP tasks, there exists a semantic gap between input images and the ground-truth labels in CV tasks. To this end, Dosovitskiy et al. develop the ViT , which paves the way for transferring the success of transformer based NLP models. Concretely, ViT divides the given image into several local patches as a visual sequence. Then, the attention can be naturally calculated between any two image patches for generating effective feature representations for the recognition task. Subsequently, Touvron et al. explore the data-efficient training and distillation to enhance the performance of ViT on the ImageNet benchmark and obtain an about 81.8% ImageNet top-1 accuracy, which is comparable to that of the state-of-the-art convolutional networks. Chen et al. further treat the image processing tasks (e.g., denosing and super-resolution) as a series of translations and develop the IPT model for handling multiple low-level computer vision problems . Nowadays, transformer architectures have been used in a growing number of computer vision tasks such as image recognition , object detection , and segmentation .

Although the aforementioned visual transformers have made great efforts to boost the models’ performances, most of existing works follow the conventional representation scheme used in ViT, i.e., dividing the input images into patches. Such a exquisite paradigm can effectively capture the visual sequential information and estimate the attention between different image patches. However, the diversity of natural images in modern benchmarks is very high, e.g., there are over 120 M images with 1000 different categories in the ImageNet dataset . As shown in Figure 1, representing the given image into local patches can help us to find the relationship and similarity between them. However, there are also some sub-patches inside them with high similarity. Therefore, we are motivated to explore a more exquisite visual image dividing method for generating visual sequences and improve the performance.

In this paper, we propose a novel Transformer-iN-Transformer (TNT) architecture for visual recognition as shown in Figure 1. To enhance the feature representation ability of visual transformers, we first divide the input images into several patches as “visual sentences” and then further divide them into sub-patches as “visual words”. Besides the conventional transformer blocks for extracting features and attentions of visual sentences, we further embed a sub-transformer into the architecture for excavating the features and details of smaller visual words. Specifically, features and attentions between visual words in each visual sentence are calculated independently using a shared network so that the increased amount of parameters and FLOPs (floating-point operations) is negligible. Then, features of words will be aggregated into the corresponding visual sentence. The class token is also used for the subsequent visual recognition task via a fully-connected head. Through the proposed TNT model, we can extract visual information with fine granularity and provide features with more details. We then conduct a series of experiments on the ImageNet benchmark and downstream tasks to demonstrate its superiority and thoroughly analyze the impact of the size for dividing visual words. The results show that our TNT can achieve better accuracy and FLOPs trade-off over the state-of-the-art transformer networks.

Approach

In this section, we describe the proposed transformer-in-transformer architecture and analyze the computation and parameter complexity in details.

We first briefly describe the basic components in transformer , including MSA (Multi-head Self-Attention), MLP (Multi-Layer Perceptron) and LN (Layer Normalization).

Finally, a linear layer is used to produce the output. Multi-head self-attention splits the queries, keys and values to hh parts and perform the attention function in parallel, and then the output values of each head are concatenated and linearly projected to form the final output.

MLP.

The MLP is applied between self-attention layers for feature transformation and non-linearity:

where WW and bb are the weight and bias term of fully-connected layer respectively, and σ()\sigma(\cdot) is the activation function such as GELU .

LN.

2 Transformer in Transformer

In TNT, we have two data flows in which one flow operates across the visual sentences and the other processes the visual words inside each sentence. For the word embeddings, we utilize a transformer block to explore the relation between visual words:

where l=1,2,,Ll=1,2,\cdots,L is the index of the ll-th block, and LL is the total number of stacked blocks. The input of the first block Y0iY_{0}^{i} is just Yi{Y}^{i} in Eq. 5. All word embeddings in the image after transformation are Yl=[Yl1,Yl2,,Yln]\mathcal{Y}_{l}=[Y_{l}^{1},Y_{l}^{2},\cdots,Y_{l}^{n}]. This can be viewed as an inner transformer block, denoted as TinT_{in}. This process builds the relationships among visual words by computing interactions between any two visual words. For example, in a patch of human face, a word corresponding to the eye is more related to other words of eyes while interacts less with forehead part.

This outer transformer block ToutT_{out} is used for modeling relationships among sentence embeddings.

In summary, the inputs and outputs of the TNT block include the visual word embeddings and sentence embeddings as shown in Fig. 1(b), so the TNT can be formulated as

In our TNT block, the inner transformer block is used to model the relationship between visual words for local feature extraction, and the outer transformer block captures the intrinsic information from the sequence of sentences. By stacking the TNT blocks for LL times, we build the transformer-in-transformer network. Finally, the classification token serves as the image representation and a fully-connected layer is applied for classification.

Spatial information is an important factor in image recognition. For sentence embeddings and word embeddings, we both add the corresponding position encodings to retain spatial information as shown in Fig. 1. The standard learnable 1D position encodings are utilized here. Specifically, each sentence is assigned with a position encodings:

3 Complexity Analysis

A standard transformer block includes two parts, i.e., the multi-head self-attention and multi-layer perceptron. The FLOPs of MSA are 2nd(dk+dv)+n2(dk+dv)2nd(d_{k}+d_{v})+n^{2}(d_{k}+d_{v}), and the FLOPs of MLP are 2ndvrdv2nd_{v}rd_{v} where rr is the dimension expansion ratio of hidden layer in MLP. Overall, the FLOPs of a standard transformer block are

Since rr is usually set as 4, and the dimensions of input, key (query) and value are usually set as the same, the FLOPs calculation can be simplified as

The number of parameters can be obtained as

Our TNT block consists of three parts: an inner transformer block TinT_{in}, an outer transformer block ToutT_{out} and a linear layer. The computation complexity of TinT_{in} and ToutT_{out} are 2nmc(6c+m)2nmc(6c+m) and 2nd(6d+n)2nd(6d+n) respectively. The linear layer has FLOPs of nmcdnmcd. In total, the FLOPs of TNT block are

Similarly, the parameter complexity of TNT block is calculated as

4 Network Architecture

We build our TNT architectures by following the basic configuration of ViT and DeiT . The patch size is set as 16×\times16. The number of sub-patches is set as m=44=16m=4\cdot 4=16 by default. Other size values are evaluated in the ablation studies. As shown in Table 1, there are three variants of TNT networks with different model sizes, namely, TNT-Ti, TNT-S and TNT-B. They consist of 6.1M, 23.8M and 65.6M parameters respectively. The corresponding FLOPs for processing a 224×\times224 image are 1.4B, 5.2B and 14.1B respectively.

Experiments

In this section, we conduct extensive experiments on visual benchmarks to evaluate the effectiveness of the proposed TNT architecture.

ImageNet ILSVRC 2012 is an image classification benchmark consisting of 1.2M training images belonging to 1000 classes, and 50K validation images with 50 images per class. We adopt the same data augmentation strategy as that in DeiT including random crop, random clip, Rand-Augment , Random Erasing , Mixup and CutMix . For the license of ImageNet dataset, please refer to http://www.image-net.org/download.

In addition to ImageNet, we also test on the downstream tasks with transfer learning to evaluate the generalization ability of TNT. The details of used visual datasets are listed in Table 2. The data augmentation strategy of image classification datasets are the same as that of ImageNet. For COCO and ADE20K, the data augmentation strategy follows that in PVT . For the licenses of these datasets, please refer to the original papers.

Implementation Details.

We utilize the training strategy provided in DeiT . The main advanced technologies apart from common settings include AdamW , label smoothing , DropPath , and repeated augmentation . We list the hyper-parameters in Table 3 for better understanding. All the models are implemented with PyTorch and MindSpore and trained on NVIDIA Tesla V100 GPUs. The potential negative societal impacts may include energy consumption and carbon dioxide emissions of GPU computation.

2 TNT on ImageNet

We train our TNT models with the same training settings as that of DeiT . The recent transformer-based models like ViT and DeiT are compared. To have a better understanding of current progress of visual transformers, we also include the representative CNN-based models such as ResNet , RegNet and EfficientNet . The results are shown in Table 4. We can see that our transformer-based model, i.e., TNT outperforms all other visual transformer models. In particular, TNT-S achieves 81.5% top-1 accuracy which is 1.7% higher than the baseline model DeiT-S, indicating the benefit of the introduced TNT framework to preserve local structure information inside the patch. Compared to CNNs, TNT can outperform the widely-used ResNet and RegNet. Note that all the transformer-based models are still inferior to EfficientNet which utilizes special depth-wise convolutions, so it is yet a challenge of how to beat EfficientNet using pure transformer.

We also plot the accuracy-parameters and accuracy-FLOPs line charts in Fig. 2 to have an intuitive comparison of these models. Our TNT models consistently outperform other transformer-based models by a significant margin.

Deployment of transformer models on devices is important for practical applications, so we test the inference speed of our TNT model. Following , the throughput is measured on an NVIDIA V100 GPU and PyTorch, with 224×\times224 input size. Since the resolution and content inside the patch is smaller than that of the whole image, we may need fewer blocks to learn its representation. Thus, we can reduce the used TNT blocks and replace some with vanilla transformer blocks. From the results in Table 5, we can see that our TNT is more efficient than DeiT and PVT by achieving higher accuracy with similar inference speed.

3 Ablation Studies

Position information is important for image recognition. In TNT structure, sentence position encoding is for maintaining global spatial information, and word position encoding is used to preserve locally relative position. We verify their effect by removing them separately. As shown in Table 6, we can see that TNT-S with both patch position encoding and word position encoding performs the best by achieving 81.5% top-1 accuracy. Removing sentence/word position encoding results in a 0.8%/0.7% accuracy drop respectively, and removing all position encodings heavily decrease the accuracy by 1.0%.

Number of heads.

The effect of #heads in standard transformer has been investigated in multiple works and a head width of 64 is recommended for visual tasks . We adopt the head width of 64 in outer transformer block in our model. The number of heads in inner transformer block is another hyper-parameter for investigation. We evaluate the effect of #heads in inner transformer block (Table 7). We can see that a proper number of heads (e.g., 2 or 4) achieve the best performance.

Number of visual words.

In TNT, the input image is split into a number of 16×\times16 patches and each patch is further split into mm sub-patches (visual words) of size (s,s)(s,s) for computational efficiency. Here we test the effect of hyper-parameter mm on TNT-S architecture. When we change mm, the embedding dimension cc also changes correspondingly to control the FLOPs. As shown in Table 8, we can see that the value of mm has slight influence on the performance, and we use m=16m=16 by default for its efficiency, unless stated otherwise.

4 Visualization

We visualize the learned features of DeiT and TNT to further understand the effect of the proposed method. For better visualization, the input image is resized to 1024×\times1024. The feature maps are formed by reshaping the patch embeddings according to their spatial positions. The feature maps in the 1-st, 6-th and 12-th blocks are shown in Fig. 3(a) where 12 feature maps are randomly sampled for these blocks each. In TNT, the local information are better preserved compared to DeiT. We also visualize all the 384 feature maps in the 12-th block using t-SNE (Fig. 3(b)). We can see that the features of TNT are more diverse and contain richer information than those of DeiT. These benefits owe to the introduction of inner transformer block for modeling local features.

In addition to the patch-level features, we also visualize the pixel-level embeddings of TNT in Fig. 4. For each patch, we reshape the word embeddings according to their spatial positions to form the feature maps and then average these feature maps by the channel dimension. The averaged feature maps corresponding to the 14×\times14 patches are shown in Fig. 4. We can see that the local information is well preserved in the shallow layers, and the representations become more abstract gradually as the network goes deeper.

Visualization of Attention Maps.

There are two self-attention layers in our TNT block, i.e., an inner self-attention and an outer self-attention for modeling relationship among visual words and sentences respectively. We show the attention maps of different queries in the inner transformer in Figure 5. For a given query visual word, the attention values of visual words with similar appearance are higher, indicating their features will be interacted more relevantly with the query. These interactions are missed in ViT and DeiT, etc. The attention maps in the outer transformer can be found in the supplemental material.

5 Transfer Learning

To demonstrate the strong generalization ability of TNT, we transfer TNT-S, TNT-B models trained on ImageNet to the downstream tasks.

Following DeiT , we evaluate our models on 4 image classification datasets with training set size ranging from 2,040 to 50,000 images. These datasets include superordinate-level object classification (CIFAR-10 , CIFAR-100 ) and fine-grained object classification (Oxford-IIIT Pets , Oxford 102 Flowers and iNaturalist 2019 ), shown in Table 2. All models are fine-tuned with an image resolution of 384×\times384. We adopt the same training settings as those at the pre-training stage by preserving all data augmentation strategies. In order to fine-tune in a different resolution, we also interpolate the position embeddings of new patches. For CIFAR-10 and CIFAR-100, we fine-tune the models for 64 epochs, and for fine-grained datasets, we fine-tune the models for 300 epochs. Table 9 compares the transfer learning results of TNT to those of ViT, DeiT and other convolutional networks. We find that TNT outperforms DeiT in most datasets with less parameters, which shows the superiority of modeling pixel-level relations to get better feature representation.

Pure Transformer Object Detection.

We construct a pure transformer object detection pipeline by combining our TNT and DETR . For fair comparison, we adopt the training and testing settings in PVT and add a 2×\times2 average pooling to make the output size of TNT backbone the same as that of PVT and ResNet. All the compared models are trained using AdamW with batch size of 16 for 50 epochs. The training images are randomly resized to have a shorter side in the range of and a longer side within 1333 pixels. For testing, the shorter side is set as 800 pixels. The results on COCO val2017 are shown in Table 10. Under the same setting, DETR with TNT-S backbone outperforms the representative pure transformer detector DETR+PVT-Small by 3.5 AP with similar parameters.

Pure Transformer Semantic Segmentation.

We adopt the segmentation framework of Trans2Seg to build the pure transformer semantic segmentation based on TNT backbone. We follow the training and testing configuration in PVT for fair comparison. All the compared models are trained by AdamW optimizer with initial learning rate of 1e-4 and polynomial decay schedule. We apply random resize and crop of 512×\times512 during training. The ADE20K results with single scale testing are shown in Table 11. With similar parameters, Trans2Seg with TNT-S backbone achieves 43.6% mIoU, which is 1.0% higher than that of PVT-small backbone and 2.8% higher than that of DeiT-S backbone.

Conclusion

In this paper, we propose a novel Transformer-iN-Transformer (TNT) network architecture for visual recognition. In particular, we uniformly split the image into a sequence of patches (visual sentences) and view each patch as a sequence of sub-patches (visual words). We introduce a TNT block in which an outer transformer block is utilized for processing the sentence embeddings and an inner transformer block is used to model the relation among word embeddings. The information of visual word embeddings is added to the visual sentence embedding after the projection of a linear layer. We build our TNT architecture by stacking the TNT blocks. Compared to the conventional vision transformers (ViT) which corrupts the local structure of the patch, our TNT can better preserve and model the local information for visual recognition. Extensive experiments on ImageNet and downstream tasks have demonstrate the effectiveness of the proposed TNT architecture.

Acknowledgement

This work was supported by NSFC (62072449, 61632003), Guangdong-Hongkong-Macao Joint Research Grant (2020B1515130004) and Macao FDCT (0018/2019/AKP, 0015/2019/AKP).

Appendix A Appendix

In Figure 6, we plot the attention maps from each patch to all the patches. We can see that for both DeiT-S and TNT-S, more patches are related as layer goes deeper. This is because the information between patches has been fully communicated with each other in deeper layers. As for the difference between DeiT and TNT, the attention of TNT can focus on the meaningful patches in Block-12, while DeiT still pays attention to the tree which is not related to the pandas.

Attention between Class Token and Patches.

In Figure 7, we plot the attention maps between class token to all the patches for some randomly sampled images. We can see that the output feature mainly focus on the patches related to the object to be recognized.

A.2 Exploring SE module in TNT

Inspired by squeeze-and-excitation (SE) network for CNNs , we propose to explore channel-wise attention for transformers. We first average all the sentence (word) embeddings and use a two-layer MLP to calculate the attention values. The attention is multiplied to all the embeddings. The SE module only brings in a few extra parameters but is able to perform dimension-wise attention for feature enhancement. From the results in Table 12, adding SE module into TNT can further improve the accuracy slightly.

A.3 Object Detection with Faster RCNN

As a general backbone network, TNT can also be applied with multi-scale vision models like Faster RCNN . We extract the features from different layers of TNT to construct multi-scale features. In particular, FPN takes 4 levels of features (14\frac{1}{4}, 18\frac{1}{8}, 116\frac{1}{16}, 132\frac{1}{32}) as input, while the resolution of feature of every TNT block is 116\frac{1}{16}. We select the 4 layers from shallow to deep (3rd, 6th, 9th, 12th) to form multi-level representation. To match the feature shape, we insert deconvolution/convolution layers with proper stride. We evaluate TNT-S and DeiT-S on Faster RCNN with FPN . The DeiT model is used in the same way. The COCO2017 val results are shown in Table 13. TNT achieves much better performance than ResNet and DeiT backbones, indicating its generalization for FPN-like framework.

References