SpecTr: Fast Speculative Decoding via Optimal Transport

Ziteng Sun, Ananda Theertha Suresh, Jae Hun Ro, Ahmad Beirami, Himanshu Jain, Felix Yu

Introduction

Autoregressive language models have shown to achieve state-of-the-art results in several natural language tasks . During inference, given a context xt:=x(1),x(2),x(t)x^{t}{:=}x(1),x(2)\ldots,x(t), an autoregressive model Mb{\cal M}_{b} generates successive tokens x(t+1),x(t+2),x(t+1),x(t+2),\ldots via temperature sampling , where the next token x(t+1)x(t+1) is drawn from the temperature-scaled distribution Mb(xt){\cal M}_{b}(\cdot|x^{t}). If the temperature is zero, i.e., greedy decoding, the next token is determined by the maximum likelihood method i.e., x(t+1)=argmaxxΩMb(xxt)x(t+1)=\arg\max_{x\in\Omega}{\cal M}_{b}(x|x^{t}), where Ω\Omega is the domain of a single token also referred to as the vocabulary. The sampling approach can be further combined with other sampling primitives such as nucleus sampling and top-kk sampling .

All these approaches are autoregressive decodingIn this work, we use the words sampling and decoding interchangably to refer to the process of sequentially generating tokens from a language model. methods, where tokens are generated serially one after another, which can be slow or even prohibitive in several applications . Hence, several techniques have been proposed to improve the speed of decoding. Before we proceed further, we first present some notations and a simplified computational model.

Standard inference. Given a context xtx^{t}, with O(t2)O(t^{2}) computation and O(1)O(1) time, an autoregressive model Mb{\cal M}_{b} can compute Mb(yxt){\cal M}_{b}(y|x^{t}), the (temperature-scaled) probability of all possible next tokens yΩy\in\Omega.

Parallelization along the time axis. Given a context xtx^{t}, with O(t2)O(t^{2}) computation and O(1)O(1) time, an autoregressive model Mb{\cal M}_{b} can compute Mb(yxi){\cal M}_{b}(y|x^{i}), for all yΩy\in\Omega and i{1,2,,t}i\in\{1,2,\ldots,t\}.

Parallelization along time and batch axis. Let KK be the maximum batch size that can be used during the inference of the autoregressive model. Given several contexts, x1t,x2t,xKtx^{t}_{1},x^{t}_{2},\ldots x^{t}_{K}, with O(Kt2)O(Kt^{2}) computation and O(1)O(1) time, an autoregressive model Mb{\cal M}_{b} can compute Mb(yxji){\cal M}_{b}(y|x^{i}_{j}), for all yΩy\in\Omega, i[t]i\in[t], and j[K]j\in[K].When the assumption holds, one could naively batch multiple decoding contexts, which improves decoding throughput, but not the latency of each context.

The above computation model shows that parallelizing along time and batch axes does not increase the computation time. It is a simplified characterization of the typical hardware, such as TPUs and GPUs, used in neural network inference. Previous approaches also assume similar computational model to devise faster decoding algorithms . In practice, there will be some overhead depending on hardware, implementation and resource utilization. In Appendix E, we experimentally verify that the theoretical gains are largely preserved for a large transformer model in practice. We also note that there are efficient transformer architectures, which reduces the computation cost from O(t2)O(t^{2}) to O(tlogt)O(t\log t) (see for a detailed survey). Such approaches are orthogonal to the focus of this paper, and they can be easily combined with our approach.

Broadly speaking, multiple previous approaches proposed to guess a few possible future tokens using an efficient model. They then compute several conditional probability distributions from the large model based on the guesses. Computing the distributions takes O(1)O(1) time due to parallelization along the time axis. The guessed tokens are then accepted or rejected based on a statistical method such that the accepted tokens are effectively samples from the large model. This guarantees that there is provably no degradation in the quality of the decoded output compared to that of the large model. When the guesses are plausible under the large model, multiple tokens will be accepted, leading to a larger gain in latency improvement. We will further characterize the acceptance probability as a function of the closeness of the distributions of large model and the small model. While this approach incurs the same computation cost as vanilla decoding (under the simplified computational model assumed in this paper), it can significantly improve decoding latency due to parallelization.

The goal of this work is to provide a principled understanding of the above approaches and discuss optimality conditions and algorithmic improvements. We start by providing a more formal overview of speculative decoding and related works.

Previous works and speculative decoding

Previous approaches make use of parallelization along the time axis to provide speedups. They first predict multiple tokens and validate if these multiple tokens can be generated by the model with the corresponding sampling or decoding scheme. For greedy decoding, multiple tokens can be predicted by a separate model , aggressive decoding , or retrieval augmented text . For sampling, recently proposed an algorithm called speculative decoding, and we provide an overview of this algorithm in the rest of the section. Suppose we have access to a computationally-inexpensive draft model Ms{\cal M}_{s}, which predicts the next token given the context, and the predictions of Ms{\cal M}_{s} are close to that of Mb{\cal M}_{b} for most contexts. Suppose we have obtained prefix xtx^{t}. The next iteration of the speculative algorithm can be broken down into three steps (see Fig. 1 for an illustration).

After this step, we use x1t+L+1x^{t+L^{\prime}+1}_{1} as the next context and sample the next few tokens using speculative decoding iteratively. For a complete statement of the algorithm, we refer the readers to . The crux of the above steps is draft selection, which given a draft sequence and the conditional probabilities from both models, selects a valid sequence such that the output has the same distribution as that of the large model. In speculative decoding, this is achieved via recursively applying a token-level maximal coupling algorithm, which is provided in Algorithm 1. Note that for the draft selection, Algorithm 1 is applied where pp is the conditional distribution of the draft model Ms(xt){\cal M}_{s}(\cdot\mid x^{t}) and qq is the conditional distribution of the large model Mb(xt){\cal M}_{b}(\cdot\mid x^{t}) (which may be further conditioned on the newly decoded tokens).

