Depth-Adaptive Transformer

Maha Elbayad, Jiatao Gu, Edouard Grave, Michael Auli

Introduction

The size of modern neural sequence models (Gehring et al., 2017; Vaswani et al., 2017; Devlin et al., 2019) can amount to billions of parameters (Radford et al., 2019). For example, the winning entry of the WMT’19 news machine translation task in English-German used an ensemble totaling two billion parameters (Ng et al., 2019). While large models are required to do better on hard examples, small models are likely to perform as well on easy ones, e.g., the aforementioned ensemble is probably not required to translate a short phrase such as "Thank you". However, current models apply the same amount of computation regardless of whether the input is easy or hard.

In this paper, we propose Transformers which adapt the number of layers to each input in order to achieve a good speed-accuracy trade off at inference time. We extend Graves (2016; ACT) who introduced dynamic computation to recurrent neural networks in several ways: we apply different layers at each stage, we investigate a range of designs and training targets for the halting module and we explicitly supervise through simple oracles to achieve good performance on large-scale tasks.

Universal Transformers (UT) rely on ACT for dynamic computation and repeatedly apply the same layer (Dehghani et al., 2018). Our work considers a variety of mechanisms to estimate the network depth and applies a different layer at each step. Moreover, Dehghani et al. (2018) fix the number of steps for large-scale machine translation whereas we vary the number of steps to demonstrate substantial improvements in speed at no loss in accuracy. UT uses a layer which contains as many weights as an entire standard Transformer and this layer is applied several times which impacts speed. Our approach does not increase the size of individual layers. We also extend the resource efficient object classification work of Huang et al. (2017) and Bolukbasi et al. (2017) to structured prediction where dynamic computation decisions impact future computation. Related work from computer vision includes Teerapittayanon et al. (2016); Figurnov et al. (2017) and Wang et al. (2018) who explored the idea of dynamic routing either by exiting early or by skipping layers.

We encode the input sequence using a standard Transformer encoder to generate the output sequence with a varying amount of computation in the decoder network. Dynamic computation poses a challenge for self-attention because omitted layers in prior time-steps may be required in the future. We experiment with two approaches to address this and show that a simple approach works well (§2). Next, we investigate different mechanisms to control the amount of computation in the decoder network, either for the entire sequence or on a per-token basis. This includes multinomial and binomial classifiers supervised by the model likelihood or whether the argmax is already correct as well as simply thresholding the model score (§3). Experiments on IWSLT14 German-English translation (Cettolo et al., 2014) as well as WMT’14 English-French translation show that we can match the performance of well tuned baseline models at up to 76% less computation (§4).

Anytime structured prediction

We first present a model that can make predictions at different layers. This is known as anytime prediction for computer vision models (Huang et al., 2017) and we extend it to structured prediction.

We base our approach on the Transformer sequence-to-sequence model (Vaswani et al., 2017). Both encoder and decoder networks contain NN stacked blocks where each has several sub-blocks surrounded by residual skip-connections. The first sub-block is a multi-head dot-product self-attention and the second a position-wise fully connected feed-forward network. For the decoder, there is an additional sub-block after the self-attention to add source context via another multi-head attention.

Given a pair of source-target sequences (x,y)(\boldsymbol{x},\boldsymbol{y}), x\boldsymbol{x} is processed with the encoder to give representations s=(s1,,sx)\boldsymbol{s}=(s_{1},\ldots,s_{{|\boldsymbol{x}|}}). Next, the decoder generates y\boldsymbol{y} step-by-step. For every new token yt\boldsymbol{y}_{t} input to the decoder at time tt, the NN decoder blocks process it to yield hidden states (htn)1nN{(h_{t}^{n})}_{1\leq n\leq N}:

where blockn\operatorname{block}_{n} is the mapping associated with the nthn^{\text{th}} block and embed\operatorname{embed} is a lookup table.

The output distribution for predicting the next token is computed by feeding the activations of the last decoder layer htNh_{t}^{N} into a softmax normalized output classifier WW:

Standard Transformers have a single output classifier attached to the top of the decoder network. However, for dynamic computation we need to be able to make predictions at different stages of the network. To achieve this, we attach output classifiers Cn{\mathscr{C}_{n}} parameterized by WnW_{n} to the output htnh_{t}^{n} of each of the NN decoder blocks:

The classifiers can be parameterized independently or we can share the weights across the NN blocks.

2 Training multiple output classifiers

Dynamic computation enables the model to use any of the NN exit classifiers instead of just the final one. Some of our models can choose a different output classifier at each time-step which results in an exponential number of possible output classifier combinations in the sequence length.

We consider two possible ways to train the decoder network (Figure 1). Aligned training optimizes all classifiers simultaneously and assumes all previous hidden states required by the self-attention are available. However, at test time this is often not the case when we choose a different exit for every token which leads to misaligned states. Instead, mixed training samples several sequences of exits for a given sentence and exposes the model to hidden states from different layers.

Generally, for a given output sequence y\boldsymbol{y}, we have a sequence of chosen exits (n1,,ny)(n_{1},\ldots,n_{{|\boldsymbol{y}|}}) and we denote the block at which we exit at time tt as ntn_{t}.

Aligned training assumes all hidden states h1n1,,htn1h^{n-1}_{1},\dots,h^{n-1}_{t} are available in order to compute self-attention and it optimizes NN loss terms, one for each exit (Figure 1(a)):

The compound loss Ldec(x,y)\mathcal{L}_{dec}(\boldsymbol{x},\boldsymbol{y}) is a weighted average of NN terms w.r.t. to (ω1,ωN)(\omega_{1},\ldots\omega_{N}). We found that uniform weights achieve better BLEU compared to other weighing schemes (c.f. Appendix A). At inference time, not all time-steps will have hidden states for the current layer since the model exited early. In this case, we simply copy the last computed state to all upper layers, similar to mixed training (§2.2.2). However, we do apply layer-specific key and value projections to the copied state.

2.2 Mixed training

Aligned training assumes that all hidden states of the previous time-steps are available but this assumption is unrealistic since an early exit may have been chosen previously. This creates a mismatch between training and testing. Mixed training reduces the mismatch by training the model to use hidden states from different blocks of previous time-steps for self-attention. We sample MM different exit sequences (n1(m),ny(m))1mM{(n_{1}^{(m)},\ldots n_{|\boldsymbol{y}|}^{(m)})}_{1\leq m\leq M} and evaluate the following loss:

When nt<Nn_{t}<N, we copy the last evaluated hidden state htnh_{t}^{n} to the subsequent layers so that the self-attention of future time steps can function as usual (see Figure 1(b)).

Adaptive depth estimation

We present a variety of mechanisms to predict the decoder block at which the model will stop and output the next token, or when it should exit to achieve a good speed-accuracy trade-off. We consider two approaches: sequence-specific depth decodes all output tokens using the same block (§3.1) while token-specific depth determines a separate exit for each individual token (§3.2).

We model the distribution of exiting at time-step tt with a parametric distribution qtq_{t} where qt(n)q_{t}(n) is the probability of computing block1,,blockn\operatorname{block}_{1},\dots,\operatorname{block}_{n} and then emitting a prediction with Cn{\mathscr{C}_{n}}. The parameters of qtq_{t} are optimized to match an oracle distribution qtq^{*}_{t} with cross-entropy:

The exit loss (Lexit\mathcal{L}_{\text{exit}}) is back-propagated to the encoder-decoder parameters. We simultaneously optimize the decoding loss (Eq. (4)) and the exit loss (Eq. (6)) balanced by a hyper-parameter α\alpha to ensure that the model maintains good generation accuracy. The final loss takes the form:

In the following we describe for each approach how the exit distribution qtq_{t} is modeled (illustrated in Figure 2) and how the oracle distribution qtq_{t}^{*} is inferred.

where WhW_{h} and bhb_{h} are the weights and biases of the halting mechanism. We consider two oracles to determine which of the NN blocks should be chosen. The first is based on the sequence likelihood and the second looks at an aggregate of the correctly predicted tokens at each block.

This oracle is based on the likelihood of the entire sequence after each block and we optimize it with the Dirac delta centered around the exit with the highest sequence likelihood.

