Patch Slimming for Efficient Vision Transformers

Yehui Tang, Kai Han, Yunhe Wang, Chang Xu, Jianyuan Guo, Chao Xu, Dacheng Tao

Introduction

Recently, transformer models have been introduced into the field of computer vision and achieved high performance in many tasks such as object recognition , image process , and video analysis . Compared with the convolutional neural networks (CNNs), the transformer architecture introduces less inductive biases and hence has larger potential to absorb more training data and generalize well on more diverse tasks . However, similar to CNNs, vision transformers also suffer high computational cost, which blocks their deployment on resource-limited devices such as mobile phones and various IoT devices. To apply a deep neural network in such real scenarios, massive model compression algorithms have been proposed to reduce the required computational cost . For example, quantization algorithms approximate weights and intermediate features maps in neural networks with low-bit data . Knowledge distillation improves the performance of a compact network by transferring knowledge from giant models .

In addition, network pruning is widely explored and used to reduce the neural architecture by directly removing useless components in the pre-defined network . Structured pruning discards whole contiguous components of a pre-trained model, which has attracted much attention in recent years, as it can realize acceleration without specific hardware design. In CNNs, removing a whole filter for improving the network efficiency is a representative paradigm, named channel pruning (or filter pruning) . For example, Liu et al. introduce scaling factors to control the information flow in the neural network and filters with small factors will be removed. Although the aforementioned network compression methods have made tremendous efforts for deploying compact convolutional neural networks, there are only few works discussing how to accelerate vision transformers.

Different from the paradigm in conventional CNNs, the vision transformer splits the input image into multiple patches and calculates the features of all these patches in parallel. The attention mechanism will further aggregate all patch embeddings into visual features as the output. Elements in the attention map reflect the relationship or similarity between any two patches, and the largest attention value for constructing the feature of an arbitrary patch is usually calculated from itself. Thus, we have to preserve this information flow in the pruned vision transformers for retaining the model performance, which cannot be guaranteed in the conventional CNN channel pruning methods. Moreover, not all the manually divided patches are informative enough and deserve to be preserved in all layers, e.g., some patches are redundant with others. Hence we consider developing a patch slimming approach that can effectively identify and remove redundant patches.