Algorithm 1 returns a random variable YY which either is the accepted input XX or a sample from the residual distribution presp^{\text{res}}, which is defined in Step 11 of Algorithm 1. The algorithm is recursively applied as long as the draft tokens are accepted to select the first LLL^{\prime}\leq L tokens from the draft model. For the first rejected token, the sample YY from the residual distribution is used as a correction. Previous works showed that if XpX\sim p, then YqY\sim q . In the case of the draft selection, this means that the output of the algorithm is distributed according to Mb(xt){\cal M}_{b}(\cdot\mid x^{t}), which is exactly the desired outcome. Furthermore

where dTVd_{\text{TV}} is the total variation distance between pp and qq. The closer pp and qq are in dTVd_{\text{TV}}, the higher the chance of Pr(Y=X)\Pr(Y=X), and fewer the number of serial calls to the larger model. In the ideal case, if p=qp=q, then Pr(Y=X)=1\Pr(Y=X)=1, i.e., the draft token is always accepted, and when used for speculative decoding we have L=L.L^{\prime}=L. Together with the extra sampled tokenWhen L=LL^{\prime}=L, x(t+L+1)x(t+L+1) is sampled from Mb(xt+L){\cal M}_{b}(\cdot\mid x^{t+L}). from Mb{\cal M}_{b}, L+1L+1 tokens are obtained in one iteration. In such a case, based on our computational model (Section 1), assuming the decoding time of draft model is negligible, the speedup is (L+1)(L+1) times.

Our contributions

From a theoretical viewpoint, the speculative decoding algorithm raises multiple questions.

What is the relationship between speculative decoding and the broader literature of sampling in statistics?

Is speculative decoding optimal in an information-theoretic sense?

Speculative decoding uses parallelization along time to speed up decoding; would it be possible to use parallelization along batch (number of drafts) to further improve decoding speed?

We provide answers to all the above questions in this work. We first relate the problem of speculative decoding to the broader and well-studied discrete optimal transport theory through a token-level coupling problem (Section 4). With this connection, it becomes clear that the token-level draft selection is the optimal solution for optimal transport with indicator cost function and also related to the problem of maximal coupling . Based on the connection to optimal transport, we show that one can further speed up the decoding by parallelizing along the batch axis by using multiple drafts from the draft model (Section 5).

More precisely, we formulate the token-level draft selection problem as a discrete optimal transport problem with membership cost, which is referred to as OTM. Discrete optimal transport can be solved with a linear program, but the number of variables is exponential in batch size, which can be prohibitive. To address this, we propose a valid transport plan that can be efficiently computed. Moreover, it achieves a (11/e)(1-1/e)-approximation of the optimal acceptance probability (Section 6).

With the theoretically motivated algorithms and guarantees, we circle back to speeding up decoding and propose a new algorithm called SpecTr and theoretically show that it can be used to derive valid sequences from the large model with better speedups (Section 7). See Fig. 2 for an illustration of SpecTr. Compared to speculative decoding (Fig. 1), the main difference lies in the number of sampled drafts sampled from the small model and the selection algorithm that selects a valid sequence from multiple draft sequences. We remark here that the latter requires completely new statistical tools, and the connection between the token-level draft selection and OTM is critical for obtaining valid transport plans with good guarantees. We view this as one of the main contributions of the work. Similar to speculative decoding, there is provably no degradation in the quality of the decoded output compared that of the large model.

We then experimentally demonstrate the benefit of our approach on standard datasets (Section 8). More precisely, we show that for state-of-the-art large language models, SpecTr achieves a wall clock speedup of 2.13X, a further 1.37X speedup over speculative decoding on standard benchmarks.

Token-level draft selection and optimal transport

In this section, we focus on the draft selection step of SpecTr. We start by considering the case when L=1L=1, which is a token-level draft selection problem. In particular, given context xtx^{t}, let X1,XkX_{1},\ldots X_{k} be a collection of draft tokens sampled from the small model, e.g., sampled i.i.d. from Ms(xt){\cal M}_{s}(\cdot\mid x^{t}). Note that by our assumption of the computation model, we could compute the following conditional probabilities from the large model in parallel ( along time and batch axes):

The goal of the draft selection algorithm f:ΩkΩf:\Omega^{k}\rightarrow\Omega is to output Y=f(Xk)Y=f(X^{k}), whose distribution follows Mb(xt){\cal M}_{b}(\cdot\mid x^{t}), and hence is a valid sample from the large model. Moreover, when Y{X1,,Xk}Y\in\{X_{1},\ldots,X_{k}\}, we could sample an extra token from Mb(xt,Y){\cal M}_{b}(\cdot\mid x^{t},Y) without calling Mb{\cal M}_{b} since we have already computed the conditional probabilities Mb(xt,Y){\cal M}_{b}(\cdot\mid x^{t},Y). Hence we would like to maximize the probability that we accept one token from the set of drafts.

When L>1L>1, the drafts are sequences sampled from Ms{\cal M}_{s}, a sequence of token-level draft selection algorithms could be used along the time axis to select a valid sequence from the Mb{\cal M}_{b}. See an example in Fig. 3. The full details about the sequence-level selection algorithm is provided in Section 7.

The reminder of the section will be focused on the token-level draft selection problem. From the above discussion, there are the two main goals of the draft selection problem.

