SGD on Neural Networks Learns Functions of Increasing Complexity

Preetum Nakkiran, Gal Kaplun, Dimitris Kalimeris, Tristan Yang, Benjamin L. Edelman, Fred Zhang, Boaz Barak

Introduction

Neural networks have been extremely successful in modern machine learning, achieving the state-of-the-art in a wide range of domains, including image-recognition, speech-recognition, and game-playing . Practitioners often train deep neural networks with hundreds of layers and millions of parameters and manage to find networks with good out-of-sample performance. However, this practical prowess is accompanied by feeble theoretical understanding. In particular, we are far from understanding the generalization performance of neural networks—why can we train large, complex models on relatively few training examples and still expect them to generalize to unseen examples? It has been observed in the literature that the classical generalization bounds that guarantee small generalization gap (i.e., the gap between train and test error) in terms of VC dimension or Rademacher complexity do not yield meaningful guarantees in the context of real neural networks. More concretely, for most if not all real-world settings, there exist neural networks which fit the train set exactly, but have arbitrarily bad test error .

The existence of such “bad” empirical risk minimizers (ERMs) with large gaps between the train and test error means that the generalization performance of deep neural networks depends on the particular algorithm (and initialization) used in training, which is most often stochastic gradient descent (SGD). It has been conjectured that SGD provides some form of “implicit regularization” by outputting “low complexity” models, but it is safe to say that the precise notion of complexity and the mechanism by which this happens are not yet understood (see related works below).

In this paper, we provide evidence for this hypothesis and shed some light on how it comes about. Specifically, our thesis is that the dynamics of SGD play a crucial role and that SGD finds generalizing ERMs because:

In the initial epochs of learning, SGD has a bias towards simple classifiers as opposed to complex ones; and

in later epochs, SGD is relatively stable and retains the information from the simple classifier it obtained in the initial epochs.

Figure 1 illustrates qualitatively the predictions of this thesis for the dynamics of SGD over time. In this work, we give experimental and theoretical evidence for both parts of this thesis. While several quantitative measures of complexity of neural networks have been proposed in the past, including the classic notions of VC dimension, Rademacher complexity and margin , we do not propose such a measure here. Our focus is on the qualitative question of how much of SGD’s early progress in learning can be explained by simple models. Our main findings are the following:

In natural settings, the initial performance gains of SGD on a randomly initialized neural network can be attributed almost entirely to its learning a function correlated with a linear classifier of the data.

In natural settings, once SGD finds a simple classifier with good generalization, it is likely to retain it, in the sense that it will perform well on the fraction of the population classified by the simple classifier, even if training continues until it fits all training samples.

We state these claims broadly, using “in natural settings” to refer to settings of network architecture, initialization, and data distributions that are used in practice. We emphasize that this holds for vanilla SGD with standard architecture and random initialization, without using any regularization, dropout, early stopping or other explicit methods of biasing towards simplicity.

Some indications for variants of Claim 2 have been observed in practice, but we provide further experimental evidence and also show (Theorem 1) a simple setting where it provably holds. Our main novelty is Claim 1, which is established via several experiments described in Sections 3 and 4. We emphasize that our claims do not imply that during the early stages of training the decision boundary of the neural network is linear, but rather that there often exists a linear classifier highly agreeing with the network’s predictions. The decision boundary itself may be very complex.Figure 6 in Appendix C provides a simple illustration of this phenomenon.

The other core contribution of this paper is a novel formulation of a mutual-information based measure to quantify how much of the prediction success of the neural network produced by SGD can be attributed to a simple classifier. We believe this measure is of independent interest.

While our main findings relate to linear classifiers, our methodology extends beyond this. We conjecture that generally, the dynamics of SGD are such that it initially learns simpler components of its final classifier, and retains these as it continues to learn more and more complex parts (see Figure 2). We provide evidence for this conjecture in Section 4.

This paper is focused on binary classification tasks but our mutual-information based definitions and methodology can be extended to multi-class classification. Preliminary results suggest that our results continue to hold.

