Fast Inference from Transformers via Speculative Decoding

Yaniv Leviathan, Matan Kalman, Yossi Matias

Introduction

Large autoregressive models, notably large Transformers (Vaswani et al., 2017), are much more capable than smaller models, as is evidenced countless times in recent years e.g., in the text or image domains, like GPT-3 (Brown et al., 2020), LaMDA (Thoppilan et al., 2022), Parti (Yu et al., 2022), and PaLM (Chowdhery et al., 2022). Unfortunately, a single decode step from these larger models is significantly slower than a step from their smaller counterparts, and making things worse, these steps are done serially - decoding KK tokens takes KK serial runs of the model.

Given the importance of large autoregressive models and specifically large Transformers, several approaches were developed to make inference from them faster. Some approaches aim to reduce the inference cost for all inputs equally (e.g. Hinton et al., 2015; Jaszczur et al., 2021; Hubara et al., 2016; So et al., 2021; Shazeer, 2019). Other approaches stem from the observation that not all inference steps are born alike - some require a very large model, while others can be approximated well by more efficient models. These adaptive computation methods (e.g. Han et al., 2021; Sukhbaatar et al., 2019; Schuster et al., 2021; Scardapane et al., 2020; Bapna et al., 2020; Elbayad et al., 2019; Schwartz et al., 2020) aim to use less compute resources for easier inference steps. While many of these solutions have proven extremely effective in practice, they usually require changing the model architecture, changing the training-procedure and re-training the models, and don’t maintain identical outputs.

The key observation above, that some inference steps are “harder” and some are “easier”, is also a key motivator for our work. We additionally observe that inference from large models is often not bottlenecked on arithmetic operations, but rather on memory bandwidth and communication, so additional computation resources might be available. Therefore we suggest increasing concurrency as a complementary approach to using an adaptive amount of computation. Specifically, we are able to accelerate inference without changing the model architectures, without changing the training-procedures or needing to re-train the models, and without changing the model output distribution. This is accomplished via speculative execution.

Speculative execution (Burton, 1985; Hennessy & Patterson, 2012) is an optimization technique, common in processors, where a task is performed in parallel to verifying if it’s actually needed - the payoff being increased concurrency. A well-known example of speculative execution is branch prediction. For speculative execution to be effective, we need an efficient mechanism to suggest tasks to execute that are likely to be needed. In this work, we generalize speculative execution to the stochastic setting - where a task might be needed with some probability. Applying this to decoding from autoregressive models like Transformers, we sample generations from more efficient approximation models as speculative prefixes for the slower target models. With a novel sampling method, speculative sampling, we maximize the probability of these speculative tasks to be accepted, while guaranteeing that the outputs from our system have the same distribution as those from the target model alone. For example, the sentence in Figure 1, consisting of 38 tokens, was generated by our method with only 9 serial runs of a larger target model (97M parameters) thanks to a smaller and more efficient approximation model (6M parameters), while the probability of generating it is unchanged.

We analyze our method in a variety of tasks and model sizes: unconditional generation from a 97M parameter GPT-like model trained on lm1b, English to German translation and news article summarization with an 11B parameters T5-XXL model, and a dialog task with a 137B parameter LaMDA model. We implement our method and compare actual walltimes for T5-XXL to those of the robust T5X implementation (Roberts et al., 2022), showing an out-of-the-box latency improvement of 2X-3X, without any change to the outputs (Section 4).

Our method is easy to employ in actual production settings, doesn’t require training new models, and doesn’t change the outputs. Therefore, in common situations where memory bandwidth is the bottleneck, and compute resources are available, it may be a good default to accelerate sampling from autoregressive models like Transformers.

To summarize, our main contributions are: (1) A generalization of speculative execution to the stochastic setting, with a novel sampling method we call speculative sampling, and (2) A decoding mechanism we call speculative decoding that can accelerate decoding from autoregressive models, without any change to the model architectures, training regimes and output distributions.

