Confident Adaptive Language Modeling

Tal Schuster, Adam Fisch, Jai Gupta, Mostafa Dehghani, Dara Bahri, Vinh Q. Tran, Yi Tay, Donald Metzler

Introduction

Recent advances in Large Language Models (LLMs) have led to breakthroughs in language understanding and language generation across almost every widely-used Natural Language Processing (NLP) task considered in the field today [5; 15; 17; 20; 51; 52; 53; 75; 89; 73]. Autoregressive language modeling provides a flexible framework for solving complex tasks with a unified natural language input and output format, while also relaxing the need for large-scale task-specific data collection and training [67; 15; 17; 58; 80]. The large size of LLMs, however, results in massive computational load that might be limiting for certain real-world applications (e.g., machine translation) [9; 30; 42; 49; 59; 63; 71]. This is especially pronounced in the autoregressive decoding process where the full stack of Transformer layers is repeatedly computed for each output token [37; 40; 86].

While large models do better in general, the same amount of computation may not be required for every input to achieve similar performance (e.g., depending on if the input is easy or hard) . Early exiting is a promising approach to decreasing the computational cost of multilayered architectures such as those used in Transformer-based LLMs, where the number of layers used by the model is dynamically decided on an input-by-input basis [18; 23; 57; 60; 70]. In this setting, an LLM can choose to generate a new token based off the representation at an intermediate layer instead of using the full model, and save computation as a result. A natural question that arises, however, is when is it a good decision to exit early, as opposed to wait? Naively choosing when to exit can be suboptimal in terms of saving computation time, and also result in unpredictable degradations to model performance, especially when predictions depend on each other, as in autoregressive language generation.

In this work, we analyze the early exiting paradigm for LLMs, and present a principled method for increasing model efficiency while remaining confident in the quality of the resulting predictions. Specifically, we develop a method for calibrating local, per-token, exit decisions such that global, sequence-level constraints—as determined by lexical or semantic sequence-level metrics like ROUGE or BLEURT score—are provably maintained with arbitrarily high probability (e.g., 95%). This process, which we call Confident Adaptive Language Modeling (CALM), is illustrated in Figure 1.

Finally, we empirically validate our method on multiple, diverse NLP generation tasks, including text summarization, machine translation, and question answering. Our experiments demonstrate the potential of CALM in reducing the average complexity of the model and accelerating inference by about ×3\times 3 while reliably controlling for high performance.

Contributions. In summary, our main contributions are as follows:

A framework (CALM) for reliably accelerating Transformer-based LLM generations.

A systematic analysis of the token-wise early exit mechanism that motivates a simple-but-effective class of confidence measures and threshold functions that are used as part of the CALM framework.

An empirical demonstration of CALM’s efficiency gains on three diverse generation datasets.

Related Work

Improving inference-time efficiency of LLMs has been an ongoing effort of the research community over the past several years [49; 72; 85], leveraging techniques such as knowledge distillation [6; 32; 36; 69; 69; 78; 56], floating point quantization [71; 65], layer pruning , vector dropping , and others . Another line of work involves conditional computation to train larger models that only use a sparser subset of the full network during inference, for example by routing over mixture-of-experts [9; 22; 39; 91], recurring modules [18; 29; 35], or accessing external memory . These models, however, still use the same amount of compute for all input examples.

Here, we focus on adaptive compute, a specific kind of conditional compute that aims to dynamically allocate different computational power per example, with the goal of reducing the overall complexity while maintaining high performance. This approach, often referred to as early-exiting [16; 25; 47; 74; 79; 87], is complementary to many of the solutions above and can potentially be combined with them. Multiple early-exit techniques for encoder-only Transformers (e.g., BERT ) have been recently proposed [8; 34; 43; 44; 45; 60; 68; 83; 90; 92]. Most of these methods rely on intrinsic confidence measures (e.g., based on the softmax distribution), while others try to predict the routing in advance [46; 70], or train a small early-exit classifier [57; 84], as we also examine here. These measures can be calibrated to reliably guarantee consistency of the early prediction with the full model . However, the techniques used for encoder-only classifiers are unsuitable for global consistency constraints with a sequence of dependent predictions, which are inherent in the decoding process of autoregressive language models, which we address here.

