Vision Transformer with Progressive Sampling

Xiaoyu Yue, Shuyang Sun, Zhanghui Kuang, Meng Wei, Philip Torr, Wayne Zhang, Dahua Lin

Introduction

Transformers have become the de-facto standard architecture for natural language processing tasks. Thanks to their powerful global relation modeling abilities, researchers attempt to introduce them to fundamental computer vision tasks such as image classification , object detection and image segmentation recently. However, transformers are initially tailored for processing mid-size sequences, and of quadratic computational complexity w.r.t. the sequence length. Thus, they cannot directly be used to process images with massive pixels.

To overcome the computational complexity issue, the pioneer Vision Transformer (ViT) adopts a naive tokenization scheme that partitions one image into a sequence of regularly spaced patches, which are linearly projected into tokens. In this way, the image is converted into hundreds of visual tokens, which are fed into a stack of transformer encoder layers for classification. ViT attains excellent results, especially when pre-trained on large-scale datasets, which proves that full-transformer architecture is a promising alternative for vision tasks. However, the limitations of such a naive tokenization scheme are obvious. First, the hard splitting might separate some highly semantically correlated regions that should be modeled with the same group of parameters, which destructs inherent object structures and makes the input patches to be less informative. Figure 1 (a) shows that the cat head is divided into several parts, resulting in recognition challenges based on one part only. Second, tokens are placed on regular grids irrespective of the underlying image content. Figure 1 (a) shows that most grids focus on the uninterested background, which might lead to the interesting foreground object is submerged in interference signals.

The human vision system organizes visual information in a completely different way than indiscriminately processing a whole scene at once. Instead, it progressively and selectively focuses attention on interesting parts of the visual space when and where it is needed while ignoring uninterested parts, combining information from different fixations over time to understand the scene .

Inspired by the procedure above, we propose a novel transformer-based progressive sampling module, which is able to learn where to look in images, to mitigate the issues caused by the naive tokenization scheme in ViT . Instead of sampling from fixed locations, our proposed module updates the sampling locations in an iterative manner. As shown in Figure 1 (b), at each iteration, tokens of the current sampling step are fed into a transformer encoder layer, and a group of sampling offsets is predicted to update the sampling locations for the next step. This mechanism utilizes the capabilities of the transformer to capture global information to estimate offsets towards regions of interest, by combining with the local contexts and the positions of current tokens. In this way, attention progressively converges to discriminative regions of images step by step as what human vision does. Our proposed progressive sampling is differentiable, and readily plugged into ViT instead of the hard splitting, to construct end-to-end Vision Transforms with Progressive Sampling Networks dubbed as PS-ViT. Thanks to task-driven training, PS-ViT tends to sample object regions correlated with semantic structures. Moreover, it pays more attention to foreground objects while less to ambiguous background compared with simple tokenization.

The proposed PS-ViT outperforms the current state-of-the-art transformer-based approaches when trained from scratch on ImageNet. Concretely, it achieves 82.3%82.3\% top-1 accuracy on ImageNet which is higher than that of the recent ViT’s variant DeiT with only about 4×4\times fewer parameters and 2×2\times fewer FLOPs. As shown in Figure 2, we observe that PS-ViT is remarkably better, faster, and more parameter-efficient compared to state-of-the-art transformer-based networks ViT and DeiT.

Related Work

Transformers are first proposed for sequence models such as machine translation . Benefiting from their powerful global relation modeling abilities, and highly efficient training, transformers achieve significant improvements and become the de-facto standard in many Natural Language Processing (NLP) tasks .

