Towards Understanding Knowledge Distillation

Mary Phuong, Christoph H. Lampert

Introduction

In 2014, Hinton et al. (2014) made a surprising observation: they found it easier to train classifier using the real-valued outputs of another classifier as target values than using actual ground-truth labels. Calling the procedure knowledge distillation, or distillation for short, they noticed the positive effect to occur even when the existing classifier (called teacher) was trained on the same data as it used afterwards for the distillation-training of the new classifier (called students). Since that time, the positive properties of distillation-based training has been confirmed several times: the optimization step is generally more well-behaved than the optimization step in label-based training, and it needs less if any regularization or specific optimization tricks. Consequently, in several fields, distillation has become a standard technique for transfering the information between classifiers with different architectures, such as from deep to shallow neural networks or from ensembles of classifiers to individual ones.

While the practical benefits of distillation are beyond doubt, its theoretical justification remains almost completely unclear. Existing explanations rarely go beyond qualitative statements, e.g. claiming that learning from soft labels should be easier than learning from hard labels, or that in a multi-class setting the teacher’s output provides information about how similar different classes are to each other.

In this work, we follow a different approach. Instead of studying distillation in full generality, we restrict our attention to a simplified, analytically tractable, setting: binary classification with linear teacher and linear student (either shallow or deep linear networks). For this situation, we achieve the first quantitative results about the effectiveness of distillation-based training. Specifically, our main results are: 1) We prove a generalization bound that establishes extremely fast convergence of the risk of distillation-trained classifiers. In fact, it can reach zero risk from finite training sets. 2) We identify three key factors that explain the success of distillation: data geometry – geometric properties of the data distribution, in particular class separation, directly influence the convergence speed of the student’s risk; optimization bias – even though the distillation objective can have many optima, gradient descent optimization is guaranteed to find a particularly favorable one; and strong monotonicity – increasing the training set always decreases the risk of the student classifier.

Related Work

Ideas underpinning distillation have a long history dating back to the work of Ba & Caruana (2014); Bucilua et al. (2006); Craven & Shavlik (1996); Li et al. (2014); Liang et al. (2008). In its current and most widely known form, it was introduced by Hinton et al. (2014) in the context of neural network compression.

Since then, distillation has quickly gained popularity among practitioners and established its place in deep learning folklore. It has been found to work well across a wide range of applications, including e.g. transferring from one architecture to another (Geras et al., 2016), compression (Howard et al., 2017; Polino et al., 2018), integration with first-order logic (Hu et al., 2016) or other prior knowledge (Yu et al., 2017), learning from noisy labels (Li et al., 2017), defending against adversarial attacks (Papernot et al., 2016), training stabilization (Romero et al., 2015; Tang et al., 2016), distributed learning (Polino et al., 2018), reinforcement learning (Rusu et al., 2016) and data privacy (Celik et al., 2017).

In contrast to the empirical success, the mathematical principles underlying distillation’s effectiveness have largely remained a mystery. Only very works examine distillation from a theoretical perspective. Lopez-Paz et al. (2016) cast distillation as a form of learning using privileged information (LUPI, Vapnik & Izmailov 2015), a learning setting in which additional per-instance information is available at training time but not at test time. However, the LUPI view concentrates on the aspect that the teacher’s supervision to the student is noise-free. This argument fails to explain, e.g., the success of distillation even when the original problem is noise-free to start with. The only other theoretical analysis we are aware of is by Urner et al. (2011), who study distillation as a form of semi-supervised learning. Specifically, they show that a two-step procedure, consisting of first training a teacher on a small labelled dataset and then training the student on a separate large dataset labelled by the teacher, can be more effective than training the student directly on the small labelled dataset. The paper’s focus is on the semi-supervised aspect, i.e. the gains from having a large unlabelled dataset.

A more distantly related topic is machine teaching (Zhu, 2015). In machine teaching, a machine learning system is trained by a human teacher, whose goal is to hand-pick as small a training set as possible, while ensuring that the machine learns a desired hypothesis. Transferring knowledge via machine teaching techniques is extremely effective: perfect transfer is often possible from a small finite teaching set (Zhu, 2013; Liu & Zhu, 2016). However, the price for this radical reduction in sample complexity is the expensive training set construction. Our work shows that, at least in the linear setting, distillation achieves a similar effectiveness with a more practical form of supervision.

Background: Linear Distillation

We allow the weight vector to be parameterised as a product of matrices, w=WNWN1W1\mathbf{w}^{\intercal}=\mathbf{W}_{N}\mathbf{W}_{N-1}\cdots\mathbf{W}_{1} for some N1N\geq 1. When N2N\geq 2, this parameterisation is known as a deep linear network. Although deep linear networks have no additional capacity compared to directly parameterised linear classifiers (N=1;w=W1)(N=1;\mathbf{w}^{\intercal}=\mathbf{W}_{1}), they induce different gradient-descent dynamics, and are often studied as a first step towards understanding deep nonlinear networks (Saxe et al., 2014; Kawaguchi, 2016; Hardt & Ma, 2017).