Validity. The output token is always a valid token from the large model i.e., its distribution follows the conditional probability of the large model. This guarantees that there is no quality degradation compared to the large model.

Maximizing acceptance. The higher the probability that we accept a draft token, the more serial computation we can save through parallelization, and hence better speedup.

Before proposing our framework to achieve the above goals, we would like to first discuss the technical challenge of draft selection with multiple draft tokens. One attempt is to sequentially apply the acceptance phase of Algorithm 1 (line 3 - 5) to each draft token XiX_{i} with p=Ms(xt)p={\cal M}_{s}(\cdot\mid x^{t}) and q=Mb(xt)q={\cal M}_{b}(\cdot\mid x^{t}). However, this approach would not guarantee that the final accepted token is from the desired distribution. To see this, consider the example of p=Ber(1)p=\text{Ber}(1) and q=Ber(1/2)q=\text{Ber}(1/2).Ber(b)\text{Ber}(b) denotes a Bernoulli distribution with the probability of seeing a head bb. Then we have i=1,,k,\forall i=1,\ldots,k, Xi=1X_{i}=1 and each of them will be accepted with probability 1/21/2. After applying Algorithm 1 to all XiX_{i}’s, the probability of getting a 11 will be at least 11/2k1-1/2^{k} and hence the output distribution would not be Ber(1/2)\text{Ber}(1/2) for k>1k>1. Therefore the algorithm does not produce valid samples, which is a requirement of the draft selection problem.

In this work, we conduct a principled investigation of the draft selection problem, and show that these two main goals could be captured by the framework of optimal transport with a properly defined cost function. Next we define optimal transport formally and then connect it to draft selection with one draft. The generalization to multiple drafts is provided in Section 5.

To simplify notations, we assume Ω\Omega is a discrete domain.

For two probability distributions PP over X{\cal X} and QQ over Y{\cal Y}, we say a joint distribution π\pi supported over X×Y{\cal X}\times{\cal Y} is a coupling between PP and QQ if x,y,π(x,y)0\forall x,y,\pi(x,y)\geq 0,

We use Π(P,Q)\Pi(P,Q) to denote the set of all possible couplings between PP and QQ.

When it is clear from context, we will overload notation and refer to the probabilistic mapping fπ:XYf_{\pi}:{\cal X}\rightarrow{\cal Y} introduced by the conditional probability π(yx):=π(x,y)/P(x)\pi(y\mid x){:=}\pi(x,y)/P(x) as a coupling, which is also referred to as a transport plan from PP to QQ . In this paper, we will set PP to be the distribution of the draft tokens and QQ to be the target distribution of the output token. In this case, the fπf_{\pi} is a valid draft selection algorithm. Formally, this is stated in the claim below.

For all πΠ(P,Q)\pi\in\Pi(P,Q), let fπf_{\pi} be the probabilistic mapping defined above . If XPX\sim P, then fπ(X)Qf_{\pi}(X)\sim Q.

In this paper, we will design selection algorithms by finding valid couplings between the draft distribution and target distribution to guarantee validity of the output tokens.

The optimal transport plan is the coupling πΠ(P,Q)\pi\in\Pi(P,Q) that minimizes the transportation cost.

Speculative decoding with one draft token.

With these definitions in place, we can see that with X=Y=Ω{\cal X}={\cal Y}=\Omega, the domain of the tokens and P=p,Q=qP=p,Q=q, we recover the speculative decoding objective with one draft token using the cost function of indicator cost, which captures the resampling cost, defined below:

which is achieved by the maximal coupling between pp and qq stated in Algorithm 1 . And hence speculative sampling achieves the optimal cost with one draft token.

Optimal transport with multiple draft tokens

where S(x)={oΩo appears in x}S(x)=\{o\in\Omega\mid o\text{ appears in }x\} denotes the set of distinct elements in xx. When k=1k=1, it recovers the indicator cost mentioned before. The transportation cost of the coupling is

From now on we will use membership cost as the default cost function and refer to the optimal transport solution as optimal transport with membership cost (OTM). We use π\pi^{*} to denote the coupling that minimizes this cost π=argminπΠ(P,Q)C(π);\pi^{*}=\arg\min_{\pi\in\Pi(P,Q)}C(\pi);The existence of optimal coupling in discrete domain is well-known, e.g., see . When the optimal coupling is not unique, we use π\pi^{*} to denote one of the optimal couplings. and the cost C(π)C(\pi^{*}) is referred to as the optimal transport cost between PP and QQ. We use α(P,Q)=1C(π)\alpha(P,Q)=1-C(\pi^{*}) to denote the corresponding optimal acceptance probability.

In this paper, we will mainly focus on the case when the draft tokens are i.i.d. samples from a base distribution.The above generic formulation immediately allows generalization to more complex draft selection strategies, such as sampling kk tokens without replacement, or using a different drafting distribution for each draft. Let p,qp,q be supported over Ω\Omega and the goal is to obtain one valid token from qq given kk i.i.d. samples from pp. For SpecTr with context xtx^{t}, we have p=Ms(xt)p={\cal M}_{s}(\cdot\mid x^{t}) and q=Mb(xt)q={\cal M}_{b}(\cdot\mid x^{t}). We set P=pkP=p^{\otimes k}, a product distribution whose marginals are all pp, and Q=qQ=q. The OT problem we want to solve is the following:

We overload notation and denote the optimal acceptance probability as αk(p,q):=α(pk,q)=1C(π)\alpha_{k}(p,q){:=}\alpha(p^{\otimes k},q)=1-C(\pi^{*}). To better understand the quantity, we state a few properties about αk\alpha_{k}.

(Appendix A.2) The optimal acceptance probability statisfies the following properties.

