Towards Explaining the Regularization Effect of Initial Large Learning Rate in Training Neural Networks

Yuanzhi Li, Colin Wei, Tengyu Ma

Introduction

It is a commonly accepted fact that a large initial learning rate is required to successfully train a deep network even though it slows down optimization of the train loss. Modern state-of-the-art architectures typically start with a large learning rate and anneal it at a point when the model’s fit to the training data plateaus . Meanwhile, models trained using only small learning rates have been found to generalize poorly despite enjoying faster optimization of the training loss.

A number of papers have proposed explanations for this phenomenon, such as sharpness of the local minima , the time it takes to move from initialization , and the scale of SGD noise . However, we still have a limited understanding of a surprising and striking part of the large learning rate phenomenon: from looking at the section of the accuracy curve before annealing, it would appear that a small learning rate model should outperform the large learning rate model in both training and test error. Concretely, in Fig. 1, the model trained with small learning rate outperforms the large learning rate until epoch 60 when the learning rate is first annealed. Only after annealing does the large learning rate visibly outperform the small learning rate in terms of generalization.

In this paper, we propose to theoretically explain this phenomenon via the concept of learning order of the model, i.e., the rates at which it learns different types of examples. This is not a typical concept in the generalization literature — learning order is a training-time property of the model, but most analyses only consider post-training properties such as the classifier’s complexity , or the algorithm’s output stability . We will construct a simple distribution for which the learning order of a two-layer network trained under large and small initial learning rates determines its generalization.

Informally, consider a distribution over training examples consisting of two types of patterns (“pattern” refers to a grouping of features). The first type consists of a set of easy-to-generalize (i.e., discrete) patterns of low cardinality that is difficult to fit using a low-complexity classifier, but easily learnable via complex classifiers such as neural networks. The second type of pattern will be learnable by a low-complexity classifier, but are inherently noisy so it is difficult for the classifier to generalize. In our case, the second type of pattern requires more samples to correctly learn than the first type. Suppose we have the following split of examples in our dataset:

The following informal theorems characterize the learning order and generalization of the large and small initial learning rate models. They are a dramatic simplification of our Theorems 3.4 and 3.5 meant only to highlight the intuitions behind our results.

There is a dataset with size NN of the form (1.1) such that with a large initial learning rate and noisy gradient updates, a two layer network will:

1) initially only learn hard-to-generalize, easy-to-fit patterns from the 0.8N0.8N examples containing such patterns.

2) learn easy-to-generalize, hard-to-fit patterns only after the learning rate is annealed.

Thus, the model learns hard-to-generalize, easily fit patterns with an effective sample size of 0.8N0.8N and still learns all easy-to-generalize, hard to fit patterns correctly with 0.2N0.2N samples.

In the same setting as above, with small initial learning rate the network will:

1) quickly learn all easy-to-generalize, hard-to-fit patterns.

2) ignore hard-to-generalize, easily fit patterns from the 0.6N0.6N examples containing both pattern types, and only learn them from the 0.2N0.2N examples containing only hard-to-generalize patterns.

Thus, the model learns hard-to-generalize, easily fit patterns with a smaller effective sample size of 0.2N0.2N and will perform relatively worse on these patterns at test time.

Together, these two theorems can justify the phenomenon observed in Figure 1 as follows: in a real-world network, the large learning rate model first learns hard-to-generalize, easier-to-fit patterns and is unable to memorize easy-to-generalize, hard-to-fit patterns, leading to a plateau in accuracy. Once the learning rate is annealed, it is able to fit these patterns, explaining the sudden spike in both train and test accuracy. On the other hand, because of the low amount of SGD noise present in easy-to-generalize, hard-to-fit patterns, the small learning rate model quickly overfits to them before fully learning the hard-to-generalize patterns, resulting in poor test error on the latter type of pattern.

Both intuitively and in our analysis, the non-convexity of neural nets is crucial for the learning-order effect to occur. Strongly convex problems have a unique minimum, so what happens during training does not affect the final result. On the other hand, we show the non-convexity causes the learning order to highly influence the characteristics of the solutions found by the algorithm.

In Section E.1, we propose a mitigation strategy inspired by our analysis. In the same setting as Theorems 1.1 and 1.2, we consider training a model with small initial learning rate while adding noise before the activations which gets reduced by some constant factor at some particular epoch in training. We show that this algorithm provides the same theoretical guarantees as the large initial learning rate, and we empirically demonstrate the effectiveness of this strategy in Section 7. In Section 7 we also empirically validate Theorems 1.1 and 1.2 by adding an artificial memorizable patch to CIFAR-10 images, in a manner inspired by (1.1).

The question of training with larger batch sizes is closely tied with learning rate, and many papers have empirically studied large batch/small LR phenomena , particularly focusing on vision tasks using SGD as the optimizer.While these papers are framed as a study of large-batch training, a number of them explicitly acknowledge the connection between large batch size and small learning rate. Keskar et al. argue that training with a large batch size or small learning rate results in sharp local minima. Hoffer et al. propose training the network for longer and with larger learning rate as a way to train with a larger batch size. Wen et al. propose adding Fisher noise to simulate the regularization effect of small batch size.

Adaptive gradient methods are a popular method for deep learning that adaptively choose different step sizes for different parameters. One motivation for these methods is reducing the need to tune learning rates . However, these methods have been observed to hurt generalization performance , and modern architectures often achieve the best results via SGD and hand-tuned learning rates . Wilson et al. construct a toy example for which ADAM generalizes provably worse than SGD. Additionally, there are several alternative learning rate schedules proposed for SGD, such as warm-restarts and . Ge et al. analyze the exponentially decaying learning rate and show that its final iterate achieves optimal error in stochastic optimization settings, but they only analyze convex settings.

There are also several recent works on implicit regularization of gradient descent that establish convergence to some idealized solution under particular choices of learning rate . In contrast to our analysis, the generalization guarantees from these works would depend only on the complexity of the final output and not on the order of learning.

Other recent papers have also studied the order in which deep networks learn certain types of examples. Mangalam and Prabhu and Nakkiran et al. experimentally demonstrate that deep networks may first fit examples learnable by “simpler” classifiers. For our construction, we prove that the neural net with large learning rate follows this behavior, initially learning a classifier on linearly separable examples and learning the remaining examples after annealing. However, the phenomenon that we analyze is also more nuanced: with a small learning rate, we prove that the model first learns a complex classifier on low-noise examples which are not linearly separable.

Finally, our proof techniques and intuitions are related to recent literature on global convergence of gradient descent for over-parametrized networks . These works show that gradient descent learns a fixed kernel related to the initialization under sufficient over-parameterization. In our analysis, the underlying kernel is changing over time. The amount of noise due to SGD governs the space of possible learned kernels, and as a result, regularizes the order of learning.

Setup and Notations

We formally introduce our data distribution, which contains examples supported on two types of components: a P\mathcal{P} component meant to model hard-to-generalize, easier-to-fit patterns, and a Q\mathcal{Q} component meant to model easy-to-generalize, hard-to-fit patterns (see the discussion in our introduction). Formally, we assume that the label yy has a uniform distribution over {1,1}\{-1,1\}, and the data xx is generated as

where P1,P1\mathcal{P}_{-1},\mathcal{P}_{1} are assumed to be two half Gaussian distributions with a margin γ0\gamma_{0} between them:

Memorizing 𝒬𝒬\mathcal{Q} with a two-layer net

It is easy for a two-layer relu network to memorize the labels of x2x_{2} using two neurons with weights w,vw,v such that w,z<0\langle w,z\rangle<0, w,zζ>0\langle w,z-\zeta\rangle>0 an v,z<0\langle v,z\rangle<0, v,z+ζ>0\langle v,z+\zeta\rangle>0. In particular, we can verify that w,x2+v,x2+-\langle w,x_{2}\rangle_{+}-\langle v,x_{2}\rangle_{+} will output a negative value for x2{zζ,z+ζ}x_{2}\in\{z-\zeta,z+\zeta\} and a zero value for x2=zx_{2}=z. Thus choosing a small enough ρ>0\rho>0, the classifier w,x2+v,x2++ρ-\langle w,x_{2}\rangle_{+}-\langle v,x_{2}\rangle_{+}+\rho gives the correct sign for the label yy.

We assume that we have a training dataset with NN examples {(x(1),y(1)),,(x(N),y(N))}\{(x^{(1)},y^{(1)}),\cdots,(x^{(N)},y^{(N)})\} drawn i.i.d from the distribution described above. We use pp and qq to denote the empirical fraction of data points that are drawn from equation (2.2) and (2.3).

Two-layer neural network model

Here we slightly abuse the notation to use WW to denote both a matrix of 2d2d columns with last dd columns being zero, or a matrix of dd columns. We also extend our theorem to other UU such as a two layer convolution network in Section E.

Training objective

We consider a regularized training objective L^λ(u,U)=L^(u,U)+λ2UF2\widehat{L}_{\lambda}(u,U)=\widehat{L}(u,U)+\frac{\lambda}{2}\|U\|_{F}^{2}. For the simplicity of derivation, the second layer weight vector uu is random initialized and fixed throughout this paper. Thus with slight abuse of notation the training objective can be written as L^λ(U)=L^(u,U)+λ2UF2\widehat{L}_{\lambda}(U)=\widehat{L}(u,U)+\frac{\lambda}{2}\|U\|_{F}^{2}.

Notations

Main Results

The training algorithm that we consider is stochastic gradient descent with spherical Gaussian noise. We remark that we analyze this algorithm as a simplification of the minibatch SGD noise encountered when training real-world networks. There are a number of works theoretically characterizing this particular noise distribution , and we leave analysis of this setting to future work.

We initialize U0U_{0} to have i.i.d. entries from a Gaussian distribution with variance τ02\tau_{0}^{2}, and at each iteration of gradient descent we add spherical Gaussian noise with coordinate-wise variance τξ2\tau_{\xi}^{2} to the gradient updates. That is, the learning algorithm for the model is

where γt\gamma_{t} denotes the learning rate at time tt. We will analyze two algorithms:

Algorithm 1 (L-S): The learning rate is η1\eta_{1} for t0t_{0} iterations until the training loss drops below the threshold ε1+qlog2\varepsilon_{1}+q\log 2. Then we anneal the learning rate to γt=η2\gamma_{t}=\eta_{2} (which is assumed to be much smaller than η1\eta_{1}) and run until the training loss drops to ε2\varepsilon_{2}.

Algorithm 2 (S): We used a fixed learning rate of η2\eta_{2} and stop at training loss ε2ε2\varepsilon_{2}^{\prime}\leq\varepsilon_{2}.

For the convenience of the analysis, we make the following assumption that we choose τ0\tau_{0} in a way such that the contribution of the noises in the system stabilize at the initialization:Let τ0\tau_{0}^{\prime} be the solution to (3.3) holding τξ,η1,λ\tau_{\xi},\eta_{1},\lambda fixed. If the standard deviation of the initialization is chosen to be smaller than τ0\tau_{0}^{\prime}, then standard deviation of the noise will grow to τ0\tau_{0}^{\prime}. Otherwise if the initialization is chosen to be larger, the contribution of the noise will decrease to the level of τ0\tau_{0}^{\prime} due to regularization. In typical analysis of SGD with spherical noises, often as long as either the noise or the learning rate is small enough, the proof goes through. However, here we will make explicit use of the large learning rate or the large noise to show better generalization performance.

After fixing λ\lambda and τξ\tau_{\xi}, we choose initialization τ0\tau_{0} and large learning rate η1\eta_{1} so that

As a technical assumption for our proofs, we will also require η1ε1\eta_{1}\lesssim\varepsilon_{1}.

We also require sufficient over-parametrization.

We assume throughout the paper that τ0=1/poly(dε)\tau_{0}=1/\textup{poly}\left(\frac{d}{\varepsilon}\right) and mpoly(dετ0)m\geq\textup{poly}\left(\frac{d}{\varepsilon\tau_{0}}\right) where poly is a sufficiently large constant degree polynomial. We note that we can choose τ0\tau_{0} arbitrarily small, so long as it is fixed before we choose mm.

As we will see soon, the precise relation between N,dN,d implies that the level of over-parameterization is polynomial in N,ϵN,\epsilon, which fits with the conditions assumed in prior works, such as .

