DistillSpec: Improving Speculative Decoding via Knowledge Distillation

Yongchao Zhou, Kaifeng Lyu, Ankit Singh Rawat, Aditya Krishna Menon, Afshin Rostamizadeh, Sanjiv Kumar, Jean-François Kagy, Rishabh Agarwal

Introduction

Large language models (LLMs) have revolutionized natural language understanding and generation across diverse applications (OpenAI, 2023; Anil et al., 2023). However, their autoregressive generation nature poses significant computational challenges, especially in real-time deployments with stringent latency constraints (Thoppilan et al., 2022; Pope et al., 2023). Conversely, smaller language models, while computationally efficient, often lack the expressive power of their larger counterparts and achieve subpar performance. While reducing the inference cost of larger models, e.g., via quantization or pruning, or improving the performance of the smaller models, e.g., via knowledge distillation (KD) (Hinton et al., 2015), constitute natural approaches to enable a favorable performance versus inference cost trade-off, these approaches frequently result in unacceptable performance gap compared to the high-quality large models. This has inspired a growing literature on designing mechanisms that combine both large and small models at inference to approximate the performance of the larger models without incurring their high computational cost.

Conventionally, model cascading approaches aim to identify easy instances where smaller models suffice to achieve good performance, thereby soliciting larger models only on a subset of hard instances (Rowley et al., 1998; Xu et al., 2014) or tasks (Cai et al., 2023b). Different from such task- or instance-level cascading, speculative decoding (SD) (Leviathan et al., 2023; Chen et al., 2023) aims to exploit the token-level variability in the computation demand during LLM inference by interactively invoking a small “draft” model and a large “target” model. At a given stage during inference, the draft model generates successive candidate tokens for multiple inference steps via autoregressive decoding. The target model then verifies the candidate tokens via parallel decoding, and employs rejection sampling to accept a subset of candidate tokens at contiguous positions.

The main objective of SD is to speed up text generation while guaranteeing that the decoded tokens follow the target model distribution. SD relies on the insight that the combined cost of autoregressive decoding with a small draft model followed by parallel decoding with the target model is lower than the cost of autoregressive decoding with the target model alone. However, the realized inference cost reduction or latency improvement crucially depends on the acceptance rate of the draft-generated tokens by the target model, which can be shown to be directly tied to the alignment between the token distributions of the draft and target models. Thus, a successful application of SD hinges on identifying a compact draft model that simultaneously has small autoregressive decoding cost and is closely aligned with the target model.

In this work, we propose DistillSpec, a novel approach that relies on KD (Hinton et al., 2015) to obtain an effective draft model. Unlike the standard application of KD which primarily focuses on improving the task performance of a small student model, DistillSpec aims at aligning the student (draft) model with the teacher (target) model to enhance the acceptance rate during speculative decoding. This requires the student model to closely approximate the teacher distribution at the token and sequence level, even if it translates to suboptimal downstream task performance.

We undertake a comprehensive exploration of the distillation process for speeding up SD, considering several factors including the composition of training data, choice of divergence functions to define the training objective for KD, and decoding strategies. Notably, our findings underscore that using model-generated data is crucial for ensuring strong student-teacher alignment across various tasks via KD, and that the selection of the best-performing divergence function in DistillSpec is highly task-dependent and sensitive to the decoding strategy (i.e., greedy versus non-greedy). Furthermore, we explore the utility of DistillSpec for lossy SD (Leviathan et al., 2023) which allows for sampling away from the target model distribution. We show that combining DistillSpec with lossy SD enables a more fine-grained control over the latency versus task performance trade-off.

Finally, we carry out a systematic study of how to design an efficient inference scheme in a practical setting where one has access to multiple language models of increasing size and quality. Leveraging the insights that we have laid out in this paper about KD and SD, our study concludes that the most effective strategy involves first distilling a large model into a smaller one as the potential target model for performance optimization, followed by DistillSpec for distilling an even smaller model to be used as the draft model in SD. This approach results in a remarkable 610×6-10\times reduction in latency, compared to a standalone non-distilled target of same size, with minimal performance degradation.

We propose DistillSpec that uses KD to enhance draft model alignment with the target model (§4), and show that our method can improve SD speed by 10-45% while preserving model performance across four diverse datasets with both greedy and non-greedy sampling (Figure 1).

We conduct an extensive analysis of the optimal distillation recipe (§5.2) for model alignment, encompassing factors such as training data generation and different divergences, and emphasizing the distinctions between standard KD and distillation tailored for SD.

We extend DistillSpec to lossy SD, enabling refined control over the quality-latency trade-off. Moreover, we offer insights for combining KD and SD when several models are available (§5.3).