where LL^{*} is a normalization constant, such that the minimum of L1L^{1} is 0. It only serves the purpose of simplifying notation and has no effect on the optimization.

The student observes the loss as a function of its parameters, i.e. the individual weight matrices,

and optimizes it via gradient descent. For the theoretical analysis, we avoid the complications of stepsize selection and adopt the notion of infinitesimal step sizeFor readers who are unfamiliar with gradient flows, it suffices to think of the stepsize as finite and ”sufficiently small”., which turns the gradient descent procedure into a continuous gradient flow. We write Wi(τ)\mathbf{W}_{i}(\tau) for the value of the matrix Wi\mathbf{W}_{i} at time τ[0,)\tau\in[0,\infty), with Wi(0)\mathbf{W}_{i}(0) denoting the initial value, and w(τ)=WN(τ)W1(τ)\mathbf{w}(\tau)^{\intercal}=\mathbf{W}_{N}(\tau)\cdots\mathbf{W}_{1}(\tau). Then, each Wi(τ)\mathbf{W}_{i}(\tau), for i{1,,N}i\in{\left\{1,\dots,N\right\}}, evolves according to the following differential equation.

The student is trained until convergence, i.e. τ\tau\to\infty. We measure the transfer risk of the trained student, defined as the probability that its prediction differs from that of the teacher,

In Section 4.2, we will derive a bound for the transfer risk and establish how rapidly it decreases as a function of nn.

Generalization Properties of Linear Distillation

This section contains our main technical results. First, in Section 4.1, we provide an explicit characterization of the outcome of distillation-based training in the linear setting. In other words, we identify what the student actually learns. In particular, we prove that the student is able to perfectly identify the teacher’s weight vector, if the number of training examples (nn) is equal to the dimensionality of the data (dd) or higher. If less data is available, under minor assumptions, the student finds the best approximation of the teacher’s weight vector that is possible within the subspace spanned by the training data.

In Section 4.2 we use these results to study the generalization properties of the student classifier, i.e. we characerize how fast the student learns. Specifically, we prove a generalization bound with much more appealing properties than what is possible in the classic situation of learning from hard labels. As soon as enough training data is available (ndn\geq d), the student’s risk is simply . Otherwise, the risk can be bounded explicitly in a distribution-dependent way that, in particular, allows us to identify three key factors that explain the success of distillation, and to understand when distillation-based transfer is most effective.

In this section, we derive in closed form the asymptotic solution to the gradient flow (3) undergone by the student when trained by distillation. We state the results separately for directly parameterized linear classifiers (N=1)(N=1) and deep linear networks (N2)(N\geq 2), as the settings require slightly different ways of initializing parameters. Namely, in the former case, initializing w(0)=0\mathbf{w}(0)=\mathbf{0} is valid, while in the latter case, this would lead to vanishing gradients, and we have to initialize with small (typically random) values.

Assume the student is a directly parameterised linear classifier (N=1)(N=1) with weight vector initialised at zero, w(0)=0\mathbf{w}(0)=\mathbf{0}. Then, the student’s weight vector fulfills almost surely

Theorem 1 shows a remarkable property of distillation-based training for linear systems: if sufficiently many (at least dd) data points are available, the student exactly recovers the teacher’s weight vector, w\mathbf{w}_{*}. This is a strong justification for distillation as a method of knowledge transfer between linear classifiers and the theorem establishes that the effect occurs not just in the infinite data limit (nn\to\infty), as one might have expected, but already in the finite sample regime (ndn\geq d).

When few data points are available (n<dn<d), the weight vector learned by the student is simply the projection of the teacher’s weight vector onto the data span (the subspace spanned by the columns of X\mathbf{X}). In a sense, this is the best the student can do: the gradient descent update direction w(τ)τ{\frac{\partial\mathbf{w}(\tau)}{\partial\tau}} always lies in the data span, so there is no way for the student to learn anything outside of it. The projection is the best subspace-constrained approximation of w\mathbf{w}_{*} with respect to the Euclidean norm. The extent to which Euclidean closeness implies closeness in predictions is a separate matter, and the subject of Section 4.2.

First, notice that w^\hat{\mathbf{w}} is a global minimiser of L1L^{1}. Moreover, when ndn\geq d, it is (almost surely wrt. XPxn\mathbf{X}\sim P_{\mathbf{x}}^{n}) unique, and when n<dn<d, it is (almost surely) the only one lying in the span of X\mathbf{X} and thus potentially reachable by gradient descent.