Transformers in Computer Vision. Inspired by the success of transformers in NLP tasks, many researchers attempt to apply transformers, or attention mechanism in computer vision tasks, such as image classification , object detection , image segmentation , low-level image processing , video understanding generation , multi-modality understanding , and Optical Character Recognition (OCR) . Transformers’ powerful modelling capacity comes at the cost of computational complexity. Their consumed memory and computation grow quadratically w.r.t. the token length, which prevents them to being directly applied to images with massive pixels as tokens. Axial attention applied attention along a single axis of the tensor without flattening to reduce the computational resource requirement. iGPT simply down-sampled images to one low resolution, trained a sequence of transformers to auto-regressively predict pixels and achieved promising performance with a linear probe. ViT regularly partitioned one high-resolution image into 16×1616\times 16 patches, which were fed into one pure transformer architecture for classification, and attained excellent results even compared to state-of-the-art convolutional networks for the first time. However, ViT needs pretraining on large-scale datasets, thereby limiting their adoption. DeiT proposed a data-efficient training strategy and a teacher-student distillation mechanism , and improved ViT’s performance greatly. Moreover, it is trained on ImageNet only, and thus considerably simplifies the overall pipeline of ViT. Our proposed PS-ViT also starts from ViT. Instead of splitting pixels into a small number of visual tokens, we propose a novel progressive sampling strategy to avoid structure destruction and focus more attention on interesting regions.

Hard Visual Attention. PS-ViT as a series of glimpses akin to hard visual attention , makes decisions based on a subset of locations only in the input image. However, PS-ViT is differentiable and can be easily trained in an end-to-end fashion while previous hard visual attention approaches are non-differentiable and trained with Reinforcement Learning (RL) methods. These RL-based methods have proven to be less effective when scaled onto more complex datasets . Moreover, our PS-ViT targets at progressively sampling discriminative tokens for Vision Transformers while previous approaches locate interested regions for convolutional neural networks or sequence decoders . Our work is also related to the deformable convolution and deformable attention mechanism, however, the motivation and the way of pixel sampling in this work are different from what proposed in the deformable convolution and attention mechanism.

Methodology

In this section, we first introduce our progressive sampling strategy and then describe the overall architecture of our proposed PS-ViT network. Finally, we will elaborate on the details of PS-ViT. Symbols and notations of our method are presented in Table 1.

ViT regularly partitions one image into 16×1616\times 16 patches, which are linearly projected into a set of tokens, regardless of the content importance of image regions and the integral structure of objects. To pay more attention to interesting regions of images and mitigate the problem of structure destruction, we propose one novel progressive sampling module. As it is differentiable, it is adaptively driven via the following vision transformer based image classification task.

As shown in Figure 3, at each iteration, the sampling locations are updated via adding them with the offset vectors of the last iteration. Formally,

where πiy\pi_{i}^{y} and πix\pi_{i}^{x} map the location index ii to the row index and the column one respectively. \left\lfloor\cdot\right\rfloor indicates the floor operation. shs_{h} and sws_{w} are the step size in the yy and xx axial direction respectively. Initial tokens are then sampled over the input feature map at the sampled locations as follows:

With the progressive sampling strategy, the sampled locations progressively converge to interesting regions of images. Therefore, we name it by progressive sampling.

2 Overall Architecture

As shown in Figure 4, the architecture of the PS-ViT consists of four main components: 1) a feature extractor module to predict dense tokens; 2) a progressive sampling module to sample discriminative locations; 3) a vision transformer module that follows the similar configuration of ViT and DeiT ; 4) a classification module.

The feature extractor module aims at extracting the dense feature map F\mathbf{F}, where the progressive sampling module can simple tokens Tt\mathbf{T}_{t}. Each pixel of the dense feature map F\mathbf{F} can be treated as a token associated with a patch of the image. We employ the convolutional stem and the first two residual blocks in the first stage of the ResNet50 as our feature extractor module because the convolution operator is especially effective at modeling spatially local contexts.

3 Implementation

Transformer Encoder Layer. The transformer encoder layer serves as the basic building block for the progressive sampling module and the vision transformer module. Each transformer encoder layer has a multi-head self-attention and a feed-forward unit.

where QT\mathbf{Q}^{T} indicates the transpose of Q\mathbf{Q}, and softmax()\text{softmax}(\cdot) is the softmax operation applied over each row of the input matrix. For Multi-Head self-Attention (MHA), the queries, keys and values are generated via linear transformations on the inputs for MM times with one individual learned weight for each head. Then attention function is applied in parallel on queries, keys and values of each head. Formally,

The feed-forward unit of the transformer encoder layer consists of two fully connected layers with one GELU non-linear activation between them and the latent variable dimension being 3C3C. For simplicity, the transformer encoder layers in both the progressive sampling module and the vision transformer module keep the same settings.

