Neural Networks can Learn Representations with Gradient Descent

Alex Damian, Jason D. Lee, Mahdi Soltanolkotabi

Introduction

Crucial to the practical success of deep learning is the ability of gradient-based algorithms to learn good feature representations from the training data and learn simple functions on top of these representations. Despite significant progress towards a theoretical foundation for neural networks, a robust understanding of this unique representation learning capability of gradient descent methods has remained elusive. A major challenge is that due to the highly nonconvex loss landscape, establishing convergence to a global optimum that achieves near zero training loss is challenging. Furthermore, due to the overparameterized nature of modern neural nets (containing many more parameters than training data) the training landscape has many global optima. In fact, there are many global optima with poor generalization performance . This paper thus focuses on answering this intriguing question:

How do gradient-based methods learn feature representations and why do these representations allow for efficient generalization and transfer learning?

The most prominent contemporary approach to understanding neural networks is the linearization or neural tangent kernel (NTK) technique. The premise of the linearization method is that the dynamics of gradient descent are well-approximated by gradient descent on a linear regression instance with fixed feature representation. Using this linearization technique, it is possible to prove convergence to a zero training loss point . However, this technique often requires unrealistic hyper-parameter choices (e.g. small learning rate, large initialization, or wide networks) that does not allow the features to evolve across the iterations and thus the generalization error with this technique cannot be better than that of a kernel method. Indeed, precise lower bounds show that the NTK solutions do not generalize better than the polynomial kernel . As a result this regime of training is also sometimes referred to as the lazy regime .See Section 4 for a more in depth discussion of this literature and other related work. In practice, neural networks far outperform their corresponding induced kernels . Therefore, understanding the representation learning of neural networks beyond the lazy regime is of fundamental importance.

In this paper, we initiate the study of the representation learning of neural networks beyond this NTK/linear/lazy regime. To this aim, we consider the problem of learning polynomials with low-dimensional latent representation of the form f(x)=g(Ux)f^{*}(x)=g(Ux), where UU maps from dd to rr dimensions with drd\gg r with gg a multivariate polynomial of degree pp. This is a natural choice as the failure of the NTK solution is in part due to its inability to learn data-dependent feature representations that adapt to the intrinsic low latent dimensionality of the ground truth function. Existing analysis based on the NTK regime provably require ndpn\asymp d^{p} samples to learn any degree pp polynomial, even if they only depend on a few relevant directions. In contrast we show that gradient descent from random initialization only requires nd2r+drpn\asymp d^{2}r+dr^{p} samples, breaking the sample complexity barrier dictated by NTK proof techniques. More specifically, our contributions are as follows:

Feature Learning: When the target function f=g(Ux)f^{\star}=g(Ux) only depends on the projection of xx onto a hidden subspace span(U)\operatorname{span}(U), we show that gradient descent learns features that span span(U)\operatorname{span}(U). Leveraging these features, gradient descent can reach vanishing training loss with a very small network which guarantees good generalization performance. See Section 5.1.

Lower Bound: Finally, we show a lower bound that demonstrates our non-degeneracy assumption (Assumption 2) is strictly necessary. Without the non-degeneracy, there is a family of polynomials which depend on single relevant dimensions (i.e. of the form f(x)=g(ux)f^{\star}(x)=g(u\cdot x)) which cannot be learned with fewer than ndp/2n\asymp d^{p/2} by any gradient descent based learner.

Setup

where ς2\varsigma^{2} controls the strength of the label noise.

In order to make the problem of learning ff^{\star} tractable, additional assumptions are necessary. The set of degree pp polynomials in dd dimensions span a linear subspace of L2(D)L^{2}(\mathcal{D}) of dimension Θ(dp)\Theta(d^{p}). Learning arbitrary degree pp polynomials therefore requires ndpn\gtrsim d^{p} samples. We follow Chen and Meka , Chen et al. in assuming that the ground truth ff^{\star} has a special low dimensional latent structure. Specifically, we assume that ff^{\star} only depends on a small number of relevant dimensions and that the expected Hessian is non degenerate. We show in Theorem 2 that this non degeneracy assumption is strictly necessary to avoid sample complexity dΩ(p)d^{\Omega(p)}.

We will call S:=span(u1,,ur)S^{\star}:=\operatorname{span}(u_{1},\ldots,u_{r}) the principal subspace of ff^{\star}. We will also denote by Π:=ΠS\Pi^{\star}:=\Pi_{S^{\star}} the orthogonal projection onto SS^{\star}.

We will also denote the normalized condition number of HH by κ:=\normHr\kappa:=\frac{\norm*{H^{\dagger}}}{\sqrt{r}}.

2 The Network and Loss

where mm denotes the width of the network. We use a symmetric initialization, so that fθ0(x)=0f_{\theta_{0}}(x)=0 . Explicitly, we will assume that mm is an even number and that

We will use the following initialization:

We note that while we focus on such symmetric initialization for clarity of exposition, our results also hold with small random initialization that is not necessarily symmetric. This holds by simple modifications in the proof accounting for the small nonzero output of the network at initialization. We will also denote the empirical and population losses by L(θ)\mathcal{L}(\theta) and LD(θ)\mathcal{L}_{\mathcal{D}}(\theta) respectively:

3 Notation

Main Results

Before we formally state our main result let us specify the exact form of gradient-based training we use in our theory.

With this algorithm in place, we are now ready to state our main result.

It is useful to note that the use of λ\lambda in the algorithm corresponds to the common practice of weight decay and its value is chosen in such a way that \norma(T)Ba\norm{a^{(T)}}\leq B_{a}, i.e. to solve a constrained minimization problem (see Section 5.1). In practice, one simply tunes the hyperparameter λ\lambda in order to achieve the desired tradeoff between training and test loss.