We add a regularization term to encourage lower exits that achieve good likelihood:

Likelihood ignores whether the model already assigns the highest score to the correct target. Instead, this oracle chooses the lowest block that assigns the largest score to the correct prediction. For each block, we count the number of correctly predicted tokens over the sequence and choose the block with the most number of correct tokens. A regularization term controls the trade-off between speed and accuracy.

Oracles based on test metrics such as BLEU are feasible but expensive to compute since we would need to decode every training sentence NN times. We leave this for future work.

2 Token-specific depth:

The token-specific approach can choose a different exit at every time-step. We consider two options for the exit distribution qtq_{t} at time-step t: a multinomial with a classifier conditioned on the first decoder hidden state ht1h_{t}^{1} and a geometric-like where an exit probability χtn\chi_{t}^{n} is estimated after each block based on the activations of the current block htnh_{t}^{n}.

The most probable exit argmaxqt(nx,y<t)\arg\max q_{t}(n|\boldsymbol{x},\boldsymbol{y}_{<t}) is selected at inference.

The two classifiers are trained to minimize the cross-entropy with respect to either one the following oracle distributions:

At each time-step tt, we choose the block whose exit classifier has the highest likelihood plus a regularization term weighted by λ\lambda to encourage lower exits.

This oracle ignores the impact of the current decision on the future time-steps and we therefore consider smoothing the likelihoods with an RBF kernel.

where we control the size of the surrounding context with σ\sigma the kernel width. We refer to this oracle as LL(σ,λ)\operatorname{LL}(\sigma,\lambda) including the case where we only look at the likelihood of the current token with σ0\sigma\rightarrow 0.

Similar to the likelihood-based oracle we can look at the correctness of the prediction at time-step tt as well as surrounding positions. We define the target qtq_{t}^{*} as follows:

Finally, we consider thresholding the model predictions (§2), i.e., exit when the maximum score of the current output classifier p(yt+1htn)p(y_{t+1}|h_{t}^{n}) exceeds a hyper-parameter threshold τn\tau_{n}. This does not require training and the thresholds τ=(τ1,,τN1)\boldsymbol{\tau}=(\tau_{1},\ldots,\tau_{N-1}) are simply tuned on the valid set to maximize BLEU. Concretely, for 10k iterations, we sample a sequence of thresholds τU(0,1)N1\boldsymbol{\tau}\sim\mathcal{U}(0,1)^{N-1}, decode the valid set with the sampled thresholds and then evaluate the BLEU score and computational cost achieved with this choice of τ\boldsymbol{\tau}. After 10k evaluations we pick the best performing thresholds, that is τ\boldsymbol{\tau} with the highest BLEU in each cost segment.

Experiments

We evaluate on several benchmarks and measure tokenized BLEU (Papineni et al., 2002):

IWSLT’14 German to English (De-En). We use the setup of Edunov et al. (2018) and train on 160K sentence pairs. We use N=6N=6 blocks, a feed-forward network (ffn) of intermediate-dimension 10241024, 4 heads, dropout 0.30.3, embedding dimension denc=512d_{\text{enc}}=512 for the encoder and ddec=256d_{\text{dec}}=256 for the decoder. Embeddings are untied with 6 different output classifiers. We evaluate with a single checkpoint and a beam of width 5.

WMT’14 English to French (En-Fr). We also experiment on the much larger WMT’14 English-French task comprising 35.5m training sentence pairs. We develop on 26k held out pairs and test on newstest14. The vocabulary consists of 44k joint BPE types (Sennrich et al., 2016). We use a Transformer big architecture and tie the embeddings of the encoder, the decoder and the output classifiers ((Wn)1n6(W_{n})_{1\leq n\leq 6}; §2.1). We average the last ten checkpoints and use a beam of width 4.

Models are implemented in fairseq (Ott et al., 2019) and are trained with Adam (Kingma & Ba, 2015). We train for 50k updates on 128 GPUs with a batch size of 460k tokens for WMT’14 En-Fr and on 2 GPUs with 8k tokens per batch for IWSLT’14 De-En. To stabilize training, we re-normalize the gradients if the norm exceeds gclip=3g_{\text{clip}}=3.

