Learning Overparameterized Neural Networks via Stochastic Gradient Descent on Structured Data
Yuanzhi Li, Yingyu Liang
Introduction
Neural networks have achieved great success in many applications, but despite a recent increase of theoretical studies, much remains to be explained. For example, it is empirically observed that learning with stochastic gradient descent (SGD) in the overparameterized setting (i.e., learning a large network with number of parameters larger than the number of training data points) does not lead to overfitting . Some recent studies use the low complexity of the learned solution to explain the generalization, but usually do not explain how the SGD or its variants favors low complexity solutions (i.e., the inductive bias or implicit regularization) . It is also observed that overparameterization and proper random initialization can help the optimization , but it is also not well understood why a particular initialization can improve learning. Moreover, most of the existing works trying to explain these phenomenons in general rely on unrealistic assumptions about the data distribution, such as Gaussian-ness and/or linear separability .
This paper thus proposes to study the problem of learning a two-layer overparameterized neural network using SGD for classification, on data with a more realistic structure. In particular, the data in each class is a mixture of several components, and components from different classes are well separated in distance (but the components in each class can be close to each other). This is motivated by practical data. For example, on the dataset MNIST , each class corresponds to a digit and can have several components corresponding to different writing styles of the digit, and an image in it is a small perturbation of one of the components. On the other hand, images that belong to the same component are closer to each other than to an image of another digit. Analysis in this setting can then help understand how the structure of the practical data affects the optimization and generalization.
In this setting, we prove that when the network is sufficiently overparameterized, SGD provably learns a network close to the random initialization and with a small generalization error. This result shows that in the overparameterized setting and when the data is well structured, though in principle the network can overfit, SGD with random initialization introduces a strong inductive bias and leads to good generalization.
Our result also shows that the overparameterization requirement and the learning time depends on the parameters inherent to the structure of the data but not on the ambient dimension of the data. More importantly, the analysis to obtain the result also provides some interesting theoretical insights for various aspects of learning neural networks. It reveals that the success of learning crucially relies on overparameterization and random initialization. These two combined together lead to a tight coupling around the initialization between the SGD and another learning process that has a benign optimization landscape. This coupling, together with the structure of the data, allows SGD to find a solution that has a low generalization error, while still remains in the aforementioned neighborhood of the initialization. Our work makes a step towrads explaining how overparameterization and random initialization help optimization, and how the inductive bias and good generalization arise from the SGD dynamics on structured data. Some other more technical implications of our analysis will be discussed in later sections, such as the existence of a good solution close to the initialization, and the low-rankness of the weights learned. Complementary empirical studies on synthetic data and on the benchmark dataset MNIST provide positive support for the analysis and insights.
Related Work
Generalization of neural networks. Empirical studies show interesting phenomena about the generalization of neural networks: practical neural networks have the capacity to fit random labels of the training data, yet they still have good generalization when trained on practical data . These networks are overparameterized in that they have more parameters than statistically necessary, and their good generalization cannot be explained by naïvely applying traditional theory. Several lines of work have proposed certain low complexity measures of the learned network and derived generalization bounds to better explain the phenomena. proved spectrally-normalized margin-based generalization bounds, derived bounds from a PAC-Bayes approach, and derived bounds from the compression point of view. They, in general, do not address why the low complexity arises. This paper takes a step towards this direction, though on two-layer networks and a simplified model of the data.
Overparameterization and implicit regularization. The training objectives of overparameterized networks in principle have many (approximate) global optima and some generalize better than the others , while empirical observations imply that the optimization process in practice prefers those with better generalization. It is then an interesting question how this implicit regularization or inductive bias arises from the optimization and the structure of the data. Recent studies are on SGD for different tasks, such as logistic regression and matrix factorization . More related to our work is , which studies the problem of learning a two-layer overparameterized network on linearly separable data and shows that SGD converges to a global optimum with good generalization. Our work studies the problem on data with a well clustered (and potentially not linearly separable) structure that we believe is closer to practical scenarios and thus can advance this line of research.
Theoretical analysis of learning neural networks. There also exists a large body of work that analyzes the optimization landscape of learning neural networks . They in general need to assume unrealistic assumptions about the data such as Gaussian-ness, and/or have strong assumptions about the network such as using only linear activation. They also do not study the implicit regularization by the optimization algorithms.
Problem Setup
In this work, a two-layer neural network with ReLU activation for -classes classification is given by such that for each :
Let us define the support of a distribution with density over as the distance between two sets as and the diameter of a set as Then we are ready to make the assumptions about the data.
(Separability) There exists such that for every and every , Moreover, for every ,The assumption can be made to for any by paying a large polynomial in in the sample complexity. We will not prove it in this paper because we would like to highlight the key factors.
(Normalization) Any from the distribution has .
A few remarks are worthy. Instead of having one distribution for one class, we allow an arbitrary distributions in each class, which we believe is a better fit to the real data. For example, in MNIST, a class can be the number 1, and can be the different styles of writing ( or or ).
Assumption (A2) is for simplicity, while (A1) is our key assumption. With distributions inside each class, our assumption allows data that is not linearly separable, e.g., XOR type data in where there are two classes, one consisting of two balls of diameter with centers and and the other consisting of two of the same diameter with centers and . See Figure 3 in Appendix C for an illustration. Moreover, essentially the only assumption we have here is . When , , which is the minimal requirement on the order of for the distribution to be efficiently learnable. Our work allows larger , so that the data can be more complicated inside each class. In this case, we require the separation to also be higher. When we increase to refine the distributions inside each class, we should expect the diameters of each distribution become smaller as well. As long as the rate of diameter decreasing in each distribution is greater than the total number of distributions, then our assumption will hold.
Assumptions about the learning process. We will only learn the weight to simplify the analysis. Since the ReLU activation is positive homogeneous, the effect of overparameterization can still be studied, and a similar approach has been adopted in previous work . So the network is also written as for .
We assume the learning is from a random initialization:
The learning process minimizes the cross entropy loss over the softmax, defined as:
Let denote the cross entropy loss for a particular point .
We consider a minibatch SGD of batch size , number of iterations and learning rate as the following process: Randomly divide the total training examples into batches, each of size . Let the indices of the examples in the -th batch be . At each iteration, the update isStrictly speaking, does not have gradient everywhere due to the non-smoothness of ReLU. One can view as a convenient notation for the right hand side of (1).
Main Result
Suppose the assumptions (A1)(A2)(A3) are satisfied. Then for every , there is such that for every , after doing a minibatch SGD with batch size and learning rate for iterations, with high probability:
Our theorem implies if the data satisfies our assumptions, and we parametrize the network properly, then we only need polynomial in many samples to achieve a good prediction error. This error is measured directly on the true distribution , not merely on the input data used to train this network. Our result is also dimension free: There is no dependency on the underlying dimension of the data, the complexity is fully captured by . Moreover, no matter how much the network is overparameterized, it will only increase the total iterations by factors of . So we can overparameterize by an sub-exponential amount without significantly increasing the complexity.
Furthermore, we can always treat each input example as an individual distribution, thus is always zero. In this case, if we use batch size for iterations, we would have . Then our theorem indicate that as long as , where is the minimal distance between each examples, we can actually fit arbitrary labels of the input data. However, since the total iteration only depends on , when but the input data is actually structured (with small and large ), then SGD can actually achieve a small generalization error, even when the network has enough capacity to fit arbitrary labels of the training examples (and can also be done by SGD). Thus, we prove that SGD has a strong inductive bias on structured data: Instead of finding a bad global optima that can fit arbitrary labels, it actually finds those with good generalization guarantees. This gives more thorough explanation to the empirical observations in .
Intuition and Proof Sketch for A Simplified Case
To train a neural network with ReLU activations, there are two questions need to be addressed:
Why can SGD optimize the training loss? Or even finding a critical point? Since the underlying network is highly non-smooth, existing theorems do not give any finite convergence rate of SGD for training neural network with ReLUs activations.
Why can the trained network generalize? Even when the capacity is large enough to fit random labels of the input data? This is known as the inductive bias of SGD.
This work takes a step towards answering these two questions. We show that when the network is overparameterized, it becomes more “pseudo smooth”, which makes it easir for SGD to minimize the training loss, and furthermore, it will not hurt the generalization error. Our proof is based on the following important observation:
The more we overparameterize the network, the less likely the activation pattern for one neuron and one data point will change in a fixed number of iterations.
This observation allows us to couple the gradient of the true neural network with a “pseudo gradient” where the activation pattern for each data point and each neuron is fixed. That is, when computing the “pseudo gradient”, for fixed , whether the -th hidden node is activated on the -th data point will always be the same for different . (But for fixed , for different or , the sign can be different.) We are able to prove that unless the generalization error is small, the “pseudo gradient” will always be large. Moreover, we show that the network is actually smooth thus SGD can minimize the loss.
We then show that when the number of hidden neurons increases, with a properly decreasing learning rate, the total number of iterations it takes to minimize the loss is roughly not changed. However, the total number of iterations that we can couple the true gradient with the pseudo one increases. Thus, there is a polynomially large so that we can couple these two gradients until the network reaches a small generalization error.
Here we illustrate the proof sketch for a simplified case and Appendix A provides the proof. The proof for the general case is provided in Appendix B. In the simplified case, we further assume:
(No variance) Each is a single data point , and also we are doing full batch gradient descent as opposite to the minibatch SGD.
Then we reload the loss notation as and the gradient is
Following the intuition above, we define the pseudo gradient as
where it uses instead of as in the true gradient. That is, the activation pattern is set to be that in the initialization. Intuitively, the pseudo gradient is similar to the gradient for a pseudo network (but not exactly the same), defined as Coupling the gradients is then similar to coupling the networks and .
For simplicity, let and when , Roughly, if is small, then is relatively larger compared to the other , so the classification error is small.
We prove the following two main lemmas. The first says that at each iteration, the total number of hidden units whose gradient can be coupled with the pseudo one is quite large.
The second lemma says that the pseudo gradient is large unless the error is small.
Discussion of Insights from the Analysis
Our analysis, though for learning two-layer networks on well structured data, also sheds some light upon learning neural networks in more general settings.
Generalization. Several lines of recent work explain the generalization phenomenon of overparameterized networks by low complexity of the learned networks, from the point views of spectrally-normalized margins , compression , and PAC-Bayes .
Our analysis has partially explained how SGD (with proper random initialization) on structured data leads to the low complexity from the compression and PCA-Bayes point views. We have shown that in a neighborhood of the random initialization, w.h.p. the gradients are similar to those of another benign learning process, and thus SGD can reduce the error and reach a good solution while still in the neighborhood. The closeness to the initialization then means the weights (or more precisely the difference between the learned weights and the initialization) can be easily compressed. In fact, empirical observations have been made and connected to generalization in . Furthermore, explicitly point out such a compression using a helper string (corresponding to the initialization in our setting). also point out that the compression view can be regarded as a more explicit form of the PAC-Bayes view, and thus our intuition also applies to the latter.
The existence of a solution of a small generalization error near the initialization is itself not obvious. Intuitively, on structured data, the updates are structured signals spread out across the weights of the hidden neurons. Then for prediction, the random initialized part in the weights has strong cancellation, while the structured signal part in the weights collectively affects the output. Therefore, the latter can be much smaller than the former while the network can still give accurate predictions. In other words, there can be a solution not far from the initialization with high probability.
Some insight is provided on the low rank of the weights. More precisely, when the data are well clustered around a few patterns, the accumulated updates (difference between the learned weights and the initialization) should be approximately low rank, which can be seen from checking the SGD updates. However, when the difference is small compared to the initialization, the spectrum of the final weight matrix is dominated by that of the initialization and thus will tend to closer to that of a random matrix. Again, such observations/intuitions have been made in the literature and connected to compression and generalization (e.g., ).
Effect of random initialization. Our analysis also shows how proper random initializations helps the optimization and consequently generalization. Essentially, this guarantees that w.h.p. for weights close to the initialization, many hidden ReLU units will have the same activation patterns (i.e., activated or not) as for the initializations, which means the gradients in the neighborhood look like those when the hidden units have fixed activation patterns. This allows SGD makes progress when the loss is large, and eventually learns a good solution. We also note that it is essential to carefully set the scale of the initialization, which is a extensively studied topic . Our initialization has a scale related to the number of hidden units, which is particularly useful when the network size is varying, and thus can be of interest in such practical settings.
Experiments
This section aims at verifying some key implications: (1) the activation patterns of the hidden units couple with those at initialization; (2) The distance from the learned solution from the initialization is relatively small compared to the size of initialization; (3) The accumulated updates (i.e., the difference between the learned weight matrix and the initialization) have approximately low rank. These are indeed supported by the results on the synthetic and the MNIST data. Additional experiments are presented in Appendix D.
The network structure and the learning process follow those in Section 3; the number of hidden units varies in the experiments, and the weights are initialized with . On the synthetic data, the SGD is run for steps with batch size and learning rate . On MNIST, the SGD is run for steps with batch size and learning rate .
Besides the test accuracy, we report three quantities corresponding to the three observations/implications to be verified. First, for coupling, we compute the fraction of hidden units whose activation pattern changed compared to the time at initialization. Here, the activation pattern is defined as if the input to the ReLU is positive and otherwise. Second, for distance, we compute the relative ratio , where is the weight matrix at time . Finally, for the rank of the accumulated updates, we plot the singular values of where is the final step. All experiments are repeated 5 times, and the mean and standard deviation are reported.
Results. Figure 1 shows the results on the synthetic data. The test accuracy quickly converges to , which is even more significant with larger number of hidden units, showing that the overparameterization helps the optimization and generalization. Recall that our analysis shows that for a learning rate linearly decreasing with the number of hidden nodes , the number of iterations to get the accuracy to achieve a desired accuracy should be roughly the same, which is also verified here. The activation pattern difference ratio is less than , indicating a strong coupling. The relative distance is less than , so the final solution is indeed close to the initialization. Finally, the top 20 singular values of the accumulated updates are much larger than the rest while the spectrum of the weight matrix do not have such structure, which is also consistent with our analysis.
Figure 2 shows the results on MNIST. The observation in general is similar to those on the synthetic data (though less significant), and also the observed trend become more evident with more overparameterization. Some additional results (e.g., varying the variance of the synthetic data) are provided in the appendix that also support our theory.
Conclusion
Acknowledgements
We would like to thank the anonymous reviewers of NeurIPS’18 and Jason Lee for helpful comments. This work was supported in part by FA9550-18-1-0166, NSF grants CCF-1527371, DMS-1317308, Simons Investigator Award, Simons Collaboration Grant, and ONR-N00014-16-1-2329. Yingyu Liang would also like to acknowledge that support for this research was provided by the Office of the Vice Chancellor for Research and Graduate Education at the University of Wisconsin Madison with funding from the Wisconsin Alumni Research Foundation.
References
Appendix A Proofs for the Simplified Case
In the simplified case, we make the following simplifying assumption:
(No variance) Each is a single data point , and also we are doing full batch gradient descent as opposite to the minibatch SGD.
Recall that the loss is then The gradient descent update on is given by
where . The pseudo gradient is defined as
When clear from the context, we write as . Then we can simplify the above expression as:
Furthermore, indicates the “classification error”. The smaller is, the smaller the classification error is.
In the following subsections, we first show that the gradient is coupled with the pseudo gradient, then show that if the classification error is large then the pseudo gradient is large, and finally prove the convergence.
which implies that .
Now, for every , we consider the set such that
For every and every , we know that for every :
Now, we need to bound the size of . Since , by standard property of Gaussian we directly have that for . ∎
A.2 Error Large ⟹\implies Gradient Large
The pseudo gradient can be rewritten as the following summation:
We would like to show that if some is large, a good fraction of will have large pseudo gradient. Now, the first step is to show that for any fixed (that does not depend on the random initialization ), with good probability (over the random choice of ) we have that is large; see Lemma A.2. Then we will take a union bound over an epsilon net on to show that for every (that can depend on ), at least a good fraction of of is large; See Lemma A.3.
For any possible fixed such that , we have:
where . For every , consider the event defined as
for all : .
By the definition of initialization , we know that:
By assumption we know that for every :
Thus if we pick , taking a union bound we know that
Moreover, since and is independent of , we know that .
The following proof will conditional on this event , and then treat as fixed and let be the only random variable. In this way, we will have: for every such that and for every , since and ,
which is a linear function of . With this information, we can rewrite as:
where and is some linear function in . Thus, we know that
is a convex function with . Then applying Lemma A.5 gives
Since for , conditional on the density , which implies that
Now we can look at . By the random initialization of , and since by our assumption are not functions of , a standard tail bound of Gaussian random variables shows that for every fixed and every :
Taking and putting together with inequality (2) with complete the proof. ∎
Now, we can take an epsilon net and switch the order of the quantifiers in Lemma A.2 as shown in the following lemma.
This lemma implies that if the classification error is large, then many will have a large gradient.
We first consider fixed . First of all, using the randomness of we know that with probability at least ,
A.3 Convergence
Having the lemmas, we can now prove the convergence:
Thus, this lemma shows that eventually will be small. However, we do not give any bound on how small the step size needs to be, and how a small leads to a small classification error. These are addressed in the proof of the general case in the next section, but here we are content with an eventually small for a sufficiently small .
By Lemma A.3, we know that there are at least fraction of such that
Now combine with Lemma A.1. If we pick , then at least fraction of have
Thus, for a sufficiently small , we have:
A.4 Technical Lemmas
The following lemma above non-smooth convex function v.s. linear function is needed in the proof.
We have for every , for every linear function :
Without loss of generality (up to subtracting a linear function on ), let us assume that and .
Moreover, denote , we know that at least one of the following is true:
,
.
We shall give the proof for the case . The other case follows from replacing with .
Let us then consider the following two cases.
, in this case, by convexity of we have that . Thus,
, in this case, intersects with at a point . Consider two cases:
, then we have: . Thus,
, then we have:
This completes the proof of the first claim. For the second claim, in case 1, we know that every would have . In case 2(a), every satisfies this claim. In case 2(b) we can take every . This completes the proof. ∎
Appendix B Proofs for the General Case
We consider a minibatch SGD of batch size , number of iterations and learning rate as the following process: Randomly divide the total training examples into batches, each of size . Let the indices of the examples in the -th batch be . The update rule is:
The pseudo gradient on a point is defined as:
In the following subsections, we first show that the gradient is coupled with the pseudo gradient, then show that if the classification error is large then the pseudo gradient is large, and finally prove the convergence.
We have the following lemma for coupling, analog to Lemma A.1.
B.2 Expected Error Large ⟹\implies Gradient Large
Following the same structure as before, we can write the expected pseudo gradient as:
where is defined as:
When clear from the context, we use for short. When the choice of is not important, we will also use .
The proof is very similar to the proof of Lemma A.2.
where . For every , consider the event defined as
By the definition of initialization , we know that:
By assumption we can simply calculate that for every : . This implies that
With , we know that . The following proof will conditional on this event , and then treat as fixed and let be the only random variable. In this way, for every such that and for every :
is a linear function for with probability .
With this information, we can rewrite as:
where is a convex function with .
We will have . Now apply Lemma B.5, we can conclude from the same proof of Lemma A.2. ∎
Now we can take the union bound to switch the order of quantifiers. However, we cannot do a naive union bound since there are infinitely many . Instead, we will use a sampling trick to prove the following Lemma:
This lemma implies that if the classification error is large, then many ’s have a large pseudo gradient.
We first pick samples , with many from distribution , and with the corresponding value function . Since each , we know that w.h.p., for every :
B.3 Convergence
We now show the following important lemma about convergence.
We know that for at least fraction of such that
Now we consider the non-smooth gradient descent. Consider a newly sampled point , and let us denote
Let denote the event that (3) holds.
Theorem 4.1. Suppose the assumptions (A1)(A2)(A3) are satisfied. Then for every , there is such that for every , after doing a minibatch SGD with batch size and learning rate for iterations, with high probability:
Let denote .
Then for every , if , then
. For all such , even if all the predictions are wrong, it will only increase the total error by so the other half error must come from other .
Therefore, to prove the theorem, it suffices to show that will be smaller than after a proper amount of iterations. Suppose , then by Lemma B.4, as long as
B.4 Technical Lemmas
The following lemma above non-smooth convex function v.s. linear function is needed in the proof.
We have that for every , for every convex function , let , then
Without loss of generality, we can assume that either and , or and . The lemma can be proved using the same argument as in Lemma A.5. ∎
We also need the following lemma regarding the gradient descent on non-smooth function.
The proof of this lemma follows directly from
where the last line follows from the chain rule and Lipschitz smoothness, and the last to second line follows from
Appendix C Illustration of the Separability Assumption
(Separability) There exists such that for every and every ,
Appendix D Additional Experimental Results
Here we provide some additional experimental results.
Recall that our analysis that for a learning rate decreasing with the number of hidden nodes , the number of iterations to get the accuracy roughly remain the same. A more direct way to check is to plot the number of steps to achieve the accuracy for different . As shown in Figure 4, the number of steps roughly match what our theory predicts.
Furthermore, Figure 5 shows the relative distances when achieving the desired accuracies. It is observed that the distances scale roughly as . In particular, they closely match on the synthetic data and on MNIST (the red lines in the figures), where is the number of hidden nodes. Explanations are left for future work.
D.2 Synthetic Data with Larger Variances
Figure 6 shows that the test accuracy decreases with increasing variance , and it takes longer time to get a good solution. On the other hand, an increasing variance does not change the trends for activation patterns, distance, and the rank of the weight matrix. This is possibly due to that the signal in the updates remain small with increasing variances, while the noise in the updates act similarly as the randomness in the weights.