DistillSpec: Improving Speculative Decoding via Knowledge Distillation
Yongchao Zhou, Kaifeng Lyu, Ankit Singh Rawat, Aditya Krishna Menon, Afshin Rostamizadeh, Sanjiv Kumar, Jean-François Kagy, Rishabh Agarwal
Introduction
Large language models (LLMs) have revolutionized natural language understanding and generation across diverse applications (OpenAI, 2023; Anil et al., 2023). However, their autoregressive generation nature poses significant computational challenges, especially in real-time deployments with stringent latency constraints (Thoppilan et al., 2022; Pope et al., 2023). Conversely, smaller language models, while computationally efficient, often lack the expressive power of their larger counterparts and achieve subpar performance. While reducing the inference cost of larger models, e.g., via quantization or pruning, or improving the performance of the smaller models, e.g., via knowledge distillation (KD) (Hinton et al., 2015), constitute natural approaches to enable a favorable performance versus inference cost trade-off, these approaches frequently result in unacceptable performance gap compared to the high-quality large models. This has inspired a growing literature on designing mechanisms that combine both large and small models at inference to approximate the performance of the larger models without incurring their high computational cost.
Conventionally, model cascading approaches aim to identify easy instances where smaller models suffice to achieve good performance, thereby soliciting larger models only on a subset of hard instances (Rowley et al., 1998; Xu et al., 2014) or tasks (Cai et al., 2023b). Different from such task- or instance-level cascading, speculative decoding (SD) (Leviathan et al., 2023; Chen et al., 2023) aims to exploit the token-level variability in the computation demand during LLM inference by interactively invoking a small “draft” model and a large “target” model. At a given stage during inference, the draft model generates successive candidate tokens for multiple inference steps via autoregressive decoding. The target model then verifies the candidate tokens via parallel decoding, and employs rejection sampling to accept a subset of candidate tokens at contiguous positions.
The main objective of SD is to speed up text generation while guaranteeing that the decoded tokens follow the target model distribution. SD relies on the insight that the combined cost of autoregressive decoding with a small draft model followed by parallel decoding with the target model is lower than the cost of autoregressive decoding with the target model alone. However, the realized inference cost reduction or latency improvement crucially depends on the acceptance rate of the draft-generated tokens by the target model, which can be shown to be directly tied to the alignment between the token distributions of the draft and target models. Thus, a successful application of SD hinges on identifying a compact draft model that simultaneously has small autoregressive decoding cost and is closely aligned with the target model.
In this work, we propose DistillSpec, a novel approach that relies on KD (Hinton et al., 2015) to obtain an effective draft model. Unlike the standard application of KD which primarily focuses on improving the task performance of a small student model, DistillSpec aims at aligning the student (draft) model with the teacher (target) model to enhance the acceptance rate during speculative decoding. This requires the student model to closely approximate the teacher distribution at the token and sequence level, even if it translates to suboptimal downstream task performance.
We undertake a comprehensive exploration of the distillation process for speeding up SD, considering several factors including the composition of training data, choice of divergence functions to define the training objective for KD, and decoding strategies. Notably, our findings underscore that using model-generated data is crucial for ensuring strong student-teacher alignment across various tasks via KD, and that the selection of the best-performing divergence function in DistillSpec is highly task-dependent and sensitive to the decoding strategy (i.e., greedy versus non-greedy). Furthermore, we explore the utility of DistillSpec for lossy SD (Leviathan et al., 2023) which allows for sampling away from the target model distribution. We show that combining DistillSpec with lossy SD enables a more fine-grained control over the latency versus task performance trade-off.
Finally, we carry out a systematic study of how to design an efficient inference scheme in a practical setting where one has access to multiple language models of increasing size and quality. Leveraging the insights that we have laid out in this paper about KD and SD, our study concludes that the most effective strategy involves first distilling a large model into a smaller one as the potential target model for performance optimization, followed by DistillSpec for distilling an even smaller model to be used as the draft model in SD. This approach results in a remarkable reduction in latency, compared to a standalone non-distilled target of same size, with minimal performance degradation.
We propose DistillSpec that uses KD to enhance draft model alignment with the target model (§4), and show that our method can improve SD speed by 1045% while preserving model performance across four diverse datasets with both greedy and non-greedy sampling (Figure 1).
We conduct an extensive analysis of the optimal distillation recipe (§5.2) for model alignment, encompassing factors such as training data generation and different divergences, and emphasizing the distinctions between standard KD and distillation tailored for SD.
We extend DistillSpec to lossy SD, enabling refined control over the quality-latency trade-off. Moreover, we offer insights for combining KD and SD when several models are available (§5.3).
Related Work
Due to the inherent sequential nature of autoregressive decoding, the primary latency bottleneck in LLM inference arises from memory read/write operations rather than arithmetic computations (Pope et al., 2023). Speculative decoding (Leviathan et al., 2023; Chen et al., 2023) (SD) addresses this challenge by utilizing a compact draft model to generate a batch of tokens sequentially, while validating them in parallel with a larger target model. Prior to SD, various parallel computing paradigms have been explored for autoregressive models, including block parallel sampling (Stern et al., 2018), shallow aggressive decoding (Sun et al., 2021), and aggressive decoding (Ge et al., 2022). However, these approaches are not readily adaptable to typical language models due to potential deviations from target model’s distribution, strict input constraints, or limited support for general stochastic sampling. Notably, recent variants of SD have also incorporated parallel computation along the batch axis, sometimes combined with token tree verification, as seen in SpecTr (Sun et al., 2023), SpecInfer (Miao et al., 2023), and Medusa (Cai et al., 2023a). In contrast, our work focuses on enhancing SD by improving the alignment between the small draft model and the large target model through KD, which does not require any changes to serving infrastructures already implementing SD and is complementary to the recent variants of SD. Furthermore, we do a systematic study of lossy SD for providing nuanced control over the trade-off between quality and latency for specific serving models.
KD (Buciluǎ et al., 2006; Hinton et al., 2015), which trains high-quality smaller student models with the help of larger teacher models, has emerged as a vital technique for reducing the inference cost while maintaining performance quality across a range of domains. In the context of LLMs, prior uses of KD (Taori et al., 2023; Fu et al., 2023) have mostly focused on black-box KD, wherein only teacher’s output generations, often via APIs, are accessible during student training. However, with the proliferation of open-source LLMs (Zhang et al., 2022; Touvron et al., 2023), there is a growing interest in white-box KD, where we have access to teacher weights and logits. White-box KD allows student models to benefit from richer supervision signals provided by white-box teacher models, leading to enhanced language abilities (Agarwal et al., 2023; Gu et al., 2023; Wen et al., 2023). Despite notable improvements in student quality, substantial performance gaps persist between large and small models (OpenAI, 2023; Anil et al., 2023), which may remain unbridgeable through distillation alone.
Unlike prior works focused on creating highly capable standalone student models, we harness KD to foster closer collaboration between smaller and larger models in SD, which may be particularly valuable when a small distilled model alone cannot meet stringent quality requirements. While Stern et al. (2018) use an black-box KD approach (SeqKD) to enhance blockwise parallel decoding, they use samples generated from the large target model, which is prohibitively expensive for LLMs. Furthermore, they ignore the teacher model’s logits and train their draft model using only one-hot teacher labels – a reasonable choice for greedy decoding but a less effective one for non-greedy sampling (Figure 2). Relatedly, Rawat et al. (2021) leverage KD to improve model cascading. However, different from our efforts which focus on text generation with LLMs, their study focuses on classifications tasks in vision and NLP domains.
Background: Speculative Decoding
Given an input sequence comprising tokens from a pre-defined vocabulary, a language model provides a distribution over possible output sequences . Suppose we employ SD with a compact draft model to assist a larger target model . Let and represent the distributions governing next-token predictions at time step for and , respectively, given the context . Given input as prefix, let and represent the distributions governing the sequence sampled autoregressively from and , respectively, where the generation stops either when an end-of-sequence token is sampled, or the maximum sequence length is reached. For simplicity, we use and as shorthands for and , whenever the context is clear. Similarly, and serve as shorthands for and , whenever the input is clear.
Standard SD uses a procedure called speculative sampling to generate tokens from the draft model while maintaining the same output distribution as the target model. As detailed in Algorithm A.1 (Appendix), each step of SD works as follows. First, a block of tokens, denoted as , is autoregressively sampled from . Next, the tokens are verified in parallel by passing them to as a whole block, which sequentially accepts token with probability . If any token is rejected before the end of the block, the subsequent tokens are discarded and the rejected token is resampled from the adjusted distribution ; otherwise, the tokens are all accepted and an extra token is sampled from and appended to the output sequence. This process guarantees that the sequence of accepted and resampled tokens follow the same output distribution as (Leviathan et al., 2023). The procedure is repeated until an end-of-sequence token is accepted, or the maximum sequence length has been reached.
Each SD step takes a constant amount of time, so the wall-clock time scales linearly with the number of steps. This number is equal to the total number of steps that the target model rejects a token, plus the number of blocks accepted as a whole, where the latter term is small for large . This motivates us to use the acceptance rate as a surrogate efficiency measure for the wall-clock time. For an ideal SD process with , we define the sequence-level acceptance rate for a given input as follows:
Walltime improvement. For given block efficiency , the expected speedup of SD is given by , where the relative latency is the ratio between the times for making a single forward pass through the draft model and target model .
DistillSpec: Knowledge Distillation for Speculative Decoding
As described in § 3, speculative decoding (SD) can reduce the latency of the larger (target) model with the help of a smaller (draft) model without any performance drop. However, the realized latency reduction via SD critically depends on how “well-aligned” the draft model is to the target model. Setting performance improvement of SD as a primary objective, our proposed DistillSpec method enables a closer alignment between the draft and target models by leveraging KD. We first present KD-based training of the draft model. We highlight how our objective of enhancing SD via KD influences our selection of training data generation method and divergence function – two key ingredients of DistillSpec. We then discuss how DistillSpec can be extended to lossy SD.
Let the draft model be parameterized by . DistillSpec utilizes predictions from the target model as a source of supervision (teacher) while training the draft model (student). We assume white-box access to both models, i.e., we can obtain their token-level distributions and , and therefore we are able to generate samples both from the target and draft models. Given a divergence function that measures the misalignment between two distributions, KD-based training of the draft model seeks to minimize the divergence between the teacher (target) and student (draft) distributions over a training set :
Our choices for and are guided by how the resulting distilled model, once employed as draft model, improves the speed of SD. Towards this, we first highlight the role that the total variation distance between and plays in dictating the acceptance rate (§ 3)—a key efficiency measure for SD.
Leviathan et al. (2023, Corollary 3.6) shows that the token-level acceptance rate satisfies . Hence, Eq. 1 implies that maximizing the sequence-level acceptance rate is equivalent to minimizing the expected between and over the output sequence distribution of , i.e.,
Choice of divergence. Based on Eq. 3, it appears that directly minimizing may be a principled objective for draft model distillation. While optimizing is theoretically inspired, our empirical study shows that such an objective may not consistently yield optimal results. We find that the choice of the most suitable divergence is highly task-dependent (§ 5.2).
As for , Eq. 3 suggests optimizing the expected over outputs generated from the teacher. However, decoding from a large teacher is generally prohibitively expensive, especially at the scale of dataset required for KD. Alternatively, one could resort to an existing ground-truth dataset, however the teacher’s output distribution may deviate from the ground-truth distribution despite the teacher having been fine-tuned on it. Moreover, ground-truth datasets are often limited in size, so training only on such data could result in overfitting. To resolve these issues, we explore using on-policy data during distillation, i.e., output sequences sampled from the student itself. Besides being more computationally efficient compared to teacher generations, this approach is inspired by Gu et al. (2023); Agarwal et al. (2023), where distilling on on-policy data is shown to improve student task performance. However, different from these prior works, our primary focus is on improving the student-teacher alignment. Thus, it is not immediately clear whether minimizing the expected over on-policy (student-generated) data ensures an improved acceptance rate, which is computed as an expectation over the teacher’s output distribution (cf. Eq. 3). For this, our following result shows that this is indeed the case.
We defer the proof to Section B.2. Intuitively, it builds upon the following insights. If the on-policy KD loss is small, then, for any , the same loss evaluated only at the -th token should also be small. Since the first token generation is independent of any other tokens, a small value of online-policy KD loss ensures that the first token distributions of the draft and target models are close. Then, an inductive argument shows that once the draft and target are similar on the first tokens, the distributions of the -th token should also be close. Our proof makes this argument rigorous by utilizing variational representations of , leading to a linear error bound in .
Experiments
We evaluate the effectiveness of KD in improving the speed of speculative decoding (SD). We specifically investigate its impact on the enhancement of acceptance rate, block efficiency, and latency.
Following Leviathan et al. (2023), we evaluate two model types: 1) GPT-like decoder-only Transformer models trained using the standard autoregressive objective on LM1B task (Chelba et al., 2013), where the target and draft models have 234M and 33M parameterscolor=cyan][Ankit] Update model sizes if needed., respectively; and 2) Standard encoder-decoder T5 v1.1 models (Raffel et al., 2020) supervised fine-tuned on four different tasks, with T5-XL (3B) and T5-Small (77M) color=cyan][Ankit] Change model names/sizes if needed.serving as the target and draft models, respectively. As for the four datasets, we utilize two datasets form the T5 paper, namely WMT EnDe (Bojar et al., 2014) and CNN/DM (Hermann et al., 2015) which deal with translation and text summarization, respectively. The remaining two tasks used to test T5 models are XSum (Narayan et al., 2018) and GSM8K (Cobbe et al., 2021), which deal with abstractive summarization and arithmetic reasoning, respectively. See Appendix C for more details.
We study five KD algorithms outlined in Table 1. However, for SeqKD (Kim & Rush, 2016) and -Distill (Wen et al., 2023), we opt for an online data generation approach from the teacher instead of a conventional fixed offline teacher-generated dataset. This approach, while computationally expensive, yields a more diverse dataset. For GKD, we solely rely on the data generated by the online student model, excluding the static ground truth data. All data generated by either the teacher or the student is based on temperature sampling with a temperature of 1.0 (see Appendix D.1.3 for an ablation study on sampling temperature).
For each pair of target model and draft model, we measure the empirical acceptance rate and block efficiency with three different block sizes , both with and without distillation. As per Leviathan et al. (2023), we measure the relative wall time improvements with a batch size of 1 for both greedy sampling () and standard temperature sampling ().
Figure 1 shows that the impact of distillation on SD speed is evident, consistently yielding a 1046% improvement across various datasets. This effect is most pronounced when employing greedy decoding. The summary of results for different block sizes and decoding strategies across five datasets is presented in Table D.1 (Appendix), highlighting the superior performance of the KD algorithm in terms of latency. The findings demonstrate that KD significantly enhances the acceptance rate and block efficiency for both decoder-only and encoder-decoder models across all datasets. Distillation algorithms utilizing model-generated data consistently outperform other approaches, resulting in approximately a 20% additional speedup compared to standard SD on LM1B, XSum, CNN/DM, and GSM8K. However, the gains on the WMT dataset are marginal, as the preliminary model already achieves a high acceptance rate and block efficiency without KD.
Figure 2 presents a comparison of block efficiency across different algorithms, employing temperature sampling () with a block size . The figure showcases the utility of model generated data as using a fixed ground truth data (i.e., Supervised KD) ranked the lowest across all settings except WMT. In contrast, -Distill and GKD which use the purely model generated data significantly outperform the other counterparts. Besides, the subpar performance of SeqKD, despite purely trained on the data generated by the target model, implies that white-box distillation (information from the target model’s logits) is vital for SD. This is corroborated by Figure 3(a), which illustrates the evolution of the acceptance rate throughout the training. Supervised KD ranks lowest, and its performance plateaus as training progresses due to a static dataset. In contrast, all other algorithms that employ model-generated data continue to improve. Despite -Distill being much more computationally costly than GKD due to using teacher generated data, both exhibit comparable performance. Notably, GKD achieves the best wall-time performance improvement. (see Appendix D.1.2 for more visualizations on performance improvement during training).
We investigate whether KD improves block efficiency universally or mainly impacts a subset of examples. We depict the change in block efficiency per example in Figure 3(b). The results reveal a consistent enhancement in block efficiency across most examples, as also seen in various datasets (see Figure D.12). Figure 3(c) illustrates a strong agreement between theoretical and empirical block efficiency values for several distilled models (each model as a filled circle). Despite theoretical values occasionally overestimating or underestimating empirical values due to potential deviations from the i.i.d. token-level assumption (cf.§3), the ranking of distilled models remains highly consistent. In summary, these findings affirm that KD effectively optimizes block efficiency.
2 DistillSpec recipe
We now focus on identifying the optimal KD approach for SD. Following the training and evaluation protocols in § 5.1, we explore four training data construction methods and four divergence functions on XSum and GSM8K. In particular, we explore the following variants of training data: 1) Fixed ground-truth dataset , 2) Data generated only from the draft , 3) Data generated only from teacher , 4) Data generated from both and in equal proportion. We further examine the following divergences: 1) Forward KL (FKL), 2) Jenson-Shannon divergence (JSD), 3) Reverse KL (RKL), and 4) Total variation distance (TVD).
Figure 4 illustrates the block efficiency improvement on XSum and GSM8K, in line with observations from § 5.1. Note that using model-generated data (last three rows) yields superior performance than using a fixed dataset (first row). Specifically, on XSum with greedy decoding, using data generated from both and gives the best performance, with JSD slightly outperforming the other divergences. However, on GSM8K with greedy sampling, FKL with only draft generated data emerges as the best KD setup. In contrast, with temperature sampling (at ), a different trend is observed as RKL combined with data generated by is the most effective. See Appendix D.2.1 for results on different datasets and decoding strategies. Nonetheless, using only draft generated data is competitive.
We also study how different distillation approaches affect draft model’s task performance and if the same design choices are optimal for improving both draft task performance and its utility for SD (cf. Figure D.18, D.19). Similar to our earlier observations, the use of generated data is paramount for improving draft performance. More notably, utilizing data generated from yields comparable or superior results compared to using data generated from . However, the choice of the optimal KD algorithm largely depends on the task and the underlying decoding strategy. More interestingly, Figure 5(a) highlight a dichotomy between block efficiency improvements and task performance gains via KD as a drafter with high task performance does not necessarily indicate a powerful drafter for SD. See Appendix D.2.2 for more results on different datasets and decoding strategies.
Interestingly, although TVD is the objective we aim to optimize for SD (cf. Eq. 3), its direct optimization does not yield the best performance in most of the settings explored. We generally find that the choice of the divergence in KD is a hyperparameter that needs to be tuned based on the task and decoding strategy used. For training data construction, we propose using draft for data generation as it can achieve similar or superior performance compared to teacher , but at a much lower cost.
3 Quality versus latency trade-off
We analyze the quality-latency trade-off using lossy SD variants, as detailed in Algorithm A.1. Figure 5(b) illustrates that employing either KD () or SD () alone does not fully bridge the performance or latency gaps, respectively. In such cases, a leniency parameter can help interpolate between these two approaches, as demonstrated in Figure 5(b). Interpolating on GSM8K presents challenges, as still results in high performance and latency when using a lenience of , while pushes it further along the trade-off curve. Although can still interpolate, it yields a worse trade-off. Interestingly, it is possible to significantly reduce latency while almost preserving the quality in GSM8K, possibly because many tokens have little impact on the final performance, and accepting them is a model preference with minimal effect on generation quality. See Appendix D.3.1 for comparison between raw draft and distilled draft models, where we show that distilled draft achieves a much better trade-off.
In practical scenarios, we often have access to multiple models of different sizes – a model garden – to design the inference pipeline. We consider this setting by focusing on the T5-models with five sizes: T5-Small (77M), T5-Base (250M), T5-Large (800M), T5-XL (3B), and T5-XXL (11B). We study four different quality-latency trade-off curves using KD and SD: 1) Raw: Deploying supervised fine-tuned (SFT) T5 models; 2) Distilled: Applying KD to T5 models for optimizing downstream task performance 3) Speculative: Applying SD to T5 models; and 4) DistillSpec: Applying KD to T5 models and using distilled models as target and draft.
Figure 6 illustrates that SD effectively shifts the trade-off curve leftward, especially with larger model sizes. However, its efficacy diminishes with smaller model sizes when the relative computation time between the draft and target models is closely matched. In contrast, distillation, which optimizes the model for downstream task performance, appears to offer a superior trade-off between quality and latency, particularly for smaller models. Conversely, a reverse trend is observed for larger model sizes when evaluating the model with temperature sampling. Figure 25(a) indicates a substantial gap between the distilled model and the larger teacher model, while the SD-based method significantly reduces latency. This suggests that when stringent performance and decoding strategy constraints are in place, SD remains a valuable approach. Notably, our method, DistillSpec, which combines the benefits of distillation and SD, consistently achieves the best trade-off between quality and latency, resulting in an impressive reduction in latency while maintaining nearly identical performance. Specifically, DistillSpec reduces relative latency from 17.3 to 2.7 and from 15.0 to 1.4 on XSum and GSM8K, respectively, representing speedup improvements of 6.4 and 10.7. In contrast, the Rouge2 Score only experiences a marginal decrease, shifting from 23.1 to 23.0 on XSum, while the model accuracy on GSM8K actually improves, rising from 33.1 to 34.8.
Conclusion
In this paper, we evaluate the efficacy of white-box knowledge distillation (KD) in enhancing speculative decoding (SD) through improved alignment between target and draft models. A thorough analysis is conducted to understand the impact of training data construction and divergence functions on KD performance. We underscore the significance of utilizing model-generated data and argue that employing the draft model’s on-policy data during KD is a cost-efficient method for realizing the improved alignment. Additionally, we assess the trade-off between quality and latency within the scope of lenience and availability of multiple models of varying quality and size, concluding that KD procures a superior trade-off compared to standard SD. The optimal strategy involves initially applying KD for downstream task performance, followed by SD, resulting in a six to ten-fold decrease in latency with negligible performance loss. Our study contributes novel insights into the white-box KD algorithms for LLMs and provides guidance on striking an effective balance between quality and latency using KD and SD.
YZ led the project and conducted all distillation experiments and evaluations for T5 models. YZ also wrote the initial draft of the paper and made all the plots and tables. KL provided a theoretical justification of the proposed method via Theorem 4.1 and revised the paper. ASR and AKM gave high-level guidance on the project and revised the paper. AR and SK gave high-level feedback on the project. JK supervised the project and served as a host to YZ at Google Research, and conducted evaluation of the decoder-only models. RA served as an advisor, provided detailed technical feedback, and revised the paper.
We would like to extend a special thank you to Neha Gupta, Wittawat Jitkrittum, Nino Veillard, Yaniv Leviathan, Matan Kalman, Danny Vainstein, Natan Potikha, Ananda Theertha Suresh, Laz Karydas, Aishwarya PS, Pranav Nair, Praneeth Netrapalli, Nikunj Saunshi, Ziteng Sun, Keiran Paster, Olivier Bachem, Aleksandra Faust for insightful discussion and valuable feedback.
References
Appendix
Appendix B Method
Below are some common divergence functions used in distillation, given two discrete probability distribution and .
B.2 Justification of using on-policy data
First, we decompose and into sums of contributions from each token to ease the analysis:
For all , and , where
Similarly, by definition of and we have
Lemma B.1 motivates us to study and instead. Below we rewrite them as variational forms that will be used later. For this, we introduce some defintiions.
For any sequence that consists only of letters and , we define as the distribution of sequences sampled as follows:
If there are tokens, sample the -th token from if , and from otherwise;
Repeat until an end-of-sequence token is sampled, or the sequence length has reached .
We use the shorthand and to denote the sequence of consecutive letters of and respectively. We use to denote the set of all possible strings of length , and to denote a function that maps a sequence of tokens to a real number in $\delta(y)=0y\notin\Omega^{t}$.
For a fixed pair of and , we rewrite the total variance distance between and as the following variational form:
Then after taking the expectations we have
where the last step is due to our abuse of notation that for all . This proves Equation B.7, and Equation B.8 can be proved similarly. ∎
With our variational forms of and , we obtain the following lemma for bounding in terms of .
Now taking a telescoping sum over , we obtain
Taking a telescoping sum over gives
Subtracting Equation B.11 from Equation B.10, we have the following holds for all functions :
Taking a sum of Equation B.9 over , we have
which proves the theorem statement after taking the expectation over . ∎
Appendix C Implementation Details
In this section, we present a comprehensive overview of the datasets employed in this study.
The Extreme Summarization (XSum) dataset serves as an evaluation benchmark for abstractive single-document summarization systems. This dataset comprises 226,711 news articles, sourced from BBC articles spanning the years 2010 to 2017. These articles encompass a wide range of domains, including News, Politics, Sports, Weather, Business, Technology, Science, Health, Family, Education, Entertainment, and Arts. Summarization performance is evaluated using ROUGE scores on the validation dataset split of XSum, primarily emphasizing ROUGE-2, while observing similar trends in ROUGE-LSum and ROUGE-1. A maximum input length of 1024 and a maximum output length of 64 are employed for distillation and evaluation.
The CNN/Daily Mail (CNN/DM) dataset is tailored for text summarization. It comprises abstractive summary bullets generated by humans from news stories on CNN and Daily Mail websites, presented in the form of questions with entities hidden. These questions are answered using corresponding passages from the source text. Similar to XSum, ROUGE scores on the validation dataset are reported, primarily emphasizing ROUGE-2, with observations in ROUGE-LSum and ROUGE-1. A maximum input length of 2048 and a maximum output length of 128 are used for distillation and evaluation.
The WMT14 EnDe dataset stands as a standard benchmark for machine translation. It entails the task of translating English text into German while preserving content, semantics, and style. Evaluation relies on the BLEU score, measuring the similarity of machine-translated text to high-quality reference translations. A maximum input length of 80 and a maximum output length of 80 are employed for distillation and evaluation, with performance assessed on the original test split.
The GSM8K dataset comprises 8.5K high-quality, linguistically diverse grade school math word problems crafted by human problem writers. The dataset is divided into 7.5K training problems and 1K test problems, with solutions typically requiring 2 to 8 steps involving elementary calculations using basic arithmetic operations. To enhance reasoning abilities, we explored distillation alongside the zero-shot chain-of-thought (CoT) method, as described in Agarwal et al. (2023). A maximum input length of 256 and a maximum output length of 320 are used for distillation and evaluation.
The One Billion Word dataset (LM1B) is a widely recognized benchmark for language modeling. The training and held-out data are derived from the WMT 2011 News Crawl dataset, created using Bash shell and Perl scripts. A maximum input length of 128 and a maximum output length of 128 are used for distillation and evaluation.
C.2 Models
In accordance with Leviathan et al. (2023), we evaluate two model types: 1) GPT-like decoder-only Transformer models trained using the standard autoregressive objective on LM1B task (Chelba et al., 2013), where the target and draft models have 234M and 33M parameters, respectively; and 2) Standard encoder-decoder T5 v1.1 models (Raffel et al., 2020) supervised fine-tuned on four different tasks, with T5-XL (3B) and T5-Small (77M) serving as the target and draft models, respectively.
The for decoder-only model experiment has: dimension 1024, feed-forward dimension 4096, 12 layers, and 16 attention heads. The corresponding has 33M parameters: dimension 512, feed-forward dimension 1024, 4 layers, and 4 attention heads. All models utilize the T5 tokenizer with 32k tokens. As for the T5 base checkpoints, we start from LM-adapted T5v1.1 models. These LM-adapted models are initialized from T5v1.1 and trained for an additional 100K steps on the LM objective discussed in the T5 paper (Raffel et al., 2020). These checkpoints are open-sourced at https://console.cloud.google.com/storage/browser/t5-data/pretrained_models.
In our experiments, both the student and teacher models for the distillation process are initialized from the supervised fine-tuning on the original training dataset. This process is detailed as follows:
XSum. For small, base, large, XL and XXL models, we use LM-Adapted T5v1.1 models supervised fine-tuned for 100K, 50K, 30k, 20K and 8K steps respectively.
CNN/DM. For small, base, large, XL and XXL models, we use LM-Adapted T5v1.1 models supervised fine-tuned for 200K, 80K, 20k, 20k and 4K steps respectively.
WMT. For small, base, large, XL and XXL models, we use LM-Adapted T5v1.1 models supervised fine-tuned for 250K, 250K, 110k , 50K and 10K steps respectively.
GSM8K. All models were supervised fine-tuned starting from FLAN-T5 models on the Palm-540 generated CoT dataset for 10K steps.
C.3 Distillation
We employ the Adafactor optimizer (Shazeer & Stern, 2018) to train our draft student model, denoted as , across all our experiments, following the guidelines outlined in Algorithm A.2. In the context of our knowledge distillation (KD) loss function, as defined in Eq. 2, we maintain the temperatures for both the target model, denoted as , and the draft model, denoted as , at a constant value of 1.0. It is imperative to emphasize the significance of preserving this uniform temperature setting, as it plays a pivotal role in speculative decoding, ensuring a consistent and coherent semantic interpretation for both and . A summary of the hyperparameters used in our knowledge distillation process can be found in Table C.1.
C.4 Evaluation
To obtain scores for each task (specifically, Rouge2 for XSum and CNN/DM, BLEU for WMT, and Accuracy for GSM8K), we employ the evaluation methodology as outlined by Agarwal et al. (2023) to assess all examples within the test or validation sets and subsequently report the average performance. For the assessment of empirical speculative decoding metrics, encompassing empirical acceptance rate and empirical block efficiency, we conduct evaluations on all instances within the test or validation sets and then report the average performance. For measuring the actual latency, we adhere to the procedure detailed in Leviathan et al. (2023), where both our target model and draft model are executed on the same device, specifically a TPUv4, without utilizing model parallelism. To gauge the real latency, we randomly sample 500 examples from either the test or validation set, and measure the decoding time with a batch size of 1. This measurement procedure is repeated three times, and the mean performance is reported. It is worth noting that we have observed minimal variance across different random seeds in our results.