Hidden Progress in Deep Learning: SGD Learns Parities Near the Computational Limit

Boaz Barak, Benjamin L. Edelman, Surbhi Goel, Sham Kakade, Eran Malach, Cyril Zhang

Introduction

In deep learning, performance improvements are frequently observed upon simply scaling up resources (such as data, model size, and training time). While these improvements are often continuous in terms of these resources, some of the most surprising recent advances in the field have been emergent capabilities: at a certain threshold, behavior changes qualitatively and discontinuously. Through a statistical lens, it is well-understood that larger models, trained with more data, can fit more complex and expressive functions. However, far less is known about the analogous computational question: how does the scaling of these resources influence the success of gradient-based optimization?

These phase transitions cannot be explained via statistical capacity alone: they can appear even when the amount of data remains fixed, with only model size or training time increasing. A timely example is the emergence of reasoning and few-shot learning capabilities when scaling up language models (Radford et al., 2019, Brown et al., 2020, Chowdhery et al., 2022, Hoffmann et al., 2022); Srivastava et al. (2022) identify various tasks which language models are only able to solve if they are larger than a critical scale. Power et al. (2022) give examples of discontinuous improvements in population accuracy (“grokking”) when running time increases, while dataset and model sizes remain fixed.

In this work, we analyze the computational aspects of scaling in deep learning, in an elementary synthetic setting which already exhibits discontinuous improvements. Specifically, we consider the supervised learning problem of learning a sparse parity: the label is the parity (XOR) of knk\ll n bits in a random length-nn binary string. This problem is computationally difficult for a range of algorithms, including gradient-based (Kearns, 1998) and streaming (Kol et al., 2017) algorithms. We focus on analyzing the resource measure of training time, and demonstrate that the loss curves for sparse parities display a phase transition across a variety of architectures and hyperparameters (see Figure 1, left). Strikingly, we observe that SGD finds the sparse subset (and hence, reaches 0 error) with a variety of activation functions and initialization schemes, even with no over-parameterization.

A natural hypothesis to explain SGD’s success in learning parities, with no visible progress in error and loss for most of training, would be that it simply “stumbles in the dark”, performing random search for the unknown target (e.g. via stochastic gradient Langevin dynamics). If that were the case, we might expect to observe a convergence time of 2Ω(n)2^{\Omega(n)}, like a naive search over parameters or subsets of indices. However, Figure 1 (right), already provides some evidence against this “random search” hypothesis: the convergence time adapts to the sparsity parameter kk, with a scaling of nO(k)n^{O(k)} on small instances. Notably, such a convergence rate implies that SGD is closer to achieving the optimal computation time among a natural class of algorithms (namely, statistical query algorithms).

Through an extensive empirical analysis of the scaling behavior of a variety of models, as well as theoretical analysis, we give strong evidence against the “stumbling in the dark” viewpoint. Instead, there is a hidden progress measure under which SGD is steadily improving. Furthermore, and perhaps surprisingly, we show that SGD achieves a computational runtime much closer to the optimal SQ lower bound than simply doing (non-sparse) parameter search. More generally, our investigations reveal a number of notable phenomena regarding the dependence of SGD’s performance on resources: we identify phase transitions when varying data, model size, and training time.

It is known from SQ lower bounds that with a constant noise level, gradient descent on any architecture requires at least nΩ(k)n^{\Omega(k)} computational steps to learn kk-sparse nn-dimensional parities (for background, see Appendix A). We first show a wide variety of positive empirical results, in which neural networks successfully solve the parity problem in a number of iterations which scales near this computational limit:

For all small instances (n30,k4n\leq 30,k\leq 4) of the sparse parity problem, architectures A{\mathcal{A}\in\{2-layer MLPs, TransformersWith a smaller range of hyperparameters., sinusoidal/oscillating neurons, PolyNetsA non-standard architecture introduced in this work; see Section 3 for the definition.}\}, initializations in {\{uniform, Gaussian, Bernoulli}\}, and batch sizes 1B10241\leq B\leq 1024, SGD on A\mathcal{A} solves the (n,k)(n,k)-sparse parity problem (w.p. 0.2\geq 0.2) within at most cnαkc\cdot n^{\alpha k} steps, for small constants c,αc,\alpha.

Our empirical results suggest that, in a number of computational steps matching the SQ limit, SGD is able to solve the parity problem and identify the influential coordinates, without an explicit sparse prior. We give a theoretical analysis which validates this claim.

We also present a stronger analysis for an idealized architecture (which we call the disjoint-PolyNet), which allows for any batch size, and captures the phase transitions observed in the error curves.

On disjoint-PolyNets, SGD (with any batch size B1B\geq 1) converges with high probability to a solution with at most ϵ\epsilon error on the (n,k)(n,k)-parity problem in at most nO(k)log(1/ϵ)n^{O(k)}\cdot\log(1/\epsilon) iterations. Continuous-time gradient flow exhibits a phase transition: it spends a 1o(1)1-o(1) fraction of its time before convergence with error 49%\geq 49\%.

Our theoretical and empirical results hold in non-overparameterized regimes (including with a width-11 sinusoidal neuron), in which no fixed kernel, including the neural tangent kernel (NTK) (Jacot et al., 2018), is sufficiently expressive to fit all sparse parities with a large margin. Thus, our findings comprise an elementary example of combinatorial feature learning: SGD can only successfully converge by learning a low-width sparse representation.

Building upon our core positive results, we provide a wide variety of preliminary experiments, showing sparse parity learning to be a versatile testbed for understanding the challenges and surprises in solving combinatorial problems with neural networks. These include quantities which reveal the continual hidden progress behind uninformative training curves (as predicted by the theory), experiments at small sample sizes which exhibit grokking (Power et al., 2022), as well as an example where greedy layer-wise learning is impossible but end-to-end SGD can learn the layers jointly.

2 Related work

We present the most directly related work on feature learning, and learning parities with neural nets. A broader discussion can be found in Appendix A.3.

Theoretical analysis of gradient descent on neural networks is notoriously hard, due to the non-convex nature of the optimization problem. That said, it has been established that in some settings, the dynamics of GD keep the weights close to their initialization, thus behaving like convex optimization over the Neural Tangent Kernel (see, for example, (Jacot et al., 2018, Allen-Zhu et al., 2019, Du et al., 2018)). In contrast, it has been shown that in various tasks, moving away from the fixed features of the NTK is essential for the success of neural networks trained with GD (for example (Yehudai and Shamir, 2019, Allen-Zhu and Li, 2019, Wei et al., 2019) and the review in (Malach et al., 2021)). These results demonstrate that feature learning is an important part of the GD optimization process. Our work also focuses on a setting where feature learning is essential for the target task. In our theoretical analysis, we show that the initial population gradient encodes the relevant features for the problem. The importance of the first gradient step for feature learning has been recently studied in (Ba et al., 2022).

The problem of learning parities using neural networks has been investigated in prior works from various perspectives. It has been demonstrated that parities are hard for gradient-based algorithms, using similar arguments as in the SQ analysis (Shalev-Shwartz et al., 2017, Abbe and Sandon, 2020). One possible approach for overcoming the computational hardness is to make favorable assumptions on the input distribution. Indeed, recent works show that under various assumptions on the input distribution, neural networks can be efficiently trained to learn parities (XORs) (Daniely and Malach, 2020, Shi et al., 2021, Frei et al., 2022, Malach et al., 2021). In contrast to these results, this work takes the approach of intentionally focusing on a hard benchmark task, without assuming that the distribution has some favorable (namely, non-uniform) structure. This setting allows us to probe the performance of deep learning at a known computational limit. Notably, the work of Andoni et al. (2014) provides analysis for learning polynomials (and in particular, parities) under the uniform distribution. However, their main results require a network of size nO(k)n^{O(k)} (i.e., extremely overparameterized network), and provides only partial theoretical and empirical evidence for the success of smaller networks. Studying a related subject, some works have shown that neural networks display a spectral bias, learning to fit low-frequency coefficients before high-frequency ones (Rahaman et al., 2019, Cao et al., 2019).

Preliminaries

We provide an expanded discussion of background and related work in Appendix A.

A key fact about parities is that they are orthogonal under the correlation inner product: for S[n]S^{\prime}\subseteq[n],

Our main results are presented in the online learning setting, with a stream of i.i.d. batches of examples. At each iteration t=1,,Tt=1,\ldots,T, a learning algorithm A\mathcal{A} receives a batch of BB examples {(xt,i,yt,i)}i=1B\{(x_{t,i},y_{t,i})\}_{i=1}^{B} drawn i.i.d. from DS\mathcal{D}_{S}, then outputs a classifier y^t:{±1}n{±1}\widehat{y}_{t}:\{\pm 1\}^{n}\rightarrow\{\pm 1\}. We say that A\mathcal{A} solves the parity task in tt steps (with error ϵ\epsilon) if

Empirical findings

The central phenomenon of study in this work is the empirical observation that neural networks, with standard initialization and training, can solve the (n,k)(n,k)-parity problem in a number of iterations scaling as nO(k)n^{O(k)} on small instances. We observed robust positive results for randomly-initialized SGD on the following architectures, indexed by Roman numerals:

2-layer MLPs: ReLU (σ(z)=(z)+)(\sigma(z)=(z)_{+}) or polynomial (σ(z)=zk\sigma(z)=z^{k}) activation, in a wide variety of width regimes rkr\geq k. Settings (i), (ii), (iii) (resp. (iv), (v), (vi)) use r={10,100,1000}r=\{10,100,1000\} ReLU (resp. polynomial) activations. We also consider r=kr=k (exceptional settings (*i), (*ii) ), the minimum width for representing a kk-wise parity for both activations.

1-neuron networks: Next, we consider non-standard activation functions σ\sigma which allow a one-neuron architecture f(x;w)=σ(wx)f(x;w)=\sigma(w^{\top}x) to realize kk-wise parities. The constructions stem from letting w=iSeiw^{*}=\sum_{i\in S}e_{i}, and constructing σ()\sigma(\cdot) to interpolate (the appropriate scaling of) kwx2 mod 2\frac{k-{w^{*}}^{\top}x}{2}\text{ mod }2 with a piecewise linear kk-zigzag activation (vii), or a degree-kk polynomial (viii). Going a step further, a single \infty-zigzag (ix) or sinusoidal (x) neuron can represent all kk-wise parities. In settings (xi), (xii), (xiii), (xiv), we remove the second trainable layer (setting u=1u=1). We find that wider architectures with these activations also train successfully.

Transformers: There is growing interest in using parity as a benchmark for combinatorial function learning, long-range dependency learning, and length generalization in Transformers (Lu et al., 2021, Edelman et al., 2021, Hahn, 2020, Anil et al., 2022, Liu et al., 2022). Motivated by these recent theoretical and empirical works, we consider a simplified specialization of the Transformer architecture to this sequence classification problem. This is the less-robust setting (*iii); the architecture and optimizer are described in Appendix D.1.3.