The proof consists of two parts. We prove that a) the gradient flow (3) drives the objective value towards the optimum, L1(w(t))0L^{1}(\mathbf{w}(t))\to 0 as tt\to\infty, and b) the distance between w(t)\mathbf{w}(t) and the claimed asymptote w^\hat{\mathbf{w}} is upper-bounded by the objective gap,

for some constant c>0c>0 and all t[0,)t\in[0,\infty).

For part a), observe that L1L^{1} is convex. For any τ[0,)\tau\in[0,\infty), the time-derivative of L1(w(τ))L^{1}(\mathbf{w}(\tau)) is negative unless we are at a global minimum,

This allows us (via a technical derivation that we omit here) to relate the objective gap to the gradient norm: it can be shown that there exists c>0c^{\prime}>0, such that

Applying the above to w(τ)\mathbf{w}(\tau) in (8), we are able to bound the amount of reduction in the objective in terms of the objective itself, ultimately proving linear convergence.

For part b), invoke (9) with v=w(τ)\mathbf{v}=\mathbf{w}(\tau) and w=w^\mathbf{w}=\hat{\mathbf{w}}; this gives L1(w(τ))μ2w(τ)w^2.L^{1}(\mathbf{w}(\tau))\geq\frac{\mu}{2}{\left\|\mathbf{w}(\tau)-\hat{\mathbf{w}}\right\|}^{2}.

The full proof is given in the Supplementary Material.

The next results is the analog of Theorem 1 for deep linear networks. Here, some technical conditions are needed because the parameters cannot all be initialized at .

Let w^\hat{\mathbf{w}} be defined as in Theorem 1. Assume the student is a deep linear network, initialized such that for some ϵ>0\epsilon>0,

for j=1,,N1j=1,\dots,N-1. Then, for ndn\geq d, student’s weight vector fulfills almost surely

The interpretation of the theorem is analogous to Theorem 1. Given enough data (ndn\geq d), the student learns to perfectly mimic the teacher. Otherwise, it learns an approximation at least ϵ\epsilon-close to the projection of the teacher’s weight vector onto the data span.

The conditions (11)–(13) appear for technical reasons and a closer look at them shows that they do not pose problems in practice. Condition (11) states that the network’s weights should be initialised with sufficiently small values. Consequently, this assumption is easy to satisfy in practice. Condition (12) requires that the initial loss is smaller than the loss at w=0\mathbf{w}=\mathbf{0}. This condition guarantees that the gradient flow does not hit the point w=0\mathbf{w}=\mathbf{0}, where all gradient vanish and the optimization would stop prematurely. In practice, when the step size is finite, the condition is not needed. Nevertheless, it is also not hard to satisfy: for any near-zero initialisation, w(0)=w0\mathbf{w}(0)=\mathbf{w}_{0}, either w0\mathbf{w}_{0} or w0-\mathbf{w}_{0} will satisfy (12), so at most one has to flip the sign on one of the Wi(0)\mathbf{W}_{i}(0) matrices. Finally, condition (13) is called balancedness (Arora et al., 2018) and discussed in-depth in (Arora et al., 2019)). It simplifies the analysis of matrix products and makes it possible to explicitly analyze the evolution of w\mathbf{w} induced by gradient flow in the Wi\mathbf{W}_{i}’s. Assuming near-zero initialization, the condition is automatically satisfied approximately and there is some evidence (Arora et al., 2019) suggesting that approximate balancedness may suffice for convergence results of the kind we are interested in. Otherwise, the condition can also simply be enforced numerically.

First, we establish convergence in the objective, L1(w(t))0L^{1}(\mathbf{w}(t))\to 0 as tt\to\infty, similarly to the case N=1N=1. Unlike that case, however, the evolution of the end-to-end weight vector w(τ)\mathbf{w}(\tau) is governed by complex mechanics induced by gradient flow in Wi\mathbf{W}_{i}’s. A key tool for analyzing this induced flow was recently established in (Arora et al., 2018): the authors show that the induced flow behaves similarly to gradient flow with momentum applied directly to w\mathbf{w}. Making use of this result, one can proceed analogously as in the case of N=1N=1 to show convergence in the objective.