Monotonicity. For any p,qp,q and k1k\geq 1, αk(p,q)αk+1(p,q)\alpha_{k}(p,q)\leq\alpha_{k+1}(p,q).

The above properties demonstrate that for a large kk, the value of αk\alpha_{k} can become large. Hence increasing kk could increase the acceptance probability, leading to further speedups. We now focus on computing the optimal transport plan and the optimal acceptance probability.

OTM via Linear programming. Optimal transport in discrete domain has been studied extensively , and it is shown that the optimal transport problem is equivalent to the following linear programming problem:

The linear program in (4) has Ωk+1|\Omega|^{k+1} variables and Ωk+Ω|\Omega|^{k}+|\Omega| equality constraints (see Definition 1). Linear programming can be solved in time polynomial in the number of variables and constraints ,To our best knowledge, the best practical computation bound (through interior-point method) is O(Ω3k)O(|\Omega|^{3k}) and the best theoretical computation bound is O(Ω2.5k)O(|\Omega|^{2.5k}) . implying the following lemma.

Given p,qp,q over Ω\Omega, the solution to Eq. 3 can be computed in time O(ΩO(k))O(|\Omega|^{O(k)}).

We refer to the optimal coupling obtained above as OTM-kk and denote it as πOTMk\pi^{{\rm OTM-}k}. When k=1k=1, there is a closed form expression for the optimal acceptance cost (see Eq. 1), whereas for larger values of kk, we are unaware of a general closed form expression. In Section A.1, we provide an information-theoretic upper (and lower) bound, which is tight up to a multiplicative constant of 1(11/k)k11/e1-(1-1/k)^{k}\geq 1-1/e.

While solving OTM in Eq. 4 gives the plan with optimal acceptance probability, to the best of our knowledge, the best-known runtime will be exponential in kk, which can be prohibitive when either the vocabulary size Ω|\Omega| or the number of draft tokens kk is large.For discrete OT, Sinkhorn algorithm could be used to solve an entropy-regularized version of OT, which has a better computation complexity . However, the computation cost of the algorithm will still have a linear dependence on Ωk|\Omega|^{k}, which can be prohibitive. In the next section, we will present a selection algorithm that can be efficiently computed and show that it achieves an acceptance probability of at least (1(11/k)k)αk(11/e)αk(1-(1-1/k)^{k})\alpha_{k}\geq(1-1/e)\alpha_{k}.

Draft selection via k𝑘k-sequential selection

In this section, we present a sequential selection algorithm (k-Seq), an approximate solutionNote here that the solution still satisfies the constrains in Eq. 3, and hence is a valid transport plan. The term approximate here means that the solution is not the exact minimizer of the cost in Eq. 3. to the optimal transport problem in Eq. 3, which can be efficiently computed in time almost linear in Ω|\Omega| and logarithmic in kk. The algorithm is presented in Algorithm 2.

At a high-level, the algorithm goes over all kk draft samples generated from pp sequentially, and decides on whether to accept each XiX_{i} based on the ratio q(Xi)/p(Xi)q(X_{i})/p(X_{i}). The algorithm output the first accepted sample or result from a residual distribution presp^{\text{res}} if none of the samples is accepted. To guarantee that the the final returned token is a valid sample from qq, we choose an appropriate ρ[1,k]\rho\in[1,k] and accept XiX_{i} with probability min(1,q(Xi)/(ρp(Xi)))\min(1,q(X_{i})/(\rho\cdot p(X_{i}))) instead of min(1,q(Xi)/(p(Xi)))\min(1,q(X_{i})/(p(X_{i}))) as in Algorithm 1. In Theorem 1, we show that with appropriately chosen ρ\rho’s, Algorithm 2 is indeed valid transportation plans from pkp^{\otimes k} to qq. Moreover, to find the best transportation plan within the family, we only need to search over a single parameter ρ\rho, which reduces the computation cost significantly. We also show that searching over this sub-family of couplings won’t decrease the optimal acceptance probability by a multiplicative constant. The performance of Algorithm 2 is stated in Theorem 1.

Let \beta_{p,q}(\rho)=\sum_{x\in\Omega}\min\big{(}p(x),\frac{q(x)}{\rho}\bigr{)} and ρ\rho^{*} be the solution to the identity below.

When ρρ\rho\geq\rho^{*}, the coupling πρ\textsckSeq\pi^{\textsc{k-Seq}}_{\rho} in Algorithm 2 is a valid transport plan from pkp^{\otimes k} to qq. When ρ=ρ\rho=\rho^{*}, we have

Moreover, ρ\rho^{*} can be computed up to accuracy δ\delta in time O(Ωlog((k1)/δ))O(|\Omega|\log((k-1)/\delta)).

We provide the proof in Section C.1. In Appendix B, using a few canonical examples of distributions, we plot the acceptance probability of k-Seq and compare it with the optimal acceptance probability αk\alpha_{k}. It can be shown that k-Seq could have a strictly worse acceptance probability compared to the OTM solution for certain cases while there also exist non-trivial cases where k-Seq achieves the optimal acceptance probability.

Concurrent and recent work of has proposed another efficient algorithm for the draft selection phase. To the best of our knowledge, there is no optimality guarantee proved for their proposed algorithm. In Section B.3, we present its acceptance probability empirically for the canonical case of Bernoulli distributions, and show that both our proposed algorithms (OTM and k-Seq) have a higher acceptance probability.

SpecTr: Application of OTM in autoregressive sampling

In this section, we describe how OTM can be used to speed up auto-regressive sampling, which we refer to as SpecTr sampling. Similar to speculative decoding, each iteration of SpecTr can be decomposed into three phases (Fig. 2):