PolyNets: Our final setting (xv) is the PolyNet, a slightly modified version of the parity machine architecture. Parity machines have been studied extensively in the statistical mechanics of ML literature (see the related work section) as well as in a line of work on ‘neural cryptography’ (Rosen-Zvi et al., 2002). A parity machine outputs the sign of the product of kk linear functions of the input. A PolyNet simply outputs the product itself. Both architectures can clearly realize kk-sparse parities. The PolyNet architecture was originally motivated by the search for an idealized setting where an end-to-end optimization trajectory analysis is tractable (see Section 4.1); we found in these experiments that this architecture trains very stably and sample-efficiently.

All of the networks listed above were observed to successfully learn sparse parities in a variety of settings. We summarize our findings as follows: for all combinations of n{10,20,30}n\in\{10,20,30\}, k{2,3,4}k\in\{2,3,4\}, batch sizes B{1,2,4,,1024}B\in\{1,2,4,\ldots,1024\}, initializations {\{uniform, Gaussian, Bernoulli}\}, loss functions {\{hinge, square, cross entropy}\}, and architecture configurations {(i),(ii),,(xv)}\{\text{(i)},\text{(ii)},\ldots,\text{(xv)}\}, SGD solved the parity problem (with 100%100\% accuracy, validated on a batch of 2132^{13} samples) in at least 20%20\% of 2525 random trials, for at least one choice of learning rate η{0.001,0.01,0.1,1}\eta\in\{0.001,0.01,0.1,1\}. The models converged in tccnαk105t_{c}\leq c\cdot n^{\alpha k}\leq 10^{5} steps, for small architecture-dependent constants c,αc,\alpha (see Appendix C). Figure 1 (left) shows some representative training curves.

Settings (*i) and (*ii), where the MLP just barely represents a kk-sparse parity, and the Transformer setting (*iii), are less robust to small batch sizes. In these settings, the same positive results as above only held for sufficiently large batch sizes: B16B\geq 16. Also, setting (*iii) used the Adam optimizer (which is standard for Transformers); see Appendix D.1.3 for details.

For almost all of the architectures, we find that that the training curves exhibit phase transitions in terms of running time (and thus, in the online learning setting, dataset size as well): long durations of seemingly no progress, followed by periods of rapid decrease in the validation error. Strikingly, for architectures (v) and (vi), this plateau is absent: the error in the initial phase appears to decrease with a linear slope. See Appendix C.8 for more plots.

2 Random search or hidden progress?

The remainder of this paper seeks to answer the question: “By what mechanism does deep learning solve these emblematic computationally-hard optimization problems?”

A natural hypothesis would be that SGD somehow implicitly performs Monte Carlo random search, “bouncing around” the loss landscape in the absence of a useful gradient signal. Upon closer inspection, several empirical observations clash with this hypothesis:

Scaling of convergence times: Without an explicit sparsity prior in the architecture or initialization, it is unclear how to account for the runtimes observed in experiments, which adapt to the sparsity kk. The initializations, which certainly do not prefer sparse functionsIndeed, under all standard architectures and initialization, the probability that a random network is Ω(1)\Omega(1)-correlated with a sparse parity would be 2Ω(n)2^{-\Omega(n)}, since with that probability 1o(1)1-o(1) of the total influence would be accounted by the nkn-k irrelevant features., are close to the correct solutions with probability 2Ω(n)nk2^{-\Omega(n)}\ll n^{-k}.

Sensitivity to initialization, not SGD samples: Running these training setups over multiple stochastic batches from a common initialization, we find that loss curves and convergence times are highly correlated with the architecture’s random initialization, and are quite concentrated conditioned on initialization; see Figure 2 (center).

Elbows in the scaling curves: For larger nn, the power-law scaling ceases to hold: the exponent worsens (see Figure 2 (right), as well as the discussion in Appendix C.2). This would not be true for random exhaustive search.

Even these observations, which do not probe the internal state of the algorithm, suggest that exhaustive search is an insufficient picture of the training dynamics, and a different mechanism is at play.

Theoretical analyses

We now provide a theoretical account for the success of SGD in solving the (n,k)(n,k)-parity problem. Our main theoretical observation is that, in many cases, the population gradient of the weights at initialization contains enough “information” for solving the parity problem. That is, given an accurate enough estimate of the initial gradient (by e.g. computing the gradient over a large enough batch size), the relevant subset SS can be found.

Carefully extending this insight, we obtain an end-to-end convergence result for ReLU-activation MLP networks with a particular symmetric choice of ±1\pm 1 initialization, trained with the hinge loss:

This does not capture the full range of settings in which we empirically observe successful convergence. First, it requires a sign vector initialization, while we observe convergence with other random initialization schemes (namely, uniform and Gaussian). Second, it requires the batch size to scale with nΩ(k)n^{\Omega(k)}In fact, at this batch size, the correct parity indices emerge in a single SGD step., while we also obtain positive results when BB is small (even B=1B=1). Analogous statements for these cases (as well as other activations and losses) would require Fourier gaps for population gradient functions other than majority; lower bounds on the degree-(k1)(k-1) coefficients (“Fourier anti-concentration”) are particularly elusive in the literature, and we leave it as an open challenge to establish them in more general settings. We provide preliminary empirics in Appendix C.1, suggesting that the Fourier gaps in our empirical settings are sufficiently large.Interestingly, we observe that the Fourier gap tends to increase over the course of training. This is not captured by our current theoretical analysis.

We note that in the low-width (non-overparameterized) regimes considered in this work, no fixed kernel (including the neural tangent kernel (Jacot et al., 2018), whose dimensionality is the network’s parameter count) can solve the sparse parity problem. The following is a consequence of results in (Kamath et al., 2020, Malach and Shalev-Shwartz, 2022):

Thus, our low-width results lie outside the NTK regime, which requires far larger models (size nΩ(k)n^{\Omega(k)}) to express parities. However, we note that better sample complexity bounds are possible in the NTK regime, with an algorithm more similar to standard SGD (see (Telgarsky, 2022) and Appendix A.3).

2 Disjoint-PolyNet: exact trajectory analysis for an idealized architecture

In this section, we present an architecture (a version of PolyNets (xv)) which empirically exhibits similar behavior to MLPs and bypasses the difficulty of analyzing Fourier gaps. The disjoint-PolyNet takes a product over kk linear functions of an equal-sizedWe assume for simplicity that nn is divisible by kk. partition P1,,PkP_{1},\ldots,P_{k} of the input coordinates: f(x;w1:k):=i=1kwi,xPif(x;w_{1:k}):=\prod_{i=1}^{k}\langle w_{i},x_{P_{i}}\rangle. As noted in the Section 1.2, this is equivalent to a tree parity machine, with real-valued (instead of ±1\pm 1) outputs.

This architecture also requires us to assume that the set SS of size kk in the (n,k)(n,k)-parity problem is selected such that exactly one index belongs to each disjoint partition, that is, for all i[k]i\in[k], SPi=1S\cap P_{i}=1. We refer to this problem as the (n,k)(n,k)-disjoint parity problem. Note that there are still (n)k=(n/k)k(n^{\prime})^{k}=(n/k)^{k} different possibilities for set SS under this restriction. For fixed kk, these represent a constant fraction of the (nk)(ne/k)k\binom{n}{k}\approx(ne/k)^{k} (by Stirling’s approximation) possibilities for SS in the general non-disjoint case.

Consider training a disjoint-PolyNet w.r.t. the correlation loss. Without loss of generality, assume that each relevant coordinate in SS is the first element PiP_{i}. Then, the population gradient is non-zero only at indices iSi\in S:

Suppose k3k\geq 3. Let T(ϵ)T(\epsilon) denote the smallest time at which the error is at most ϵ\epsilon. Then,

Informally, the network takes much longer to reach slightly-better-than-trivial accuracy than it takes to go from slightly better than trivial to perfect accuracy. Returning to discrete time, we also analyze the trajectory of disjoint-PolyNets trained with online SGD at any batch size, confirming that a neural network can learn kk-sparse disjoint parities within nO(k)n^{O(k)} iterations.

Extended versions of these theorems, along with proofs, can be found in Appendix B.3.

Hidden progress: discussion and additional experiments

So far, we have shown that sparse parity learning provides an idealized setting in which neural networks successfully learn sparse combinatorial features, with a mechanism of continual progress hiding behind discontinuous training curves. In this section, we outline preliminary explorations on a broader range of interesting phenomena which arise in this setting. Details are provided in Appendix C, while more systematic investigations are deferred to future work.

The theoretical and (black-box) empirical results suggest that SGD does not learn parities via the memoryless process of random exhaustive search. This suggests the existence of progress measures: scalar quantities which are functions of the training algorithm’s state (i.e. the model weights wtw_{t}) and are predictive of the time to successful convergence. We provide some white-box investigations which further support the hypothesis of hidden progress, by examining the gradual improvement in quantities other than the training loss. In Appendix C.1, we directly plot the Fourier gaps of the population gradient, as a function of tt, finding that they are large (within a small constant factor of those of majority) in practice. In Figure 3 and Appendix C.3, we examine the weight movement norm ρ(w0:t):=wtw0\rho(w_{0:t}):=\left\lVert w_{t}-w_{0}\right\rVert_{\infty} to reveal hidden progress, motivated by the fact that wtw0w_{t}-w_{0} is a linearized estimate for the initial population gradient.

An interesting consequence of our analysis is that it illuminates scaling behaviors with respect to a third fundamental resource parameter: model size, which we study in terms of network width rr. If SGD operated by a “random search” mechanism, one would expect width to provide a parallel speedup. Instead, we find that SGD sequentially amplifies progress. The sharp lower tails in Figure 2 (left) imply that running rr identical copies of SGD does not give (1/r)×(1/r)\times speedups; more directly, in Appendix C.4 (previewed in Figure 4 (left)), we find that convergence times for sparse parities empirically plateau at large model sizes.

Our main results are presented in the online learning setting (fresh minibatches from DS\mathcal{D}_{S} at each iteration). While this mitigates the confounding factor of overfitting, it couples the resources of training time and independent samples in a suboptimal way, due to the computational-statistical gap for parity learning. In Appendix C.5, we find empirically that minibatch SGD (with weight decay) can learn sparse parities, even with smaller sample sizes mnkm\ll n^{k}. We reliably observe the grokking phenomenon (Power et al., 2022): an initial overfitting phase, then a delayed phase transition in the generalization error; see the two center panels of Figure 4 (right). These results complement and corroborate the findings of Nanda and Lieberum (2022), who analyze the hidden progress of Transformers trained on arithmetic tasks (a setting which also exhibits grokking).