Our work is also motivated by recent findings on the existence of saturation events in LMs, where the top-ranked prediction is unchanged after some layer and is propagated upward. Geva et al. examined interactions of the hidden-state with feed-forward layers to predict these events. However, they only consider local single predictions and do not address the challenges involved with sequence generation. Our early-exit LM architecture most closely relates to Elbayad et al. , who found a token-level early-exit classifier to provide the best efficiency-performance tradeoffs on machine translation. Here, we introduce a theoretically-grounded calibration method for provably controlling the quality of the full sequence. By doing so, we provide reliable efficiency gains—deriving local early exiting decisions from the global desirable constraints. Moreover, we introduce several model improvements and empirical analyses, including (1) analyzing the primary sources of performance degradation, leading us to propose a decaying threshold function for better tradeoff control without inflating the search space; (2) improving the early-exit classifier training; and (3) experimenting with two new tasks.

Our calibration procedure for connecting global constraints to local decisions, relates to recent research around distribution-free uncertainty quantification [1; 62; 77]. Several methods were developed in recent studies to expand and adjust the theoretical framework for obtaining practical efficiency gains on target applications [4; 7; 21; 26; 27; 48; 88]. Here, we frame our consistency requirements around the Learn then Test (LTT) framework , and leverage the approximately monotonic behavior of our confidence measures and the nested structure of our problem, that by definition guarantees consistency with large enough threshold, to form tight and effective bounds.

Early Exiting for Adaptive Language Modeling

In the following, we describe and analyze the early-exiting Transformer LM. We begin with a brief recap of the Transformer architecture (§3.1) and early exiting (§3.2) for convenience, following previous work [23; 70; 76]. We then investigate the effects of early exiting on model performance, and identify primary sources of performance degradation and how to alleviate them (§3.3)—which guide our architecture and training design (§3.4) and proposed per-token confidence measures (§3.5).

We use the Transformer sequence-to-sequence model, based on the T5x implementation . Here, we only review simplified details of the Transformer architecture relevant to early-exiting, and refer the reader to Vaswani et al. for full details. At a high level, both encoder and decoder networks contain LL stacked layers, where each layer is composed of a multi-head self-attention sub-layer, followed by a feedforward sub-layer, each with residual connections and layer normalization. The decoder network has an additional multi-head attention sub-layer that attends to the encoder states.

Consider a prompt x=(x1,,xp)x=(x_{1},\ldots,x_{p}), processed by the encoder to yield encoder states (e1,,ep)(e_{1},\ldots,e_{p}), and the current, partially generated response (y1,,yt)(y_{1},\ldots,y_{t}). When generating the next token yt+1y_{t+1}, the decoder computes a decoder state dtid_{t}^{i} for layer ii out of LL as:

Multi-head and normalization components are omitted for brevity. Each layer uses different projections WQi\mathbf{W}^{i}_{Q}, WKi\mathbf{W}^{i}_{K}, and WVi\mathbf{W}^{i}_{V} (which are also unique for computing htih_{t}^{i} versus atia_{t}^{i}).

2 Decoding with early exiting

Note that due to the self-attention mechanism of the Transformer, computing the input hidden state htih_{t}^{i} for layer ii depends on d1:t1i1d_{1:t-1}^{i-1}, i.e., the output hidden states of the previous layer for all the tokens that have been generated so far.In autoregressive decoding, the ks,vs\mathbf{k}_{s},\mathbf{v}_{s} vectors are cached to avoid repetitive compute for tokens t>st>s. Therefore, if the model has early exited at some layer j<i1j<{i-1} for a token s<ts<t, then dsi1d_{s}^{i-1} is not available. As an approximation, we set dsk=dsjd_{s}^{k}=d_{s}^{j} for all layers k>jk>j following Elbayad et al. , with the understanding that this will introduce some error. In the next section, in addition to other factors, we will analyze the impact of this copied state on performance.

3 The effects of early exiting on error propagation

We perform several controlled experiments to investigate the behavior and the potential of early-exiting during decoding. We use an 8-layer T5 encoder-decoder and the CNN/DM dataset for these experiments. See §5 for more details on this model and data.