There is a substantial body of work that attempts to understand the generalization of (deep) neural networks, tackling the problem from different perspectives. Previous works by Hardt et. al. (2016) and Kuzborskij & Lampert (2017) argue that generalization is due to stability. Neyshabur et. al. (2015); Keskar et. al. (2016); Bartlett et. al. (2016) consider margin-based approaches , while Dziugaite & Roy (2017); Neyshabur et. al. (2017); Neyshabur et. al. (2018); Golowich et. al. (2018); Pérez et. al. (2019); Zhou et. al. (2019) focus on PAC-Bayes analysis and norm-based bounds . Arora et. al. (2018) propose a compression-based approach.

The implicit bias of (stochastic) gradient descent was also studied in various contexts, including linear classification, matrix factorization and neural networks. This includes the works of Brutzkus et. al. (2017); Gunasekar et. al. (2017); Soudry et. al. (2018); Gunasekar et. al. (2018); Li et. al. (2018); Wu et. al. (2019) and Ji & Telgarsky (2019) . There are also recent works proving generalization of overparameterized networks, by analyzing the specific behavior of SGD from random initialization . These results are so far restricted to simplified settings.

Several prior works propose measures of the complexity of neural networks, and claim that training involves learning simple patterns . However, a key difference in our work is that our measures are intrinsic to the classification function and data-distribution (and do not depend on the representation of the classifier, or its behavior outside the data distribution). Moreover, our measures address the extent by which one classifier “explains” the performance of another.

The concept of mutual information has also been used in the study of neural networks, though in different ways than ours. For example, Schwartz-Ziv and Tishby (2017) use it to argue that a network compresses information as a means of noise reduction, saving only the most meaningful representation of the input.

Paper Organization.

We begin by defining our mutual-information based formalization of Claims 1 and 2 in Section 2. In Section 3, we establish the main result of the paper—that for many synthetic and real data sets, the performance of neural networks in the early phase of training is well explained by a linear classifier. In Section 4, we investigate extensions to non-linear classifiers (see also Remark 1). We make the case that as training proceeds, SGD moves beyond this “linear learning” regime, and learns concepts of increasing complexity. In Section 5 we focus on the overfitting regime. We provide a simple theoretical setting where, provably, if we start from a “simple” generalizable solution, then overfitting to the train set will not hurt generalization. Moreover, the overfit classifier retains the information from the initial classifier. Finally, in Section 6 we discuss future directions.

Performance Correlation via Mutual Information

In this section, we present our measures for the contribution of a “simple classifier” to the performance of a “more complex” one. This allows us to state what it means for the performance of a neural network to be “almost entirely explained by a linear classifier”, formalizing Claims 1 and 2.

Key to our formalism are the quantities of mutual information and conditional mutual information. Recall that for three random variables X,Y,ZX,Y,Z, the mutual information between XX and YY is defined as I(X;Y)=H(Y)H(YX)I(X;Y)=H(Y)-H(Y|X) and the conditional mutual information between XX and YY conditioned on ZZ is defined as I(X;YZ)=H(YZ)H(YX,Z)I(X;Y|Z)=H(Y|Z)-H(Y|X,Z), where HH is the (conditional) entropy.

2 Performance correlation

For random variables F,L,YF,L,Y we define the performance correlation of FF and LL as

Throughout this paper, we denote by ftf_{t} the classifier SGD outputs on a randomly-initialized neural network after tt gradient steps, and denote by FtF_{t} the corresponding random variable ft(X)f_{t}(X). We now formalize Claim 1 and Claim 2:

Claim 2 (Restated) . In natural settings, for t>T0t>T_{0}, μY(Ft;L)\mu_{Y}(F_{t};L) plateaus at value I(L;Y)\approx I(L;Y) and does not shrink significantly even if training continues until SGD fits all the training set.

SGD Learns a Linear Model First

In this section, we provide experimental evidence for Claim 1—the first phase of SGD is dominated by “linear learning”—and Claim 2—at later stages SGD retains information from early phases. We demonstrate these claims by evaluating our information-theoretic measures empirically on real and simulated classification tasks.

