Neural Networks can Learn Representations with Gradient Descent
Alex Damian, Jason D. Lee, Mahdi Soltanolkotabi
Introduction
Crucial to the practical success of deep learning is the ability of gradient-based algorithms to learn good feature representations from the training data and learn simple functions on top of these representations. Despite significant progress towards a theoretical foundation for neural networks, a robust understanding of this unique representation learning capability of gradient descent methods has remained elusive. A major challenge is that due to the highly nonconvex loss landscape, establishing convergence to a global optimum that achieves near zero training loss is challenging. Furthermore, due to the overparameterized nature of modern neural nets (containing many more parameters than training data) the training landscape has many global optima. In fact, there are many global optima with poor generalization performance . This paper thus focuses on answering this intriguing question:
How do gradient-based methods learn feature representations and why do these representations allow for efficient generalization and transfer learning?
The most prominent contemporary approach to understanding neural networks is the linearization or neural tangent kernel (NTK) technique. The premise of the linearization method is that the dynamics of gradient descent are well-approximated by gradient descent on a linear regression instance with fixed feature representation. Using this linearization technique, it is possible to prove convergence to a zero training loss point . However, this technique often requires unrealistic hyper-parameter choices (e.g. small learning rate, large initialization, or wide networks) that does not allow the features to evolve across the iterations and thus the generalization error with this technique cannot be better than that of a kernel method. Indeed, precise lower bounds show that the NTK solutions do not generalize better than the polynomial kernel . As a result this regime of training is also sometimes referred to as the lazy regime .See Section 4 for a more in depth discussion of this literature and other related work. In practice, neural networks far outperform their corresponding induced kernels . Therefore, understanding the representation learning of neural networks beyond the lazy regime is of fundamental importance.
In this paper, we initiate the study of the representation learning of neural networks beyond this NTK/linear/lazy regime. To this aim, we consider the problem of learning polynomials with low-dimensional latent representation of the form , where maps from to dimensions with with a multivariate polynomial of degree . This is a natural choice as the failure of the NTK solution is in part due to its inability to learn data-dependent feature representations that adapt to the intrinsic low latent dimensionality of the ground truth function. Existing analysis based on the NTK regime provably require samples to learn any degree polynomial, even if they only depend on a few relevant directions. In contrast we show that gradient descent from random initialization only requires samples, breaking the sample complexity barrier dictated by NTK proof techniques. More specifically, our contributions are as follows:
Feature Learning: When the target function only depends on the projection of onto a hidden subspace , we show that gradient descent learns features that span . Leveraging these features, gradient descent can reach vanishing training loss with a very small network which guarantees good generalization performance. See Section 5.1.
Lower Bound: Finally, we show a lower bound that demonstrates our non-degeneracy assumption (Assumption 2) is strictly necessary. Without the non-degeneracy, there is a family of polynomials which depend on single relevant dimensions (i.e. of the form ) which cannot be learned with fewer than by any gradient descent based learner.
Setup
where controls the strength of the label noise.
In order to make the problem of learning tractable, additional assumptions are necessary. The set of degree polynomials in dimensions span a linear subspace of of dimension . Learning arbitrary degree polynomials therefore requires samples. We follow Chen and Meka , Chen et al. in assuming that the ground truth has a special low dimensional latent structure. Specifically, we assume that only depends on a small number of relevant dimensions and that the expected Hessian is non degenerate. We show in Theorem 2 that this non degeneracy assumption is strictly necessary to avoid sample complexity .
We will call the principal subspace of . We will also denote by the orthogonal projection onto .
We will also denote the normalized condition number of by .
2 The Network and Loss
where denotes the width of the network. We use a symmetric initialization, so that . Explicitly, we will assume that is an even number and that
We will use the following initialization:
We note that while we focus on such symmetric initialization for clarity of exposition, our results also hold with small random initialization that is not necessarily symmetric. This holds by simple modifications in the proof accounting for the small nonzero output of the network at initialization. We will also denote the empirical and population losses by and respectively:
3 Notation
Main Results
Before we formally state our main result let us specify the exact form of gradient-based training we use in our theory.
With this algorithm in place, we are now ready to state our main result.
It is useful to note that the use of in the algorithm corresponds to the common practice of weight decay and its value is chosen in such a way that , i.e. to solve a constrained minimization problem (see Section 5.1). In practice, one simply tunes the hyperparameter in order to achieve the desired tradeoff between training and test loss.
An intriguing aspect of the above result is that despite the fact that may be of arbitrarily high degree, learning requires only samples and only requires a very small network with . We note that our dependence on the latent dimension is near optimal as the minimax sample complexity even when the principal subspace is known is .
We show in Theorem 3 that by resampling the data after the first step, the sample complexity can be further reduced to , dropping a factor of from the second term. The extra factor of results from the dependence between the data used in the first and second stages and we believe that a more careful analysis could remove this additional factor.
We contrast Theorem 1 with the following lower bound for learning a function class which satisfies 1 with but does not satisfy 2.
For any , there exists a function class of polynomials of degree , each of which depends on a single relevant dimension, such that any correlational statistical query learner using queries requires a tolerance of at most
in order to output a function with loss at most .
Using the heuristic , which represents the expected scale of the concentration error, we get the immediate corollary that violating 2 allows us to construct a function class which any neural network with polynomially many parameters trained for polynomially many steps of gradient descent cannot learn without at least samples. We emphasize that this is only a heuristic argument as concentration errors are random rather than adversarial.
On the other hand, Theorem 1 shows that incorporating 2 allows gradient descent to efficiently learn polynomials of arbitrarily high degree with only samples.
The difference in sample complexity between Theorem 1 and Theorem 2 is that in Theorem 1, our non-degeneracy assumption (2) allows the network to extract useful features that aid robust learning and allowed learning high degree polynomials with samples. Theorem 2 shows that violating this assumption allows us to construct a function class which cannot be learned without samples, demonstrating the necessity of 2.
The fact that the network extracts useful features not only allows it to learn efficiently, but also allows for efficient transfer learning. In particular, Theorem 3 shows that we can efficiently learn any target polynomial that depends on the same relevant dimensions as with sample complexity independent of by simply truncating and retraining the head of the network:
Learning therefore only requires , which is independent of the ambient dimension . We note that this is minimax optimal for learning arbitrary degree polynomials even when the hidden subspace is known. Theorem 3 also shows that pre-training samples are sufficient for gradient descent to learn the subspace from the pre-training data.
Related work
A growing body of recent work show the connection between gradient descent on the full network and the Neural Tangent Kernel (NTK) . Using this technique one can prove concrete results about neural network training and generalization in the kernel regime. The key idea is that for a large enough initialization, it suffices to consider a linearization of the neural network around the origin. This allows connecting the analysis of neural networks with the well-studied theory of kernel methods. This is also sometimes referred to as lazy training, as with such an initialization the parameters of the neural networks stay close to the parameters at initialization and these results can only show that neural networks are as powerful as shallow learners such as kernels. There is however growing evidence that this NTK-style analysis might not be sufficient to completely explain the success of neural networks in practice. The papers provides empirical evidence that by choosing a smaller initialization the test error of the neural network decreases. A similar performance gap between the performance of the NTK and neural networks has been observed in . This NTK-style analysis however does not yield satisfactory results in the setting studied in this paper. In particular for learning the polynomials of the form we study in this paper, demonstrates that one needs at least samples in the kernel regime. In contrast, our results only require on the order of samples.
Leveraging the fact that linearized models are not feature learners, Ghorbani et al. and showed precise upper and lower bounds on the sample complexity of NTK methods. They showed that because NTK is unable to learn new features, learning any polynomial in dimension of degree requires samples, which gives no improvement over polynomial kernels. On the empirical front, the NTK linearization analysis is also lacking. Arora et al. demonstrated that the kernel predictor loses more than in test accuracy relative to a deep network trained with SGD and state-of-art regularization on CIFAR-10. Our work is motivated by the contrast between these negative theoretical results for linearized NTK models and the spectacular empirical performance of deep learning.
The gap between such shallow learners and the full neural network has been established in theory and observed in practice . There is an emerging literature on learning beyond the lazy/NTK regime in the small initialization setting. The papers shows that for the problem of low-rank reconstruction in a non-lazy regime with small random initialization gradient descent finds globally optimal solutions with good generalization capability. This is carried out by utilizing a spectral bias phenomena exhibited by the early stages of gradient descent from small random initialization that puts the iterates on the trajectory towards generalizable models. For the problem of tensor decomposition it has also been shown that gradient descent with small initialization is able to leverage low-rank structure . In , it has been shown that neural networks with orthogonal weights can be learned via SGD and outperform any kernel method. One crucial element in their analysis is that the early stage of the training is connected with learning the first and second moment of the data. Higher-order approximations of the training dynamics and the Neural Tangent Hierarchy have also been recently proposed towards closing this gap. None of the above papers, however, focus on learning polynomial representations efficiently via neural networks as carried out in this paper.
Another line of work focuses on learning single activations such as the ReLU function. In this context shows that it is hard to learn a single ReLU activation via stochastic gradient descent with random features where as learning such activations is possible in a non-NTK regime again highlighting this important gap. In related work where the label also only depends on a single relevant direction , the authors show that in the context of learning the parity function, gradient descent is able to efficiently learn the planted set. However, this is a result of the unbalanced data distribution which skews the gradient towards the planted set. In contrast, we consider isotropic Gaussian data so that no information can be extracted from the data distribution itself and features must be extracted from higher order correlations between the data and the labels. Chen and Meka also studied the problem of learning polynomials of few relevant dimensions. They provide an algorithm that learns polynomials of degree in dimensions that depends on hidden dimensions with samples where is an unspecified function of which is likely exponential in . However, their algorithm is not a variant of gradient descent, and requires a clever spectral initialization. On the other hand, this work focuses on the ability of gradient descent to automatically extract hidden features and learn representations from the data.
There is also a line of work , which is concerned with the mean-field analysis of neural networks. The insight is that for sufficiently large width the training dynamics of the neural network can be coupled with the evolution of a probability distribution described by a PDE. These papers use a smaller initialization than in the NTK-regime and, hence, the parameters can move away from the initialization. However, these results do not provide explicit convergence rates and require an unrealistically large width of the neural network. To the extent of our knowledge such an analysis technique has not been used to show efficient learning of polynomial representations using neural networks as carried out in this paper.
A concurrent line of work studied the feature learning ability of gradient descent in the mean field regime with data sampled from the boolean cube . The authors identified a necessary and sufficient condition for learning with sample complexity linear in , dubbed the merged staircase property, in the special case when the hidden weights of the two layer neural network are initialized at . However, the zero initialization hinders the feature learning ability of the network. For example, the boolean function XOR violates the merged staircase property, however noisy XOR is known to be learnable by two layer neural networks with sample complexity linear in . In this work we study the impact that the nonzero initialization of the hidden weights has on the feature learning ability of neural networks.
Proof Sketches
Using the chain rule, we can further expand this as
With high probability over the random initialization,
Note that the remainder term, of order , contains all higher order terms in the series expansion.
However, it is also important to note that the population gradient is bounded by and we only have access to the empirical gradient . As mentioned above, extracting the necessary subspace information from to learn therefore requires samples, which is the dominant term in our final sample complexity result.
Once we show that the gradient at initialization contains all the relevant features, we note that after the first step of gradient descent,
After the first step, the model therefore resembles a random feature model with random features . Previous results have shown that in these linearized regimes, e.g. random feature models/NTK, learning degree polynomials requires samples and width . As our “random features” are now constrained to the hidden subspace , which has dimension , we should expect that our sample complexity improves to .
The remainder of Algorithm 1 runs ridge regression on the network head with fixed features . We can directly analyze the generalization of this algorithm using standard techniques from Rademacher complexity. In particular, a high level sketch of the remainder of the proof goes as follows:
(Section A.3): We show the equivalence between ridge regression and norm constrained linear regression implies the existence of such that the th iterate satisfies
2 Proof of Theorem 2
Let be a class of functions and be a data distribution such that
Then any correlational statistical query learner requires at least queries of tolerance to output a function in with loss at most .
To construct , we begin by showing that there are a large number of approximately orthogonal unit vectors in :
There exists an absolute constant such that for any , there exists a set of unit vectors such that for any such that , we have .
Therefore implies . Theorem 2 then directly follows from Lemma 2 (see Appendix D for a more detailed proof).
Experiments
In this section we present a toy example that clearly demonstrates the gap between kernel methods and gradient descent on two layer networks. For , consider the target function
which satisfies . Note that only depends on the projection of onto a single relevant direction, . We show in Section 5.1 that gradient descent is capable of isolating the subspace spanned by and then fitting a one dimensional random feature model to , and that this entire process requires samples to generalize.
On the other hand, existing works Ghorbani et al. have shown that samples are strictly necessary in order to learn in the NTK or random features regime. The theory predicts that with samples, kernel regression will return the predictor and with samples, kernel regression will return , incurring a loss of .
We empirically verify these predictions. We take and and consider the function . We use label noise and attempt to learn using Algorithm 1, a random feature model, and a linearized NTK model. All experiments are conducted on a two layer neural network with widths and . For each value of , the weight decay parameter is tuned on a holdout set of size and test accuracies are reported over a separate test set of size . Errors bars reflect the mean and standard deviation over random seeds.
We note that while Algorithm 1 easily converged to vanishing excess risk, even at width , both the random features model and the neural tangent kernel model only managed to fit the quadratic term , as predicted by the theory in Ghorbani et al. .
2 Transfer Learning
The proof of Theorem 1 involves showing that Algorithm 1 learns features corresponding to (see Section 5.1) and the proof of Theorem 3 shows that this implies efficient transfer learning. We again verify this empirically. We consider the function:
Note that this was exactly the hard example in Theorem 2 that was unlearnable without samples by a correlational statistical query learner (and in particular, gradient-based learners).
We pretrain with samples on the from Section 6.1, then train the output layer using samples from . As in Section 6.1, we use a label noise strength of . We pick so that random feature methods or the neural tangent kernel will require at least samples to learn .
We note that in Figure 2, when , fine tuning on target samples gives trivial risk until , which is to be expected of a kernel method with no prior information. However, for pretraining samples, we can fine tune on just target samples to reach nontrivial loss and the loss decays rapidly as a function of . This experiment therefore fully supports the conclusion of Theorem 3.
Discussion and Future Work
In this work we provide a clear separation between gradient-based training and kernel methods. We show that there is a large family of degree polynomials which are efficiently learnable by gradient descent with samples, in contrast to the lower bound of for random feature/NTK analysis. The main idea driving both our sample complexity result (Theorem 1) and our transfer learning result (Theorem 3) is that gradient descent learns useful representations of the data.
One promising direction for future work is tightening the dimension dependence of our upper bound. In particular, our sample complexity is driven by the difficult in learning from a degree Hermite polynomial. However, our lower bound for such functions (Theorem 2) only rules out learning with samples. In this situation the lower bound is tight as Chen et al. show that sparse degree polynomials can be efficiently learned with samples.
Another promising direction from future work is generalizing our result to the situation in which the hidden layer and the output layer are trained together. This introduces dependencies between the hidden and output layers which are difficult to control. However, such analysis may lead to a better understanding of learning order and inductive bias in deep learning.
Acknowledgements
AD acknowledges support from a NSF Graduate Research Fellowship. JDL and AD acknowledge support of the ARO under MURI Award W911NF-11-1-0304, the Sloan Research Fellowship, NSF CCF 2002272, NSF IIS 2107304, ONR Young Investigator Award, and NSF-CAREER under award #2144994. MS is supported by the Packard Fellowship in Science and Engineering, a Sloan Fellowship in Mathematics, an NSF-CAREER under award #1846369, DARPA Learning with Less Labels (LwLL) and FastNICS programs, and NSF-CIF awards #1813877 and #2008443.
References
Appendix A Proofs
We define for a sufficiently large constant . Throughout the appendix we will use to track failure probabilities of various lemmas and theorems.
We say that an event happens with high probability if it happens with probability at least where does not depend on .
Note that high probability events are closed under taking union bounds over sets of size . We will assume throughout that for a sufficiently small absolute constant .
The following lemma bounds and is a direct corollary of Lemma 15:
With high probability, for .
All remaining proofs will be conditioned on this high probability event.
Let . Then the Hermite expansion of is
Let denote the Hermite coefficients of , i.e. . Note that
Let the Hermite expansion of be
Note that as an immediate consequence of Lemma 5, . In addition, 1 guarantees that .
A.1.3 Concentrating α,β𝛼𝛽\alpha,\beta
Let and . Then, with high probability,
Let . Note that
The bound on therefore immediately follows from Lemma 17 applied to . The bound on is a special case of Lemma 19 with . ∎
A.1.4 Hermite Expanding the Features
Note that by the scale invariance of , Algorithm 1 does not depend on for . Therefore we can assume WLOG that for and . For the remainder of the appendix we will assume that .
We define .
The functions and capture the features that can be learned after one step of gradient descent:
By Stein’s lemma and the orthogonality of Hermite polynomials,
Note that these sums are finite as for . Next, by Corollary 12 we have the high probability bounds,
Applying these bounds term by term and using Lemma 6 to bound and gives the desired result. ∎
Furthermore, it will become necessary to bound terms of the form . Note that and are dependent random variables. The following lemma handles this dependence.
Let and assume . Then with high probability,
For the first term, note that and are independent so so with high probability,
Note that in the first term, the and the sum are independent. Therefore by Corollary 7 the first term is bounded with high probability by . In addition, by Lemma 17, the second term is bounded by which completes the proof. ∎
A.2 Random Feature Approximation
This section shows that after we reinitialize the biases we can use random features to transform the activation into which is more natural for learning polynomials.
Let , and . Then for any there exists such that for ,
First, for we can take . Then,
and . Next, for we can take . Then,
and we have . Next, note that by integration by parts we have for any function ,
Therefore for if and
Let , and . Then for any there exists such that for ,
Let be the function constructed in Lemma 9 and let
where denotes the density of . Then,
A.2.2 Multivariable Random Feature Approximation
With high probability over the data , we have for ,
We can decompose and note that
We can bound the th moment term by term. We have by Corollary 8 and Lemma 24 that for ,
We can now show that the random features are sufficiently expressive to allow us to efficiently represent any polynomial of degree restricted to the principal subspace .
For any , there exists an absolute constant such that if and ,
where denotes the orthogonal projection onto symmetric tensors restricted to .
for all symmetric tensor with . Recall that . Therefore by the binomial theorem,
where Therefore by Young’s inequality,
Let be the symmetric tensor defined by . Then by Corollary 13,
Because we assumed and for a sufficiently large constant , we have
Assume and for a sufficiently large constant . Then for any and any symmetric tensor supported on , there exists such that
Assume and for a sufficiently large constant . Let , let and let be a tensor. Then with high probability, there exists such that if
where and are constructed in Corollary 3 and Corollary 4 respectively. Recall that . Then for ,
where the second to last line followed from Lemma 8. The first part of the lemma now follows from a union bound over . For the bounds on , we have
Assume and for a sufficiently large constant and let . Then with high probability, there exists such that if
with . Let
Then is immediate from Lemma 12 and
Let where is the function constructed in Corollary 5. Then,
Then with probability we have that for . Therefore,
For the first term, by Bernstein’s inequality we have with probability at least ,
and the first part of the lemma follows from a union bound.
We will now turn to the bound on . Let . Note that are positive, i.i.d., and bounded by . In addition, they have expectation . Therefore by Popoviciu’s inequality they have variance bounded by
Therefore by Bernstein’s inequality we have that with high probability,
A.3 Proof of Theorem 1
to be the empirical losses with respect to the true labels (recall , ).
Assume and for a sufficiently large constant and let . Let be the vector constructed in the proof of Lemma 13 and let . Then with high probability,
Let . Then,
First, by Hoeffding’s inequality, we have with high probability,
We are now ready to directly prove Theorem 1.
Note that we can assume that there is an absolute constant such that , and . Otherwise, we can simply take and return the zero predictor.
From Lemma 14 we know that with high probability, there exists such that if ,
and Therefore by equality of norm constrained linear regression and ridge regression, there exists such that if
Then with high probability, . In addition, from Lemma 28,
Appendix B Transfer Learning
The proof of Theorem 3 is virtually identical to that of Theorem 1. We can use Lemma 13 to construct such that if then with high probability,
In addition, there exists such that if ,
Now let . Then by Lemma 27 we have with high probability,
Appendix C Concentration Lemmas
Let . Then, for any ,
Let . Then for some constant ,
Let be a polynomial of degree . Then there exists an absolute constant depending only on such that for any ,
Therefore by Theorem 1.2 of , there exists an absolute constant such that
Note that the planes divides the sphere into at most convex regions. For each region there exists an net of size . Therefore we can take the union of these nets over each region which has size at most . ∎
Let be a polynomial of degree and let . Then there exists an absolute constant depending only on such that for any , with probability at least , we have
Let so that
Then note that for fixed , is -sub Gaussian so for each , with probability we have
so by a union bound we have with probability ,
so setting we have with probability ,
Using and putting everything together gives with probability ,
Let . Then with high probability,
Next, note that for fixed , is sub-Gaussian so for any , with probability ,
By a union bound, with probability at least ,
Appendix D CSQ Lower Bound
The proof is a modified version of the proof in Szörényi . Let denote the inner product with respect to . We will show that there are at least two functions such that for each query , and . Therefore, we can simply respond to each query adversarially with and it is impossible for the learner to distinguish between . Note that failing to do so will result in a loss of . Let the th query be and let
Similarly, we have that so the number of functions that are eliminated from the th query is at most . We can continue this process for at most iterations. ∎
Let . Then for every pair , is subgaussian so for an absolute constant , with probability , . Therefore with probability this holds for all so there must exist at least one collection of such points. ∎
Let be the set constructed in Lemma 3. Let
and note that for all , . Then for and ,
Therefore, by Lemma 2 we have for any ,
In particular if we take we get
Appendix E Additional Technical Lemmas
For a tensor , let denote the symmetrization of along all permutations of indices.
There exist such that
and for .
Note that from the Taylor series of we have
Note that by a simple counting argument, the number of permutations such that this product of indicators is nonzero is exactly as you can first order the indices corresponding to each , then split them into groups of two, then shuffle these groups of two. Therefore,
because , which completes the proof. ∎
Let and denote the change of basis matrices between Hermite polynomials and monomials, i.e.
Let be a symmetric -tensor and let . Then for ,
Let with . Using the change of basis ,
Let be a symmetric -tensor with . For ,
The proof follows directly from Lemma 23 and the inequality for . ∎
Let be a symmetric -tensor with . With probability at least ,
Therefore by Lemma 17, with probability at least , and taking square roots completes the proof. ∎
This follows immediately from Lemma 23 and . ∎
Let be a polynomial of degree . Then
E.2 Sphere Lemmas
This follows from the decomposition with independent. ∎
Let be a symmetric -tensor with . With probability at least ,
Let . For ,
E.3 Rademacher Complexity Bounds
Let be a two layer neural network. For fixed , Let
Let and let be a two layer neural network. Let