First, we control for the correctness of the predicted tokens to examine the effect of state copying (§3.2), and also measure an approximate upper bound for compute reduction. We use an oracle confidence measure that exits at the earliest layer that agrees with the top prediction (i.e., replacing the conditions in Eq. 4 with arg maxp(yt+1  dti)=arg maxp(yt+1  dtL)\operatorname*{arg\,max}p(y_{t+1}~{}|~{}d_{t}^{i})=\operatorname*{arg\,max}p(y_{t+1}~{}|~{}d_{t}^{L})). Hence, the only factor that can cause divergence in the generation is the state copying mechanism for skipped layers. The results of this experiment are highly encouraging. This oracle achieves an ROUGE-L score of 38.24, compared to 38.32 with the full model, while only using an average of 1.53 layers per token. We also try an oracle that always uses d1:t11d^{1}_{1:t-1} and it reaches 38.31 ROUGE-L. These results indicate that (1) the model is robust to state copying from lower layers, and (2) there is remarkable potential for saving compute—by up to ×5.2\times 5.2—while preserving performance, given a good confidence measure.

We also experiment with copying the projected states Kj,Vj\mathbf{K}^{j},\mathbf{V}^{j} to skipped layers k>jk>j. This version of the oracle results in a significant drop in performance to 23.02 ROUGE-L. Overall, we conjecture that the self-attention at layer ii for token tt can safely use hidden-states dsjd^{j}_{s} for j<i1j<i-1 as key-values of tokens s<ts<t, as long as the projections WK/Vi\mathbf{W}^{i}_{K/V} of layer ii are used. Notably, this projection can now be computed concurrently for all skipped layers as they all use the same dd from the exited layer.

3.2 Sensitivity to local errors

Next, we examine the impact of local token modifications—which might occur due to early exits—on the whole generated sequence. We experiment with two kinds of perturbations: sampling-based, where we select the 10th-ranked token according to layer LL; and layer-based, where we select the the first layer’s prediction at timestep tt. All other tokens are predicted greedily by layer LL. As shown in Figure 2(a), earlier perturbations result in lower sequence-level scores as there are more tokens that might suffer from the divergence. The degradation, though, is much smaller with layer- compared to sampling-based perturbations since, in practice, the early exit predictions are mostly accurate.

Decaying threshold. Following the above observation, we introduce a decaying early-exiting threshold that is more permissive towards exiting as the decoding process continues. Motivated by the logarithmic behavior in Figure 2(a), we use an exponential function with a user-defined temperature τ\tau:

where NN is the maximum output length. Figure 2(b) illustrates this function. Essentially, this function presents an effective compromise between simply using the same threshold for all tokens, and searching over a huge space of per-position different thresholds. Practically, it supports finer and better control over the performance-efficiency tradeoff compared to a single threshold. Figure 2(c) presents the outcomes of a search over λ\lambda with steps of 0.010.01 and softmax-based confidence (§3.5). With the single threshold variant (τ=0\tau=0), attempting to improve the efficiency will lead to a drastic drop of more than 10 points in the textual similarity against the full model’s prediction. In contrast, the decaying thresholds reveal several intermediate points with desirable tradeoffs to consider.

4 Training early exit classifiers for local consistency

While our goal is to preserve the quality of the complete output sequence, we note that this doesn’t necessarily demand local token-level consistency. Consider the target sequence “the concert was wonderful and long.” An output that switches the order of adjectives to “the concert was long and wonderful” would be called consistent by most semantic measures (and obtain 100 token-F1F_{1} score). Yet, the sentences diverge at the first adjective long which is semantically different from wonderful.

Training for global consistency, however, could be challenging as it depends on possibly noisy signals that might affect the learning, and also breaks the efficient teacher-forcing training strategy of LMs that relies on local-decisions. On the other hand, perfect local consistency implies global consistency. Therefore, we opt to train for local consistency, which requires minimal changes to the training procedure, and relax the local requirement to a global one during inference.

Specifically, similar to Elbayad et al. , we average losses for each layer to obtain the objective

