Are Sixteen Heads Really Better than One?
Paul Michel, Omer Levy, Graham Neubig
Introduction
Transformers (Vaswani et al., 2017) have shown state of the art performance across a variety of NLP tasks, including, but not limited to, machine translation (Vaswani et al., 2017; Ott et al., 2018), question answering (Devlin et al., 2018), text classification (Radford et al., 2018), and semantic role labeling (Strubell et al., 2018). Central to its architectural improvements, the Transformer extends the standard attention mechanism (Bahdanau et al., 2015; Cho et al., 2014) via multi-headed attention (MHA), where attention is computed independently by parallel attention mechanisms (heads). It has been shown that beyond improving performance, MHA can help with subject-verb agreement (Tang et al., 2018) and that some heads are predictive of dependency structures (Raganato and Tiedemann, 2018). Since then, several extensions to the general methodology have been proposed (Ahmed et al., 2017; Shen et al., 2018).
However, it is still not entirely clear: what do the multiple heads in these models buy us? In this paper, we make the surprising observation that – in both Transformer-based models for machine translation and BERT-based (Devlin et al., 2018) natural language inference – most attention heads can be individually removed after training without any significant downside in terms of test performance (§3.2). Remarkably, many attention layers can even be individually reduced to a single attention head without impacting test performance (§3.3).
Based on this observation, we further propose a simple algorithm that greedily and iteratively prunes away attention heads that seem to be contributing less to the model. By jointly removing attention heads from the entire network, without restricting pruning to a single layer, we find that large parts of the network can be removed with little to no consequences, but that the majority of heads must remain to avoid catastrophic drops in performance (§4). We further find that this has significant benefits for inference-time efficiency, resulting in up to a 17.5% increase in inference speed for a BERT-based model.
We then delve into further analysis. A closer look at the case of machine translation reveals that the encoder-decoder attention layers are particularly sensitive to pruning, much more than the self-attention layers, suggesting that multi-headedness plays a critical role in this component (§5). Finally, we provide evidence that the distinction between important and unimporant heads increases as training progresses, suggesting an interaction between multi-headedness and training dynamics (§6).
Background: Attention, Multi-headed Attention, and Masking
In this section we lay out the notational groundwork regarding attention, and also describe our method for masking out attention heads.
In self-attention, every is used as the query to compute a new sequence of representations, whereas in sequence-to-sequence models is typically a decoder state while corresponds to the encoder output.
2 Multi-headed Attention
In multi-headed attention (MHA), independently parameterized attention layers are applied in parallel to obtain the final result:
To allow the different attention heads to interact with each other, transformers apply a non-linear feed-forward network over the MHA’s output, at each transformer layer (Vaswani et al., 2017).
3 Masking Attention Heads
In order to perform ablation experiments on the heads, we modify the formula for MHAtt:
where the are mask variables with values in . When all are equal to , this is equivalent to the formulation in Equation 1. In order to mask head , we simply set .
Are All Attention Heads Important?
We perform a series of experiments in which we remove one or more attention heads from a given architecture at test time, and measure the performance difference. We first remove a single attention head at each time (§3.2) and then remove every head in an entire layer except for one (§3.3).
In all following experiments, we consider two trained models:
This is the original “large” transformer architecture from Vaswani et al. 2017 with 6 layers and 16 heads per layer, trained on the WMT2014 English to French corpus. We use the pretrained model of Ott et al. 2018.https://github.com/pytorch/fairseq/tree/master/examples/translation and report BLEU scores on the newstest2013 test set. In accordance with Ott et al. 2018, we compute BLEU scores on the tokenized output of the model using Moses (Koehn et al., 2007). Statistical significance is tested with paired bootstrap resampling (Koehn, 2004) using compare-mthttps://github.com/neulab/compare-mt (Neubig et al., 2019) with 1000 resamples. A particularity of this model is that it features 3 distinct attention mechanism: encoder self-attention (Enc-Enc), encoder-decoder attention (Enc-Dec) and decoder self-attention (Dec-Dec), all of which use MHA.
BERT
BERT (Devlin et al., 2018) is a single transformer pre-trained on an unsupervised cloze-style “masked language modeling task” and then fine-tuned on specific tasks. At the time of its inception, it achieved state-of-the-art performance on a variety of NLP tasks. We use the pre-trained base-uncased model of Devlin et al. 2018 with 12 layers and 12 attention heads which we fine-tune and evaluate on MultiNLI (Williams et al., 2018). We report accuracies on the “matched” validation set. We test for statistical significance using the t-test. In contrast with the WMT model, BERT only features one attention mechanism (self-attention in each layer).
2 Ablating One Head
To understand the contribution of a particular attention head , we evaluate the model’s performance while masking that head (i.e. replacing with zeros). If the performance sans is significantly worse than the full model’s, is obviously important; if the performance is comparable, is redundant given the rest of the model.
Figures 1(a) and 1(b) shows the distribution of heads in term of the model’s score after masking it, for WMT and BERT respectively. We observe that the majority of attention heads can be removed without deviating too much from the original score. Surprisingly, in some cases removing an attention head results in an increase in BLEU/accuracy.
To get a finer-grained view on these results we zoom in on the encoder self-attention layers of the WMT model in Table 1. Notably, we see that only 8 (out of 96) heads cause a statistically significant change in performance when they are removed from the model, half of which actually result in a higher BLEU score. This leads us to our first observation: at test time, most heads are redundant given the rest of the model.
3 Ablating All Heads but One
This observation begets the question: is more than one head even needed? Therefore, we compute the difference in performance when all heads except one are removed, within a single layer. In Table 3 and Table 3 we report the best score for each layer in the model, i.e. the score when reducing the entire layer to the single most important head.
We find that, for most layers, one head is indeed sufficient at test time, even though the network was trained with 12 or 16 attention heads. This is remarkable because these layers can be reduced to single-headed attention with only th (resp. th) of the number of parameters of a vanilla attention layer. However, some layers do require multiple attention heads; for example, substituting the last layer in the encoder-decoder attention of WMT with a single head degrades performance by at least 13.5 BLEU points. We further analyze when different modeling components depend on more heads in §5.
Additionally, we verify that this result holds even when we don’t have access to the evaluation set when selecting the head that is “best on its own”. For this purpose, we select the best head for each layer on a validation set (newstest2013 for WMT and a 5,000-sized randomly selected subset of the training set of MNLI for BERT) and evaluate the model’s performance on a test set (newstest2014 for WMT and the MNLI-matched validation set for BERT). We observe that similar findings hold: keeping only one head does not result in a statistically significant change in performance for 50% (resp. 100%) of layers of WMT (resp. BERT). The detailed results can be found in Appendix A.
4 Are Important Heads the Same Across Datasets?
There is a caveat to our two previous experiments: these results are only valid on specific (and rather small) test sets, casting doubt on their generalizability to other datasets. As a first step to understand whether some heads are universally important, we perform the same ablation study on a second, out-of-domain test set. Specifically, we consider the MNLI “mismatched” validation set for BERT and the MTNT English to French test set (Michel and Neubig, 2018) for the WMT model, both of which have been assembled for the very purpose of providing contrastive, out-of-domain test suites for their respective tasks.
We perform the same ablation study as §3.2 on each of these datasets and report results in Figures 2(a) and 2(b). We notice that there is a positive, correlation () between the effect of removing a head on both datasets. Moreover, heads that have the highest effect on performance on one domain tend to have the same effect on the other, which suggests that the most important heads from §3.2 are indeed “universally” important.
Iterative Pruning of Attention Heads
In our ablation experiments (§3.2 and §3.3), we observed the effect of removing one or more heads within a single layer, without considering what would happen if we altered two or more different layers at the same time. To test the compounding effect of pruning multiple heads from across the entire model, we sort all the attention heads in the model according to a proxy importance score (described below), and then remove the heads one by one. We use this iterative, heuristic approach to avoid combinatorial search, which is impractical given the number of heads and the time it takes to evaluate each model.
As a proxy score for head importance, we look at the expected sensitivity of the model to the mask variables defined in §2.3:
where is the data distribution and the loss on sample . Intuitively, if has a high value then changing is liable to have a large effect on the model. In particular we find the absolute value to be crucial to avoid datapoints with highly negative or positive contributions from nullifying each other in the sum. Plugging Equation 1 into Equation 2 and applying the chain rule yields the following final expression for :
This formulation is reminiscent of the wealth of literature on pruning neural networks (LeCun et al., 1990; Hassibi and Stork, 1993; Molchanov et al., 2017, inter alia). In particular, it is equivalent to the Taylor expansion method from Molchanov et al. (2017).
2 Effect of Pruning on BLEU/Accuracy
Figures 3(a) (for WMT) and 3(b) (for BERT) describe the effect of attention-head pruning on model performance while incrementally removing of the total number of heads in order of increasing at each step. We also report results when the pruning order is determined by the score difference from §3.2 (in dashed lines), but find that using is faster and yields better results.
We observe that this approach allows us to prune up to and of heads from WMT and BERT (respectively), without incurring any noticeable negative impact. Performance drops sharply when pruning further, meaning that neither model can be reduced to a purely single-head attention model without retraining or incurring substantial losses to performance. We refer to Appendix B for experiments on four additional datasets.
3 Effect of Pruning on Efficiency
Beyond the downstream task performance, there are intrinsic advantages to pruning heads. First, each head represents a non-negligible proportion of the total parameters in each attention layer ( for WMT, for BERT), and thus of the total model (roughly speaking, in both our models, approximately one third of the total number of parameters is devoted to MHA across all layers).Slightly more in WMT because of the Enc-Dec attention. This is appealing in the context of deploying models in memory-constrained settings.
Moreover, we find that actually pruning the heads (and not just masking) results in an appreciable increase in inference speed. Table 4 reports the number of examples per second processed by BERT, before and after pruning 50% of all attention heads. Experiments were conducted on two different machines, both equipped with GeForce GTX 1080Ti GPUs. Each experiment is repeated 3 times on each machine (for a total of 6 datapoints for each setting). We find that pruning half of the model’s heads speeds up inference by up to for higher batch sizes (this difference vanishes for smaller batch sizes).
When Are More Heads Important? The Case of Machine Translation
As shown in Table 3, not all MHA layers can be reduced to a single attention head without significantly impacting performance. To get a better idea of how much each part of the transformer-based translation model relies on multi-headedness, we repeat the heuristic pruning experiment from §4 for each type of attention separately (Enc-Enc, Enc-Dec, and Dec-Dec).
Figure 4 shows that performance drops much more rapidly when heads are pruned from the Enc-Dec attention layers. In particular, pruning more than of the Enc-Dec attention heads will result in catastrophic performance degradation, while the encoder and decoder self-attention layers can still produce reasonable translations (with BLEU scores around 30) with only 20% of the original attention heads. In other words, encoder-decoder attention is much more dependent on multi-headedness than self-attention.
Dynamics of Head Importance during Training
Previous sections tell us that some heads are more important than others in trained models. To get more insight into the dynamics of head importance during training, we perform the same incremental pruning experiment of §4.2 at every epoch. We perform this experiment on a smaller version of WMT model (6 layers and 8 heads per layer), trained for German to English translation on the smaller IWSLT 2014 dataset Cettolo et al. (2015). We refer to this model as IWSLT.
Figure 5 reports, for each level of pruning (by increments of 10% — 0% corresponding to the original model), the evolution of the model’s score (on newstest2013) for each epoch. For better readability we display epochs on a logarithmic scale, and only report scores every 5 epochs after the 10th). To make scores comparable across epochs, the Y axis reports the relative degradation of BLEU score with respect to the un-pruned model at each epoch. Notably, we find that there are two distinct regimes: in very early epochs (especially 1 and 2), performance decreases linearly with the pruning percentage, i.e. the relative decrease in performance is independent from , indicating that most heads are more or less equally important. From epoch 10 onwards, there is a concentration of unimportant heads that can be pruned while staying within of the original BLEU score (up to of total heads).
This suggests that the important heads are determined early (but not immediately) during the training process. The two phases of training are reminiscent of the analysis by Shwartz-Ziv and Tishby (2017), according to which the training of neural networks decomposes into an “empirical risk minimization” phase, where the model maximizes the mutual information of its intermediate representations with the labels, and a “compression” phase where the mutual information with the input is minimized. A more principled investigation of this phenomenon is left to future work.
Related work
The use of an attention mechanism in NLP and in particular neural machine translation (NMT) can be traced back to Bahdanau et al. (2015) and Cho et al. (2014), and most contemporaneous implementations are based on the formulation from Luong et al. (2015). Attention was shortly adapted (successfully) to other NLP tasks, often achieving then state-of-the-art performance in reading comprehension (Cheng et al., 2016), natural language inference (Parikh et al., 2016) or abstractive summarization (Paulus et al., 2017) to cite a few. Multi-headed attention was first introduced by Vaswani et al. (2017) for NMT and English constituency parsing, and later adopted for transfer learning (Radford et al., 2018; Devlin et al., 2018), language modeling (Dai et al., 2019; Radford et al., 2019), or semantic role labeling (Strubell et al., 2018), among others.
There is a rich literature on pruning trained neural networks, going back to LeCun et al. (1990) and Hassibi and Stork (1993) in the early 90s and reinvigorated after the advent of deep learning, with two orthogonal approaches: fine-grained “weight-by-weight” pruning (Han et al., 2015) and structured pruning (Anwar et al., 2017; Li et al., 2016; Molchanov et al., 2017), wherein entire parts of the model are pruned. In NLP, structured pruning for auto-sizing feed-forward language models was first investigated by Murray and Chiang (2015). More recently, fine-grained pruning approaches have been popularized by See et al. (2016) and Kim and Rush (2016) (mostly on NMT).
Concurrently to our own work, Voita et al. (2019) have made to a similar observation on multi-head attention. Their approach involves using LRP (Binder et al., 2016) for determining important heads and looking at specific properties such as attending to adjacent positions, rare words or syntactically related words. They propose an alternate pruning mechanism based on doing gradient descent on the mask variables . While their approach and results are complementary to this paper, our study provides additional evidence of this phenomenon beyond NMT, as well as an analysis of the training dynamics of pruning attention heads.
Conclusion
We have observed that MHA does not always leverage its theoretically superior expressiveness over vanilla attention to the fullest extent. Specifically, we demonstrated that in a variety of settings, several heads can be removed from trained transformer models without statistically significant degradation in test performance, and that some layers can be reduced to only one head. Additionally, we have shown that in machine translation models, the encoder-decoder attention layers are much more reliant on multi-headedness than the self-attention layers, and provided evidence that the relative importance of each head is determined in the early stages of training. We hope that these observations will advance our understanding of MHA and inspire models that invest their parameters and attention more efficiently.
Acknowledgments
The authors would like to extend their thanks to the anonymous reviewers for their insightful feedback. We are also particularly grateful to Thomas Wolf from Hugging Face, whose independent reproduction efforts allowed us to find and correct a bug in our speed comparison experiments. This research was supported in part by a gift from Facebook.
References
Appendix A Ablating All Heads but One: Additional Experiment.
Tables 6 and 6 report the difference in performance when only one head is kept for any given layer. The head is chosen to be the best head on its own on a separate dataset.
Appendix B Additional Pruning Experiments
We report additional results for the importance-driven pruning approach from Section 4 on 4 additional datasets:
SST-2: The GLUE version of the Stanford Sentiment Treebank (Socher et al., 2013). We use a fine-tuned BERT as our model.
CoLA: The GLUE version of the Corpus of Linguistic Acceptability (Warstadt et al., 2018). We use a fine-tuned BERT as our model.
MRPC: The GLUE version of the Microsoft Research Paraphrase Corpus (Dolan and Brockett, 2005). We use a fine-tuned BERT as our model.
IWSLT: The German to English translation dataset from IWSLT 2014 (Cettolo et al., 2015). We use the same smaller model described in Section 6.
Figure 6 shows that in some cases up to 60% (SST-2) or 50% (CoLA, MRPC) of heads can be pruned without a noticeable impact on performance.