Second, to show convergence in parameter space, we decompose w(t)\mathbf{w}(t) into its projection onto the span of X\mathbf{X}, and an orthogonal component. The X\mathbf{X}-component converges to w^\hat{\mathbf{w}}, by strong convexity arguments as in the case N=1N=1. It remains to show that the orthogonal component is small. Now, recall that in the case N=1N=1, we initialise at w(0)=0\mathbf{w}(0)=\mathbf{0} and move within the span, so the orthogonal component is always zero. When N2N\geq 2, the situation is different: a) we initialise with a potentially non-zero orthogonal component (because we need to avoid the spurious stationary point w=0\mathbf{w}=\mathbf{0}), and b) the momentum term causes the orthogonal component to grow during optimisation. Luckily, the rate of growth can be precisely characterised and controlled by the initialisation norm w(0){\left\|\mathbf{w}(0)\right\|}, so depending on how close to zero we initialise, we can upper-bound the size of the orthogonal component. This yields a bound on the distance w(t)w^{\left\|\mathbf{w}(t)-\hat{\mathbf{w}}\right\|}. ∎

For the formal proof, we refer the reader to the Supplemental Material.

2 How Fast Does the Student Learn?

In this section, we present our main quantitative result, a bound for the expected transfer risk in linear distillation.

A key quantity for us is the angle between w\mathbf{w}_{*} and a randomly chosen x\mathbf{x}, for xPx\mathbf{x}\sim P_{\mathbf{x}}. For a given transfer task (Px,w)(P_{\mathbf{x}},\mathbf{w}_{*}), we denote by pp the reverse cdf of αˉ(w,x)\bar{\alpha}(\mathbf{w}_{*},\mathbf{x}),

By construction, p(θ)p(\theta) is monotonically decreasing, starting with p(0)=1p(0)=1 and approaches for θπ/2\theta\to\pi/2. Figure 1 illustrates this behavior for three exemplary data distributions as Tasks A,B and C. In Task A, the probability mass is well aligned with the direction of the teacher’s weight vector. The probability that a randomly chosen data point xPxx\sim{P_{\mathbf{x}}} has a large angle with w\mathbf{w}_{*} is small. Therefore, the value of p(θ)p(\theta) quickly drops with growing angle θ\theta. In Task B, the data also aligns well with w\mathbf{w}_{*}, but in addition, the data region remains bounded away from the decision boundary. Therefore, certain large angles can never occur, i.e. there exists a value θ0<π/2\theta_{0}<\pi/2, such that p(θ)=0p(\theta)=0 for θθ0\theta\geq\theta_{0}. In Task C, the situation is different: the data distribution is concentrated along the decision boundary and the probability of a angle between w\mathbf{w}_{*} and a randomly chosen data point xPxx\sim{P_{\mathbf{x}}} is large. As a consequence, p(θ)p(\theta) drops more slowly with growing angle than in the previous two settings.

We are now ready to state the main result. For improved readability, we phrase it for a student with infinitesimally small initialization, i.e. ϵ0\epsilon\to 0. The general formulation can be found in the supplemental material.

Equation (18) is unsurprising, of course, because in Section 4.1 we already established that for ndn\geq d the student is able to perfectly mimic the teacher.

Inequality (19), however, is –to our knowledge– the first quantitative characterization how well a student can learn via distillation.

Before we provide the proof sketch, we present two instantiations of the bound for specific classes of tasks that provide insight how fast the right hand side of (19) actually decreases.

The margin case. The first class of tasks we consider are tasks in which the classes are separated by an angular margin, illustrated in Figure 2 (left). These tasks are characterized by a ‘wedge’ of zero probability mass near the boundaryIn bounded domains this condition is, in particular, fulfilled in the classical margin situation (Schölkopf & Smola, 2002), when the classes are separated by a positive distance from each other.. For these tasks, we obtain from Theorem 3 that the expected risk decays exponentially in nn, up to n=d1n=d-1.

If there exists β[0,π/2]\beta\in[0,\pi/2] such that p(β)=0p(\beta)=0 and γ:=p(π/2β)<1\gamma:=p(\pi/2-\beta)<1, then

The polynomial case. The second class are tasks for which we can upper-bound pp by a κ\kappa-order polynomial. This can be done trivially for any task by setting κ=0.0\kappa=0.0, but that choice would yield a vacuous bound. Higher values of κ\kappa correspond to stronger assumptions on the distribution but enable better rates. Figure 2 (center, right) shows examples of polynomial distributions for κ{1.0,2.0}\kappa\in{\left\{1.0,2.0\right\}}. The special case κ=1.0\kappa=1.0 corresponds to a uniform angle distribution, while distribution with κ=2.0\kappa=2.0 have low probability mass near the decision boundary, while not necessarily exhibiting a margin.

The following corollary establishes that for tasks with polynomial behavior of p(θ)p(\theta), the expected risk decays essentially at a rate of (logn/n)κ(\log n/n)^{\kappa} or faster.

If there exists a κ0\kappa\geq 0 be such that p(θ)c(1(2/π)θ)κp(\theta)\leq c\cdot(1-(2/\pi)\theta)^{\kappa} for all θ[0,π/2]\theta\in[0,\pi/2], then