Related Work

Due to the inherent sequential nature of autoregressive decoding, the primary latency bottleneck in LLM inference arises from memory read/write operations rather than arithmetic computations (Pope et al., 2023). Speculative decoding (Leviathan et al., 2023; Chen et al., 2023) (SD) addresses this challenge by utilizing a compact draft model to generate a batch of tokens sequentially, while validating them in parallel with a larger target model. Prior to SD, various parallel computing paradigms have been explored for autoregressive models, including block parallel sampling (Stern et al., 2018), shallow aggressive decoding (Sun et al., 2021), and aggressive decoding (Ge et al., 2022). However, these approaches are not readily adaptable to typical language models due to potential deviations from target model’s distribution, strict input constraints, or limited support for general stochastic sampling. Notably, recent variants of SD have also incorporated parallel computation along the batch axis, sometimes combined with token tree verification, as seen in SpecTr (Sun et al., 2023), SpecInfer (Miao et al., 2023), and Medusa (Cai et al., 2023a). In contrast, our work focuses on enhancing SD by improving the alignment between the small draft model and the large target model through KD, which does not require any changes to serving infrastructures already implementing SD and is complementary to the recent variants of SD. Furthermore, we do a systematic study of lossy SD for providing nuanced control over the trade-off between quality and latency for specific serving models.

KD (Buciluǎ et al., 2006; Hinton et al., 2015), which trains high-quality smaller student models with the help of larger teacher models, has emerged as a vital technique for reducing the inference cost while maintaining performance quality across a range of domains. In the context of LLMs, prior uses of KD (Taori et al., 2023; Fu et al., 2023) have mostly focused on black-box KD, wherein only teacher’s output generations, often via APIs, are accessible during student training. However, with the proliferation of open-source LLMs (Zhang et al., 2022; Touvron et al., 2023), there is a growing interest in white-box KD, where we have access to teacher weights and logits. White-box KD allows student models to benefit from richer supervision signals provided by white-box teacher models, leading to enhanced language abilities (Agarwal et al., 2023; Gu et al., 2023; Wen et al., 2023). Despite notable improvements in student quality, substantial performance gaps persist between large and small models (OpenAI, 2023; Anil et al., 2023), which may remain unbridgeable through distillation alone.

Unlike prior works focused on creating highly capable standalone student models, we harness KD to foster closer collaboration between smaller and larger models in SD, which may be particularly valuable when a small distilled model alone cannot meet stringent quality requirements. While Stern et al. (2018) use an black-box KD approach (SeqKD) to enhance blockwise parallel decoding, they use samples generated from the large target model, which is prohibitively expensive for LLMs. Furthermore, they ignore the teacher model’s logits and train their draft model using only one-hot teacher labels – a reasonable choice for greedy decoding but a less effective one for non-greedy sampling (Figure 2). Relatedly, Rawat et al. (2021) leverage KD to improve model cascading. However, different from our efforts which focus on text generation with LLMs, their study focuses on classifications tasks in vision and NLP domains.

Background: Speculative Decoding

Given an input sequence xx comprising tokens from a pre-defined vocabulary, a language model M\mathscr{M} provides a distribution over possible output sequences yy. Suppose we employ SD with a compact draft model Mq\mathscr{M}_{q} to assist a larger target model Mp\mathscr{M}_{p}. Let p(ytx,y<t)p(y_{t}\,|\,x,y_{<t}) and q(ytx,y<t)q(y_{t}\,|\,x,y_{<t}) represent the distributions governing next-token predictions at time step tt for Mp\mathscr{M}_{p} and Mq\mathscr{M}_{q}, respectively, given the context ρ={x,y<t}\rho=\{x,y_{<t}\}. Given input xx as prefix, let pT(yx)p_{\leq T}(y\,|\,x) and qT(yx)q_{\leq T}(y\,|\,x) represent the distributions governing the sequence yy sampled autoregressively from Mp\mathscr{M}_{p} and Mq\mathscr{M}_{q}, respectively, where the generation stops either when an end-of-sequence token is sampled, or the maximum sequence length TT is reached. For simplicity, we use p(yt)p(y_{t}) and q(yt)q(y_{t}) as shorthands for p(ytx,y<t)p(y_{t}\,|\,x,y_{<t}) and q(ytx,y<t)q(y_{t}\,|\,x,y_{<t}), whenever the context ρ\rho is clear. Similarly, pT(y)p_{\leq T}(y) and qT(y)q_{\leq T}(y) serve as shorthands for pT(yx)p_{\leq T}(y\,|\,x) and qT(yx)q_{\leq T}(y\,|\,x), whenever the input xx is clear.