It is a significant challenge (and generally outside the scope of this work) to understand the interactions between network depth and computational/statistical efficiency. In Appendix C.7, we show that learning parities with deeper polynomial-activation MLPs comprises a simple counterexample to the “deep only works if shallow is good” principle of Malach and Shalev-Shwartz (2019): a deep network can get near-perfect accuracy, even when greedy layer-wise training (e.g. (Belilovsky et al., 2019)) cannot beat trivial performance. By providing positive theory and empirics which elude these simplified explanations of SGD, we hope to point the way to a more complete understanding of learning dynamics in the challenging cases where no apparent progress is made for extended periods of time.

Conclusion

This work puts forward sparse parity learning as an elementary test case to explore the puzzling features of the role of computational (as opposed to statistical) resources in deep learning. In particular, we have shown that a variety of neural architectures solve this combinatorial search problem, with a number of computational steps nearly matching the sparsity-dependent SQ lower bound. Furthermore, we have shown that despite abrupt phase transitions in the loss and accuracy curves, SGD works by gradually amplifying the sparse features “under the hood”.

Even in this simple setting, there are several open experimental and theoretical questions. This work largely focuses on the online learning case, which couples training iterations with fresh i.i.d. samples. We believe it would be instructive to investigate parity learning when the three resources of samples, time, and model size are scaled separately. Some very preliminary findings along these lines are presented in Section 3. It is an open problem to extend our theoretical results to the small-batch setting, as well as to the full range of architectures and losses in our experiments. Resolving these questions would require a better understanding of the anti-concentration behavior of Boolean Fourier coefficients, which is much less studied than the analogous concentration phenomena.

Another important follow-up direction is understanding the extent to which these insights extend from parity learning to more complex (including real-world) combinatorial problem settings, as well as the extent to which non-synthetic tasks (in, e.g., natural language processing and program synthesis) embed within them parity-like subtasks of exhaustive combinatorial search. We hope that our results will lead to further progress towards understanding and improving the optimization dynamics behind the recent slew of dramatic empirical successes of deep learning in these types of domains.

This work seeks to contribute to the foundational understanding of computational scaling behaviors in deep learning. Our theoretical and empirical analyses are in a heavily-idealized synthetic problem setting. Hence, we see no direct societal impacts of the results in this study.

We would like to thank Lenka Zdeborová for providing us with references to the statistical physics literature on phase transitions in the learning curves of neural networks, and Matus Telgarsky for bringing to our attention the better sample complexity guarantees of 2-sparse parity learning in the NTK regime. Sham Kakade acknowledges funding from the Office of Naval Research under award N00014-22-1-2377.

References

Appendix

This work leverages the parity problem as a “computationally hard case” for identifying the features SS which are relevant to the label. Observe that for any S[n]S^{\prime}\subseteq[n], it holds that

That is, a learner who guesses indices SS^{\prime} cannot use correlations as feedback to reveal which (or how many) indices in SS^{\prime} are correct, unless SS^{\prime} is exactly the correct subset. In this sense, for the (n,k)(n,k)-parity problem, the (nk)1\binom{n}{k}-1 wrong answers are indistinguishable from each other. Thus, the structure of this problem forces this form of learner (but not necessarily all learning algorithms) to perform exhaustive search over subsets.

It should be mentioned that the (n,k)(n,k)-parity problem can be solved efficiently by a learning algorithm that has access to examples (i.e., an algorithm that does not operate in the SQ framework). Specifically, this problem can be solved by the Gaussian elimination algorithm. Moreover, it has been shown that the (Stochastic) Gradient Descent algorithm, discussed in the next section, can also be utilized for solving parities, given accurate enough estimates of the gradient and a very particular choice of neural network architecture Abbe and Sandon (2020). That said, when the accuracy of the gradients is not sufficient, GD suffers from the same SQ lower bound mentioned above (i.e., GD is essentially an SQ algorithm Abbe et al. (2021)).

Learning sparse noisy parities, even at a very small noise level (i.e., o(1)o(1) or nδn^{-\delta}) is believed to be computationally hard. This was first explicitly conjectured by Alekhnovich (2003), and has been the basis for several cryptographic schemes (e.g., (Ishai et al., 2008, Applebaum et al., 2009, 2010)). For noiseless sparse parities, Kol et al. (2017) show time-space hardness in the setting where k=ω(1)k=\omega(1). We present some experiments with noisy parities in Appendix C.6, finding that our empirical results (and theoretical analysis) are robust to Θ(1)\Theta(1) noise.

A.2 Neural networks and standard training

where σ()\sigma(\cdot) is applied entrywise. It is standard to use GD to jointly update the network’s parameters. Our results include positive results about “single neurons”: MLPs with width r=1r=1. We note that for our theoretical analysis, when training MLPs with GD, we allow for different learning rate and weight decay schedule for the different layers.

Finally, we will analyze randomized learning algorithms, such as GD with random initialization θ0\theta_{0}, whose iterates θt\theta_{t} (and thus classifiers y^\widehat{y}) are random variables even when the samples are not. A learning algorithm has permutation symmetry if, for all sequences of data {(xt,i,yt,i)}\{(x_{t,i},y_{t,i})\}, the classifiers y^tπ\widehat{y}_{t}\circ\pi resulting from feeding {(π(xt,i),yt,i)}\{(\pi(x_{t,i}),y_{t,i})\} to the learner have identical distributions as π\pi ranges over all permutations of indices. The neural architectures and initializations (and thus, SGD) considered in this work are permutation-symmetric; for this reason, it is convenient for notation to choose S=[k]S=[k] as the canonical (n,k)(n,k)-parity learning problem, without loss of generality.

A.3 Additional related work

A line of recent work has focused on understanding the feature learning ability of gradient descent dynamics on neural networks. These analyses go beyond the Neural Tangent Kernel (NTK) regime, where they show a separation between learning with fixed features versus GD on neural networks, for these problems. Several of these works assume structure (often "sparse") in the input data which is useful for the prediction task, and helps avoid computational hardness. In contrast, our work focuses on studying hard problems at their computational limit. Here we discuss the most relevant works in detail:

A line of work (Diakonikolas et al. (2020), Yehudai and Ohad (2020), Frei et al. (2020)) focuses on learning a single non-linearity y=σ(wx)y=\sigma(w^{\top}x) (typically σ()\sigma(\cdot) is the ReLU or sigmoid) using gradient-based methods. These works obtain polynomial-time convergence guarantees when the distribution satisfies a spread condition. These results do not extend to the Boolean hypercube.

Daniely and Malach (2020) also study the problem of learning sparse parities using neural networks. One key difference from our work is that they assume a modified version of the problem, where the input distribution is not uniform over the hypercube, but instead leaks information about the label. In particular, the distribution ensures that the relevant parity bits always have the same value. Shi et al. (2021) generalize this setting by considering a setting where labels are generated based on certain class specific patterns and the data itself is generated using these patterns with some extra background patterns. This also embeds information in the data itself regarding the label, unlike our setting, where the labels are uncorrelated with the input features. Under these structural assumption, the papers study how GD on a two-layer network can learn useful features in polynomial time. Both these analysis also exploits the first gradient step to find useful features. Shi et al. (2021) additionally require a second step to refine the features.

Ba et al. (2022) show how the first gradient step is important for feature leaning. In particular, they show that first update is essentially rank-1 and aligns with the linear component of the underlying function. The functions we consider (parity) do not have a linear component.

Abbe et al. (2022) define a notion of initial alignment between the network at initialization and the target function and show that it is essential to get polynomial time learnability with noisy gradients on a fully connected network. Our MLP results also exploit the correlation between the gradient and the label to ensure that the gradient update gives us signal.

Frei et al. (2022) also study learnability of a parity-like function with k=2k=2 under noisy labels. The paper analyzes early stopping GD for learning the underlying labeling function. Our setup is quite different from theirs and can handle k>2k>2.

In concurrent work, Damian et al. (2022) consider the problem of learning polynomials which depend on few relevant directions using gradient descent on a two-layer network. They assume that the distributional Hessian of the ground truth function spans exactly the subspace of the relevant direction. Using this, they show that gradient descent can learn the relevant subspace with sample complexity scaling as d2d^{2} and not dpd^{p} where pp is the degree of the underlying polynomial as long as the number of relevant directions is much less than dd. Their proof technique is similar to our two-layer MLP result which also exploits correlation in the first gradient step. However, for our setting, the distributional Hessian has rank 0 and does not satisfy their assumptions.

An extensive body of work originating in the statistical physics community has studied phase transitions in the learning curves of neural networks (Gardner and Derrida, 1989, Watkin et al., 1993, Engel and Van den Broeck, 2001). These works typically focus on student-teacher learning in the “thermodynamic limit”, in which the number of training examples is α\alpha times larger than the input dimension and both are taken to infinity. One of the classic toy architectures analyzed in this literature is the parity machine (Mitchison and Durbin, 1989, Hansel et al., 1992, Opper, 1994, Kabashima, 1994, Simonetti and Caticha, 1996). In our work, we introduce PolyNets, a variant of parity machines in which the output is real-valued rather than ±1\pm 1; and we theoretically analyze disjoint-Polynets, which are the real-output analogue of the oft-considered parity machines with tree architecture. While much of the statistical mechanics of ML literature focuses on an idealized training limit in which the weights reach a Gibbs distribution equilibrium, there is a strand of the literature that aims to characterize the trajectory of SGD training in the high-dimensional limit with constant-sized sets of ordinary differential equations (Saad and Solla, 1995b, a, Goldt et al., 2019). These papers discuss cases, including problems that share aspects with 2-sparse parities Refinetti et al. (2021), where the network gets stuck in (and then escapes from) a plateau of suboptimal generalization error. Recently, Arous et al. (2021) studied (for rank-one parameter estimation problems) the relative amount of time spent by SGD in an initial high-error “search” phase versus a final “descent” phase, which is reminiscent of the framing of Theorem 6. However, to our knowledge prior work has not shown kk-sparse parities can be learned with a number of iterations that nearly matches known lower bounds, nor has it specifically studied phase transitions in kk-sparse parity learning during gradient descent.