Progressive Sampling Back-propagation. The back-propagation of the progressive sampling is straightforward. According to Equation (1) and Equation (3), for each sampling location ii, the gradient w.r.t. the sampling offsets oti\mathbf{o}_{t}^{i} at the iteration tt is computed as:

where K(,)K(\cdot,\cdot) is the kernel for bilinear interpolation to calculate weights for each integral spatial location q\mathbf{q}.

Network Configuration. The feature dimension CC, the iteration number NN in the progressive sampling module, the vision transformer layer number NvN_{v} in the vision transformer module, and the head number MM in each transformer layer affect the model size, FLOPs, and performances. In this paper, we configure them with different speed-performance tradeoffs in Table 2 so that the proposed PS-ViT can be used in different application scenarios. The number of sampling points along each spatial dimension nn is set as 1414 by default.

Considering the sampling in each iteration is conducted over the same feature map F\mathbf{F} in the progressive sampling module, we try to share weights between those iterations to further reduce the number of trainable parameters. As shown in Table 2, about 25%\% parameters can be saved in this setting.

Experiments

All the experiments for image classification are conducted on the ImageNet 2012 dataset that includes 1k classes, 1.2 million images for training, and 50 thousand images for validation. We train our proposed PS-ViT on ImageNet without pretraining on large-scale datasets. We train all the models of PS-ViT using PyTorch with 8 GPUs. Inspired by the data-efficient training as done in , we use the AdamW as the optimizer. The total training epoch number and the batch size are set to 300300 and 512512 respectively. The learning rate is initialized with 0.00050.0005, and decays with the cosine annealing schedule . We regularize the loss via the smoothing label with ϵ=0.1\epsilon=0.1. We use random crop, Rand-Augment , Mixup, and CutMix to augment images during training. Images are resized to 256×256256\times 256, and cropped at the center with 224×224224\times 224 size when testing. Training strategy and its hyper-parameter settings are summarized in Table 4.

2 Results on ImageNet

We compare our proposed PS-ViT with state-of-the-art networks on the standard image classification benchmark ImageNet in terms of parameter numbers, FLOPS, and top-1 and top-5 accuracies in Table 3.

Comparison with CNN based networks. Our PS-ViTs considerably outperform ResNets while with much fewer parameters and FLOPs. Specifically, Compared with ResNet-18, PS-ViT-Ti/14 absolutely improves the top-1 accuracy by 5.8% while reducing 6.9 M parameters and 0.2 B FLOPs. We can observe a similar trend when comparing PS-ViT-B/10 (PS-ViT-B/14) and ResNet-50 (ResNet-101). Our proposed PS-ViT achieves superior performance and computational efficiency when compared with the state-of-the-art CNN based network RegNet . Particularly, when compared with RegNetY-16GF, PS-ViT-B/18 improves the top-1 accuracy by 1.8% with about a quarter of parameters and a half of FLOPS.

Comparison with transformer based networks. Table 3 shows that our proposed PS-ViT outperforms ViT and its recent variant DeiT . In particular, PS-ViT-B/18 achieves 82.3% top-1 accuracy which is 0.5% higher than the baseline model DeiT-B while with 21 M parameters and 8.8 B FLOPs only. Our performance gain attributes to two parts. First, PS-ViT samples CNN-based tokens which is more efficient than raw image patches used in ViT and DeiT . Second, our progressive sampling module can adaptively focus on regions of interest and produce more semantically correlated tokens than the naive tokenization used in .

3 Ablation Studies

The PS-ViT models predict on the class token in all the ablation studies.

A larger sampling number nn leads to better performance. We first evaluate how the sampling number parameter nn affects the PS-ViT performance. The sequence length of sampled tokens which is fed into the vision transformer module is n2n^{2}. The more the sampled tokens, the more information PS-ViT can extract. However, sampling more tokens would increase the computation and memory usage. Table 5 reports the FLOPs, and top-1 and top-5 accuracies with different nn. It has been shown that the FlOPs increases as nn becomes larger, and the accuracy increases when n16n\leq 16 and plateaus when n>16n>16. Considering the speed-accuracy trade-off, we set n=14n=14 by default except as otherwise noted.