We apply Theorem 3 and insert the polynomial upper bound for pp. For the case n<dn<d, we get

Finally, we use the inequality ex1+xe^{x}\geq 1+x and the claim follows. ∎

Note that, in contrast to many results in statistical learning theory, the bounds are far from vacuous, even when only little data is available. This can best be seen in Corollary 1, where γ<1\gamma<1 and hence γn\gamma^{n} is an informative upper bound for the classification error. These observations suggest that distillation operates in a very different regime from classical hard-target learning. Standard bounds usually have little to say when n<dn<d and only start to be useful when ndn\gg d. In contrast, (linear) distillation ensures perfect transfer when ndn\geq d and non-vacuous bounds are possible even when n<dn<d.

3 Proof of Theorem 3

The case ndn\geq d follows trivially from the result of Theorem 1 and 2. For the case n<dn<d, the following property turns out to be crucial for obtaining a transfer rate of the form that we do.

and an analogous statement holds for X\mathbf{X}_{-}. Now, because the first nn_{-} columns of Q+\mathbf{Q}_{+} coincide with Q\mathbf{Q}_{-}, we have Q+wQw{\left\|\mathbf{Q}_{+}^{\intercal}\mathbf{w}_{*}\right\|}\geq{\left\|\mathbf{Q}_{-}^{\intercal}\mathbf{w}_{*}\right\|} and

Taking cos1\cos^{-1} on both sides (and remembering that cos1\cos^{-1} is decreasing) yields the claim. ∎

For the moment, think of αˉ(w,w^)\bar{\alpha}(\mathbf{w}_{*},\hat{\mathbf{w}}) as a proxy for the transfer risk, i.e. the closer the trained student w^\hat{\mathbf{w}} is to the teacher w\mathbf{w}_{*} in terms of angles, the lower the transfer risk. A direct consequence of Lemma 1, and the reason we call it ‘strong mononoticity’, is that including additional data in the transfer set can never harm the transfer risk, only improve it. This property is specific to distillation; it does not hold in hard-target learning.

We decompose the expected risk as follows:

Let us fix some x\mathbf{x} for which αˉ(w,x)<β\bar{\alpha}(\mathbf{w}_{*},\mathbf{x})<\beta and wx>0\mathbf{w}_{*}^{\intercal}\mathbf{x}>0 (i.e. an ‘easy’ positive test example); for this x\mathbf{x} we have α(w,x)=αˉ(w,x)\alpha(\mathbf{w}_{*},\mathbf{x})=\bar{\alpha}(\mathbf{w}_{*},\mathbf{x}). Consider the situation where αˉ(w,xi)<π/2β\bar{\alpha}(\mathbf{w}_{*},\mathbf{x}_{i})<\pi/2-\beta for some ii (i.e. there is at least one good teaching point). Then, Lemma 1 with X+=X\mathbf{X}_{+}=\mathbf{X} and X=xi\mathbf{X}_{-}=\mathbf{x}_{i} yields αˉ(w,w^)αˉ(w,xi)<π/2β\bar{\alpha}(\mathbf{w}_{*},\hat{\mathbf{w}})\leq\bar{\alpha}(\mathbf{w}_{*},\mathbf{x}_{i})<\pi/2-\beta. Combined with the triangle inequality, we obtain

which implies w^x>0\hat{\mathbf{w}}^{\intercal}\mathbf{x}>0, i.e. a correct prediction (same as the teacher’s). Conversely, an error can occur only if αˉ(w,xi)π/2β\bar{\alpha}(\mathbf{w}_{*},\mathbf{x}_{i})\geq\pi/2-\beta for all ii. Because xi\mathbf{x}_{i} are independent, we have

By a symmetric argument, one can show that

Combining (29), (32) and (33) yields the result:

Why Does Distillation Work?

From the formal analysis in the previous section, three concepts emerge as key factors for the success of distillation: data geometry, optimization bias, and strong monotonicity. In this section, we discuss these factors and provide some empirical confirmation how they affect or explain variations in the transfer risk.

From Theorem 3 we know that the data geometry, in particular the angular alignment between the data distribution and the teacher, crucially impact how fast the student can learn. Formally, this is reflected in p(θ)p(\theta): the faster it decreases, the easier it should be for the student to learn the task.

To experimentally test the effect of data geometry on the effectiveness of distillation, we adopt the setting of Corollary 2. We consider a series of tasks of varying angular alignment, as measured by the degree, κ\kappa, of the polynomial by which p(θ)p(\theta) is upper bounded.