Standard SD uses a procedure called speculative sampling to generate tokens from the draft model while maintaining the same output distribution as the target model. As detailed in Algorithm A.1 (Appendix), each step of SD works as follows. First, a block of γ\gamma tokens, denoted as yt,,yt+γ1y_{t},\dots,y_{t+\gamma-1}, is autoregressively sampled from q(yt),,q(yt+γ1)q(y_{t}),\dots,q(y_{t+\gamma-1}). Next, the γ\gamma tokens are verified in parallel by passing them to Mp\mathscr{M}_{p} as a whole block, which sequentially accepts token yt+iy_{t+i} with probability min(1,p(yt+i)/q(yt+i))\min\left({1,{p(y_{t+i})}/{q(y_{t+i})}}\right). If any token yt+iy_{t+i} is rejected before the end of the block, the subsequent tokens are discarded and the rejected token is resampled from the adjusted distribution p(yt+i)max(0,p(yt+i)q(yt+i))p^{\prime}(y_{t+i})\propto\max({0,p(y_{t+i})-q(y_{t+i})}); otherwise, the tokens are all accepted and an extra token is sampled from p(yt+γ)p(y_{t+\gamma}) and appended to the output sequence. This process guarantees that the sequence of accepted and resampled tokens follow the same output distribution as p(yt+i)p(y_{t+i}) (Leviathan et al., 2023). The procedure is repeated until an end-of-sequence token is accepted, or the maximum sequence length TT has been reached.

Each SD step takes a constant amount of time, so the wall-clock time scales linearly with the number of steps. This number is equal to the total number of steps that the target model rejects a token, plus the number of blocks accepted as a whole, where the latter term is small for large γ\gamma. This motivates us to use the acceptance rate as a surrogate efficiency measure for the wall-clock time. For an ideal SD process with γ=\gamma=\infty, we define the sequence-level acceptance rate α(x)\alpha(x) for a given input xx as follows:

Walltime improvement. For given block efficiency τ(x)\tau(x), the expected speedup of SD is given by τ(x)/(cγ+1){\tau(x)}/(c\gamma+1), where the relative latency cc is the ratio between the times for making a single forward pass through the draft model Mq\mathscr{M}_{q} and target model Mp\mathscr{M}_{p}.

DistillSpec: Knowledge Distillation for Speculative Decoding

As described in § 3, speculative decoding (SD) can reduce the latency of the larger (target) model with the help of a smaller (draft) model without any performance drop. However, the realized latency reduction via SD critically depends on how “well-aligned” the draft model is to the target model. Setting performance improvement of SD as a primary objective, our proposed DistillSpec method enables a closer alignment between the draft and target models by leveraging KD. We first present KD-based training of the draft model. We highlight how our objective of enhancing SD via KD influences our selection of training data generation method and divergence function – two key ingredients of DistillSpec. We then discuss how DistillSpec can be extended to lossy SD.

Let the draft model Mqθ\mathscr{M}_{q}^{\theta} be parameterized by θ\theta. DistillSpec utilizes predictions from the target model Mp\mathscr{M}_{p} as a source of supervision (teacher) while training the draft model Mqθ\mathscr{M}_{q}^{\theta} (student). We assume white-box access to both models, i.e., we can obtain their token-level distributions p(yt)p(y_{t}) and q(yt)q(y_{t}), and therefore we are able to generate samples both from the target and draft models. Given a divergence function DD that measures the misalignment between two distributions, KD-based training of the draft model seeks to minimize the divergence between the teacher (target) and student (draft) distributions over a training set G\mathscr{G}:

Our choices for G\mathscr{G} and DD are guided by how the resulting distilled model, once employed as draft model, improves the speed of SD. Towards this, we first highlight the role that the total variation distance between p(yt)p(y_{t}) and q(yt)q(y_{t}) plays in dictating the acceptance rate (§ 3)—a key efficiency measure for SD.

Leviathan et al. (2023, Corollary 3.6) shows that the token-level acceptance rate β(x,y<t)\beta(x,y_{<t}) satisfies β(x,y<t)=1DTVD(p(yt),q(yt))\beta(x,y_{<t})=1-D_{\operatorname{TVD}}(p(y_{t}),q(y_{t})). Hence, Eq. 1 implies that maximizing the sequence-level acceptance rate α(x)\alpha(x) is equivalent to minimizing the expected DTVDD_{\operatorname{TVD}} between p(yt)p(y_{t}) and q(yt)q(y_{t}) over the output sequence distribution of Mp\mathscr{M}_{p}, i.e.,

