f-Divergence Minimization for Sequence-Level Knowledge Distillation

Yuqiao Wen, Zichao Li, Wenyu Du, Lili Mou

Introduction

Increasingly large language models have continued to achieve state-of-the-art performance across various natural language generation tasks, such as data-to-text generation (Lebret et al., 2016; Li and Liang, 2021), summarization (Paulus et al., 2018; Zhang et al., 2020a), and dialogue generation (Li et al., 2016b; Zhang et al., 2020b). However, super-large language models are inaccessible to most users and researchers due to their prohibitively large model size, emphasizing the importance of high-performing, parameter-efficient small neural models.

A widely used approach to training small models is knowledge distillation (KD, Hinton et al., 2015), where the small model (known as the student) learns the knowledge from a much larger model (known as the teacher). KD has shown great success in helping smaller models achieve competitive performance across a wide range of applications (Sun et al., 2019; Jiao et al., 2020; Shleifer and Rush, 2020).

Existing KD approaches can be categorized into two main branches: representation matching and distribution matching. The former aims to imitate the teacher’s real-valued intermediate-layer representations, say, with mean squared error (Sun et al., 2019; Jiao et al., 2020). Our work focuses on the latter, distribution matching, where the student model learns the teacher’s predictive distribution. Hinton et al. (2015) minimize the cross-entropy loss against the teacher-predicted soft labels, which is equivalent to minimizing the Kullback–Leibler (KL) divergence between the teacher and student. Kim and Rush (2016) propose SeqKD, arguing that KL divergence should be minimized at the sequence level for language models. However, such an approach tends to learn an overly smooth student distribution to cover the entire support of the teacher distribution due to the asymmetric nature of the KL divergence. This is often known as the mode-averaging problem (Figure 1a).

Tu et al. (2020) propose ENGINE, a non-autoregressive translation model that minimizes the energy function defined by the teacher’s output distribution. It can be shown that their objective is related to minimizing the reverse KL between the teacher and student (see Section 2.2). This, on the other hand, results in the mode-collapsing problem, where the student model is overly concentrated on certain high-probability regions of the teacher distribution (Figure 1b).

In this paper, we address knowledge distillation for text generation tasks, and propose ff-distill, a unified framework that formulates sequence-level knowledge distillation as minimizing ff-divergence functions. Existing SeqKD (Kim and Rush, 2016) and ENGINE (Tu et al., 2020) methods are approximations of KL and reverse KL distillations under the ff-distill framework. Further, our formulation naturally leads to Jensen–Shannon (JS) divergence and total variation distance (TVD) distillations, where the divergence measures are symmetric in teacher and student distributions. This forces the student to learn the teacher’s distribution better, alleviating mode averaging and collapsing problems.

We further develop efficient algorithms for our ff-distill approach. First, we show that sequence-level ff-divergence can be decomposed step by step either exactly or as an upper bound. Second, we propose to sample from the teacher model in an offline manner, mitigating the additional training cost of symmetric divergence measures (namely, JS and TVD).

We evaluated our approach on four datasets: DART for data-to-text generation (Nan et al., 2021), XSum for summarization (Narayan et al., 2018), WMT16 EN-RO for machine translation (Bojar et al., 2016), and Commonsense Dialogue (Zhou et al., 2021). Experiments show that our proposed ff-distill variants consistently outperform existing distribution-matching KD methods, allowing ff-distill to achieve an add-on performance improvement when combined with representation-matching KD methods. Further, results show that our symmetric distilling losses outperform asymmetric ones, confirming that extreme mode averaging or collapsing is not ideal.

To sum up, our contributions are three-fold:

We propose ff-distill, a novel distilling framework that generalizes KL distillation and balances mode averaging and collapsing;

We derive step-wise decomposition and propose an offline sampling method to efficiently compute sequence-level ff-divergences; and

We provide detailed experimental analysis across four text generation datasets to show the effectiveness of our approach.

Approach

In this section, we first review classic knowledge distilling (KD) algorithms and analyze their drawbacks. Then, we propose ff-distill, a generalized distilling framework for sequence-level distillation.

In classic KD, the KL divergence is often used to train the student model to match the teacher’s distribution (Hinton et al., 2015). For autoregressive text generation, this is decomposed into a step-wise KL divergence:

Kim and Rush (2016) propose SeqKD and minimize cross-entropy loss at the sequence level as