In this paper, we present a novel patch slimming algorithm for accelerating the vision transformers. In contrast to existing works focusing on the redundancy in the network channel dimension, we aim to explore the computational redundancy in the patches of a vision transformer (as shown in Figure 1. The proposed method removes redundant patches from the given transformer architecture in a top-down framework, in order to ensure the retained high-level features of discriminative patches can be well calculated. Specifically, the patch pruning will execute from the last layer to the first layer, wherein the useless patches are identified by calculating their importance scores to the final classification feature (i.e., class token). To guarantee the information flow, a patch will be preserved if the patches in the same spatial location are retained by deeper layers. For other patches, the importance scores determine whether they are preserved, and patches with lower scores will be discarded. The whole pruning scheme for vision transformers is conducted under a careful control of the network error, so that the pruned transformer network can maintain the original performance with significantly lower computational cost. Extensive experiments validate the effectiveness of the proposed method for deploying efficient vision transformers. For example, our method can reduce more than 45% FLOPs of the ViT-Ti model with only 0.2% top-1 accuracy loss on the ImageNet dataset.

Related work

Structure pruning for CNNs. Channel pruning discards the entire convolution kernels to accelerate the inference process and reduce the required memory cost . To identify the redundant filters, massive methods have been proposed. Wen et al. add a group-sparse regularization on the filters and remove filters with small norm. Beyond imposing sparsity regularization on the filters directly, Liu et al. introduce extra scaling factors to each channel and these scaling factors are trained to be sparse. Filters with small scaling factors has less impact on the network output and will be removed for accelerating inference. He et al. rethink the criterion that filters with small norm values are less important and propose to discard the filters having larger similarity to others. To maximally excavate redundancy, Tang set up a scientific control to alleviate the distribution of irrelevant factors and remove filters with little relation to the given task. In the conventional channel pruning for CNNs, channels in different layers have no one-to-one relationship, and then the choice of effective channels in a layer has little impact on that in other channels.

Structure pruning for transformers. In the transformer model for NLP tasks, a series of works focus on reducing the heads in the multi-head attention (MSA) module. For example, Michel et al. observes that removing a large percentages of heads in the pre-trained BERT models has limited impact on its performance. Voita et al. analyze the role of each head in the transformer and evaluate their contribution to the model performance. Those heads with less contributions will be reduced. Besides the MSA module, the neurons in the multilayer perceptron (MLP) module are also pruned in . Designed for vision transformers, VTP reduces the number of embedding dimensions by introducing control coefficients and removes neurons with small coefficients. Different from them, the proposed patch slimming explores the redundancy from a new perspective by considering the information integration of different patches in a vision transformer. Actually, reducing patches can be also combined with pruning in other dimensions to realize higher acceleration.

Patch Slimming for Vision Transformer

In this section, we introduce the scheme of pruning patches in vision transformers. We first review the vision transformer briefly and then introduce the formulation of patch slimming.

where dd is embedding dimension, HH is the number of heads, Qlh=Zl1WlhqQ_{l}^{h}=Z_{l-1}W^{hq}_{l}, Klh=Zl1WlhkK_{l}^{h}=Z_{l-1}W^{hk}_{l}, and Vlh=Zl1WlhvV_{l}^{h}=Z_{l-1}W^{hv}_{l} are the query, key and value of the hh-th head in the ll-th layer, respectively. WlaW_{l}^{a}, WlbW_{l}^{b} are the weights for linear transformation and ϕ()\phi(\cdot) is the non-linear activation function (e.g., GeLU). Most of recent vision transformer models are constructed by stacking MSA and MLP modules alternately and a block Bl()\mathcal{B}_{l}(\cdot) is defined as Bl(Zl1)=MLP(MSA(Zl1)+Zl1)+Zl\mathcal{B}_{l}(Z_{l-1})={\rm MLP}({\rm MSA}(Z_{l-1})+Z_{l-1})+Z^{\prime}_{l}.

As discussed above, there is considerable redundant information existing in the patch level of vision transformers. To further verify this phenomenon, we calculate the average cosine similarity between patches within a layer, and show how similarity vary w.r.t. layers in Figure 3. The similarity between patches increase rapidly as layers increase, and the average similarity even exceed 0.8 in deeper layers. The high similarity implies that patches are redundant especially in the deeper layers and removing them will not obviously affect the feature calculation.

Patch slimming aims to recognize and discard redundant patches for accelerating the inference process (as shown in Figure 1). Here we use a binary vector ml{0,1}N{\bm{m}}_{l}\in\{0,1\}^{N} to indicate whether a patch is preserved or not, the pruned MSA and MLP modules can be formulated as follows:

where \mboxdiag(ml){\mbox{diag}}({\bm{m}}_{l}) is a diagonal matrix whose diagonal line is composed of elements in ml{\bm{m}}_{l}. Specifically, ml,i=0{\bm{m}}_{l,i}=0 indicates that the ii-th patch in the ll-th layer is pruned. Z^l1\widehat{Z}_{l-1}, Z^l\widehat{Z}^{\prime}_{l} are the input and the intermediate features of the ll-th layer in a pruned vision transformer. Then the pruned block is defined as B^l(Z^l1)=MLP^l(MSA^l(Z^l1)+Z^l1)+Z^l\widehat{\mathcal{B}}_{l}(\widehat{Z}_{l-1})=\widehat{\rm MLP}_{l}(\widehat{\rm MSA}_{l}(\widehat{Z}_{l-1})+\widehat{Z}_{l-1})+\widehat{Z}^{\prime}_{l}.

In practical implementation, only the effective patches of input feature Z^l1\widehat{Z}_{l-1} are selected to calculate queries, and then all the subsequent operations are only implemented on these effective patches. Thus, the computation of the pruned patches can be avoided According to ml{\bm{m}}_{l}, only effective patches from the shortcut branch are added to the output of pruned MSA, while the output of pruned MLP is padded with zeros before added to the shortcut..

Computation Efficiency. Compared with the original block Bl()\mathcal{B}_{l}(\cdot), the pruned B^l()\widehat{\mathcal{B}}_{l}(\cdot) can save a large amount of computational cost. Given a block B()\mathcal{B}(\cdot) with NN patches and dd-dimension embedding, the computational costs of MLP (2-layers with hidden dimension dd^{\prime}) and MSA are (2Ndd)(2Ndd^{\prime}) and (2N2d+4Nd2)(2N^{2}d+4Nd^{2}), respectively. After pruning η%\eta\% patches, all the computational components in MLP are pruned, and then η%\eta\% FLOPs in the MLP module are reduced. For the MSA module, the cost of calculating query, attention map and output projection can be reduced, and then η%(2N2d+2Nd2)\eta\%(2N^{2}d+2Nd^{2}) FLOPs is reduced.

Excavating Redundancy via Inverse Pruning

In this section, we present the top-down framework to prune patches in the vision transformer, and provide an effective importance score estimation of each patch.

For patch slimming in vision transformer, we adopt a top-down manner to prune patches layer-by-layer. It is a natural choice with two reasons as described in the following.

For a CNN model, pruning channels in different layers independently can achieves high performance . However, this paradigm cannot work well in vision transformers. The main reason is that patches in different layers of a vision transformer are one-to-one corresponding. Figure 2 compares pruning channels in CNNs and pruning patches in vision transformers. As own in Figure 2(a), channels in adjacent layers of a CNN model are fully connected by learnable weights, and each channel contains information from the entire image. However, in the vision transformer (Figure 2(b)), different patches communicate with others by an attention map, which reflects the similarity between different patches. If patch ii and patch jj are more similar, the corresponding value AlhijA_{lh}^{ij} tends to have a larger value. The diagonal elements AlhiiA_{lh}^{ii} usually plays a dominant role, that is, a patch pays highest attention to the input at the position of itself. Besides, the shortcut connection directly copies the feature in the ll-layer to the corresponding patches in the next layer. This one-to-one correspondence inspires us to preserve some important patches in the same spatial locations of different layers, which can guarantee the information propagation across layers.

Another characteristic of vision transformer is that deeper layers tend to have more redundant patches. The attention mechanism in the MSA module aggregates different patches layer-by-layer, and a large number of similar patches are produced in the process (as shown in Figure 3). It implies that more redundant patches can be safely removed in deeper layers, and fewer in shallower layers.

Based on the above analysis, we start the pruning procedure from the output layer, and then prune previous layers by transmitting the selected effective patches from top to down. Specially, all the patches preserved in the (l+1)(l+1)-th layer will be also preserved in the ll-th layer. Thus, this top-down pruning procedure can guarantee that shallow layers maintain more patches than the deep layers, which is consistent with the redundancy characteristic of vision transformer.

2 Impact Estimation

With the patch pruning scheme described in the above section, all that’s left is to recognize redundant patches in a vision transformer, i.e., find the optimal mask ml{\bm{m}}_{l} in each layer. Our goal is to prune patches as many as possible to realize maximal acceleration, while maintaining the representation ability of the output feature. Actually, only a part of patch embeddings in the last layer are used to predict the labels of input images for a specific task. For example, in the image classification task, only a patch related to classification (i.e., class token) is sent to the classifier for predicting labels. Other patches in the output layer can be removed safely without affecting network output. Supposing the first patch is the class token, we can get the mask in the last layer, i.e., mL,1=1{\bm{m}}_{L,1}=1, and mL,i=0, i=2,3,,N{\bm{m}}_{L,i}=0,\forall~{}i=2,3,\cdots,N. Then for the other layers, the optimization object is formulated as follows:

The attention mechanism aggregates information from different patches to one patch, which is the main cause to produce redundant patch features. To focus on the attention layer for excavating redundant patches, we reformulate the definition of a block B()\mathcal{B}(\cdot) in a simple formulation. Denoting Plh=softmax(QlhKlh/d)P_{l}^{h}={\rm softmax}\left({Q^{h}_{l}{K^{h}_{l}}^{\top}}/{\sqrt{d}}\right), the MSA module in Eq. 1 can be formulated as:

where O(,Wl)\mathcal{O}(\cdot,W_{l}) is composed of multiple linear projection matrices {Wl}\{W_{l}\} in the MSA and MLP module, as well as non-linear activation functions (e.g., GeLU).

Based on the simplified formulation of a block (Eq. 5), we here explore how a patch in the tt-th layer affects the error EL\mathcal{E}_{L} (Eq LABEL:eq-obj) of effective patches in the last layer. We reverse the transformer and prune it from the last to the first layer sequentially. Thus when it comes to the tt-th layer, all the deeper layers have been pruned. To approximate the significance of each token, we have the following theorem.

where Ath=l=t+1L\mboxdiag(ml)PlhA_{t}^{h}=\prod_{l=t+1}^{L}{\mbox{diag}}({\bm{m}}_{l})P_{l}^{h} and Uth=PthZt1U_{t}^{h}=P_{t}^{h}\left|Z_{t-1}\right|. Ath[:,i]A_{t}^{h}[:,i] denotes the ii-th column of AthA_{t}^{h} and Uth[i,:]U^{h}_{t}[i,:] is ii-th row of UthU_{t}^{h}. [H]Lt+1[H]^{L\sim t+1} denotes all the attention heads in the (t+1)(t+1)-th to LL-th layer.

We use F^lt(Zt1,{mt}tL) (l>t)\widehat{\mathcal{F}}_{l\sim t}(Z_{t-1},\{{\bm{m}}_{t}\}_{t}^{L})~{}(l>t) to denote feature of the ll-th layer in a vision transformer, whose layers behind tt-th layer have been pruned, while the previous layers has not pruned yet, i.e., F^Lt(Zt1,{mt}tL)=B^LB^L1B^t(Zt1)\widehat{\mathcal{F}}_{L\sim t}(Z_{t-1},\{{\bm{m}}_{t}\}_{t}^{L})=\hat{\mathcal{B}}_{L}\circ\hat{\mathcal{B}}_{L-1}\circ\cdots\circ\hat{\mathcal{B}}_{t}(Z_{t-1}). When pruning the patch in the tt-th layer, we compare effective patches of the last layer from two transformers to decide whether the tt-th layer has been pruned. Then the error EL\mathcal{E}_{L} is calculated as:

The error EL\mathcal{E}_{L} in the last layer can be represented by the patches in the (L1)(L-1)-th layer, i.e.,

where |\cdot| is the element-wisely absolute value. The inequality above comes the Lipschitz continuity of function O()\mathcal{O}(\cdot) and CLC_{L} is the Lipschitz constant. Recalling that O()\mathcal{O}(\cdot) is compose of multiple linear projections and non-linear activation function, the condition of Lipschitz continuity is satisfied . ELE_{L} can be further transmitted to previous layers, and for the tt-th layer we have

Then we get the importance of each patch, i.e., st,i=h[H]Lt+1Ath[:,i]Uth[i,:]F2{\bm{s}}_{t,i}=\sum_{h\in[H]^{L\sim t+1}}\left\lVert A_{t}^{h}[:,i]U_{t}^{h}[i,:]\right\rVert_{F}^{2}. ∎

For the ii-th patch in the tt-th layer, st,i{\bm{s}}_{t,i} reflects its impact on the effective output of the final layer. A larger st,i{\bm{s}}_{t,i} implies the corresponding patch has larger impact to the final error, which can reflect the importance of a patch to the model performance. The calculation of st,i{\bm{s}}_{t,i} involves all the attention maps in behind layers and the input feature of the current layer. Before pruning the current layer, we randomly sample a subset of training dataset to calculate the significance scores st{\bm{s}}_{t} and the average st{\bm{s}}_{t} over these data is adopted. The obtained st{\bm{s}}_{t} can be viewed as the real-number score for binary mt{\bm{m}}_{t}.

3 Pruning Procedure

Here we conclude the overall pipeline of the proposed patch slimming method.

We start from the output layer and prune the previous layers layer-by-layer from top to down. Specially, all the patches preserved in the (l+1)(l+1)-th layer will be also preserved in the ll-th layer. The other patches are greedily selected according to their impact scores sl,i{\bm{s}}_{l,i}, wherein patches with larger scores are preserved preferentially. Considering the reconstruction error El+1\mathcal{E}_{l+1} in the (l+1)(l+1)-th layer is directly affected by the patch selection in the ll-th layer, we use it to determine whether the ll-th layer has already enough patches. In practice, we iteratively select rr^{\prime} important patches in each step and continue the selection process in the current layer until El+1\mathcal{E}_{l+1} is less than the given tolerate value ϵ\epsilon. To make El+1\mathcal{E}_{l+1} well maintain the representation ability of current preserved patches, we fine-tune the current block B^l\widehat{\mathcal{B}}_{l} for a few epochs after each step of patch selection. Taking the original feature Zl1Z_{l-1} in the (l1)(l-1)-th layer as input, and the reconstruction error El+1\mathcal{E}_{l+1} as the objective, the parameters in the current block B^l\widehat{\mathcal{B}}_{l} are optimized. Note that the block B^l\widehat{\mathcal{B}}_{l} is a very small model with only one MSA and one MLP modules, the fine-tune process is very fast. After pruning, the mask ml{\bm{m}}_{l} is fixed, and weight parameters in the vision transformer is further fine-tuned to be compatible with the efficient architecture. The procedure of patching slimming for vision transformer is summarized in Algorithm 1.

4 A Dynamic Variant

Experiments

In this section, we empirically investigate the effectiveness of the proposed patch slimming methods for efficient vision transformers (PS-ViT). We evaluate our method on the benchmark ImageNet (ILSVRC2012) dataset, which contains 1000-class natural images, including 1.2M training images and 5k validation images. The proposed method is compared with SOTA pruning methods and we also conduct extensive ablation studies to better understand our method.

We conduct experiments on the standard ViT models (DeiT ), an improved variant network T2T-ViT and the state-of-the art LV-ViT .

Implementation details. For a fair comparison, we follow the training and testing settings in the original papers , and the patch slimming is implemented based on the official pre-trained models. The global tolerant error is select from {0.01, 0.02} to get models with different acceleration rates, and the search granularity rr is set to 10. We fine-tune the current block for 3 epochs after each iteration of patch selection. After determining the proper patches in each layer, the pruned transformers are fine-tuned following the training strategy in . All the experiments are conducted with PyTorch and MindSpore on NVIDIA V100 GPUs.

Competing methods. We compare our patch slimming with several representative model pruning methods including CNN channel pruning methods and BERT pruning methods . SCOP is a SOTA network pruning method for reducing the channels of CNNs, and we re-implement it to reduce the patches in vision transformers. PoWER accelerates BERT inference by progressively eliminating word-vector. HVT directly designs efficient vision transformer architectures by progressively reducing the spatial dimensions through pooling operations.

Experimental results. The experimental results are shown in Table 1, where ‘PS-’ and ‘DPS-’ denote the proposed patch pruning method and its dynamic variant, respectively. We evaluate on three versions of DeiT with different model sizes, i.e., DeiT-Ti, DeiT-S, and DeiT-B. Our method achieve obviously higher performance compared to the existing methods. The SCOP method designed for CNNs achieve poor performance when applied for reducing patches in a vision transformer, implying simply migrating the channel pruning methods cannot work well. PoWER has a larger accuracy drop than our method, indicating the model compression method for NLP models is not optimal for CV models. Compared to the vision transformer structure pruning method VTP , our method investigates a new prospective by pruning patches and achieve higher accuracy with similar FLOPs.

As for T2T-ViT model, our method can reduce the FLOPs by 40.4% and only have a small accuracy decrease (0.4%), which is much better than the compared PoWER method. This indicates that the patch-level redundancy exists in various vision transformer models and our method can well excavate the redundancy.

We further conduct experiments on a SOTA transformer model, LV-ViT , and show the results in Table 2. The results show that our patch pruning method also work well on LV-ViT, e.g., the dynamic patch slimming reduces the FLOPs of LV-ViT-M from 16.0G to 8.3G, still achieving 83.7% top-1 accuracy. Its performance is also superior to other SOTA models such as Swin transformer .

2 Ablation Study

We conduct extensive ablation studies on ImageNet to verify the effectiveness of each component in our method. The DeiT-S model on the ImageNet dataset is used as the base model.

The effect of global tolerant error ϵ\epsilon. The tolerant error ϵ\epsilon affects the balance between computational cost and accuracy of the pruned model, which is empirically investigated in Figure 6. Increasing ϵ\epsilon implies larger reconstructed error between features of the pruned DeiT and original DeiT, while more patches can be pruned to achieve higher acceleration rate. When the reduction of FLOPs is less than 45%, there is almost no accuracy loss (less than 0.4%), which is because that a large number of patches are redundant.

Learned patch pruning vs. uniform pruning. In our method, the number of patches required in a specific layers is determined automatically via the global tolerant value ϵ\epsilon. The architecture of the pruned DeiT model is shown in Figure 6. We can see that a pyramid-like architecture is obtained, where most of the patches in deep layers are pruned while more patches are preserved in shallow layers. To validate the superiority of the learned pyramid architecture, we also implement a baseline that uniformly prunes all the layers with the similar pruning rate. We compare the results of the proposed patch slimming method and uniform pruning in Table 3. The accuracy of uniform pruning is only 77.2%, which incurs a large accuracy drop (-2.6%).

To better understand the behavior of patch pruning in the vision transformer, we prune patches in a single layer to see how the test accuracy change. The experiments are conducted with DeiT-S model on ImageNet.

Redundancy w.r.t. depth. We test the patch redundancy of different layers to verify the motivation of top-down patch slimming procedure. We prune a single layer and keep the same pruning ratio for different layers. Figure 6 shows the accuracy of the pruned model after pruning patches of a certain layer, and each line denotes pruning patches with a given pruning rate. In deeper layers, more patches can be safely removed without large impact on the final performance. However, removing a patch in lower layers usually incurs obvious accuracy drop. The patch redundancy is extremely different across layers and deeper layers have more redundancy, which can be attributed to that the attention mechanism aggregates features from different patches and the deeper patches have been fully communicated with each other. This phenomenon is different from the channel pruning in CNNs, where lower layers are observed to have more channel-level redundancy (Figure 4 in ).

Effectiveness of impact estimation. We define the scores sl{\bm{s}}_{l} in Eq. 6 to approximate significance of a patch by propagating the reconstruction error of effective patches in output layer. To validate its effectiveness, we compare it with two baseline scores: ‘Random’ denotes removing patches in the layer randomly, and ‘Attn’ approximates the importance of a patch only with the norm of its attention map in the current layer. We compare the three scores by utilizing them to prune patches in different layer. The results are presented in Figure 7, where yy-axis is the test accuracy of the pruned models (without fine-tuning). From the results, our impact estimation manner suffers less accuracy loss than the others with the same pruning rate (e.g., 50%). It implies that our method can effectively identify patches that really make contributions to the final prediction.

Conclusion

We propose to accelerate vision transformers by reducing the number of patches required to calculate. Considering that the attention mechanism aggregates different patches layer-by-layer, a top-down framework is developed to excavate the redundant patches. The importance of each patch is also approximated according to its impact on the effective output features. After pruning, a compact vision transformer with a pyramid-like architecture is obtained. Extensive experiments on benchmark datasets validate that the proposed method can effectively reduce the computational cost. In the future, we plan to combine the patch slimming methods with more compression technologies (e.g., weight pruning, model quantization) to explore extremely efficient vision transformers.

Acknowledgment. This work is supported by National Natural Science Foundation of China under Grant No.61876007, Australian Research Council under Project DP210101859 and the University of Sydney SOAR Prize.

References