Choice of divergence. Based on Eq. 3, it appears that directly minimizing DTVDD_{\operatorname{TVD}} may be a principled objective for draft model distillation. While optimizing DTVD(p,q)D_{\operatorname{TVD}}(p,q) is theoretically inspired, our empirical study shows that such an objective may not consistently yield optimal results. We find that the choice of the most suitable divergence is highly task-dependent (§ 5.2).

As for G\mathscr{G}, Eq. 3 suggests optimizing the expected DTVDD_{\operatorname{TVD}} over outputs generated from the teacher. However, decoding from a large teacher is generally prohibitively expensive, especially at the scale of dataset required for KD. Alternatively, one could resort to an existing ground-truth dataset, however the teacher’s output distribution may deviate from the ground-truth distribution despite the teacher having been fine-tuned on it. Moreover, ground-truth datasets are often limited in size, so training only on such data could result in overfitting. To resolve these issues, we explore using on-policy data during distillation, i.e., output sequences sampled from the student itself. Besides being more computationally efficient compared to teacher generations, this approach is inspired by Gu et al. (2023); Agarwal et al. (2023), where distilling on on-policy data is shown to improve student task performance. However, different from these prior works, our primary focus is on improving the student-teacher alignment. Thus, it is not immediately clear whether minimizing the expected DTVDD_{\operatorname{TVD}} over on-policy (student-generated) data ensures an improved acceptance rate, which is computed as an expectation over the teacher’s output distribution (cf. Eq. 3). For this, our following result shows that this is indeed the case.

We defer the proof to Section B.2. Intuitively, it builds upon the following insights. If the on-policy KD loss is small, then, for any 1tT1\leq t\leq T, the same loss evaluated only at the tt-th token should also be small. Since the first token generation is independent of any other tokens, a small value of online-policy KD loss ensures that the first token distributions of the draft and target models are close. Then, an inductive argument shows that once the draft and target are similar on the first tt tokens, the distributions of the (t+1)(t+1)-th token should also be close. Our proof makes this argument rigorous by utilizing variational representations of DTVDD_{\operatorname{TVD}}, leading to a linear error bound in TT.

Experiments

We evaluate the effectiveness of KD in improving the speed of speculative decoding (SD). We specifically investigate its impact on the enhancement of acceptance rate, block efficiency, and latency.

Following Leviathan et al. (2023), we evaluate two model types: 1) GPT-like decoder-only Transformer models trained using the standard autoregressive objective on LM1B task (Chelba et al., 2013), where the target and draft models have 234M and 33M parameterscolor=cyan][Ankit] Update model sizes if needed., respectively; and 2) Standard encoder-decoder T5 v1.1 models (Raffel et al., 2020) supervised fine-tuned on four different tasks, with T5-XL (3B) and T5-Small (77M) color=cyan][Ankit] Change model names/sizes if needed.serving as the target and draft models, respectively. As for the four datasets, we utilize two datasets form the T5 paper, namely WMT EnDe (Bojar et al., 2014) and CNN/DM (Hermann et al., 2015) which deal with translation and text summarization, respectively. The remaining two tasks used to test T5 models are XSum (Narayan et al., 2018) and GSM8K (Cobbe et al., 2021), which deal with abstractive summarization and arithmetic reasoning, respectively. See Appendix C for more details.

We study five KD algorithms outlined in Table 1. However, for SeqKD (Kim & Rush, 2016) and ff-Distill (Wen et al., 2023), we opt for an online data generation approach from the teacher instead of a conventional fixed offline teacher-generated dataset. This approach, while computationally expensive, yields a more diverse dataset. For GKD, we solely rely on the data generated by the online student model, excluding the static ground truth data. All data generated by either the teacher or the student is based on temperature sampling with a temperature of 1.0 (see Appendix D.1.3 for an ablation study on sampling temperature).

For each pair of target model and draft model, we measure the empirical acceptance rate α\alpha and block efficiency τ\tau with three different block sizes γ{3,5,7}\gamma\in\{3,5,7\}, both with and without distillation. As per Leviathan et al. (2023), we measure the relative wall time improvements with a batch size of 1 for both greedy sampling (T=0T=0) and standard temperature sampling (T=1T=1).

Figure 1 shows that the impact of distillation on SD speed is evident, consistently yielding a 10-46% improvement across various datasets. This effect is most pronounced when employing greedy decoding. The summary of results for different block sizes and decoding strategies across five datasets is presented in Table D.1 (Appendix), highlighting the superior performance of the KD algorithm in terms of latency. The findings demonstrate that KD significantly enhances the acceptance rate and block efficiency for both decoder-only and encoder-decoder models across all datasets. Distillation algorithms utilizing model-generated data consistently outperform other approaches, resulting in approximately a \sim20% additional speedup compared to standard SD on LM1B, XSum, CNN/DM, and GSM8K. However, the gains on the WMT dataset are marginal, as the preliminary model already achieves a high acceptance rate and block efficiency without KD.