For models with adaptive exits, we first train without exit prediction (α=0\alpha=0 in Eq. (7)) using the aligned mode (c.f. §2.2.1) for 50k updates and then continue training with α0\alpha\neq 0 until convergence. The exit prediction classifiers are parameterized by a single linear layer (Eq. (8)) with the same input dimension as the embedding dimension, e.g., 10241024 for a big Transformer; the output dimension is NN for a multinomial classifier or one for geometric-like. We exit when χt,n>0.5\chi_{t,n}>0.5 for geometric-like classifiers.

2 Training multiple output classifiers

We first compare the two training regimes for our model (§2.2). Aligned training performs self-attention on aligned states (§2.2.1) and mixed training exposes self-attention to hidden states from different blocks (§2.2.2).

We compare the two training modes when choosing either a uniformly sampled exit or a fixed exit n=1,,6n=1,\dots,6 at inference time for every time-step. The sampled exit experiment tests the robustness to mixed hidden states and the fixed exit setup simulates an ideal setting where all previous states are available. As baselines we show six separate standard Transformers with N[1..6]N\in[1..6] decoder blocks. All models are trained with an equal number of updates and mixed training with M=6M{=}6 paths is most comparable to aligned training since the number of losses per sample is identical.

Table 1 shows that aligned training outperforms mixed training both for fixed exits as well as for randomly sampled exits. The latter is surprising since aligned training never exposes the self-attention mechanism to hidden states from other blocks. We suspect that this is due to the residual connections which copy features from lower blocks to subsequent layers and which are ubiquitous in Transformer models (§2). Aligned training also performs very competitively to the individual baseline models.

Aligned training is conceptually simple and fast. We can process a training example with NN exits in a single forward/backward pass while MM passes are needed for mixed training. In the remaining paper, we use the aligned mode to train our models. Appendix A reports experiments with weighing the various output classifiers differently but we found that a uniform weighting scheme worked well. On our largest setup, WMT’14 English-French, the training time of an aligned model with six output classifiers increases only marginally by about 1% compared to a baseline with a single output classifier keeping everything else equal.

3 Adaptive depth estimation

Next, we train models with aligned states and compare adaptive depth classifiers in terms of BLEU as well as computational effort. We measure the latter as the average exit per output token (AE\operatorname{AE}).

As baselines we use again six separate standard Transformers with N[1..6]N\in[1..6] with a single output classifier. We also measure the performance of the aligned mode trained model for fixed exits n[1..6]n\in[1..6]. For the adaptive depth token-specific models (Tok), we train four combinations: likelihood-based oracle (LL) + geometric-like, likelihood-based oracle (LL) + multinomial, correctness based oracle (C) + geometric-like and correctness-based oracle (C) + multinomial. Sequence-specific models (Seq) are trained with the correctness oracle (C) and the likelihood oracle (LL) with different values for the regularization weight λ\lambda. All parameters are tuned on the valid set and we report results on the test set for a range of average exits.

Figure 3 shows that the aligned model (blue line) can match the accuracy of a standard 6-block Transformer (black line) at half the number of layers (n=3n=3) by always exiting at the third block. The aligned model outperforms the baseline for n=2,,6n=2,\dots,6.

For token specific halting mechanisms (Figure 3(a)) the geometric-like classifiers achieves a better speed-accuracy trade-off than the multinomial classifiers (filled vs. empty triangles). For geometric-like classifiers, the correctness oracle outperforms the likelihood oracle (Tok-C geometric-like vs. Tok-LL geometric-like) but the trend is less clear for multinomial classifiers. At the sequence-level, likelihood is the better oracle (Figure 3(b)).