Speculative Decoding

2 Standardized Sampling

First, note that while there are many methods and parameters of sampling, like argmax, top-k, nucleus, and setting a temperature, and popular implementations usually treat them differently at the logits level, they can all easily be cast into standard sampling from an adjusted probability distribution. For example, argmax sampling is equivalent to zeroing out non-max elements of the distribution and normalizing. We can therefore only deal with standard sampling from a probability distribution, and cast all of the other types of sampling into that framework. Going forward we’ll assume that p(x)p(x) and q(x)q(x) are the distributions from MpM_{p} and MqM_{q} respectively, adjusted for the sampling method.

3 Speculative Sampling

To sample xp(x)x\sim p(x), we instead sample xq(x)x\sim q(x), keeping it if q(x)p(x)q(x)\leq p(x), and in case q(x)>p(x)q(x)>p(x) we reject the sample with probability 1p(x)q(x)1-\frac{p(x)}{q(x)} and sample xx again from an adjusted distribution p(x)=norm(max(0,p(x)q(x)))p^{\prime}(x)=norm(max(0,p(x)-q(x))) instead. It’s easy to show (see Section A.1) that for any distributions p(x)p(x) and q(x)q(x), and xx sampled in this way, indeed xp(x)x\sim p(x).

Given the distribution q(x)q(x) obtained from running MqM_{q} on a conditioning prefixprefix, we can sample a token x1q(x)x_{1}\sim q(x). We then calculate the distribution p(x)p(x) by running MpM_{p} on prefixprefix while in parallel speculatively calculating the distribution of the next token x2x_{2} by running MpM_{p} on prefix+[x1]prefix+[x_{1}]. Once both computations complete, we proceed as per above: If x1x_{1} is rejected, we discard the computation of x2x_{2} and re-sample x1x_{1} from an adjusted distribution, and if x1x_{1} is accepted, we keep both tokens. Algorithm 1 generalizes this idea to sample between 1 and γ+1\gamma+1 tokens at once.

Analysis

Let’s analyze the reduction factor in the number of serial calls to the target model, or equivalently, the expected number of tokens produced by a single run of Algorithm 1.

E(β)E(\beta) is then a natural measure of how well MqM_{q} approximates MpM_{p}. If we make the simplifying assumption that the β\betas are i.i.d., and denote α=E(β)\alpha=E(\beta), then the number of tokens produced by a single run of Algorithm 1 is a capped geometric variable, with success probability 1α1-\alpha and cap γ+1\gamma+1, and the expected number of tokens generated by Algorithm 1 satisfies Equation 1. See Figure 2.

2 Calculating α𝛼\alpha

We’ll now derive a simple formula for calculating α\alpha given a prefix and the two models MpM_{p} and MqM_{q}. We start by defining a natural divergence DLKD_{LK}:

DLK(p,q)=xp(x)M(x)=xq(x)M(x)D_{LK}(p,q)=\sum_{x}|p(x)-M(x)|=\sum_{x}|q(x)-M(x)| where M(x)=p(x)+q(x)2M(x)=\frac{p(x)+q(x)}{2}.

DLK(p,q)=xp(x)M(x)=xpq2=1xp+qpq2=1xmin(p(x),q(x))D_{LK}(p,q)=\sum_{x}|p(x)-M(x)|=\sum_{x}\frac{|p-q|}{2}=1-\sum_{x}\frac{p+q-|p-q|}{2}=1-\sum_{x}\min(p(x),q(x))

From Lemma 3.3 we immediately get the following results:

DLK(p,q) is a symmetric divergence in .DLK(p,q)=0    p=q.DLK(p,q)=1    p and q have disjoint support.D_{LK}(p,q)\ \text{is a symmetric divergence in}\ .\\ D_{LK}(p,q)=0\iff p=q.\\ D_{LK}(p,q)=1\iff\text{p and q have disjoint support}.