L\mathcal{L} is the negative log-likelihood loss. We set ωi=i/j=1Lj\omega_{i}=i/{\sum_{j=1}^{L}j} to favor higher layers, and find this objective to mostly preserve the full model’s performance compared to regular training. We note that there is some misalignment between this training and inference behavior due to the hidden states of skipped layers. However, as discussed in §3.3.1, the performance is not affected if the hidden-state is copied.

5 Local confidence measures

We experiment with three confidence measures for Eq. (4) that differ in their parameter and compute operation efficiencies. Our experiments (§6) will also show that they differ in their predictive power.

Softmax response. We take the difference between the top two values of Softmax(Widti)\operatorname{Softmax}(\mathbf{W_{i}}d^{i}_{t}). With a large output vocabulary, this results in many floating point operations (FLOPs)—though, the next layer i+1i+1 can start its computation in parallel, avoiding additional runtime.

Calibrating Local Early Exits from Global Constraints

We now describe our calibration procedure for finding a shared exit threshold λ\lambda\in that can be used directly in Eq. (4), or via Eq. (5), such that we provably satisfy our desired global constraints over the fully generated sequences. At a high level, our approach uses the following basic recipe:

We specify a grid of possible values of Λ=(λ1,,λk)\Lambda=(\lambda_{1},\ldots,\lambda_{k}) that may result in acceptable generations;

We choose the lowest valid λΛ\lambda\in\Lambda that we can identify with rigorous statistical testing tools.

Choosing a value of λ\lambda that rigorously satisfies our consistency objectives is challenging, as the performance impact of increasing or decreasing λ\lambda is not necessarily monotonic. Naively setting λ\lambda, for example, based simply on average calibration set performance, can lead to statistically invalid results in our finite-sample, distribution-free setting. The LTT framework proposed by Angelopoulos et al. solves this problem by reframing hyper-parameter selection as a multiple testing problem.

Here, we are using consistency to refer to either textual consistency or risk consistency. Eq. (7) can be satisfied by applying standard multiple hypothesis testing techniques as long as super-uniform p-values, pjp_{j}, are supplied for each value λjΛ\lambda_{j}\in\Lambda that support the null hypothesis

2 Defining p-values for consistent early-exiting

where E^(λj):=1ni=1nLi(λj)\widehat{E}(\lambda_{j}):=\frac{1}{n}\sum_{i=1}^{n}L_{i}(\lambda_{j}) is the empirical average of random variable Li(λj)L_{i}(\lambda_{j})\in, with

for textual consistency versus risk consistency, respectively. Note that, as a technicality of enforcing the r.v. Li(λj)L_{i}(\lambda_{j}) to be within $$, Eq. (11) computes a conservative estimate of the difference in the empirical risk that doesn’t reward instances in which the risk of the early-exit model is lower.

3 Efficient fixed sequence testing

Here we define a sequence of descending thresholds λ1>λ2>λk\lambda_{1}>\lambda_{2}>\ldots\lambda_{k} with a relatively coarse step size (e.g., increments of 0.050.05). For each λj\lambda_{j} in order, we compute pjp_{j}, and reject HjH_{j} if pjϵp_{j}\leq\epsilon. The first time we fail to reject HjH_{j}, we immediately terminate our search, and return λj1\lambda_{j-1} to use as our calibrated threshold (or 11, if we fail to reject H1H_{1}). An Algorithm of the full procedure is provided in Appendix E.

Experimental Setting

We empirically evaluate our methods on three popular text generation tasks that vary in their target generation length and extractive degrees against the input. CNN/DM is a collection of news articles to be summarized in few sentences. WMT15 EN-FR contains English sentences (one per example) to be machine translated to French. Open-book SQuAD 1.1 is a QA dataset with Wikipedia paragraphs paired with questions, where the target answer is a text span from the input. Length statistics of the validation sets are summarized in Table 1.

Model. We implement CALM on top of the T5 encoder-decoder model that showed good performance on the tasks above , using the T5X framework . We use the 8 layers T5 1.1 model that doesn’t share input and output embeddings. We share all output embeddings for the softmax predictions, and the early-exit classifier across all decoder layers. Based on validation results, we set the temperature of our decaying threshold to τ=4\tau=4 for the softmax and classifier measures of CNN/DM and WMT. In other settings, we use τ=0\tau=0. See App. C for more details, and App. B.3 for a 12 layers T5 model.