Figure 2 presents a comparison of block efficiency across different algorithms, employing temperature sampling (T=1T=1) with a block size γ=7\gamma=7. The figure showcases the utility of model generated data as using a fixed ground truth data (i.e., Supervised KD) ranked the lowest across all settings except WMT. In contrast, ff-Distill and GKD which use the purely model generated data significantly outperform the other counterparts. Besides, the subpar performance of SeqKD, despite purely trained on the data generated by the target model, implies that white-box distillation (information from the target model’s logits) is vital for SD. This is corroborated by Figure 3(a), which illustrates the evolution of the acceptance rate throughout the training. Supervised KD ranks lowest, and its performance plateaus as training progresses due to a static dataset. In contrast, all other algorithms that employ model-generated data continue to improve. Despite ff-Distill being much more computationally costly than GKD due to using teacher generated data, both exhibit comparable performance. Notably, GKD achieves the best wall-time performance improvement. (see Appendix D.1.2 for more visualizations on performance improvement during training).

We investigate whether KD improves block efficiency universally or mainly impacts a subset of examples. We depict the change in block efficiency per example in Figure 3(b). The results reveal a consistent enhancement in block efficiency across most examples, as also seen in various datasets (see Figure D.12). Figure 3(c) illustrates a strong agreement between theoretical and empirical block efficiency values for several distilled models (each model as a filled circle). Despite theoretical values occasionally overestimating or underestimating empirical values due to potential deviations from the i.i.d. token-level assumption (cf.§3), the ranking of distilled models remains highly consistent. In summary, these findings affirm that KD effectively optimizes block efficiency.

2 DistillSpec recipe

We now focus on identifying the optimal KD approach for SD. Following the training and evaluation protocols in § 5.1, we explore four training data construction methods and four divergence functions on XSum and GSM8K. In particular, we explore the following variants of training data: 1) Fixed ground-truth dataset DTrain\mathcal{D}_{\text{Train}}, 2) Data generated only from the draft Mqθ\mathscr{M}_{q}^{\theta}, 3) Data generated only from teacher Mp\mathscr{M}_{p}, 4) Data generated from both Mqθ\mathscr{M}_{q}^{\theta} and Mp\mathscr{M}_{p} in equal proportion. We further examine the following divergences: 1) Forward KL (FKL), 2) Jenson-Shannon divergence (JSD), 3) Reverse KL (RKL), and 4) Total variation distance (TVD).

Figure 4 illustrates the block efficiency improvement on XSum and GSM8K, in line with observations from § 5.1. Note that using model-generated data (last three rows) yields superior performance than using a fixed dataset (first row). Specifically, on XSum with greedy decoding, using data generated from both Mqθ\mathscr{M}_{q}^{\theta} and Mp\mathscr{M}_{p} gives the best performance, with JSD slightly outperforming the other divergences. However, on GSM8K with greedy sampling, FKL with only draft Mqθ\mathscr{M}_{q}^{\theta} generated data emerges as the best KD setup. In contrast, with temperature sampling (at T=1T=1), a different trend is observed as RKL combined with data generated by Mp\mathscr{M}_{p} is the most effective. See Appendix D.2.1 for results on different datasets and decoding strategies. Nonetheless, using only draft generated data is competitive.

We also study how different distillation approaches affect draft model’s task performance and if the same design choices are optimal for improving both draft task performance and its utility for SD (cf. Figure D.18, D.19). Similar to our earlier observations, the use of generated data is paramount for improving draft performance. More notably, utilizing data generated from Mqθ\mathscr{M}_{q}^{\theta} yields comparable or superior results compared to using data generated from Mp\mathscr{M}_{p}. However, the choice of the optimal KD algorithm largely depends on the task and the underlying decoding strategy. More interestingly, Figure 5(a) highlight a dichotomy between block efficiency improvements and task performance gains via KD as a drafter with high task performance does not necessarily indicate a powerful drafter for SD. See Appendix D.2.2 for more results on different datasets and decoding strategies.

Interestingly, although TVD is the objective we aim to optimize for SD (cf. Eq. 3), its direct optimization does not yield the best performance in most of the settings explored. We generally find that the choice of the divergence in KD is a hyperparameter that needs to be tuned based on the task and decoding strategy used. For training data construction, we propose using draft Mqθ\mathscr{M}_{q}^{\theta} for data generation as it can achieve similar or superior performance compared to teacher Mp\mathscr{M}_{p}, but at a much lower cost.

3 Quality versus latency trade-off