\beta=E_{x\sim q(x)}\left\{\begin{tabular}[]{@{}p{0.5cm}p{1.8cm}}1&q(x)\leq p(x)\\\frac{p(x)}{q(x)}&q(x)>p(x)\\ \end{tabular}\right.=E_{x\sim q(x)}\min(1,\frac{p(x)}{q(x)})=\sum_{x}\min(p(x),q(x))

See Footnote for empirically observed α\alpha values in our experiments.

3 Walltime Improvement

We’ve shown that with the i.i.d. assumption our algorithm reduces the number of calls to the target model by a factor of 1αγ+11α\frac{1-\alpha^{\gamma+1}}{1-\alpha}. Note that speculative execution in general, and our algorithm in particular, assume that we have enough compute resources to support the increased concurrency (Section 3.4). For the walltime anaylsis, we’ll assume that we can run γ+1\gamma+1 concurrent evaluations of MpM_{p} in parallel without increasing the walltime. To get the total walltime improvement, we now consider the cost of running the approximation model MqM_{q}.

Let cc, the cost coefficient, be the ratio between the time for a single run of MqM_{q} and the time for a single run of MpM_{p}.

Note that unlike α\alpha which is an intrinsic property of the models and the task, the value of cc depends on the hardware configuration and software implementation details. In our experiments where MqM_{q} is typically a couple of orders of magnitude smaller than MpM_{p}, cc was always less than 0.050.05 and often negligibly close to 0.

The expected improvement factor in total walltime by Algorithm 1 is 1αγ+1(1α)(γc+1)\frac{1-\alpha^{\gamma+1}}{(1-\alpha)({\gamma}c+1)}.

Denote the cost of running a single step of MpM_{p} by TT. Now, each run of Algorithm 1 costs Tcγ+TTc\gamma+T (for running the approximation model MqM_{q} γ\gamma times and running MpM_{p} once) and according to Equation 1 produces 1αγ+11α\frac{1-\alpha^{\gamma+1}}{1-\alpha} tokens on average. So the overall expected cost for producing a token with Algorithm 1 is (cγ+1)(1α)1αγ+1T\frac{(c\gamma+1)(1-\alpha)}{1-\alpha^{\gamma+1}}T. Since the cost of producing a single token with the standard decoding algorithm is TT, we get the desired result. ∎

Note that Theorem 3.8 assumes long enough generations (for example, since we run MpM_{p} at least once, the improvement factor is capped by the number of generated tokens).

If α>c\alpha>c, there exists γ\gamma for which we’ll get an improvement, and the improvement factor will be at least 1+α1+c\frac{1+\alpha}{1+c}.

If we get an improvement for γ\gamma, we’d also get an improvement for any 0<γ<γ0<\gamma^{*}<\gamma, so for our method to yield an improvement, we can evaluate Theorem 3.8 for γ=1\gamma=1, yielding 1α2(1α)(c+1)=1+α1+c\frac{1-\alpha^{2}}{(1-\alpha)(c+1)}=\frac{1+\alpha}{1+c}. ∎

4 Number of Arithmetic Operations

Algorithm 1 does γ+1\gamma+1 runs of MpM_{p} in parallel, so the number of concurrent arithmetic operations grows by a factor of γ+1\gamma+1. Now, since Algorithm 1 produces at most γ+1\gamma+1 tokens per run, the total number of arithmetic operations might be higher than that of the standard decoding algorithm. When we accept the sample from MqM_{q} the increased concurrency is “free” and the total number of operations isn’t increasedNeglecting the cost of MqM_{q}.. When we reject a guess though, computation is wasted. Let’s now analyze the effect of our method on the total number of arithmetic operations.

Let c^\hat{c} be the ratio of arithmetic operations per token of the approximation model MqM_{q} to that of the target model MpM_{p}.

The expected factor of increase in the number of total operations of Algorithm 1 is (1α)(γc^+γ+1)1αγ+1\frac{(1-\alpha)({\gamma}\hat{c}+\gamma+1)}{1-\alpha^{\gamma+1}}.

Denote by T^\hat{T} the number of arithmetic operations done by a standard decoding baseline per token, i.e. the number of operations of a single run of MpM_{p}. Then a single iteration of Algorithm 1 costs T^c^γ+T^(γ+1)\hat{T}\hat{c}\gamma+\hat{T}(\gamma+1) operations (for γ\gamma runs of MqM_{q} and γ+1\gamma+1 parallel runs of MpM_{p}). Dividing by the expected number of tokens produced by Algorithm 1, i.e. Equation 1, and by T^\hat{T}, we get the desired result. ∎

If α\alpha is low, the increase in the number of arithmetic operations is high, and vice-versa. Note that for Transformer decoders, the total number of arithmetic operations by Algorithm 1 (not counting runs of MqM_{q}) can be bounded from above by a single run of the same-size Transformer encoder.

Unlike the total number of arithmetic operations, the total number of memory accesses can go down with our method. Specifically, the target model’s weights and KV cache can be read once per execution of Algorithm 1, so the number of memory accesses for reading them shrinks by a factor of 1αγ+11α\frac{1-\alpha^{\gamma+1}}{1-\alpha}, according to Equation 1.

5 Choosing γ𝛾\gamma

Given cc and α\alpha and assuming enough compute resources (see Section 3.4), the optimal γ\gamma is the one maximizing the walltime improvement equation (Theorem 3.8): 1αγ+1(1α)(γc+1)\frac{1-\alpha^{\gamma+1}}{(1-\alpha)({\gamma}c+1)}. Since γ\gamma is an integer, it can be easily found numerically, see Figure 3.

Table 1 and Figure 4 illustrate the trade-off between inference speed and the total number of arithmetic operations for various values of α\alpha and γ\gamma, assuming c=c^=0c=\hat{c}=0. Figure 5 shows a simplified trace diagram.

Instead of picking a single value for γ\gamma based on α\alpha, since the β\betas aren’t constant, we could get further improvement by predicting the value of β\beta and accordingly varying the value of γ\gamma during the run of Algorithm 1. To get an upper bound on the additional improvement factor, assume we had an oracle for γ\gamma. We would then have E(# generated tokens)=11αE(\#\ generated\ tokens)=\frac{1}{1-\alpha}. For typical values of cc and α\alpha, and assuming unbounded compute resources, the enhanced walltime improvement factor can be up to \sim60% higher than the improvement factor with a fixed γ\gamma. We leave exploring this for future workThe above bound assumes that we still run MpM_{p} to verify the oracle’s predictions. If we skip those verifications the bound doesn’t hold and we would get a substantial additional improvement..

6 Approximation Models

Speculative sampling, and therefore speculative decoding, guarantee an identical output distribution for any choice of approximation model MqM_{q} without restriction (see Section A.1). In our experiments, we mostly tested existing off-the-shelf smaller Transformers as the approximation models. Further, we only tested approximation models of the same architecture as the target models MpM_{p} and using the same probability standardization. In this setup, choosing MqM_{q} to be around two orders of magnitude smaller than MpM_{p} usually performed best, balancing α\alpha and cc (Theorem 3.8).

Another type of approximation models, negligible-cost models, are those for which c0c\approx 0, i.e. approximation models with a negligible cost relative to the target model. In this case, we get an expected walltime improvement of 1αγ+11α\frac{1-\alpha^{\gamma+1}}{1-\alpha}, which is bounded from above by 11α\frac{1}{1-\alpha} (we approach equality if γ\gamma is large). One interesting type of negligible-cost approximation models are n-gram models, where the evaluation amounts to a table lookup. Interestingly, in empirical tests (Section 4.2) we get non zero α\alphas even for these trivial n-gram models. For example, for the English-German translation task, with MpM_{p} being T5-XXL 11B and MqM_{q} being a trivial bigram model, we get α0.2\alpha\approx 0.2 which leads to an inference speed improvement factor of 1.251.25X with γ=3\gamma=3.

Other simple heuristics can be used as negligible-cost approximation models. For example, in cases where long sequences are likely to repeat, such as for summarization tasks or chat-like interfaces E.g. where a user and a language model iterate on content, like text or code (“can you rewrite this story but change the ending”, “can you make this function also do X”)., an approximation model that simply copies tokens from the context in case we find a matching prefix, might yield high values of α\alpha. These parameter-less approximation models, have the additional advantage of being even simpler to deploy from a production standpoint.

Another type of approximation models that can be used by speculative decoding are non-autoregressive models, like those from (Stern et al., 2018). Then, instead of the autogreressive loop in Algorithm 1 we’d just call the non-autoregressive model once.

A final example, interesting mostly from a theoretical perspective, is an approximation model which chooses tokens at random, which guarantees some improvement (although very small) for all models MpM_{p}.

Experiments

We implement our algorithm and compare it to the implementation in the T5X codebase for accelerating T5-XXL.

We test a standard encoder-decoder T5 version 1.1 model (Raffel et al., 2020) on two tasks from the T5 paper: (1) English to German translation fine tuned on WMT EnDe, and (2) Text summarization fine tuned on CCN/DM. For both tasks, we use T5-XXL (11B) for MpM_{p}. For the approximation model MqM_{q} we test several existing configurations, namely T5-large (800M), T5-base (250M), and T5-small (77M) (Raffel et al., 2020). We use existing checkpoints for all models. We measure walltime improvements with a batch size of 1 on a single TPU-v4 for both argmax sampling (temp=0) and standard sampling (temp=1).

Results

Table 2 shows the empirical results from our method. We see that T5-small (77M), with a good balance of cc and α\alpha, provides the highest speedup out of the tested approximation models. As expected we see that α\alpha increases with the size of the approximation model. Interestingly, α\alpha and walltime improvement are higher for argmax sampling (temp=0). We observe speedups of 2.6X (temp=1) and 3.4X (temp=0) on the translation task and slightly lower speedups of 2.3X (temp=1) and 3.1X (temp=0) for the summarization task. These empirical results match well with the theoretical predictions, with some variance due to implementation details (see Section A.3).

2 Empirical α𝛼\alpha Values

While we only implemented our method for T5, we measured α\alpha values for various tasks, sampling methods, target models MpM_{p}, and approximation models MqM_{q}. Specifically, we evaluated the expectation from Corollary 3.6 on 10K tokens generated by MpM_{p}, for each of the settings below.

We test a decoder-only Transformer model on unconditional language generation, trained on lm1b (Chelba et al., 2013). The model here is a GPT-like Transformer decoder with Gelu activations (Hendrycks & Gimpel, 2016). For MqM_{q} we experimented with a Transformer decoder model with 6M parameters: dim 256, dim feed-forward 1024, 2 layers, 4 attention heads, as well as simple unigram and bigram models. MpM_{p} has 97M parameters: dim 768, dim feed-forward 3072, 12 layers, 12 attention heads. We used Bert tokenization (Devlin et al., 2019) with 8k tokens for all models.

LaMDA (137B params)

We tested a decoder only LaMDA model on a dialog task (Thoppilan et al., 2022). We used existing checkpoints from LaMDA 137B as MpM_{p} and LaMDA 8B, LaMDA 2B, and LaMDA 100M for MqM_{q}.

See Section 4.1 for the setup of the T5-XXL (11B params) model.

Footnote summarizes the α\alpha values for the tested cases. We observe that approximation models that are a couple of orders of magnitude smaller than the target model tend to produce α\alpha values between 0.5 and 0.9. Interestingly, we also note that for all models, the sharper the adjusted distribution, the higher the α\alpha values. Finally, we note that even trivial unigram and bigram approximations yield non negligible α\alpha values. For example, for the case of English to German translation, the bigram model has an α\alpha value of 0.2, and since c=0c=0 in this case, yields a 1.25X speed improvement, which is surprisingly high for this trivial approximation model (but is still lower than the speedup we get from using T5-small as the approximation model).

Related work

The efficiency of inference from large models was studied extensively (Dehghani et al., 2021). Many approaches aim to speed up inference from large models in general, and autoregressive models like Transformers in particular. Numerous techniques try to make inference more efficient for all tokens, e.g. distillation (Hinton et al., 2015), sparcification (Jaszczur et al., 2021), quantization (Hubara et al., 2016), and architecture modification (So et al., 2021; Shazeer, 2019). Closer to our approach are adaptive computation methods which adapt the amount of computation to problem difficulty (Han et al., 2021). Examples include attending to a subset of the inputs (Sukhbaatar et al., 2019), and early exits (Schuster et al., 2021; Scardapane et al., 2020; Bapna et al., 2020; Elbayad et al., 2019; Schwartz et al., 2020). Notably, Wisdom of Committees (Schwartz et al., 2020) leverages off-the-shelf smaller models, but is an adaptive computation approach, and so it uses a heuristic to determine when to stop, losing the guarantee of identical outputs to those of the target models. In general, adaptive computation methods usually learn, either within the model itself or with an auxiliary model, when a computation shortcut can be taken. Usually, these methods save on both inference time and arithmetic operations, but require a change of architecture, a change of training procedure and training custom models or re-training of existing models. They usually also change the outputs of the model. We note that while many of the methods above improve the memory to arithmetic-operations ratio, in cases where the ratio remains high, these methods and our speculative decoding method might be effective in tandem.

Two prior methods leverage speculative execution for speeding up decoding from autoregressive models. Blockwise Parallel Decoding (Stern et al., 2018) decodes several tokens in parallel, similarly to our work. However, it only supports greedy decoding (temperature=0) and not the general stochastic setting, it requires additional training of a custom model, and focuses on preserving down-stream task quality, instead of guaranteeing identical outputs. Shallow Aggressive Decoding (SAD) (Sun et al., 2021) also decodes several tokens in parallel, similarly to our work. Unlike our work, SAD only supports copying the input to the output, and not general approximation models, making it only suitable for the cases where the inputs and outputs are very similar like grammatical error correction. In addition, similarly to Blockwise Parallel Decoding, SAD does not support the general stochastic sampling setting.

After we initially published our work, an independent implementation of speculative decoding (Chen et al., 2023) showed similar 2X-2.5X improvements on Chinchilla 70B.

Discussion

We presented speculative sampling which enables efficient stochastic speculative execution - i.e. speculative execution in the stochastic setting. We analyzed its impact on decoding from autoregressive models like Transformers via speculative decoding and have shown that given enough compute resources, we get meaningful 2X-3X speedups in practice vs T5X, a popular optimized implementation.

One limitation of speculative execution in general, and of speculative decoding in particular, is that latency is improved through increased concurrency at the cost of an increased number of arithmetic operations. Thus, our method is not helpful for configurations where additional computation resources are not available. However, in common cases where additional computation resources are available (e.g. when memory bandwidth is the bottleneck) our method provides the speedup with significant benefits: the model architecture doesn’t change, retraining isn’t required, and most importantly, the output distribution is guaranteed to stay the same. Our method is easy to implement, and can be used to speedup inference using out-of-the-box models without developing and evaluating custom schemes.

There are several directions for follow up research, importantly, further investigating the compatibility of speculative decoding with beam search (see Section A.4). Also, while our method yields substantial speedups with existing off-the-shelf approximation models, greater improvements might be obtained via custom approximation models (Section 3.6), such as those with custom architectures (e.g. custom sizes, non-autoregressive models, or various heuristics) or with custom training procedures (e.g. standard distillation with soft targets from MpM_{p}, or optimizing MqM_{q} for α\alpha directly). It could also be interesting to explore a hierarchical version of the algorithm, where the approximation model is itself accelerated by an even faster model, which could allow for more capable approximation models. In this work we fixed the approximation model and the number of guesses γ\gamma throughout inference, but varying them during inference could yield additional improvements (Section 3.5). In our experiments we always performed the same standardization on the distributions generated by the approximation model as the desired one for the target model (Section 2.2), but further improvements might be obtained by applying different transformations. We tested speculative decoding only in the text modality, but it might work well in other domains (e.g. images) which would be interesting to experiment with.

Finally, we note that stochastic speculative execution and speculative sampling can be helpful outside the scope of speculative decoding from autoregressive models. For example, given two slow functions, f(x)f(x) and g(y)g(y) such that f(x)f(x) generates a distribution from which gg’s input is sampled, we could use our method to run ff and gg in parallel. This setup might arise e.g. in physics simulations, or in reinforcement learning where ff is a large model that produces a distribution on actions, and gg is the world simulation, which would be interesting to explore.

Acknowledgments

We would like to extend a special thank you to YaGuang Li for help with everything LaMDA related and for calculating the LaMDA figures in the paper, and to Blake Hechtman for great insights and help with XLA. We would also like to thank the reviewers for insightful comments, as well as Asaf Aharoni, Reiner Pope, Sasha Goldshtein, Nadav Sherman, Eyal Segalis, Eyal Molad, Dani Valevski, Daniel Wasserman, Valerie Nygaard, Danny Vainstein, the LaMDA and Theta Labs teams at Google, and our families.

References

Appendix A Appendix

We will now show that for any distributions p(x)p(x) and q(x)q(x), the tokens sampled via speculative sampling from p(x)p(x) and q(x)q(x) are distributed identically to those sampled from p(x)p(x) alone. Let β\beta be the acceptance probability (Definition 3.1).

Note that as p(x)=norm(max(0,p(x)q(x)))=p(x)min(q(x),p(x))x(p(x)min(q(x),p(x)))=p(x)min(q(x),p(x))1βp^{\prime}(x)=norm(max(0,p(x)-q(x)))=\frac{p(x)-min(q(x),p(x))}{\sum_{x^{\prime}}(p(x^{\prime})-min(q(x^{\prime}),p(x^{\prime})))}=\frac{p(x)-min(q(x),p(x))}{1-\beta}, the normalizing constant for the adjusted distribution p(x)p^{\prime}(x) is 1β1-\beta, where the last equation follows immediately from Lemma 3.3 and Theorem 3.5.

A.2 Speculative Sampling vs. Rejection Sampling

Rejection sampling is the following iterative sampling procedure that looks superficially similar to ours:

Where M=maxxp(x)q(x)M=max_{x}\frac{p(x)}{q(x)}. We could employ a non-iterative version of rejection sampling instead of speculative sampling - specifically go through steps 1 and 2 above, and otherwise sample from an unmodified p(x)p(x) directly. That would be much less efficient than our method though. Specifically, the expected accept probability here is Exq(x)p(x)Mq(x)=xp(x)minxq(x)p(x)xp(x)min(1,q(x)p(x))=xmin(p(x),q(x))=αE_{x\sim q(x)}\frac{p(x)}{Mq(x)}=\sum_{x}p(x)\min_{x^{\prime}}\frac{q(x^{\prime})}{p(x^{\prime})}\leq\sum_{x}p(x)\min(1,\frac{q(x)}{p(x)})=\sum_{x}\min(p(x),q(x))=\alpha is (potentially much) lower than the expected accept probability in our method α\alpha.

A.3 Theoretical Predictions vs. Empirical Runtimes

Table 4 compares the expected runtime improvements based on Theorem 3.8 to the empirically measured runtimes from Table 2. We estimated the values of cc for the various models based on profiler traces. We can see that the theoretical predictions mostly match the measured runtimes. The larger differences are due to: (1) optimization differences between our implementation and the baseline, and (2) the simplifying assumption that the β\betas are i.i.d. being only an approximation (see Section 3.1).

A.4 Application to Beam Search

Our method can be applied, with some performance penalty, to beam search sampling. Given the original beam width ww, we can perform beam search with the approximation model MqM_{q} and beam width uwu\geq w for γ\gamma steps. Then, we can use MpM_{p} to check all of the candidates in parallel (costing a compute budget of (w+uγ)(w+u\gamma) runs of MpM_{p}). Finally, for each step, we can accept the guesses of MqM_{q} as long as topw(Mp)topu(Mq)top_{w}(M_{p})\subseteq top_{u}(M_{q}) to get identical results to regular beam search with MpM_{p} alone (with a more elaborate procedure we could also accept cases where the candidates we got happen to have higher probabilities than those of MpM_{p} alone). The analysis of our method in this setting is more involved and we leave it for future work.

A.5 Lenience

A strong property of Algorithm 1 is that the output distribution is guaranteed to remain unchanged. That said, if we’re willing to allow some changes, with nice guarantees, we can get further inference speed improvements. To further motivate this, note that when we train two models with identical architectures and sizes on the same dataset, the generated probability distributions will not be identical, so some lenience might make sense. Note that the results in this paper except for this section use the strictest version of Algorithm 1 and don’t allow lenience of any kind.

We could include a lenience parameter ll\in and multiply q(x)q(x) by ll before comparing with p(x)p(x) in Algorithm 1. This still maintains the nice guarantee that no token can be sampled with probability greater than p(x)l\frac{p(x)}{l}. This means for example, that with l=110l=\frac{1}{10} no token can be sampled with more than 1010X its ground truth probability, so we can guarantee that extremely rare tokens will remain extremely rare (there is no guarantee on the minimum probability, so lenience could hurt the diversity of the samples).

Specifically, with a lenience factor ll we have \alpha=E_{x\sim q(x)}\left\{\begin{tabular}[]{@{}p{0.5cm}p{2cm}}1&lq(x)\leq p(x)\\\frac{p(x)}{lq(x)}&lq(x)>p(x)\\ \end{tabular}\right.=E_{x\sim q(x)}\frac{p(x)}{max(p(x),lq(x))}=\sum_{x}\frac{p(x)q(x)}{max(p(x),lq(x))}=\frac{1}{l}\sum_{x}\min(p(x),lq(x))=\sum_{x}\min(\frac{p(x)}{l},q(x)).

Table 5 shows α\alpha values for different values of ll when MpM_{p} is T5-XXL (11B) and MqM_{q} is T5-small (77M). With c=0.015c=0.015, using lenience values of 1, 0.5, 0.3, and 0.1 (meaning that no token can be sampled with probability greater than 1X, 2X, 3X and 10X of the ground truth) we get improvement factors of 2.5X, 3.1X, 3.6X, and 5X respectively.

Note that when using temperature = 0 (i.e. argmax sampling), we can no longer use lenience as above. Instead, we could allow some lenience before standardizing the distributions. For example, we could accept the token xx sampled from MqM_{q} in case p(x)lmax(p)p(x)\leq l\cdot max(p). In this case, we measure similar empirical increases in α\alpha values to those with temperature = 1. For example, when using lenience values of 1, 0.5, 0.3, and 0.1 for MpM_{p} T5-XXL MqM_{q} T5-small for English-German translation, we get α\alpha values of 0.75, 0.75, 0.8, 0.87. Taking for example c=0.015c=0.015 and γ=8\gamma=8 we get speed improvement factors of 3.3X, 3.3X, 3.9X, and 4.9X respectivelyIn this case, unlike in the standard sampling case shown in Table 5, a lenience factor of 0.5 doesn’t improve the speed-up..