Throughout this paper, we assume the following dependencies between the parameters. We assume that N,dN,d\rightarrow\infty with a relationship Nd=1κ2\frac{N}{d}=\frac{1}{\kappa^{2}} where κ(0,1)\kappa\in(0,1) is a small value.Or in a non-asymptotic language, we assume that N,dN,d are sufficiently large compared to κ\kappa: N,dpoly(κ)N,d\gg\textup{poly}(\kappa) We set r=d3/4r=d^{-3/4}, p0=κ2/2p_{0}=\kappa^{2}/2, and q0=Θ(1)q_{0}=\Theta(1). The regularizer will be chosen to be λ=d5/4\lambda=d^{-5/4}. All of these choices of hyper-parameters can be relaxed, but for simplicity of exposition we only work this setting.

We note that under our assumptions, for sufficiently large NN, pp0p\approx p_{0} and qq0q\approx q_{0} up to constant multiplicative factors. Thus we will mostly work with pp and qq (the empirical fractions) in the rest of the paper. We also note that our parameter choice satisfies (rd)1,dλ,λ/rκO(1)(rd)^{-1},d\lambda,\lambda/r\leq\kappa^{O(1)} and λr2/(κ2q3p2)\lambda\leq r^{2}/(\kappa^{2}q^{3}p^{2}), which are a few conditions that we frequently use in the technical part of the paper.

Now we present our main theorems regarding the generalization of models trained with the L-S and S algorithms. The final generalization error of the model trained with the L-S algorithm will end up a factor O(κ)=O(p1/2)O(\kappa)=O(p^{1/2}) smaller than the generalization error of the model trained with S algorithm.

Under Assumption 3.1, 3.2, and 3.3, there exists a universal constant 0<c<1/160<c<1/16 such that Algorithm 1 (L-S) with annealing at loss ε1+qlog2\varepsilon_{1}+q\log 2 for ε1(dc,κ2p2q3)\varepsilon_{1}\in\left(d^{-c},\kappa^{2}p^{2}q^{3}\right) and stopping criterion ε2=ε1/q\varepsilon_{2}=\sqrt{\varepsilon_{1}/q} satisfies the following:

It anneals the learning rate within O~(dη1ε1)\widetilde{O}\left(\frac{d}{\eta_{1}\varepsilon_{1}}\right) iterations.

It stops at at most t=O~(dη1ε1+1η2rε13)t=\widetilde{O}\left(\frac{d}{\eta_{1}\varepsilon_{1}}+\frac{1}{\eta_{2}r\varepsilon_{1}^{3}}\right). With probability at least 0.99, the solution UtU_{t} has test (classification) error and test loss at most O(pκlog1ε1){O}\left(p\kappa\log\frac{1}{\varepsilon_{1}}\right).

Roughly, the learning order and generalization of the L-S model is as follows: before annealing the learning rate, the model only learns an effective classifier for P\mathcal{P} on the (1q)N\approx(1-q)N samples in M1\mathcal{M}_{1} as the large learning rate creates too much noise to effectively learn Q\mathcal{Q} (Lemma 4.1 and Lemma 4.2). After the learning rate is annealed, the model memorizes Q\mathcal{Q} and correctly classifies examples with only a Q\mathcal{Q} component during test time (formally shown in Lemmas 4.3 and 4.4). For examples with only P\mathcal{P} component, the generalization error is (ignoring log factors and other technicalities) pdN=O(pκ)p\sqrt{\frac{d}{N}}=O(p\kappa) via standard Rademacher complexity. The full analysis of the L-S algorithm is clarified in Section 4.

Let ε2\varepsilon_{2} be chosen in Theorem 3.4. Under Assumption 3.1, 3.2 and 3.3, there exists a universal constant c>0c>0 such that w.h.p, Algorithm 2 with any η2η1dc\eta_{2}\leq\eta_{1}d^{-c} and any stopping criterion ε2(dc,ε2]\varepsilon_{2}^{\prime}\in(d^{-c},\varepsilon_{2}], achieves training loss ε2\varepsilon_{2}^{\prime} in at most O~(dη2ε2)\widetilde{O}\left(\frac{d}{\eta_{2}\varepsilon_{2}^{\prime}}\right) iterations, and both the test error and the test loss of the obtained solution are at least Ω(p)\Omega(p).

We explain this lower bound as follows: the S algorithm will quickly memorize the Q\mathcal{Q} component which is low noise and ignore the P\mathcal{P} component for the 1pq\approx 1-p-q examples with both P\mathcal{P} and Q\mathcal{Q} components (shown in Lemma 5.2). Thus, it only learns P\mathcal{P} on pN\approx pN examples. It obtains a small margin on these examples and therefore misclassifies a constant fraction of P\mathcal{P}-only examples at test time. This results in the lower bound of Ω(p)\Omega(p). We formalize the analysis in Section 5.

It will be fruitful for our analysis to separately consider the gradient signal and Gaussian noise components of the weight matrix UtU_{t}. We will decompose the weight matrix UtU_{t} as follows: Ut=Ut+U~tU_{t}=\overline{U}_{t}+\widetilde{U}_{t}. In this formula, Ut\overline{U}_{t} denotes the signals from all the gradient updates accumulated over time, and U~t\widetilde{U}_{t} refers to the noise accumulated over time:

Note that when the learning rate γt\gamma_{t} is always η\eta, the formula simplifies to Ut=s=1tη(1ηλ)tsL^(Us1)\overline{U}_{t}=\sum_{s=1}^{t}\eta(1-\eta\lambda)^{t-s}\nabla\widehat{L}(U_{s-1}) and U~t=(1ηλ)tU0+s=1tη(1ηλ)tsξs1\widetilde{U}_{t}=(1-\eta\lambda)^{t}U_{0}+\sum_{s=1}^{t}\eta(1-\eta\lambda)^{t-s}\xi_{s-1}. The decoupling and our particular choice of initialization satisfies that the noise updates in the system stabilize at initialization, so the marginal distribution of U~t\widetilde{U}_{t} is always the same as the initialization. Another nice aspect of the signal-noise decomposition is as follows: we use tools from to show that if the signal term U\overline{U} is small, then using only the noise component U~\widetilde{U} to compute the activations roughly preserves the output of the network. This facilitates our analysis of the network dynamics. See Section A.1 for full details.

Decomposition of Network Outputs

For convenience, we will explicitly decompose the model prediction at each time into two components, each of which operates on one pattern: we have NUt(u,Ut;x)=gt(x)+rt(x)N_{U_{t}}(u,U_{t};x)=g_{t}(x)+r_{t}(x),

In other words, the network gtg_{t} acts on the Q\mathcal{Q} component of examples, and the network rtr_{t} acts on the P\mathcal{P} component of examples.

Characterization of Algorithm 1 (L-S)

We characterize the behavior of algorithm L-S with large initial learning rate. We provide proof sketches in Section 6.1 with full proofs in Section C.

The following lemma bounds the rate of convergence to the point where the loss gets annealed. It also bounds the total gradient signal accumulated by this point.

In the setting of Theorem 3.4, at some time step t0O~(dη1ε1)t_{0}\leq\widetilde{O}\left(\frac{d}{\eta_{1}\varepsilon_{1}}\right), the training loss L^(Ut0)\widehat{L}(U_{t_{0}}) becomes smaller than qlog2+ϵ1q\log 2+\epsilon_{1}. Moreover, we have Ut0F2=O(dlog21ε1)\|\overline{U}_{t_{0}}\|_{F}^{2}={O}\left(d\log^{2}\frac{1}{\varepsilon_{1}}\right).

Our proof of Lemma 4.1 views the SGD dynamics as optimization with respect to the neural tangent kernel induced by the activation patterns where the kernel is rapidly changing due to the noise terms ξ\xi. This is in contrast to the standard NTK regime, where the activation patterns are assumed to be stable . Our analysis extends the NTK techniques to deal with a sequence of changing kernels which share a common optimal classifier (see Section 6.1 and Theorem 6.2 for additional details).

The next lemma says that with large initial learning rate, the function gtg_{t} does not learn anything meaningful for the Q\mathcal{Q} component before the 1η1λ\frac{1}{\eta_{1}\lambda}-timestep. Note that by our choice of parameters 1/λd1/\lambda\gg d and Lemma 4.1, we anneal at the time step O~(dη1ε1)1η1λ\widetilde{O}\left(\frac{d}{\eta_{1}\varepsilon_{1}}\right)\leq\frac{1}{\eta_{1}\lambda}. Therefore, the function has not learned anything meaningful about the memorizable pattern on distribution Q\mathcal{Q} before we anneal.

In the setting of Theorem 3.4, w.h.p., for every t1η1λt\leq\frac{1}{\eta_{1}\lambda},

After iteration t0t_{0}, we decrease the learning rate to η2\eta_{2}. The following lemma bounds how fast the loss converges after annealing.

In the setting of Theorem 3.4, there exists t=O~(1ε13η2r)t=\widetilde{O}\left(\frac{1}{\varepsilon_{1}^{3}\eta_{2}r}\right), such that after t0+tt_{0}+t iterations, we have that

Moreover, Ut0+tUt0F2O~(1ε12r)O(d)\|\overline{U}_{t_{0}+t}-\overline{U}_{t_{0}}\|_{F}^{2}\leq\widetilde{O}\left(\frac{1}{\varepsilon_{1}^{2}r}\right)\leq O(d).

The following lemma bounds the training loss on the example subsets M1\mathcal{M}_{1}, Mˉ1\bar{\mathcal{M}}_{1}.

In the setting of Lemma 4.3 using the same t=O~(1ε13η2r)t=\widetilde{O}\left(\frac{1}{\varepsilon_{1}^{3}\eta_{2}r}\right), the average training losses on the subsets M1\mathcal{M}_{1} and Mˉ1\bar{\mathcal{M}}_{1} are both good in the sense that

Intuitively, low training loss of gt0+tg_{t_{0}+t} on Mˉ1\bar{\mathcal{M}}_{1} immediately implies good generalization on examples containing patterns from Q\mathcal{Q}. Meanwhile, the classifier for P\mathcal{P}, rt0+tr_{t_{0}+t}, has low loss on (1q)N(1-q)N examples. Then the test error bound follows from standard Rademacher complexity tools applied to these (1q)N(1-q)N examples.

Characterization of Algorithm 2 (S)

We present our small learning rate lemmas, with proofs sketches in Section 6.2 and full proofs in Section D.

The below lemma shows that the algorithm will converge to small training error too quickly. In particular, the norm of WtW_{t} is not large enough to produce a large margin solution for those xx such that x2=0x_{2}=0.

Lower bound on the generalization error

The following important lemma states that our classifier for P\mathcal{P} does not learn much from the examples in M2\mathcal{M}_{2}. Intuitively, under a small learning rate, the classifier will already learn so quickly from the Q\mathcal{Q} component of these examples that it will not learn from the P\mathcal{P} component of examples in M1M2\mathcal{M}_{1}\cap\mathcal{M}_{2}. We make this precise by showing that the magnitude of the gradients on M2\mathcal{M}_{2} is small.

The above lemma implies that WW does not learn much from examples in M2\mathcal{M}_{2}, and therefore must overfit to the pNpN examples in Mˉ2\bar{\mathcal{M}}_{2}. As pNd/2pN\leq d/2 by our choice of parameters, we will not have enough samples to learn the dd-dimensional distribution P\mathcal{P}. The following lemma formalizes the intuition that the margin will be poor on samples from P\mathcal{P}.

As the margin is poor, the predictions will be heavily influenced by noise. We use this intuition to prove the classification lower bound for Theorem 3.5.

Proof Sketches

We first introduce notations that will be useful in these proofs. We will explicitly decouple the noise in the weights from the signal by abstracting the loss as a function of only the signal portion Ut\overline{U}_{t} of the weights. Let us define the following:

Now the proof of Lemma 4.1 relies on the following two results, which we state below and prove in Section C.1. The first says that there is a common target for the signal part of the network that is a good solution for all of the KtK_{t}.

In the setting of Lemma 4.1, there exists a solution UU^{\star} satisfying a) UF2O(dlog21ε1)\|U^{\star}\|_{F}^{2}\leq{O}\left(d\log^{2}\frac{1}{\varepsilon_{1}}\right) and b) for every t0t\geq 0