We use an input space dimension of d=1000d=1000 and a transfer set size n=20n=20. Then, we train a linear student by distillation on each of the tasks and evaluate its transfer risk on held-out data. Figure 3 shows the results. The plot shows a clearly decreasing trend: on tasks with more favorable data geometry (higher κ\kappa), transfer via distillation is more effective and the student achieves lower risk.

2 Optimization Bias

A second key factor for the success of distillation is a specific optimization bias. For n<dn<d, the distillation training objective (1) has many minima of identical function value but potentially different generalization properties. Therefore, the optimization method used could have a large impact on the transfer risk. As Theorems 1 and 2 show, gradient descent has a particularly favorable bias for distillation.

To verify this observation experimentally, we consider learners that are guided by an optimisation bias to different degrees: at one end of the spectrum is the gradient-descent learner we have studied in previous sections, while at the other end is a learner that treats all minimizers of the distillation training loss equally, i.e. that has no bias toward any of the solutions. Specifically, consider learners with weights of the form wδ=w^+δw^qq\mathbf{w}_{\delta}=\hat{\mathbf{w}}+\delta\frac{{\left\|\hat{\mathbf{w}}\right\|}}{{\left\|\mathbf{q}\right\|}}\mathbf{q}, where w^\hat{\mathbf{w}} is the gradient-descent distillation solution and q\mathbf{q} is a Gaussian random vector in the subspace orthogonal to the data span, i.e. if X\mathbf{X} is the data matrix, then Xq=0\mathbf{X}^{\intercal}\mathbf{q}=\mathbf{0}. All learners of this form globally minimize the distillation training loss, and depending on δ\delta, they are more or less guided by the gradient-descent bias: δ=0\delta=0 and δ{\left|\delta\right|}\to\infty represent the two extremes mentioned above.

Figure 4 shows the result. There is a clear trend in favor of learners that are more strongly guided by the gradient-descent bias (small δ\delta); these learners generally achieve lower transfer risk. This result supports the idea of optimization bias as a key component of distillation’s success.

3 Strong Monotonicity

The third key factor we identify is strong monotonicity, as established in Lemma 1: training the student on more data always leads to a better approximation of the teacher’s weight vector.

Compared to data geometry and optimisation bias, strong monotonicity is less amenable to experimental study because it is a downstream property that cannot directly be manipulated. We therefore take an indirect approach. We consider a set of learners including the gradient-descent distillation learner, the hard-target learner, and several learners with reduced optimisation bias (as in Section 5.2), and train them on the same task. For each learner, we note its expected risk calculated on a held-out set, and its monotonicity index, defined as the probability that an additional training example reduces the angle between the student’s and the teacher’s weight vectors rather than increasing it, i.e.

where the student’s weight vector w\mathbf{w} is now treated as a function of the training set. Thus, we can relate a learner’s risk and its monotonicity.

We train the learners on the polynomial-angle task (Pxκ,wκ)(P_{\mathbf{x}}^{\kappa},\mathbf{w}_{*}^{\kappa}) from Section 5.1, with κ=1,d=100\kappa=1,d=100 and n=5n=5. The expected risk as well as the monotonicity index are estimated as averages over 1000 transfer sets.

The results are shown in Figure 5. There is a negative correlation between monotonicity and transfer risk, which supports the intuition of monotonicity as a desirable property and a possible explanation of distillation’s success.

However, a few reservations are in order. First, as mentioned above, monotonicity cannot easily be manipulated, so its effect on transfer risk remains unknown. We can only measure correlation. Second, monotonicity is of binary nature; it only captures whether an extra data point helps or not. Yet for a quantitative characterization of risk, one would have to capture by how much an extra data point helps. We leave more refined definitions of monotonicity for future work.

Conclusion

In this work, we have formulated and studied a linear model of knowledge distillation. Within this model, we have derived a) a characterization of the solution learned by the student, b) a bound on the transfer risk, meaningful even in the low-data regime, and c) three key factors that explain the success of distillation. In doing so, we hope to have enriched both the current intuitive and theoretical understanding of distillation, both of which have only been weakly developed.

Our work paints a picture of distillation as an extremely effective method for knowledge transfer that derives its power from an optimization bias of gradient-based methods initialized near the origin, which in particular has the effect that any additionally included training point can only improve the student’s approximation of the teacher. Distillation further benefits strongly from a favorable data geometry, in particular a margin between classes.

While we have supported this picture by theory and empirical work only in the linear case, we hypothesize that similar properties also govern the behavior of distillation in the nonlinear setting. If this hypothesis turns out to be true, it would have implications for the design of transfer sets (a large teacher model being stored along with only the minimal dataset necessary for future transfer) or active learning (which samples are most informative to have labeled by the teacher). Potentially, strong monotonicity could serve as a leading design principle for new sample-efficient algorithms. We thus consider the extension to nonlinear models the main direction for future work.