In practice, the expectation over the sentence space is intractable, so they approximate it with a hard sequence y\mathbf{y} generated by beam search on the teacher model. Their loss is

2 Our Proposed f𝑓f-distill Framework

To this end, we propose a generalized ff-distill framework, a family of distilling methods based on ff-divergence functions (Ali and Silvey, 1966; Sason and Verdú, 2016).

Formally, the ff-divergence of two distributions is defined as

In the rest of this subsection, we will first present Kullback–Leibler (KL) and reverse KL (RKL) distilling methods, which are closely related to previous work Kim and Rush (2016); Tu et al. (2020). Then, we will propose Jensen–Shannon (JS) and total variation distance (TVD) distillations; they are based on symmetric ff-divergence functions, and are able to force the student to better learn from the teacher distribution.

Kullback–Leibler (KL) distillation. Recall that we denote the teacher distribution by pp and the student distribution by qθq_{\theta}. Using the common KL divergence leads to the standard distilling objective

where y\mathbf{y} is sampledIn our method, the expectation (5) is approximated by one Monte Carlo-sampled sequence. We denote a sampled sequence by a lower letter y\mathbf{y}. from the teacher distribution pp. Here, the constant is the entropy of pp, which can be ignored as it does not involve the student parameters.

Similar to SeqKD, such KL distillation may also suffer from the mode-averaging problem and learn an overly smooth distribution, because qθq_{\theta} is in the denominator in (5).

Reverse KL (RKL) distillation. We propose RKL distillation, which can potentially address the mode-averaging problem:

where y\mathbf{y}^{\prime} is sampled from the student distribution. In other words, the loss can be decomposed into the negative log probability of the teacher’s predicted probability plus the entropy of the student.

RKL does not suffer from mode averaging because the student distribution qθq_{\theta} goes to the numerator and does not have to cover the teacher distribution. Also, the entropy term in (7) penalizes the student for learning a wide-spreading distribution, further mitigating the mode-averaging problem.

However, RKL distillation has the opposite problem, known as mode collapsing, where the student only learns one or a few modes of the teacher distribution. This is because the RKL loss would be large, if qθ(Y)q_{\theta}(\mathbf{Y}^{\prime}) is high but p(Y)p(\mathbf{Y}^{\prime}) is low for some Y\mathbf{Y}^{\prime}. As a result, the student tends to overly concentrate its probability mass on certain high-probability regions of the teacher model, which may not be ideal either (Figure 1b).

Remarks. KL and RKL have the mode-averaging or mode-collapsing problem, because DKL()D_{\text{KL}}(\cdot\|\cdot) is asymmetric in its two arguments, requiring the second distribution to cover the support of the first. In the following, we will propose two ff-distill variants based on symmetric divergence functions to seek a balance between these two extremes.

Jenson–Shannon (JS) distillation. Our proposed JS distillation minimizes the JS divergence, which measures the difference between two distributions and their average. We derive the step-wise decomposition of the sequence-level JS loss:

where y\mathbf{y} and y\mathbf{y}^{\prime} are sampled from the teacher’s and student’s distributions, which are compared with their average m()=12p()+12qθ()m(\cdot)=\frac{1}{2}p(\cdot)+\frac{1}{2}q_{\theta}(\cdot). Appendix A provides the proof of this decomposition, and Subsection 2.3 presents an efficient approximation by avoiding on-the-fly sampling from the teacher.

Total variation distance (TVD) distillation. Our ff-distill gives rise to another novel distilling variant based on the total variation distance

We would like to decompose the sequence-level TVD step by step due to the intractable summation over the sentence space. However, TVD decomposition is non-trivial, and we show in Appendix A that the sequence-level TVD is upper bounded by step-wise terms, being our objective to minimize:

where y{\mathbf{y}} and y{\mathbf{y}}^{\prime} are again sampled from the teacher and student models, respectively.

Summary. In this part, we have described our proposed ff-distill framework with four variants based on different ff-divergence functions. We have also presented their step-wise decompositions, whose justification is summarized by the following theorem, proved in Appendix A.

(a) The sequence-level KL, RKL, and JS divergences can be decomposed exactly into step-wise terms. (b) The sequence-level TVD can be upper bounded by step-wise terms.

3 Implementation Considerations