Our main efficiency metric is the average number of decoder layers used per output token, as it directly measures complexity reduction without conflating with implementation or infrastructure specific details . For reference, we also report the average decoder FLOPs reduction per token . Also, we compute an estimated speedup of the whole encoder-decoder model for generating the full sequence, based on TPUv3 benchmarking with 200 examples in Colab (see App. C for details).

Baselines. We emphasize that the CALM framework is general for any autoregressive multi-layered LM with any confidence measure, allowing controlled consistency by Eq. (1) or Eq. (2). To empirically evaluate the efficiency gains enabled by our proposed confidence measures, we compare with static baselines that use the same number of layers for all tokens. We also compare our early-exit classifier training with the geometric method of in Appendix D. Also, we compute an oracle local measure (§3.3.1) as an upper-bound estimate of the performance-efficiency tradeoff.

Experimental Results

We first report the empirical performance-efficiency tradeoff achieved with each confidence measure. For each task and measure, we evaluate the full range of λ\lambda on the validation set, with steps of 0.050.05. The results, presented in Figure 3, show the power of the softmax response measure, allowing only minor performance loss while reducing more than half of the layers in all three tasks. The early-exit classifier, that is more FLOP-efficient, is also effective, mostly when targeting high performance (right hand side of plots). The simple and parameter-free state saturation measure is competitive, but often falls bellow the static baseline, despite enabling per-token exit decisions.

The dynamic oracle obtains compelling efficiency gains, using only 1.5, 1.3, and 1.2 layers on average for summarization, WMT, and QA, respectively, without losing any performance. This illustrates the full potential of CALM and leaves further room for improvements with better confidence measures. It also shows the effectiveness of inference-time state propagation for skipped layers (§3.3.1).

Next, we examine the outcomes of the calibration process. Since the obtained risk is guaranteed to be valid (i.e., δ\leq\delta at least 95% of the time), we focus here on efficiency gains per chosen δ\delta. We refer the reader to Appendix B for empirical validation and for additional results and qualitative examples.

Table 2 presents the efficiency gains per choice of δ\delta for each consistency objective and confidence measure. We examine larger δ\delta values for textual consistency as this is generally a stricter requirement since the full model’s error is not considered.

Across all, the softmax confidence measure leads to the greatest decrease in number of decoder layers required. Accordingly, softmax mostly enables the highest speedup gains of up to about three times faster than running through all the model’s layers. The very lightweight early-exit classifier sometimes provides better gains than softmax, even if more decoding layers are used. Since the speedup is computed over the full generated output, we see more gains on the longer outputs of summarization and translation where the decoding takes most of the time, compared to the short QA outputs where the whole decoding time is not much longer than the encoding time.

These encouraging efficiency gains are enabled even with the rigorous performance guarantees that are sometimes conservative (e.g., Eq. (11)). We note that relaxing these constraints, or tightening the confidence intervals (e.g., with larger calibration sets), can further improve the empirical gains.

The softmax operation over the full output vocabulary is FLOPs heavy (though, this compute can potentially be paralleled), sometime leading to increased total FLOPs, even with fewer used layers. The state-based and early-exit classifier measures require minimal FLOPs and provide a good alternative with compelling efficiency gains, if total (parallizeable, or not) FLOPs is of concern.

2 Example output: effectively distributing the model’s capacity across timesteps

Conclusion

We present confident adaptive language modeling (CALM) for dynamically allocating different amounts of compute per generated token, following explicitly defined tolerance levels on the full generation output. This paper covers both modeling solutions and analyses towards this goal, as well as a theoretically-grounded framework for provably controlling the quality of the full output to meet the user-specified tolerance levels. We investigate the effects of local early exiting during decoding on the final output, leading us to propose a decaying function over the initial threshold that enables finer control over the performance-efficiency tradeoffs without inflating the search space. We also study different solutions for addressing missing computations of early-exited tokens that are dependent upon for future tokens. Overall, our complete adaptive compute framework for LMs requires minimal modifications to the underlying model and enables efficiency gains while satisfying rigorous quality guarantees for the output. Also, our oracle experiments and runtime analysis demonstrates the full potential of this framework and leave room for future research to further improve the efficiency in a controllable way.