Draft set construction. Given current context xtx^{t}, use the draft model sample a set of KK draft sequences with length LL, denoted by S={zLMs(xt)}S=\{z^{L}\sim{\cal M}_{s}(\cdot\mid x^{t})\}. We keep the conditional probabilities Ms(yxt,zi){\cal M}_{s}(y\mid x^{t},z^{i}) for all yΩ,iLy\in\Omega,i\leq L and zLSz^{L}\in S.

Conditional probability computation. Compute the conditional probabilities on the next token for the large model Mb(yxt,zi){\cal M}_{b}(y\mid x^{t},z^{i}) for all yΩ,iLy\in\Omega,i\leq L and zLSz^{L}\in S in parallel.

Draft selection. Select first LL^{\prime} of the LL tokens and set x(t+i)=z(i)x(t+i)=z(i) for iLi\leq L^{\prime} and some zSz\in S given the set of draft sequences and the conditional probabilities from both models. Sample a token from a residual distribution as a correction to the rejected tokens.

The conditional probability computation step takes O(1)O(1) when S|S| is not large based on our simplified computations model. We mainly focus on the draft set construction phase and draft selection phase.

Draft set with i.i.d. draft sequences. Given context xtx^{t}, a natural way to come up with a set of KK drafts is to independently sample KK draft sequences from Ms(xt){\cal M}_{s}(\cdot\mid x^{t}), i.e.,

The draft set construction method in (7) can be generalized to a prefix-tree based algorithm. However, this generalized version did not perform better in our experiments. We include this construction in Appendix D for completeness.

Draft selection with multiple candidates. We present the sequence-level selection algorithm given a set of draft sequences in Algorithm 3. We assume the conditional probabilities on the next token are available given any prefix in the candidate set since they are computed in parallel in the second phase, and won’t list them as inputs explicitly in Algorithm 3.

A sample run of the algorithm is presented in Fig. 3. The algorithm proceeds in a recursive fashion. Given prompt xtx^{t} and a candidate set SS sampled from Ms(xt){\cal M}_{s}(\cdot\mid x^{t}), the algorithm first computes a token-level draft selection algorithm fπ:ΩSΩf_{\pi}:\Omega^{|S|}\rightarrow\Omega which is a transport plan from Ms(xt)S{\cal M}_{s}(\cdot\mid x^{t})^{\otimes|S|} to Mb(xt){\cal M}_{b}(\cdot\mid x^{t}). Then fπf_{\pi} is applied to the set of first tokens of the draft sequences in SS to obtained a valid token YY from Mb(xt){\cal M}_{b}(\cdot\mid x^{t}). If YY is not the last token (L2L\geq 2), we filter out sequences in SS whose first token is not YY and denote the remaining sequences as SnextS_{\rm next} and feed it to the algorithm with context (xt,Y)(x^{t},Y) and draft length L1L-1. This goes on until we have L=1L=1 or Snext=S_{\rm next}=\emptyset.

In this case when YY is the last token (i.e., L=1L=1) and YSY\in S, we have the choice to sample an additional token Mb((xt,Y)){\cal M}_{b}(\cdot\mid(x^{t},Y)) since this conditional probability is already computed in the second phase. Due to the property of the token-level selection algorithms and the autoregressive structure of language models, it can be shown that YY is always a valid sample from Mb(xt){\cal M}_{b}(\cdot\mid x^{t}). Let LL^{\prime} be the number of decoded tokens in one iteration. Note that this is a random variable in the range [1,L+1][1,L+1].

The formal quality guarantee is stated in Theorem 2. We present the proof in Section C.2.

Assume all drafts in the set SS are generated from the small model with input xtx^{t}, or more precisely, zS,\forall z\in S,

Let (xt,Yτ)(x^{t},Y^{\tau}) be the output of Algorithm 3 where τ\tau is the length of the newly decoded tokens, then it satisfies that Y1:τY^{1:\tau} is distributed according to Mb(,,τ dotsxt){\cal M}_{b}(\underbrace{\cdot,\cdot,\ldots\cdot}_{\tau\text{ dots}}\mid x^{t}). More precisely, For any τ0[1,L+1]\tau_{0}\in[1,L+1], and any τ0\tau_{0}-length, sequence oτ0=(o(1),,o(τ0))Ωτ0o^{\tau_{0}}=(o(1),\ldots,o(\tau_{0}))\in\Omega^{\tau_{0}}, we have

Experiments

We empirically evaluate SpecTr and compare it with two methods: (1) the baseline auto-regressive decoding; and (2) speculative decoding with K=1K=1. Note that all three methods effectively generate samples from the same baseline large model, and hence the quality of the two speculative decoding methods is provably neutral to that of the large model. Thus, we will only focus on measuring the speedup in our experiments. In the simplified computation model, we made the following assumptions: (1) Decoding time from small models is negligible compared to decoding from the small model; (2) Parallelization along the batch and time axis doesn’t increase the time for a serial call to the large model. With these, the theoretical speedup compared to baseline decoding will be the average number of decoded tokens per serial call, which is called block efficiency , defined below

However, in real deployment of the SpecTr algorithm, the actual end-to-end (wall clock) speedup is further impacted by the following aspects. (1) The decoding time for Ms{\cal M}_{s} might not be negligible; (2) Parallelization along the batch and time axis might increase the time for a single call to Mb{\cal M}_{b}; (3) Overhead due to the implementation of additional functionalities in SpecTr such as the draft selection algorithm and switching between models. These factors will depend on how the algorithm is implemented and optimized. In our experiment, we consider both the block efficiency, and average wall clock speedup with our implementation of SpecTr.