Now the second statement is a general one proving that gradient descent on a sequence of convex, but changing, functions will still find a optimum provided these functions share the same solution.

KtK_{t}’s are LL-Lipschitz, i.e., Kt(z)2L,z,t\|\nabla K_{t}(z)\|_{2}\leq L,\forall z,t

For every μ>0\mu>0, we have that for λR21100μ\lambda R^{2}\leq\frac{1}{100}\mu and ημ100(λ2R2+L2)\eta\leq\frac{\mu}{100(\lambda^{2}R^{2}+L^{2})}, ηT>R2μ\eta T>\frac{R^{2}}{\mu}, there is a t[T]t^{\star}\in[T] such that:

Furthermore, the iterates satisfy ztz2R\|z_{t}-z^{\star}\|_{2}\leq R for all ttt\leq t^{\star}.

Combining these two statements leads to the proof of Lemma 4.1.

We can apply Theorem 6.2 with KtK_{t} defined in (6.2) and z=Uz^{\star}=U^{\star} defined in Lemma 6.1, using R=O(dlog21ε1)R=O\left(d\log^{2}\frac{1}{\varepsilon_{1}}\right). We note that η1\eta_{1} satisfies the conditions of Theorem 6.2 by our parameter choices, which completes the proof. ∎

To prove Lemma 4.2, we will essentially argue in Section C.2 that the change in activations caused by the noise will prevent the model from learning Q\mathcal{Q} with a large learning rate. This is because the examples in Q\mathcal{Q} require a very specific configuration of activation patterns to learn correctly, and the noise will prevent the model from maintaining this configuration.

Now after we anneal the learning rate, in order to conclude Lemmas 4.3 and 4.4, the following must hold: 1) the network learns the Q\mathcal{Q} component of the distribution and 2) the network does not forget the P\mathcal{P} component that it previously learned. To prove the latter, we rely on the following lemma stating that the activations do not change much with a small learning rate:

The activation patterns do not change much after annealing the learning rate: for every t0,t1η2λt_{0},t\leq\frac{1}{\eta_{2}\lambda}, for any xx and for any row [Ut]i[U_{t}]_{i} of the weight matrix UU, we have that

Moreover, for all i[m]i\in[m], [Ut]i21λm\left\|[\overline{U}_{t}]_{i}\right\|_{2}\leq\frac{1}{\lambda\sqrt{m}}, it holds that w.h.p. for every xx:

We prove the above lemma in Section C.3. Now to complete the proof of Lemma 4.3, we will construct a target solution for all timesteps after annealing the learning rate based on the activations at time t0t_{0} (as they do not change by much in subsequent time steps because of Lemma 6.3) and reapply Theorem 6.2. Finally, to prove Lemma 4.4, we use the fact that the WtW_{t} component of the solution does not change by much, and therefore the loss on M1\mathcal{M}_{1} is still low.

2 Proof Sketches for Small Learning Rate

The proof of Lemma 5.1 proceeds similarly as the proof of Lemma 4.3: we will show the existence of a target solution of KtK_{t} for all iterations, and use Theorem 6.2 to prove convergence to this target solution.

The next two statements argue that ρt\rho_{t} can be large only in a limited number of time steps. As the training loss converges quickly with small learning rate, this will be used to argue that the P\mathcal{P} components of examples in M2\mathcal{M}_{2} provide a very limited signal to WtW_{t}. The proofs of these statements are in Section D.2.

We first show the following lemma that says that if ρt\rho_{t} is large (which means the loss is large as well), then the total gradient norm has to be big. This lemma holds because there is little noise in the Q\mathcal{Q} component of the distribution, and therefore the gradient of VtV_{t} will be large if ρt\rho_{t} is large.

For every t1η2λt\leq\frac{1}{\eta_{2}\lambda}, we have that if ρt=Ω(1N)\rho_{t}=\Omega\left(\frac{1}{N}\right), then w.h.p.

Now we use the above lemma to bound the number of times when ρt\rho_{t} is large.

In the setting of Lemma 5.2, let T\mathcal{T} be the set of iterations where ρtε22ε32\rho_{t}\geq\varepsilon_{2}^{\prime 2}\varepsilon_{3}^{2}, where ε3\varepsilon_{3} is defined in Lemma 5.2. Then w.h.p, T1rε28ε38η2.|\mathcal{T}|\lesssim\frac{1}{r\varepsilon_{2}^{\prime 8}\varepsilon_{3}^{8}\eta_{2}}.

Now if ρt\rho_{t} is small, the gradient accumulated on WtW_{t} from examples in M2\mathcal{M}_{2} must be small. We formalize this argument in our proof of Lemma 5.2 in Section D.2.

Lemma 5.3 will then follow by explicitly decomposing Wt\overline{W}_{t} into a component in span{x1(i)}iMˉ2\text{span}\{x_{1}^{(i)}\}_{i\in\bar{\mathcal{M}}_{2}} and some remainder, which is shown to be small by Lemma 5.2. This is presented in the below lemma, which is proved in Section D.3.

There exists real numbers {αk}kMˉ2\{\alpha_{k}\}_{k\in\bar{\mathcal{M}}_{2}} such that for every j[m]j\in[m], we have

with WtFO~(ε3d)\|\overline{W}_{t}^{\prime}\|_{F}\leq\widetilde{O}\left(\varepsilon_{3}\sqrt{d}\right).

This allows us to conclude Lemma 5.3 via computations carried out in Section D.3.

Finally, to complete the proof of Theorem 3.5, we will argue in Section B.2 that a classifier rtr_{t} of the form given by (5.2) cannot have small generalization error because it will be too heavily influenced by the noise in x1x_{1}.

Experiments

Our theory suggests that adding noise to the network could be an effective strategy to regularize a small learning rate in practice. We test this empirically by adding small Gaussian noise during training before every activation layer in a WideResNet16 architecture, as our analysis highlights pre-activation noise as a key regularization mechanism of SGD. The noise level is annealed over time. We demonstrate on CIFAR-10 images without data augmentation that this regularization can indeed counteract the negative effects of small learning rate, as we report a 4.72% increase in validation accuracy when adding noise to a small learning rate. Full details are in Section G.1.

We will also empirically demonstrate that the choice of large vs. small initial learning rate can indeed invert the learning order of different example types. We add a memorizable 7 ×\times 7 pixel patch to a subset of CIFAR-10 images following the scenario presented in (1.1), such that around 20% of images have no patch, 16% of images contain only a patch, and 64% contain both CIFAR-10 data and patch. We generate the patches so that they are not easily separable, as in our constructed Q\mathcal{Q}, but they are low in variation and therefore easy to memorize. Precise details on producing the data, including a visualization of the patch, are in Section G.2. We train on the modified dataset using WideResNet16 using 3 methods: large learning rate with annealing at the 30th epoch, small initial learning rate, and small learning rate with noise annealed at the 30th epoch.

Figure 3 depicts the validation accuracy vs. epoch on clean (no patch) and patch-only images. From the plots, it is apparent that the small learning rate picks up the signal in the patch very quickly, whereas the other two methods only memorize the patch after annealing.

From the validation accuracy on clean images, we can deduce that the small learning rate method is indeed learning the CIFAR images using a small fraction of all the available data, as the validation accuracy of a small LR model when training on the full dataset is around 83%, but the validation on clean data after training with the patch is 70%. We provide additional arguments in Section G.2. Our code for these experiments is online at the following link: https://github.com/cwein3/large-lr-code.

Conclusion

In this work, we show that the order in which a neural net learns to fit different types of patterns plays a crucial role in generalization. To demonstrate this, we construct a distribution on which models trained with large learning rates generalize provably better than those trained with small learning rates due to learning order. Our analysis reveals that more SGD noise, or larger learning rate, biases the model towards learning “generalizing” kernels rather than “memorizing” kernels. We confirm on articifially modified CIFAR-10 data that the scale of the learning rate can indeed influence learning order and generalization. Inspired by these findings, we propose a mitigation strategy that injects noise before the activations and works both theoretically for our construction and empirically. The design of better algorithms for regularizing learning order is an exciting question for future work.

Acknowledgements

CW acknowledges support from a NSF Graduate Research Fellowship.

References

Appendix A Basic Properties and Toolbox

In this section, we collect a few basic properties of the neural networks we are studying. In section F, we provide two lemmas on Gaussian random variables and perturbation theory of the matrices.

Let [L^(U)]i[\nabla\widehat{L}(U)]_{i} be the ii-th row of L^(U)\nabla\widehat{L}(U). We have that [L^(U)]i21/m\|[\nabla\widehat{L}(U)]_{i}\|_{2}\lesssim 1/\sqrt{m}.

For any tt, if γs=η\gamma_{s}=\eta for every sts\leq t, then we have that [Ut]i2min{1mλ,ηt/m}\|[\overline{U}_{t}]_{i}\|_{2}\lesssim\min\{\frac{1}{\sqrt{m}\lambda},\eta t/\sqrt{m}\} and UtF1λ\|\overline{U}_{t}\|_{F}\lesssim\frac{1}{\lambda}.

By equation (3.4) and Proposition A.2, we have that

Hence [U~x]+2U~x2τmx2\|[\widetilde{U}x]_{+}\|_{2}\leq\|\widetilde{U}x\|_{2}\lesssim\tau\sqrt{m}\|x\|_{2}. Now, since each uiu_{i} is i.i.d. uniform {m1/2,m1/2}\{-m^{-1/2},m^{1/2}\}, using the randomness of uiu_{i} we know that w.h.p.

Here in the last inequality we applied Lemma A.8. The second statement follows from Proposition A.4 and triangle inequality. ∎

We have the following Rademacher complexity bound:

In this section, we collect useful statements which will help with decoupling the signal U\overline{U} from the noise U~\widetilde{U} in our analysis. First, we observe that if the noise updates in the system stabilize at initialization, the marginal distribution of UtU_{t} is always the same as the initialization.

Under Assumption 3.1, suppose we run Algorithm 1. Then for any tt before annealing the learning rate, U~t\widetilde{U}_{t} has marginal distribution N(0,τ02Im×mId×d)\mathcal{N}(0,\tau_{0}^{2}I_{m\times m}\otimes I_{d\times d}). In other words, each entry of U~t\widetilde{U}_{t} follows N(0,τ02)\mathcal{N}(0,\tau_{0}^{2}) and they are independent with each others.

One nice aspect of the signal-noise decomposition is as follows: we use tools from to show that if the signal term U\overline{U} is small, then using only the noise component U~\widetilde{U} to compute the activations roughly preserves the output of the network. This facilitates our analysis of the network dynamics.

As we will often apply (A.11) with UF1λ\|\overline{U}\|_{F}\lesssim\frac{1}{\lambda}, for notational simplicity we denote throughout the paper εs=(1λτ0)4/3m1/3\varepsilon_{s}=\left(\frac{1}{\lambda\tau_{0}}\right)^{4/3}m^{-1/3}. By our choice of mpoly(d/τ0)m\geq\textup{poly}(d/\tau_{0}) we know that εsdΘ(1)\varepsilon_{s}\leq d^{-\Theta(1)}.

Appendix B Proof of Main Theorems

We start with the following lemma that shows that if gg has small training error on Mˉ1\bar{\mathcal{M}}_{1}, then the output of gg on x2x_{2} is large compared to x2\|x_{2}\|. This is because for the loss to be low, gg must have a good margin on x2x_{2}. However, as the norm of x2x_{2} is roughly uniform in $,theexampleswithsmallnormwillforce, the examples with small norm will forceg$ to have larger output.

W.h.p. for every t0t\geq 0 and every δ1qN\delta\geq\frac{1}{\sqrt{qN}}, as long as L^Mˉ1(gt0+t)δ\widehat{L}_{\bar{\mathcal{M}}_{1}}(g_{t_{0}+t})\leq\delta, we have that: for every (x,y)(x,y),

We use Mˉ1(1)\bar{\mathcal{M}}_{1}^{(1)} to denote the set of all x2(i)Mˉ1x_{2}^{(i)}\in\bar{\mathcal{M}}_{1} such that x2(i)=α(zζ)x_{2}^{(i)}=\alpha(z-\zeta). Similarly, we use Mˉ1(2)\bar{\mathcal{M}}_{1}^{(2)} to denote the set of all x2(i)Mˉ1x_{2}^{(i)}\in\bar{\mathcal{M}}_{1} such that x2(i)=α(z+ζ)x_{2}^{(i)}=\alpha(z+\zeta), and use Mˉ1(3)\bar{\mathcal{M}}_{1}^{(3)} to denote the set of all x2(i)Mˉ1x_{2}^{(i)}\in\bar{\mathcal{M}}_{1} such that x2(i)=αzx_{2}^{(i)}=\alpha z.