The rightmost Tok-C geometric-like point (σ=0\sigma=0, λ=0.1\lambda=0.1) achieves 34.73 BLEU at AE=1.42\operatorname{AE}=1.42 which corresponds to similar accuracy as the N=6N=6 baseline at 76% fewer decoding blocks. The best accuracy of the aligned model is 34.95 BLEU at exit 5 and the best comparable Tok-C geometric-like configuration achieves 34.99 BLEU at AE=1.97\operatorname{AE}=1.97, or 61% fewer decoding blocks. When fixing the budget to two decoder blocks, Tok-C geometric-like with AE=1.97\operatorname{AE}=1.97 achieves BLEU 35, a 0.64 BLEU improvement over the baseline (N=2N=2) and aligned which both achieve BLEU 34.35.

Confidence thresholding (Figure 3(c)) performs very well but cannot outperform Tok-C geometric-like.

In this section, we look at the effect of the two main hyper-parameters on IWSLT’14 De-En: λ\lambda the regularization scale (c.f. Eq. (9)), and the RBF kernel width σ\sigma used to smooth the scores (c.f. Eq. (15)). We train Tok-LL Geometric-like models and evaluate them with their default thresholds (exit if χtn>0.5\chi_{t}^{n}>0.5). Figure 4(a) shows that higher values of λ\lambda lead to lower exits. Figure 4(b) shows the effect of σ\sigma for two values of λ\lambda. In both curves, we see that wider kernels favor higher exits.

4 Scaling the adaptive-depth models

Finally, we take the best performing models form the IWSLT benchmark and test them on the large WMT’14 English-French benchmark. Results on the test set (Figure 5(a)) show that adaptive depth still shows improvements but that they are diminished in this very large-scale setup. Confidence thresholding works very well and sequence-specific depth approaches improve only marginally over the baseline. Tok-LL geometric-like can match the best baseline result of BLEU 43.4 (N=6N=6) by using only AE=2.40\operatorname{AE}=2.40 which corresponds to 40% of the decoder blocks; the best aligned result of BLEU 43.6 can be matched with AE=3.25\operatorname{AE}=3.25. In this setup, Tok-LL geometric-like slightly outperforms the Tok-C counterpart.

Confidence thresholding matches the accuracy of the N=6N{=}6 baseline with AE\operatorname{AE} 2.5 or 59% fewer decoding blocks. However, confidence thresholding requires computing the output classifier at each block to determine whether to halt or continue. This is a large overhead since output classifiers predict 44k types for this benchmark (§4.1). To better account for this, we measure the average number of FLOPs per output token (details in Appendix B). Figure 5(b) shows that the Tok-LL geometric-like approach provides a better trade-off when the overhead of the output classifiers is considered.

5 Qualitative results

The exit distribution for a given sample can give insights into what a Depth-Adaptive Transformer decoder considers to be a difficult task. In this section, for each hypothesis y~\widetilde{\boldsymbol{y}}, we will look at the sequence of selected exits (n1,,ny~)(n_{1},\ldots,n_{|\widetilde{\boldsymbol{y}}|}) and the probability scores (p1,py~)(p_{1},\ldots p_{|\widetilde{\boldsymbol{y}}|}) with pt=p(y~tht1nt)p_{t}=p(\widetilde{y}_{t}|h_{t-1}^{n_{t}}) i.e. the confidence of the model in the sampled token at the selected exit.

Figures 6 and 7 show hypotheses from the WMT’14 En-Fr and IWSLT’14 De-En test sets, respectively. For each hypothesis we state the exits and the probability scores. In Figure 6(a), predicting ‘présent’ (meaning ‘present’) is hard. A straightforward translation is ‘était là’ but the model chooses ‘present’ which is also appropriate. In Figure 6(b), the model uses more computation to predict the definite article ‘les’ since the source has omitted the article for ‘passengers’.