An intriguing aspect of the above result is that despite the fact that ff^{\star} may be of arbitrarily high degree, learning ff^{\star} requires only ndrp+d2rn\gtrsim dr^{p}+d^{2}r samples and only requires a very small network with mrpm\gtrsim r^{p}. We note that our dependence on the latent dimension rr is near optimal as the minimax sample complexity even when the principal subspace SS^{\star} is known is Θ(rp)\Theta(r^{p}).

We show in Theorem 3 that by resampling the data after the first step, the sample complexity can be further reduced to d2r+rpd^{2}r+r^{p}, dropping a factor of dd from the second term. The extra factor of dd results from the dependence between the data used in the first and second stages and we believe that a more careful analysis could remove this additional factor.

We contrast Theorem 1 with the following lower bound for learning a function class which satisfies 1 with r=1r=1 but does not satisfy 2.

For any p0p\geq 0, there exists a function class Fp\mathcal{F}_{p} of polynomials of degree pp, each of which depends on a single relevant dimension, such that any correlational statistical query learner using qq queries requires a tolerance τ\tau of at most

in order to output a function fFpf\in\mathcal{F}_{p} with L2(D)L^{2}(\mathcal{D}) loss at most 11.

Using the heuristic τ1n\tau\approx\frac{1}{\sqrt{n}}, which represents the expected scale of the concentration error, we get the immediate corollary that violating 2 allows us to construct a function class which any neural network with polynomially many parameters trained for polynomially many steps of gradient descent cannot learn without at least ndp/2n\gtrsim d^{p/2} samples. We emphasize that this is only a heuristic argument as concentration errors are random rather than adversarial.

On the other hand, Theorem 1 shows that incorporating 2 allows gradient descent to efficiently learn polynomials of arbitrarily high degree with only d2r+drpd^{2}r+dr^{p} samples.

The difference in sample complexity between Theorem 1 and Theorem 2 is that in Theorem 1, our non-degeneracy assumption (2) allows the network fθf_{\theta} to extract useful features that aid robust learning and allowed learning high degree polynomials with nd2n\gtrsim d^{2} samples. Theorem 2 shows that violating this assumption allows us to construct a function class which cannot be learned without dΩ(p)d^{\Omega(p)} samples, demonstrating the necessity of 2.

The fact that the network fθf_{\theta} extracts useful features not only allows it to learn ff^{\star} efficiently, but also allows for efficient transfer learning. In particular, Theorem 3 shows that we can efficiently learn any target polynomial g(x)g^{\star}(x) that depends on the same relevant dimensions as ff^{\star} with sample complexity independent of dd by simply truncating and retraining the head of the network:

Learning g(x)g^{\star}(x) therefore only requires N,mrpN,m\gtrsim r^{p}, which is independent of the ambient dimension dd. We note that this is minimax optimal for learning arbitrary degree pp polynomials even when the hidden subspace SS^{\star} is known. Theorem 3 also shows that nd2rn\gtrsim d^{2}r pre-training samples are sufficient for gradient descent to learn the subspace SS^{\star} from the pre-training data.

Related work

A growing body of recent work show the connection between gradient descent on the full network and the Neural Tangent Kernel (NTK) . Using this technique one can prove concrete results about neural network training and generalization in the kernel regime. The key idea is that for a large enough initialization, it suffices to consider a linearization of the neural network around the origin. This allows connecting the analysis of neural networks with the well-studied theory of kernel methods. This is also sometimes referred to as lazy training, as with such an initialization the parameters of the neural networks stay close to the parameters at initialization and these results can only show that neural networks are as powerful as shallow learners such as kernels. There is however growing evidence that this NTK-style analysis might not be sufficient to completely explain the success of neural networks in practice. The papers provides empirical evidence that by choosing a smaller initialization the test error of the neural network decreases. A similar performance gap between the performance of the NTK and neural networks has been observed in . This NTK-style analysis however does not yield satisfactory results in the setting studied in this paper. In particular for learning the polynomials of the form we study in this paper, demonstrates that one needs at least dpd^{p} samples in the kernel regime. In contrast, our results only require on the order of d2d^{2} samples.

Leveraging the fact that linearized models are not feature learners, Ghorbani et al. and showed precise upper and lower bounds on the sample complexity of NTK methods. They showed that because NTK is unable to learn new features, learning any polynomial in dimension dd of degree pp requires n=Θ(dp)n=\Theta(d^{p}) samples, which gives no improvement over polynomial kernels. On the empirical front, the NTK linearization analysis is also lacking. Arora et al. demonstrated that the kernel predictor loses more than 20%20\% in test accuracy relative to a deep network trained with SGD and state-of-art regularization on CIFAR-10. Our work is motivated by the contrast between these negative theoretical results for linearized NTK models and the spectacular empirical performance of deep learning.

The gap between such shallow learners and the full neural network has been established in theory and observed in practice . There is an emerging literature on learning beyond the lazy/NTK regime in the small initialization setting. The papers shows that for the problem of low-rank reconstruction in a non-lazy regime with small random initialization gradient descent finds globally optimal solutions with good generalization capability. This is carried out by utilizing a spectral bias phenomena exhibited by the early stages of gradient descent from small random initialization that puts the iterates on the trajectory towards generalizable models. For the problem of tensor decomposition it has also been shown that gradient descent with small initialization is able to leverage low-rank structure . In , it has been shown that neural networks with orthogonal weights can be learned via SGD and outperform any kernel method. One crucial element in their analysis is that the early stage of the training is connected with learning the first and second moment of the data. Higher-order approximations of the training dynamics and the Neural Tangent Hierarchy have also been recently proposed towards closing this gap. None of the above papers, however, focus on learning polynomial representations efficiently via neural networks as carried out in this paper.