Let gt0+t(z+ζ)=ρ1,gt0+t(zζ)=ρ2,gt+t0(z)=ρ3g_{t_{0}+t}(z+\zeta)=\rho_{1},g_{t_{0}+t}(z-\zeta)=\rho_{2},g_{t+t_{0}}(z)=\rho_{3}. By the positive homogeneity of ReLU, we know that for every x2Mˉ1(i)x_{2}\in\bar{\mathcal{M}}_{1}^{(i)}, it holds:

Since L^Mˉ1(gt0+t)δ\widehat{L}_{\bar{\mathcal{M}}_{1}}(g_{t_{0}+t})\leq\delta, it holds that w.h.p. for every ii\in,

Our proof of Theorem 3.4 now amounts to carefully checking that all examples in M2\mathcal{M}_{2} are classified correctly, and the classifier rt0+tr_{t_{0}+t} will generalize well on Mˉ2\bar{\mathcal{M}}_{2}.

By Lemma 4.4, we know that for t=O~(1ε13η2r)t=\widetilde{O}\left(\frac{1}{\varepsilon_{1}^{3}\eta_{2}r}\right) we have L^Mˉ1(gt0+t)=O(ε1/q3)\widehat{L}_{\bar{\mathcal{M}}_{1}}(g_{t_{0}+t})=O(\sqrt{\varepsilon_{1}/q^{3}}). Thus applying Lemma B.1, we obtain that as long as ε11N\varepsilon_{1}\geq\frac{1}{\sqrt{N}} (which is implied by Assumption 3.3)

This implies that for x1Dx1x_{1}\sim\mathcal{D}_{x_{1}}, applying Lemma A.8 gives us

Moreover, applying Lemma A.6 on rt0+tr_{t_{0}+t} with Wt0+tF2Wt0F2+Wt0+tWt0F2(dlog21ε)\|W_{t_{0}+t}\|_{F}^{2}\leq\|W_{t_{0}}\|_{F}^{2}+\|W_{t_{0}+t}-W_{t_{0}}\|_{F}^{2}\lesssim\left(d\log^{2}\frac{1}{\varepsilon}\right) by Lemma 4.2 and Lemma 4.3, we have that

where we used the fact that ε1κ2p2q3\varepsilon_{1}\leq\kappa^{2}p^{2}q^{3}.

Here the last step uses the definition of ε1\varepsilon_{1} that ε1κ2p2q3\varepsilon_{1}\leq\kappa^{2}p^{2}q^{3}. ∎

B.2 Proof of Theorem 3.5

We will prove Theorem 3.5 using Lemma 5.3 by roughly arguing that the predictions made by rtr_{t} will be heavily influenced by a vector α\alpha in the low rank span of examples from Mˉ2\bar{\mathcal{M}}_{2}. With high probability, this vector α\alpha will be noisy and not align well with the ground truth ww^{\star}, leading to mispredictions.

Recall that ε2\varepsilon_{2}^{\prime} denotes the stopping criterion used in Theorem 3.5 and ε3=d1/321ε22\varepsilon_{3}=d^{-1/32}\frac{1}{\varepsilon_{2}^{\prime 2}}. Using Lemma 5.3, we know that w.h.p.

By Lemma F.2 we know that w.h.p. over the randomness of x1(i)x_{1}^{(i)}’s, for αspan{x1(i)}iMˉ2\alpha\in\text{span}\{x_{1}^{(i)}\}_{i\in\bar{\mathcal{M}}_{2}} we have as long as Npd/2Np\leq d/2: α,wα2w20.9\frac{\langle\alpha,w^{\star}\rangle}{\|\alpha\|_{2}\|w^{\star}\|_{2}}\leq 0.9. For every randomly chosen x1x_{1}, we can also write x1=γw+βx_{1}=\gamma w^{\star}+\beta where βw\beta\bot w^{\star} so β\beta is independent of γ\gamma, hence

Note that α,βN(0,σ2α22/d)\langle\alpha,\beta\rangle\sim\mathcal{N}(0,\sigma^{2}\|\alpha\|_{2}^{2}/d) with σ0.1\sigma\geq 0.1, and with probability at least 0.10.1, γ2α2/d\gamma\leq 2\|\alpha\|_{2}/\sqrt{d}. This implies that with probability at least Ω(1)\Omega(1) over a randomly chosen x1x_{1} we can have:

For β\beta, we know that with probability at least Ω(1)\Omega(1), we have:

Moreover, since β\beta is independent of γ\gamma, we know that with probability Ω(1)\Omega(1) both events can happen, in which case:

Thus, since α2=Ω(Np)\|\alpha\|_{2}=\Omega(\sqrt{Np}) by Lemma 5.3, we know that as long as

However, since w,x1<0\langle w^{\star},x_{1}\rangle<0, we know that either rt(x1)<0r_{t}(x_{1})<0, which results in rt(x1)<0r_{t}(-x_{1})<0 but w,x1>0\langle w^{\star},-x_{1}\rangle>0. So when x2=0x_{2}=0, the network classifies (x1,0)(-x_{1},0) incorrectly. On the other hand, we have when rt(x1)>0r_{t}(x_{1})>0 the network will classify (x1,0)(x_{1},0) incorrectly. Since w,x1<0\langle w^{\star},x_{1}\rangle<0 and rt(x1)rt(x1)r_{t}(x_{1})\geq r_{t}(-x_{1}) holds with probability Ω(1)\Omega(1), this shows that the test error is at least Ω(p)\Omega(p). ∎

Appendix C Proofs for Large Learning Rate Lemmas

To prove Lemma 4.1, we will show that the network will learn all examples with P\mathcal{P} component while the learning rate is large. The key to the proof is that although the large learning rate noise only allows the network to search over coarse kernels, P\mathcal{P} is still learnable by these kernels because of its linearly-separable structure. To make this precise, we decompose the weights UtU_{t} Into the signal and noise components, and show that there exists a fixed “target” signal matrix which will classify P\mathcal{P} correctly no matter the noise matrix.

Recall our definitions of ft(B;x)f_{t}(B;x), Kt(B)K_{t}(B) in (6.1) and (6.2), and that

Recall that Lemma 6.1 leverages the linearly-separable structure of P\mathcal{P} to find a “target” signal matrix that correctly classifies P\mathcal{P} w.h.p over the noise matrix. We state its proof below.

By proposition A.3, UtFO(1λ)\|\overline{U}_{t}\|_{F}\leq O\left(\frac{1}{\lambda}\right). We apply Lemma A.8 as follows: by Proposition A.7, U~t\widetilde{U}_{t}’s entry has marginal distribution N(0,τ02)\mathcal{N}(0,\tau_{0}^{2}) and therefore the column of U~t\widetilde{U}_{t} has distribution N(0,τ02Im×m)\mathcal{N}(0,\tau_{0}^{2}I_{m\times m}). Since w.h.p. x2logd\|x\|_{2}\lesssim\sqrt{\log d}, the coupling Lemma A.8 gives

On the other hand, we also have by Proposition A.5, using the fact that maxi[Ui]21mλ\max_{i}\|[\overline{U}_{i}]\|_{2}\lesssim\frac{1}{\sqrt{m}\lambda}, w.h.p.

Here in the last inequality we used the fact that the network is sufficiently over-parameterized so that εs=O~(τ0λ)\varepsilon_{s}=\widetilde{O}(\tau_{0}\lambda).

Using (C.4), noting that our choice of m,λ,τ0m,\lambda,\tau_{0} satisfies τ0logd=o(ε1)\tau_{0}\log d=o(\varepsilon_{1}), we conclude

For the term NUt(u,U;x)N_{U_{t}}(u,U^{*};x), we know that

Note that entries of W~tx1\widetilde{W}_{t}x_{1} are i.i.d. random Bernoulli(1/21/2), thus we know that w.h.p.

Thus, by our choice that m1/3=O(ε1)m^{-1/3}=O(\varepsilon_{1}) and dεs=O(ε1)\sqrt{d}\varepsilon_{s}=O(\varepsilon_{1}),

By definition of ww^{\star}, we know that

Now we wish to argue that even though the noise matrix is changing, gradient descent will still find the fixed target signal matrix UU^{\star}. This leverages the fact that once we fix the activation patterns, we can view each step of the optimization as gradient descent with respect to a convex, but changing, function. Below we provide a proof of Theorem 6.2, which allows for optimization of this changing function.

For the sake of contradiction, we assume that Kt(zt)c+μK_{t}(z_{t})\geq c^{\star}+\mu for all tTt\leq T. Using the definition of KtλK_{t}^{\lambda}, we have that the update rule of ztz_{t} can be written as

Assuming that ztz2R\|z_{t}-z^{\star}\|_{2}\leq R, we have that as long as λR21100μ\lambda R^{2}\leq\frac{1}{100}\mu and ημ100(λ2R2+L2)\eta\leq\frac{\mu}{100(\lambda^{2}R^{2}+L^{2})}, we have:

C.2 Proof of Lemma 4.2

In this section, we will often consider the activation patterns on the inputs z,zζ,z+ζz,z-\zeta,z+\zeta at various time steps. For convenience, we have the following definition:

For any ss, and vector ww, let Esw{i[m]:[V~s]iw0}\mathcal{E}^{w}_{s}\triangleq\{i\in[m]:[\widetilde{V}_{s}]_{i}w\geq 0\} denote the set of neurons that have positive pre-activation on the input ww (with weights V~s\widetilde{V}_{s}), and Eˉsw{i[m]:[V~s]iw<0}\bar{\mathcal{E}}^{w}_{s}\triangleq\{i\in[m]:[\widetilde{V}_{s}]_{i}w<0\} be the set of neurons with negative pre-activations on the input ww. (We will mostly be interested in the quantities Ezζ,Eˉzζ,Ez+ζ,Eˉz+ζ\mathcal{E}^{z-\zeta},\bar{\mathcal{E}}^{z-\zeta},\mathcal{E}^{z+\zeta},\bar{\mathcal{E}}^{z+\zeta} and their intersections.)

Let Qtdiag(v)VtQ_{t}\triangleq\operatorname*{\text{diag}}(v)\overline{V}_{t}. Then, we have that

Towards bounding the terms in equation (C.25), we will need to reason about the activations patterns of z,zζ,z+ζz,z-\zeta,z+\zeta at various time steps. We first show that the activation patterns of zζz-\zeta and z+ζz+\zeta have to agree in most of neurons except an r\approx r fraction of them. This will be useful to show that the second term of the RHS of equation (C.25) is small.

In the setting of Lemma C.2, w.h.p over the randomness of the initialization and all the randomness in the algorithm, for every tpoly(d),i[m]t\leq\textup{poly}(d),i\in[m], iEtzζEtz+ζi\in\mathcal{E}^{z-\zeta}_{t}\oplus\mathcal{E}^{z+\zeta}_{t} implies that [V~t]izτ0rlogd|[\widetilde{V}_{t}]_{i}z|\lesssim\tau_{0}r\sqrt{\log d}. Moreover, the size of the set EtzζEtz+ζ\mathcal{E}^{z-\zeta}_{t}\oplus\mathcal{E}^{z+\zeta}_{t} is bounded by

Recall that ζ2=r\|\zeta\|_{2}=r and by Proposition A.7 [V~t]i[\widetilde{V}_{t}]_{i} has distribution N(0,τ02Id×d)\mathcal{N}(0,\tau_{0}^{2}I_{d\times d}). Therefore, by standard Gaussian concentration and union bound, with high probability over the randomness of the initialization and the algorithm, for all tpoly(d)t\leq\textup{poly}(d),