We analyze the quality-latency trade-off using lossy SD variants, as detailed in Algorithm A.1. Figure 5(b) illustrates that employing either KD (\star) or SD (×\times) alone does not fully bridge the performance or latency gaps, respectively. In such cases, a leniency parameter can help interpolate between these two approaches, as demonstrated in Figure 5(b). Interpolating on GSM8K presents challenges, as flinf_{\rm lin} still results in high performance and latency when using a lenience of 10510^{-5}, while fsqf_{\rm sq} pushes it further along the trade-off curve. Although fexpf_{\rm exp} can still interpolate, it yields a worse trade-off. Interestingly, it is possible to significantly reduce latency while almost preserving the quality in GSM8K, possibly because many tokens have little impact on the final performance, and accepting them is a model preference with minimal effect on generation quality. See Appendix D.3.1 for comparison between raw draft and distilled draft models, where we show that distilled draft achieves a much better trade-off.

In practical scenarios, we often have access to multiple models of different sizes – a model garden – to design the inference pipeline. We consider this setting by focusing on the T5-models with five sizes: T5-Small (77M), T5-Base (250M), T5-Large (800M), T5-XL (3B), and T5-XXL (11B). We study four different quality-latency trade-off curves using KD and SD: 1) Raw: Deploying supervised fine-tuned (SFT) T5 models; 2) Distilled: Applying KD to T5 models for optimizing downstream task performance 3) Speculative: Applying SD to T5 models; and 4) DistillSpec: Applying KD to T5 models and using distilled models as target and draft.

Figure 6 illustrates that SD effectively shifts the trade-off curve leftward, especially with larger model sizes. However, its efficacy diminishes with smaller model sizes when the relative computation time between the draft and target models is closely matched. In contrast, distillation, which optimizes the model for downstream task performance, appears to offer a superior trade-off between quality and latency, particularly for smaller models. Conversely, a reverse trend is observed for larger model sizes when evaluating the model with temperature sampling. Figure 25(a) indicates a substantial gap between the distilled model and the larger teacher model, while the SD-based method significantly reduces latency. This suggests that when stringent performance and decoding strategy constraints are in place, SD remains a valuable approach. Notably, our method, DistillSpec, which combines the benefits of distillation and SD, consistently achieves the best trade-off between quality and latency, resulting in an impressive reduction in latency while maintaining nearly identical performance. Specifically, DistillSpec reduces relative latency from 17.3 to 2.7 and from 15.0 to 1.4 on XSum and GSM8K, respectively, representing speedup improvements of 6.4×\times and 10.7×\times. In contrast, the Rouge2 Score only experiences a marginal decrease, shifting from 23.1 to 23.0 on XSum, while the model accuracy on GSM8K actually improves, rising from 33.1 to 34.8.

Conclusion

In this paper, we evaluate the efficacy of white-box knowledge distillation (KD) in enhancing speculative decoding (SD) through improved alignment between target and draft models. A thorough analysis is conducted to understand the impact of training data construction and divergence functions on KD performance. We underscore the significance of utilizing model-generated data and argue that employing the draft model’s on-policy data during KD is a cost-efficient method for realizing the improved alignment. Additionally, we assess the trade-off between quality and latency within the scope of lenience and availability of multiple models of varying quality and size, concluding that KD procures a superior trade-off compared to standard SD. The optimal strategy involves initially applying KD for downstream task performance, followed by SD, resulting in a six to ten-fold decrease in latency with negligible performance loss. Our study contributes novel insights into the white-box KD algorithms for LLMs and provides guidance on striking an effective balance between quality and latency using KD and SD.

YZ led the project and conducted all distillation experiments and evaluations for T5 models. YZ also wrote the initial draft of the paper and made all the plots and tables. KL provided a theoretical justification of the proposed method via Theorem 4.1 and revised the paper. ASR and AKM gave high-level guidance on the project and revised the paper. AR and SK gave high-level feedback on the project. JK supervised the project and served as a host to YZ at Google Research, and conducted evaluation of the decoder-only models. RA served as an advisor, provided detailed technical feedback, and revised the paper.

We would like to extend a special thank you to Neha Gupta, Wittawat Jitkrittum, Nino Veillard, Yaniv Leviathan, Matan Kalman, Danny Vainstein, Natan Potikha, Ananda Theertha Suresh, Laz Karydas, Aishwarya PS, Pranav Nair, Praneeth Netrapalli, Nikunj Saunshi, Ziteng Sun, Keiran Paster, Olivier Bachem, Aleksandra Faust for insightful discussion and valuable feedback.

References

Appendix

Appendix B Method

Below are some common divergence functions used in distillation, given two discrete probability distribution P(C)P(\mathcal{C}) and Q(C)Q(\mathcal{C}).