Efficient approximation. Symmetric distilling losses (i.e., JS and TVD) are slow to compute, because they require sampling from both teacher and student models during training.

We propose to mitigate this by offline sampling for the teacher model to improve training efficiency. Specifically, we obtain teacher samples, i.e., y\mathbf{y} in Eqns. (8) and (10), beforehand and keep them fixed during training. This is feasible because the teacher model is unchanged and hence does not require multiple inferences, whereas the student model is continuously updated and thus requires inference in an online fashion. Experiments show that such a treatment significantly improves the training efficiency for both JS and TVD distillations.

Pre-distillation. We warm-start our student model with the techniques developed by Shleifer and Rush (2020), who combine MLE training, word-level KL, and hidden state matching. Such a pre-distilling process is crucial to our ff-distill method, because most variants (namely, RKL, JS, and TVD distillations) require sampling from a student, but a randomly initialized student model generates poor samples, making the distilling process less meaningful.

Notice that, for a fair comparison, all baseline models are built upon the same pre-distilling process. This further confirms that our ff-distill is compatible with existing techniques and yields add-on performance gain (shown in Section 3.2).

Experiments

Datasets and metrics. We evaluated ff-distill on a wide range of text generation tasks.

\bullet DART. The DART dataset (Nan et al., 2021) is a popular data-to-text generation benchmark, where samples consist of structured data records and their corresponding text descriptions. We report common string-matching metrics, BLEU (Papineni et al., 2002), METEOR (Banerjee and Lavie, 2005), and TER (Snover et al., 2006), as well as popular learned metrics, BERTScore (Zhang et al., 2019), MoverScore (Zhao et al., 2019), and BLEURT (Sellam et al., 2020).

\bullet XSum. Extreme Summarization (XSum, Narayan et al., 2018) is a large-scale dataset consisting of BBC articles and their one-sentence summaries. We report ROUGE scores, the most widely used metrics for summarization (Lin, 2004).

\bullet WMT16 EN-RO. This dataset contains parallel texts for English and Romanian, and is one of the commonly used machine translation datasets (Bojar et al., 2016). We extracted 100K samples from the original dataset, as the teacher performance is nearly saturated at this size. We report BLEU, chrF (Popović, 2015), and TER scores for the translation quality, following existing machine translation literature (Sennrich et al., 2016; Barrault et al., 2019).

\bullet Commonsense Dialogue. The Commonsense Dialogue dataset (Zhou et al., 2021) consists of dialogue sessions that are grounded on social contexts. We evaluated the output quality by BLEU and BERTScore. We only report BLEU1 and BLEU2, as higher-order BLEU scores are known to be unreliable for dialogue evaluation (Liu et al., 2016).

Model architectures. We evaluated ff-distill using state-of-the-art teacher models for different tasks. We followed the encoder–decoder architecture and used BART (Lewis et al., 2020) as the teacher for DART and XSum. We used T5 (Raffel et al., 2020), another encoder–decoder model, for WMT16 EN-RO, as it excels at machine translation. For Commonsense Dialogue, we followed Zhang et al. (2020b) and used DialoGPT, a decoder-only model pretrained on massive dialogue data.

Our student models followed the teachers’ architectures, but we reduced the number of layers. In our experiments, we generally set the total number of layers to be four; specifically, encoder–decoder models had three encoder layers and one decoder layer, following the suggestion of deep encoders and shallow decoders in Kasai et al. (2020). For XSum, we set both the encoder and decoder to be three layers to compensate for the larger dataset. Additional experimental details can be found in Appendix B.

2 Results and Analyses

Main results. Table 2 presents the main results of our ff-distill along with a number of competing methods in the four experiments.

We first trained a neural network without distillation. The network was identical to our student model in terms of the neural architecture and hyperparameters, but we trained it directly by maximum likelihood estimation (MLE) based on ground-truth target sequences. As seen, the non-distilling model performs significantly worse than distilling methods, which agrees with existing literature and justifies the need for knowledge distillation (Hinton et al., 2015; Tang et al., 2019; Jiao et al., 2020).

We pre-distilled our student model based on Shleifer and Rush (2020), a classic distilling approach that combines ground-truth training, word-level distillation, and intermediate-layer matching. Our ff-distill approach requires pre-distillation, because it provides a meaningful initialization of the student model, from which our ff-distill would generate samples during training. That being said, all our distilling methods were built on the same pre-distilling model, constituting a fair comparison. The results show that, although the pre-distilling approach outperforms ground-truth MLE training, it is generally worse than other distilling methods. This implies that our contribution is “orthogonal” to existing methods, and that our ff-distill provides an add-on performance improvement.

We further experimented with SeqKD (Kim and Rush, 2016) and ENGINE (Tu et al., 2020), two established distilling methods in the distribution-matching category (see Section 1). They learn from hard sequences rather than probabilities, and thus are hard approximations of our KL and RKL distillations, respectively (Section 2.1). As seen, our soft label-based methods consistently outperform SeqKD and ENGINE. This suggests that soft labels (i.e., probabilities) provide more informative supervision signals than hard sentences for sequence-level distillation, which is consistent with early literature on classification tasks (Buciluǎ et al., 2006; Hinton et al., 2015).

Among our ff-distill variants, we further observe that symmetric distilling losses (JS and TVD) are consistently better than asymmetric ones (KL and RKL) across all datasets except for WMT16 EN-RO, where KL achieves a slightly better TER performance. A plausible reason is that the machine translation task is semantically grounded: given a source text, there are limited ways to translate, because the model output has to preserve the meaning of the input sentence. This is analogous to learning a uni-modal distribution, where mode averaging does not occur because there is only one mode. Despite this, JS and TVD perform better in all other scenarios, as their symmetric divergence can force the student to better learn from its teacher distribution. They rank first or second for all tasks in terms of most of the metrics in Table 2, consistently and largely outperforming previous methods.

Likelihood and coverage. We further analyze the mode averaging and collapsing behaviors of different distilling methods in Table 3. We propose to measure these aspects by a likelihood risk RllhR_{\text{llh}} and a coverage risk RcvgR_{\text{cvg}}.

The likelihood risk is computed by Rllh=1DstudentyDstudentlogp(y)R_{\text{llh}}=\frac{1}{|\mathcal{D}_{\text{student}}|}\sum\nolimits_{{\mathbf{y}}^{\prime}\in\mathcal{D}_{\text{student}}}-\log p({\mathbf{y}}^{\prime}). Here, Dstudent\mathcal{D}_{\text{student}} is the set of sentences generated from the student, where we sample a sentence for each input in the test set; p(y)p({\mathbf{y}}^{\prime}) is the teacher’s predicted probability of a student-sampled sentence y{\mathbf{y}}^{\prime}. A large likelihood risk suggests that the student may have averaged the teacher’s modes, causing it to generate atypical sentences from the teacher’s point of view (Figure 1a).

On the contrary, the coverage risk is computed by Rcvg=1DteacheryDteacherlogqθ(y)R_{\text{cvg}}=\frac{1}{|\mathcal{D}_{\text{teacher}}|}\sum\nolimits_{{\mathbf{y}}\in\mathcal{D}_{\text{teacher}}}-\log q_{\theta}({\mathbf{y}}), where we use the student qθq_{\theta} to evaluate a teacher-sampled sentence yDteacher\mathbf{y}\in\mathcal{D}_{\text{teacher}}. This measures whether the teacher’s samples are typical from the student’s point of view, i.e., how well a student covers the support of the teacher’s distribution. A large coverage risk means that the teacher’s typical outputs are not captured by the student, which is an indicator of mode collapse (Figure 1b).

In addition, we notice that mode averaging and collapsing are significantly affected by how “multi-modal” a task is. We propose to measure this by the distinct bi-gram percentage (Li et al., 2016a) of the teacher model (denoted by TeacherDist): for each test input, we sampled five outputs from the teacher and computed the percentage of distinct bi-grams, which is then averaged across the test set. As seen in Table 3, the dialogue task exhibits the highest diversity, i.e., it is the most multi-modal, whereas machine translation is the least multi-modal.

Comparing KL and RKL, we find that KL distillation consistently achieves lower RcvgR_{\text{cvg}} risks (i.e., better coverage) than RKL across all datasets. This confirms that KL distillation yields a smooth student distribution that covers the teacher’s, whereas RKL distillation does not have the covering property due to its mode-collapsing nature.

We further observe that RKL achieves significantly higher likelihood (given by a lower RllhR_{\text{llh}}) on the Commonsense Dialogue dataset. This shows that the mode-collapsing phenomenon of RKL distillation allows the student to generate plausible responses for the one-to-many dialogue task (Figure 1b), whereas the mode-averaging KL distillation puts the student in some desolate area in the teacher’s distribution (Figure 1a). On the other hand, RKL does not achieve lower likelihood risks in other tasks, since their one-to-many phenomenon is not as severe as dialogue generation (Wei et al., 2019; Bao et al., 2020; Wen et al., 2023).

Referring back to Table 2, we see that mode-averaging KL distillation is preferred over RKL for less multi-modal tasks, such as machine translation (which has a low TeacherDist score), whereas mode-collapsing RKL is preferred for highly multi-modal tasks, such as dialogue generation (which has a higher TeacherDist score).

Last, our symmetric distilling objectives (JS and TVD) generally have moderate likelihood and coverage risks between the two extremes. This shows that they achieve a compromise between mode collapsing and averaging, allowing them to yield high performance in all tasks (Table 2).

Analysis of the student size. We analyze our ff-distill variants with different student sizes in comparison with the SeqKD model. Due to the limited time and resources, we chose the DART dataset as our testbed. We reduced the student model to different sizes by changing the number of encoder layers, as we had already used a single-layer decoder following the suggested architecture in Kasai et al. (2020). Results are shown in Figure 2.

As seen, our ff-distill outperforms SeqKD across all model sizes. The symmetric losses (JS and TVD) also consistently outperform the asymmetric ones (KL and RKL). This is consistent with our main results and further validates the effectiveness and robustness of our ff-distill framework.

Analysis of training efficiency. Our ff-distill involves sampling sequences from the teacher. We propose an offline approach that obtains the teacher’s samples before training. We analyze the efficiency of offline sampling for JS and TVD distillations by comparing them with their online counterparts. We ran this experiment on an NVidia RTX A6000 GPU and an Intel Xeon Gold 5317 CPU.To obtain a rigorous time estimate, we ran efficiency analysis on an unshared, consumer-grade server, whereas other experiments were run on clusters (Appendix B).

As seen in Table 4, the offline variant achieves comparable performance, while the training speed is more than doubled. This is expected, as the offline distilling methods do not require inference from the teacher model during training, which constitutes a significant portion of the training process. This shows that our symmetric distilling methods can achieve high performance without the need for sampling from both the teacher and student.

Human Evaluation. We further validated ff-distill by human evaluation, where models were rated by fluency, missing information, and hallucination between 1 to 5 on the DART dataset, following previous work (Nan et al., 2021; Keymanesh et al., 2022). We invited five human annotators to evaluate 50 test samples for four competing models: SeqKD, ENGINE, JS, and TVD. For each test sample, the annotators were presented with shuffled model outputs, so they could not tell which output was generated by which model. Results are shown in Table 5.

As seen, our ff-distill enables students to capture the input data records more faithfully while also retaining a high level of fluency. This is additionally supported by the pp-values: comparing SeqKD and TVD, there is no statistically significant difference in terms of fluency (pp-value=32.6%); however, the improvements for missing information (pp-value=1.28%) and hallucination (pp-value=0.669%) are statistically significant. Our human evaluation confirms the effectiveness of ff-distill.

Case Study. Appendix C shows example outputs for our ff-distill variants. Indeed, we observe KL distillation yields short and generic utterances that are believed to be an indicator of mode averaging Wei et al. (2019); Bao et al. (2020). Our symmetric losses (JS and TVD) are able to generate more meaningful, fluent, and coherent sentences.

Related Work

Knowledge distillation (KD) is pioneered by Buciluǎ et al. (2006), who use an ensemble model as the teacher to train a single-model student by minimizing the squared difference between their predicted logits. Hinton et al. (2015) propose to directly learn from the output probabilities by minimizing their KL divergence. Sun et al. (2019) propose patient knowledge distillation (PKD), which requires the student to learn from the teacher’s intermediate layers. Jiao et al. (2020) propose TinyBERT, extending knowledge distillation for Transformer models by additional treatments on the attention layers. Other recent distilling methods include finding the optimal layer mapping between two models (Li et al., 2020; Jiao et al., 2021) and learning from multiple teachers (Yang et al., 2020; Wu et al., 2021; Li et al., 2022).

The success of KD has since sparked significant interest in its applications to text generation. Kim and Rush (2016) investigate sequence-level knowledge distillation (SeqKD) for neural machine translation, where they use sampled, hard sequences to approximate the KL divergence. Tu et al. (2020) train a student model by minimizing the energy function defined by a teacher model, which we show is an approximation to reverse KL distillation. Lin et al. (2020) propose imitation-based KD, where the teacher provides oracle probabilities on student-sampled partial sequences to address the exposure bias problem. Further, KD has been extensively used to train non-autoregressive text generation models to reduce the complexity of the training data (Gu et al., 2018; Shao et al., 2022; Huang et al., 2022).

It is noted that our ff-distill requires meaningful student sampling and thus is built upon existing KD techniques (Shleifer and Rush, 2020), including word-level and intermediate-layer KD. Nevertheless, it shows that our approach achieves an add-on performance improvement, and that our contributions are orthogonal to previous work.

Besides KD, common model compression techniques include parameter pruning and sparse modeling. Parameter pruning first trains a dense network and then removes certain neural weights in hopes of not significantly affecting the model performance (LeCun et al., 1989; Liu et al., 2018; Fan et al., 2021). Alternatively, one may apply sparse modeling techniques such as regularization during the training process to ensure zero-valued parameters (Frankle and Carbin, 2018; Louizos et al., 2018; Tang et al., 2022). Our work does not follow these directions, as we consider the knowledge distilling setting.

Regarding the ff-divergence function, it has many applications in the machine learning literature. The standard cross-entropy training is equivalent to minimizing the KL divergence between the ground-truth label distribution (often one-hot) and model distribution Bishop (2006). Generative adversarial networks (Goodfellow et al., 2014) minimize the Jensen–Shannon divergence by simultaneously training a generator and a discriminator against each other. Zhao et al. (2020) minimize α\alpha-divergence for adversarial learning, which generalizes KL and RKL, and is a special case of ff-divergence functions. Zhang et al. (2021) use total variation distance as a regularizer to encourage the model to predict more distinguishable probabilities. Further, JSD is used in computer vision KD Yin et al. (2020); Fang et al. (2021), but their tasks do not involve sequential data and the underlying techniques largely differ from our approach. To the best of our knowledge, we are the first to systematically formulate sequence-level knowledge distillation as ff-divergence minimization.

Conclusion

We propose ff-distill, a family of sequence-level distilling methods beyond minimizing the KL divergence. Under our framework, we propose and analyze four variants: KL, RKL, JS, and TVD distillations, where existing SeqKD and ENGINE are approximations of KL and RKL variants; we further derive step-wise decomposition for our ff-distill. Results on four text generation tasks show ff-distill consistently outperforms existing KD methods, and that our symmetric losses (JS and TVD) outperform asymmetric ones by avoiding extreme mode averaging and collapsing.

Limitations

Our ff-distill variants are less efficient to train than SeqKD and ENGINE, as we require the teacher’s soft probabilities instead of hard, sampled sequences. However, our methods achieve a significant performance improvement, and more importantly, the additional training time does not affect inference when the model is deployed. This follows the spirit of knowledge distillation in general, i.e., to obtain a small and efficient model for deployment.

Another potential threat to validity is that we have not reported multi-run statistics. In our preliminary experiments, we ran our approach multiple times and found results were generally consistent. Due to our excessive experimentation (estimated at 2000 GPU hours), it is not possible to run each model multiple times. We instead adopted a wide range of established automatic metrics, consistently showing the effectiveness of our approach. We further conducted in-depth analyses to better understand our proposed framework. We deem multi-run statistics not crucial to this paper, as this paper does not purely focus on empirical analysis. Rather, our main contributions lie in the novel machine learning framework, ff-distill, and the theoretical connections between step-wise and sequence-level ff-divergence functions.

Finally, a limitation of the formally published paper in the proceedings of ACL’23 is that we could not foresee the discussions during the conference, which are highlighted in the next section of this arXiv manuscript.

Notes from ACL’23 Conference

During the ACL conference, we received a number of questions and feedback, based on which we would like to make two clarifications.

which is not the same as (11). We realize that we made the above mistake in our implementation, which nevertheless works empirically well and can be thought of as an approximation. It is also noted that calculating (12) requires storing the probabilities for each step and would be less efficient.

Acknowledgments

We thank Wai Hong Ong for discussing the technical details of m()m(\cdot) in JS distillation, and Lucas Torroba Hennigen for discussing the connection between ff-distill and reinforcement learning.

We also thank all reviewers and chairs for their valuable comments. The research is supported in part by the Natural Sciences and Engineering Research Council of Canada (NSERC) under Grant No. RGPIN2020-04465, the Amii Fellow Program, the Canada CIFAR AI Chair Program, a UAHJIC project, a donation from DeepMind, and the Digital Research Alliance of Canada (alliancecan.ca).

References

Appendix A Proof of Theorem 1

[Part (a)] We first consider the JS decomposition. Let pp and qθq_{\theta} be the predicted distribution for the teacher and student, respectively. Let m(Y)=12p(Y)+12qθ(Y)m(\mathbf{Y})=\frac{1}{2}p(\mathbf{Y})+\frac{1}{2}q_{\theta}(\mathbf{Y}) be their average. We claim that JS divergence between two length-TT sequenceIn practice, TT can be thought of as the maximum length. Alternatively, we may consider varying-length sequences by a mixture of different values of TT. distributions can be decomposed step by step as

In fact, the partially sampled sequences are reused for the summation over t=1,,Tt=1,\cdots,T. That is to say, we will first sample the sequences y1:T1p\mathbf{y}_{1:T-1}\sim p and y1:T1qθ\mathbf{y}_{1:T-1}^{\prime}\sim q_{\theta} and then compute the summation; thus, the complexity is linear rather than quadratic.

To prove (14), we first focus on the first term of (13):

Then, we can unroll the first term of (19) recursively, resulting in

We state KL and RKL decompositions below. Their proofs are similar and thus omitted.

[Part (b)] This part shows that the same step-wise decomposition for TVD is an upper bound:

We again start by re-writing the TVD loss in a recursive form

Likewise, we can obtain the following inequality by multiplying and dividing by qθ(y1:T1)q_{\theta}(\mathbf{y}_{1:T-1}) in (28)

These two upper bounds, (36) and (37), are then combined to obtain (25), concluding the proof.

Admittedly, both (36) and (37) are valid upper bounds for the TVD divergence, but we nevertheless combine these two formulas to obtain a more computationally robust upper bound in the same spirit of JS decomposition.

Appendix B Experimental Details

Table 6 shows the statistics of our datasets. As seen, we benchmarked our models on a variety of natural language generation tasks with different data sizes. We chose state-of-the-art models as the teachers, with 200M–400M parameters. Accordingly, our students had 50M–150M parameters. The high performance of ff-distill variants across these datasets highlights the robustness of our approach.

For training, we used the Adam optimizer (Kingma and Ba, 2015) with default hyperparameters β=(0.9,0.999)\beta=(0.9,0.999) on DART, XSum, and Commonsense Dialogue. For WMT16 EN-RO, we followed the T5 teacher model (Raffel et al., 2020) and used the AdaFactor optimizer (Shazeer and Stern, 2018). We chose a small batch size of eight to fit the student as well as the large teacher in our GPU. All student models were trained for 28 epochs for pre-distillation and another 12 epochs for each distilling method, as additional training did not further improve performance.

The main experiments were conducted on AMD Milan 7413 CPUs and NVidia A100 GPUs, and the total training time was estimated at 2000 GPU hours. Note that this is not because our algorithm is slow (efficiency analyzed in Table 4), but because we have extensively experimented with a variety of datasets and model variants.

Appendix C Case Study

Table 7 shows example outputs for DART and Commonsense Dialogue. On the DART dataset, the KL and RKL distillations fail to yield coherent responses from the input data records. By contrast, JS and TVD distillations enable the student to generate sentences of much higher quality: they correctly recognize the name of the train as well as its origin and destination.

We additionally show an example output from the Commonsense Dialogue dataset, because the dialogue task exhibits the most severe multi-modal problem, which in turn requires the student to carefully balance mode averaging and collapsing. As seen, the KL-distilled student generates a short and generic response, which is consistent with existing literature (Wei et al., 2019; Bao et al., 2020), explained as mode averaging in our paper. The RKL-distilled student generates a detailed, but ungrammatical and incoherent, response. For JS and TVD distillations, the students generate responses that are both coherent and detailed. The case studies confirm our main claim that JS and JVD are more effective sequence-level distilling approaches.