Moreover, note that Pr[[V~t]izτ0rlogd]rlogd\Pr\left[[|\widetilde{V}_{t}]_{i}z|\leq\tau_{0}r\sqrt{\log d}\right]\lesssim r\sqrt{\log d}. By the independence between [V~t]i[\widetilde{V}_{t}]_{i}’s and standard concentration inequalities (Bernstein inequality), we have that with high probability, there are at most rmlogd+logdrm\sqrt{\log d}+\log d entries i[m]i\in[m] satisfying [V~t]izτ0rlogd|[\widetilde{V}_{t}]_{i}z|\leq\tau_{0}r\sqrt{\log d}. Together with the first part of the lemma, and that mm is sufficiently large so that rmlogd+logdrmlogdrm\sqrt{\log d}+\log d\lesssim rm\sqrt{\log d}, we complete the proof of equation (C.26). ∎

We use the lemma above to conclude that the second term in the decomposition (C.25) is at most on the order of r2/λr^{2}/\lambda.

In the setting of Lemma C.2, we have that

By the definition of our algorithm, before annealing the learning rate, we have

Using Proposition A.3 and that vi=1m|v_{i}|=\frac{1}{\sqrt{m}}, we have that [Qt]i21λm\|[Q_{t}]_{i}\|_{2}\lesssim\frac{1}{\lambda m}. It follows that

Equation above and equation (C.30) complete the proof. ∎

The following lemma decomposes QQ into a sum of the contribution of the gradient from all the previous steps.

In the setting of Lemma C.2, let ΔQtdiag(v)VL^(Ut)\Delta Q_{t}\triangleq\operatorname*{\text{diag}}(v)\nabla_{V}\widehat{L}(U_{t}). (ΔQt\Delta Q_{t} can be viewed as the raw change of QtQ_{t} at the time step tt without considering the effect of the regularizer.) We have that

Using the fact that z21\|z\|_{2}\leq 1 we complete the proof. ∎

Define the analog of Esw\mathcal{E}_{s}^{w} with VtV_{t} to compute the activation pattern: for any ss, and vector ww, let Gsw{i[m]:[Vs]iw0}\mathcal{G}^{w}_{s}\triangleq\{i\in[m]:[V_{s}]_{i}w\geq 0\} and define Gˉsw{i[m]:[Vs]iw<0}\bar{\mathcal{G}}^{w}_{s}\triangleq\{i\in[m]:[V_{s}]_{i}w<0\} similarly.

Suppose at some iteration ss, zζz-\zeta and z+ζz+\zeta have the same activation pattern at neuron ii and jj in the sense that i,jGszζGsz+ζi,j\in\mathcal{G}^{z-\zeta}_{s}\cap\mathcal{G}^{z+\zeta}_{s}, or i,jGˉszζGˉsz+ζi,j\in\bar{\mathcal{G}}^{z-\zeta}_{s}\cap\bar{\mathcal{G}}^{z+\zeta}_{s}. Then the corresponding gradient update at that iteration for the weight vectors associated with ii and jj are the same up to a potential sign flip:

Moreover, suppose we have that i,ji,j satisfy that [V~s]ixτ0rlogd[\widetilde{V}_{s}]_{i}x\gtrsim\tau_{0}r\sqrt{\log d} and [V~s]jxτ0rlogd[\widetilde{V}_{s}]_{j}x\gtrsim\tau_{0}r\sqrt{\log d} (or [V~s]ixτ0rlogd[\widetilde{V}_{s}]_{i}x\lesssim-\tau_{0}r\sqrt{\log d} and [V~s]jxτ0rlogd[\widetilde{V}_{s}]_{j}x\lesssim-\tau_{0}r\sqrt{\log d}) for x{zζ,z+ζ}x\in\{z-\zeta,z+\zeta\}, then the same conclusion holds for ii and jj.

Note that by definition, [ΔQs]i=vi[VL^(Us)]i[\Delta Q_{s}]_{i}=v_{i}[\nabla_{V}\widehat{L}(U_{s})]_{i}, and thus it suffices to prove that vi[VL^(Us)]i=vj[VL^(Us)]jv_{i}[\nabla_{V}\widehat{L}(U_{s})]_{i}=v_{j}[\nabla_{V}\widehat{L}(U_{s})]_{j}. By Proposition A.1, we have that

Note that x2x_{2} can only take (a positive scaling of) four values zζ,z,z+ζ,0z-\zeta,z,z+\zeta,0. We claim that for every choice of these four values, for the i,ji,j satisfying the condition of the lemma, we have

Note that the equation above together with vi2=vj2=1v_{i}^{2}=v_{j}^{2}=1 suffices to complete the proof.

Now to prove the second part of the lemma, suppose i,ji,j satisfy that [V~s]ixτ0rlogd[\widetilde{V}_{s}]_{i}x\gtrsim\tau_{0}r\sqrt{\log d} and [V~s]jxτ0rlogd[\widetilde{V}_{s}]_{j}x\gtrsim\tau_{0}r\sqrt{\log d} for x{zζ,z+ζ}x\in\{z-\zeta,z+\zeta\}. Using [V~s]i21λm\|[\widetilde{V}_{s}]_{i}\|_{2}\leq\frac{1}{\lambda\sqrt{m}} from Proposition A.3, we have that [Vs]iz[V~s]iz[Vs]izτ0rlogdO(1λm)τ0rlogd[V_{s}]_{i}z\geq[\widetilde{V}_{s}]_{i}z-|[\overline{V}_{s}]_{i}z|\gtrsim\tau_{0}r\sqrt{\log d}-O(\frac{1}{\lambda\sqrt{m}})\geq\tau_{0}r\sqrt{\log d} where used the assumption that 1/λ=poly(d)1/\lambda=\textup{poly}(d) and m=poly(d/τ0)m=\textup{poly}(d/\tau_{0}). Therefore, we conclude that i,jGszζGsz+ζi,j\in\mathcal{G}^{z-\zeta}_{s}\cap\mathcal{G}^{z+\zeta}_{s}. Now by the first lemma of the lemma we complete the proof. ∎

Now we are ready to bound the first term on the RHS of equation C.25, which is the crux of the proofs in this section. The key here is to get a bound that scales quadratically in rr.

In the setting of Lemma C.2, let ΔQs\Delta Q_{s} be defined in Proposition C.5. Then, we have that

As a direct corollary of the equation above and Proposition C.5, we have that

By the set operations and the facts that EtzζEtz+ζEtz\mathcal{E}^{z-\zeta}_{t}\cap\mathcal{E}^{z+\zeta}_{t}\subset\mathcal{E}^{z}_{t} and that EtzEtzζEtz+ζ\mathcal{E}^{z}_{t}\subset\mathcal{E}^{z-\zeta}_{t}\cup\mathcal{E}^{z+\zeta}_{t}, we have that

where the ,\lesssim,\gtrsim notations hide universal constants that make the first conclusion of Proposition C.3 true. By the second part of Proposition C.3 (or more directly equation (C.28)), we have that Fs+EszζEsz+ζ\mathcal{F}^{+}_{s}\subset\mathcal{E}^{z-\zeta}_{s}\cap\mathcal{E}^{z+\zeta}_{s}, and FsEˉszζEˉsz+ζ\mathcal{F}^{-}_{s}\subset\bar{\mathcal{E}}^{z-\zeta}_{s}\cap\bar{\mathcal{E}}^{z+\zeta}_{s}. By Proposition C.6, we have that for any i,jFsi,j\in\mathcal{F}^{-}_{s}, [ΔQs]i=[ΔQs]j[\Delta Q_{s}]_{i}=[\Delta Q_{s}]_{j}. For notational simplicity, let A=Etz+ζ\EtzA=\mathcal{E}^{z+\zeta}_{t}\backslash\mathcal{E}^{z}_{t} and B=Etz\EtzζB=\mathcal{E}^{z}_{t}\backslash\mathcal{E}^{z-\zeta}_{t}. Therefore it follows that

where in the last inequality we use that for any i,jFsi,j\in\mathcal{F}^{-}_{s}, [ΔQs]i=[ΔQs]j[\Delta Q_{s}]_{i}=[\Delta Q_{s}]_{j}, and the fact that [ΔQs]i2=1m[VL^(Us)]i21/m\|[\Delta Q_{s}]_{i}\|_{2}=\frac{1}{\sqrt{m}}\|[\nabla_{V}\widehat{L}(U_{s})]_{i}\|_{2}\leq 1/m (by Proposition A.2.)