We provide a brief description of our experimental setup here; a full description is provided in Appendix B. We consider the following binary classification tasks We focus on binary classification because: (1) there is a natural choice for the “simplest” model class (i.e., linear models), and (2) our mutual-information based metrics can be more accurately estimated from samples. We have preliminary work extending our results to the multi-class setting. :

Binary MNIST: predict whether the image represents a number from to 44 or from 55 to 99.

CIFAR-10 Animals vs Objects: predict whether the image represents an animal or an object.

CIFAR-10 First 55 vs Last 55: predict whether the image is in classes {04}\{0\dots 4\} or {59}\{5\dots 9\}.

We train neural networks with standard architectures: CNNs for image-recognition tasks and Multi-layer Perceptrons (MLPs) for the other tasks. We use standard uniform Xavier initialization and we train with binary cross-entropy loss. In all experiments, we use vanilla SGD without regularization (e.g., dropout, weight decay) for simplicity and consistency. (Preliminary experiments suggest our results are robust with respect to these choices). We use a relatively small step-size for SGD, in order to more closely examine the early phase of training.

Results and Discussion.

The results of our experiments are presented in Figure 4. We observe the following similar behaviors across several architectures and datasets:

Define the first phase of training as all steps tT0t\leq T_{0}, where T0T_{0} is the first SGD step such that the network’s performance I(Ft;Y)I(F_{t};Y) reaches the linear model’s performance I(L;Y)I(L;Y). Now:

In the following epochs, for t>T0t>T_{0}, μY(Ft;L)\mu_{Y}(F_{t};L) plateaus around I(L;Y)I(L;Y). This means that FtF_{t} retains its correlation with LL, which keeps explaining as much of FtF_{t}’s generalization performance as possible.

Observation (1) provides strong support for Claim 1. Since neural networks are a richer class than linear classifiers, a priori one might expect that throughout the learning process, some of the growth in the mutual information between the label YY and the classifier’s output FtF_{t} will be attributable to the linear classifier, and some of this growth will be attributable to a more complex classifier. However, what we observe is a relatively clean (though not perfect) separation of the learning process while in the initial phase, all of the mutual information between FtF_{t} and YY disappears if we condition on LL.

Observation (2) can be seen as offering support to Claim 2. If SGD “forgets” the linear model as it continues to fit the training examples, then we would expect the value of μY(Ft;L)\mu_{Y}(F_{t};L) to shrink with time. However, this does not occur. Since the linear classifier itself would generalize, this explains at least part of the generalization performance of FtF_{t}. To fully explain the generalization performance, we would need to extend this theory to models more complex than linear; some preliminary investigations are given in Section 4.

Table 1 summarizes the qualitative behavior of several information theoretic quantities we observe across different datasets and architectures. We stress that these phenomena would not occur for an arbitrary learning algorithm that increases model test accuracy. Rather, it is SGD (with a random, or at least “non-pathological” initialization, see Section 5 and Figure 8) that produces such behavior.

Beyond Linear: SGD Learns Functions of Increasing Complexity

In this section we investigate Remark 1—that SGD learns functions of increasing complexity—through the lens of the mutual information framework, and provide experimental evidence supporting the natural extension of the results from Section 3 to models more complex than linear.

Conjecture 1 (Beyond linear classifiers: Remark 1 restated). There exist increasingly complex functions (g1,g2,...)(g_{1},g_{2},...) under some measure of complexity, and a monotonically increasing sequence (T1,T2,...)(T_{1},T_{2},...) such that μY(Ft;Gi)I(Ft;Y)\mu_{Y}(F_{t};G_{i})\approx I(F_{t};Y) for tTit\leq T_{i} and μY(Ft;Gi)I(Gi;Y)\mu_{Y}(F_{t};G_{i})\approx I(G_{i};Y) for t>Tit>T_{i}. Note that implicit in our conjecture is that each GiG_{i} is itself explained by G<iG_{<i}, so we should not have to condition on all previous GiG_{i}’s; i.e. μY(Ft;(G1:i))μY(Ft;Gi)\mu_{Y}(F_{t};(G_{1:i}))\approx\mu_{Y}(F_{t};G_{i}).