Another line of work focuses on learning single activations such as the ReLU function. In this context shows that it is hard to learn a single ReLU activation via stochastic gradient descent with random features where as learning such activations is possible in a non-NTK regime again highlighting this important gap. In related work where the label also only depends on a single relevant direction , the authors show that in the context of learning the parity function, gradient descent is able to efficiently learn the planted set. However, this is a result of the unbalanced data distribution which skews the gradient towards the planted set. In contrast, we consider isotropic Gaussian data so that no information can be extracted from the data distribution itself and features must be extracted from higher order correlations between the data and the labels. Chen and Meka also studied the problem of learning polynomials of few relevant dimensions. They provide an algorithm that learns polynomials of degree pp in dd dimensions that depends on rr hidden dimensions with nC(r,p)dn\gtrsim C(r,p)d samples where C(r,p)C(r,p) is an unspecified function of r,pr,p which is likely exponential in rr. However, their algorithm is not a variant of gradient descent, and requires a clever spectral initialization. On the other hand, this work focuses on the ability of gradient descent to automatically extract hidden features and learn representations from the data.

There is also a line of work , which is concerned with the mean-field analysis of neural networks. The insight is that for sufficiently large width the training dynamics of the neural network can be coupled with the evolution of a probability distribution described by a PDE. These papers use a smaller initialization than in the NTK-regime and, hence, the parameters can move away from the initialization. However, these results do not provide explicit convergence rates and require an unrealistically large width of the neural network. To the extent of our knowledge such an analysis technique has not been used to show efficient learning of polynomial representations using neural networks as carried out in this paper.

A concurrent line of work studied the feature learning ability of gradient descent in the mean field regime with data sampled from the boolean cube . The authors identified a necessary and sufficient condition for learning with sample complexity linear in dd, dubbed the merged staircase property, in the special case when the hidden weights of the two layer neural network are initialized at . However, the zero initialization hinders the feature learning ability of the network. For example, the boolean function XOR violates the merged staircase property, however noisy XOR is known to be learnable by two layer neural networks with sample complexity linear in dd . In this work we study the impact that the nonzero initialization of the hidden weights has on the feature learning ability of neural networks.

Proof Sketches

Using the chain rule, we can further expand this as

With high probability over the random initialization,

Note that the remainder term, of order d1d^{-1}, contains all higher order terms in the series expansion.

However, it is also important to note that the population gradient is bounded by \normwjLD(θ)=O(d1/2)\norm{\nabla_{w_{j}}\mathcal{L}_{\mathcal{D}}(\theta)}=O(d^{-1/2}) and we only have access to the empirical gradient wjL(θ)\nabla_{w_{j}}\mathcal{L}(\theta). As mentioned above, extracting the necessary subspace information from wjLD(θ)\nabla_{w_{j}}\mathcal{L}_{\mathcal{D}}(\theta) to learn ff^{\star} therefore requires nd2n\gtrsim d^{2} samples, which is the dominant term in our final sample complexity result.

Once we show that the gradient at initialization contains all the relevant features, we note that after the first step of gradient descent,

After the first step, the model therefore resembles a random feature model with random features {Hw}wSd1S\{Hw\}_{w\in S^{d-1}}\subset S^{\star}. Previous results have shown that in these linearized regimes, e.g. random feature models/NTK, learning degree pp polynomials requires ndpn\gtrsim d^{p} samples and width mdpm\gtrsim d^{p}. As our “random features” are now constrained to the hidden subspace SS^{\star}, which has dimension rr, we should expect that our sample complexity improves to nrpn\gtrsim r^{p}.

The remainder of Algorithm 1 runs ridge regression on the network head aa with fixed features xσ(W(1)x+b)x\to\sigma(W^{(1)}x+b). We can directly analyze the generalization of this algorithm using standard techniques from Rademacher complexity. In particular, a high level sketch of the remainder of the proof goes as follows:

(Section A.3): We show the equivalence between ridge regression and norm constrained linear regression implies the existence of λ>0\lambda>0 such that the TTth iterate a(T)a^{(T)} satisfies

2 Proof of Theorem 2

Let F\mathcal{F} be a class of functions and D\mathcal{D} be a data distribution such that

Then any correlational statistical query learner requires at least \absolutevalueF(τ2ϵ)2\frac{\absolutevalue{\mathcal{F}}(\tau^{2}-\epsilon)}{2} queries of tolerance τ\tau to output a function in F\mathcal{F} with L2(D)L^{2}(\mathcal{D}) loss at most 22ϵ2-2\epsilon.

To construct Fp\mathcal{F}_{p}, we begin by showing that there are a large number of approximately orthogonal unit vectors in Sd1S^{d-1}:

There exists an absolute constant cc such that for any ϵ>0\epsilon>0, there exists a set SS of 12ecϵ2d\frac{1}{2}e^{c\epsilon^{2}d} unit vectors such that for any v,wSv,w\in S such that vwv\neq w, we have \absolutevaluevwϵ\absolutevalue{v\cdot w}\leq\epsilon.

Therefore \absolutevalueuvd1/2logm\absolutevalue{u\cdot v}\leq d^{-1/2}\sqrt{\log m} implies \absolutevalueExD[fu(x)fv(x)]dk/2(logm)k/2\absolutevalue{E_{x\sim\mathcal{D}}[f_{u}(x)f_{v}(x)]}\leq d^{-k/2}(\log m)^{k/2}. Theorem 2 then directly follows from Lemma 2 (see Appendix D for a more detailed proof).

Experiments

In this section we present a toy example that clearly demonstrates the gap between kernel methods and gradient descent on two layer networks. For uSd1u\in S^{d-1}, consider the target function

which satisfies ExD[fu(x)2]=1E_{x\sim\mathcal{D}}[f_{u}^{\star}(x)^{2}]=1. Note that ff^{\star} only depends on the projection of xx onto a single relevant direction, uu. We show in Section 5.1 that gradient descent is capable of isolating the subspace spanned by uu and then fitting a one dimensional random feature model to gg, and that this entire process requires nd2n\asymp d^{2} samples to generalize.

On the other hand, existing works Ghorbani et al. have shown that ndpn\asymp d^{p} samples are strictly necessary in order to learn ff^{\star} in the NTK or random features regime. The theory predicts that with n<d2n<d^{2} samples, kernel regression will return the predictor and with d2<n<dpd^{2}<n<d^{p} samples, kernel regression will return 12He2(ux)\frac{1}{2}He_{2}(u\cdot x), incurring a L2(D)L^{2}(\mathcal{D}) loss of 12\frac{1}{2}.

We empirically verify these predictions. We take d=10d=10 and p=4p=4 and consider the function fe1(x)=He2(x1)2+He4(x1)43f_{e_{1}}^{\star}(x)=\frac{He_{2}(x_{1})}{2}+\frac{He_{4}(x_{1})}{4\sqrt{3}}. We use label noise σ2=1\sigma^{2}=1 and attempt to learn ff^{\star} using Algorithm 1, a random feature model, and a linearized NTK model. All experiments are conducted on a two layer neural network with widths m=100m=100 and m=1000m=1000. For each value of nn, the weight decay parameter λ\lambda is tuned on a holdout set of size 10510^{5} and test accuracies are reported over a separate test set of size 10510^{5}. Errors bars reflect the mean and standard deviation over 1010 random seeds.

We note that while Algorithm 1 easily converged to vanishing excess risk, even at width m=100m=100, both the random features model and the neural tangent kernel model only managed to fit the quadratic term 12He2(ux)\frac{1}{2}He_{2}(u\cdot x), as predicted by the theory in Ghorbani et al. .

2 Transfer Learning

The proof of Theorem 1 involves showing that Algorithm 1 learns features corresponding to SS^{\star} (see Section 5.1) and the proof of Theorem 3 shows that this implies efficient transfer learning. We again verify this empirically. We consider the function:

Note that this was exactly the hard example in Theorem 2 that was unlearnable without ndp2n\gtrsim d^{\frac{p}{2}} samples by a correlational statistical query learner (and in particular, gradient-based learners).

We pretrain with nn samples on the f(x)f^{\star}(x) from Section 6.1, then train the output layer using NN samples from ftargetf^{\star}_{\text{target}}. As in Section 6.1, we use a label noise strength of σ2=1\sigma^{2}=1. We pick p=3p=3 so that random feature methods or the neural tangent kernel will require at least nd3n\gtrsim d^{3} samples to learn ff^{\star}.

We note that in Figure 2, when n=d0,d1n=d^{0},d^{1}, fine tuning on NN target samples gives trivial risk until Nd3N\gtrsim d^{3}, which is to be expected of a kernel method with no prior information. However, for nd2n\geq d^{2} pretraining samples, we can fine tune on just N=O(1)N=O(1) target samples to reach nontrivial loss and the loss decays rapidly as a function of NN. This experiment therefore fully supports the conclusion of Theorem 3.

Discussion and Future Work

In this work we provide a clear separation between gradient-based training and kernel methods. We show that there is a large family of degree pp polynomials which are efficiently learnable by gradient descent with nd2n\asymp d^{2} samples, in contrast to the lower bound of dpd^{p} for random feature/NTK analysis. The main idea driving both our sample complexity result (Theorem 1) and our transfer learning result (Theorem 3) is that gradient descent learns useful representations of the data.

One promising direction for future work is tightening the dimension dependence of our upper bound. In particular, our nd2n\asymp d^{2} sample complexity is driven by the difficult in learning from a degree 22 Hermite polynomial. However, our lower bound for such functions (Theorem 2) only rules out learning with ndn\leq d samples. In this situation the lower bound is tight as Chen et al. show that sparse degree 22 polynomials can be efficiently learned with ndn\asymp d samples.

Another promising direction from future work is generalizing our result to the situation in which the hidden layer and the output layer are trained together. This introduces dependencies between the hidden and output layers which are difficult to control. However, such analysis may lead to a better understanding of learning order and inductive bias in deep learning.

Acknowledgements

AD acknowledges support from a NSF Graduate Research Fellowship. JDL and AD acknowledge support of the ARO under MURI Award W911NF-11-1-0304, the Sloan Research Fellowship, NSF CCF 2002272, NSF IIS 2107304, ONR Young Investigator Award, and NSF-CAREER under award #2144994. MS is supported by the Packard Fellowship in Science and Engineering, a Sloan Fellowship in Mathematics, an NSF-CAREER under award #1846369, DARPA Learning with Less Labels (LwLL) and FastNICS programs, and NSF-CIF awards #1813877 and #2008443.

References

Appendix A Proofs

We define ι=Cιlog(nmd)\iota=C_{\iota}\log(nmd) for a sufficiently large constant CιC_{\iota}. Throughout the appendix we will use eιe^{-\iota} to track failure probabilities of various lemmas and theorems.

We say that an event AA happens with high probability if it happens with probability at least 1poly(n,m,d)eι1-\operatorname{poly}(n,m,d)e^{-\iota} where poly(n,m,d)\operatorname{poly}(n,m,d) does not depend on CιC_{\iota}.

Note that high probability events are closed under taking union bounds over sets of size poly(n,m,d)\operatorname{poly}(n,m,d). We will assume throughout that ιcd\iota\leq cd for a sufficiently small absolute constant cc.

The following lemma bounds \normxi\norm{x_{i}} and is a direct corollary of Lemma 15:

With high probability, \normxi2\quantity[d2,2d]\norm{x_{i}}^{2}\in\quantity[\frac{d}{2},2d] for i=1,,ni=1,\ldots,n.

All remaining proofs will be conditioned on this high probability event.

Let σ(x):=ReLU(x)=max(0,x)\sigma(x):=\operatorname{ReLU}(x)=\max(0,x). Then the Hermite expansion of σ(x)\sigma(x) is

Let ckc_{k} denote the Hermite coefficients of σ\sigma, i.e. σ(x)=k0ckk!Hek(x)\sigma(x)=\sum_{k\geq 0}\frac{c_{k}}{k!}He_{k}(x). Note that

Let the Hermite expansion of ff^{\star} be

Note that as an immediate consequence of Lemma 5, \normCkF2k!\norm{C_{k}}_{F}^{2}\leq k!. In addition, 1 guarantees that Ck\quantity(xk)=Ck\quantity(\quantity(Πx)k)C_{k}\quantity(x^{\otimes k})=C_{k}\quantity(\quantity(\Pi^{\star}x)^{\otimes k}).

A.1.3 Concentrating α,β𝛼𝛽\alpha,\beta

Let α=1ni=1nyi\alpha=\frac{1}{n}\sum_{i=1}^{n}y_{i} and β=1ni=1nyixi\beta=\frac{1}{n}\sum_{i=1}^{n}y_{i}x_{i}. Then, with high probability,

Let F(x1,,xn)=1ni=1nf(xi)C0F(x_{1},\ldots,x_{n})=\frac{1}{n}\sum_{i=1}^{n}f^{\star}(x_{i})-C_{0}. Note that

The bound on \absolutevalueαC0\absolutevalue{\alpha-C_{0}} therefore immediately follows from Lemma 17 applied to FF. The bound on \normβC1\norm{\beta-C_{1}} is a special case of Lemma 19 with σ(x)=x\sigma(x)=x. ∎

A.1.4 Hermite Expanding the Features

Note that by the scale invariance of σ(x)=ReLU(x)\sigma(x)=\operatorname{ReLU}(x), Algorithm 1 does not depend on \normwj\norm{w_{j}} for j=1,,mj=1,\ldots,m. Therefore we can assume WLOG that \normwj=1\norm{w_{j}}=1 for j=1,,mj=1,\ldots,m and wjUnif(Sd1)w_{j}\sim\operatorname{Unif}(S^{d-1}). For the remainder of the appendix we will assume that \normwj=1\norm{w_{j}}=1.

We define f^(x):=f(x)αβx\widehat{f}^{\star}(x):=f^{\star}(x)-\alpha-\beta\cdot x.

The functions g(w)g(w) and gn(w)g_{n}(w) capture the features that can be learned after one step of gradient descent:

By Stein’s lemma and the orthogonality of Hermite polynomials,

Note that these sums are finite as Ck=0C_{k}=0 for k>pk>p. Next, by Corollary 12 we have the high probability bounds,

Applying these bounds term by term and using Lemma 6 to bound \absolutevalueC0α\absolutevalue{C_{0}-\alpha} and \normC1β\norm{C_{1}-\beta} gives the desired result. ∎

Furthermore, it will become necessary to bound terms of the form gn(w)xig_{n}(w)\cdot x_{i}. Note that gn(w)g_{n}(w) and xix_{i} are dependent random variables. The following lemma handles this dependence.

Let wSd1w\sim S^{d-1} and assume nd2ιpn\geq d^{2}\iota^{p}. Then with high probability,

For the first term, note that g(w)g(w) and xjx_{j} are independent so g(w)xjN(0,\normg(w)2)g(w)\cdot x_{j}\sim N(0,\norm{g(w)}^{2}) so with high probability,

Note that in the first term, the xjx_{j} and the sum are independent. Therefore by Corollary 7 the first term is bounded with high probability by O\quantity(dιp+2n)O\quantity(\sqrt{\frac{d\iota^{p+2}}{n}}). In addition, by Lemma 17, the second term is bounded by O\quantity(ιp/2dn)O\quantity(\frac{\iota^{p/2}d}{n}) which completes the proof. ∎

A.2 Random Feature Approximation

This section shows that after we reinitialize the biases we can use random features to transform the activation σ(x)=ReLU(x)\sigma(x)=\operatorname{ReLU}(x) into σ(x)=xp\sigma(x)=x^{p} which is more natural for learning polynomials.

Let aUnif(\quantity1,1)a\sim\operatorname{Unif}(\quantity{-1,1}), and bUnif()b\sim\operatorname{Unif}(). Then for any k0k\geq 0 there exists vk(a,b)v_{k}(a,b) such that for \absolutevaluex1\absolutevalue{x}\leq 1,

First, for k=0k=0 we can take v0(a,b):=6bv_{0}(a,b):=6b. Then,

and supa,b\absolutevaluev0(a,b)=6\sup_{a,b}\absolutevalue{v_{0}(a,b)}=6. Next, for k=1k=1 we can take v1(a,b):=2av_{1}(a,b):=2a. Then,

and we have supa,b\absolutevaluev1(a,b)=2\sup_{a,b}\absolutevalue{v_{1}(a,b)}=2. Next, note that by integration by parts we have for any function ff,

Therefore for k2k\geq 2 if f(x)=xkf(x)=x^{k} and

Let aUnif(\quantity1,1)a\sim\operatorname{Unif}(\quantity{-1,1}), and bN(0,1)b\sim N(0,1). Then for any k0k\geq 0 there exists vk(a,b)v_{k}(a,b) such that for \absolutevaluex1\absolutevalue{x}\leq 1,

Let vk\overline{v}_{k} be the function constructed in Lemma 9 and let

where μ(b):=ex222π\mu(b):=\frac{e^{-\frac{x^{2}}{2}}}{\sqrt{2\pi}} denotes the density of bb. Then,

A.2.2 Multivariable Random Feature Approximation

With high probability over the data {xi}i[n]\{x_{i}\}_{i\in[n]}, we have for j4pj\leq 4p,

We can decompose r(w)=\quantity[gn(w)g(w)]+\quantity[g(w)Hw2π]r(w)=\quantity[g_{n}(w)-g(w)]+\quantity[g(w)-\frac{Hw}{\sqrt{2\pi}}] and note that

We can bound the jjth moment term by term. We have by Corollary 8 and Lemma 24 that for k2k\geq 2,

We can now show that the random features gn(w)g_{n}(w) are sufficiently expressive to allow us to efficiently represent any polynomial of degree pp restricted to the principal subspace SS^{\star}.

For any kpk\leq p, there exists an absolute constant CC such that if nCd2rκ2ιp+1n\geq Cd^{2}r\kappa^{2}\iota^{p+1} and dCκr3/2d\geq C\kappa r^{3/2},

where ΠSymk(S)\Pi_{\operatorname{Sym}^{k}(S^{\star})} denotes the orthogonal projection onto symmetric kk tensors restricted to SS^{\star}.

for all symmetric kk tensor TT with \normTF2=1\norm{T}_{F}^{2}=1. Recall that gn(w)=Hw2π+r(w)g_{n}(w)=\frac{Hw}{\sqrt{2\pi}}+r(w). Therefore by the binomial theorem,

where \absolutevalueδ(w)i=1k\normT\quantity((Hw)ki)F\normΠr(w)i.\absolutevalue{\delta(w)}\lesssim\sum_{i=1}^{k}\norm{T\quantity((Hw)^{\otimes k-i})}_{F}\norm{\Pi^{\star}r(w)}^{i}. Therefore by Young’s inequality,

Let T^\hat{T} be the symmetric kk tensor defined by T^(v1,,vk)=T(Hv1,,Hvk)\hat{T}(v_{1},\ldots,v_{k})=T(Hv_{1},\ldots,Hv_{k}). Then by Corollary 13,

Because we assumed nCd2rκ2ιp+1n\geq Cd^{2}r\kappa^{2}\iota^{p+1} and dCκr3/2d\geq C\kappa r^{3/2} for a sufficiently large constant CC, we have

Assume nCd2rκ2ιp+1n\geq Cd^{2}r\kappa^{2}\iota^{p+1} and dCκr3/2d\geq C\kappa r^{3/2} for a sufficiently large constant CC. Then for any kpk\leq p and any symmetric kk tensor TT supported on SS^{\star}, there exists zT(w)z_{T}(w) such that

Assume nCd2rκ2ιp+1n\geq Cd^{2}r\kappa^{2}\iota^{p+1} and dCκr3/2d\geq C\kappa r^{3/2} for a sufficiently large constant CC. Let η1=dC2ι3\eta_{1}=\sqrt{\frac{d}{C^{2}\iota^{3}}}, let kpk\leq p and let TT be a kk tensor. Then with high probability, there exists hT(a,w,b)h_{T}(a,w,b) such that if

where vk(a,b)v_{k}(a,b) and zT(w)z_{T}(w) are constructed in Corollary 3 and Corollary 4 respectively. Recall that w(1)=2η1agn(w)w^{(1)}=2\eta_{1}ag_{n}(w). Then for x\quantityx1,,xnx\in\quantity{x_{1},\ldots,x_{n}},

where the second to last line followed from Lemma 8. The first part of the lemma now follows from a union bound over x1,,xnx_{1},\ldots,x_{n}. For the bounds on hh, we have

Assume nCd2rκ2ιp+1n\geq Cd^{2}r\kappa^{2}\iota^{p+1} and dCκr3/2d\geq C\kappa r^{3/2} for a sufficiently large constant CC and let η1=dι3\eta_{1}=\sqrt{\frac{d}{\iota^{3}}}. Then with high probability, there exists h(a,w,b)h(a,w,b) such that if

with \normTkFrpk4\norm{T_{k}}_{F}\lesssim r^{\frac{p-k}{4}}. Let

Then 1ni=1n(fh(xi)f(xi))21n\frac{1}{n}\sum_{i=1}^{n}(f_{h}(x_{i})-f^{\star}(x_{i}))^{2}\lesssim\frac{1}{n} is immediate from Lemma 12 and

Let aj:=1mh(aj,wj,bj)a^{\star}_{j}:=\frac{1}{m}h(a_{j},w_{j},b_{j}) where hh is the function constructed in Corollary 5. Then,

Then with probability 1poly(n,m,d)eι1-\operatorname{poly}(n,m,d)e^{-\iota} we have that Zj(xi)=Zj(xi)Z_{j}(x_{i})=\overline{Z}_{j}(x_{i}) for i=1,,ni=1,\ldots,n. Therefore,

For the first term, by Bernstein’s inequality we have with probability at least 12eι1-2e^{-\iota},

and the first part of the lemma follows from a union bound.

We will now turn to the bound on \norma2\norm{a^{\star}}^{2}. Let zi=(ai)2+(ami)2z_{i}=(a^{\star}_{i})^{2}+(a^{\star}_{m-i})^{2}. Note that {zi}im/2\{z_{i}\}_{i\leq m/2} are positive, i.i.d., and bounded by O(m2r2pκ4pι12p)O(m^{-2}r^{2p}\kappa^{4p}\iota^{12p}). In addition, they have expectation O(m2rpκ2pι3p)O(m^{-2}r^{p}\kappa^{2p}\iota^{3p}). Therefore by Popoviciu’s inequality they have variance bounded by

Therefore by Bernstein’s inequality we have that with high probability,

A.3 Proof of Theorem 1

to be the empirical L2L^{2} losses with respect to the true labels (recall yi=f(xi)+ϵiy_{i}=f^{\star}(x_{i})+\epsilon_{i}, ϵi{σ,σ}\epsilon_{i}\sim\{-\sigma,\sigma\}).