Another relevant line of work studies learning the parity problem using the neural tangent kernel (NTK) (Jacot et al., 2018). Namely, in some settings, when the network’s weights stay close to their initialization throughout the training, SGD converges to a solution that is given by a linear function over the initial features of the NTK. As shown in Theorem 5, learning parities over a fixed set of features requires the size of the model to be Ω(nk)\Omega(n^{k}). In contrast, the model size (number of hidden neurons) considered in this paper does not depend at all on the input dimension nn. Nevertheless, the NTK analysis does give better sample complexity guarantees than the ones presented in this work, with a somewhat more natural version of SGD. For example, the work of Ji and Telgarsky (2019) demonstrates learning 2-sparse parities using NTK analysis with a sample complexity of O(n2)O(n^{2}), which matches the sample complexity lower bound for learning this problem with NTK (see Wei et al. (2019)). Concurrent work by Telgarsky (2022) shows that this sample complexity can be improved to O(n)O(n) once the optimization leaves the NTK regime. However, this analysis is given for networks of size O(nn)O(n^{n}), much larger than the networks considered in this paper. We refer the reader to Table 1 in Telgarsky (2022) for a complete comparison of the sample-complexity, run-time and model-size bounds achieved by different works studying 2-sparse parities.

Appendix B Proofs

For some even number rr, consider a ReLU MLP of size rr:

For all 1ir/21\leq i\leq r/2, randomly initialize

Fix some kk, and assume that n2k2n\geq 2k^{2}. Then, for every S[n]S\subseteq[n] s.t. S=k|S|=k it holds that:

for some universal constants c1,c2c_{1},c_{2}. More precisely,

Therefore, by the previous lemma we get that for every even kk the following holds:

Fix some kk and assume that n4kn\geq 4k. Then, Majority has a γk\gamma_{k}-Fourier gap at SS of size kk with γk=0.03(n1)k12\gamma_{k}=0.03(n-1)^{-\frac{k-1}{2}}.

First we establish a simple relationship between ξk1|\xi_{k-1}| and ξk+1|\xi_{k+1}|.

Here, the first equation follows from Lemma 1, and the second equation follows by simple algebra using the following equality: (mr)=mr+1r(mr1)\binom{m}{r}=\frac{m-r+1}{r}\binom{m}{r-1}.

Here, the first equality holds from above, the second by Lemma 1, the third inequality holds from standard approximations of the binomial coefficients, and the last inequality follows from the following inequalities: n2k+1(n1)/2n-2k+1\geq(n-1)/2 (by assumption on nn) and 2πkk/2(k1)ek+2>0.03\frac{\sqrt{2\pi}k^{k/2}}{(k-1)e^{k+2}}>0.03 (by standard calculus). This gives us the desired result.

Assume that kk is even and that n2(k+1)2n\geq 2(k+1)^{2}. Then, the following hold:

Let τ>0\tau>0 be some tolerance parameter, fix ϵ(0,1)\epsilon\in(0,1) and let η=1kξk1\eta=\frac{1}{k|\xi_{k-1}|}. Assume that kk is an even number. Fix some w1,,wk{±1}nw_{1},\dots,w_{k}\in\{\pm 1\}^{n}, b1,,br(1,1)b_{1},\dots,b_{r}\in(-1,1) and u1,,uk{±1}u_{1},\dots,u_{k}\in\{\pm 1\}. Let w^i=ηg^i\widehat{w}_{i}=-\eta\widehat{g}_{i} and b^i=biηγ^i\widehat{b}_{i}=b_{i}-\eta\widehat{\gamma}_{i} s.t. g^igiτ\|\widehat{g}_{i}-g_{i}\|_{\infty}\leq\tau and γ^iγiτ\|\widehat{\gamma}_{i}-\gamma_{i}\|\leq\tau. Assume the following holds:

Additionally, for all ii and all xx it holds that σ(w^ix+b^i)n+1|\sigma(\widehat{w}_{i}\cdot x+\widehat{b}_{i})|\leq n+1.

Claim 1: For all ii and for all j[k]j\in[k] it holds that w^i,j12kτξk1\left|\widehat{w}_{i,j}-\frac{1}{2k}\right|\leq\frac{\tau}{|\xi_{k-1}|}.

Proof: First, observe that by the assumption it holds that for all i,j[k]i,j\in[k],

Claim 2: For all ii and for all j>kj>k it holds that w^i,jξk+1+2τ2kξk1|\widehat{w}_{i,j}|\leq\frac{|\xi_{k+1}|+2\tau}{2k|\xi_{k-1}|}

Claim 3: For all ii it holds that b^ibiτkξk1|\widehat{b}_{i}-b_{i}|\leq\frac{\tau}{k|\xi_{k-1}|}

Claim 4: Fix δ>0\delta>0. Let hih_{i} be a function s.t. hi(x)=σ(12kj=1kxj+bi)h_{i}(x)=\sigma(\frac{1}{2k}\sum_{j=1}^{k}x_{j}+b_{i}) and h^i\widehat{h}_{i} a function s.t. h^i(x)=σ(w^i,jx+b^i)\widehat{h}_{i}(x)=\sigma(\widehat{w}_{i,j}\cdot x+\widehat{b}_{i}). Then, if τδ2kξk12nlog(2k/ϵ)\tau\leq\frac{\delta}{2}\frac{k|\xi_{k-1}|}{\sqrt{2n\log(2k/\epsilon)}} and n32c22log(2k/ϵ)c12δ2n\geq\frac{32c_{2}^{2}\log(2k/\epsilon)}{c_{1}^{2}\delta^{2}}, the following holds:

where in the last inequality we use Eq. (6). So, choosing τδ2kξk12nlog(2k/ϵ)\tau\leq\frac{\delta}{2}\frac{k|\xi_{k-1}|}{\sqrt{2n\log(2k/\epsilon)}} and n32c22log(2k/ϵ)c12δ2n\geq\frac{32c_{2}^{2}\log(2k/\epsilon)}{c_{1}^{2}\delta^{2}} gives the required.

Claim 5: Let h1,,hkh_{1},\dots,h_{k} be the functions defined in the previous claim. Then, there exists weights uu^{*} with u8k\|u^{*}\|_{\infty}\leq 8k s.t. for f(x)=i=1kuihi(x)f^{*}(x)=\sum_{i=1}^{k}u^{*}_{i}h_{i}(x) it holds that f(x)=2χ[k](x)f^{*}(x)=2\chi_{[k]}(x) for all x{±1}nx\in\{\pm 1\}^{n}.

Proof: For ik2i\leq k-2 define ui=8k(1)i+1u^{*}_{i}=8k(-1)^{i+1} and uk1=6ku^{*}_{k-1}=6k, uk=2ku^{*}_{k}=-2k.

Proof of Lemma 4: Choose u^=u\widehat{u}=u^{*}. Using Claim 4 and the union bound, w.p. 1ϵ1-\epsilon over x{±1}nx\sim\{\pm 1\}^{n} it holds that for all i[k]i\in[k], hi(x)h^i(x)δ|h_{i}(x)-\widehat{h}_{i}(x)|\leq\delta. Therefore, w.p. 1ϵ\geq 1-\epsilon

so, choosing δ=18k2\delta=\frac{1}{8k^{2}} we get that, w.p. at least 1ϵ1-\epsilon over the choice of xx it holds that f(x)χ[k](x)1f(x)\chi_{[k]}(x)\geq 1. Additionally, for every xx it holds that

Assume we randomly initialize an MLP using the unbiased initialization defined previously. Consider the following update:

Additionally, it holds that σ(W(1)x+b(1))n+1\|\sigma(W^{(1)}\cdot x+b^{(1)})\|_{\infty}\leq n+1.

Claim: with probability at least 1δ1-\delta,

Proof: Fix some i,ji,j and note that by Hoeffding’s inequality,

Using the union bound, with probability at least 1δ1-\delta, there exists a set of kk neurons satisfying the conditions of Lemma 4, and therefore the required follows from the Lemma. ∎

We use the following well-known result on convergence of SGD (see for example Shalev-Shwartz and Ben-David (2014)):

Let kk be an even number. Assume we randomly initialize an MLP using the unbiased initialization defined previously. Fix ϵ(0,1/2)\epsilon\in(0,1/2) and let T29k3rn2ϵ2,rk2klog(8k/ϵ),Bc1128k7/6n(nk1)log(128k3n/ϵ)log(32nr/ϵ),n211k4c22log(128k3n/ϵ)c12T\geq\frac{2^{9}k^{3}rn^{2}}{\epsilon^{2}},r\geq k\cdot 2^{k}\log(8k/\epsilon),B\geq c_{1}^{-1}2^{8}k^{7/6}n\binom{n}{k-1}\log(128k^{3}n/\epsilon)\log(32nr/\epsilon),n\geq\frac{2^{11}k^{4}c_{2}^{2}\log(128k^{3}n/\epsilon)}{c_{1}^{2}}. Choose the following learning rate and weight decay schedule:

At the first step, use η0=1kξk1\eta_{0}=\frac{1}{k|\xi_{k-1}|}, λ0=1\lambda_{0}=1 for all weights.

After the first step, use ηt=0\eta_{t}=0 for the first layers weights and biases and ηt=4k1.5nr(T1)\eta_{t}=\frac{4k^{1.5}}{n\sqrt{r(T-1)}} for the second layer weights, with λt=0\lambda_{t}=0 for both layers.

Then, the following holds, with expectation over the randomness of the initialization and the sampling of the batches:

Now, we will show that w.h.p. there exists uu^{*} with good loss. Let ϵ=ϵ64k2n,δ=ϵ8\epsilon^{\prime}=\frac{\epsilon}{64k^{2}n},\delta^{\prime}=\frac{\epsilon}{8}. Denote τ=ξk116k2nlog(128k3n/ϵ)=ξk116k2nlog(2k/ϵ)\tau=\frac{|\xi_{k-1}|}{16k\sqrt{2n\log(128k^{3}n/\epsilon)}}=\frac{|\xi_{k-1}|}{16k\sqrt{2n\log(2k/\epsilon^{\prime})}}. Observe that rk2klog(k/δ)r\geq k\cdot 2^{k}\log(k/\delta^{\prime}), and using the fact that ξk1c1(k1)1/3(nk1)1|\xi_{k-1}|\geq c_{1}(k-1)^{-1/3}\binom{n}{k-1}^{-1} we get

and additionally n211k4c22log(2k/ϵ)c12n\geq\frac{2^{11}k^{4}c_{2}^{2}\log(2k/\epsilon^{\prime})}{c_{1}^{2}}.

So, we can apply Theorem 8 with M=8kkM=8k\sqrt{k} and ρ=2rn\rho=2\sqrt{r}n and get that, w.p. 1ϵ/41-\epsilon/4 over the initialization and the first step, it holds that

B.2 Recoverability of the parity indices from Fourier gaps

Given a network architecture where some neuron has a γ\gamma-Fourier gap with respect to the target subset SS, we quantify how the indices in SS can be determined by observing an estimate of the population gradient for a general activation function σ\sigma and wtw_{t}:

Then, for every ww such that σ(wx)\sigma^{\prime}(w^{\top}x) has a γ\gamma-Fourier gap at SS, the kk indices at which g(w)g(w) has the largest absolute values are exactly the indices in SS.