We first present the performance of our algorithm and compare it to speculative decoding using state-of-the-art PALM-2 models with prompts from the one-billion language benchmark (LM1B) . In Appendix E, we use a pair of smaller transformer models to break down different affecting factors mentioned above.

In Table 1, we use PALM-2-Gecko and PALM-2-Bison as the small model and large model, respectively . The wall clock speedup is normalized by the wall clock latency of baseline autoregressive decoding. The time we log include all above mentioned aspects. In the considered parameter configurations, the wall clock speedup increases as KK and LL increases. As seen from the table, the actual wall clock speedup is smaller than the theoretical speedup of block efficiency, which is consistent with what we expected. Importantly, the benefit from SpecTr outweighs these overheads. In particular, when L=8L=8 and K=8K=8, our proposed SpecTr algorithm has a speedup of 2.13x, a further 1.37x increase compared to speculative decoding (K=1K=1).

Acknowledgements

Authors thank Asaf Aharoni, Kwangjun Ahn, Badih Ghazi, Sanjiv Kumar, Teodor Marinov, Michael Riley, and NeurIPS reviewers for helpful comments and discussions.

References

Appendix A Properties of optimal transport cost

Below we provide an information-theoretic upper (and lower) bound in Lemma 3, which is tight up to a multiplicative constant of 1(11/k)k11/e1-(1-1/k)^{k}\geq 1-1/e. The proof is presented in Section A.3. For the case of k=1k=1, the upper bound matches the optimal acceptance probability.

For any two distributions p,qp,q and k1\forall k\geq 1, we have

In Appendix B, we plot αk\alpha_{k} as a function of kk for a few simple pairs of (p,q)(p,q)’s as illustrative examples. We note that the upper bound in Lemma 3 is tight for examples considered in Appendix B.

A.2 Proof of Lemma 1

We first prove monotonicity. By definition,

Moreover, for any πΠ(pk,q)\pi\in\Pi(p^{\otimes{k}},q), we can construct πΠ(pk+1,q)\pi^{\prime}\in\Pi(p^{\otimes{k+1}},q) by setting

i.e., adding and independent sample from pp to XkX^{k}.

Next we prove consistency. We start with the case when xΩ,q(x)/p(x)<\forall x\in\Omega,q(x)/p(x)<\infty. To prove this, we will show that Algorithm 2 with ρmax=maxxΩq(x)/p(x)\rho_{\max}=\max_{x\in\Omega}q(x)/p(x) statisifies

Since α(πρmax\textsckSeq)αk(p,q)\alpha(\pi_{\rho_{\max}}^{\textsc{k-Seq}})\leq\alpha_{k}(p,q), the above equation implies limkαk(p,q)=1\lim_{k\rightarrow\infty}\alpha_{k}(p,q)=1. Notice that by Lemma 4 and Theorem 1, πρmax\textsckSeq\pi_{\rho_{\max}}^{\textsc{k-Seq}} is a valid coupling, and

where βp,q(ρ)=xΩmin(p(x),q(x)ρ)1/ρmax>0\beta_{p,q}(\rho)=\sum_{x\in\Omega}\min(p(x),\frac{q(x)}{\rho})\geq 1/\rho_{\max}>0. Taking kk\rightarrow\infty concludes the proof.

For the case when q(x)/p(x)q(x)/p(x) is unbounded, there exists xΩx\in\Omega such that q(x)>0q(x)>0 and p(x)=0p(x)=0. Let

Let x0x_{0} be such that p(x0)>0p(x_{0})>0. We define qq^{\prime} such that

Then we have dTV(q,q)=poffd_{\rm TV}(q,q^{\prime})=p_{\rm off}, and hence by subadditivity of transport cost,

Moreover, we have xΩ,q(x)/p(x)<\forall x\in\Omega,q^{\prime}(x)/p(x)<\infty. Hence

A.3 Proof of Lemma 3

For the upper bound, it would be enough to show that for any πΠ(pk,q)\pi\in\Pi(p^{\otimes k},q), and any Ω0Ω\Omega_{0}\subset\Omega, we have

For the lower bound, we show that k-Seq achieves an acceptance probability of at least (1(11/k)k)αˉk(p,q)(1-(1-1/k)^{k})\bar{\alpha}_{k}(p,q), see Eq. 11, implying the lower bound guarantee.

We illustrate the acceptance probabilities for our proposed token-level selection algorithms using a few simple examples and plot them in Figures 5 and 5. The analysis for these simple distributions is presented in Appendix B.1 and Section B.2.

Let Ber(b)\text{Ber}(b) be a Bernoulli distribution with probability bb of getting a head. In Figure 5, we plot the acceptance probability comparison between OTM-kk and k-Seq for different Bernoulli distributions q=Ber(b)q=\text{Ber}(b) as a function of kk when p=Ber(0.25)p=\text{Ber}(0.25). Note that when p=qp=q (b=0.25b=0.25), the acceptance probability is always one for both methods. When pqp\neq q, the acceptance probabilities for both methods increase as kk increases before they reach one. When b=0.1b=0.1 or 0.750.75, k-Seq has a worse acceptance probability compared to the OTM-kk algorithm. When b=1b=1, the two algorithms have the same performance.

Pairs of uniform distributions.

Let U(d)U(d) denote a uniform distribution over [d][d]. In Figure 5, we plot the optimal acceptance probability for different uniform functions qq as a function of kk. For these distributions, it can be shown that k-Seq achieves the optimal acceptance probability αk\alpha_{k}. Hence only αk\alpha_{k} is plotted. Observe that all acceptance probabilities are monotonically increasing and tend to one when kk\to\infty, as stated in Lemma 1.