Assume nCd2rκ2ιp+1n\geq Cd^{2}r\kappa^{2}\iota^{p+1} and dCκr3/2d\geq C\kappa r^{3/2} for a sufficiently large constant CC and let η1=dι3\eta_{1}=\sqrt{\frac{d}{\iota^{3}}}. Let aa^{\star} be the vector constructed in the proof of Lemma 13 and let θ=(a,W(1),b(1))\theta=(a^{\star},W^{(1)},b^{(1)}). Then with high probability,

Let δi=fθ(xi)f(xi)\delta_{i}=f_{\theta}(x_{i})-f^{\star}(x_{i}). Then,

First, by Hoeffding’s inequality, we have with high probability,

We are now ready to directly prove Theorem 1.

Note that we can assume that there is an absolute constant CC such that nCd2rκ2ιp+1n\geq Cd^{2}r\kappa^{2}\iota^{p+1}, and mrpκ2pι6p+1m\geq r^{p}\kappa^{2p}\iota^{6p+1}. Otherwise, we can simply take λ\lambda\to\infty and return the zero predictor.

From Lemma 14 we know that with high probability, there exists aa^{\star} such that if θ=(a,W(1),b(1))\theta=(a^{\star},W^{(1)},b^{(1)}),

and \norma22rpκ2pι6p+1m.\norm{a^{\star}}_{2}^{2}\lesssim\frac{r^{p}\kappa^{2p}\iota^{6p+1}}{m}. Therefore by equality of norm constrained linear regression and ridge regression, there exists λ>0\lambda>0 such that if

Then with high probability, f(a(T),W(1),b(1))Ff_{(a^{(T)},W^{(1)},b^{(1)})}\in\mathcal{F}. In addition, from Lemma 28,

Appendix B Transfer Learning

The proof of Theorem 3 is virtually identical to that of Theorem 1. We can use Lemma 13 to construct aa^{\star} such that if θ=(a,W(1),b(1))\theta^{\star}=(a^{\star},W^{(1)},b^{(1)}) then with high probability,

In addition, there exists λ\lambda such that if TΘ(η1λ1)T\geq\Theta(\eta^{-1}\lambda^{-1}),

Now let F=\quantityf(a,W,b) : \norma2\norma\mathcal{F}=\quantity{f_{(a,W,b)}~{}:~{}\norm{a}_{2}\leq\norm{a^{\star}}}. Then by Lemma 27 we have with high probability,

Appendix C Concentration Lemmas

Let Xχ2(d)X\sim\chi^{2}(d). Then, for any t0t\geq 0,

Let wN(0,Id)w\sim N(0,I_{d}). Then for some constant CC,

Let gg be a polynomial of degree pp. Then there exists an absolute constant CpC_{p} depending only on pp such that for any δ\delta,

Therefore by Theorem 1.2 of , there exists an absolute constant CpC_{p} such that

Note that the planes wx1=0,,wxn=0w\cdot x_{1}=0,\ldots,w\cdot x_{n}=0 divides the sphere Sd1S^{d-1} into at most i=0d(ni)nd\sum_{i=0}^{d}\binom{n}{i}\lesssim n^{d} convex regions. For each region there exists an ϵ\epsilon net of size \quantity(3ϵ)d\quantity(\frac{3}{\epsilon})^{d}. Therefore we can take the union of these nets over each region which has size at most \quantity(3nϵ)d=eCdlog(n/ϵ)\quantity(\frac{3n}{\epsilon})^{d}=e^{Cd\log(n/\epsilon)}. ∎

Let f(x)f(x) be a polynomial of degree pp and let σ(x){x,ReLU(x)}\sigma(x)\in\{x,\operatorname{ReLU}(x)\}. Then there exists an absolute constant CpC_{p} depending only on pp such that for any ι>0\iota>0, with probability at least 12neι1-2ne^{-\iota}, we have

Let Zi(w):=g(xi)(uxi)σ(wxi)1{\absolutevalueg(x)<R}Z_{i}(w):=g(x_{i})(u\cdot x_{i})\sigma^{\prime}(w\cdot x_{i})\mathbf{1}_{\{\absolutevalue{g(x)}<R\}} so that

Then note that for fixed ww, Zi(w)Z_{i}(w) is RR-sub Gaussian so for each uN1/4u\in\mathcal{N}_{1/4}, with probability 12ez1-2e^{-z} we have

so by a union bound we have with probability 12eCdlog(n/ϵ)ez1-2e^{Cd\log(n/\epsilon)}e^{-z},

so setting z=Cdlog(n/ϵ)+ιz=Cd\log(n/\epsilon)+\iota we have with probability 12eι1-2e^{\iota},

Using ϵ=dn\epsilon=\sqrt{\frac{d}{n}} and putting everything together gives with probability 12neι1-2ne^{-\iota},

Let ϵi{ς,ς}\epsilon_{i}\sim\{-\varsigma,\varsigma\}. Then with high probability,

Next, note that for fixed u,wu,w, ϵi(uxi)σ(wxi)\epsilon_{i}(u\cdot x_{i})\sigma^{\prime}(w\cdot x_{i}) is ς2\varsigma^{2} sub-Gaussian so for any ι>0\iota>0, with probability 12eι1-2e^{-\iota},

By a union bound, with probability at least 12eι1-2e^{\iota},

Appendix D CSQ Lower Bound