Note that the distribution of ([V~s]i,[V~t]i([\widetilde{V}_{s}]_{i},[\widetilde{V}_{t}]_{i}’s are independent across the choice of ii. Thus we will compute Pr[iEtz+ζ,iEtz,iFs+]Pr[iEtz,iEtzζ,iFs+]\Pr[i\in\mathcal{E}^{z+\zeta}_{t},i\notin\mathcal{E}^{z}_{t},i\in\mathcal{F}^{+}_{s}]-\Pr[i\in\mathcal{E}^{z}_{t},i\notin\mathcal{E}^{z-\zeta}_{t},i\in\mathcal{F}^{+}_{s}] and then apply concentration concentration inequality for the sum. Note that the event here depends on three quantities [V~s]iz[\widetilde{V}_{s}]_{i}z, [V~t]iz[\widetilde{V}_{t}]_{i}z, and [V~t]iζ[\widetilde{V}_{t}]_{i}\zeta. First of all, [V~t]iζ[\widetilde{V}_{t}]_{i}\zeta is independent of these other two because ζ\zeta is orthogonal to zz and [V~t]i[\widetilde{V}_{t}]_{i} and [V~s]i[\widetilde{V}_{s}]_{i} have spherical covariance matrices.

By the definition of V~s,V~t\widetilde{V}_{s},\widetilde{V}_{t}, we can express their relationship by writing [V~t]iz=(1η1λ)ts[V~s]iz+[Ξt,s]iz[\widetilde{V}_{t}]_{i}z=(1-\eta_{1}\lambda)^{t-s}[\widetilde{V}_{s}]_{i}z+[\Xi_{t,s}]_{i}z, where Ξt,s=η1j[ts](1η1λ)tsjξs+j\Xi_{t,s}=\eta_{1}\sum_{j\in[t-s]}(1-\eta_{1}\lambda)^{t-s-j}\xi_{s+j}. Recall that by proposition A.7, we have [V~s]izN(0,τ02)[\widetilde{V}_{s}]_{i}z\sim\mathcal{N}(0,\tau_{0}^{2}) and [Ξt,s]iz[\Xi_{t,s}]_{i}z are two independent Gaussians. Let σt,s\sigma_{t,s} be the variance of [Ξt,s]iz[\Xi_{t,s}]_{i}z. We compute σt,s\sigma_{t,s} by observing that

Note that ζz=0\zeta^{\top}z=0, thus [V~s]iz[\widetilde{V}_{s}]_{i}z is independent of [V~t]iζ[\widetilde{V}_{t}]_{i}\zeta conditioned on [V~t]iz[\widetilde{V}_{t}]_{i}z, for every sts\leq t . For notational simplicity, let Y1=[V~s]izY_{1}=[\widetilde{V}_{s}]_{i}z, Y2=[V~t]izY_{2}=[\widetilde{V}_{t}]_{i}z, and Y3=[V~t]iζY_{3}=[\widetilde{V}_{t}]_{i}\zeta, and κ=O(τ0rlogd)\kappa=O(\tau_{0}r\sqrt{\log d}) where the big O notation hide the same constant factor used in defining Fs+\mathcal{F}^{+}_{s} in equation (C.40). Let Y4=[Ξt,s]iz=Y1βY2Y_{4}=[\Xi_{t,s}]_{i}z=Y_{1}-\beta Y_{2} where β=η1(1η1λ)ts1\beta=\eta_{1}(1-\eta_{1}\lambda)^{t-s}\gtrsim 1 (because t1/(η1λ)t\leq 1/(\eta_{1}\lambda)). Note that by the calculation above, Y4Y_{4} has standard deviation σs,t\sigma_{s,t} which is bounded from below by τ0λη1(st)\tau_{0}\sqrt{\lambda\eta_{1}(s-t)}. Then, we have that

Now by equation (C.42) and standard concentration inequality, and the fact that mm is sufficiently large, we have that with high probability,

Using standard concentration inequality and the fact that mm is sufficiently large, we have that with high probability,

We can also prove the same bound for BFsc|B\cap\mathcal{F}^{c}_{s}| analogously. Using equation (C.41) and the several equations above, we conclude that

where the last step uses that the condition that t1/(η1λ)t\leq 1/(\eta_{1}\lambda).

Now combining the Propositions above we are ready to prove Lemma 4.2.

Using triangle inequality, Proposition A.8, and equation (A.6) of Proposition A.5, we have that for any xx of norm O(1)O(1),

C.3 Proof of Lemma 6.3

The proof of Lemma 6.3 relies on the fact that a smaller learning rate preserves the noise generated from the timestep before annealing. This allows us to reason that the new activations are similar to the original before reducing the learning rate.

where Ξt:=η2jt(1λη2)tjξt0+j\Xi_{t}:=\eta_{2}\sum_{j\leq t}(1-\lambda\eta_{2})^{t-j}\xi_{t_{0}+j}. By properties of a sum of Independent Gaussians, we have [Ξt]iN(0,σt2I)[\Xi_{t}]_{i}\sim\mathcal{N}(0,\sigma_{t}^{2}I) where σt\sigma_{t} is the standard deviation of each entry of Ξt\Xi_{t}. We also have that Ξt\Xi_{t} is independent of U~t0\widetilde{U}_{t_{0}}. Moreover, for every t1η2λt\leq\frac{1}{\eta_{2}\lambda}, the standard deviation σt\sigma_{t} can be bounded by

(Note that since η2η1\eta_{2}\ll\eta_{1}, we should expect that the standard deviations satisfy σtσ0\sigma_{t}\ll\sigma_{0}. That is, the additional randomness introduced in the pre-activation is small.)

On the other hand, for every t1η2λt\leq\frac{1}{\eta_{2}\lambda}, the contribution of U~t0\widetilde{U}_{t_{0}} to Ut+t0U_{t+t_{0}} is still present because the entry of (1η2λ)t[U~t0]i(1-\eta_{2}\lambda)^{t}[\widetilde{U}_{t_{0}}]_{i} has variance at least on the order of the variance of the entries of [U~t0]i[\widetilde{U}_{t_{0}}]_{i}, which is τ02\gtrsim\tau_{0}^{2}. This also implies that the variance of the entries of U~t0+t\widetilde{U}_{t_{0}+t} is lower bounded by the variance of (1η2λ)t[U~t0]i(1-\eta_{2}\lambda)^{t}[\widetilde{U}_{t_{0}}]_{i}. This in turn is lower bounded by τ02\tau_{0}^{2} up to constant factor.

Therefore, using the decomposition (C.59) and the bounds above, we should expect that the sign of Ut0+tU_{t_{0}+t} strongly correlates with the the sign of Ut0U_{t_{0}}, which will be formally shown below. Using Lemma A.8, we have that the activation pattern is mostly decided by the noise part (U~t+t0\widetilde{U}_{t+t_{0}} and U~t0\widetilde{U}_{t_{0}}), in the sense that for every xx,

Fixing xx, we can decompose our target to

We’ve bounded the first and third term on the RHS of the equation above. For the middle term, let αi=(1η2λ)t[U~t0]ix\alpha_{i}=(1-\eta_{2}\lambda)^{t}[\widetilde{U}_{t_{0}}]_{i}x and βi=[Ξt+t0]ix\beta_{i}=[\Xi_{t+t_{0}}]_{i}x. Note that [U~t+t0]ix=αi+βi[\widetilde{U}_{t+t_{0}}]_{i}x=\alpha_{i}+\beta_{i} and that αi\alpha_{i} and βi\beta_{i} are zero-mean independent Gaussian random variables with variance τ02x2\gtrsim\tau_{0}^{2}\|x\|^{2} and variance η2τ02x2/η1\lesssim\eta_{2}\tau_{0}^{2}\|x\|^{2}/\eta_{1}, respectively. The basic property of Gaussian random variable implies that

Since αi,βi\alpha_{i},\beta_{i}’s are independent, by basic concentration inequality (e.g., Bernstein inequality or Hoeffding inequality), we have that with high probability

Combining the equation above with equation (C.61), (C.62),and (C.64) completes the proof for the first part.

where the last inequality is due to maxi[Ut0+t]i2=O(1/mλ)\max_{i}\|[\overline{U}_{t_{0}+t}]_{i}\|_{2}=O(1/\sqrt{m}\lambda) by Proposition A.3, and bounding NUt0+t(u,U~t0+t;x)εsλ+τ0logd\left|N_{U_{t_{0}+t}}(u,\widetilde{U}_{t_{0}+t};x)\right|\lesssim\frac{\varepsilon_{s}}{\lambda}+\tau_{0}\log d by Proposition A.5. ∎

We note that this lemma also applies to the setting when t0=0t_{0}=0, i.e. we start with an initial small learning rate and compare to the random initialization. This is useful for the proofs in the small initial learning rate setting.

C.4 Proof of Lemma 4.3

We will now show that the network learns patterns from Q\mathcal{Q} once the learning rate is annealed by constructing a common target for the network at every subsequent time step. We will then use Theorem 6.2 to show that the optimization finds this target. Let us define

Formally, we first show the following proposition, which proves the existence of a target solution that has good accuracy on Mˉ1\bar{\mathcal{M}}_{1} and does not unlearn the network’s progress on M1\mathcal{M}_{1}:

In the setting of Lemma 4.3, let Kt(B)K_{t}(B) be defined in equation (6.2). Then, there exists a solution UU^{*} satisfying UF2=O~(1ε12r)\|U^{*}\|_{F}^{2}=\widetilde{O}\left(\frac{1}{\varepsilon_{1}^{2}r}\right) and

To prove this proposition, we need the following lemma:

Suppose gtg_{t} satisfies that gt(z+ζ)+gt(zζ)2gt(z)δ\left|g_{t}(z+\zeta)+g_{t}(z-\zeta)-2g_{t}(z)\right|\leq\delta for some δ1\delta\lesssim 1. Then, we have that

And moreover, if L^Mˉ1(u,U)log2+O(δ)\widehat{L}_{\bar{\mathcal{M}}_{1}}(u,U)\leq\log 2+O(\delta^{\prime}) for some δδ\delta^{\prime}\geq\delta, then the prediction of gtg_{t} on zζ,z,z+ζz-\zeta,z,z+\zeta satisfies gt(zζ),gt(z+ζ),gt(z)=O(δ+logd/qN)|g_{t}(z-\zeta)|,|g_{t}(z+\zeta)|,|g_{t}(z)|=O(\sqrt{\delta^{\prime}+\log d/\sqrt{qN}}).

For convenience, let us denote gt(z+δ)=u,gt(zδ)=v,gt(z)=(u+v)/2+γg_{t}(z+\delta)=u,g_{t}(z-\delta)=v,g_{t}(z)=(u+v)/2+\gamma. By our assumption, we have that γδ|\gamma|\leq\delta.

Let h(z):=log11+ezh(z):=-\log\frac{1}{1+e^{-z}}. We have that w.h.p, for c=O(logd/qN)c=O(\log d/\sqrt{qN}),

and the factor of 1c1-c comes from the fact that the fraction of examples that are zζ,z+ζ,zz-\zeta,z+\zeta,z will be 1/4±O(logd/qN)1/4\pm O(\log d/\sqrt{qN}), 1/4±O(logd/qN)1/4\pm O(\log d/\sqrt{qN}), 1/2±O(logd/qN)1/2\pm O(\log d/\sqrt{qN}), respectively, w.h.p. Since the function h(z)h(z) is a 2-Lip function, we know that

The equation above together with the assumption L^Mˉ1(u,U)log2+O(δ)\widehat{L}_{\bar{\mathcal{M}}_{1}}(u,U)\leq\log 2+O(\delta^{\prime}) implies that

which implies that h((u+v)/2)+h((u+v)/2)2h(0)+ΔO(δ)+O(c)h((u+v)/2)+h(-(u+v)/2)-2h(0)+\Delta\leq O(\delta^{\prime})+O(c). It follows that h((u+v)/2)+h((u+v)/2)2h(0)O(δ)+O(c)h((u+v)/2)+h(-(u+v)/2)-2h(0)\leq O(\delta^{\prime})+O(c) and ΔO(δ)+O(c)\Delta\leq O(\delta^{\prime})+O(c). Now we note that By the strict convexity of h(z)h(z), we can easily conclude that u,vO(δ+c)|u|,|v|\leq O(\sqrt{\delta^{\prime}+c}).∎

Next, we will bound ε0\varepsilon_{0} and the value of gt0g_{t_{0}}. This allows us to conclude that gt0g_{t_{0}} is small, so that it is easy to “unlearn” once the learning rate is annealed.

Suppose the condition in Lemma 4.1 holds. Then

Since Lt0qlog2+ε1L_{t_{0}}\leq q\log 2+\varepsilon_{1}, we know that L^Mˉ1(u,Ut0)log2+2ε1/q\widehat{L}_{\bar{\mathcal{M}}_{1}}(u,U_{t_{0}})\leq\log 2+2\varepsilon_{1}/q. Applying Proposition C.9 with δ=ε1\delta^{\prime}=\varepsilon_{1} and δ=O(r2/λ)=O(ε1)\delta=O(r^{2}/\lambda)=O(\varepsilon_{1}), we have that gt0(z),gt0(z+ζ),gt0(zζ)O(ε1/q)|g_{t_{0}}(z)|,|g_{t_{0}}(z+\zeta)|,|g_{t_{0}}(z-\zeta)|\leq O(\sqrt{\varepsilon_{1}/q}) and L^Mˉ1(u,Ut0)log2ε1\widehat{L}_{\bar{\mathcal{M}}_{1}}(u,U_{t_{0}})\geq\log 2-\varepsilon_{1}.

Now we will complete the proof of Proposition C.8.

Let us define sets E1,E2,E3\mathcal{E}_{1},\mathcal{E}_{2},\mathcal{E}_{3} as the following:

for some sufficiently large universal constant cc.

Note that the random noise vector [V~t0]i[\widetilde{V}_{t_{0}}]_{i} will satisfy the condition for set Ei\mathcal{E}_{i} with probability proportional to the angle between zζz-\zeta and zz, which is r±O(r2)r\pm O(r^{2}) by Taylor approximation of arcsin\arcsin. Thus, as Vt0V_{t_{0}} and V~t0\widetilde{V}_{t_{0}} differ in at most εsm\varepsilon_{s}m activations, w.h.p., E1,E2,E3=12πrm±O~(r2m+m)±εsm|\mathcal{E}_{1}|,|\mathcal{E}_{2}|,|\mathcal{E}_{3}|=\frac{1}{2\pi}rm\pm\widetilde{O}\left(r^{2}m+\sqrt{m}\right)\pm\varepsilon_{s}m. This implies that

Hence we can also easily conclude that for every x2{α(zζ),αz,α(z+ζ)}x_{2}\in\{\alpha(z-\zeta),\alpha z,\alpha(z+\zeta)\},

On the other hand we have that by Lemma C.10, it holds that

Now the first term equals NVt0(v,Vt0;x2)=O(1)|N_{V_{t_{0}}}(v,\overline{V}_{t_{0}};x_{2})|=O(1), and the second term is bounded by

using Proposition A.3 to upper bound [Vt0]i2\|[\overline{V}_{t_{0}}]_{i}\|_{2}. Thus, it follows that yNVt0+t(v,Vt0;x2)=O(1)|yN_{V_{t_{0}+t}}(v,\overline{V}_{t_{0}};x_{2})|=O(1).

It follows that for every x2{zζ,z,z+ζ}x_{2}\in\{z-\zeta,z,z+\zeta\} and its corresponding label yy, as long as x22ε1\|x_{2}\|_{2}\geq\varepsilon_{1},

The last inequality follows from our choice of parameters such that τ0logdqε1\tau_{0}\log d\leq q\varepsilon_{1}. Putting together Eq (C.105) and (C.107) and defining U=(0,V)U^{*}=(0,V^{*}), we have that

By proposition C.8, there exists VV^{*} with VF2O~(1rε12)\|V^{*}\|_{F}^{2}\leq\widetilde{O}\left(\frac{1}{r\varepsilon_{1}^{2}}\right) such that for every t1η2λt\leq\frac{1}{\eta_{2}\lambda},

By Theorem 6.2, with z=(Wt0,V)z^{*}=(\overline{W}_{t_{0}},V^{*}), starting from z0=(Wt0,Vt0)z_{0}=(\overline{W}_{t_{0}},\overline{V}_{t_{0}}), we can take R2=O~(1rε12)R^{2}=\widetilde{O}\left(\frac{1}{r\varepsilon_{1}^{2}}\right), L=1L=1, μ=ε1\mu=\varepsilon_{1} to conclude that the algorithm converges to ε0+2ε1\varepsilon_{0}+2\varepsilon_{1} in O~(1η2rε13)\widetilde{O}\left(\frac{1}{\eta_{2}r\varepsilon_{1}^{3}}\right) iterations. Applying Lemma C.10 to bound ε0\varepsilon_{0} completes the proof.∎

C.5 Proof of Lemma 4.4

By the 1-Lipschitzness of logistic loss, we know that

To bound this term, we can directly use Cauchy-Shwartz and obtain that:

We can further bound rt0+t(x1(i))rt0(x1(i))r_{t_{0}+t}(x_{1}^{(i)})-r_{t_{0}}(x_{1}^{(i)}) by applying Lemma 6.3, as from our choice of parameters η2η1ε14λ2,εs/λε12\eta_{2}\leq\eta_{1}\varepsilon_{1}^{4}\lambda^{2},\varepsilon_{s}/\lambda\leq\varepsilon_{1}^{2}, τ0logdε12\tau_{0}\log d\leq\varepsilon_{1}^{2}:

Now, let us denote X=(x(i))i[N]X=(x^{(i)})_{i\in[N]} as the data matrix. By the standard Gaussian matrix spectral norm bound we know that w.h.p. X2210Nd\|X\|_{2}^{2}\leq 10\frac{N}{d}.

By (C.113) and our definition of ε0\varepsilon_{0} as

Using the bound on ε0\varepsilon_{0} that ε0=O(ε1/q)\varepsilon_{0}=O(\sqrt{\varepsilon_{1}/q}) by Lemma C.10, we conclude the bound on L^M1(rt0+t)\widehat{L}_{\mathcal{M}_{1}}(r_{t_{0}+t}).

In the end, by L^Mˉ1(gt0+t)L^t0+t\widehat{L}_{\bar{\mathcal{M}}_{1}}(g_{t_{0}+t})\leq\widehat{L}_{t_{0}+t} and the assumption that L^t0+tO(ε1/q)\widehat{L}_{t_{0}+t}\leq O(\sqrt{\varepsilon_{1}/q}) , it must hold that (since Mˉ1=qN|\bar{\mathcal{M}}_{1}|=qN)

Appendix D Proofs for Small Learning Rate

In the setting of theorem 3.5, there exists a solution UU^{\star} satisfying a) UF2O~(1ε22r+Np)\|U^{\star}\|_{F}^{2}\leq\widetilde{O}(\frac{1}{\varepsilon_{2}^{\prime 2}r}+Np) and b) for every t1η2λt\leq\frac{1}{\eta_{2}\lambda},