Let h(x):=σ(wx)h(x):=\sigma^{\prime}(w^{\top}x). We compute the population gradient, we we call gˉ(w)\bar{g}(w):

where the inequality in the final i∉Si\not\in S case is due to the Fourier gap property. Then, it holds that for all iSi\in S we have gi>γ/2|g_{i}|>\gamma/2 and for all iSi\notin S we have gi<γ/2|g_{i}|<\gamma/2. Thus, the largest entries of the estimate g(w)g(w) occur at the indices in SS, as claimed. ∎

B.3 Global convergence for disjoint-PolyNets

In this section we will develop theory for disjoint-PolyNets trained with correlation loss, as illustrated in Figure 5. Section B.3.1 will consider optimization with gradient flow, and section B.4 will consider optimization with SGD at any batch size B1B\geq 1.

In gradient flow, the relevant weights evolve according to the following differential equations:

Suppose disjoint-PolyNet for k>2k>2 is initialized such that ivi(0)>0\prod_{i}v_{i}(0)>0, and optimized with gradient flow. Let vˉa:=1ki=1k(vi(0))2\bar{v}_{a}:=\frac{1}{k}\sum_{i=1}^{k}(v_{i}(0))^{2}, and vˉg:=(i=1k(vi(0))2)1/k\bar{v}_{g}:=\left(\prod_{i=1}^{k}(v_{i}(0))^{2}\right)^{1/k}. For any b0b\geq 0 and i[k]i\in[k], let Ti(b):=argsupt0(vi(t)b)T_{i}(b):=\arg\sup_{t\geq 0}(|v_{i}(t)|\leq b). Then

Let Ti():=argsupt0(vi(t)<)T_{i}(\infty):=\arg\sup_{t\geq 0}(|v_{i}(t)|<\infty). Then

First, observe that the product of the relevant weights is non-decreasing during gradient flow.

Thus, ivi(0)>0\prod_{i}v_{i}(0)>0 implies that ivi(t)>0\prod_{i}v_{i}(t)>0 for all tt.

In other words, the squares of the relevant weights each follow the same trajectory, shifted according to their initializations. Let q(t):=(vi(t))2(vi(0))2q(t):=({v_{i}(t)})^{2}-({v_{i}(0)})^{2}, for any ii. This quantity evolves as follows:

Since q(t)q(t) is strictly increasing, its inverse q1q^{-1} is well-defined, and we can use the inverse function theorem to characterize q1q^{-1} for all t0t\geq 0:

We can upper- and lower-bound the integrand by applying Maclaurin’s inequality (see page 52 in Hardy et al. (1952)):

The amount of time it takes for qq to reach a value of cc is thus:

Hence, for any b0b\geq 0, for each ii, vi(t)b|v_{i}(t)|\leq b as long as

Meanwhile, the amount of time after q1(c)q^{-1}(c) it takes for qq to explode to infinity is

Substituting c=b2vi(0)2c=b^{2}-v_{i}(0)^{2}, we obtain that the amount of time it takes for vi|v_{i}| to grow from bb to \infty is

We can upper- and lower-bound q˙\dot{q} by applying Maclaurin’s inequality (see page 52 in Hardy et al. (1952)):

When k=2k=2, solving the LHS and RHS differential inequalities yields:

From the lower bound on q(t)q(t), we can infer that the relevant weights all explode to infinity by the following time:

From the upper bound, we can infer that for any c>0c>0, it is the case that q(t)cq(t)\leq c so long as

Hence, for each ii, vi(t)b|v_{i}(t)|\leq b for all

Now we analyze the relationship between the relevant weights and the accuracy of the disjoint-PolyNet.

Then define the error of ff with parameters w1:kw_{1:k} as

Let ww be any setting of the weights of a disjoint-PolyNet such that ivi>0\prod_{i}v_{i}>0. For ease of notation, let ui:=wi,2:nu_{i}:=w_{i,2:n^{\prime}} be the irrelevant portion of wiw_{i}. There is a constant cc such that

where erf\operatorname*{erf} is the Gauss error function erf(y):=2π0yeτ2dτ\operatorname*{erf}(y):=\frac{2}{\sqrt{\pi}}\int_{0}^{y}e^{-\tau^{2}}\,d\tau.

Let zi=x(i1)n+1z_{i}=x_{(i-1)n^{\prime}+1} be the iith relevant coordinate of xx, and let zi=x(i1)n+2:inz_{i}^{-}=x_{(i-1)n^{\prime}+2:in^{\prime}} be the irrelevant coordinates in PiP_{i}.

Then we have that the error of ff with parameters w1:kw_{1:k} is

Line 10 follows because uiziu_{i}^{\top}z_{i}^{-} and ziz_{i} are independent. In line 11 we use that uiziu_{i}^{\top}z_{i}^{-} is symmetric about 0. Finally, line 12 uses the assumption that i=1kvi>0\prod_{i=1}^{k}v_{i}>0.

We can bound Prx[i=1kwixPi=0]\Pr_{x}\left[\prod_{i=1}^{k}w_{i}^{\top}x_{P_{i}}=0\right] using Hoeffding’s inequality:

The indicator random variables 1I[uizi>vi]\operatorname*{1{\hskip-2.5pt}\hbox{I}}[u_{i}^{\top}z_{i}^{-}>|v_{i}|] are independent of each other, so the first term in line 12 can be characterized using the distribution of the parity of a sum of independent Bernoulli random variables. Let XiBer(pi)X_{i}\sim\operatorname{Ber}(p_{i}) for i[k]i\in[k], and let X=i=1kXiX=\sum_{i=1}^{k}X_{i}. The generating function for XX is f(z)=i=1k((1pi)+piz)f(z)=\prod_{i=1}^{k}((1-p_{i})+p_{i}z). The parity of XX then satisfies

First we will prove the upper bound on err\mathsf{err}. First, observe that

by Hoeffding’s inequality, and we are done.

Now we will prove the lower bound on err\mathsf{err}. We have

We can bound this expression using the Berry-Esseen theorem (Berry, 1941, Esseen, 1942). Let βN(0,ui22)\beta\sim N(0,\|u_{i}\|_{2}^{2}). Then, the Berry-Esseen theorem states that there is a constant cc (which in practice can be .56.56) such that for any i[k]i\in[k],

Plugging this into equation 13, we obtain the lower bound on err\mathsf{err}. ∎

First, we’ll apply Lemma 6 and Lemma 7 to the situation where the wiw_{i}’s have ±1\pm 1 initialization. This generalizes Theorem 6.

Suppose all the weights in a disjoint-Polynet are initialized randomly in ±1\pm 1 and k3k\geq 3.

