Linformer: Self-Attention with Linear Complexity
Sinong Wang, Belinda Z. Li, Madian Khabsa, Han Fang, Hao Ma
Introduction
Transformer models (Vaswani et al., 2017) have become ubiquitous for wide variety of problems in natural language processing (NLP), including translation (Ott et al., 2018), text classification, question answering, among others (Raffel et al., 2019; Mohamed et al., 2019). Over the last couple of years, the number of parameters in state-of-the-art NLP transformers has grown drastically, from the original 340 million introduced in BERT-Large to 175 billion in GPT-3 (Brown et al., 2020). Although these large-scale models yield impressive results on wide variety of tasks, training and deploying such model are slow in practice. For example, the original BERT-Large model (Devlin et al., 2019) takes four days to train on 16 Cloud TPUs, and the recent GPT-3 (Brown et al., 2020) consumed orders of magnitude more petaflops / day to train compared to its predecessor, GPT-2 (Radford et al., 2019). Beyond training, deploying Transformer models to real world applications is also expensive, usually requiring extensive distillation (Hinton et al., 2015) or compression.
The main efficiency bottleneck in Transformer models is its self-attention mechanism. Here, each token’s representation is updated by attending to all other tokens in the previous layer. This operation is key for retaining long-term information, giving Transformers the edge over recurrent models on long sequences. However, attending to all tokens at each layer incurs a complexity of with respect to sequence length. Thus, in this paper, we seek to answer the question: can Transformer models be optimized to avoid this quadratic operation, or is this operation required to maintain strong performance?
Prior work has proposed several techniques for improving the efficiency of self-attention. One popular technique is introducing sparsity into attention layers (Child et al., 2019; Qiu et al., 2019; Beltagy et al., 2020) by having each token attend to only a subset of tokens in the whole sequence. This reduces the overall complexity of the attention mechanism to (Child et al., 2019). However, as shown in Qiu et al. (2019), this approach suffers from a large performance drop with limited efficiency gains, i.e., a 2% drop with only 20% speed up. More recently, the Reformer (Kitaev et al., 2020) used locally-sensitive hashing (LSH) to reduce the self-attention complexity to . However, in practice, the Reformer’s efficiency gains only appear on sequences with length (Figure 5 in Kitaev et al. (2020)). Furthermore, the Reformer’s multi-round hashing approach actually increases the number of sequential operations, which further undermines their final efficiency gains.
In this work, we introduce a novel approach for tackling the self-attention bottleneck in Transformers. Our approach is inspired by the key observation that self-attention is low rank. More precisely, we show both theoretically and empirically that the stochastic matrix formed by self-attention can be approximated by a low-rank matrix. Empowered by this observation, we introduce a novel mechanism that reduces self-attention to an operation in both space- and time-complexity: we decompose the original scaled dot-product attention into multiple smaller attentions through linear projections, such that the combination of these operations forms a low-rank factorization of the original attention. A summary of runtimes for various Transformer architectures, including ours, can be found in Table 1.
One predominant application of Transformers, that has seen the most gains, is using them as pretrained language models, whereby models are first pretrained with a language modeling objective on a large corpus, then finetuned on target tasks using supervised data (Devlin et al., 2019; Liu et al., 2019; Lewis et al., 2019). Following Devlin et al. (2019), we pretrain our model on BookCorpus (Zhu et al., 2015) plus English Wikipedia using masked-language-modeling objective. We observe similar pretraining performance to the standard Transformer model. We then finetune our pretrained models on three tasks from GLUE (Wang et al., 2018) and one sentiment analysis task, IMDB reviews (Maas et al., 2011). On these tasks, we find that our model performs comparably, or even slightly better, than the standard pretrained Transformer, while observing significant training and inference speedups.
Backgrounds and Related works
The Transformer is built upon the idea of Multi-Head Self-Attention (MHA), which allows the model to jointly attend to information at different positions from different representation subspaces. MHA is defined as
2 Related works
There has been much prior literature on improving the efficiency of Transformers, especially the self-attention bottleneck. The most common techniques for model efficiency that can be applied to Transformers (some specific to Transformers, others more general-purpose) include:
Mixed Precision (Micikevicius et al., 2017): Using half-precision or mixed-precision representations of floating points is popular in deep learning, and is also widely used in training Transformers (Ott et al., 2019). This technique can be further improved through Quantization Aware Training (Jacob et al., 2018; Fan et al., 2020), where the weights are quantized during training and the gradients are approximated with the Straight-Through Estimator. This line of work is orthogonal to our approach, and we use mixed-precision training by default.
Knowledge Distillation (Hinton et al., 2015): Knowledge distillation aims to transfer the “knowledge" from a large teacher model to a lightweight student model. The student model is then used during inference. However this approach has drawbacks: It does not address speeding up the teacher model during training, and moreover, student models usually suffer performance degradation compared to the teacher model. For example, when distilling a 12-layer BERT to a 6-layer BERT, the student model experiences an average 2.5% performance drop on several benchmark tasks (Sanh et al., 2019).
Sparse Attention (Child et al., 2019): This technique improves the efficiency of self-attention by adding sparsity in the context mapping matrix . For example, the Sparse Transformer (Child et al., 2019) only computes around the diagonal of matrix (instead of the all ). Meanwhile, blockwise self-attention (Qiu et al., 2019) divides into multiple blocks and only computes within the selected blocks. However, these techniques also suffer a large performance degradation, while having only limited additional speed-up, i.e., 2% drop with 20% speed up.
LSH Attention (Kitaev et al., 2020): Locally-sensitive hashing (LSH) attention utilizes a multi-round hashing scheme when computing dot-product attention, which in theory reduces the self-attention complexity to . However, in practice, their complexity term has a large constant and it is only more efficient than the vanilla transformer when sequence length is extremely long.
Improving Optimizer Efficiency: Microbatching (Huang et al., 2019) splits a batch into small microbatches (which can be fit into memory), and then separately runs forward and backward passes on them with gradient accumulation. Gradient checkpointing (Chen et al., 2016) saves memory by only caching activations of a subset of layers. The uncached activations are recomputed during backpropagation from the latest checkpoint. Both techniques trade off time for memory, and do not speed up inference.
As we’ve noted, most common techniques have limitations in reducing both the training and inference time/memory consumption, we investigate how to optimize the self-attention layers and introduce our approach next.
Self-Attention is Low Rank
In this section, we demonstrate that the self-attention mechanism, i.e., the context mapping matrix , is low-rank.
We first provide a spectrum analysis of the context mapping matrix . We use two pretrained transformer models, RoBERTa-base (12-layer stacked transformer) and RoBERTa-large (24-layer stacked transformer) (Liu et al., 2019) on two tasks: masked-language-modeling task on Wiki103 (Merity et al., 2016) and classification task on IMDB (Maas et al., 2011). In Figure 1 (left), we apply singular value decomposition into across different layers and different heads of the model, and plot the normalized cumulative singular value averaged over 10k sentences. The results exhibit a clear long-tail spectrum distribution across each layer, head and task. This implies that most of the information of matrix can be recovered from the first few largest singular values. In Figure 1 (right), we plot a heatmap of the normalized cumulative singular value at the 128-th largest singular value (out of 512). We observe that the spectrum distribution in higher layers is more skewed than in lower layers, meaning that, in higher layers, more information is concentrated in the largest singular values and the rank of is lower.
Below, we provide a theoretical analysis of the above spectrum results.
where the context mapping matrix is defined in (2).
Based on the definition of the context mapping matrix , we can write
For more details, refer to the supplementary materials. ∎
Given the low-rank property of the context mapping matrix , one straightforward idea is to use singular value decomposition (SVD) to approximate with a low-rank matrix , as follows
where , and are the largest singular values and their corresponding singular vectors. Based on the results in Theorem 1 and the Eckart–Young–Mirsky Theorem (Eckart & Young, 1936), one can use to approximate self-attention (2) with error and time and space complexity. However, this approach requires performing an SVD decomposition in each self-attention matrix, which adds additional complexity. Therefore, we propose another approach for low-rank approximation that avoids this added complexity.
Model
In this section, we propose a new self-attention mechanism which allows us to compute the contextual mapping in linear time and memory complexity with respect to sequence length.
Finally, we compute context embeddings for each headi using . Note the above operations only require time and space complexity. Thus, if we can choose a very small projected dimension , such that , then we can significantly reduce the memory and space consumption. The following theorem states that, when (independent of ), one can approximate using linear self-attention (7) with error.
by setting . This result does not utilize the low rank property of matrix (rank()=) and the resultant has a dependency on sequence length . We will further utlize the fact that rank()= to prove the choice of can be constant and independent of sequence length . For more details, refer to the supplementary materials. ∎
In Figure 2 (top right), we plot the inference speed of Linformer and standard Transformer versus sequence length, while holding the total number of tokens fixed. We see that while standard Transformer becomes slower at longer sequence lengths, the Linformer speed remains relatively flat and is significantly faster at long sequences.
Several additional techniques can be introduced on top of Linformer to further optimize for both performance and efficiency:
Parameter sharing between projections: One can share parameters for the linear projection matrices across layers and heads. In particular, we experimented with 3 levels of sharing:
Headwise sharing: for each layer, we share two projection matrices and such that and across all heads .
Key-value sharing: we do headwise sharing, with the additional constraint of sharing the key and value projections. For each layer, we create a single projection matrix such that for each key-value projection matrix across all head .
Layerwise sharing: we use a single projection matrix across all layers, for all heads, and for both key and value.
For example, in a 12-layer, 12-head stacked Transformer model, headwise sharing, key-value sharing and layerwise sharing will introduce 24, 12, and 1 distinct linear projection matrices, respectively.
Nonuniform projected dimension: One can choose a different projected dimension for different heads and layers. As shown in Figure 1 (right), the contextual mapping matrices in different heads and layers have distinct spectrum distributions, and heads in higher layer tend towards a more skewed distributed spectrum (lower rank). This implies one can choose a smaller projected dimension for higher layers.
General projections: One can also choose different kinds of low-dimensional projection methods instead of a simple linear projection. For example, one can choose mean/max pooling, or convolution where the kernel and stride is set to . The convolutional functions contain parameters that require training.
Experiments
In this section, we present experimental results for the the techniques described above. We analyze the techniques one-by-one and explore how they impact performance.
We first compare the pretraining performance of our proposed architecture against RoBERTa (Liu et al., 2019), which is based on the Transformer. Following Devlin et al. (2019), we use BookCorpus (Zhu et al., 2015) plus English Wikipedia as our pretraining set (3300M words). All models are pretrained with the masked-language-modeling (MLM) objective, and the training for all experiments are parallelized across 64 Tesla V100 GPUs with 250k updates.
Effect of projected dimension: We experiment with various values for the projected dimension . (We use the same across all layers and heads of Linformer.) In the Figure 3(a) and (b), we plot the validation perplexity curves for both the standard Transformer and the Linformer across different , for maximum sequence lengths and . As expected, the Linformer performs better as projected dimension increases. However, even at for and for , Linformer’s performance is already nearly on par with the original Transformer.
Effect of sharing projections: In Figure 3(c), we plot the validation perplexity curves for the three parameter sharing strategies (headwise, key-value, and layerwise) with . Note that when we use just a single projection matrix (i.e. for layerwise sharing), the resulting Linformer model’s validation perplexity almost matches that of the the non-shared model. This suggests that we can decrease the number of additional parameters in our model, and consequently, it’s memory consumption, without much detriment to performance.
Effect of longer sequences: We evaluate the effect of sequence length during Linformer pretraining. In the Figure 3(d), we plot the validation perplexity for Linformer with , holding projected dimension fixed at . Note that as sequence length increases, even though our projected dimension is fixed, the final perplexities after convergence remain about the same. This further empirically supports our assertion that the Linformer is linear-time.
2 Downstream Results
Thus far, we have only examined the pretraining perplexities of our model. However, we wish to show that our conclusions hold after finetuning on downstream tasks. We finetune our Linformer on IMDB (Maas et al., 2011) and SST-2 (Socher et al., 2013) (sentiment classification), as well as QNLI (natural language inference) (Rajpurkar et al., 2016), and QQP (textual similarity) (Chen et al., 2018) We do the same with RoBERTa, 12-layer BERT-base and 6-layer distilled BERT. All of our models, including the Transformer baselines, were pretrained with the same objective, pretraining corpus, and up to 250k updates (although our Linformer takes much less wall-clock time to get to 250k updates, and was consequently trained for less time). Results are listed in Table 2.
We observe that the Linformer model () has comparable downstream performance to the RoBERTa model, and in fact even slightly outperforms it at . Moreover, we note that although the Linformer’s layerwise sharing strategy shares a single projection matrix across the entire model, it actually exhibits the best accuracy result of all three parameter sharing strategies. Furthermore, the Linformer pretrained with longer sequence length has similar results to the one pretrained with shorter length , this empirically supports the notion that the performance of Linformer model is mainly determined by the projected dimension instead of the ratio .
3 Inference-time Efficiency Results
In Table 3, we report the inference efficiencies of Linformer (with layerwise sharing) against a standard Transformer. We benchmark both models’ inference speed and memory on a 16GB Tesla V100 GPU card. We randomly generate data up to some sequence length and perform a full forward pass on a multiple batches. We also choose batch size based on the maximum batch size that can fit in memory, and our memory savings are computed based on this number.
From Table 3, we see that even with and , Linformer has faster inference time and allows for a larger maximum batch size than the Transformer. As sequence length increases, the inference-time speed-up and memory savings are even more dramatic. We also plot inference times of both Linformer and Transformer on the 100 data samples in the top right of Figure 2.
Conclusion
Transformer models are notoriously slow to train and deploy in practice since their self-attention operations have time and space complexity with respect to sequence length . In this paper, we demonstrate, both theoretically and empirically, that the stochastic matrix formed by self-attention mechanism is low-rank. We further leverage this observation to propose a new, highly efficient self-attention mechanism. Through a combination of theoretical and empirical analysis, we demonstrate that our proposed approach is with respect to sequence length.
Broader Impact
Our work focuses on making Transformers more efficient by introducing a mechanism that reduces self-attention to linear-time complexity. Potential positive impacts of efficient transformers include increasing the accessibility of our models, both for deployment on devices, as well as during training for research purposes. It also has potential impact on training transformer on images since we can support very long sequences. Furthermore, there are positive environmental benefits associated with decreasing the power consumption of models. As such, we see no immediate negative ethical or societal impacts of our work beyond what applies to other core building blocks of deep learning.
References
Appendix A Proof of Theorem 1
The main proof idea is based on the distributional Johnson–Lindenstrauss lemma (Lindenstrauss, 1984) (JL, for short), the following version is from (Arriaga & Vempala, 2006).
Based on the definition of contextual mapping matrix , we have
where is an diagonal matrix such that
Here we provide a constructive proof. Given any approximation error , define the following matrix.
The above, step (a) is based on the union bound. The step (b) is utilizing the result of JL Lemma. Let , then theorem follows. ∎
Appendix B Proof of Theorem 2
Based on the triangle inequality, we have
The above, step (a) is based on the Cauchy inequality and JL Lemma in (11). The step (b) utilizes the fact that exponential function is Lipchitz continuous in a compact region. Then we can choose a small enough , i.e., such that
The step (c) is based on the JL Lemma defined in (12).
Applying the result in (21) to every row vector of matrix and every column vector of matrix , one can directly prove that, for any row vector of matrix ,
by setting . This result does not utilize the low rank property of matrix (rank()=) and the resultant has a dependency on sequence length . We will further prove the choice of can be constant and independent of sequence length .
We have that, for any row vector of matrix , .
The above, step (a) utilizes the inequality , where ( is the largest eigenvalue) is the spectrum norm of a matrix . The step (b) is based on matrix norm inequality , where is the Frobenius norm of matrix . The step (c) is based on the results of (24). ∎