Acknowledgements

We thank Ionel Gog for significantly improving the implementation after submission. We also thank Anselm Levskaya, Hyung Won Chung, Seungyeon Kim, Tao Wang, Paul Barham, and Michael Isard for great discussions and code suggestions. We thank Orhan Firat, Carlos Riquelme, Aditya Menon, Zhifeng Chen, Sanjiv Kumar, and Jeff Dean for helpful discussions and feedback on the project.

References

Appendix A Mathematical Details

Appendix B Additional Results

We provide additional experimental results to supplement Section 6. In Section B.1, we include calibration plots, for both the validation and test sets, with the full range of δ\delta, also showing the standard deviation across random trials. In Section B.2, we present a few example outputs with a visualization of the per-token early-exit decisions to illustrate CALM’s behavior. In Section B.3, we include results of a larger 12-layer model, showing the generalizability of our framework to other configurations.

We present complementary results to Table 2. Figure B.1 and Figure B.2 present the empirical consistencies and efficiency gains for textual and risk consistency constraints, respectively. Figure B.3 and Figure B.4 report the same on the validation datasets. First, we observe that the calibration holds empirically, achieving risk values that are not greater than the specified δ\delta (i.e., being under the diagonal in the upper subplots). We also see that the risk is often lower than allowed for (a good thing), especially with the risk consistency objective. This is due to the conservativeness of our measure (Eq. (11)), not rewarding instances where the early prediction has lower risk. While obtaining lower risk than the target is not a downside, this indicates that there is further potential in improving the efficiency gains achieved per δ\delta. Yet, even with the rigorous and conservative theoretical guarantees, we already obtain significant efficiency gains that, naturally, increase with larger tolerance values.

B.2 Qualitative examples

Figure B.5 presents two example outputs of CALM for instances from the machine translation, and question-answering (QA) datasets (See Figure 4 for summarization). The colors depict the per-token number of decoder layers that were used for generating that output. We also report the risk values for textual and risk consistency of both outputs, as well as the speedup compared to the full model. We observe that the textual distance generally increases as we accelerate the decoding. Though, the outputs still remain relatively similar to the full model even when using very few layers. The risk consistency doesn’t always correlate with the textual one when the full model’s risk is non-zero. In some cases, the accelerated output has even lower risk than the full model’s output. This demonstrates the value of having both our textual and risk consistency configurations, which the user can pick from based on their objective, and whether quality reference outputs for calibration are available or not.

Interestingly, following our initial intuition, CALM distributes the compute unevenly, using very few layers for certain “easy” tokens, and additional compute to “hard” tokens. Examining the examples, we see that many times “hard” generation steps come at the beginning of sentences, or when generating a verb. We leave further investigations on perceived difficulties to future work.

B.3 T5-base results

While throughout the rest of this paper we experiment with a 8-layer encoder-decoder T5 model. We include here results for a 12-layer T5-base model that besides the additional layers is also larger in its internal dimensions, having 12 attention heads and 64, 768, and 2048 dimensions for the attention head, embeddings, and MLP, respectively.

Figure B.6 shows the empirical performance-efficiency tradeoffs achieved with this model on the three tasks. Overall, we the trends are very similar to the one observed with the 8-layer model (Figure 3). One exception is the SQuAD model where the static baseline that uses only one or two decoder layers completely fails. This suggests that the actual predictions of this model are starting to be formed only from the third layer. Also, the local oracle measure on SQuAD obtains slightly lower global performance compared to the full model, also suggesting that in this case the hidden-state of the very low layers might not be a good enough representation for followup generations. Yet, the softmax and early-exit classifier confidence measure provide good proxies for the consistency and often outperform the static baselines. In the other two datasets, the local oracle matches the performance of the full model, similar to the behavior of the 8-layer model.

Figure B.7 and Figure B.8 present the validity and efficiency gains of our calibration procedure on the 12-layer model for textual and risk consistency objectives, respectively. We observe a largely similar behavior as the 8-layer model, showing the generality of our framework to other configurations of the backbone language model.

Appendix C Implementation Details

