Fast Inference from Transformers via Speculative Decoding
Yaniv Leviathan, Matan Kalman, Yossi Matias
Introduction
Large autoregressive models, notably large Transformers (Vaswani et al., 2017), are much more capable than smaller models, as is evidenced countless times in recent years e.g., in the text or image domains, like GPT-3 (Brown et al., 2020), LaMDA (Thoppilan et al., 2022), Parti (Yu et al., 2022), and PaLM (Chowdhery et al., 2022). Unfortunately, a single decode step from these larger models is significantly slower than a step from their smaller counterparts, and making things worse, these steps are done serially - decoding tokens takes serial runs of the model.
Given the importance of large autoregressive models and specifically large Transformers, several approaches were developed to make inference from them faster. Some approaches aim to reduce the inference cost for all inputs equally (e.g. Hinton et al., 2015; Jaszczur et al., 2021; Hubara et al., 2016; So et al., 2021; Shazeer, 2019). Other approaches stem from the observation that not all inference steps are born alike - some require a very large model, while others can be approximated well by more efficient models. These adaptive computation methods (e.g. Han et al., 2021; Sukhbaatar et al., 2019; Schuster et al., 2021; Scardapane et al., 2020; Bapna et al., 2020; Elbayad et al., 2019; Schwartz et al., 2020) aim to use less compute resources for easier inference steps. While many of these solutions have proven extremely effective in practice, they usually require changing the model architecture, changing the training-procedure and re-training the models, and don’t maintain identical outputs.
The key observation above, that some inference steps are “harder” and some are “easier”, is also a key motivator for our work. We additionally observe that inference from large models is often not bottlenecked on arithmetic operations, but rather on memory bandwidth and communication, so additional computation resources might be available. Therefore we suggest increasing concurrency as a complementary approach to using an adaptive amount of computation. Specifically, we are able to accelerate inference without changing the model architectures, without changing the training-procedures or needing to re-train the models, and without changing the model output distribution. This is accomplished via speculative execution.
Speculative execution (Burton, 1985; Hennessy & Patterson, 2012) is an optimization technique, common in processors, where a task is performed in parallel to verifying if it’s actually needed - the payoff being increased concurrency. A well-known example of speculative execution is branch prediction. For speculative execution to be effective, we need an efficient mechanism to suggest tasks to execute that are likely to be needed. In this work, we generalize speculative execution to the stochastic setting - where a task might be needed with some probability. Applying this to decoding from autoregressive models like Transformers, we sample generations from more efficient approximation models as speculative prefixes for the slower target models. With a novel sampling method, speculative sampling, we maximize the probability of these speculative tasks to be accepted, while guaranteeing that the outputs from our system have the same distribution as those from the target model alone. For example, the sentence in Figure 1, consisting of 38 tokens, was generated by our method with only 9 serial runs of a larger target model (97M parameters) thanks to a smaller and more efficient approximation model (6M parameters), while the probability of generating it is unchanged.
We analyze our method in a variety of tasks and model sizes: unconditional generation from a 97M parameter GPT-like model trained on lm1b, English to German translation and news article summarization with an 11B parameters T5-XXL model, and a dialog task with a 137B parameter LaMDA model. We implement our method and compare actual walltimes for T5-XXL to those of the robust T5X implementation (Roberts et al., 2022), showing an out-of-the-box latency improvement of 2X-3X, without any change to the outputs (Section 4).
Our method is easy to employ in actual production settings, doesn’t require training new models, and doesn’t change the outputs. Therefore, in common situations where memory bandwidth is the bottleneck, and compute resources are available, it may be a good default to accelerate sampling from autoregressive models like Transformers.
To summarize, our main contributions are: (1) A generalization of speculative execution to the stochastic setting, with a novel sampling method we call speculative sampling, and (2) A decoding mechanism we call speculative decoding that can accelerate decoding from autoregressive models, without any change to the model architectures, training regimes and output distributions.
Speculative Decoding
2 Standardized Sampling
First, note that while there are many methods and parameters of sampling, like argmax, top-k, nucleus, and setting a temperature, and popular implementations usually treat them differently at the logits level, they can all easily be cast into standard sampling from an adjusted probability distribution. For example, argmax sampling is equivalent to zeroing out non-max elements of the distribution and normalizing. We can therefore only deal with standard sampling from a probability distribution, and cast all of the other types of sampling into that framework. Going forward we’ll assume that and are the distributions from and respectively, adjusted for the sampling method.
3 Speculative Sampling
To sample , we instead sample , keeping it if , and in case we reject the sample with probability and sample again from an adjusted distribution instead. It’s easy to show (see Section A.1) that for any distributions and , and sampled in this way, indeed .
Given the distribution obtained from running on a conditioning , we can sample a token . We then calculate the distribution by running on while in parallel speculatively calculating the distribution of the next token by running on . Once both computations complete, we proceed as per above: If is rejected, we discard the computation of and re-sample from an adjusted distribution, and if is accepted, we keep both tokens. Algorithm 1 generalizes this idea to sample between 1 and tokens at once.
Analysis
Let’s analyze the reduction factor in the number of serial calls to the target model, or equivalently, the expected number of tokens produced by a single run of Algorithm 1.
is then a natural measure of how well approximates . If we make the simplifying assumption that the s are i.i.d., and denote , then the number of tokens produced by a single run of Algorithm 1 is a capped geometric variable, with success probability and cap , and the expected number of tokens generated by Algorithm 1 satisfies Equation 1. See Figure 2.
2 Calculating α𝛼\alpha
We’ll now derive a simple formula for calculating given a prefix and the two models and . We start by defining a natural divergence :
where .
∎
From Lemma 3.3 we immediately get the following results:
\beta=E_{x\sim q(x)}\left\{\begin{tabular}[]{@{}p{0.5cm}p{1.8cm}}1&q(x)\leq p(x)\frac{p(x)}{q(x)}&q(x)>p(x)\\ \end{tabular}\right.=E_{x\sim q(x)}\min(1,\frac{p(x)}{q(x)})=\sum_{x}\min(p(x),q(x)) ∎
See Footnote for empirically observed values in our experiments.
3 Walltime Improvement
We’ve shown that with the i.i.d. assumption our algorithm reduces the number of calls to the target model by a factor of . Note that speculative execution in general, and our algorithm in particular, assume that we have enough compute resources to support the increased concurrency (Section 3.4). For the walltime anaylsis, we’ll assume that we can run concurrent evaluations of in parallel without increasing the walltime. To get the total walltime improvement, we now consider the cost of running the approximation model .
Let , the cost coefficient, be the ratio between the time for a single run of and the time for a single run of .
Note that unlike which is an intrinsic property of the models and the task, the value of depends on the hardware configuration and software implementation details. In our experiments where is typically a couple of orders of magnitude smaller than , was always less than and often negligibly close to 0.
The expected improvement factor in total walltime by Algorithm 1 is .
Denote the cost of running a single step of by . Now, each run of Algorithm 1 costs (for running the approximation model times and running once) and according to Equation 1 produces tokens on average. So the overall expected cost for producing a token with Algorithm 1 is . Since the cost of producing a single token with the standard decoding algorithm is , we get the desired result. ∎
Note that Theorem 3.8 assumes long enough generations (for example, since we run at least once, the improvement factor is capped by the number of generated tokens).
If , there exists for which we’ll get an improvement, and the improvement factor will be at least .
If we get an improvement for , we’d also get an improvement for any , so for our method to yield an improvement, we can evaluate Theorem 3.8 for , yielding . ∎
4 Number of Arithmetic Operations
Algorithm 1 does runs of in parallel, so the number of concurrent arithmetic operations grows by a factor of . Now, since Algorithm 1 produces at most tokens per run, the total number of arithmetic operations might be higher than that of the standard decoding algorithm. When we accept the sample from the increased concurrency is “free” and the total number of operations isn’t increasedNeglecting the cost of .. When we reject a guess though, computation is wasted. Let’s now analyze the effect of our method on the total number of arithmetic operations.
Let be the ratio of arithmetic operations per token of the approximation model to that of the target model .
The expected factor of increase in the number of total operations of Algorithm 1 is .
Denote by the number of arithmetic operations done by a standard decoding baseline per token, i.e. the number of operations of a single run of . Then a single iteration of Algorithm 1 costs operations (for runs of and parallel runs of ). Dividing by the expected number of tokens produced by Algorithm 1, i.e. Equation 1, and by , we get the desired result. ∎
If is low, the increase in the number of arithmetic operations is high, and vice-versa. Note that for Transformer decoders, the total number of arithmetic operations by Algorithm 1 (not counting runs of ) can be bounded from above by a single run of the same-size Transformer encoder.
Unlike the total number of arithmetic operations, the total number of memory accesses can go down with our method. Specifically, the target model’s weights and KV cache can be read once per execution of Algorithm 1, so the number of memory accesses for reading them shrinks by a factor of , according to Equation 1.
5 Choosing γ𝛾\gamma
Given and and assuming enough compute resources (see Section 3.4), the optimal is the one maximizing the walltime improvement equation (Theorem 3.8): . Since is an integer, it can be easily found numerically, see Figure 3.
Table 1 and Figure 4 illustrate the trade-off between inference speed and the total number of arithmetic operations for various values of and , assuming . Figure 5 shows a simplified trace diagram.
Instead of picking a single value for based on , since the s aren’t constant, we could get further improvement by predicting the value of and accordingly varying the value of during the run of Algorithm 1. To get an upper bound on the additional improvement factor, assume we had an oracle for . We would then have . For typical values of and , and assuming unbounded compute resources, the enhanced walltime improvement factor can be up to 60% higher than the improvement factor with a fixed . We leave exploring this for future workThe above bound assumes that we still run to verify the oracle’s predictions. If we skip those verifications the bound doesn’t hold and we would get a substantial additional improvement..
6 Approximation Models
Speculative sampling, and therefore speculative decoding, guarantee an identical output distribution for any choice of approximation model without restriction (see Section A.1). In our experiments, we mostly tested existing off-the-shelf smaller Transformers as the approximation models. Further, we only tested approximation models of the same architecture as the target models and using the same probability standardization. In this setup, choosing to be around two orders of magnitude smaller than usually performed best, balancing and (Theorem 3.8).
Another type of approximation models, negligible-cost models, are those for which , i.e. approximation models with a negligible cost relative to the target model. In this case, we get an expected walltime improvement of , which is bounded from above by (we approach equality if is large). One interesting type of negligible-cost approximation models are n-gram models, where the evaluation amounts to a table lookup. Interestingly, in empirical tests (Section 4.2) we get non zero s even for these trivial n-gram models. For example, for the English-German translation task, with being T5-XXL 11B and being a trivial bigram model, we get which leads to an inference speed improvement factor of X with .
Other simple heuristics can be used as negligible-cost approximation models. For example, in cases where long sequences are likely to repeat, such as for summarization tasks or chat-like interfaces E.g. where a user and a language model iterate on content, like text or code (“can you rewrite this story but change the ending”, “can you make this function also do X”)., an approximation model that simply copies tokens from the context in case we find a matching prefix, might yield high values of . These parameter-less approximation models, have the additional advantage of being even simpler to deploy from a production standpoint.
Another type of approximation models that can be used by speculative decoding are non-autoregressive models, like those from (Stern et al., 2018). Then, instead of the autogreressive loop in Algorithm 1 we’d just call the non-autoregressive model once.
A final example, interesting mostly from a theoretical perspective, is an approximation model which chooses tokens at random, which guarantees some improvement (although very small) for all models .
Experiments
We implement our algorithm and compare it to the implementation in the T5X codebase for accelerating T5-XXL.
We test a standard encoder-decoder T5 version 1.1 model (Raffel et al., 2020) on two tasks from the T5 paper: (1) English to German translation fine tuned on WMT EnDe, and (2) Text summarization fine tuned on CCN/DM. For both tasks, we use T5-XXL (11B) for . For the approximation model we test several existing configurations, namely T5-large (800M), T5-base (250M), and T5-small (77M) (Raffel et al., 2020). We use existing checkpoints for all models. We measure walltime improvements with a batch size of 1 on a single TPU-v4 for both argmax sampling (temp=0) and standard sampling (temp=1).
Results
Table 2 shows the empirical results from our method. We see that T5-small (77M), with a good balance of and , provides the highest speedup out of the tested approximation models. As expected we see that increases with the size of the approximation model. Interestingly, and walltime improvement are higher for argmax sampling (temp=0). We observe speedups of 2.6X (temp=1) and 3.4X (temp=0) on the translation task and slightly lower speedups of 2.3X (temp=1) and 3.1X (temp=0) for the summarization task. These empirical results match well with the theoretical predictions, with some variance due to implementation details (see Section A.3).
2 Empirical α𝛼\alpha Values
While we only implemented our method for T5, we measured values for various tasks, sampling methods, target models , and approximation models . Specifically, we evaluated the expectation from Corollary 3.6 on 10K tokens generated by , for each of the settings below.
We test a decoder-only Transformer model on unconditional language generation, trained on lm1b (Chelba et al., 2013). The model here is a GPT-like Transformer decoder with Gelu activations (Hendrycks & Gimpel, 2016). For we experimented with a Transformer decoder model with 6M parameters: dim 256, dim feed-forward 1024, 2 layers, 4 attention heads, as well as simple unigram and bigram models. has 97M parameters: dim 768, dim feed-forward 3072, 12 layers, 12 attention heads. We used Bert tokenization (Devlin et al., 2019) with 8k tokens for all models.
LaMDA (137B params)
We tested a decoder only LaMDA model on a dialog task (Thoppilan et al., 2022). We used existing checkpoints from LaMDA 137B as and LaMDA 8B, LaMDA 2B, and LaMDA 100M for .
See Section 4.1 for the setup of the T5-XXL (11B params) model.
Footnote summarizes the values for the tested cases. We observe that approximation models that are a couple of orders of magnitude smaller than the target model tend to produce values between 0.5 and 0.9. Interestingly, we also note that for all models, the sharper the adjusted distribution, the higher the values. Finally, we note that even trivial unigram and bigram approximations yield non negligible values. For example, for the case of English to German translation, the bigram model has an value of 0.2, and since in this case, yields a 1.25X speed improvement, which is surprisingly high for this trivial approximation model (but is still lower than the speedup we get from using T5-small as the approximation model).
Related work
The efficiency of inference from large models was studied extensively (Dehghani et al., 2021). Many approaches aim to speed up inference from large models in general, and autoregressive models like Transformers in particular. Numerous techniques try to make inference more efficient for all tokens, e.g. distillation (Hinton et al., 2015), sparcification (Jaszczur et al., 2021), quantization (Hubara et al., 2016), and architecture modification (So et al., 2021; Shazeer, 2019). Closer to our approach are adaptive computation methods which adapt the amount of computation to problem difficulty (Han et al., 2021). Examples include attending to a subset of the inputs (Sukhbaatar et al., 2019), and early exits (Schuster et al., 2021; Scardapane et al., 2020; Bapna et al., 2020; Elbayad et al., 2019; Schwartz et al., 2020). Notably, Wisdom of Committees (Schwartz et al., 2020) leverages off-the-shelf smaller models, but is an adaptive computation approach, and so it uses a heuristic to determine when to stop, losing the guarantee of identical outputs to those of the target models. In general, adaptive computation methods usually learn, either within the model itself or with an auxiliary model, when a computation shortcut can be taken. Usually, these methods save on both inference time and arithmetic operations, but require a change of architecture, a change of training procedure and training custom models or re-training of existing models. They usually also change the outputs of the model. We note that while many of the methods above improve the memory to arithmetic-operations ratio, in cases where the ratio remains high, these methods and our speculative decoding method might be effective in tandem.
Two prior methods leverage speculative execution for speeding up decoding from autoregressive models. Blockwise Parallel Decoding (Stern et al., 2018) decodes several tokens in parallel, similarly to our work. However, it only supports greedy decoding (temperature=0) and not the general stochastic setting, it requires additional training of a custom model, and focuses on preserving down-stream task quality, instead of guaranteeing identical outputs. Shallow Aggressive Decoding (SAD) (Sun et al., 2021) also decodes several tokens in parallel, similarly to our work. Unlike our work, SAD only supports copying the input to the output, and not general approximation models, making it only suitable for the cases where the inputs and outputs are very similar like grammatical error correction. In addition, similarly to Blockwise Parallel Decoding, SAD does not support the general stochastic sampling setting.
After we initially published our work, an independent implementation of speculative decoding (Chen et al., 2023) showed similar 2X-2.5X improvements on Chinchilla 70B.
Discussion
We presented speculative sampling which enables efficient stochastic speculative execution - i.e. speculative execution in the stochastic setting. We analyzed its impact on decoding from autoregressive models like Transformers via speculative decoding and have shown that given enough compute resources, we get meaningful 2X-3X speedups in practice vs T5X, a popular optimized implementation.
One limitation of speculative execution in general, and of speculative decoding in particular, is that latency is improved through increased concurrency at the cost of an increased number of arithmetic operations. Thus, our method is not helpful for configurations where additional computation resources are not available. However, in common cases where additional computation resources are available (e.g. when memory bandwidth is the bottleneck) our method provides the speedup with significant benefits: the model architecture doesn’t change, retraining isn’t required, and most importantly, the output distribution is guaranteed to stay the same. Our method is easy to implement, and can be used to speedup inference using out-of-the-box models without developing and evaluating custom schemes.
There are several directions for follow up research, importantly, further investigating the compatibility of speculative decoding with beam search (see Section A.4). Also, while our method yields substantial speedups with existing off-the-shelf approximation models, greater improvements might be obtained via custom approximation models (Section 3.6), such as those with custom architectures (e.g. custom sizes, non-autoregressive models, or various heuristics) or with custom training procedures (e.g. standard distillation with soft targets from , or optimizing for directly). It could also be interesting to explore a hierarchical version of the algorithm, where the approximation model is itself accelerated by an even faster model, which could allow for more capable approximation models. In this work we fixed the approximation model and the number of guesses throughout inference, but varying them during inference could yield additional improvements (Section 3.5). In our experiments we always performed the same standardization on the distributions generated by the approximation model as the desired one for the target model (Section 2.2), but further improvements might be obtained by applying different transformations. We tested speculative decoding only in the text modality, but it might work well in other domains (e.g. images) which would be interesting to experiment with.
Finally, we note that stochastic speculative execution and speculative sampling can be helpful outside the scope of speculative decoding from autoregressive models. For example, given two slow functions, and such that generates a distribution from which ’s input is sampled, we could use our method to run and in parallel. This setup might arise e.g. in physics simulations, or in reinforcement learning where is a large model that produces a distribution on actions, and is the world simulation, which would be interesting to explore.
Acknowledgments
We would like to extend a special thank you to YaGuang Li for help with everything LaMDA related and for calculating the LaMDA figures in the paper, and to Blake Hechtman for great insights and help with XLA. We would also like to thank the reviewers for insightful comments, as well as Asaf Aharoni, Reiner Pope, Sasha Goldshtein, Nadav Sherman, Eyal Segalis, Eyal Molad, Dani Valevski, Daniel Wasserman, Valerie Nygaard, Danny Vainstein, the LaMDA and Theta Labs teams at Google, and our families.
References
Appendix A Appendix
We will now show that for any distributions and , the tokens sampled via speculative sampling from and are distributed identically to those sampled from alone. Let be the acceptance probability (Definition 3.1).
Note that as , the normalizing constant for the adjusted distribution is , where the last equation follows immediately from Lemma 3.3 and Theorem 3.5.
A.2 Speculative Sampling vs. Rejection Sampling
Rejection sampling is the following iterative sampling procedure that looks superficially similar to ours:
Where . We could employ a non-iterative version of rejection sampling instead of speculative sampling - specifically go through steps 1 and 2 above, and otherwise sample from an unmodified directly. That would be much less efficient than our method though. Specifically, the expected accept probability here is is (potentially much) lower than the expected accept probability in our method .
A.3 Theoretical Predictions vs. Empirical Runtimes
Table 4 compares the expected runtime improvements based on Theorem 3.8 to the empirically measured runtimes from Table 2. We estimated the values of for the various models based on profiler traces. We can see that the theoretical predictions mostly match the measured runtimes. The larger differences are due to: (1) optimization differences between our implementation and the baseline, and (2) the simplifying assumption that the s are i.i.d. being only an approximation (see Section 3.1).
A.4 Application to Beam Search
Our method can be applied, with some performance penalty, to beam search sampling. Given the original beam width , we can perform beam search with the approximation model and beam width for steps. Then, we can use to check all of the candidates in parallel (costing a compute budget of runs of ). Finally, for each step, we can accept the guesses of as long as to get identical results to regular beam search with alone (with a more elaborate procedure we could also accept cases where the candidates we got happen to have higher probabilities than those of alone). The analysis of our method in this setting is more involved and we leave it for future work.
A.5 Lenience
A strong property of Algorithm 1 is that the output distribution is guaranteed to remain unchanged. That said, if we’re willing to allow some changes, with nice guarantees, we can get further inference speed improvements. To further motivate this, note that when we train two models with identical architectures and sizes on the same dataset, the generated probability distributions will not be identical, so some lenience might make sense. Note that the results in this paper except for this section use the strictest version of Algorithm 1 and don’t allow lenience of any kind.
We could include a lenience parameter and multiply by before comparing with in Algorithm 1. This still maintains the nice guarantee that no token can be sampled with probability greater than . This means for example, that with no token can be sampled with more than X its ground truth probability, so we can guarantee that extremely rare tokens will remain extremely rare (there is no guarantee on the minimum probability, so lenience could hurt the diversity of the samples).
Specifically, with a lenience factor we have \alpha=E_{x\sim q(x)}\left\{\begin{tabular}[]{@{}p{0.5cm}p{2cm}}1&lq(x)\leq p(x)\frac{p(x)}{lq(x)}&lq(x)>p(x)\\ \end{tabular}\right.=E_{x\sim q(x)}\frac{p(x)}{max(p(x),lq(x))}=\sum_{x}\frac{p(x)q(x)}{max(p(x),lq(x))}=\frac{1}{l}\sum_{x}\min(p(x),lq(x))=\sum_{x}\min(\frac{p(x)}{l},q(x)).
Table 5 shows values for different values of when is T5-XXL (11B) and is T5-small (77M). With , using lenience values of 1, 0.5, 0.3, and 0.1 (meaning that no token can be sampled with probability greater than 1X, 2X, 3X and 10X of the ground truth) we get improvement factors of 2.5X, 3.1X, 3.6X, and 5X respectively.
Note that when using temperature = 0 (i.e. argmax sampling), we can no longer use lenience as above. Instead, we could allow some lenience before standardizing the distributions. For example, we could accept the token sampled from in case . In this case, we measure similar empirical increases in values to those with temperature = 1. For example, when using lenience values of 1, 0.5, 0.3, and 0.1 for T5-XXL T5-small for English-German translation, we get values of 0.75, 0.75, 0.8, 0.87. Taking for example and we get speed improvement factors of 3.3X, 3.3X, 3.9X, and 4.9X respectivelyIn this case, unlike in the standard sampling case shown in Table 5, a lenience factor of 0.5 doesn’t improve the speed-up..