It is problematic to show Conjecture 1 in full generality, as the correct measure of complexity is unclear; it may depend on the distribution, architecture, and even initialization. Nevertheless, we are able to support it in the image-classification setting, parameterizing complexity using the number of convolutional layers.

In order to explore the behavior of more complex classifiers we consider the CIFAR “First 5 vs. Last 5” task introduced in Section 3, for which there is no high-accuracy linear classifier. We observed that the performance of various architectures on this task was similar to their performance on the full 10-way CIFAR classification task, which supports the relevance of this example to standard use-cases.Potentially since we need to distinguish between visually similar classes, e.g. automobile/truck or cat/dog.

As our model ff, we train an 1818-layer pre-activation ResNet described in which achieves over 9090% accuracy on this task. For the simple models gig_{i}, we use convolutional neural networks corresponding to the 22nd, 44th, and 66th shallowest layers of the network for ff. Similarly to Section 3, the models gig_{i} are trained on the images labeled by ff_{\infty} (that is the model at the end of training). For more details refer to Appendix B: “Finding the Conditional Models".

Results and Discussion.

Our results are illustrated in Figure 5. We can see a separation in phases for learning, where all curves μY(Ft;Gi)\mu_{Y}(F_{t};G_{i}) are initially close to I(Ft;Y)I(F_{t};Y), before each successively plateaus as training progresses. Moreover, note that I(Gi;Y)I(G_{i};Y) remains flat in the overfitting regime for all three ii, demonstrating that SGD does not “forget” the simpler functions as stated in Claim 2.

Overfitting Does Not Hurt Generalization

In the previous, sections we investigated the early and middle phases of SGD training. In this section, we focus on the last phase, i.e. the overfitting regime. In practice, we often observe that in late phases of training, train error goes to , while test error stabilizes, despite the fact that bad ERMs exist. The previous sections suggest that this phenomenon is an inherent property of SGD in the overparameterized setting, where training starts from a “simpler” model at the beginning of the overfitting regime and does not forget it even as it learns more “complex” models and fits the noise.

In what follows, we demonstrate this intuition formally in an illustrative simplified setting where, provably, a heavily overparameterized (linear) model trained with SGD fits the training set exactly, and yet its population accuracy is optimal for a class of “simple” initializations.

We confine ourselves to the linear classification setting. To formalize notions of “simple” we consider a data distribution that explicitly decomposes into a component explainable by a sparse classifier, and a remaining orthogonal noisy component on which it is possible to overfit. Specifically, we define the data distribution D\mathcal{D} as follows:

Consider training a linear classifier via minimizing the empirical square loss using SGD. Let ε>0\varepsilon>0 be a small constant and let the initial vector w0\bm{w}_{0} satisfy w0(1)n0.99\bm{w}_{0}(1)\geq-n^{0.99}, and w0(i)12pε|\bm{w}_{0}(i)|\leq 1-2p-\varepsilon for all i>1i>1. Then, with high probability, sample accuracy approaches 1 and population accuracy approaches 1p1-p as the number of gradient steps goes to infinity.

The displacement of the weight vector from initialization will always lie in the span of the sample vectors which, because the samples are sparse, is in expectation almost orthogonal to the population. Moreover, as long as the initialization is bounded sufficiently, the first coordinate of the learned vector will approach a constant. The full proof is deferred to Appendix A. ∎

Discussion and Future Work

Our findings yield new insight into the inductive bias of SGD on deep neural networks. In particular, it appears that SGD increases the complexity of the learned classifier as training progresses, starting by learning an essentially linear classifier.

There are several natural questions that arise from our work. First, why does this “linear learning” occur? We pose this problem of understanding why Claims 1 and 2 are true as an important direction for future work. Second, what is the correct measure of complexity which SGD increases over time? That is, we would like the correct formalization of Conjecture 1—ideally with a measure of complexity that implies generalization. We view our work as an initial step in a framework towards understanding why neural networks generalize, and we believe that theoretically establishing our claims would be significant progress in this direction.