In this section, we provide a sketch of optimal acceptance probability calculations for results in Figures 5 and 5.

Consider the transport plan π\pi given by π(1k,1)=min(pk,q)\pi(1^{k},1)=\min(p^{k},q), π(1k,0)=pkmin(pk,q)\pi(1^{k},0)=p^{k}-\min(p^{k},q), π(0k,0)=min((1p)k,1q)\pi(0^{k},0)=\min((1-p)^{k},1-q), and π(0k,1)=(1p)kmin((1p)k,1q)\pi(0^{k},1)=(1-p)^{k}-\min((1-p)^{k},1-q). It can be checked that this is a valid transport plan. To see this matches the upper bound on the optimal cost from Lemma 3, notice that

If pkqp^{k}\leq q and (1p)k1q(1-p)^{k}\leq 1-q, then the above equation simplifies to 11 and (10) also simplifies to 11. If pk>qp^{k}>q and (1p)k1q(1-p)^{k}\leq 1-q, then the above equation simplifies to 1+qpk1+q-p^{k} and (10) also simplifies to the same quantity. Similarly, the proof applies for pkqp^{k}\leq q and (1p)k>1q(1-p)^{k}>1-q.

Figure 5: p=U​(d)𝑝𝑈𝑑p=U(d) and q=U​(d/r)𝑞𝑈𝑑𝑟q=U(d/r).

We first prove αk(U(d),U(d/r))1(11/r)k\alpha^{k}(U(d),U(d/r))\geq 1-(1-1/r)^{k} by a construction. Let S(Xk)S(X^{k}) be the set of unique symbols in XkX^{k}. Consider the following transport plan, where YY is drawn uniformly from S(Xk)[d/r]S(X^{k})\cap[d/r] and draws a new uniform sample from [d/r][d/r] if S(Xk)[d/r]=S(X^{k})\cap[d/r]=\emptyset. Observe that since U(d)U(d) is uniform over [d][d], this is a valid transport plan and furthermore,

The upper bound follows by setting Ω0=[d][d/r]\Omega_{0}=[d]\setminus[d/r] in Lemma 3.

B.2 Acceptance probability of k-Seq for the example in Figure 5

In this section, we show that for the example in Figure 5, k-Seq achieves the optimal acceptance accuracy. In this case, p=U(d)p=U(d) and q=U(d/r)q=U(d/r). Recall that the optimal acceptance probability is

And hence solving 1(1β(ρ))k=ρβ(ρ)1-(1-\beta(\rho))^{k}=\rho\beta(\rho) gives ρ=r(1(11/r)k)\rho^{*}=r(1-(1-1/r)^{k}). And be Theorem 1, we have

And the equality holds since this is an upper bound for any coupling.

B.3 Comparison to multi-round rejection sampling in [21, 20]

In this section, we compare our proposed draft selection algorithms (OTM and k-Seq) to the multi-round rejection sampling algorithm (multi-round) in concurrent and recent work of (see Algorithm 1 in ) using the example of Bernoulli distributions. As Figure 6 demonstrates, both our proposed algorithms outperform their algorithm. The advantage of OTM is demonstrated by the fact it is the optimal algorithm under the validity guarantee of the final accepted token. Our proposed efficient algorithm k-Seq also outperforms multi-round for the considered examples. We leave a systematic comparison of the algorithms as future work.

Appendix C Analysis of SpecTr

We start by proving the following lemma on ρ\rho^{*}.

Then we have Let ρ\rho^{*} be the solution to Eq. 6. Then when dTV(p,q)(0,1)d_{\rm TV}(p,q)\in(0,1),

f(ρ)f(\rho) is monotone in ρ\rho in [1,)[1,\infty);

\rho^{*}\in\big{[}1,\min\{k,\max_{x}\frac{q(x)}{p(x)}\}\big{]}.

It would enough to prove the followings: (1) f(ρ)f(\rho) is monotone in ρ\rho in [1,)[1,\infty); (2) f(1)0f(1)\geq 0; (3) f(k)0f(k)\leq 0; (4) f\big{(}\max_{x}\frac{q(x)}{p(x)}\big{)}\leq 0.

To see (1), since βp,q(ρ)\beta_{p,q}(\rho) is decreasing in ρ\rho, so is 1(1βp,q(ρ))k1-(1-\beta_{p,q}(\rho))^{k}. Moreover, ρβp,q(ρ)=xmin{ρp(x),q(x)}\rho\beta_{p,q}(\rho)=\sum_{x}\min\{\rho p(x),q(x)\}, which is non-decreasing in ρ\rho. Hence we have 1(1βp,q(ρ))kρβp,q(ρ)1-(1-\beta_{p,q}(\rho))^{k}-\rho\beta_{p,q}(\rho) is decreasing.

To see (2), note that when ρ=1\rho=1, βp,q(ρ)=1dTV(p,q)\beta_{p,q}(\rho)=1-d_{\rm TV}(p,q). Hence we have

When ρ=k\rho=k, (3) holds since for xx\in, we have 1(1x)kkx1-(1-x)^{k}\leq kx. Moreover, when ρ=maxxq(x)p(x)>1\rho=\max_{x}\frac{q(x)}{p(x)}>1, we have βp,q(ρ)=1/ρ\beta_{p,q}(\rho)=1/\rho and (4) holds since

Next we prove Theorem 1, we will break the proof into four parts: (1) computation efficiency; (2) πρ\textsckSeq\pi_{\rho}^{\textsc{k-Seq}} is a valid transport plan; (3) acceptance probability; (4) optimality guarantee of πρ\textsckSeq\pi^{\textsc{k-Seq}}_{\rho^{*}}.

Note that Lemma 4 immediately implies that ρ\rho^{*} can be computed up to arbitrary accuracy δ\delta in time O(Ωlog((k1)/δ){O(|\Omega|\log((k-1)/\delta)} using binary search over [1,k][1,k].

Valid transport plan.

We next prove that πρ\textsckSeq\pi^{\textsc{k-Seq}}_{\rho} is a valid transport plan when ρρ\rho\geq\rho^{*}. By Lemma 4, when ρρ\rho\geq\rho^{*}, we have 1(1βp,q(ρ))kρβp,q(ρ)1-(1-\beta_{p,q}(\rho))^{k}\geq\rho\beta_{p,q}(\rho). Recall that pacc=1(1βp,q(ρ))kp_{\rm acc}=1-(1-\beta_{p,q}(\rho))^{k}, and

this implies pres(x)0p^{\text{res}}(x)\geq 0 for all xΩx\in\Omega. Moreover,

Hence presp^{\text{res}} is a valid distribution. It remains to show that the marginal of YY is qq. We first compute the probability of the output Y=xY=x. Note that probability that Y=X1Y=X_{1} is

Acceptance probability.

It can be seen that β(ρ)\beta(\rho) is decreasing in ρ\rho, and so is 1(1βp,q(ρ))k1-(1-\beta_{p,q}(\rho))^{k}. Hence we have

The statement holds since f(x)=1(1x)kkxf(x)=\frac{1-(1-x)^{k}}{kx} in monotonically decreasing when x(0,1/k]x\in(0,1/k] and f(1/k)=1(11/k)k,limx0+f(x)=1f(1/k)=1-(1-1/k)^{k},\lim_{x\rightarrow 0^{+}}f(x)=1.

Moreover, x0,kx1(1x)k\forall x\geq 0,kx\geq 1-(1-x)^{k}. Hence we have

where the last inequality is due to the upper bound in Lemma 3 with Ω0=Ω\Omega_{0}=\Omega.

C.2 Proof of Theorem 2

We prove the theorem via induction. When L=1L=1, τ{1,2}\tau\in\{1,2\}. Let k=Sk=|S|. Since for the first step, fπf_{\pi} in Algorithm 3 is a valid transport plan from Ms(xt)k{\cal M}_{s}(\cdot\mid x^{t})^{\otimes k} to Mb(xt){\cal M}_{b}(\cdot\mid x^{t}). We have Y1Mb(xt)Y_{1}\sim{\cal M}_{b}(\cdot\mid x^{t}), which completes the proof when τ=1\tau=1. When τ=2\tau=2, we have Y2Mb(xt,Y1)Y_{2}\sim{\cal M}_{b}(\cdot\mid x^{t},Y_{1}) as stated in Step 5 of Algorithm 3. Hence the statement holds.

Note that in this case τ=τ+1\tau=\tau^{\prime}+1, and for any (τ0+1)(\tau_{0}+1)-length sequence oτ0+1=(o(1),,o(τ0),o(τ0+1))Ωτ0+1o^{\tau_{0}+1}=(o(1),\ldots,o(\tau_{0}),o(\tau_{0}+1))\in\Omega^{\tau_{0}+1}, we have

Combining the two cases, we complete the proof.

Appendix D Candidate set construction via a prefix-tree

As discussed in Section 1, the size of the draft set SS is constrained by the number of parallel computations that can be supported in the hardware. Hence it is important to design the draft set carefully to allow for a longer sequence of accepted candidate sets. In addition to the i.i.d. draft set selection approach listed in Section 7, we present an algorithm that samples a draft set that forms the leaves of a prefix tree. Given a draft set size KK, the algorithm can be specified by a sequence of parameter (k1,k2,,kL)(k_{1},k_{2},\ldots,k_{L}) satisfying i=1Lki=K\prod_{i=1}^{L}k_{i}=K.

The algorithm starts with a root node with sequence x1:tx^{1:t} and forms a prefix tree of depth LL. At depth i[1:L1]i\in[1:L-1], each node is expanded by a factor of ki+1k_{i+1} and each of its children will contain a sequence that satisfies: (1) Its prefix agrees with the sequence in the parent node; (2) The next token is sampled from the conditional probability given the prefix in small model. These child nodes will be at depth i+1i+1 and the process goes until it hits depth LL. We give a detailed description of the algorithm in Algorithm 4.

Appendix E Additional experiments

In this section, we perform a detailed investigation of different factors that affect the speed of SpecTr with smaller transformer models. We train decoder-only transformer models on the LM1B dataset based on the example provided in the FLAX library . For the draft model, we use transformer models with 2M2M, 6M6M and 20M20M parameters, and for the large model we use a 97M97M parameter transformer model.

We first provide a verification of the computational model introduced in Section 1 by reporting the latencies of using the large model to compute the probabilistic distributions with parallelization over time and batch axes. As shown in Table 2, the latency stays roughly constant in these setting.

Similar to Table 2, we report relative latency when parallelizing across the time and batch axes using the small 6M6M draft model in Table 3. In Table 3, the reported relative latencies are relative to the large 97M97M model to get a sense of the relative cost of sampling multiple drafts with the small model compared to the large model.

To see how the size of size of the draft model will affect the block efficiency, we also include results for varying draft model sizes with the same 97M97M large model for LM1B in Table 4. These draft models were produced by either halving (2M2M) or doubling (20M20M) the original 6M6M draft model’s number of layers, embedding dimension, MLP dimension, and number of attention heads. As expected, the larger draft models improve all speculative methods’ block efficiency with SpecTr maintaining the best performance across all draft model sizes.