To prove Lemma 5.1, we can apply an identical analysis as 4.3 to show that for t=O~(1η2ε23r)t^{\prime}=\widetilde{O}\left(\frac{1}{\eta_{2}\varepsilon_{2}^{\prime 3}r}\right), L^M2(Ut)ε2\widehat{L}_{\mathcal{M}_{2}}(U_{t^{\prime}})\leq\varepsilon_{2}^{\prime}. The rest of the proof follows from combining Theorem 6.2 and Lemma D.1.

D.2 Proof of Lemma 5.2

Recall the expression ρt\rho_{t} defined in (6.10). We first prove Lemma 6.4 here, which says that if ρt\rho_{t} is large (which means the loss is large as well), then the total gradient norm has to be big.

For notation simplicity, let’s fix tt and let

The gradient with respect to VV can be computed by

Let us denote the set S2,1(α0),S2,2(α0),S2,3(α0)\mathcal{S}_{2,1}^{(\alpha_{0})},\mathcal{S}_{2,2}^{(\alpha_{0})},\mathcal{S}_{2,3}^{(\alpha_{0})} as:

i.e., the loss gradient using activations computed by the noise component of VtV_{t} scaled by a factor of NmvkNmv_{k}.

By the Geometry of ReLU Lemma D.2, we have that w.h.p.

Where the last inequality is obtained since for every jS2,j(0)j\in\mathcal{S}_{2,j^{\prime}}^{(0)}, QjQ_{j} has the same sign.

Note that αjU(0,1)\alpha_{j}\sim U(0,1), and therefore for every fixed α01N\alpha_{0}\geq\frac{1}{\sqrt{N}}, w.h.p. there are O(Nα0)O(N\alpha_{0}) many αj\alpha_{j} such that αjα0\alpha_{j}\leq\alpha_{0}. For each of them, we also know that Qj1|Q_{j}|\leq 1, which implies that

Picking α0=Θ(ρt)\alpha_{0}=\Theta(\sqrt{\rho_{t}}), we complete the proof by our choice of mN101(λτ0)4m\geq N^{10}\frac{1}{(\lambda\tau_{0})^{4}}. ∎

Now we prove Proposition 6.5, which bounds the number of iterations in which ρt\rho_{t} can be large.

Consider the function Fs(x):=NU0(u,Us;x)\mathcal{F}_{s}(x):=N_{U_{0}}(u,\overline{U}_{s};x), and let us define Gs+1(x):=NU0(u,Usη21η2λL^(Us);x)\mathcal{G}_{s+1}(x):=N_{U_{0}}(u,\overline{U}_{s}-\frac{\eta_{2}}{1-\eta_{2}\lambda}\nabla\widehat{L}(U_{s});x). We have that since Us+1=(1η2λ)Usη2L^(Us)\overline{U}_{s+1}=(1-\eta_{2}\lambda)\overline{U}_{s}-\eta_{2}\nabla\widehat{L}(U_{s}),

Now, by standard gradient descent analysis, we have that (as the logistic loss has Lipschitz derivative and the data have bounded norm):

Next, we will bound L^(Us)L^(Fs)F\|\nabla\widehat{L}(U_{s})-\nabla\widehat{L}(\mathcal{F}_{s})\|_{F}. We can compute

Plugging this back into (D.29), by the coupling Lemma 6.3 we obtain the bound

This implies that for η2λ<0.1\eta_{2}\lambda<0.1,

which implies that for every t1η2λt\leq\frac{1}{\eta_{2}\lambda}, as long as η2,εc=O(λ)\eta_{2},\varepsilon_{c}=O(\lambda) , we have:

By Lemma 6.4, we have that if ρtε22ε32\rho_{t}\geq{\varepsilon_{2}^{\prime}}^{2}\varepsilon_{3}^{2}, then L^(Us)F2rε28ε38\|\nabla\widehat{L}(U_{s})\|_{F}^{2}\geq r{\varepsilon_{2}^{\prime}}^{8}\varepsilon_{3}^{8}. It follows that there will be at most O(1rε28ε38η2)O(\frac{1}{r{\varepsilon_{2}^{\prime}}^{8}\varepsilon_{3}^{8}\eta_{2}}) such tt.

Finally, we complete the proof of Lemma 5.2 by noting that ρt\rho_{t} cannot be large for very many iterations, and therefore WtW_{t} will not obtain much signal from the P\mathcal{P} component of examples in M2\mathcal{M}_{2}.

The last line followed from the spectral norm bound on matrix XX. Let T\mathcal{T} be defined as in Proposition 6.5. It follows that

Now the conclusion of the lemma follows by the assumption that t=O(d/η2ε2)t=O(d/\eta_{2}\varepsilon_{2}^{\prime}) and our choice of η2\eta_{2} and 1ε28ε38rε2d\frac{1}{\varepsilon_{2}^{\prime 8}\varepsilon_{3}^{8}r}\leq\varepsilon_{2}^{\prime}d in Theorem 3.5.

D.3 Proof of Lemma 5.3

We now prove the decomposition lemma of Wt\overline{W}_{t}, Lemma 6.6. Recall our definition of Wt(2)\overline{W}_{t}^{(2)} as

For each step, we know that for every j[m]j\in[m],

Thus, multiplying by η2(1η2λ)ts\eta_{2}(1-\eta_{2}\lambda)^{t-s} and summing, following our definition of Wt(2)\overline{W}_{t}^{(2)} in (5.1), we get

Now we focus on bounding the bottom term. We can see that

By Auxiliary Coupling Lemma 6.3 with t0=0t_{0}=0, we know that for s1η2λs\leq\frac{1}{\eta_{2}\lambda}, w.h.p.

for some real values {αk}kMˉ2\{\alpha_{k}\}_{k\in\bar{\mathcal{M}}_{2}} with

By the above calculation, (D.47), and Lemma 5.2, we have:

where the last inequality followed by our choice of parameters. ∎

Using the decomposition lemma, the conclusion of Lemma 5.3 now follows via computation.

Let us define the function U\mathfrak{U} as:

Note that U\mathfrak{U} is some kernel prediction function. Since each [W0]j[W_{0}]_{j} is distributed as a vector of i.i.d. spherical Gaussians, we know that for fixed x1(k),x1x_{1}^{(k)},x_{1}:

In the above equation Θ(x1(k),x1(i))\Theta(x_{1}^{(k)},x_{1}^{(i)}) is the principle angle between x1(k),x1(i)x_{1}^{(k)},x_{1}^{(i)}. Since each [W0]j[W_{0}]_{j} is i.i.d., with basic concentration bounds, we know that w.h.p.

The last inequality uses the fact that w.h.p. for kik\not=i, x1(k),x1(i)x1(k)2x1(i)2=O~(d1/2)\frac{\langle x_{1}^{(k)},x_{1}^{(i)}\rangle}{\|x_{1}^{(k)}\|_{2}\|x_{1}^{(i)}\|_{2}}=\widetilde{O}(d^{-1/2}).

Let us define α=14kMˉ2αkx1(k)\alpha=\frac{1}{4}\sum_{k\in\bar{\mathcal{M}}_{2}}\alpha_{k}x_{1}^{(k)}; then

Since the training loss is at ε2p/10\varepsilon_{2}\leq p/10, we know that 1Mˉ2iMˉ2U(x1(i))1\frac{1}{|\bar{\mathcal{M}}_{2}|}\sum_{i\in\bar{\mathcal{M}}_{2}}|\mathfrak{U}(x_{1}^{(i)})|\geq 1 (or else the loss would not be low).

Since U(x1(i))α,x1(i)+32αix1(i)22+1dO~(kMˉ2,kiαk)+O(m1/6)|\mathfrak{U}(x_{1}^{(i)})|\leq|\langle\alpha,x_{1}^{(i)}\rangle|+\frac{3}{2}|\alpha_{i}|\|x_{1}^{(i)}\|_{2}^{2}+\frac{1}{d}\widetilde{O}\left(\sum_{k\in\bar{\mathcal{M}}_{2},k\not=i}|\alpha_{k}|\right)+O(m^{-1/6}), we can get:

Thus, either 1Mˉ2iMˉ2α,x1(i)14\frac{1}{|\bar{\mathcal{M}}_{2}|}\sum_{i\in\bar{\mathcal{M}}_{2}}|\langle\alpha,x_{1}^{(i)}\rangle|\geq\frac{1}{4}, which implies that

We now ready to conclude the proof: for randomly chosen x1x_{1}, it holds that

Now using the same expansion of U\mathfrak{U} as before gives

Now we note that as the nonzero degrees in the polynomial expansion of arccos\arccos are all odd, we have

The end result is that by Lemma 6.3, it will hold that:

Appendix E General case

Instead of using large learning rate and annealing to a small learning rate, the regularization effect also exists if we use a small learning rate (η2\eta_{2}) and large pre-activation noise and then decay the noise. Hence the update is given as:

where ξtN(0,τξ2Im×mId×d)\xi_{t}\sim N(0,\tau_{\xi}^{2}I_{m\times m}\otimes I_{d\times d}). However, the output of the network is given as:

Here ΞtN(0,τt2Im×m)\Xi_{t}\sim\mathcal{N}(0,\tau_{t}^{2}I_{m\times m}) is a (freshly random) gaussian variable at each iteration.

The same conclusion as in Theorem 3.4 holds if we first use noise level τt=τ0\tau_{t}=\tau_{0} and then anneal to τt=0\tau_{t}=0 after O~(dη1ε1)\widetilde{O}\left(\frac{d}{\eta_{1}\varepsilon_{1}}\right) iterations.

E.2 Extension to two layer convolution network