Acknowledgements. We thank all of the participants of the Harvard ML Theory Reading Group for many useful discussions and presentations that motivated this work. We especially thank: Noah Golowich, Yamini Bansal, Thibaut Horel, Jarosław Błasiok, Alexander Rakhlin, and Madhu Sudan.

This work was supported by NSF awards CCF 1565264, CNS 1618026, CCF 1565641, CCF 1715187, NSF GRFP Grant No. DGE1144152, a Simons Investigator Fellowship, and Investigator Award.

References

Appendix A Proof of Theorem 1

For the simplicity of our argument, we work with the following assumptions on the training set.

Label noise: Exactly pp fraction of the sample points have their first coordinate flipped.

Orthogonality: the non-zero coordinates are distinct for all nn data points (except for the first coordinate)

Notice that by the fact that n=o(d)n=o(\sqrt{d}) and a simple union bound, Assumption 2 holds with high probability. For each i[d]i\in[d], we let j(i)j(i) denote the index jj that satisfies xj(i)=1\bm{x}_{j}(i)=1, if it exists. To simplify the notation, we assume that all labels yiy_{i} are 11; this is without loss of generality, since one can always replace xi\bm{x}_{i} with yixiy_{i}\bm{x}_{i}.

In order to prove Theorem 1, we will precisely characterize the limiting behavior of SGD in this setting. We remark that as the optimization objective is strongly convex, SGD and GD with appropriate choices of step size is guaranteed to converge to the global minimum.

as the number of steps goes to infinity. Moreover, we have

where s\bm{s} is the first column of X\bm{X} and η=12p\eta=1-2p, and for i1i\neq 1

We focus on the gradient descent case; the proof for SGD is analogous. Consider the empirical loss L(w)\mathcal{L}(\bm{w}). By the definition of gradient descent, the iterations always stay in the affine space V=w0+RowSpan(X)V=\bm{w}_{0}+\text{RowSpan}(X). Gradient descent solves the linear least squares problem minwVL(w)\min_{\bm{w}\in V}\mathcal{L}(\bm{w}). We claim that ww^{\prime} is indeed the optimal solution to this program, and thus gradient descent converges to it.

First, one can check XXT\bm{X}\bm{X}^{T} is non-singular under Assumption 2, since XXT=I+ssT\bm{X}\bm{X}^{T}=\bm{I}+\bm{s}\bm{s}^{T} and sTs1\bm{s}^{T}\bm{s}\neq-1. Moreover, wV\bm{w}^{\prime}\in V since XT(XXT)1\bm{X}^{T}(\bm{X}\bm{X}^{T})^{-1} is the orthogonal projector onto the row space of X\bm{X}. Finally, we check that w\bm{w}^{\prime} achieves zero empirical loss

For the first coordinate of w\bm{w}^{\prime}, by the Sherman-Morrison formula [11, p. 51],

Substituting this into (1) and simplifying yields claims (2) and (3). ∎

Let k(j)k(j) denote the non-zero coordinate of xj\bm{x}_{j} (besides the first coordinate). First we note that

Because each xj(k(j))\bm{x}_{j}(k(j)) is a Bernoulli(p)\text{Bernoulli}(p) random variable and every w0(k(j))1\bm{w}_{0}(k(j))\leq 1, we can apply Chebyshev’s inequality to further deduce that with high probability, sTXw0=nw0(1)+O(n)\bm{s}^{T}\bm{X}\bm{w}_{0}=n\bm{w}_{0}(1)+O(\sqrt{n}). Substituting this into (2) of Lemma 1, letting w\bm{w}^{\prime} be the weight vector SGD converges to, we obtain