B.2 Justification of using on-policy data

First, we decompose α(x)\alpha(x) and ϵ(x)\epsilon(x) into sums of contributions from each token to ease the analysis:

For all xx, α(x)=11Lp(x)t=1TAt\alpha(x)=1-\frac{1}{L_{p}(x)}\sum_{t=1}^{T}A_{t} and ϵ(x)1Tt=1TEt\epsilon(x)\geq\frac{1}{T}\sum_{t=1}^{T}E_{t}, where

Similarly, by definition of ϵ(x)\epsilon(x) and DTVDD_{\operatorname{TVD}} we have

Lemma B.1 motivates us to study At(x)A_{t}(x) and Et(x)E_{t}(x) instead. Below we rewrite them as variational forms that will be used later. For this, we introduce some defintiions.

For any sequence z{P,Q}τz\in\{\mathtt{P},\mathtt{Q}\}^{\tau} that consists only of letters P\mathtt{P} and Q\mathtt{Q}, we define M(x,y,z)\mathscr{M}(x,y,z) as the distribution of sequences sampled as follows:

If there are t1t-1 tokens, sample the tt-th token from Mp\mathscr{M}_{p} if zt=Pz_{t}=\mathtt{P}, and from Mq\mathscr{M}_{q} otherwise;

Repeat until an end-of-sequence token is sampled, or the sequence length has reached τ\tau.

We use the shorthand Pk\mathtt{P}^{k} and Qk\mathtt{Q}^{k} to denote the sequence of kk consecutive letters of P\mathtt{P} and Q\mathtt{Q} respectively. We use Ωk\Omega^{k} to denote the set of all possible strings of length kk, and δ:Ωt[1/2,1/2]\delta:\Omega^{t}\to[-1/2,1/2] to denote a function that maps a sequence of tt tokens to a real number in $.Weabusethenotationandassign. We abuse the notation and assign\delta(y)=0forallfor ally\notin\Omega^{t}$.

For a fixed pair of xx and yy, we rewrite the total variance distance between p(ytx,y<t)p(y_{t}\,|\,x,y_{<t}) and q(ytx,y<t)q(y_{t}\,|\,x,y_{<t}) as the following variational form:

Then after taking the expectations we have

where the last step is due to our abuse of notation that δ(y)=0\delta(y)=0 for all yΩty\notin\Omega^{t}. This proves Equation B.7, and Equation B.8 can be proved similarly. ∎

With our variational forms of At(x)A_{t}(x) and Et(x)E_{t}(x), we obtain the following lemma for bounding At(x)A_{t}(x) in terms of Et(x)E_{t}(x).

Now taking a telescoping sum over 1kt1\leq k\leq t, we obtain

Taking a telescoping sum over 1kt11\leq k\leq t-1 gives

Subtracting Equation B.11 from Equation B.10, we have the following holds for all functions δ:S[1/2,1/2]\delta:\mathscr{S}\to[-1/2,1/2]:

Taking a sum of Equation B.9 over 1tT1\leq t\leq T, we have

which proves the theorem statement after taking the expectation over xXx\sim X. ∎

Appendix C Implementation Details

In this section, we present a comprehensive overview of the datasets employed in this study.

The Extreme Summarization (XSum) dataset serves as an evaluation benchmark for abstractive single-document summarization systems. This dataset comprises 226,711 news articles, sourced from BBC articles spanning the years 2010 to 2017. These articles encompass a wide range of domains, including News, Politics, Sports, Weather, Business, Technology, Science, Health, Family, Education, Entertainment, and Arts. Summarization performance is evaluated using ROUGE scores on the validation dataset split of XSum, primarily emphasizing ROUGE-2, while observing similar trends in ROUGE-LSum and ROUGE-1. A maximum input length of 1024 and a maximum output length of 64 are employed for distillation and evaluation.

The CNN/Daily Mail (CNN/DM) dataset is tailored for text summarization. It comprises abstractive summary bullets generated by humans from news stories on CNN and Daily Mail websites, presented in the form of questions with entities hidden. These questions are answered using corresponding passages from the source text. Similar to XSum, ROUGE scores on the validation dataset are reported, primarily emphasizing ROUGE-2, with observations in ROUGE-LSum and ROUGE-1. A maximum input length of 2048 and a maximum output length of 128 are used for distillation and evaluation.

The WMT14 EnDe dataset stands as a standard benchmark for machine translation. It entails the task of translating English text into German while preserving content, semantics, and style. Evaluation relies on the BLEU score, measuring the similarity of machine-translated text to high-quality reference translations. A maximum input length of 80 and a maximum output length of 80 are employed for distillation and evaluation, with performance assessed on the original test split.