A clear trend in both benchmarks is that the model requires less computation near the end of decoding to generate the end of sequence marker </s>{<}/\text{s}{>} and the preceding full-stop when relevant. In Figure 8, we show the distribution of the exits at the beginning and near the end of test set hypotheses. We consider the beginning of a sequence to be the first 10% of tokens and the end as the last 10% of tokens. The exit distributions are shown for three models on WMT’14 En-Fr: Model1\text{Model}_{1} has an average exit of AE=2.53\operatorname{AE}=2.53, Model2\text{Model}_{2} exits at AE=3.79\operatorname{AE}=3.79 on average and Model3\text{Model}_{3} with AE=4.68\operatorname{AE}=4.68. Within the same models, deep exits late are used at the beginning of the sequence and early exits are selected near the end. For heavily regularized models such as Model1\text{Model}_{1} with AE=2.53\operatorname{AE}=2.53, the disparity between beginning and end is less severe as the model exits early most of the time. Model2\text{Model}_{2} and Model3\text{Model}_{3} are less regularized (higher AE) and tend to use late exits at the beginning of the sequence and early exits near the end. On the other hand, the more regularized Model1\text{Model}_{1} with AE=2.53\operatorname{AE}=2.53 exits early most of the time. There is also a correlation between the model probability and the amount of computation, particularly in models with low AE\operatorname{AE}{}. Figure 9 shows the joint histogram of the scores and the selected exit. For both Model1\text{Model}_{1} and Model2\text{Model}_{2}, low exits (n2n\leq 2) are used in the high confidence range [0.81][0.8-1] and high exits (n4n\geq 4) are used in the low-confidence range [00.5][0-0.5]. Model3\text{Model}_{3} has a high average exit (AE=4.68\operatorname{AE}=4.68) so most tokens exit late, however, in low confidence ranges the model does not exit earlier than n=5n=5.

Conclusion

We extended anytime prediction to the structured prediction setting and introduced simple but effective methods to equip sequence models to make predictions at different points in the network. We compared a number of different mechanisms to predict the required network depth and find that a simple correctness based geometric-like classifier obtains the best trade-off between speed and accuracy. Results show that the number of decoder layers can be reduced by more than three quarters at no loss in accuracy compared to a well tuned Transformer baseline.

We thank Laurens van der Maaten for fruitful comments and suggestions.

References

Appendix A Loss scaling

In this section we experiment with different weights for scaling the output classifier losses. Instead of uniform weighting, we bias towards specific output classifiers by assigning higher weights to their losses. Table 2 shows that weighing the classifiers equally provides good results.

Adding intermediate supervision at different levels of the decoder results in richer gradients for lower blocks compared to upper blocks. This is because earlier layers affect more loss terms in the compound loss of Eq. (4). To balance the gradients of each block in the decoder, we scale up the gradients of each loss term (LLn)(-\operatorname{LL}_{n}) when it is updating the parameters of its associated block (blockn\operatorname{block}_{n} with parameters θn\theta_{n}) and revert it back to its normal scale before back-propagating it to the previous blocks. Figure 10 and Algorithm 1 illustrate this gradient scaling procedure. The θn\theta_{n} are updated with γn\gamma_{n}-amplified gradients from the block’s supervision and (Nn)(N{-}n) gradients from the subsequent blocks. We choose γn=γ(Nn)\gamma_{n}=\gamma(N-n) to control the ratio γ:1\gamma{:}1 as the ratio of the block supervision to the subsequent blocks’ supervisions.

Table 3 shows that gradient scaling can benefit the lowest layer at the expense of higher layers. However, no scaling generally works very well.

Appendix B FLOPS approximation

This section details the computation of the FLOPS we report. The per token FLOPS are for the decoder network only since we use an encoder of the same size for all models. We breakdown the FLOPS of every operation in Algorithm 2 (blue front of the algorithmic statement). We omit non-linearities, normalizations and residual connections. The main operations we account for are dot-products and by extension matrix-vector products since those represent the vast majority of FLOPS (we assume batch size one to simplify the calculation).

With this breakdown, the total computational cost at time-step tt of a decoder block that we actually go through, denoted with FC, is:

where the cost of mapping the source’ keys and values is incurred the first time the block is called (flagged with FirstCall). This occurs at t=1t=1 for the baseline model but it is input-dependent with depth adaptive estimation and may never occur if all tokens exit early.

Depending on the halting mechanism, an exit prediction cost, denoted wit FP, is added:

For a set of source sequences {x(i)}iI\{\boldsymbol{x}^{(i)}\}_{i\in\mathcal{I}} and generated hypotheses {y(i)}iI\{\boldsymbol{y}^{(i)}\}_{i\in\mathcal{I}}, the average flops per token is: