Towards Understanding Knowledge Distillation
Mary Phuong, Christoph H. Lampert
Introduction
In 2014, Hinton et al. (2014) made a surprising observation: they found it easier to train classifier using the real-valued outputs of another classifier as target values than using actual ground-truth labels. Calling the procedure knowledge distillation, or distillation for short, they noticed the positive effect to occur even when the existing classifier (called teacher) was trained on the same data as it used afterwards for the distillation-training of the new classifier (called students). Since that time, the positive properties of distillation-based training has been confirmed several times: the optimization step is generally more well-behaved than the optimization step in label-based training, and it needs less if any regularization or specific optimization tricks. Consequently, in several fields, distillation has become a standard technique for transfering the information between classifiers with different architectures, such as from deep to shallow neural networks or from ensembles of classifiers to individual ones.
While the practical benefits of distillation are beyond doubt, its theoretical justification remains almost completely unclear. Existing explanations rarely go beyond qualitative statements, e.g. claiming that learning from soft labels should be easier than learning from hard labels, or that in a multi-class setting the teacher’s output provides information about how similar different classes are to each other.
In this work, we follow a different approach. Instead of studying distillation in full generality, we restrict our attention to a simplified, analytically tractable, setting: binary classification with linear teacher and linear student (either shallow or deep linear networks). For this situation, we achieve the first quantitative results about the effectiveness of distillation-based training. Specifically, our main results are: 1) We prove a generalization bound that establishes extremely fast convergence of the risk of distillation-trained classifiers. In fact, it can reach zero risk from finite training sets. 2) We identify three key factors that explain the success of distillation: data geometry – geometric properties of the data distribution, in particular class separation, directly influence the convergence speed of the student’s risk; optimization bias – even though the distillation objective can have many optima, gradient descent optimization is guaranteed to find a particularly favorable one; and strong monotonicity – increasing the training set always decreases the risk of the student classifier.
Related Work
Ideas underpinning distillation have a long history dating back to the work of Ba & Caruana (2014); Bucilua et al. (2006); Craven & Shavlik (1996); Li et al. (2014); Liang et al. (2008). In its current and most widely known form, it was introduced by Hinton et al. (2014) in the context of neural network compression.
Since then, distillation has quickly gained popularity among practitioners and established its place in deep learning folklore. It has been found to work well across a wide range of applications, including e.g. transferring from one architecture to another (Geras et al., 2016), compression (Howard et al., 2017; Polino et al., 2018), integration with first-order logic (Hu et al., 2016) or other prior knowledge (Yu et al., 2017), learning from noisy labels (Li et al., 2017), defending against adversarial attacks (Papernot et al., 2016), training stabilization (Romero et al., 2015; Tang et al., 2016), distributed learning (Polino et al., 2018), reinforcement learning (Rusu et al., 2016) and data privacy (Celik et al., 2017).
In contrast to the empirical success, the mathematical principles underlying distillation’s effectiveness have largely remained a mystery. Only very works examine distillation from a theoretical perspective. Lopez-Paz et al. (2016) cast distillation as a form of learning using privileged information (LUPI, Vapnik & Izmailov 2015), a learning setting in which additional per-instance information is available at training time but not at test time. However, the LUPI view concentrates on the aspect that the teacher’s supervision to the student is noise-free. This argument fails to explain, e.g., the success of distillation even when the original problem is noise-free to start with. The only other theoretical analysis we are aware of is by Urner et al. (2011), who study distillation as a form of semi-supervised learning. Specifically, they show that a two-step procedure, consisting of first training a teacher on a small labelled dataset and then training the student on a separate large dataset labelled by the teacher, can be more effective than training the student directly on the small labelled dataset. The paper’s focus is on the semi-supervised aspect, i.e. the gains from having a large unlabelled dataset.
A more distantly related topic is machine teaching (Zhu, 2015). In machine teaching, a machine learning system is trained by a human teacher, whose goal is to hand-pick as small a training set as possible, while ensuring that the machine learns a desired hypothesis. Transferring knowledge via machine teaching techniques is extremely effective: perfect transfer is often possible from a small finite teaching set (Zhu, 2013; Liu & Zhu, 2016). However, the price for this radical reduction in sample complexity is the expensive training set construction. Our work shows that, at least in the linear setting, distillation achieves a similar effectiveness with a more practical form of supervision.
Background: Linear Distillation
We allow the weight vector to be parameterised as a product of matrices, for some . When , this parameterisation is known as a deep linear network. Although deep linear networks have no additional capacity compared to directly parameterised linear classifiers , they induce different gradient-descent dynamics, and are often studied as a first step towards understanding deep nonlinear networks (Saxe et al., 2014; Kawaguchi, 2016; Hardt & Ma, 2017).
where is a normalization constant, such that the minimum of is 0. It only serves the purpose of simplifying notation and has no effect on the optimization.
The student observes the loss as a function of its parameters, i.e. the individual weight matrices,
and optimizes it via gradient descent. For the theoretical analysis, we avoid the complications of stepsize selection and adopt the notion of infinitesimal step sizeFor readers who are unfamiliar with gradient flows, it suffices to think of the stepsize as finite and ”sufficiently small”., which turns the gradient descent procedure into a continuous gradient flow. We write for the value of the matrix at time , with denoting the initial value, and . Then, each , for , evolves according to the following differential equation.
The student is trained until convergence, i.e. . We measure the transfer risk of the trained student, defined as the probability that its prediction differs from that of the teacher,
In Section 4.2, we will derive a bound for the transfer risk and establish how rapidly it decreases as a function of .
Generalization Properties of Linear Distillation
This section contains our main technical results. First, in Section 4.1, we provide an explicit characterization of the outcome of distillation-based training in the linear setting. In other words, we identify what the student actually learns. In particular, we prove that the student is able to perfectly identify the teacher’s weight vector, if the number of training examples () is equal to the dimensionality of the data () or higher. If less data is available, under minor assumptions, the student finds the best approximation of the teacher’s weight vector that is possible within the subspace spanned by the training data.
In Section 4.2 we use these results to study the generalization properties of the student classifier, i.e. we characerize how fast the student learns. Specifically, we prove a generalization bound with much more appealing properties than what is possible in the classic situation of learning from hard labels. As soon as enough training data is available (), the student’s risk is simply . Otherwise, the risk can be bounded explicitly in a distribution-dependent way that, in particular, allows us to identify three key factors that explain the success of distillation, and to understand when distillation-based transfer is most effective.
In this section, we derive in closed form the asymptotic solution to the gradient flow (3) undergone by the student when trained by distillation. We state the results separately for directly parameterized linear classifiers and deep linear networks , as the settings require slightly different ways of initializing parameters. Namely, in the former case, initializing is valid, while in the latter case, this would lead to vanishing gradients, and we have to initialize with small (typically random) values.
Assume the student is a directly parameterised linear classifier with weight vector initialised at zero, . Then, the student’s weight vector fulfills almost surely
Theorem 1 shows a remarkable property of distillation-based training for linear systems: if sufficiently many (at least ) data points are available, the student exactly recovers the teacher’s weight vector, . This is a strong justification for distillation as a method of knowledge transfer between linear classifiers and the theorem establishes that the effect occurs not just in the infinite data limit (), as one might have expected, but already in the finite sample regime ().
When few data points are available (), the weight vector learned by the student is simply the projection of the teacher’s weight vector onto the data span (the subspace spanned by the columns of ). In a sense, this is the best the student can do: the gradient descent update direction always lies in the data span, so there is no way for the student to learn anything outside of it. The projection is the best subspace-constrained approximation of with respect to the Euclidean norm. The extent to which Euclidean closeness implies closeness in predictions is a separate matter, and the subject of Section 4.2.
First, notice that is a global minimiser of . Moreover, when , it is (almost surely wrt. ) unique, and when , it is (almost surely) the only one lying in the span of and thus potentially reachable by gradient descent.
The proof consists of two parts. We prove that a) the gradient flow (3) drives the objective value towards the optimum, as , and b) the distance between and the claimed asymptote is upper-bounded by the objective gap,
for some constant and all .
For part a), observe that is convex. For any , the time-derivative of is negative unless we are at a global minimum,
This allows us (via a technical derivation that we omit here) to relate the objective gap to the gradient norm: it can be shown that there exists , such that
Applying the above to in (8), we are able to bound the amount of reduction in the objective in terms of the objective itself, ultimately proving linear convergence.
For part b), invoke (9) with and ; this gives ∎
The full proof is given in the Supplementary Material.
The next results is the analog of Theorem 1 for deep linear networks. Here, some technical conditions are needed because the parameters cannot all be initialized at .
Let be defined as in Theorem 1. Assume the student is a deep linear network, initialized such that for some ,
for . Then, for , student’s weight vector fulfills almost surely
The interpretation of the theorem is analogous to Theorem 1. Given enough data (), the student learns to perfectly mimic the teacher. Otherwise, it learns an approximation at least -close to the projection of the teacher’s weight vector onto the data span.
The conditions (11)–(13) appear for technical reasons and a closer look at them shows that they do not pose problems in practice. Condition (11) states that the network’s weights should be initialised with sufficiently small values. Consequently, this assumption is easy to satisfy in practice. Condition (12) requires that the initial loss is smaller than the loss at . This condition guarantees that the gradient flow does not hit the point , where all gradient vanish and the optimization would stop prematurely. In practice, when the step size is finite, the condition is not needed. Nevertheless, it is also not hard to satisfy: for any near-zero initialisation, , either or will satisfy (12), so at most one has to flip the sign on one of the matrices. Finally, condition (13) is called balancedness (Arora et al., 2018) and discussed in-depth in (Arora et al., 2019)). It simplifies the analysis of matrix products and makes it possible to explicitly analyze the evolution of induced by gradient flow in the ’s. Assuming near-zero initialization, the condition is automatically satisfied approximately and there is some evidence (Arora et al., 2019) suggesting that approximate balancedness may suffice for convergence results of the kind we are interested in. Otherwise, the condition can also simply be enforced numerically.
First, we establish convergence in the objective, as , similarly to the case . Unlike that case, however, the evolution of the end-to-end weight vector is governed by complex mechanics induced by gradient flow in ’s. A key tool for analyzing this induced flow was recently established in (Arora et al., 2018): the authors show that the induced flow behaves similarly to gradient flow with momentum applied directly to . Making use of this result, one can proceed analogously as in the case of to show convergence in the objective.
Second, to show convergence in parameter space, we decompose into its projection onto the span of , and an orthogonal component. The -component converges to , by strong convexity arguments as in the case . It remains to show that the orthogonal component is small. Now, recall that in the case , we initialise at and move within the span, so the orthogonal component is always zero. When , the situation is different: a) we initialise with a potentially non-zero orthogonal component (because we need to avoid the spurious stationary point ), and b) the momentum term causes the orthogonal component to grow during optimisation. Luckily, the rate of growth can be precisely characterised and controlled by the initialisation norm , so depending on how close to zero we initialise, we can upper-bound the size of the orthogonal component. This yields a bound on the distance . ∎
For the formal proof, we refer the reader to the Supplemental Material.
2 How Fast Does the Student Learn?
In this section, we present our main quantitative result, a bound for the expected transfer risk in linear distillation.
A key quantity for us is the angle between and a randomly chosen , for . For a given transfer task , we denote by the reverse cdf of ,
By construction, is monotonically decreasing, starting with and approaches for . Figure 1 illustrates this behavior for three exemplary data distributions as Tasks A,B and C. In Task A, the probability mass is well aligned with the direction of the teacher’s weight vector. The probability that a randomly chosen data point has a large angle with is small. Therefore, the value of quickly drops with growing angle . In Task B, the data also aligns well with , but in addition, the data region remains bounded away from the decision boundary. Therefore, certain large angles can never occur, i.e. there exists a value , such that for . In Task C, the situation is different: the data distribution is concentrated along the decision boundary and the probability of a angle between and a randomly chosen data point is large. As a consequence, drops more slowly with growing angle than in the previous two settings.
We are now ready to state the main result. For improved readability, we phrase it for a student with infinitesimally small initialization, i.e. . The general formulation can be found in the supplemental material.
Equation (18) is unsurprising, of course, because in Section 4.1 we already established that for the student is able to perfectly mimic the teacher.
Inequality (19), however, is –to our knowledge– the first quantitative characterization how well a student can learn via distillation.
Before we provide the proof sketch, we present two instantiations of the bound for specific classes of tasks that provide insight how fast the right hand side of (19) actually decreases.
The margin case. The first class of tasks we consider are tasks in which the classes are separated by an angular margin, illustrated in Figure 2 (left). These tasks are characterized by a ‘wedge’ of zero probability mass near the boundaryIn bounded domains this condition is, in particular, fulfilled in the classical margin situation (Schölkopf & Smola, 2002), when the classes are separated by a positive distance from each other.. For these tasks, we obtain from Theorem 3 that the expected risk decays exponentially in , up to .
If there exists such that and , then
The polynomial case. The second class are tasks for which we can upper-bound by a -order polynomial. This can be done trivially for any task by setting , but that choice would yield a vacuous bound. Higher values of correspond to stronger assumptions on the distribution but enable better rates. Figure 2 (center, right) shows examples of polynomial distributions for . The special case corresponds to a uniform angle distribution, while distribution with have low probability mass near the decision boundary, while not necessarily exhibiting a margin.
The following corollary establishes that for tasks with polynomial behavior of , the expected risk decays essentially at a rate of or faster.
If there exists a be such that for all , then
We apply Theorem 3 and insert the polynomial upper bound for . For the case , we get
Finally, we use the inequality and the claim follows. ∎
Note that, in contrast to many results in statistical learning theory, the bounds are far from vacuous, even when only little data is available. This can best be seen in Corollary 1, where and hence is an informative upper bound for the classification error. These observations suggest that distillation operates in a very different regime from classical hard-target learning. Standard bounds usually have little to say when and only start to be useful when . In contrast, (linear) distillation ensures perfect transfer when and non-vacuous bounds are possible even when .
3 Proof of Theorem 3
The case follows trivially from the result of Theorem 1 and 2. For the case , the following property turns out to be crucial for obtaining a transfer rate of the form that we do.
and an analogous statement holds for . Now, because the first columns of coincide with , we have and
Taking on both sides (and remembering that is decreasing) yields the claim. ∎
For the moment, think of as a proxy for the transfer risk, i.e. the closer the trained student is to the teacher in terms of angles, the lower the transfer risk. A direct consequence of Lemma 1, and the reason we call it ‘strong mononoticity’, is that including additional data in the transfer set can never harm the transfer risk, only improve it. This property is specific to distillation; it does not hold in hard-target learning.
We decompose the expected risk as follows:
Let us fix some for which and (i.e. an ‘easy’ positive test example); for this we have . Consider the situation where for some (i.e. there is at least one good teaching point). Then, Lemma 1 with and yields . Combined with the triangle inequality, we obtain
which implies , i.e. a correct prediction (same as the teacher’s). Conversely, an error can occur only if for all . Because are independent, we have
By a symmetric argument, one can show that
Combining (29), (32) and (33) yields the result:
Why Does Distillation Work?
From the formal analysis in the previous section, three concepts emerge as key factors for the success of distillation: data geometry, optimization bias, and strong monotonicity. In this section, we discuss these factors and provide some empirical confirmation how they affect or explain variations in the transfer risk.
From Theorem 3 we know that the data geometry, in particular the angular alignment between the data distribution and the teacher, crucially impact how fast the student can learn. Formally, this is reflected in : the faster it decreases, the easier it should be for the student to learn the task.
To experimentally test the effect of data geometry on the effectiveness of distillation, we adopt the setting of Corollary 2. We consider a series of tasks of varying angular alignment, as measured by the degree, , of the polynomial by which is upper bounded.
We use an input space dimension of and a transfer set size . Then, we train a linear student by distillation on each of the tasks and evaluate its transfer risk on held-out data. Figure 3 shows the results. The plot shows a clearly decreasing trend: on tasks with more favorable data geometry (higher ), transfer via distillation is more effective and the student achieves lower risk.
2 Optimization Bias
A second key factor for the success of distillation is a specific optimization bias. For , the distillation training objective (1) has many minima of identical function value but potentially different generalization properties. Therefore, the optimization method used could have a large impact on the transfer risk. As Theorems 1 and 2 show, gradient descent has a particularly favorable bias for distillation.
To verify this observation experimentally, we consider learners that are guided by an optimisation bias to different degrees: at one end of the spectrum is the gradient-descent learner we have studied in previous sections, while at the other end is a learner that treats all minimizers of the distillation training loss equally, i.e. that has no bias toward any of the solutions. Specifically, consider learners with weights of the form , where is the gradient-descent distillation solution and is a Gaussian random vector in the subspace orthogonal to the data span, i.e. if is the data matrix, then . All learners of this form globally minimize the distillation training loss, and depending on , they are more or less guided by the gradient-descent bias: and represent the two extremes mentioned above.
Figure 4 shows the result. There is a clear trend in favor of learners that are more strongly guided by the gradient-descent bias (small ); these learners generally achieve lower transfer risk. This result supports the idea of optimization bias as a key component of distillation’s success.
3 Strong Monotonicity
The third key factor we identify is strong monotonicity, as established in Lemma 1: training the student on more data always leads to a better approximation of the teacher’s weight vector.
Compared to data geometry and optimisation bias, strong monotonicity is less amenable to experimental study because it is a downstream property that cannot directly be manipulated. We therefore take an indirect approach. We consider a set of learners including the gradient-descent distillation learner, the hard-target learner, and several learners with reduced optimisation bias (as in Section 5.2), and train them on the same task. For each learner, we note its expected risk calculated on a held-out set, and its monotonicity index, defined as the probability that an additional training example reduces the angle between the student’s and the teacher’s weight vectors rather than increasing it, i.e.
where the student’s weight vector is now treated as a function of the training set. Thus, we can relate a learner’s risk and its monotonicity.
We train the learners on the polynomial-angle task from Section 5.1, with and . The expected risk as well as the monotonicity index are estimated as averages over 1000 transfer sets.
The results are shown in Figure 5. There is a negative correlation between monotonicity and transfer risk, which supports the intuition of monotonicity as a desirable property and a possible explanation of distillation’s success.
However, a few reservations are in order. First, as mentioned above, monotonicity cannot easily be manipulated, so its effect on transfer risk remains unknown. We can only measure correlation. Second, monotonicity is of binary nature; it only captures whether an extra data point helps or not. Yet for a quantitative characterization of risk, one would have to capture by how much an extra data point helps. We leave more refined definitions of monotonicity for future work.
Conclusion
In this work, we have formulated and studied a linear model of knowledge distillation. Within this model, we have derived a) a characterization of the solution learned by the student, b) a bound on the transfer risk, meaningful even in the low-data regime, and c) three key factors that explain the success of distillation. In doing so, we hope to have enriched both the current intuitive and theoretical understanding of distillation, both of which have only been weakly developed.
Our work paints a picture of distillation as an extremely effective method for knowledge transfer that derives its power from an optimization bias of gradient-based methods initialized near the origin, which in particular has the effect that any additionally included training point can only improve the student’s approximation of the teacher. Distillation further benefits strongly from a favorable data geometry, in particular a margin between classes.
While we have supported this picture by theory and empirical work only in the linear case, we hypothesize that similar properties also govern the behavior of distillation in the nonlinear setting. If this hypothesis turns out to be true, it would have implications for the design of transfer sets (a large teacher model being stored along with only the minimal dataset necessary for future transfer) or active learning (which samples are most informative to have labeled by the teacher). Potentially, strong monotonicity could serve as a leading design principle for new sample-efficient algorithms. We thus consider the extension to nonlinear models the main direction for future work.
References
Appendix A Properties of the Cross-Entropy Loss
The gradient of the cross-entropy loss (35) takes the form
The global minimum of the cross-entropy loss (35) is 0 and the set of global minimisers is
Assume is full-rank. For any sublevel set , there exists such that
Consider the 2nd-order Taylor expansion of around ,
where is the Hessian of evaluated at , a point lying between and . A straightforward calculation shows that the Hessian takes the form
We will now show that there is a constant such that
for all and , so that we can claim , or consequently .
Now, let us apply to lower-bound (41):
Assume is full-rank. For any sublevel set , there exists such that
Let . (If is empty, the claim is trivially true.) Theorem A.3 applied to implies that for some ,
Appendix B Proof of Theorem 1
We will prove a supporting lemma, and then the theorem.
The data matrix is almost surely (wrt. ) full-rank, we can therefore apply Corollary A.1 to and to lower-bound the gradient norm on the right-hand side of (53). We obtain for some and all , or equivalently,
Integrating over yields , which proves global convergence in the objective: as .
Since as , the theorem follows. ∎
Appendix C Proof of Theorem 2
For the proof, we will need a result by (Arora et al., 2018), which characterises the induced flow on when running gradient descent on the component matrices .
If the balancedness condition (13) holds, then
Similarly to the case , we start by looking at the time-derivative of ,
It is non-positive, so stays within the -sublevel set throughout optimisation,
Also, is convex and by Assumption (12) it does not contain . We can therefore take to be the distance between and , and it follows that for .
Now, noting that is almost surely full-rank, apply Corollary A.1 to and to upper-bound the right-hand side of (57),
To prove convergence in parameters, we decompose the ‘error’ into orthogonal components and bound each of them separately,
It turns out that the right-hand side expression is integrable in yet another way, namely
Equating the two and integrating over yields
because .
We now bound the norm of . Starting from an orthogonal decomposition similar to (60) and applying (62) with (67), we get
Denote . By the same orthogonal decomposition, we also know that , so we can divide both sides above by ,
On the right-hand side, we now have a decreasing function of that goes to zero as . However, evaluated at our specific , it is lower-bounded by , implying an implicit upper bound for .
How do we find this bound? Suppose we find some constant such that . Then, because is decreasing, it must be the case that . One such candidate for is
(Here we have used condition (11): .) To check that indeed , start from the inequality
Taking the leftmost and rightmost expression and multiplying by yields
Finally, let us turn back to our original goal of bounding . With (60), (62), (67) and (73), we now know that
Hence, if we initialise close enough to zero, as specified by condition (11), we can ensure that
Appendix D Theorem 3 for Approximate Distillation
We extend Theorem 3 to the setting where the student learns the solution only -approximately, as is the case for deep linear networks initialised as in Theorem 2. When , the teacher’s weight vector is recovered exactly and the transfer risk is zero, even when the student is deep. The following theorem therefore only covers the case .
The result is very similar to Theorem 3 in the main text, the only difference is the constant which compensates for the imprecision in learning by pushing the bound up (recall that is decreasing). However, as goes to zero, so does and we recover the original bound.
For the proof, we start with a tool for controlling the angle between and . Recall that the angle is defined as
The first step is to lower-bound the inner product . To that end, we expand and rearrange to obtain
Now use the triangle relation squared to lower-bound the right-hand side of (80) and get
The left-hand side is by assumption non-negative, so we have . On this domain,
We decompose the expected risk as follows:
Let us fix some for which and ; for this we have . Consider the situation where for some . Then by the triangle inequality, Lemma D.1 and Lemma 1,
which implies , i.e. a correct prediction (same as the teacher’s). Conversely, an error can occur only if for all . Because are independent, we have
By a symmetric argument, one can show that
Combining (86), (90) and (91) yields the result. ∎