The GSM8K dataset comprises 8.5K high-quality, linguistically diverse grade school math word problems crafted by human problem writers. The dataset is divided into 7.5K training problems and 1K test problems, with solutions typically requiring 2 to 8 steps involving elementary calculations using basic arithmetic operations. To enhance reasoning abilities, we explored distillation alongside the zero-shot chain-of-thought (CoT) method, as described in Agarwal et al. (2023). A maximum input length of 256 and a maximum output length of 320 are used for distillation and evaluation.

The One Billion Word dataset (LM1B) is a widely recognized benchmark for language modeling. The training and held-out data are derived from the WMT 2011 News Crawl dataset, created using Bash shell and Perl scripts. A maximum input length of 128 and a maximum output length of 128 are used for distillation and evaluation.

C.2 Models

In accordance with Leviathan et al. (2023), we evaluate two model types: 1) GPT-like decoder-only Transformer models trained using the standard autoregressive objective on LM1B task (Chelba et al., 2013), where the target and draft models have 234M and 33M parameters, respectively; and 2) Standard encoder-decoder T5 v1.1 models (Raffel et al., 2020) supervised fine-tuned on four different tasks, with T5-XL (3B) and T5-Small (77M) serving as the target and draft models, respectively.

The Mp\mathscr{M}_{p} for decoder-only model experiment has: dimension 1024, feed-forward dimension 4096, 12 layers, and 16 attention heads. The corresponding Mqθ\mathscr{M}_{q}^{\theta} has 33M parameters: dimension 512, feed-forward dimension 1024, 4 layers, and 4 attention heads. All models utilize the T5 tokenizer with 32k tokens. As for the T5 base checkpoints, we start from LM-adapted T5v1.1 models. These LM-adapted models are initialized from T5v1.1 and trained for an additional 100K steps on the LM objective discussed in the T5 paper (Raffel et al., 2020). These checkpoints are open-sourced at https://console.cloud.google.com/storage/browser/t5-data/pretrained_models.

In our experiments, both the student and teacher models for the distillation process are initialized from the supervised fine-tuning on the original training dataset. This process is detailed as follows:

XSum. For small, base, large, XL and XXL models, we use LM-Adapted T5v1.1 models supervised fine-tuned for 100K, 50K, 30k, 20K and 8K steps respectively.

CNN/DM. For small, base, large, XL and XXL models, we use LM-Adapted T5v1.1 models supervised fine-tuned for 200K, 80K, 20k, 20k and 4K steps respectively.

WMT. For small, base, large, XL and XXL models, we use LM-Adapted T5v1.1 models supervised fine-tuned for 250K, 250K, 110k , 50K and 10K steps respectively.

GSM8K. All models were supervised fine-tuned starting from FLAN-T5 models on the Palm-540 generated CoT dataset for 10K steps.

C.3 Distillation

We employ the Adafactor optimizer (Shazeer & Stern, 2018) to train our draft student model, denoted as Mqθ\mathscr{M}_{q}^{\theta}, across all our experiments, following the guidelines outlined in Algorithm A.2. In the context of our knowledge distillation (KD) loss function, as defined in Eq. 2, we maintain the temperatures for both the target model, denoted as TpT_{p}, and the draft model, denoted as TqT_{q}, at a constant value of 1.0. It is imperative to emphasize the significance of preserving this uniform temperature setting, as it plays a pivotal role in speculative decoding, ensuring a consistent and coherent semantic interpretation for both Mp\mathscr{M}_{p} and Mqθ\mathscr{M}_{q}^{\theta}. A summary of the hyperparameters used in our knowledge distillation process can be found in Table C.1.

C.4 Evaluation

To obtain scores for each task (specifically, Rouge2 for XSum and CNN/DM, BLEU for WMT, and Accuracy for GSM8K), we employ the evaluation methodology as outlined by Agarwal et al. (2023) to assess all examples within the test or validation sets and subsequently report the average performance. For the assessment of empirical speculative decoding metrics, encompassing empirical acceptance rate and empirical block efficiency, we conduct evaluations on all instances within the test or validation sets and then report the average performance. For measuring the actual latency, we adhere to the procedure detailed in Leviathan et al. (2023), where both our target model and draft model are executed on the same device, specifically a TPUv4, without utilizing model parallelism. To gauge the real latency, we randomly sample 500 examples from either the test or validation set, and measure the decoding time with a batch size of 1. This measurement procedure is repeated three times, and the mean performance is reported. It is worth noting that we have observed minimal variance across different random seeds in our results.

Appendix D Additional Results

D.2 Distillation recipe

D.3 Quality versus latency trade-off