References

Appendix A Properties of the Cross-Entropy Loss

The gradient of the cross-entropy loss (35) takes the form

The global minimum of the cross-entropy loss (35) is 0 and the set of global minimisers is

Assume X\mathbf{X} is full-rank. For any sublevel set W={w:L1(w)l}\mathcal{W}={\left\{\mathbf{w}:L^{1}(\mathbf{w})\leq l\right\}}, there exists μ>0\mu>0 such that

Consider the 2nd-order Taylor expansion of L1L^{1} around w\mathbf{w},

where 2L1(wˉ)\nabla^{2}L^{1}(\bar{\mathbf{w}}) is the Hessian of L1L^{1} evaluated at wˉ\bar{\mathbf{w}}, a point lying between v\mathbf{v} and w\mathbf{w}. A straightforward calculation shows that the Hessian takes the form

We will now show that there is a constant ω>0\omega>0 such that

for all wˉW\bar{\mathbf{w}}\in\mathcal{W} and i{1,,n}i\in{\left\{1,\dots,n\right\}}, so that we can claim DwˉωI\mathbf{D}_{\bar{\mathbf{w}}}\succeq\omega\mathbf{I}, or consequently 2L1(wˉ)ωXX\nabla^{2}L^{1}(\bar{\mathbf{w}})\succeq\omega\mathbf{X}\mathbf{X}^{\intercal}.

Now, let us apply 2L1(w)ωXX\nabla^{2}L^{1}(\mathbf{w})\succeq\omega\mathbf{X}\mathbf{X}^{\intercal} to lower-bound (41):

Assume X\mathbf{X} is full-rank. For any sublevel set W={w:L1(w)l}\mathcal{W}={\left\{\mathbf{w}:L^{1}(\mathbf{w})\leq l\right\}}, there exists c>0c>0 such that

Let wW\mathbf{w}\in\mathcal{W}. (If W\mathcal{W} is empty, the claim is trivially true.) Theorem A.3 applied to W\mathcal{W} implies that for some μ>0\mu>0,

Appendix B Proof of Theorem 1

We will prove a supporting lemma, and then the theorem.

The data matrix X\mathbf{X} is almost surely (wrt. XPxn\mathbf{X}\sim P_{\mathbf{x}}^{n}) full-rank, we can therefore apply Corollary A.1 to W={w:L1(w)L1(0)}\mathcal{W}={\left\{\mathbf{w}:L^{1}(\mathbf{w})\leq L^{1}(\mathbf{0})\right\}} and w(τ)\mathbf{w}(\tau) to lower-bound the gradient norm on the right-hand side of (53). We obtain L(τ)cL(τ)L^{\prime}(\tau)\leq-cL(\tau) for some c>0c>0 and all τ[0,)\tau\in[0,\infty), or equivalently,

Integrating over [0,t][0,t] yields L(t)L(0)ectL(t)\leq L(0)\cdot e^{-ct}, which proves global convergence in the objective: L(t)0L(t)\to 0 as tt\to\infty.

Since L(t)0L(t)\to 0 as tt\to\infty, the theorem follows. ∎

Appendix C Proof of Theorem 2

For the proof, we will need a result by (Arora et al., 2018), which characterises the induced flow on w(τ)\mathbf{w}(\tau) when running gradient descent on the component matrices Wi\mathbf{W}_{i}.

If the balancedness condition (13) holds, then

Similarly to the case N=1N=1, we start by looking at the time-derivative of LL,

It is non-positive, so w(τ)\mathbf{w}(\tau) stays within the L(0)L(0)-sublevel set throughout optimisation,

Also, W\mathcal{W} is convex and by Assumption (12) it does not contain 0\mathbf{0}. We can therefore take δ>0\delta>0 to be the distance between W\mathcal{W} and 0\mathbf{0}, and it follows that w(τ)δ{\left\|\mathbf{w}(\tau)\right\|}\geq\delta for τ[0,)\tau\in[0,\infty).

Now, noting that X\mathbf{X} is almost surely full-rank, apply Corollary A.1 to W\mathcal{W} and w(τ)\mathbf{w}(\tau) to upper-bound the right-hand side of (57),

To prove convergence in parameters, we decompose the ‘error’ w(τ)w^\mathbf{w}(\tau)-\hat{\mathbf{w}} into orthogonal components and bound each of them separately,

It turns out that the right-hand side expression is integrable in yet another way, namely

Equating the two and integrating over [0,t][0,t] yields

because q(0)w(0)2q(0)\leq{\left\|\mathbf{w}(0)\right\|}^{2}.

We now bound the norm of w(t)\mathbf{w}(t). Starting from an orthogonal decomposition similar to (60) and applying (62) with (67), we get