By (3) of Lemma 1, for every coordinate ii such that xj(i)=0\bm{x}_{j}(i)=0 for all jj, we have w(i)=w0(i)\bm{w}^{\prime}(i)=\bm{w}_{0}(i). Consider a point x\bm{x} drawn from the population, and let kk be the index of the non-zero coordinate of x\bm{x}. With high probability, kk(j)k\neq k(j) for all j[n]j\in[n]. With probability 1p1-p, x(1)=1\bm{x}(1)=1, and in this case we obtain

For sufficiently large nn (corresponding to sufficiently large dd) this quantity is always positive, so with probability approaching 1p1-p the model correctly classifies x\bm{x}. ∎

Appendix B Experimental Setup and Results for Sections 3 and 4

We used the following four datasets in our experiments.

Binary MNIST: predict whether the image represents a number from 0 to 4 or from 5 to 9. It admits a linear classifier with accuracy 87%\approx 87\%,

CIFAR-10 Animals vs Objects: predict whether the image represents an animal or an object. In order not to enforce bias towards any of the classes we included all the 4 object classes (airplane, automobile, ship, truck) and only 4 out of 6 of the animal ones (bird, cat, dog, horse). Hence the number of positive and negative samples are the same. CIFAR-10 Animals vs Objects admits a linear classifier with accuracy 75%\approx 75\%,

CIFAR-10 First 55 vs Last 55: predict whether an image belongs to any of the first 5 classes of CIFAR10 (airplane, automobile, bird, cat, deer) or the last 5 classes (dog, frog, horse, ship, truck). CIFAR-10 First 55 vs Last 55 does not admit a linear classifier with satisfying accuracy. The best linear classifier achieves accuracy of 58%\approx 58\%,

In the cases of datasets (i), (ii), and (iii) we created the train and tests sets by relabeling the train and test sets of MNIST and CIFAR10 with {0,1}\{0,1\} labels according to the specific dataset (and excluded the images that are not relevant). All experiments are repeated 10 times with random initialization; standard deviations are reported in the figures (shaded area).

Model details.

Our results were consistent across various architectures and hyperparameter choices. For the Sinusoid distribution we train a 2-layer MLP, with ReLu activations. Each layer is 256 neurons. For the MNIST and CIFAR tasks in section 3 we train a CNN with 4 2D-Conv layers; each layer has 32 filters of size 3×33\times 3. After the first and second layers we have a 2×22\times 2 Max-Pooling layer, and at the end of the 4 convolutional layers we have two Dense layers of 2000 units. The activations on all intermediate layers are ReLUs. Across all architectures the last layer is a sigmoid neuron.

Training procedure.

We initialize the neural networks with Uniform Xavier . We note that in all the experiments we use vanilla SGD without regularization (e.g. dropout) since we want to isolate and investigate purely the effect of the optimization algorithm. We use batch size of 32 for MLP’s and 64 for CNNs. For the MLP’s in section 3, the learning rate is 0.01. For the CNN’s in Section 3, the learning rate is 0.001. For Section 4, we train the resnet and all smaller CNN’s using SGD with momentum 0.90.9, batch size 128128 and learning rate 0.010.01.

In Section 4, we perform a similar procedure: After training ff_{\infty}, we train simple models gig_{i} on the predictions of ff_{\infty} on the train set, via SGD.

Estimating the Mutual Information.

Let f,gf,g be two classifiers, and yy be the true labels. In order to estimate our mutual-information metrics, we use the empirical distribution of (F,G,Y)(F,G,Y) on the test set. For example, to estimate I(F;YG)I(F;Y|G) we use the definition

where p(f,y,g)p(f,y,g) is the joint probability density function of (F,Y,G)(F,Y,G) and p(fg)p(f|g), etc. are the conditional density functions. To estimate this quantity, we first compute the empirical distribution of (f,y,g){0,1}3(f,y,g)\in\{0,1\}^{3} over the test set. Let p^(f,y,g)\hat{p}(f,y,g) be this empirical density function. Then we estimate I(F;YG)I(F;Y|G) by evaluating Equation 6 using p^\hat{p} in place of pp.

Further Quantitative Results.

In Tables 2 and 3 we provide further quantitative results for our experiments in Sections 3 and 4 respectively.

Appendix C Additional Plots