As mentioned in Section 5, we build on the T5 encoder-decoder model [Raffel et al., 2019], and us the T5X repository [Roberts et al., 2022] for implementing CALM. Appendix E describes the main algorithmic components.

Our main experiments use the T5 1.1 version with 8 layers for both the encoder and decoder modules, 6 attention heads with dimensions of 64, 512, and 1024 for the attention head, embeddings, and MLP, respectively. The vocabulary contains 32,128 tokens. This model doesn’t share the input and output embeddings. For our early-exit head, we share the output embeddings between all intermediate with the top one, not introducing any new parameters to the model. Our binary early-exit classifier is also shared across all layers, adding only a very small amount of new parameters. We add early-exit heads to all layers.

We fine-tune the models on the training set of each task for a maximum of 500K steps, and choose the best checkpoint by performance on the validation set (using the full models’ predictions). We use a batch size of 128, the regular LM cross-entropy loss, the AdaFactor optimizer [Shazeer and Stern, 2018], and experiment with learning rates 10310^{-3} and 10410^{-4}. We aggregate the loss of individual layers with a weighted average, as discussed in Section 3.4. For the early-exit classifier training, we use an unweighted average (See Appendix D for more details). We use 64 TPUv3 chips for training. For inference, we use a single TPUv4 chip with a batch size of one, simulating a one-at-a-time processing setting, that is convenient when serving models for online requests. As described in Section 3.2, CALM exits early whenever the per-token confidence value cc (Section 3.5) exceeds the possibly-decaying threshold λ\lambda (Section 3.3.2) derived from the user-defined δ,ϵ\delta,\epsilon tolerance levels and textual or risk consistency objective (Section 4). If necessary, the hidden-state is propagated to the skipped layers (Section 3.3.1).

We detail our procedures for approximating the reference efficiency gains using the FLOPs and speedup measures. For FLOPs computation, to be consistent with Elbayad et al. (See their Appendix B), we adopt their formula to compute the average decoder FLOPs per output token.

To measure the speedup of early exiting with CALM, we execute 200 inference predictions for each task under each examined configuration in a JIT compiled function in colab with TPUv3. We ignore the time of the first execution, since it is drastically slower due to compilation, and average the rest of the measured times. For each inference prediction, we use batch size one, and measure the full generation time including both the encoder and all decoding steps until completion. For each examined confidence measure (softmax, state, or classifier), we compute the speedup by comparing to the average inference time of the full model that uses all layers, by setting the confidence threshold to maximum. We note that our implementation adds conditionals between layers that add some overhead to the compute graph. Yet, even compared to a conditional-free implementation, the gains of CALM’s early exits often outweigh the overheads of the added conditionals. We leave studying further technical improvements to the implementation to future work.

Appendix D Training Details of the Early Exit Classifier

Following the setup above, we compute the binary cross-entropy loss for each layer individually, and aggregate by taking the unweighted average across all layers. We use the loss value on the validation set to pick the best checkpoint. We also explore with the geometric-like objective proposed by Elbayad et al. . Their approach views the exiting decisions as a Bernoulli process, using the “exit”/ “don’t exit” predicted probabilities. The goal is to make an “exit” prediction at the first true layer, determined by the oracle, and “don’t exit” predictions by all preceding layers. Accordingly, the training objective maximizes the probability of this oracle-guided event, modeled as a product of all respective predicted probabilities. In practice, due to numerical instability of this product operation, we maximize the summation over the log probabilities.

Table D.1 presents the F1F_{1} validation scores of “don’t exit” predictions (with a 0.50.5 threshold) by early-exit classifier, measured against the oracle for layers 1,4, and 7. Our per-layer independent training objective outperforms the geometric-like objective across all layers and tasks. The advantage is typically most pronounced for higher layers. We conjecture that this is due to the equal weight of the independent objective that utilizes signal from all layers, whereas the geometric-like objective only learns from layers up to the first oracle exit.

Appendix E Algorithms and Code

Algorithm 1 describes the calibration process of CALM for obtaining global textual or risk δ,ϵ\delta,\epsilon consistency (Section 4).

The JAX [Bradbury et al., 2018] code for training CALM models and for executing the early-exit functionality at inference-time is available at: https://github.com/google-research/t5x/tree/main/t5x/contrib/calm