The proof is a modified version of the proof in Szörényi . Let ,D\langle\cdot,\cdot\rangle_{\mathcal{D}} denote the L2L^{2} inner product with respect to D\mathcal{D}. We will show that there are at least two functions f,gFf,g\in\mathcal{F} such that for each query hkh_{k}, \absolutevaluef,hkDτ\absolutevalue{\langle f,h_{k}\rangle_{\mathcal{D}}}\leq\tau and \absolutevalueg,hkDτ\absolutevalue{\langle g,h_{k}\rangle_{\mathcal{D}}}\leq\tau. Therefore, we can simply respond to each query adversarially with and it is impossible for the learner to distinguish between f,gf,g. Note that failing to do so will result in a loss of \normfgD222ϵ\norm{f-g}_{\mathcal{D}}^{2}\geq 2-2\epsilon. Let the kkth query be hkh_{k} and let

Similarly, we have that \absolutevalueAk1τ2ϵ\absolutevalue{A_{k}^{-}}\leq\frac{1}{\tau^{2}-\epsilon} so the number of functions that are eliminated from the kkth query is at most 2τ2ϵ\frac{2}{\tau^{2}-\epsilon}. We can continue this process for at most \absolutevalueF(τ2ϵ)2\frac{\absolutevalue{F}(\tau^{2}-\epsilon)}{2} iterations. ∎

Let v1,,vkSd1v_{1},\ldots,v_{k}\sim S^{d-1}. Then for every pair iji\neq j, vivjv_{i}\cdot v_{j} is O(d1)O(d^{-1}) subgaussian so for an absolute constant cc, with probability 12e2cϵ2d1-2e^{-2c\epsilon^{2}d}, \absolutevaluevivjϵ\absolutevalue{v_{i}\cdot v_{j}}\leq\epsilon. Therefore with probability 1k2e2cϵ2d>01-k^{2}e^{-2c\epsilon^{2}d}>0 this holds for all iji\neq j so there must exist at least one collection of such points. ∎

Let SS be the set constructed in Lemma 3. Let

and note that for all fFf\in\mathcal{F}, \normfD=1\norm{f}_{\mathcal{D}}=1. Then for v,wSv,w\in S and vwv\neq w,

Therefore, by Lemma 2 we have for any ϵ\epsilon,

In particular if we take ϵ=log(4q(cd)k/2)cd\epsilon=\sqrt{\frac{\log(4q(cd)^{k/2})}{cd}} we get

Appendix E Additional Technical Lemmas

For a kk tensor TT, let Sym(T)\operatorname{Sym}(T) denote the symmetrization of TT along all k!k! permutations of indices.

There exist T0,,TpT_{0},\ldots,T_{p} such that

and \normTkFrpk4\norm{T_{k}}_{F}\lesssim r^{\frac{p-k}{4}} for kpk\leq p.

Note that from the Taylor series of f(x)f^{\star}(x) we have

Note that by a simple counting argument, the number of permutations such that this product of indicators is nonzero is exactly k!i=1dcj!(cj/2)!k!\prod_{i=1}^{d}\frac{c_{j}!}{(c_{j}/2)!} as you can first order the indices corresponding to each cjc_{j}, then split them into groups of two, then shuffle these groups of two. Therefore,

because jcj=2k\sum_{j}c_{j}=2k, which completes the proof. ∎

Let {hkl}\{h_{kl}\} and {hkl1}\{h^{-1}_{kl}\} denote the change of basis matrices between Hermite polynomials and monomials, i.e.

Let TT be a symmetric pp-tensor and let wN(0,Id)w\sim N(0,I_{d}). Then for kpk\leq p,

Let T=icivipT=\sum_{i}c_{i}v_{i}^{p} with vi=1\|v_{i}\|=1. Using the change of basis xklkhkl1Hel(x)x^{k}\to\sum_{l\leq k}h^{-1}_{kl}He_{l}(x),

Let TT be a symmetric pp-tensor with dim(span(T))=r\dim(\operatorname{span}(T))=r. For kpk\leq p,

The proof follows directly from Lemma 23 and the inequality T(Il)F=T(Πspan(T)l)F\normTF2\normΠspan(T)lF2=rl\normTF2\|T(I^{\otimes l})\|_{F}=\|T(\Pi_{\operatorname{span}(T)}^{\otimes l})\|_{F}\leq\norm{T}_{F}^{2}\norm{\Pi_{\operatorname{span}(T)}^{\otimes l}}_{F}^{2}=r^{l}\norm{T}_{F}^{2} for 2lk2l\leq k. ∎

Let TT be a symmetric pp-tensor with dim(span(T))=r\dim(\operatorname{span}(T))=r. With probability at least 12eι1-2e^{-\iota},

Therefore by Lemma 17, with probability at least 12eι1-2e^{-\iota}, F(w)\normTF2rk2ιkF(w)\lesssim\norm{T}_{F}^{2}r^{\lfloor\frac{k}{2}\rfloor}\iota^{k} and taking square roots completes the proof. ∎

This follows immediately from Lemma 23 and (k2l)!(k2l)2(p2l)!(p2l)2(k-2l)!\binom{k}{2l}^{2}\leq(p-2l)!\binom{p}{2l}^{2}. ∎

Let ff be a polynomial of degree pp. Then

E.2 Sphere Lemmas

This follows from the decomposition w=νww=\nu\overline{w} with νχ(d),wSd1\nu\sim\chi(d),\overline{w}\sim S^{d-1} independent. ∎

Let TT be a symmetric pp-tensor with dim(span(T))=r\dim(\operatorname{span}(T))=r. With probability at least 12eι1-2e^{-\iota},

Let wSd1\overline{w}\sim S^{d-1}. For kpk\leq p,

E.3 Rademacher Complexity Bounds

Let f=aTσ(Wx+b)f=a^{T}\sigma(W^{\star}x+b) be a two layer neural network. For fixed W,bW,b, Let

Let θ=(a,W,b)\theta=(a,W,b) and let f=aTσ(Wx+b)f=a^{T}\sigma(Wx+b) be a two layer neural network. Let