The performance can be further improved with more iterations of progressive sampling. We then evaluate the effect of the iteration number NN of the progressive sampling module in Table 6. To keep the computational complexity unchanged, all models in Table 6 have 14N14-N transformer layers in the vision transformer module, and totally 1414 transformer layers in the entire network. N=1N=1 indicates the sampling points will not be updated. It has been shown that PS-ViT performs the best when N=8N=8 and the accuracy begins to decline when N>8N>8. As we keep the total number of transformer layers unchanged, increasing NN will result in the decrease of transformer layers in lateral modeling, which might damage the performance. Considering the accuracy improvement is negligible from N=4N=4 to N=8N=8, we set N=4N=4 by default except as otherwise noted.

Fair comparison with ViT. The network hyper-parameters in the transformer encoder of PS-ViT are different from the original setting of ViT. For a fair comparison, we further study how ViT performs when the network hyper-parameters are set to be the same as ours. We set the number of layers, channels, heads, and the number of tokens to be the same as what was proposed in PS-ViT-B/14, and train the network under the same training regime. As shown in Table 7, ViT achieves 78.4% top-1 accuracy, which is greatly inferior to its PS-ViT counterpart. We thereby conclude that the progressive sampling module can fairly boost the performance of ViT.

Sharing weights between sampling iterations. Model size (parameter number) is one of the key factors when deploying deep models on terminal devices. Our proposed PS-ViT is very terminal device friendly as it can share weights in the progressive sampling module with a negligible performance drop. Table 8 compares PS-ViT with and without weight sharing in the progressive sampling module. It has been shown that weight sharing can reduce the parameter number by about 21%\sim23% while with a slight performance drop, especially for PS-ViT-B/12 and PS-ViT-B/14.

4 Speed Comparison

Our proposed PS-ViT is efficient not only in theory but also in practice. Table 9 compare the efficiency of state-of-the-arts networks in terms of FLOPs and speed (images per second). For fair comparison, we measure the speed of all of the models on a server with one 32GB V100 GPU. The batch size is fixed to 128 and the number of images that can be inferred per second is reported averaged over 50 runs. It has been shown that PS-ViT is much more efficient than ViT and DeiT when their top-1 accuracies are comparable. Specifically, PS-ViT-B/14 and DeiT-B have similar accuracy around 81.7%. However, PS-ViT-B/14 achieves about 2.0 times and 3.3 times as fast as DeiT-B in terms of speed and FLOPs respectively. PS-ViT-B/10 speeds up ViT-B/16 by about 13.7 times and 16.9 times in terms of speed and FLOPs while improving 2.7% top-1 accuracy.

5 Visualization

In order to explore the mechanism of the learnable sampling locations in our method, we visualize the predicted offsets of our proposed progressive sampling module in Figure 5. We can observe that the sampling locations are adaptively adjusted according to the content of the images. Sampling points around objects tend to move to the foreground area and converge to the key parts of objects. With this mechanism, discriminative regions such as the chicken head are sampled densely, retaining the intrinsic structure information of highly semantically correlated regions.

6 Transfer Learning

In addition to ImageNet, we also transfer PS-ViT to downstream tasks to demonstrate its generalization ability. We follow the practice done in DeiT for fair comparison. Table 10 shows results for models that have been pre-trained on ImageNet and finetuned for other datasets including CIFAR-10 , CIFAR-100 , Flowers-102 and Stanford Cars . PS-ViT-B/14 can perform on-par-with or even better than DeiT-B with about 4×4\times fewer FLOPS and parameters on all these datasets, which demonstrates the superiority of our PS-ViT.

Conclusions

In this paper, we propose an efficient Vision Transformers with Progressive Sampling (PS-ViT). PS-ViT first extracts feature maps via a feature extractor, and then progressively selects discriminative tokens with one progressive sampling module. The sampled tokens are fed into a vision transformer module and the classification module for image classification. PS-ViT mitigates the structure destruction issue in the ViT and adaptively focuses on interesting regions of objects. It achieves considerable improvement on ImageNet compared with ViT and its recent variant DeiT. We also provide a deeper analysis of the experimental results to investigate the effectiveness of each component. Moreover, PS-ViT is more efficient than its transformer based competitors both in theory and in practice.

Acknowledgement. This work was partially supported by Innovation and Technology Commission of the Hong Kong Special Administrative Region, China (Enterprise Support Scheme under the Innovation and Technology Fund B/E030/18).

References