Denote ν:=lim suptw(t)\nu:=\limsup_{t\to\infty}{\left\|\mathbf{w}(t)\right\|}. By the same orthogonal decomposition, we also know that ν2lim suptPXw(t)2=w^2>0\nu^{2}\geq\limsup_{t\to\infty}{\left\|\mathbf{P}_{\mathbf{X}}\mathbf{w}(t)\right\|}^{2}={\left\|\hat{\mathbf{w}}\right\|}^{2}>0, so we can divide both sides above by ν2\nu^{2},

On the right-hand side, we now have a decreasing function of ν\nu that goes to zero as ν\nu\to\infty. However, evaluated at our specific ν\nu, it is lower-bounded by 11, implying an implicit upper bound for ν\nu.

How do we find this bound? Suppose we find some constant KK such that f(K)1f(K)\leq 1. Then, because ff is decreasing, it must be the case that νK\nu\leq K. One such candidate for KK is

(Here we have used condition (11): w(0)<w^{\left\|\mathbf{w}(0)\right\|}<{\left\|\hat{\mathbf{w}}\right\|}.) To check that indeed f(K)1f(K)\leq 1, start from the inequality

Taking the leftmost and rightmost expression and multiplying by (w^/K)2/N({\left\|\hat{\mathbf{w}}\right\|}/K)^{2/N} yields

Finally, let us turn back to our original goal of bounding w(τ)w^2{\left\|\mathbf{w}(\tau)-\hat{\mathbf{w}}\right\|}^{2}. With (60), (62), (67) and (73), we now know that

Hence, if we initialise close enough to zero, as specified by condition (11), we can ensure that

Appendix D Theorem 3 for Approximate Distillation

We extend Theorem 3 to the setting where the student learns the solution w^=X(XX)1Xw\hat{\mathbf{w}}=\mathbf{X}(\mathbf{X}^{\intercal}\mathbf{X})^{-1}\mathbf{X}^{\intercal}\mathbf{w}_{*} only ϵ\epsilon-approximately, as is the case for deep linear networks initialised as in Theorem 2. When ndn\geq d, the teacher’s weight vector is recovered exactly and the transfer risk is zero, even when the student is deep. The following theorem therefore only covers the case n<dn<d.

The result is very similar to Theorem 3 in the main text, the only difference is the constant δ\delta which compensates for the imprecision in learning w^\hat{\mathbf{w}} by pushing the bound up (recall that pp is decreasing). However, as ϵ\epsilon goes to zero, so does δ\delta and we recover the original bound.

For the proof, we start with a tool for controlling the angle between w^\hat{\mathbf{w}} and w^ϵ\hat{\mathbf{w}}_{\epsilon}. Recall that the angle is defined as

The first step is to lower-bound the inner product wv\mathbf{w}^{\intercal}\mathbf{v}. To that end, we expand and rearrange wv2ϵ2{\left\|\mathbf{w}-\mathbf{v}\right\|}^{2}\leq\epsilon^{2} to obtain

Now use the triangle relation vwϵ{\left\|\mathbf{v}\right\|}\geq{\left\|\mathbf{w}\right\|}-\epsilon squared to lower-bound the right-hand side of (80) and get

The left-hand side is by assumption non-negative, so we have α(w,v)[π/2,π/2]\alpha(\mathbf{w},\mathbf{v})\in[-\pi/2,\pi/2]. On this domain,

We decompose the expected risk as follows:

Let us fix some x\mathbf{x} for which αˉ(w,x)<β\bar{\alpha}(\mathbf{w}_{*},\mathbf{x})<\beta and wx>0\mathbf{w}_{*}^{\intercal}\mathbf{x}>0; for this x\mathbf{x} we have α(w,x)=αˉ(w,x)\alpha(\mathbf{w}_{*},\mathbf{x})=\bar{\alpha}(\mathbf{w}_{*},\mathbf{x}). Consider the situation where αˉ(w,xi)<π/2βδ\bar{\alpha}(\mathbf{w}_{*},\mathbf{x}_{i})<\pi/2-\beta-\delta for some ii. Then by the triangle inequality, Lemma D.1 and Lemma 1,

which implies w^ϵx>0\hat{\mathbf{w}}_{\epsilon}^{\intercal}\mathbf{x}>0, i.e. a correct prediction (same as the teacher’s). Conversely, an error can occur only if αˉ(w,xi)π/2δβ\bar{\alpha}(\mathbf{w}_{*},\mathbf{x}_{i})\geq\pi/2-\delta-\beta for all ii. Because xi\mathbf{x}_{i} are independent, we have

By a symmetric argument, one can show that

Combining (86), (90) and (91) yields the result. ∎