Let T(α):=argsupt0(err(w1:k(t))αT(\alpha):=\arg\sup_{t\geq 0}(\mathsf{err}(w_{1:k}(t))\geq\alpha. Then, for γ(0,1/2)\gamma\in(0,1/2), if ivi(0)>0\prod_{i}v_{i}(0)>0 (which happens w.p. 1/2),

Thus, even for γ\gamma arbitrarily close to , when the input is sufficiently long, the network spends almost all of training with error above 1/2γ1/2-\gamma.

By Lemma 6, vi(t)b|v_{i}(t)|\leq b for all ii whenever

we obtain that T(1/2γ)=1k2(1O((n)1k/2γ2/k1))T(1/2-\gamma)=\frac{1}{k-2}(1-O((n^{\prime})^{1-k/2}\cdot\gamma^{2/k-1})).

Also by Lemma 6, using the language of that lemma statement, for all ii

Once all the relevant weights have exploded to infinity, the error of the network will have zero error, so the result follows. ∎

Now let us apply Lemma 6 and Lemma 7 to the situation where the wiw_{i}’s have standard normal initialization. Again we find that with high probability, the phase of learning with near-trivial accuracy is much longer than the subsequent period until perfect accuracy, as illustrated by the left-hand plot in Figure 5. This culminates in the full theorem statement regarding phase transitions in the loss:

where T()T(\cdot) is defined as in Corollary 10.

By a standard application of the generic Chernoff bound (Wainwright, 2019), for each i[k],j[n1]i\in[k],j\in[n^{\prime}-1], we have

This implies that w.p. 1ϵ/2\geq 1-\epsilon/2,

For each ii, ui22\|u_{i}\|_{2}^{2} follows a chi-squared distribution with n1n^{\prime}-1 degrees of freedom. By Laurent and Massart (2000), for any τ0\tau\geq 0,

Combining Equation 16 and Equation 17, we obtain the desired statement.

B.4 Global convergence and phase transition for gradient flow on disjoint-PolyNets

In this section we will analyze the training of a disjoint-PolyNet using SGD with online (i.i.d.) batches. We will show a convergence result for the 0-1 error of the learned classifier.

Assume we randomly initialize the disjoint-PolyNet with weights drawn uniformly from {±1}\{\pm 1\}. Fix ϵ(0,1/2)\epsilon\in(0,1/2) and run SGD at any batch size B1B\geq 1 for T6log(2nT/δ)log(2k/ϵ)(3n2)2k1T\geq 6\log(2nT/\delta)\log(2k/\epsilon)(3n^{\prime}-2)^{2k-1} iterations. There exists an adaptive learning rate schedule, such that, with probability 1/2 over the randomness of the initialization and 1δ1-\delta over the sampling of SGD, the following holds:

For simplicity of presentation, we will assume B=1B=1. Let the sample at iteration tt be (x(t),y(t))(x^{(t)},y^{(t)}) where x(t)Unif({±1}n)x^{(t)}\sim\mathsf{Unif}(\{\pm 1\}^{n}) and y(t)=χS(t)(x(t))y^{(t)}=\chi_{S}^{(t)}(x^{(t)}). Denote the population and stochastic gradient at time tt as:

Then, for every t<Tt<T and i,ji,j, by the Azuma-Hoeffding inequality, with probability 1δnT1-\frac{\delta}{nT}, si,j(t)1/2|s^{(t)}_{i,j}|\leq 1/2. By the union bound, w.p. at least 1δ1-\delta, for every t<Tt<T and all i,ji,j it holds that si,j(t)1/2|s^{(t)}_{i,j}|\leq 1/2. Let us assume this holds.

For all tT+1t\leq T+1, for j>1j>1, wi,j(t)3/2|w_{i,j}^{(t)}|\leq 3/2.

Note that wi,j(t+1)=wi,j(1)+si,j(t)w_{i,j}^{(t+1)}=w_{i,j}^{(1)}+s_{i,j}^{(t)}, therefore, we have,

since wi,j(1)=1|w_{i,j}^{(1)}|=1 and si,1(t1)1/2|s_{i,1}^{(t-1)}|\leq 1/2. ∎

For all tT+1t\leq T+1 and i[k]i\in[k] it holds that

Observe that the claim holds for t=1t=1 since wi,(1)=1|w_{i,`}^{(1)}|=1. By induction on tt, assume the claim holds for all τt\tau\leq t. Now we will prove it holds for t+1t+1.

Observe that for all τt\tau\leq t, by the assumption:

Note that wi,1(t+1)=wi,1(1)+τ=1tηi(τ)(jiwj,1(τ))+si,1(t)w_{i,1}^{(t+1)}=w_{i,1}^{(1)}+\sum_{\tau=1}^{t}\eta^{(\tau)}_{i}\left(\prod_{j\neq i}w_{j,1}^{(\tau)}\right)+s_{i,1}^{(t)}, then we have

(18) follows from observing that ξiwi,1(1)=wi,1(1)=1\xi_{i}w_{i,1}^{(1)}=|w_{i,1}^{(1)}|=1 and j=1kξj=1\prod_{j=1}^{k}\xi_{j}=1. (19) follows from the inductive hypothesis ξjwj,1(τ)=ξjwj,1(τ)\xi_{j}w_{j,1}^{(\tau)}=|\xi_{j}w_{j,1}^{(\tau)}|. (20) follows from our assumption that si,1(t)1/2|s_{i,1}^{(t)}|\leq 1/2 and Claim 11.(21) follows from the inductive hypothesis wi,1(τ)1/2|w_{i,1}^{(\tau)}|\geq 1/2. ∎

Setting TT such that T2log(2nT/δ)α2(3n2)2k2T\geq 2\log(2nT/\delta)\alpha^{2}(3n^{\prime}-2)^{2k-2} for α=3(n1)log(2k/ϵ)1\alpha=3\sqrt{(n^{\prime}-1)\log(2k/\epsilon)}-1, from the above claims, after iteration TT, we have

Using Lemma 7, we have with probability 1δ1-\delta,

Appendix C Additional figures, experiments, and discussion

This section contains our unabridged empirical results, visualizations, and accompanying discussion. Additional example training curves (like the assortment in Figure 1 (left)) are shown in Figure 6; more examples can be found in the subsections below.

We first present the full empirical results outlined in Section 3 of the main paper. Figure 7 shows convergence times tct_{c} on small parity instances for all of the architecture configurations enumerated in Section 3.1. In some of these settings, tct_{c} exhibits high variance due to unlucky initializations (see Figure 8); thus, we report 10th10^{\text{th}} percentile convergence times. Figure 9 gives coarse-grained estimates for how tct_{c} scales with (n,k)(n,k), based on small examples. For selected architectures, Figure 10 shows how these convergence times scale with nn and kk more precisely: for small nn, power law relationships tcnαkt_{c}\propto n^{\alpha\cdot k} (for small constants α\alpha) are observed for all configurations. Note that for larger nn, the exponent (i.e. the slope in the log-log plot) increases: with a constant learning rate and standard training, the nΘ(k)n^{\Theta(k)} does not continue indefinitely. All additional details are in Appendix D.

The remainder of Appendix C expands on the various discussions and figures from Sections 4 and 5.

Appendix C.1 gives experimental evidence that Fourier gaps are present at iterates wtw_{t} and initializations w0w_{0} other than sign vectors, as well as for activation functions other than ReLU. This suggests that the feature amplification mechanism is robust, and illuminates directions for strengthening the theoretical results.

Appendix C.2 discusses how the building blocks of deep learning (activation functions, biases, initializations, learning rates, and batch sizes) play multiple, sometimes conflicting roles in this setting.

Appendix C.3 provides additional white-box visualizations of hidden progress from Figure 3.

Appendix C.4 explores the implications of the feature amplification mechanism for scaling model size– namely, unlike random search, large width does not impart parallel speedups.

Appendix C.5 shows that our results hold in the finite-sample setting (allowing for multiple passes over a training set of size mm). In particular, we show that in low-data regimes, the models exhibit the grokking phenomenon.

Appendix C.6 extends our results to noisy parities (which comprise the true “emblematic computationally-hard problem”).

Appendix C.7 introduces a counterexample for “layer-by-layer learning”, using parity distributions whose degrees are higher than those of the individual layers’ polynomial activations. Preliminary experiments show that standard training works in this setting.

Appendix C.8 presents examples of training curves for wide polynomial-activation MLPs, where, unlike the other settings, there is no initial plateau in the model’s error.

C.1 Fourier gaps at initialization and SGD iterates

Proposition 9 shows that if the function xσ(w0x)x\mapsto\sigma^{\prime}(w_{0}^{\top}x) has a Fourier gap at SS, then SS can be identified from a batch gradient at initialization w0w_{0} with B=O(1/γ2)B=O(1/\gamma^{2}) samples. Our end-to-end result (Theorem 4) requires ReLU activations and sign vector initialization, because the Fourier gap condition (Definition 1) arises from exact formulas for the Fourier coefficients of the majority function. Stronger end-to-end theoretical guarantees would follow from analogous Fourier gaps in more general population gradients. This requires xσ(wx)x\mapsto\sigma^{\prime}(w^{\top}x) to satisfy these conditions simultaneously:

Fourier concentration: upper bounds on the degree-(k+1)(k+1) coefficients f^(S{i})\widehat{f}(S\cup\{i\}), for i∉Si\not\in S. The term is borrowed from Klivans et al. (2004), who use upper bounds on Fourier coefficients of LTFs to approximate them (thus, learn halfspaces) with low-degree polynomials.

Fourier anti-concentration: lower bounds on the degree-(k1)(k-1) coefficients f^(S{i})\widehat{f}(S\setminus\{i\}), for iSi\in S.

A natural question is: which Boolean functions, other than majority, satisfy the γ\gamma-Fourier gap property at SS, for γnΩ(k)\gamma\geq n^{-\Omega(k)}?

We present some numerical evidence for large Fourier gaps in functions xσ(wx)x\mapsto\sigma^{\prime}(w^{\top}x) other than majority, which arise from gradients of architectures other than ReLU MLPs with sign initialization. This shows that the mechanism of feature emergence is empirically robust in settings not fully explained by our current theory. Establishing corresponding theoretical guarantees would enable stronger end-to-end global convergence guarantees for MLPs and other architectures.

In these experiments, population gradients were computed by brute force integration over all 2n2^{n} Boolean inputs x{±1}nx\in\{\pm 1\}^{n}. In all cases, for various choices of σ,w\sigma,w, we measure a slightly relaxed notion of Fourier gap Γ\Gamma in the population gradient:

If Γ>0\Gamma>0, then oneReplacing the first max\max in the definition of Γ\Gamma with min\min would give us the same notion of Fourier gap as Definition 1: if all the relevant coordinates are larger than all of the irrelevant ones, estimating the population gradient allows us to recover the relevant coordinates. coordinate from the parity can be identified from O(lognΓ2)O\left(\frac{\log n}{\Gamma^{2}}\right) samples of the gradient at ww.

The successful convergence of architectures with smoother activations (in the parity setting and beyond) motivates the question of whether large Fourier gaps are present in population gradients corresponding to functions other than LTFs. Figure 12 (left) shows that this is the case for sinusoidal activations.

Finally, to further close the gap between Theorem 4 and the empirical results, it is necessary to address the fact that SGD accumulates gradients with respect to time-varying iterates, while our analysis approximates this using a large-batch gradient at a static iterate w0w_{0}. In fact, SGD seems to help in some cases: Figure 12 (right) shows that when training a sinusoidal neuron, SGD amplifies the initial Fourier gap.

C.2 Counterintuitive roles of the building blocks of deep learning

Even in this simple problem setting, the simultaneous computational and statistical considerations lead to counterintuitive consequences for the optimal configurations of architectures and algorithms for this setting. We encountered the following, in the search for architecture configurations for the empirical study:

Activation functions. This mechanism of features emerging via Fourier gaps (see Definition 1) is strongest with non-smooth activations such as the ReLU, whose derivatives are discontinuous threshold functions. This is an orthogonal consideration to representational capacity and mitigation of local minima (under which one might conclude that degree-kk polynomial activations are optimal). In summary, in feature learning settings where the Fourier gaps and low-complexity solutions are simultaneous relevant, there is a sharpness-smoothness tradeoff for the activation function.

Biases. The symmetry of the majority function (as well as all unbiased LTFs) causes its even-degree Fourier coefficients to be zero; thus, certain variants of the setups in Section 3 fail for odd kk. Bias terms (trainable or fixed) are necessary to break this symmetry, in theory and practice. Simultaneously, biases serve the more conventional role of shifting the loss surface; see Section D.1.2 for how this affects the details of how the biases were chosen in the experiments.

Initializations. The role of the initialization distribution is similarly twofold in this setting: w0w_{0} should be close to the desired solution ww^{*}, but it must also be selected such that SGD will successfully amplify the Fourier gap. A third consideration, which we do not attempt to study in this work, is that multiple randomly-initialized neurons will tend to learn the correct features at different times (see the weight trajectory visualizations in Figure 3 and Figure 13, as well the staircase-like training curves seen for MLPs in Figure 1 (left)). We expect this symmetry breaking phenomenon to be present in more complex feature learning settings. Finally, as shown in the training curves from setting (vi) in Figure 6, and in more detail in Appendix C.8, the choice of activation function influences the qualitative behavior of the training curves: namely, whether the plateaus disappear at large widths and batch sizes.

C.3 Hidden progress measures

In this section, we provide an expanded discussion and plots for the investigations outlined in Figure 3 and the “hidden progress measures” section in Section 5.

Note that many trivial progress measures exist: an example to keep in mind is that for the algorithm which exhaustively enumerates over a deterministic list of hypotheses (say, the possible kk-element subsets SS in lexicographical order) and terminates when it finds the correct one, the current iteration tt is a progress measure. Thus, the purpose of demonstrating hidden progress measures ρ\rho is not to provide further evidence that SGD finds the features using Fourier gaps. Rather, it is to (1) further refute the hypothesis of SGD performing a memoryless Langevin-like random search, and (2) provide a preliminary exploration of how progress can be quantified even when the natural metrics of loss and accuracy appear to be flat.

The Fourier gap visualizations in Section C.1 already provides an example of a quantity which varies continuously as the model trains, despite no apparent progress in the loss and accuracy curves. However, none of our theoretical analyses capture the empirical observation that this quantity tends to amplify over time. Below, we consider other quantities which reveal hidden progress in parity learning, which are more straightforward and closer to our analyses.

This progress measure is shown alongside the loss curves in Figure 13, in red. We do not attempt to characterize the dynamics of ρ\rho; we only note that they are clearly distinguishable from the maximum of nn unbiased random walks, even when SGD appears to make no progress in terms of loss and accuracy. Studying hidden progress measures in deep learning more quantitatively, as well as in more general settings, presents a fruitful direction for future work.

C.4 Convergence time vs. width

We provide supplementary plots for the experiment outlined in Figure 4 (left), which probes whether extremely large widths (rnkr\gg n^{k}) afford factor-rr parallel speedups of the parity learning mechanism (as one would expect from random search). On 3 parity instances n{30,40,50},k=3n\in\{30,40,50\},k=3, we varied the width r{1,2,3,,9,10,30,100,300,,106,3×106}r\in\{1,2,3,\ldots,9,10,30,100,300,\ldots,10^{6},3\times 10^{6}\}, keeping all other parameters the same (B=128,η=0.1)(B=128,\eta=0.1).

We did not find evidence of such parallel speedups over 10001000 runs in each setting; see Figure 14. This serves as further evidence that the mechanism by which standard training solves parity is best understood as deterministic and sequential, rather than behaving like random search over size-kk subsets. A benefit of width appears to be variance reduction: the upper tail of long convergence times is mitigated by a large number of randomly-initialized neurons.

C.5 Learning and grokking in the finite-sample case

We provide some supplementary plots for the experiments outlined in Figure 4 (right). In these settings, a fixed architecture (width-100100 MLP with ReLU activations) is trained with minibatch SGD in an otherwise fixed configuration (hinge loss, learning rate η=0.1\eta=0.1, batches of size B=32B=32) on a finite training sample of size mm. We also vary a weight decay parameter λ\lambda.

As shown in Figure 15, the weight decay parameter λ\lambda modulates a delicate computational-statistical tradeoff: it improves generalization (expanding the range of mm for which training eventually finds the correct solution), but the model fails to train at large values of λ\lambda. For small mm and appropriately tuned λ\lambda, we observe grokking: the model initially overfits the training data, but finds a classifier that generalizes after a large number of iterations.

C.6 Learning noisy parities

The other empirical results in this work focus on noiseless parity distributions DS\mathcal{D}_{S}, to reduce the number of sources of variance and degrees of freedom. However, the setting of random classification noise is important for several reasons. In this section, we briefly demonstrate that our results extend to this case. Let DS(ϵ)\mathcal{D}_{S}^{(\epsilon)} denote the (n,k,ϵ)(n,k,\epsilon)-noisy parity distribution, defined by flipping the labels in the (n,k)(n,k)-parity distribution DS\mathcal{D}_{S} independently with probability 12ϵ\frac{1}{2}-\epsilon. Note that when ϵ=0\epsilon=0, the labels are completely random (thus, SS cannot be learned). By a standard PAC-learning argument, when 0<ϵ120<\epsilon\leq\frac{1}{2}, the statistical limit for identifying SS from i.i.d. samples from DS\mathcal{D}_{S} scales as Θ(klognϵ2)\Theta\left(\frac{k\log n}{\epsilon^{2}}\right).

First, learning parities from noisy samples is the true “emblematic computationally-hard distribution”. Without noise, there is a non-SQ algorithm which avoids the exponential-in-kk computational barrier: Gaussian elimination can identify SS in O(n3)O(n^{3}) time and Θ(n)\Theta(n) samples. Second, viewing parities as an idealized setting in which to understand training dynamics, resource scaling, and emergence in deep learning, it is important to see that this phenomenon is robust to label noise.

In particular, when architecture’s population gradient has a Fourier gap with parameter γ\gamma in the noiseless case implies a Fourier gap with parameter ϵγ\epsilon\cdot\gamma.

We find that the experimental findings are robust to label noise, in the sense that models are able to obtain nontrivial (and sometimes 100%100\%) accuracy; see Figure 16 for some training curves under various settings of ϵ\epsilon. This provides concrete evidence against the (already extremely dubious) hypothesis that neural networks, with standard initialization and training, learn noiseless parities by implicitly simulating an efficient algorithm such as Gaussian elimination. Note that with a constant learning rate (here, η=0.1\eta=0.1) and label noise, the iterates of SGD do not always converge to 100%100\% accurate solutions.

C.7 Counterexample for layer-by-layer learning

Consider an LL-layer MLP with activation σ\sigma, parameterized by weights and biases

where fif_{i} denotes the function zσ(Wiz+bi)z\mapsto\sigma(W_{i}z+b_{i}) for 1iL11\leq i\leq L-1, and fLf_{L} denotes zuzz\mapsto u^{\top}z. The shapes of the parameters Wi,bi,uW_{i},b_{i},u are selected such that each function composition is well-defined. Let the intermediate activations at layer ii be denoted by

Finally, rir_{i} (the width at layer ii) refers to the dimensionality of ziz_{i} as defined above.

Notice that when σ\sigma is a degree-22 polynomial (say, σ(z)=z2\sigma(z)=z^{2}), an LL-layer MLP can represent parities up to degree 2L12^{L-1}– for example, a 33-layer MLP (which composes quadratic activations twice) can represent a 44-sparse parity as a 22-sparse parity of 22-sparse parities. However, Equation (2) implies the following:

An individual layer cannot represent a parity of k>2k>2 inputs.

The population gradient (as in Equation (5) is zero (since every coordinate of the gradient is the correlation between a kk-wise parity and a polynomial of degree 22).

Thus, this setting serves as an idealized counterexample for layer-by-layer learning: if SGD succeeds on parities with higher degree than the architecture’s polynomial activations, it must do so by an end-to-end mechanism. Intuitively, earlier layers can only make progress by knowing how their outputs will be used downstream. Concretely, consider the population gradient of the correlation loss, with respect to a first-layer neuron’s weights w:=(W1)j,:w:=(W_{1})_{j,:}. With layer-by-layer training, this gradient contains no information:

However, in end-to-end training, the presence of downstream layers removes this barrier:

giving the gradient greater representation capacity (in terms of polynomial degree). The question remains of whether end-to-end training works in this setting, which we resolve positively in small experiments.

We empirically observed successful training (to 100%100\% accuracy) in a few settings (with SGD, learning rate η=0.01\eta=0.01, batch size B=32B=32, and default uniform initialization as described in Appendix D.1):

L=3,n{10,20,30},k{1,2,3,4}L=3,n\in\{10,20,30\},k\in\{1,2,3,4\}. Small widths suffice: (r1,r2)=(2,1)(r_{1},r_{2})=(2,1). Over 10 random seeds, all models converged within 2000020000 iterations.

L=4,n{10,20,30},k{1,2,3,4,5,6}L=4,n\in\{10,20,30\},k\in\{1,2,3,4,5,6\}. Widths were chosen to be slightly larger for stability: (r1,r2,r3)=(10,10,1)(r_{1},r_{2},r_{3})=(10,10,1). Over 10 random seeds, all models converged within 5000050000 iterations. Additionally, models trained on (n,k){(10,7),(20,7),(30,7),(10,8)}(n,k)\in\{(10,7),(20,7),(30,7),(10,8)\} converged within 500000500000 iterations.

As a sanity check, the models failed to converge in experimental setups where k>2L1k>2^{L-1}: (L=2,k3)(L=2,k\geq 3) and (L=3,k5)(L=3,k\geq 5).

This construction serves as a simple counterexample to the “deep only works if shallow is good” principle of Malach and Shalev-Shwartz (2019), demonstrating a case where a deep network can get near-perfect accuracy even when greedy layerwise training (e.g. (Belilovsky et al., 2019)) cannot beat trivial performance. It remains to characterize these positive empirical results theoretically, as well as to investigate whether there are pertinent analogues in real data distributions.

C.8 Lack of plateaus for wide polynomial-activation MLPs

An interesting qualitative observation from the training curves in Figure 6 is that the validation accuracy curves in setting (vi) (width-1000 polynomial-activation MLPs) do not follow the same “plateau” or ”staircase” pattern as the others. Figure 17 shows a few additional examples of training curves for polynomial-activation MLPs, varying the width rr and batch size BB. We find that the rate of descent of the validation error increases with both of these parameters; note that this does not occur with ReLU activations (where there are sharp phase transitions between plateaus at all batch sizes).

This constitutes an exception to this paper’s theme of “hidden progress” behind flat loss (or error) curves: with enough overparameterization and “over-sampling”, the continuous progress of SGD in this setting is no longer hidden, and manifests in the training curves. This phenomenon seems to be specific to certain activation functions (i.e. xkx^{k} but not ReLU); we leave it for future work to understand why and when it occurs, as well as potential practical implications.

Appendix D Details for all experiments

Our “robust space” of empirical results use the following loss functions:

In the configurations corresponding to all of the figures and convergence time experiments, we used the hinge loss. This was a relatively arbitrary choice (i.e. they appeared to be interchangeable upon running small experiments); an advantage of the hinge and square losses over cross entropy is that for architectures that can realize the parity function, there is a zero-loss solution with finite weights.

Our empirical results use the following i.i.d. weight initializations:

Uniform on the interval [c,c][-c,c], where the scale cc is chosen for all affine transformation parameters using the “Xavier initialization” convention (Glorot and Bengio, 2010). The experiments are quite tolerant to the particular choice of cc (as these are not deep networks); this choice, which is the default in deep learning packages, emphasizes that our positive empirical results hold under a standard initialization scheme.

Gaussian with mean and variance σ2\sigma^{2}, selected using the “Kaiming initialization” convention (He et al., 2015).

We consider 22-layer MLPs f(x;W,b,u)=uσ(Wx+b)f(x;W,b,u)=u^{\top}\sigma(Wx+b) for two choices of activations:

Degree-kk polynomial: σ(z):=zk\sigma(z):=z^{k}.

In both cases, whenever rkr\geq k (and, in the case of polynomial activations, choosing the degree to be kk), there exists a width-rr MLP which can represent kk-sparse parities: for all (n,k)(n,k) and S=k|S|=k, there is a setting of W,b,uW,b,u such that f(x;W,b,u)=χS(x)f(x;W,b,u)=\chi_{S}(x).

Note that if the output f(x;θ)f(x;\theta) is a degree-k<kk^{\prime}<k polynomial in xx (e.g. an MLP with σ(z)=zk\sigma(z)=z^{k^{\prime}} activations), the architecture is incapable of representing a parity of kk inputs. In fact, it is incapable of representing any function that has a nonzero correlation with parity; this follows from orthogonality (Equation (2)).

To explore the limits of concise parameterization for architectures capable of learning parities, we propose a variety of non-standard activation functions which allow a single neuron to learn sparse parities. These constructions leverage the fact that the parity is a nonlinear function of the sum of its inputs wSxw_{S}^{\top}x, where wS:=iSeiw_{S}:=\sum_{i\in S}e_{i}.

σ()\sigma(\cdot) is the piecewise linear function which interpolates the k+1k+1 points {(k,+1),(k2,1),(k4,+1),,(k+2,±1),(k,1)}\{(k,+1),(k-2,-1),(k-4,+1),\ldots,(-k+2,\pm 1),(-k,\mp 1)\} with kk linear regions {(,k],[k,k+2],,[k2,k],[k,+)}\{(-\infty,-k],[-k,-k+2],\ldots,[k-2,k],[k,+\infty)\}. Then, σ(wSx)=χS(x)\sigma(w_{S}^{\top}x)=\chi_{S}(x).

σ()\sigma(\cdot) is the degree-kk polynomial which interpolates the same points as above.

σ(z):=sin(z)\sigma(z):=\sin(z). The sinusoidal neuron sin(wx+b)\sin(w^{\top}x+b) can also express parities of arbitrary degree, since it can interpolate the same set of points as the \infty-zigzag activation. In the experiments, we pick a shift β\beta and use the activation σ(z):=sin(π2z+β)\sigma(z):=\sin(\frac{\pi}{2}z+\beta), such that σ(z)\sigma(z) interpolates the same points as the sign convention selected for the \infty-zigzag activation. In the experiments in Section 3, the sinusoidal activation is additionally scaled by a factor of 2 (zσ(2z)z\mapsto\sigma(2z)); this is interchangeable with scaling the learning rate and initialization, and is done to obtain more robust convergence in the particular setting of (n,k)=(50,3)(n,k)=(50,3).

Figure 18 visualizes all families of activations considered in this paper.

The Transformer experiments use a slightly simplified version of the architecture introduced in (Vaswani et al., 2017). In particular, it omits dropout, layer normalization, tied input/output embedding weights, and a positional embedding on the special [CLS] token. Including these does not change the results significantly; they are all present in the preliminary findings in (Edelman et al., 2021) (in which an “off-the-shelf” Transformer implementation successfully learns sparse parities). We specify the architecture below.

where softmax(z):=exp(z)/1exp(z)\mathsf{softmax}(z):=\exp(z)/\mathbf{1}^{\top}\exp(z). Note that we have specialized this architecture to a single output, at the [CLS] position.

where σ()=GeLU()\sigma(\cdot)=\mathsf{GeLU}(\cdot) (the Gaussian error linear unit) is the standard choice in Transformers.

Each matrix-shaped parameter was initialized using PyTorch’s default “Xavier uniform” convention. Unlike the other settings considered in this paper, we were unable to observe successful convergence beyond a few small (n,k)(n,k) using standard SGD. As is common practice when training Transformers, we used Adam (Kingma and Ba, 2014) with default adaptive parameters β1=0.9,β2=0.999,ϵ=108\beta_{1}=0.9,\beta_{2}=0.999,\epsilon=10^{-8} in our experiments. While there are more fine-grained accounts of why Adam outperforms vanilla SGD (Zhang et al., 2020, Agarwal et al., 2020), finding the optimal optimizer configuration and investigating ablations of this optimizer are outside the scope of this work. In this work, we only tune Adam’s learning rate η\eta.

Even with all biases bib_{i} set to , this architecture can realize a kk-wise parity, by setting {wi}={ej:jS}\{w_{i}\}=\{e_{j}:j\in S\} in any permutation.

Figure 1 (left) shows training curves from 8 representative configurations, with online i.i.d. samples from the same distribution, corresponding to the (n=50,k=3)(n=50,k=3)-sparse parity problem. The first row encompasses various MLP settings with standard activations:

Setting (i): width-1010 MLP with ReLU activation (B=32,η=0.5B=32,\eta=0.5).

Setting (i): width-1010 MLP with ReLU activation, with large batches (B=1024,η=0.05B=1024,\eta=0.05).

Setting (ii): width-100100 MLP with ReLU activation, with tiny batches (B=1,η=0.05B=1,\eta=0.05).

Setting (iv): width-1010 MLP with polynomial σ(z)=z3\sigma(z)=z^{3} activation (B=32,η=0.05B=32,\eta=0.05).

Setting (vii): width-11 MLP with a piecewise linear kk-zigzag activation (B=32,η=0.2B=32,\eta=0.2).

Setting (x): width-11 MLP with a sinusoidal activation (scaled and shifted for k=3k=3; see the discussion in Section D.1.2) (B=32,η=0.05B=32,\eta=0.05).

Setting (xv): degree-33 PolyNet (B=32,η=0.07B=32,\eta=0.07).

The first row uses the width-1010 ReLU MLP configuration (ii)(ii), holding B=32B=32 and η=0.1\eta=0.1 while varying the task difficulty across 6 settings: (n,k){(30,3),(60,3),(90,3),(30,4),(30,5),(30,6)}(n,k)\in\{(30,3),(60,3),(90,3),(30,4),(30,5),(30,6)\}. The remaining plots are all for the (50,3)(50,3) setting.

The second row uses the k=3k=3 PolyNet configuration (xv)(xv), varying (B,η){(1,0.005),(4,0.01),(16,0.1),(64,0.1),(256,0.1),(1024,0.1)}(B,\eta)\in\{(1,0.005),(4,0.01),(16,0.1),\allowbreak(64,0.1),(256,0.1),(1024,0.1)\}.

The third row uses the minimally-wide configurations (*i), (*ii), (vii), (viii), (xi), (xii) (thus presenting an example for each non-standard activation), holding batch size B=1B=1. η=0.1\eta=0.1 in each of the cases except (*ii), where η=0.01\eta=0.01.

The fourth row uses three large architectures: settings (iii), (vi), and (*iii), with (B,η){(1,0.1),(1024,0.1),(1,0.001),(1024,0.01),(32,0.0003),(1024,0.0003).}(B,\eta)\in\{(1,0.1),(1024,0.1),\allowbreak(1,0.001),(1024,0.01),(32,0.0003),(1024,0.0003).\} (*iii) uses the Adam optimizer instead of SGD.

Figure 10 contains scaling plots in various settings for the median convergence time tct_{c}. Below, we give comprehensive details about these settings. For each of these runs, we chose B=32B=32 (settings with smaller batch sizes exhibited additional variance; with larger batch sizes, the models were slower to converge), as well as the hinge loss. We used SGD with constant learning rate η\eta (enumerated below), except in setting (*iii).

The top row shows MLP settings (i) through (vi). From left to right:

Setting (i): width-1010 MLP with ReLU activation (η=1\eta=1).

Setting (ii): width-100100 MLP with ReLU activation (η=1\eta=1).

Setting (iii): width-10001000 MLP with ReLU activation (η=1\eta=1).

Setting (iv): width-1010 MLP with σ(z)=zk\sigma(z)=z^{k} activation (η=0.01\eta=0.01).

Setting (v): width-100100 MLP with σ(z)=zk\sigma(z)=z^{k} activation (η=0.01\eta=0.01).

Setting (vi): width-10001000 MLP with σ(z)=zk\sigma(z)=z^{k} activation (η=0.01\eta=0.01).

The bottom row shows miscellaneous settings. From left to right:

Setting (vii): width-11 MLP with degree-kk oscillating polynomial activation interpolating the parity function (η=0.01\eta=0.01).

Setting (xiv): single sinusoidal neuron with no second layer (η=0.01\eta=0.01).

Setting (iv): degree-kk PolyNet (η=0.05\eta=0.05).

Setting (ii): width-100100 MLP with ReLU activation (η=1\eta=1), showing an expanded range of nn for smaller kk.

Setting (xv): width-100100 MLP with σ(z)=zk\sigma(z)=z^{k} activation (η=1\eta=1), showing an expanded range of nn for smaller kk.

D.2 Training curves and convergence time plots

For all example training curves in all figures (in Sections 3 and 5, as well as the appendix), population losses and accuracies are approximated using a batch of size 81928192, sampled once at the beginning of training from the same distribution DS\mathcal{D}_{S}. All plots of single representative training runs use a fixed random seed (torch.manual_seed(0)); when RR training runs are shown, seeds 0,,R10,\ldots,R-1 are used.

In Figures 7 and 8, validation accuracies were recorded every 1010 iterations, and a run was recorded as converged if it reached 100%100\% accuracy within 10510^{5} iterations; we report the 10th10^{\text{th}} percentile over 25 random seeds, to reduce variance arising from the more initialization-sensitive settings. In Figure 9, coarse-grained scaling estimates for the (10th10^{\text{th}} percentile) convergence time are computed as follows: for nN:={10,20,30}n\in\mathcal{N}:=\{10,20,30\}, the smallest α\alpha is chosen such that tcc(nn0)αt_{c}\leq c\cdot(n-n_{0})^{\alpha}, choosing n0=minN1=9n_{0}=\min\mathcal{N}-1=9, so that c=tcc=t_{c} at n=10n=10. These estimates are calculated to give quantitative order-of-magnitude upper bounds for the convergence time. Indeed, the power-law convergence times do not extrapolate at a constant learning rate; see Figure 2 (right), the “larger nn” plots in Figure 10, and the discussion on batch sizes and learning rates in Appendix C.2.

To reduce computational load, for the larger-scale probes of convergence times tct_{c}, validation accuracies were instead checked on a sample of size 128128. For the underparameterized networks (i.e. unable to represent parity, but can still get a meaningful gradient signal), this threshold was changed to 1010 consecutive batches with accuracy at least 55%55\%. Note that for parity learning in particular, a weak learner can be converted into a strong learner: there is an efficient algorithm (Goldreich and Levin, 1989, Kushilevitz and Mansour, 1993) which, given a classifier which achieves 1/2+ϵ1/2+\epsilon accuracy on DS\mathcal{D}_{S} for a constant ϵ>0\epsilon>0, outputs SS with high probability.

In the median convergence time plots in Figure 1 (right), Figure 2 (right), and Figure 10, error bars for median convergence times in all plots are 95%95\% confidence intervals, computed from 100100 bootstrap samples. Each point on the each curve corresponds to 10001000 random trials. Halted curves signify more than 50%50\% of runs failing to converge within T=105T=10^{5} iterations (hence, infinite medians).

D.3 Implementation, hardware, and compute time

All training experiments were implemented using PyTorch (Paszke et al., 2019).

Although most of the networks considered in the main empirical results are relatively small, a large (108\sim 10^{8}) total number of models were trained to certify the “robust space” of results and obtain precise scaling curves. These individual experiments were not large enough to benefit from GPU acceleration; on an internal cluster, the CPU compute expenditure totaled approximately 15001500 CPU hours.

A subset of these experiments stood to benefit from GPU acceleration: width r100r\geq 100 MLPs; scaling behaviors for n100n\geq 100; all experiments involving Transformers. These were performed with NVIDIA Tesla P100, Tesla P40, and RTX A6000 GPUs on an internal cluster, consuming a total of approximately 200200 GPU hours.