We make a simplifying assumption that z,ζz,\zeta are only supported on the last d/kd/k coordinates. The main theorem can be stated as the follows:

The same conclusions as in Theorem 3.4 and Theorem 3.5 hold if we replace the value of rr by r/kr/k and dd by dkdk in both the theorem and in Assumption 3.3.

We use this definition so that NUt(u,Ut;x)=gt(x)+rt(x)N_{U_{t}}(u,U_{t};x)=g_{t}(x)+r_{t}(x) for every t0t\geq 0.

We denote u=(u1,,uk)u=(u_{1},\cdots,u_{k}) for the weight of the second layer associated with each convolution.

The main difference between the convolution setting and the simple case is that there is only one hidden weight that is shared across channels. However, since the output layers of these channels have different weights, we can disentangle these channels and think of them as updating “separately”, which is given as the following two lemmas.

Here uiUi=((ui)j(Ui)j)j[mk]u_{i}\odot U_{i}=((u_{i})_{j}(U_{i})_{j})_{j\in[\frac{m}{k}]}.

Since UiU_{i} does not depend on the randomness of uiu_{i^{\prime}} but only U~t\widetilde{U}_{t}, fixing U~t,Ui\widetilde{U}_{t},U_{i} we know that since each entry of uiu_{i^{\prime}} i.i.d. mean zero, we have:

Applying basic concentration bounds on NU~t(ui,uiUi;x)N_{\widetilde{U}_{t}}(u_{i^{\prime}},u_{i}\odot U_{i};x), it holds that w.h.p. NU~t(ui,uiUi;x)O~(x2λm)|N_{\widetilde{U}_{t}}(u_{i^{\prime}},u_{i}\odot U_{i};x)|\leq\widetilde{O}\left(\frac{\|x\|_{2}}{\lambda m}\right). Putting this back into Eq (E.12), we complete the proof.

We set εc=O~(kd41λm1/2)\varepsilon_{c}=\widetilde{O}\left(kd^{4}\frac{1}{\lambda m^{1/2}}\right), and with this lemma, we can restate Lemma 6.1, Lemma C.8 and Lemma D.1 in the following way: Suppose εcmin{ε1/10,ε2/10}\varepsilon_{c}\leq\min\{\varepsilon_{1}/10,\varepsilon_{2}^{\prime}/10\} for every xx in the training set. Then the following lemmas hold by directly applying Lemma E.3.

In the setting of Theorem E.2, there exists a solution UU^{\star} satisfying a) UF2O(dklog2(1/ε))\|U^{\star}\|_{F}^{2}\leq{O}(dk\log^{2}(1/\varepsilon)) and b) for every t0t\geq 0:

In the setting of Theorem E.2, there exists a solution UU^{*} satisfying UF2=O~(kε12r)\|U^{*}\|_{F}^{2}=\widetilde{O}\left(\frac{k}{\varepsilon_{1}^{2}r}\right) and for every t1η2λt\leq\frac{1}{\eta_{2}\lambda}:

In the setting of Theorem E.2, there exists a solution UU^{\star} satisfying a) UF2O~(kε22r+Npk)\|U^{\star}\|_{F}^{2}\leq\widetilde{O}\left(\frac{k}{\varepsilon_{2}^{\prime 2}r}+Npk\right) and b) for every t1η2λt\leq\frac{1}{\eta_{2}\lambda},

To prove these Lemmas, we can simply define U=kW+kVU^{*}=\sqrt{k}W^{*}+\sqrt{k}V^{*} for W,VW^{*},V^{*} given in the original proof and apply Lemma E.3. The reason we need kk here is because there are mk\frac{m}{k} channels instead of mm, so the square norm scales up by a factor of kk.

Now the next two convergence theorems follow directly from Lemma 4.1 and Lemma 4.3 and apply with initial learning rate η1\eta_{1}.

In the setting of Theorem E.2 with initial learning rate η1\eta_{1}, at some step t0O~(dkη1ε1)t_{0}\leq\widetilde{O}\left(\frac{dk}{\eta_{1}\varepsilon_{1}}\right), the training loss L^(u,Ut0)\widehat{L}(u,U_{t_{0}}) becomes smaller than qlog2+ϵ1q\log 2+\epsilon_{1}. Moreover, we have Ut0F2=O(dklog2(1/ε1))\|\overline{U}_{t_{0}}\|_{F}^{2}={O}\left(dk\log^{2}(1/\varepsilon_{1})\right).

In the setting of Theorem E.2, with initial learning rate η1\eta_{1}, there exists t=O~(kε13η2r)t=\widetilde{O}\left(\frac{k}{\varepsilon_{1}^{3}\eta_{2}r}\right) , such that after t0+tt_{0}+t iterations we have that

Moreover, Ut0+tUt0F2O~(kε12r)\|\overline{U}_{t_{0}+t}-\overline{U}_{t_{0}}\|_{F}^{2}\leq\widetilde{O}\left(\frac{k}{\varepsilon_{1}^{2}r}\right)

The following statement applies when we use a small initial learning rate and follows from the proof of Lemma 5.1.

In the setting of Theorem E.2, with initial learning rate η2\eta_{2}, there exists tt with

such that Ltε2L_{t}\leq\varepsilon_{2}^{\prime} after tt iterations. Moreover, we have that UtF2O~(kε22r+Npk)\|\overline{U}_{t}\|_{F}^{2}\leq\widetilde{O}\left(\frac{k}{\varepsilon_{2}^{\prime 2}r}+Npk\right)

Now, the following lemma directly adapts from Lemma 4.2 by applying Lemma E.4:

In the setting of Theorem E.2 with initial learning rate η1\eta_{1}, w.h.p., for every t1η1λt\leq\frac{1}{\eta_{1}\lambda},

With these lemmas, we can directly conclude the following:

In the setting of Lemma E.9 with initial learning rate η1\eta_{1}, the following holds:

Here x1(i),(j)=([x1(i)]s)s{(j1)d/k+1,(j1)d/k+2,,d}x_{1}^{(i),(j)}=([x_{1}^{(i)}]_{s})_{s\in\{(j-1)d/k+1,(j-1)d/k+2,\cdots,d\}}

The final proof of Theorem E.2 follows directly from the proof of Theorem 3.4 and Theorem 3.5.

Appendix F Toolbox

Without loss of generality, we assume aγ2/b0a\gamma_{2}/b\geq 0. Let Y1=aX1+bX2Y_{1}=aX_{1}+bX_{2} and Y2=bX1aX2Y_{2}=bX_{1}-aX_{2}. We have that Y1,Y2Y_{1},Y_{2} are independent random Gaussian variables with marginal distribution N(0,1)\mathcal{N}(0,1). Moreover, X1=aY1+bY2X_{1}=aY_{1}+bY_{2}. Thus, X1aX1+bX2=γ2X_{1}\mid aX_{1}+bX_{2}=\gamma_{2} is the same as aY1+bY2Y1=γ2aY_{1}+bY_{2}\mid Y_{1}=\gamma_{2}, which has distribution N(aγ2,b2)\mathcal{N}(a\gamma_{2},b^{2}). Let ZZ be a standard Gaussian, then

Note that Mz=wβ,z+M1zMz=w^{\star}\langle\beta,z\rangle+M_{1}z. Since M1M_{1} is a random gaussian matrix and ddd^{\prime}\leq d, we know that w.h.p. for every zz we have wM1zM1z222\frac{\langle w^{\star}M_{1}z\rangle}{\|M_{1}z\|_{2}}\leq\frac{\sqrt{2}}{2}.

Appendix G Additional Details for Experiments

In this section we provide additional details on the experimental results of Section 7. All of our models were trained using a single NVIDIA TitanXp GPU and our code is implemented via PyTorch. We note that for all our experiments, the mean pixel is subtracted from the CIFAR image and then the image is divided by the standard deviation pixel. We use mean and standard deviation values in the PyTorch WideResNet implementation: https://github.com/xternalz/WideResNet-pytorch.

In this section, we provide additional details for the mitigation strategy for a small learning rate described in Section 7. In Table 1, we demonstrate on CIFAR-10 images without data augmentation that this regularization can indeed counteract the negative effects of small learning rate, as we report a 4.72% increase in validation accuracy when adding noise to a small learning rate.

We train for all models for 200 epochs, annealing the learning rates by a factor of 0.2 at the 60th, 120th, and 150th epoch for all models. The large learning rate model uses an initial learning rate of 0.1, whereas the small learning rate model uses initial learning rate of 0.01. The large learning rate is a standard hyperparameter setting for the WideResNet16 architecture, and we chose the small learning rate by scaling this value down. The other hyperparameter settings are standard. We remove data augmentation from the training set to isolate the effect of adding noise.

We add noise before every time we apply the relu activation. As it is costly to add i.i.d. noise that is the size of the entire hidden layer, we sample Gaussian noise that has shape equal to the last two dimensions of the 4 dimensional hidden layer, where the first two dimensions are batch size and number of channels, and duplicate this over the first 2 dimensions. We sample different noise for every batch.

Our annealing schedule simply multiplies the noise level by a constant factor at every iteration. We tune the standard deviation of the noise to 0.20.2 and the annealing rate to 0.9950.995 every iteration. We show results from a single trial as the small LR with noise algorithm already shows substantial improvement over vanilla small LR.

G.2 Additional Details on Patch-Augmented CIFAR-10

We first describe in greater detail our method for producing the patch. First, the split of our data is the following: of the 50000 CIFAR-10 training images, 10000 will contain no patch and 40000 will have a patch. We generate this split randomly before training and keep it fixed. During a single epoch, we iterate through all images, loading the 10000 clean images the same way each time. For the remaining 40000 examples, we use a patch-only image with probability 0.2 and a patch mixed with CIFAR image with probability 0.8. Thus, 20% of the updates are on clean images, 16% of updates are on patches only, and 64% of updates are on mixed images, but the actual split of the data is slightly different because of our implementation.

The patch will be located in the center of the image. We visualize the patches in Figure 4. We generate the patch as follows: before training begins, we sample a random vector zz with i.i.d entries from N(0,σz2)\mathcal{N}(0,\sigma_{z}^{2}) as well as ζi[β,β]\zeta_{i}\sim[-\beta,\beta] for classes i=1,,10i=1,\ldots,10. Then to generate patch-only images, we add a scalar multiple of ζi\zeta_{i} to zz if the example belongs to class ii. This scalar multiple is in the range [α,α][-\alpha,\alpha] for some α\alpha we tune. We set coordinates not in the patch to . To generate images that contain both patch and a CIFAR example, we simply add z±ζiz\pm\zeta_{i}. In all, the hyperparameters we tune are σz,β,α\sigma_{z},\beta,\alpha.

We must choose σ,β,α\sigma,\beta,\alpha on the correct scale so that large and small learning rates don’t both ignore the patch or overfit to the patch. For the experiment shown, σz=1.25,β=0.1,α=1.75\sigma_{z}=1.25,\beta=0.1,\alpha=1.75.

Our large initial learning rate model trains with learning rate 0.1, annealing to 0.004 at the 30th epoch. and the small LR model trains with fixed learning rate 0.004. Our small LR with noise model trains with fixed learning rate 0.004, initial noise 0.40.4, and decays the noise to 4e-6 after the 30th epoch. We train all models for 60 epochs total, starting from the same dataset and choice of patches. Table 2 demonstrates the final validation accuracy numbers on patch-augmented and clean data.

Now we provide additional evidence that the generalization disparity is indeed due to the learning order effect and not simply because the large learning rate model can already generalize better on clean CIFAR-10 images. To see this, we consider the generalization error of models trained on 10000 clean CIFAR images: the small LR model achieves 65% validation accuracy, and the large LR model achieves 76% validation accuracy. For comparison, on the full clean dataset the small LR model achieves 83% validation accuracy whereas the large LR model achieves 90% accuracy.

We note that the final number of 69.89% clean image accuracy for the small LR model trained on the patch dataset is much closer to 65% than 83%, suggesting that it is indeed using a fraction of the available CIFAR samples because of learning order. On the other hand, the large LR model achieves final clean validation accuracy of 87.61% when trained on the patch dataset, which is very close to the 90% that is achievable training on the full clean dataset. This indicates that the large LR model is still using the majority of the images to learn CIFAR examples before annealing, as it has not yet memorized the patches.