Towards Explaining the Regularization Effect of Initial Large Learning Rate in Training Neural Networks
Yuanzhi Li, Colin Wei, Tengyu Ma
Introduction
It is a commonly accepted fact that a large initial learning rate is required to successfully train a deep network even though it slows down optimization of the train loss. Modern state-of-the-art architectures typically start with a large learning rate and anneal it at a point when the model’s fit to the training data plateaus . Meanwhile, models trained using only small learning rates have been found to generalize poorly despite enjoying faster optimization of the training loss.
A number of papers have proposed explanations for this phenomenon, such as sharpness of the local minima , the time it takes to move from initialization , and the scale of SGD noise . However, we still have a limited understanding of a surprising and striking part of the large learning rate phenomenon: from looking at the section of the accuracy curve before annealing, it would appear that a small learning rate model should outperform the large learning rate model in both training and test error. Concretely, in Fig. 1, the model trained with small learning rate outperforms the large learning rate until epoch 60 when the learning rate is first annealed. Only after annealing does the large learning rate visibly outperform the small learning rate in terms of generalization.
In this paper, we propose to theoretically explain this phenomenon via the concept of learning order of the model, i.e., the rates at which it learns different types of examples. This is not a typical concept in the generalization literature — learning order is a training-time property of the model, but most analyses only consider post-training properties such as the classifier’s complexity , or the algorithm’s output stability . We will construct a simple distribution for which the learning order of a two-layer network trained under large and small initial learning rates determines its generalization.
Informally, consider a distribution over training examples consisting of two types of patterns (“pattern” refers to a grouping of features). The first type consists of a set of easy-to-generalize (i.e., discrete) patterns of low cardinality that is difficult to fit using a low-complexity classifier, but easily learnable via complex classifiers such as neural networks. The second type of pattern will be learnable by a low-complexity classifier, but are inherently noisy so it is difficult for the classifier to generalize. In our case, the second type of pattern requires more samples to correctly learn than the first type. Suppose we have the following split of examples in our dataset:
The following informal theorems characterize the learning order and generalization of the large and small initial learning rate models. They are a dramatic simplification of our Theorems 3.4 and 3.5 meant only to highlight the intuitions behind our results.
There is a dataset with size of the form (1.1) such that with a large initial learning rate and noisy gradient updates, a two layer network will:
1) initially only learn hard-to-generalize, easy-to-fit patterns from the examples containing such patterns.
2) learn easy-to-generalize, hard-to-fit patterns only after the learning rate is annealed.
Thus, the model learns hard-to-generalize, easily fit patterns with an effective sample size of and still learns all easy-to-generalize, hard to fit patterns correctly with samples.
In the same setting as above, with small initial learning rate the network will:
1) quickly learn all easy-to-generalize, hard-to-fit patterns.
2) ignore hard-to-generalize, easily fit patterns from the examples containing both pattern types, and only learn them from the examples containing only hard-to-generalize patterns.
Thus, the model learns hard-to-generalize, easily fit patterns with a smaller effective sample size of and will perform relatively worse on these patterns at test time.
Together, these two theorems can justify the phenomenon observed in Figure 1 as follows: in a real-world network, the large learning rate model first learns hard-to-generalize, easier-to-fit patterns and is unable to memorize easy-to-generalize, hard-to-fit patterns, leading to a plateau in accuracy. Once the learning rate is annealed, it is able to fit these patterns, explaining the sudden spike in both train and test accuracy. On the other hand, because of the low amount of SGD noise present in easy-to-generalize, hard-to-fit patterns, the small learning rate model quickly overfits to them before fully learning the hard-to-generalize patterns, resulting in poor test error on the latter type of pattern.
Both intuitively and in our analysis, the non-convexity of neural nets is crucial for the learning-order effect to occur. Strongly convex problems have a unique minimum, so what happens during training does not affect the final result. On the other hand, we show the non-convexity causes the learning order to highly influence the characteristics of the solutions found by the algorithm.
In Section E.1, we propose a mitigation strategy inspired by our analysis. In the same setting as Theorems 1.1 and 1.2, we consider training a model with small initial learning rate while adding noise before the activations which gets reduced by some constant factor at some particular epoch in training. We show that this algorithm provides the same theoretical guarantees as the large initial learning rate, and we empirically demonstrate the effectiveness of this strategy in Section 7. In Section 7 we also empirically validate Theorems 1.1 and 1.2 by adding an artificial memorizable patch to CIFAR-10 images, in a manner inspired by (1.1).
The question of training with larger batch sizes is closely tied with learning rate, and many papers have empirically studied large batch/small LR phenomena , particularly focusing on vision tasks using SGD as the optimizer.While these papers are framed as a study of large-batch training, a number of them explicitly acknowledge the connection between large batch size and small learning rate. Keskar et al. argue that training with a large batch size or small learning rate results in sharp local minima. Hoffer et al. propose training the network for longer and with larger learning rate as a way to train with a larger batch size. Wen et al. propose adding Fisher noise to simulate the regularization effect of small batch size.
Adaptive gradient methods are a popular method for deep learning that adaptively choose different step sizes for different parameters. One motivation for these methods is reducing the need to tune learning rates . However, these methods have been observed to hurt generalization performance , and modern architectures often achieve the best results via SGD and hand-tuned learning rates . Wilson et al. construct a toy example for which ADAM generalizes provably worse than SGD. Additionally, there are several alternative learning rate schedules proposed for SGD, such as warm-restarts and . Ge et al. analyze the exponentially decaying learning rate and show that its final iterate achieves optimal error in stochastic optimization settings, but they only analyze convex settings.
There are also several recent works on implicit regularization of gradient descent that establish convergence to some idealized solution under particular choices of learning rate . In contrast to our analysis, the generalization guarantees from these works would depend only on the complexity of the final output and not on the order of learning.
Other recent papers have also studied the order in which deep networks learn certain types of examples. Mangalam and Prabhu and Nakkiran et al. experimentally demonstrate that deep networks may first fit examples learnable by “simpler” classifiers. For our construction, we prove that the neural net with large learning rate follows this behavior, initially learning a classifier on linearly separable examples and learning the remaining examples after annealing. However, the phenomenon that we analyze is also more nuanced: with a small learning rate, we prove that the model first learns a complex classifier on low-noise examples which are not linearly separable.
Finally, our proof techniques and intuitions are related to recent literature on global convergence of gradient descent for over-parametrized networks . These works show that gradient descent learns a fixed kernel related to the initialization under sufficient over-parameterization. In our analysis, the underlying kernel is changing over time. The amount of noise due to SGD governs the space of possible learned kernels, and as a result, regularizes the order of learning.
Setup and Notations
We formally introduce our data distribution, which contains examples supported on two types of components: a component meant to model hard-to-generalize, easier-to-fit patterns, and a component meant to model easy-to-generalize, hard-to-fit patterns (see the discussion in our introduction). Formally, we assume that the label has a uniform distribution over , and the data is generated as
where are assumed to be two half Gaussian distributions with a margin between them:
Memorizing 𝒬𝒬\mathcal{Q} with a two-layer net
It is easy for a two-layer relu network to memorize the labels of using two neurons with weights such that , an , . In particular, we can verify that will output a negative value for and a zero value for . Thus choosing a small enough , the classifier gives the correct sign for the label .
We assume that we have a training dataset with examples drawn i.i.d from the distribution described above. We use and to denote the empirical fraction of data points that are drawn from equation (2.2) and (2.3).
Two-layer neural network model
Here we slightly abuse the notation to use to denote both a matrix of columns with last columns being zero, or a matrix of columns. We also extend our theorem to other such as a two layer convolution network in Section E.
Training objective
We consider a regularized training objective . For the simplicity of derivation, the second layer weight vector is random initialized and fixed throughout this paper. Thus with slight abuse of notation the training objective can be written as .
Notations
Main Results
The training algorithm that we consider is stochastic gradient descent with spherical Gaussian noise. We remark that we analyze this algorithm as a simplification of the minibatch SGD noise encountered when training real-world networks. There are a number of works theoretically characterizing this particular noise distribution , and we leave analysis of this setting to future work.
We initialize to have i.i.d. entries from a Gaussian distribution with variance , and at each iteration of gradient descent we add spherical Gaussian noise with coordinate-wise variance to the gradient updates. That is, the learning algorithm for the model is
where denotes the learning rate at time . We will analyze two algorithms:
Algorithm 1 (L-S): The learning rate is for iterations until the training loss drops below the threshold . Then we anneal the learning rate to (which is assumed to be much smaller than ) and run until the training loss drops to .
Algorithm 2 (S): We used a fixed learning rate of and stop at training loss .
For the convenience of the analysis, we make the following assumption that we choose in a way such that the contribution of the noises in the system stabilize at the initialization:Let be the solution to (3.3) holding fixed. If the standard deviation of the initialization is chosen to be smaller than , then standard deviation of the noise will grow to . Otherwise if the initialization is chosen to be larger, the contribution of the noise will decrease to the level of due to regularization. In typical analysis of SGD with spherical noises, often as long as either the noise or the learning rate is small enough, the proof goes through. However, here we will make explicit use of the large learning rate or the large noise to show better generalization performance.
After fixing and , we choose initialization and large learning rate so that
As a technical assumption for our proofs, we will also require .
We also require sufficient over-parametrization.
We assume throughout the paper that and where poly is a sufficiently large constant degree polynomial. We note that we can choose arbitrarily small, so long as it is fixed before we choose .
As we will see soon, the precise relation between implies that the level of over-parameterization is polynomial in , which fits with the conditions assumed in prior works, such as .
Throughout this paper, we assume the following dependencies between the parameters. We assume that with a relationship where is a small value.Or in a non-asymptotic language, we assume that are sufficiently large compared to : We set , , and . The regularizer will be chosen to be . All of these choices of hyper-parameters can be relaxed, but for simplicity of exposition we only work this setting.
We note that under our assumptions, for sufficiently large , and up to constant multiplicative factors. Thus we will mostly work with and (the empirical fractions) in the rest of the paper. We also note that our parameter choice satisfies and , which are a few conditions that we frequently use in the technical part of the paper.
Now we present our main theorems regarding the generalization of models trained with the L-S and S algorithms. The final generalization error of the model trained with the L-S algorithm will end up a factor smaller than the generalization error of the model trained with S algorithm.
Under Assumption 3.1, 3.2, and 3.3, there exists a universal constant such that Algorithm 1 (L-S) with annealing at loss for and stopping criterion satisfies the following:
It anneals the learning rate within iterations.
It stops at at most . With probability at least 0.99, the solution has test (classification) error and test loss at most .
Roughly, the learning order and generalization of the L-S model is as follows: before annealing the learning rate, the model only learns an effective classifier for on the samples in as the large learning rate creates too much noise to effectively learn (Lemma 4.1 and Lemma 4.2). After the learning rate is annealed, the model memorizes and correctly classifies examples with only a component during test time (formally shown in Lemmas 4.3 and 4.4). For examples with only component, the generalization error is (ignoring log factors and other technicalities) via standard Rademacher complexity. The full analysis of the L-S algorithm is clarified in Section 4.
Let be chosen in Theorem 3.4. Under Assumption 3.1, 3.2 and 3.3, there exists a universal constant such that w.h.p, Algorithm 2 with any and any stopping criterion , achieves training loss in at most iterations, and both the test error and the test loss of the obtained solution are at least .
We explain this lower bound as follows: the S algorithm will quickly memorize the component which is low noise and ignore the component for the examples with both and components (shown in Lemma 5.2). Thus, it only learns on examples. It obtains a small margin on these examples and therefore misclassifies a constant fraction of -only examples at test time. This results in the lower bound of . We formalize the analysis in Section 5.
It will be fruitful for our analysis to separately consider the gradient signal and Gaussian noise components of the weight matrix . We will decompose the weight matrix as follows: . In this formula, denotes the signals from all the gradient updates accumulated over time, and refers to the noise accumulated over time:
Note that when the learning rate is always , the formula simplifies to and . The decoupling and our particular choice of initialization satisfies that the noise updates in the system stabilize at initialization, so the marginal distribution of is always the same as the initialization. Another nice aspect of the signal-noise decomposition is as follows: we use tools from to show that if the signal term is small, then using only the noise component to compute the activations roughly preserves the output of the network. This facilitates our analysis of the network dynamics. See Section A.1 for full details.
Decomposition of Network Outputs
For convenience, we will explicitly decompose the model prediction at each time into two components, each of which operates on one pattern: we have ,
In other words, the network acts on the component of examples, and the network acts on the component of examples.
Characterization of Algorithm 1 (L-S)
We characterize the behavior of algorithm L-S with large initial learning rate. We provide proof sketches in Section 6.1 with full proofs in Section C.
The following lemma bounds the rate of convergence to the point where the loss gets annealed. It also bounds the total gradient signal accumulated by this point.
In the setting of Theorem 3.4, at some time step , the training loss becomes smaller than . Moreover, we have .
Our proof of Lemma 4.1 views the SGD dynamics as optimization with respect to the neural tangent kernel induced by the activation patterns where the kernel is rapidly changing due to the noise terms . This is in contrast to the standard NTK regime, where the activation patterns are assumed to be stable . Our analysis extends the NTK techniques to deal with a sequence of changing kernels which share a common optimal classifier (see Section 6.1 and Theorem 6.2 for additional details).
The next lemma says that with large initial learning rate, the function does not learn anything meaningful for the component before the -timestep. Note that by our choice of parameters and Lemma 4.1, we anneal at the time step . Therefore, the function has not learned anything meaningful about the memorizable pattern on distribution before we anneal.
In the setting of Theorem 3.4, w.h.p., for every ,
After iteration , we decrease the learning rate to . The following lemma bounds how fast the loss converges after annealing.
In the setting of Theorem 3.4, there exists , such that after iterations, we have that
Moreover, .
The following lemma bounds the training loss on the example subsets , .
In the setting of Lemma 4.3 using the same , the average training losses on the subsets and are both good in the sense that
Intuitively, low training loss of on immediately implies good generalization on examples containing patterns from . Meanwhile, the classifier for , , has low loss on examples. Then the test error bound follows from standard Rademacher complexity tools applied to these examples.
Characterization of Algorithm 2 (S)
We present our small learning rate lemmas, with proofs sketches in Section 6.2 and full proofs in Section D.
The below lemma shows that the algorithm will converge to small training error too quickly. In particular, the norm of is not large enough to produce a large margin solution for those such that .
Lower bound on the generalization error
The following important lemma states that our classifier for does not learn much from the examples in . Intuitively, under a small learning rate, the classifier will already learn so quickly from the component of these examples that it will not learn from the component of examples in . We make this precise by showing that the magnitude of the gradients on is small.
The above lemma implies that does not learn much from examples in , and therefore must overfit to the examples in . As by our choice of parameters, we will not have enough samples to learn the -dimensional distribution . The following lemma formalizes the intuition that the margin will be poor on samples from .
As the margin is poor, the predictions will be heavily influenced by noise. We use this intuition to prove the classification lower bound for Theorem 3.5.
Proof Sketches
We first introduce notations that will be useful in these proofs. We will explicitly decouple the noise in the weights from the signal by abstracting the loss as a function of only the signal portion of the weights. Let us define the following:
Now the proof of Lemma 4.1 relies on the following two results, which we state below and prove in Section C.1. The first says that there is a common target for the signal part of the network that is a good solution for all of the .
In the setting of Lemma 4.1, there exists a solution satisfying a) and b) for every
Now the second statement is a general one proving that gradient descent on a sequence of convex, but changing, functions will still find a optimum provided these functions share the same solution.
’s are -Lipschitz, i.e.,
For every , we have that for and , , there is a such that:
Furthermore, the iterates satisfy for all .
Combining these two statements leads to the proof of Lemma 4.1.
We can apply Theorem 6.2 with defined in (6.2) and defined in Lemma 6.1, using . We note that satisfies the conditions of Theorem 6.2 by our parameter choices, which completes the proof. ∎
To prove Lemma 4.2, we will essentially argue in Section C.2 that the change in activations caused by the noise will prevent the model from learning with a large learning rate. This is because the examples in require a very specific configuration of activation patterns to learn correctly, and the noise will prevent the model from maintaining this configuration.
Now after we anneal the learning rate, in order to conclude Lemmas 4.3 and 4.4, the following must hold: 1) the network learns the component of the distribution and 2) the network does not forget the component that it previously learned. To prove the latter, we rely on the following lemma stating that the activations do not change much with a small learning rate:
The activation patterns do not change much after annealing the learning rate: for every , for any and for any row of the weight matrix , we have that
Moreover, for all , , it holds that w.h.p. for every :
We prove the above lemma in Section C.3. Now to complete the proof of Lemma 4.3, we will construct a target solution for all timesteps after annealing the learning rate based on the activations at time (as they do not change by much in subsequent time steps because of Lemma 6.3) and reapply Theorem 6.2. Finally, to prove Lemma 4.4, we use the fact that the component of the solution does not change by much, and therefore the loss on is still low.
2 Proof Sketches for Small Learning Rate
The proof of Lemma 5.1 proceeds similarly as the proof of Lemma 4.3: we will show the existence of a target solution of for all iterations, and use Theorem 6.2 to prove convergence to this target solution.
The next two statements argue that can be large only in a limited number of time steps. As the training loss converges quickly with small learning rate, this will be used to argue that the components of examples in provide a very limited signal to . The proofs of these statements are in Section D.2.
We first show the following lemma that says that if is large (which means the loss is large as well), then the total gradient norm has to be big. This lemma holds because there is little noise in the component of the distribution, and therefore the gradient of will be large if is large.
For every , we have that if , then w.h.p.
Now we use the above lemma to bound the number of times when is large.
In the setting of Lemma 5.2, let be the set of iterations where , where is defined in Lemma 5.2. Then w.h.p,
Now if is small, the gradient accumulated on from examples in must be small. We formalize this argument in our proof of Lemma 5.2 in Section D.2.
Lemma 5.3 will then follow by explicitly decomposing into a component in and some remainder, which is shown to be small by Lemma 5.2. This is presented in the below lemma, which is proved in Section D.3.
There exists real numbers such that for every , we have
with .
This allows us to conclude Lemma 5.3 via computations carried out in Section D.3.
Finally, to complete the proof of Theorem 3.5, we will argue in Section B.2 that a classifier of the form given by (5.2) cannot have small generalization error because it will be too heavily influenced by the noise in .
Experiments
Our theory suggests that adding noise to the network could be an effective strategy to regularize a small learning rate in practice. We test this empirically by adding small Gaussian noise during training before every activation layer in a WideResNet16 architecture, as our analysis highlights pre-activation noise as a key regularization mechanism of SGD. The noise level is annealed over time. We demonstrate on CIFAR-10 images without data augmentation that this regularization can indeed counteract the negative effects of small learning rate, as we report a 4.72% increase in validation accuracy when adding noise to a small learning rate. Full details are in Section G.1.
We will also empirically demonstrate that the choice of large vs. small initial learning rate can indeed invert the learning order of different example types. We add a memorizable 7 7 pixel patch to a subset of CIFAR-10 images following the scenario presented in (1.1), such that around 20% of images have no patch, 16% of images contain only a patch, and 64% contain both CIFAR-10 data and patch. We generate the patches so that they are not easily separable, as in our constructed , but they are low in variation and therefore easy to memorize. Precise details on producing the data, including a visualization of the patch, are in Section G.2. We train on the modified dataset using WideResNet16 using 3 methods: large learning rate with annealing at the 30th epoch, small initial learning rate, and small learning rate with noise annealed at the 30th epoch.
Figure 3 depicts the validation accuracy vs. epoch on clean (no patch) and patch-only images. From the plots, it is apparent that the small learning rate picks up the signal in the patch very quickly, whereas the other two methods only memorize the patch after annealing.
From the validation accuracy on clean images, we can deduce that the small learning rate method is indeed learning the CIFAR images using a small fraction of all the available data, as the validation accuracy of a small LR model when training on the full dataset is around 83%, but the validation on clean data after training with the patch is 70%. We provide additional arguments in Section G.2. Our code for these experiments is online at the following link: https://github.com/cwein3/large-lr-code.
Conclusion
In this work, we show that the order in which a neural net learns to fit different types of patterns plays a crucial role in generalization. To demonstrate this, we construct a distribution on which models trained with large learning rates generalize provably better than those trained with small learning rates due to learning order. Our analysis reveals that more SGD noise, or larger learning rate, biases the model towards learning “generalizing” kernels rather than “memorizing” kernels. We confirm on articifially modified CIFAR-10 data that the scale of the learning rate can indeed influence learning order and generalization. Inspired by these findings, we propose a mitigation strategy that injects noise before the activations and works both theoretically for our construction and empirically. The design of better algorithms for regularizing learning order is an exciting question for future work.
Acknowledgements
CW acknowledges support from a NSF Graduate Research Fellowship.
References
Appendix A Basic Properties and Toolbox
In this section, we collect a few basic properties of the neural networks we are studying. In section F, we provide two lemmas on Gaussian random variables and perturbation theory of the matrices.
Let be the -th row of . We have that .
For any , if for every , then we have that and .
By equation (3.4) and Proposition A.2, we have that
Hence . Now, since each is i.i.d. uniform , using the randomness of we know that w.h.p.
Here in the last inequality we applied Lemma A.8. The second statement follows from Proposition A.4 and triangle inequality. ∎
We have the following Rademacher complexity bound:
In this section, we collect useful statements which will help with decoupling the signal from the noise in our analysis. First, we observe that if the noise updates in the system stabilize at initialization, the marginal distribution of is always the same as the initialization.
Under Assumption 3.1, suppose we run Algorithm 1. Then for any before annealing the learning rate, has marginal distribution . In other words, each entry of follows and they are independent with each others.
One nice aspect of the signal-noise decomposition is as follows: we use tools from to show that if the signal term is small, then using only the noise component to compute the activations roughly preserves the output of the network. This facilitates our analysis of the network dynamics.
As we will often apply (A.11) with , for notational simplicity we denote throughout the paper . By our choice of we know that .
Appendix B Proof of Main Theorems
We start with the following lemma that shows that if has small training error on , then the output of on is large compared to . This is because for the loss to be low, must have a good margin on . However, as the norm of is roughly uniform in $g$ to have larger output.
W.h.p. for every and every , as long as , we have that: for every ,
We use to denote the set of all such that . Similarly, we use to denote the set of all such that , and use to denote the set of all such that .
Let . By the positive homogeneity of ReLU, we know that for every , it holds:
Since , it holds that w.h.p. for every ,
Our proof of Theorem 3.4 now amounts to carefully checking that all examples in are classified correctly, and the classifier will generalize well on .
By Lemma 4.4, we know that for we have . Thus applying Lemma B.1, we obtain that as long as (which is implied by Assumption 3.3)
This implies that for , applying Lemma A.8 gives us
Moreover, applying Lemma A.6 on with by Lemma 4.2 and Lemma 4.3, we have that
where we used the fact that .
Here the last step uses the definition of that . ∎
B.2 Proof of Theorem 3.5
We will prove Theorem 3.5 using Lemma 5.3 by roughly arguing that the predictions made by will be heavily influenced by a vector in the low rank span of examples from . With high probability, this vector will be noisy and not align well with the ground truth , leading to mispredictions.
Recall that denotes the stopping criterion used in Theorem 3.5 and . Using Lemma 5.3, we know that w.h.p.
By Lemma F.2 we know that w.h.p. over the randomness of ’s, for we have as long as : . For every randomly chosen , we can also write where so is independent of , hence
Note that with , and with probability at least , . This implies that with probability at least over a randomly chosen we can have:
For , we know that with probability at least , we have:
Moreover, since is independent of , we know that with probability both events can happen, in which case:
Thus, since by Lemma 5.3, we know that as long as
However, since , we know that either , which results in but . So when , the network classifies incorrectly. On the other hand, we have when the network will classify incorrectly. Since and holds with probability , this shows that the test error is at least . ∎
Appendix C Proofs for Large Learning Rate Lemmas
To prove Lemma 4.1, we will show that the network will learn all examples with component while the learning rate is large. The key to the proof is that although the large learning rate noise only allows the network to search over coarse kernels, is still learnable by these kernels because of its linearly-separable structure. To make this precise, we decompose the weights Into the signal and noise components, and show that there exists a fixed “target” signal matrix which will classify correctly no matter the noise matrix.
Recall our definitions of , in (6.1) and (6.2), and that
Recall that Lemma 6.1 leverages the linearly-separable structure of to find a “target” signal matrix that correctly classifies w.h.p over the noise matrix. We state its proof below.
By proposition A.3, . We apply Lemma A.8 as follows: by Proposition A.7, ’s entry has marginal distribution and therefore the column of has distribution . Since w.h.p. , the coupling Lemma A.8 gives
On the other hand, we also have by Proposition A.5, using the fact that , w.h.p.
Here in the last inequality we used the fact that the network is sufficiently over-parameterized so that .
Using (C.4), noting that our choice of satisfies , we conclude
For the term , we know that
Note that entries of are i.i.d. random Bernoulli(), thus we know that w.h.p.
Thus, by our choice that and ,
By definition of , we know that
Now we wish to argue that even though the noise matrix is changing, gradient descent will still find the fixed target signal matrix . This leverages the fact that once we fix the activation patterns, we can view each step of the optimization as gradient descent with respect to a convex, but changing, function. Below we provide a proof of Theorem 6.2, which allows for optimization of this changing function.
For the sake of contradiction, we assume that for all . Using the definition of , we have that the update rule of can be written as
Assuming that , we have that as long as and , we have:
C.2 Proof of Lemma 4.2
In this section, we will often consider the activation patterns on the inputs at various time steps. For convenience, we have the following definition:
For any , and vector , let denote the set of neurons that have positive pre-activation on the input (with weights ), and be the set of neurons with negative pre-activations on the input . (We will mostly be interested in the quantities and their intersections.)
Let . Then, we have that
Towards bounding the terms in equation (C.25), we will need to reason about the activations patterns of at various time steps. We first show that the activation patterns of and have to agree in most of neurons except an fraction of them. This will be useful to show that the second term of the RHS of equation (C.25) is small.
In the setting of Lemma C.2, w.h.p over the randomness of the initialization and all the randomness in the algorithm, for every , implies that . Moreover, the size of the set is bounded by
Recall that and by Proposition A.7 has distribution . Therefore, by standard Gaussian concentration and union bound, with high probability over the randomness of the initialization and the algorithm, for all ,
Moreover, note that . By the independence between ’s and standard concentration inequalities (Bernstein inequality), we have that with high probability, there are at most entries satisfying . Together with the first part of the lemma, and that is sufficiently large so that , we complete the proof of equation (C.26). ∎
We use the lemma above to conclude that the second term in the decomposition (C.25) is at most on the order of .
In the setting of Lemma C.2, we have that
By the definition of our algorithm, before annealing the learning rate, we have
Using Proposition A.3 and that , we have that . It follows that
Equation above and equation (C.30) complete the proof. ∎
The following lemma decomposes into a sum of the contribution of the gradient from all the previous steps.
In the setting of Lemma C.2, let . ( can be viewed as the raw change of at the time step without considering the effect of the regularizer.) We have that
Using the fact that we complete the proof. ∎
Define the analog of with to compute the activation pattern: for any , and vector , let and define similarly.
Suppose at some iteration , and have the same activation pattern at neuron and in the sense that , or . Then the corresponding gradient update at that iteration for the weight vectors associated with and are the same up to a potential sign flip:
Moreover, suppose we have that satisfy that and (or and ) for , then the same conclusion holds for and .
Note that by definition, , and thus it suffices to prove that . By Proposition A.1, we have that
Note that can only take (a positive scaling of) four values . We claim that for every choice of these four values, for the satisfying the condition of the lemma, we have
Note that the equation above together with suffices to complete the proof.
Now to prove the second part of the lemma, suppose satisfy that and for . Using from Proposition A.3, we have that where used the assumption that and . Therefore, we conclude that . Now by the first lemma of the lemma we complete the proof. ∎
Now we are ready to bound the first term on the RHS of equation C.25, which is the crux of the proofs in this section. The key here is to get a bound that scales quadratically in .
In the setting of Lemma C.2, let be defined in Proposition C.5. Then, we have that
As a direct corollary of the equation above and Proposition C.5, we have that
By the set operations and the facts that and that , we have that
where the notations hide universal constants that make the first conclusion of Proposition C.3 true. By the second part of Proposition C.3 (or more directly equation (C.28)), we have that , and . By Proposition C.6, we have that for any , . For notational simplicity, let and . Therefore it follows that
where in the last inequality we use that for any , , and the fact that (by Proposition A.2.)
Note that the distribution of ’s are independent across the choice of . Thus we will compute and then apply concentration concentration inequality for the sum. Note that the event here depends on three quantities , , and . First of all, is independent of these other two because is orthogonal to and and have spherical covariance matrices.
By the definition of , we can express their relationship by writing , where . Recall that by proposition A.7, we have and are two independent Gaussians. Let be the variance of . We compute by observing that
Note that , thus is independent of conditioned on , for every . For notational simplicity, let , , and , and where the big O notation hide the same constant factor used in defining in equation (C.40). Let where (because ). Note that by the calculation above, has standard deviation which is bounded from below by . Then, we have that
Now by equation (C.42) and standard concentration inequality, and the fact that is sufficiently large, we have that with high probability,
Using standard concentration inequality and the fact that is sufficiently large, we have that with high probability,
We can also prove the same bound for analogously. Using equation (C.41) and the several equations above, we conclude that
where the last step uses that the condition that .
Now combining the Propositions above we are ready to prove Lemma 4.2.
Using triangle inequality, Proposition A.8, and equation (A.6) of Proposition A.5, we have that for any of norm ,
C.3 Proof of Lemma 6.3
The proof of Lemma 6.3 relies on the fact that a smaller learning rate preserves the noise generated from the timestep before annealing. This allows us to reason that the new activations are similar to the original before reducing the learning rate.
where . By properties of a sum of Independent Gaussians, we have where is the standard deviation of each entry of . We also have that is independent of . Moreover, for every , the standard deviation can be bounded by
(Note that since , we should expect that the standard deviations satisfy . That is, the additional randomness introduced in the pre-activation is small.)
On the other hand, for every , the contribution of to is still present because the entry of has variance at least on the order of the variance of the entries of , which is . This also implies that the variance of the entries of is lower bounded by the variance of . This in turn is lower bounded by up to constant factor.
Therefore, using the decomposition (C.59) and the bounds above, we should expect that the sign of strongly correlates with the the sign of , which will be formally shown below. Using Lemma A.8, we have that the activation pattern is mostly decided by the noise part ( and ), in the sense that for every ,
Fixing , we can decompose our target to
We’ve bounded the first and third term on the RHS of the equation above. For the middle term, let and . Note that and that and are zero-mean independent Gaussian random variables with variance and variance , respectively. The basic property of Gaussian random variable implies that
Since ’s are independent, by basic concentration inequality (e.g., Bernstein inequality or Hoeffding inequality), we have that with high probability
Combining the equation above with equation (C.61), (C.62),and (C.64) completes the proof for the first part.
where the last inequality is due to by Proposition A.3, and bounding by Proposition A.5. ∎
We note that this lemma also applies to the setting when , i.e. we start with an initial small learning rate and compare to the random initialization. This is useful for the proofs in the small initial learning rate setting.
C.4 Proof of Lemma 4.3
We will now show that the network learns patterns from once the learning rate is annealed by constructing a common target for the network at every subsequent time step. We will then use Theorem 6.2 to show that the optimization finds this target. Let us define
Formally, we first show the following proposition, which proves the existence of a target solution that has good accuracy on and does not unlearn the network’s progress on :
In the setting of Lemma 4.3, let be defined in equation (6.2). Then, there exists a solution satisfying and
To prove this proposition, we need the following lemma:
Suppose satisfies that for some . Then, we have that
And moreover, if for some , then the prediction of on satisfies .
For convenience, let us denote . By our assumption, we have that .
Let . We have that w.h.p, for ,
and the factor of comes from the fact that the fraction of examples that are will be , , , respectively, w.h.p. Since the function is a 2-Lip function, we know that
The equation above together with the assumption implies that
which implies that . It follows that and . Now we note that By the strict convexity of , we can easily conclude that .∎
Next, we will bound and the value of . This allows us to conclude that is small, so that it is easy to “unlearn” once the learning rate is annealed.
Suppose the condition in Lemma 4.1 holds. Then
Since , we know that . Applying Proposition C.9 with and , we have that and .
Now we will complete the proof of Proposition C.8.
Let us define sets as the following:
for some sufficiently large universal constant .
Note that the random noise vector will satisfy the condition for set with probability proportional to the angle between and , which is by Taylor approximation of . Thus, as and differ in at most activations, w.h.p., . This implies that
Hence we can also easily conclude that for every ,
On the other hand we have that by Lemma C.10, it holds that
Now the first term equals , and the second term is bounded by
using Proposition A.3 to upper bound . Thus, it follows that .
It follows that for every and its corresponding label , as long as ,
The last inequality follows from our choice of parameters such that . Putting together Eq (C.105) and (C.107) and defining , we have that
By proposition C.8, there exists with such that for every ,
By Theorem 6.2, with , starting from , we can take , , to conclude that the algorithm converges to in iterations. Applying Lemma C.10 to bound completes the proof.∎
C.5 Proof of Lemma 4.4
By the 1-Lipschitzness of logistic loss, we know that
To bound this term, we can directly use Cauchy-Shwartz and obtain that:
We can further bound by applying Lemma 6.3, as from our choice of parameters , :
Now, let us denote as the data matrix. By the standard Gaussian matrix spectral norm bound we know that w.h.p. .
By (C.113) and our definition of as
Using the bound on that by Lemma C.10, we conclude the bound on .
In the end, by and the assumption that , it must hold that (since )
Appendix D Proofs for Small Learning Rate
In the setting of theorem 3.5, there exists a solution satisfying a) and b) for every ,
To prove Lemma 5.1, we can apply an identical analysis as 4.3 to show that for , . The rest of the proof follows from combining Theorem 6.2 and Lemma D.1.
D.2 Proof of Lemma 5.2
Recall the expression defined in (6.10). We first prove Lemma 6.4 here, which says that if is large (which means the loss is large as well), then the total gradient norm has to be big.
For notation simplicity, let’s fix and let
The gradient with respect to can be computed by
Let us denote the set as:
i.e., the loss gradient using activations computed by the noise component of scaled by a factor of .
By the Geometry of ReLU Lemma D.2, we have that w.h.p.
Where the last inequality is obtained since for every , has the same sign.
Note that , and therefore for every fixed , w.h.p. there are many such that . For each of them, we also know that , which implies that
Picking , we complete the proof by our choice of . ∎
Now we prove Proposition 6.5, which bounds the number of iterations in which can be large.
Consider the function , and let us define . We have that since ,
Now, by standard gradient descent analysis, we have that (as the logistic loss has Lipschitz derivative and the data have bounded norm):
Next, we will bound . We can compute
Plugging this back into (D.29), by the coupling Lemma 6.3 we obtain the bound
This implies that for ,
which implies that for every , as long as , we have:
By Lemma 6.4, we have that if , then . It follows that there will be at most such .
Finally, we complete the proof of Lemma 5.2 by noting that cannot be large for very many iterations, and therefore will not obtain much signal from the component of examples in .
The last line followed from the spectral norm bound on matrix . Let be defined as in Proposition 6.5. It follows that
Now the conclusion of the lemma follows by the assumption that and our choice of and in Theorem 3.5.
D.3 Proof of Lemma 5.3
We now prove the decomposition lemma of , Lemma 6.6. Recall our definition of as
For each step, we know that for every ,
Thus, multiplying by and summing, following our definition of in (5.1), we get
Now we focus on bounding the bottom term. We can see that
By Auxiliary Coupling Lemma 6.3 with , we know that for , w.h.p.
for some real values with
By the above calculation, (D.47), and Lemma 5.2, we have:
where the last inequality followed by our choice of parameters. ∎
Using the decomposition lemma, the conclusion of Lemma 5.3 now follows via computation.
Let us define the function as:
Note that is some kernel prediction function. Since each is distributed as a vector of i.i.d. spherical Gaussians, we know that for fixed :
In the above equation is the principle angle between . Since each is i.i.d., with basic concentration bounds, we know that w.h.p.
The last inequality uses the fact that w.h.p. for , .
Let us define ; then
Since the training loss is at , we know that (or else the loss would not be low).
Since , we can get:
Thus, either , which implies that
We now ready to conclude the proof: for randomly chosen , it holds that
Now using the same expansion of as before gives
Now we note that as the nonzero degrees in the polynomial expansion of are all odd, we have
The end result is that by Lemma 6.3, it will hold that:
Appendix E General case
Instead of using large learning rate and annealing to a small learning rate, the regularization effect also exists if we use a small learning rate () and large pre-activation noise and then decay the noise. Hence the update is given as:
where . However, the output of the network is given as:
Here is a (freshly random) gaussian variable at each iteration.
The same conclusion as in Theorem 3.4 holds if we first use noise level and then anneal to after iterations.
E.2 Extension to two layer convolution network
We make a simplifying assumption that are only supported on the last coordinates. The main theorem can be stated as the follows:
The same conclusions as in Theorem 3.4 and Theorem 3.5 hold if we replace the value of by and by in both the theorem and in Assumption 3.3.
We use this definition so that for every .
We denote for the weight of the second layer associated with each convolution.
The main difference between the convolution setting and the simple case is that there is only one hidden weight that is shared across channels. However, since the output layers of these channels have different weights, we can disentangle these channels and think of them as updating “separately”, which is given as the following two lemmas.
Here .
Since does not depend on the randomness of but only , fixing we know that since each entry of i.i.d. mean zero, we have:
Applying basic concentration bounds on , it holds that w.h.p. . Putting this back into Eq (E.12), we complete the proof.
We set , and with this lemma, we can restate Lemma 6.1, Lemma C.8 and Lemma D.1 in the following way: Suppose for every in the training set. Then the following lemmas hold by directly applying Lemma E.3.
In the setting of Theorem E.2, there exists a solution satisfying a) and b) for every :
In the setting of Theorem E.2, there exists a solution satisfying and for every :
In the setting of Theorem E.2, there exists a solution satisfying a) and b) for every ,
To prove these Lemmas, we can simply define for given in the original proof and apply Lemma E.3. The reason we need here is because there are channels instead of , so the square norm scales up by a factor of .
Now the next two convergence theorems follow directly from Lemma 4.1 and Lemma 4.3 and apply with initial learning rate .
In the setting of Theorem E.2 with initial learning rate , at some step , the training loss becomes smaller than . Moreover, we have .
In the setting of Theorem E.2, with initial learning rate , there exists , such that after iterations we have that
Moreover,
The following statement applies when we use a small initial learning rate and follows from the proof of Lemma 5.1.
In the setting of Theorem E.2, with initial learning rate , there exists with
such that after iterations. Moreover, we have that
Now, the following lemma directly adapts from Lemma 4.2 by applying Lemma E.4:
In the setting of Theorem E.2 with initial learning rate , w.h.p., for every ,
With these lemmas, we can directly conclude the following:
In the setting of Lemma E.9 with initial learning rate , the following holds:
Here
The final proof of Theorem E.2 follows directly from the proof of Theorem 3.4 and Theorem 3.5.
Appendix F Toolbox
Without loss of generality, we assume . Let and . We have that are independent random Gaussian variables with marginal distribution . Moreover, . Thus, is the same as , which has distribution . Let be a standard Gaussian, then
Note that . Since is a random gaussian matrix and , we know that w.h.p. for every we have .
Appendix G Additional Details for Experiments
In this section we provide additional details on the experimental results of Section 7. All of our models were trained using a single NVIDIA TitanXp GPU and our code is implemented via PyTorch. We note that for all our experiments, the mean pixel is subtracted from the CIFAR image and then the image is divided by the standard deviation pixel. We use mean and standard deviation values in the PyTorch WideResNet implementation: https://github.com/xternalz/WideResNet-pytorch.
In this section, we provide additional details for the mitigation strategy for a small learning rate described in Section 7. In Table 1, we demonstrate on CIFAR-10 images without data augmentation that this regularization can indeed counteract the negative effects of small learning rate, as we report a 4.72% increase in validation accuracy when adding noise to a small learning rate.
We train for all models for 200 epochs, annealing the learning rates by a factor of 0.2 at the 60th, 120th, and 150th epoch for all models. The large learning rate model uses an initial learning rate of 0.1, whereas the small learning rate model uses initial learning rate of 0.01. The large learning rate is a standard hyperparameter setting for the WideResNet16 architecture, and we chose the small learning rate by scaling this value down. The other hyperparameter settings are standard. We remove data augmentation from the training set to isolate the effect of adding noise.
We add noise before every time we apply the relu activation. As it is costly to add i.i.d. noise that is the size of the entire hidden layer, we sample Gaussian noise that has shape equal to the last two dimensions of the 4 dimensional hidden layer, where the first two dimensions are batch size and number of channels, and duplicate this over the first 2 dimensions. We sample different noise for every batch.
Our annealing schedule simply multiplies the noise level by a constant factor at every iteration. We tune the standard deviation of the noise to and the annealing rate to every iteration. We show results from a single trial as the small LR with noise algorithm already shows substantial improvement over vanilla small LR.
G.2 Additional Details on Patch-Augmented CIFAR-10
We first describe in greater detail our method for producing the patch. First, the split of our data is the following: of the 50000 CIFAR-10 training images, 10000 will contain no patch and 40000 will have a patch. We generate this split randomly before training and keep it fixed. During a single epoch, we iterate through all images, loading the 10000 clean images the same way each time. For the remaining 40000 examples, we use a patch-only image with probability 0.2 and a patch mixed with CIFAR image with probability 0.8. Thus, 20% of the updates are on clean images, 16% of updates are on patches only, and 64% of updates are on mixed images, but the actual split of the data is slightly different because of our implementation.
The patch will be located in the center of the image. We visualize the patches in Figure 4. We generate the patch as follows: before training begins, we sample a random vector with i.i.d entries from as well as for classes . Then to generate patch-only images, we add a scalar multiple of to if the example belongs to class . This scalar multiple is in the range for some we tune. We set coordinates not in the patch to . To generate images that contain both patch and a CIFAR example, we simply add . In all, the hyperparameters we tune are .
We must choose on the correct scale so that large and small learning rates don’t both ignore the patch or overfit to the patch. For the experiment shown, .
Our large initial learning rate model trains with learning rate 0.1, annealing to 0.004 at the 30th epoch. and the small LR model trains with fixed learning rate 0.004. Our small LR with noise model trains with fixed learning rate 0.004, initial noise , and decays the noise to 4e-6 after the 30th epoch. We train all models for 60 epochs total, starting from the same dataset and choice of patches. Table 2 demonstrates the final validation accuracy numbers on patch-augmented and clean data.
Now we provide additional evidence that the generalization disparity is indeed due to the learning order effect and not simply because the large learning rate model can already generalize better on clean CIFAR-10 images. To see this, we consider the generalization error of models trained on 10000 clean CIFAR images: the small LR model achieves 65% validation accuracy, and the large LR model achieves 76% validation accuracy. For comparison, on the full clean dataset the small LR model achieves 83% validation accuracy whereas the large LR model achieves 90% accuracy.
We note that the final number of 69.89% clean image accuracy for the small LR model trained on the patch dataset is much closer to 65% than 83%, suggesting that it is indeed using a fraction of the available CIFAR samples because of learning order. On the other hand, the large LR model achieves final clean validation accuracy of 87.61% when trained on the patch dataset, which is very close to the 90% that is achievable training on the full clean dataset. This indicates that the large LR model is still using the majority of the images to learn CIFAR examples before annealing, as it has not yet memorized the patches.