Hidden Progress in Deep Learning: SGD Learns Parities Near the Computational Limit
Boaz Barak, Benjamin L. Edelman, Surbhi Goel, Sham Kakade, Eran Malach, Cyril Zhang
Introduction
In deep learning, performance improvements are frequently observed upon simply scaling up resources (such as data, model size, and training time). While these improvements are often continuous in terms of these resources, some of the most surprising recent advances in the field have been emergent capabilities: at a certain threshold, behavior changes qualitatively and discontinuously. Through a statistical lens, it is well-understood that larger models, trained with more data, can fit more complex and expressive functions. However, far less is known about the analogous computational question: how does the scaling of these resources influence the success of gradient-based optimization?
These phase transitions cannot be explained via statistical capacity alone: they can appear even when the amount of data remains fixed, with only model size or training time increasing. A timely example is the emergence of reasoning and few-shot learning capabilities when scaling up language models (Radford et al., 2019, Brown et al., 2020, Chowdhery et al., 2022, Hoffmann et al., 2022); Srivastava et al. (2022) identify various tasks which language models are only able to solve if they are larger than a critical scale. Power et al. (2022) give examples of discontinuous improvements in population accuracy (“grokking”) when running time increases, while dataset and model sizes remain fixed.
In this work, we analyze the computational aspects of scaling in deep learning, in an elementary synthetic setting which already exhibits discontinuous improvements. Specifically, we consider the supervised learning problem of learning a sparse parity: the label is the parity (XOR) of bits in a random length- binary string. This problem is computationally difficult for a range of algorithms, including gradient-based (Kearns, 1998) and streaming (Kol et al., 2017) algorithms. We focus on analyzing the resource measure of training time, and demonstrate that the loss curves for sparse parities display a phase transition across a variety of architectures and hyperparameters (see Figure 1, left). Strikingly, we observe that SGD finds the sparse subset (and hence, reaches 0 error) with a variety of activation functions and initialization schemes, even with no over-parameterization.
A natural hypothesis to explain SGD’s success in learning parities, with no visible progress in error and loss for most of training, would be that it simply “stumbles in the dark”, performing random search for the unknown target (e.g. via stochastic gradient Langevin dynamics). If that were the case, we might expect to observe a convergence time of , like a naive search over parameters or subsets of indices. However, Figure 1 (right), already provides some evidence against this “random search” hypothesis: the convergence time adapts to the sparsity parameter , with a scaling of on small instances. Notably, such a convergence rate implies that SGD is closer to achieving the optimal computation time among a natural class of algorithms (namely, statistical query algorithms).
Through an extensive empirical analysis of the scaling behavior of a variety of models, as well as theoretical analysis, we give strong evidence against the “stumbling in the dark” viewpoint. Instead, there is a hidden progress measure under which SGD is steadily improving. Furthermore, and perhaps surprisingly, we show that SGD achieves a computational runtime much closer to the optimal SQ lower bound than simply doing (non-sparse) parameter search. More generally, our investigations reveal a number of notable phenomena regarding the dependence of SGD’s performance on resources: we identify phase transitions when varying data, model size, and training time.
It is known from SQ lower bounds that with a constant noise level, gradient descent on any architecture requires at least computational steps to learn -sparse -dimensional parities (for background, see Appendix A). We first show a wide variety of positive empirical results, in which neural networks successfully solve the parity problem in a number of iterations which scales near this computational limit:
For all small instances () of the sparse parity problem, architectures 2-layer MLPs, TransformersWith a smaller range of hyperparameters., sinusoidal/oscillating neurons, PolyNetsA non-standard architecture introduced in this work; see Section 3 for the definition., initializations in uniform, Gaussian, Bernoulli, and batch sizes , SGD on solves the -sparse parity problem (w.p. ) within at most steps, for small constants .
Our empirical results suggest that, in a number of computational steps matching the SQ limit, SGD is able to solve the parity problem and identify the influential coordinates, without an explicit sparse prior. We give a theoretical analysis which validates this claim.
We also present a stronger analysis for an idealized architecture (which we call the disjoint-PolyNet), which allows for any batch size, and captures the phase transitions observed in the error curves.
On disjoint-PolyNets, SGD (with any batch size ) converges with high probability to a solution with at most error on the -parity problem in at most iterations. Continuous-time gradient flow exhibits a phase transition: it spends a fraction of its time before convergence with error .
Our theoretical and empirical results hold in non-overparameterized regimes (including with a width- sinusoidal neuron), in which no fixed kernel, including the neural tangent kernel (NTK) (Jacot et al., 2018), is sufficiently expressive to fit all sparse parities with a large margin. Thus, our findings comprise an elementary example of combinatorial feature learning: SGD can only successfully converge by learning a low-width sparse representation.
Building upon our core positive results, we provide a wide variety of preliminary experiments, showing sparse parity learning to be a versatile testbed for understanding the challenges and surprises in solving combinatorial problems with neural networks. These include quantities which reveal the continual hidden progress behind uninformative training curves (as predicted by the theory), experiments at small sample sizes which exhibit grokking (Power et al., 2022), as well as an example where greedy layer-wise learning is impossible but end-to-end SGD can learn the layers jointly.
2 Related work
We present the most directly related work on feature learning, and learning parities with neural nets. A broader discussion can be found in Appendix A.3.
Theoretical analysis of gradient descent on neural networks is notoriously hard, due to the non-convex nature of the optimization problem. That said, it has been established that in some settings, the dynamics of GD keep the weights close to their initialization, thus behaving like convex optimization over the Neural Tangent Kernel (see, for example, (Jacot et al., 2018, Allen-Zhu et al., 2019, Du et al., 2018)). In contrast, it has been shown that in various tasks, moving away from the fixed features of the NTK is essential for the success of neural networks trained with GD (for example (Yehudai and Shamir, 2019, Allen-Zhu and Li, 2019, Wei et al., 2019) and the review in (Malach et al., 2021)). These results demonstrate that feature learning is an important part of the GD optimization process. Our work also focuses on a setting where feature learning is essential for the target task. In our theoretical analysis, we show that the initial population gradient encodes the relevant features for the problem. The importance of the first gradient step for feature learning has been recently studied in (Ba et al., 2022).
The problem of learning parities using neural networks has been investigated in prior works from various perspectives. It has been demonstrated that parities are hard for gradient-based algorithms, using similar arguments as in the SQ analysis (Shalev-Shwartz et al., 2017, Abbe and Sandon, 2020). One possible approach for overcoming the computational hardness is to make favorable assumptions on the input distribution. Indeed, recent works show that under various assumptions on the input distribution, neural networks can be efficiently trained to learn parities (XORs) (Daniely and Malach, 2020, Shi et al., 2021, Frei et al., 2022, Malach et al., 2021). In contrast to these results, this work takes the approach of intentionally focusing on a hard benchmark task, without assuming that the distribution has some favorable (namely, non-uniform) structure. This setting allows us to probe the performance of deep learning at a known computational limit. Notably, the work of Andoni et al. (2014) provides analysis for learning polynomials (and in particular, parities) under the uniform distribution. However, their main results require a network of size (i.e., extremely overparameterized network), and provides only partial theoretical and empirical evidence for the success of smaller networks. Studying a related subject, some works have shown that neural networks display a spectral bias, learning to fit low-frequency coefficients before high-frequency ones (Rahaman et al., 2019, Cao et al., 2019).
Preliminaries
We provide an expanded discussion of background and related work in Appendix A.
A key fact about parities is that they are orthogonal under the correlation inner product: for ,
Our main results are presented in the online learning setting, with a stream of i.i.d. batches of examples. At each iteration , a learning algorithm receives a batch of examples drawn i.i.d. from , then outputs a classifier . We say that solves the parity task in steps (with error ) if
Empirical findings
The central phenomenon of study in this work is the empirical observation that neural networks, with standard initialization and training, can solve the -parity problem in a number of iterations scaling as on small instances. We observed robust positive results for randomly-initialized SGD on the following architectures, indexed by Roman numerals:
2-layer MLPs: ReLU or polynomial () activation, in a wide variety of width regimes . Settings (i), (ii), (iii) (resp. (iv), (v), (vi)) use ReLU (resp. polynomial) activations. We also consider (exceptional settings (*i), (*ii) ), the minimum width for representing a -wise parity for both activations.
1-neuron networks: Next, we consider non-standard activation functions which allow a one-neuron architecture to realize -wise parities. The constructions stem from letting , and constructing to interpolate (the appropriate scaling of) with a piecewise linear -zigzag activation (vii), or a degree- polynomial (viii). Going a step further, a single -zigzag (ix) or sinusoidal (x) neuron can represent all -wise parities. In settings (xi), (xii), (xiii), (xiv), we remove the second trainable layer (setting ). We find that wider architectures with these activations also train successfully.
Transformers: There is growing interest in using parity as a benchmark for combinatorial function learning, long-range dependency learning, and length generalization in Transformers (Lu et al., 2021, Edelman et al., 2021, Hahn, 2020, Anil et al., 2022, Liu et al., 2022). Motivated by these recent theoretical and empirical works, we consider a simplified specialization of the Transformer architecture to this sequence classification problem. This is the less-robust setting (*iii); the architecture and optimizer are described in Appendix D.1.3.
PolyNets: Our final setting (xv) is the PolyNet, a slightly modified version of the parity machine architecture. Parity machines have been studied extensively in the statistical mechanics of ML literature (see the related work section) as well as in a line of work on ‘neural cryptography’ (Rosen-Zvi et al., 2002). A parity machine outputs the sign of the product of linear functions of the input. A PolyNet simply outputs the product itself. Both architectures can clearly realize -sparse parities. The PolyNet architecture was originally motivated by the search for an idealized setting where an end-to-end optimization trajectory analysis is tractable (see Section 4.1); we found in these experiments that this architecture trains very stably and sample-efficiently.
All of the networks listed above were observed to successfully learn sparse parities in a variety of settings. We summarize our findings as follows: for all combinations of , , batch sizes , initializations uniform, Gaussian, Bernoulli, loss functions hinge, square, cross entropy, and architecture configurations , SGD solved the parity problem (with accuracy, validated on a batch of samples) in at least of random trials, for at least one choice of learning rate . The models converged in steps, for small architecture-dependent constants (see Appendix C). Figure 1 (left) shows some representative training curves.
Settings (*i) and (*ii), where the MLP just barely represents a -sparse parity, and the Transformer setting (*iii), are less robust to small batch sizes. In these settings, the same positive results as above only held for sufficiently large batch sizes: . Also, setting (*iii) used the Adam optimizer (which is standard for Transformers); see Appendix D.1.3 for details.
For almost all of the architectures, we find that that the training curves exhibit phase transitions in terms of running time (and thus, in the online learning setting, dataset size as well): long durations of seemingly no progress, followed by periods of rapid decrease in the validation error. Strikingly, for architectures (v) and (vi), this plateau is absent: the error in the initial phase appears to decrease with a linear slope. See Appendix C.8 for more plots.
2 Random search or hidden progress?
The remainder of this paper seeks to answer the question: “By what mechanism does deep learning solve these emblematic computationally-hard optimization problems?”
A natural hypothesis would be that SGD somehow implicitly performs Monte Carlo random search, “bouncing around” the loss landscape in the absence of a useful gradient signal. Upon closer inspection, several empirical observations clash with this hypothesis:
Scaling of convergence times: Without an explicit sparsity prior in the architecture or initialization, it is unclear how to account for the runtimes observed in experiments, which adapt to the sparsity . The initializations, which certainly do not prefer sparse functionsIndeed, under all standard architectures and initialization, the probability that a random network is -correlated with a sparse parity would be , since with that probability of the total influence would be accounted by the irrelevant features., are close to the correct solutions with probability .
Sensitivity to initialization, not SGD samples: Running these training setups over multiple stochastic batches from a common initialization, we find that loss curves and convergence times are highly correlated with the architecture’s random initialization, and are quite concentrated conditioned on initialization; see Figure 2 (center).
Elbows in the scaling curves: For larger , the power-law scaling ceases to hold: the exponent worsens (see Figure 2 (right), as well as the discussion in Appendix C.2). This would not be true for random exhaustive search.
Even these observations, which do not probe the internal state of the algorithm, suggest that exhaustive search is an insufficient picture of the training dynamics, and a different mechanism is at play.
Theoretical analyses
We now provide a theoretical account for the success of SGD in solving the -parity problem. Our main theoretical observation is that, in many cases, the population gradient of the weights at initialization contains enough “information” for solving the parity problem. That is, given an accurate enough estimate of the initial gradient (by e.g. computing the gradient over a large enough batch size), the relevant subset can be found.
Carefully extending this insight, we obtain an end-to-end convergence result for ReLU-activation MLP networks with a particular symmetric choice of initialization, trained with the hinge loss:
This does not capture the full range of settings in which we empirically observe successful convergence. First, it requires a sign vector initialization, while we observe convergence with other random initialization schemes (namely, uniform and Gaussian). Second, it requires the batch size to scale with In fact, at this batch size, the correct parity indices emerge in a single SGD step., while we also obtain positive results when is small (even ). Analogous statements for these cases (as well as other activations and losses) would require Fourier gaps for population gradient functions other than majority; lower bounds on the degree- coefficients (“Fourier anti-concentration”) are particularly elusive in the literature, and we leave it as an open challenge to establish them in more general settings. We provide preliminary empirics in Appendix C.1, suggesting that the Fourier gaps in our empirical settings are sufficiently large.Interestingly, we observe that the Fourier gap tends to increase over the course of training. This is not captured by our current theoretical analysis.
We note that in the low-width (non-overparameterized) regimes considered in this work, no fixed kernel (including the neural tangent kernel (Jacot et al., 2018), whose dimensionality is the network’s parameter count) can solve the sparse parity problem. The following is a consequence of results in (Kamath et al., 2020, Malach and Shalev-Shwartz, 2022):
Thus, our low-width results lie outside the NTK regime, which requires far larger models (size ) to express parities. However, we note that better sample complexity bounds are possible in the NTK regime, with an algorithm more similar to standard SGD (see (Telgarsky, 2022) and Appendix A.3).
2 Disjoint-PolyNet: exact trajectory analysis for an idealized architecture
In this section, we present an architecture (a version of PolyNets (xv)) which empirically exhibits similar behavior to MLPs and bypasses the difficulty of analyzing Fourier gaps. The disjoint-PolyNet takes a product over linear functions of an equal-sizedWe assume for simplicity that is divisible by . partition of the input coordinates: . As noted in the Section 1.2, this is equivalent to a tree parity machine, with real-valued (instead of ) outputs.
This architecture also requires us to assume that the set of size in the -parity problem is selected such that exactly one index belongs to each disjoint partition, that is, for all , . We refer to this problem as the -disjoint parity problem. Note that there are still different possibilities for set under this restriction. For fixed , these represent a constant fraction of the (by Stirling’s approximation) possibilities for in the general non-disjoint case.
Consider training a disjoint-PolyNet w.r.t. the correlation loss. Without loss of generality, assume that each relevant coordinate in is the first element . Then, the population gradient is non-zero only at indices :
Suppose . Let denote the smallest time at which the error is at most . Then,
Informally, the network takes much longer to reach slightly-better-than-trivial accuracy than it takes to go from slightly better than trivial to perfect accuracy. Returning to discrete time, we also analyze the trajectory of disjoint-PolyNets trained with online SGD at any batch size, confirming that a neural network can learn -sparse disjoint parities within iterations.
Extended versions of these theorems, along with proofs, can be found in Appendix B.3.
Hidden progress: discussion and additional experiments
So far, we have shown that sparse parity learning provides an idealized setting in which neural networks successfully learn sparse combinatorial features, with a mechanism of continual progress hiding behind discontinuous training curves. In this section, we outline preliminary explorations on a broader range of interesting phenomena which arise in this setting. Details are provided in Appendix C, while more systematic investigations are deferred to future work.
The theoretical and (black-box) empirical results suggest that SGD does not learn parities via the memoryless process of random exhaustive search. This suggests the existence of progress measures: scalar quantities which are functions of the training algorithm’s state (i.e. the model weights ) and are predictive of the time to successful convergence. We provide some white-box investigations which further support the hypothesis of hidden progress, by examining the gradual improvement in quantities other than the training loss. In Appendix C.1, we directly plot the Fourier gaps of the population gradient, as a function of , finding that they are large (within a small constant factor of those of majority) in practice. In Figure 3 and Appendix C.3, we examine the weight movement norm to reveal hidden progress, motivated by the fact that is a linearized estimate for the initial population gradient.
An interesting consequence of our analysis is that it illuminates scaling behaviors with respect to a third fundamental resource parameter: model size, which we study in terms of network width . If SGD operated by a “random search” mechanism, one would expect width to provide a parallel speedup. Instead, we find that SGD sequentially amplifies progress. The sharp lower tails in Figure 2 (left) imply that running identical copies of SGD does not give speedups; more directly, in Appendix C.4 (previewed in Figure 4 (left)), we find that convergence times for sparse parities empirically plateau at large model sizes.
Our main results are presented in the online learning setting (fresh minibatches from at each iteration). While this mitigates the confounding factor of overfitting, it couples the resources of training time and independent samples in a suboptimal way, due to the computational-statistical gap for parity learning. In Appendix C.5, we find empirically that minibatch SGD (with weight decay) can learn sparse parities, even with smaller sample sizes . We reliably observe the grokking phenomenon (Power et al., 2022): an initial overfitting phase, then a delayed phase transition in the generalization error; see the two center panels of Figure 4 (right). These results complement and corroborate the findings of Nanda and Lieberum (2022), who analyze the hidden progress of Transformers trained on arithmetic tasks (a setting which also exhibits grokking).
It is a significant challenge (and generally outside the scope of this work) to understand the interactions between network depth and computational/statistical efficiency. In Appendix C.7, we show that learning parities with deeper polynomial-activation MLPs comprises a simple counterexample to the “deep only works if shallow is good” principle of Malach and Shalev-Shwartz (2019): a deep network can get near-perfect accuracy, even when greedy layer-wise training (e.g. (Belilovsky et al., 2019)) cannot beat trivial performance. By providing positive theory and empirics which elude these simplified explanations of SGD, we hope to point the way to a more complete understanding of learning dynamics in the challenging cases where no apparent progress is made for extended periods of time.
Conclusion
This work puts forward sparse parity learning as an elementary test case to explore the puzzling features of the role of computational (as opposed to statistical) resources in deep learning. In particular, we have shown that a variety of neural architectures solve this combinatorial search problem, with a number of computational steps nearly matching the sparsity-dependent SQ lower bound. Furthermore, we have shown that despite abrupt phase transitions in the loss and accuracy curves, SGD works by gradually amplifying the sparse features “under the hood”.
Even in this simple setting, there are several open experimental and theoretical questions. This work largely focuses on the online learning case, which couples training iterations with fresh i.i.d. samples. We believe it would be instructive to investigate parity learning when the three resources of samples, time, and model size are scaled separately. Some very preliminary findings along these lines are presented in Section 3. It is an open problem to extend our theoretical results to the small-batch setting, as well as to the full range of architectures and losses in our experiments. Resolving these questions would require a better understanding of the anti-concentration behavior of Boolean Fourier coefficients, which is much less studied than the analogous concentration phenomena.
Another important follow-up direction is understanding the extent to which these insights extend from parity learning to more complex (including real-world) combinatorial problem settings, as well as the extent to which non-synthetic tasks (in, e.g., natural language processing and program synthesis) embed within them parity-like subtasks of exhaustive combinatorial search. We hope that our results will lead to further progress towards understanding and improving the optimization dynamics behind the recent slew of dramatic empirical successes of deep learning in these types of domains.
This work seeks to contribute to the foundational understanding of computational scaling behaviors in deep learning. Our theoretical and empirical analyses are in a heavily-idealized synthetic problem setting. Hence, we see no direct societal impacts of the results in this study.
We would like to thank Lenka Zdeborová for providing us with references to the statistical physics literature on phase transitions in the learning curves of neural networks, and Matus Telgarsky for bringing to our attention the better sample complexity guarantees of 2-sparse parity learning in the NTK regime. Sham Kakade acknowledges funding from the Office of Naval Research under award N00014-22-1-2377.
References
Appendix
This work leverages the parity problem as a “computationally hard case” for identifying the features which are relevant to the label. Observe that for any , it holds that
That is, a learner who guesses indices cannot use correlations as feedback to reveal which (or how many) indices in are correct, unless is exactly the correct subset. In this sense, for the -parity problem, the wrong answers are indistinguishable from each other. Thus, the structure of this problem forces this form of learner (but not necessarily all learning algorithms) to perform exhaustive search over subsets.
It should be mentioned that the -parity problem can be solved efficiently by a learning algorithm that has access to examples (i.e., an algorithm that does not operate in the SQ framework). Specifically, this problem can be solved by the Gaussian elimination algorithm. Moreover, it has been shown that the (Stochastic) Gradient Descent algorithm, discussed in the next section, can also be utilized for solving parities, given accurate enough estimates of the gradient and a very particular choice of neural network architecture Abbe and Sandon (2020). That said, when the accuracy of the gradients is not sufficient, GD suffers from the same SQ lower bound mentioned above (i.e., GD is essentially an SQ algorithm Abbe et al. (2021)).
Learning sparse noisy parities, even at a very small noise level (i.e., or ) is believed to be computationally hard. This was first explicitly conjectured by Alekhnovich (2003), and has been the basis for several cryptographic schemes (e.g., (Ishai et al., 2008, Applebaum et al., 2009, 2010)). For noiseless sparse parities, Kol et al. (2017) show time-space hardness in the setting where . We present some experiments with noisy parities in Appendix C.6, finding that our empirical results (and theoretical analysis) are robust to noise.
A.2 Neural networks and standard training
where is applied entrywise. It is standard to use GD to jointly update the network’s parameters. Our results include positive results about “single neurons”: MLPs with width . We note that for our theoretical analysis, when training MLPs with GD, we allow for different learning rate and weight decay schedule for the different layers.
Finally, we will analyze randomized learning algorithms, such as GD with random initialization , whose iterates (and thus classifiers ) are random variables even when the samples are not. A learning algorithm has permutation symmetry if, for all sequences of data , the classifiers resulting from feeding to the learner have identical distributions as ranges over all permutations of indices. The neural architectures and initializations (and thus, SGD) considered in this work are permutation-symmetric; for this reason, it is convenient for notation to choose as the canonical -parity learning problem, without loss of generality.
A.3 Additional related work
A line of recent work has focused on understanding the feature learning ability of gradient descent dynamics on neural networks. These analyses go beyond the Neural Tangent Kernel (NTK) regime, where they show a separation between learning with fixed features versus GD on neural networks, for these problems. Several of these works assume structure (often "sparse") in the input data which is useful for the prediction task, and helps avoid computational hardness. In contrast, our work focuses on studying hard problems at their computational limit. Here we discuss the most relevant works in detail:
A line of work (Diakonikolas et al. (2020), Yehudai and Ohad (2020), Frei et al. (2020)) focuses on learning a single non-linearity (typically is the ReLU or sigmoid) using gradient-based methods. These works obtain polynomial-time convergence guarantees when the distribution satisfies a spread condition. These results do not extend to the Boolean hypercube.
Daniely and Malach (2020) also study the problem of learning sparse parities using neural networks. One key difference from our work is that they assume a modified version of the problem, where the input distribution is not uniform over the hypercube, but instead leaks information about the label. In particular, the distribution ensures that the relevant parity bits always have the same value. Shi et al. (2021) generalize this setting by considering a setting where labels are generated based on certain class specific patterns and the data itself is generated using these patterns with some extra background patterns. This also embeds information in the data itself regarding the label, unlike our setting, where the labels are uncorrelated with the input features. Under these structural assumption, the papers study how GD on a two-layer network can learn useful features in polynomial time. Both these analysis also exploits the first gradient step to find useful features. Shi et al. (2021) additionally require a second step to refine the features.
Ba et al. (2022) show how the first gradient step is important for feature leaning. In particular, they show that first update is essentially rank-1 and aligns with the linear component of the underlying function. The functions we consider (parity) do not have a linear component.
Abbe et al. (2022) define a notion of initial alignment between the network at initialization and the target function and show that it is essential to get polynomial time learnability with noisy gradients on a fully connected network. Our MLP results also exploit the correlation between the gradient and the label to ensure that the gradient update gives us signal.
Frei et al. (2022) also study learnability of a parity-like function with under noisy labels. The paper analyzes early stopping GD for learning the underlying labeling function. Our setup is quite different from theirs and can handle .
In concurrent work, Damian et al. (2022) consider the problem of learning polynomials which depend on few relevant directions using gradient descent on a two-layer network. They assume that the distributional Hessian of the ground truth function spans exactly the subspace of the relevant direction. Using this, they show that gradient descent can learn the relevant subspace with sample complexity scaling as and not where is the degree of the underlying polynomial as long as the number of relevant directions is much less than . Their proof technique is similar to our two-layer MLP result which also exploits correlation in the first gradient step. However, for our setting, the distributional Hessian has rank 0 and does not satisfy their assumptions.
An extensive body of work originating in the statistical physics community has studied phase transitions in the learning curves of neural networks (Gardner and Derrida, 1989, Watkin et al., 1993, Engel and Van den Broeck, 2001). These works typically focus on student-teacher learning in the “thermodynamic limit”, in which the number of training examples is times larger than the input dimension and both are taken to infinity. One of the classic toy architectures analyzed in this literature is the parity machine (Mitchison and Durbin, 1989, Hansel et al., 1992, Opper, 1994, Kabashima, 1994, Simonetti and Caticha, 1996). In our work, we introduce PolyNets, a variant of parity machines in which the output is real-valued rather than ; and we theoretically analyze disjoint-Polynets, which are the real-output analogue of the oft-considered parity machines with tree architecture. While much of the statistical mechanics of ML literature focuses on an idealized training limit in which the weights reach a Gibbs distribution equilibrium, there is a strand of the literature that aims to characterize the trajectory of SGD training in the high-dimensional limit with constant-sized sets of ordinary differential equations (Saad and Solla, 1995b, a, Goldt et al., 2019). These papers discuss cases, including problems that share aspects with 2-sparse parities Refinetti et al. (2021), where the network gets stuck in (and then escapes from) a plateau of suboptimal generalization error. Recently, Arous et al. (2021) studied (for rank-one parameter estimation problems) the relative amount of time spent by SGD in an initial high-error “search” phase versus a final “descent” phase, which is reminiscent of the framing of Theorem 6. However, to our knowledge prior work has not shown -sparse parities can be learned with a number of iterations that nearly matches known lower bounds, nor has it specifically studied phase transitions in -sparse parity learning during gradient descent.
Another relevant line of work studies learning the parity problem using the neural tangent kernel (NTK) (Jacot et al., 2018). Namely, in some settings, when the network’s weights stay close to their initialization throughout the training, SGD converges to a solution that is given by a linear function over the initial features of the NTK. As shown in Theorem 5, learning parities over a fixed set of features requires the size of the model to be . In contrast, the model size (number of hidden neurons) considered in this paper does not depend at all on the input dimension . Nevertheless, the NTK analysis does give better sample complexity guarantees than the ones presented in this work, with a somewhat more natural version of SGD. For example, the work of Ji and Telgarsky (2019) demonstrates learning 2-sparse parities using NTK analysis with a sample complexity of , which matches the sample complexity lower bound for learning this problem with NTK (see Wei et al. (2019)). Concurrent work by Telgarsky (2022) shows that this sample complexity can be improved to once the optimization leaves the NTK regime. However, this analysis is given for networks of size , much larger than the networks considered in this paper. We refer the reader to Table 1 in Telgarsky (2022) for a complete comparison of the sample-complexity, run-time and model-size bounds achieved by different works studying 2-sparse parities.
Appendix B Proofs
For some even number , consider a ReLU MLP of size :
For all , randomly initialize
Fix some , and assume that . Then, for every s.t. it holds that:
for some universal constants . More precisely,
Therefore, by the previous lemma we get that for every even the following holds:
Fix some and assume that . Then, Majority has a -Fourier gap at of size with .
First we establish a simple relationship between and .
Here, the first equation follows from Lemma 1, and the second equation follows by simple algebra using the following equality: .
Here, the first equality holds from above, the second by Lemma 1, the third inequality holds from standard approximations of the binomial coefficients, and the last inequality follows from the following inequalities: (by assumption on ) and (by standard calculus). This gives us the desired result.
Assume that is even and that . Then, the following hold:
Let be some tolerance parameter, fix and let . Assume that is an even number. Fix some , and . Let and s.t. and . Assume the following holds:
Additionally, for all and all it holds that .
Claim 1: For all and for all it holds that .
Proof: First, observe that by the assumption it holds that for all ,
Claim 2: For all and for all it holds that
Claim 3: For all it holds that
Claim 4: Fix . Let be a function s.t. and a function s.t. . Then, if and , the following holds:
where in the last inequality we use Eq. (6). So, choosing and gives the required.
Claim 5: Let be the functions defined in the previous claim. Then, there exists weights with s.t. for it holds that for all .
Proof: For define and , .
Proof of Lemma 4: Choose . Using Claim 4 and the union bound, w.p. over it holds that for all , . Therefore, w.p.
so, choosing we get that, w.p. at least over the choice of it holds that . Additionally, for every it holds that
Assume we randomly initialize an MLP using the unbiased initialization defined previously. Consider the following update:
Additionally, it holds that .
Claim: with probability at least ,
Proof: Fix some and note that by Hoeffding’s inequality,
Using the union bound, with probability at least , there exists a set of neurons satisfying the conditions of Lemma 4, and therefore the required follows from the Lemma. ∎
We use the following well-known result on convergence of SGD (see for example Shalev-Shwartz and Ben-David (2014)):
Let be an even number. Assume we randomly initialize an MLP using the unbiased initialization defined previously. Fix and let . Choose the following learning rate and weight decay schedule:
At the first step, use , for all weights.
After the first step, use for the first layers weights and biases and for the second layer weights, with for both layers.
Then, the following holds, with expectation over the randomness of the initialization and the sampling of the batches:
Now, we will show that w.h.p. there exists with good loss. Let . Denote . Observe that , and using the fact that we get
and additionally .
So, we can apply Theorem 8 with and and get that, w.p. over the initialization and the first step, it holds that
B.2 Recoverability of the parity indices from Fourier gaps
Given a network architecture where some neuron has a -Fourier gap with respect to the target subset , we quantify how the indices in can be determined by observing an estimate of the population gradient for a general activation function and :
Then, for every such that has a -Fourier gap at , the indices at which has the largest absolute values are exactly the indices in .
Let . We compute the population gradient, we we call :
where the inequality in the final case is due to the Fourier gap property. Then, it holds that for all we have and for all we have . Thus, the largest entries of the estimate occur at the indices in , as claimed. ∎
B.3 Global convergence for disjoint-PolyNets
In this section we will develop theory for disjoint-PolyNets trained with correlation loss, as illustrated in Figure 5. Section B.3.1 will consider optimization with gradient flow, and section B.4 will consider optimization with SGD at any batch size .
In gradient flow, the relevant weights evolve according to the following differential equations:
Suppose disjoint-PolyNet for is initialized such that , and optimized with gradient flow. Let , and . For any and , let . Then
Let . Then
First, observe that the product of the relevant weights is non-decreasing during gradient flow.
Thus, implies that for all .
In other words, the squares of the relevant weights each follow the same trajectory, shifted according to their initializations. Let , for any . This quantity evolves as follows:
Since is strictly increasing, its inverse is well-defined, and we can use the inverse function theorem to characterize for all :
We can upper- and lower-bound the integrand by applying Maclaurin’s inequality (see page 52 in Hardy et al. (1952)):
The amount of time it takes for to reach a value of is thus:
Hence, for any , for each , as long as
Meanwhile, the amount of time after it takes for to explode to infinity is
Substituting , we obtain that the amount of time it takes for to grow from to is
We can upper- and lower-bound by applying Maclaurin’s inequality (see page 52 in Hardy et al. (1952)):
When , solving the LHS and RHS differential inequalities yields:
From the lower bound on , we can infer that the relevant weights all explode to infinity by the following time:
From the upper bound, we can infer that for any , it is the case that so long as
Hence, for each , for all
Now we analyze the relationship between the relevant weights and the accuracy of the disjoint-PolyNet.
Then define the error of with parameters as
Let be any setting of the weights of a disjoint-PolyNet such that . For ease of notation, let be the irrelevant portion of . There is a constant such that
where is the Gauss error function .
Let be the th relevant coordinate of , and let be the irrelevant coordinates in .
Then we have that the error of with parameters is
Line 10 follows because and are independent. In line 11 we use that is symmetric about 0. Finally, line 12 uses the assumption that .
We can bound using Hoeffding’s inequality:
The indicator random variables are independent of each other, so the first term in line 12 can be characterized using the distribution of the parity of a sum of independent Bernoulli random variables. Let for , and let . The generating function for is . The parity of then satisfies
First we will prove the upper bound on . First, observe that
by Hoeffding’s inequality, and we are done.
Now we will prove the lower bound on . We have
We can bound this expression using the Berry-Esseen theorem (Berry, 1941, Esseen, 1942). Let . Then, the Berry-Esseen theorem states that there is a constant (which in practice can be ) such that for any ,
Plugging this into equation 13, we obtain the lower bound on . ∎
First, we’ll apply Lemma 6 and Lemma 7 to the situation where the ’s have initialization. This generalizes Theorem 6.
Suppose all the weights in a disjoint-Polynet are initialized randomly in and .
Let . Then, for , if (which happens w.p. 1/2),
Thus, even for arbitrarily close to , when the input is sufficiently long, the network spends almost all of training with error above .
By Lemma 6, for all whenever
we obtain that .
Also by Lemma 6, using the language of that lemma statement, for all
Once all the relevant weights have exploded to infinity, the error of the network will have zero error, so the result follows. ∎
Now let us apply Lemma 6 and Lemma 7 to the situation where the ’s have standard normal initialization. Again we find that with high probability, the phase of learning with near-trivial accuracy is much longer than the subsequent period until perfect accuracy, as illustrated by the left-hand plot in Figure 5. This culminates in the full theorem statement regarding phase transitions in the loss:
where is defined as in Corollary 10.
By a standard application of the generic Chernoff bound (Wainwright, 2019), for each , we have
This implies that w.p. ,
For each , follows a chi-squared distribution with degrees of freedom. By Laurent and Massart (2000), for any ,
Combining Equation 16 and Equation 17, we obtain the desired statement.
B.4 Global convergence and phase transition for gradient flow on disjoint-PolyNets
In this section we will analyze the training of a disjoint-PolyNet using SGD with online (i.i.d.) batches. We will show a convergence result for the 0-1 error of the learned classifier.
Assume we randomly initialize the disjoint-PolyNet with weights drawn uniformly from . Fix and run SGD at any batch size for iterations. There exists an adaptive learning rate schedule, such that, with probability 1/2 over the randomness of the initialization and over the sampling of SGD, the following holds:
For simplicity of presentation, we will assume . Let the sample at iteration be where and . Denote the population and stochastic gradient at time as:
Then, for every and , by the Azuma-Hoeffding inequality, with probability , . By the union bound, w.p. at least , for every and all it holds that . Let us assume this holds.
For all , for , .
Note that , therefore, we have,
since and . ∎
For all and it holds that
Observe that the claim holds for since . By induction on , assume the claim holds for all . Now we will prove it holds for .
Observe that for all , by the assumption:
Note that , then we have
(18) follows from observing that and . (19) follows from the inductive hypothesis . (20) follows from our assumption that and Claim 11.(21) follows from the inductive hypothesis . ∎
Setting such that for , from the above claims, after iteration , we have
Using Lemma 7, we have with probability ,
Appendix C Additional figures, experiments, and discussion
This section contains our unabridged empirical results, visualizations, and accompanying discussion. Additional example training curves (like the assortment in Figure 1 (left)) are shown in Figure 6; more examples can be found in the subsections below.
We first present the full empirical results outlined in Section 3 of the main paper. Figure 7 shows convergence times on small parity instances for all of the architecture configurations enumerated in Section 3.1. In some of these settings, exhibits high variance due to unlucky initializations (see Figure 8); thus, we report percentile convergence times. Figure 9 gives coarse-grained estimates for how scales with , based on small examples. For selected architectures, Figure 10 shows how these convergence times scale with and more precisely: for small , power law relationships (for small constants ) are observed for all configurations. Note that for larger , the exponent (i.e. the slope in the log-log plot) increases: with a constant learning rate and standard training, the does not continue indefinitely. All additional details are in Appendix D.
The remainder of Appendix C expands on the various discussions and figures from Sections 4 and 5.
Appendix C.1 gives experimental evidence that Fourier gaps are present at iterates and initializations other than sign vectors, as well as for activation functions other than ReLU. This suggests that the feature amplification mechanism is robust, and illuminates directions for strengthening the theoretical results.
Appendix C.2 discusses how the building blocks of deep learning (activation functions, biases, initializations, learning rates, and batch sizes) play multiple, sometimes conflicting roles in this setting.
Appendix C.3 provides additional white-box visualizations of hidden progress from Figure 3.
Appendix C.4 explores the implications of the feature amplification mechanism for scaling model size– namely, unlike random search, large width does not impart parallel speedups.
Appendix C.5 shows that our results hold in the finite-sample setting (allowing for multiple passes over a training set of size ). In particular, we show that in low-data regimes, the models exhibit the grokking phenomenon.
Appendix C.6 extends our results to noisy parities (which comprise the true “emblematic computationally-hard problem”).
Appendix C.7 introduces a counterexample for “layer-by-layer learning”, using parity distributions whose degrees are higher than those of the individual layers’ polynomial activations. Preliminary experiments show that standard training works in this setting.
Appendix C.8 presents examples of training curves for wide polynomial-activation MLPs, where, unlike the other settings, there is no initial plateau in the model’s error.
C.1 Fourier gaps at initialization and SGD iterates
Proposition 9 shows that if the function has a Fourier gap at , then can be identified from a batch gradient at initialization with samples. Our end-to-end result (Theorem 4) requires ReLU activations and sign vector initialization, because the Fourier gap condition (Definition 1) arises from exact formulas for the Fourier coefficients of the majority function. Stronger end-to-end theoretical guarantees would follow from analogous Fourier gaps in more general population gradients. This requires to satisfy these conditions simultaneously:
Fourier concentration: upper bounds on the degree- coefficients , for . The term is borrowed from Klivans et al. (2004), who use upper bounds on Fourier coefficients of LTFs to approximate them (thus, learn halfspaces) with low-degree polynomials.
Fourier anti-concentration: lower bounds on the degree- coefficients , for .
A natural question is: which Boolean functions, other than majority, satisfy the -Fourier gap property at , for ?
We present some numerical evidence for large Fourier gaps in functions other than majority, which arise from gradients of architectures other than ReLU MLPs with sign initialization. This shows that the mechanism of feature emergence is empirically robust in settings not fully explained by our current theory. Establishing corresponding theoretical guarantees would enable stronger end-to-end global convergence guarantees for MLPs and other architectures.
In these experiments, population gradients were computed by brute force integration over all Boolean inputs . In all cases, for various choices of , we measure a slightly relaxed notion of Fourier gap in the population gradient:
If , then oneReplacing the first in the definition of with would give us the same notion of Fourier gap as Definition 1: if all the relevant coordinates are larger than all of the irrelevant ones, estimating the population gradient allows us to recover the relevant coordinates. coordinate from the parity can be identified from samples of the gradient at .
The successful convergence of architectures with smoother activations (in the parity setting and beyond) motivates the question of whether large Fourier gaps are present in population gradients corresponding to functions other than LTFs. Figure 12 (left) shows that this is the case for sinusoidal activations.
Finally, to further close the gap between Theorem 4 and the empirical results, it is necessary to address the fact that SGD accumulates gradients with respect to time-varying iterates, while our analysis approximates this using a large-batch gradient at a static iterate . In fact, SGD seems to help in some cases: Figure 12 (right) shows that when training a sinusoidal neuron, SGD amplifies the initial Fourier gap.
C.2 Counterintuitive roles of the building blocks of deep learning
Even in this simple problem setting, the simultaneous computational and statistical considerations lead to counterintuitive consequences for the optimal configurations of architectures and algorithms for this setting. We encountered the following, in the search for architecture configurations for the empirical study:
Activation functions. This mechanism of features emerging via Fourier gaps (see Definition 1) is strongest with non-smooth activations such as the ReLU, whose derivatives are discontinuous threshold functions. This is an orthogonal consideration to representational capacity and mitigation of local minima (under which one might conclude that degree- polynomial activations are optimal). In summary, in feature learning settings where the Fourier gaps and low-complexity solutions are simultaneous relevant, there is a sharpness-smoothness tradeoff for the activation function.
Biases. The symmetry of the majority function (as well as all unbiased LTFs) causes its even-degree Fourier coefficients to be zero; thus, certain variants of the setups in Section 3 fail for odd . Bias terms (trainable or fixed) are necessary to break this symmetry, in theory and practice. Simultaneously, biases serve the more conventional role of shifting the loss surface; see Section D.1.2 for how this affects the details of how the biases were chosen in the experiments.
Initializations. The role of the initialization distribution is similarly twofold in this setting: should be close to the desired solution , but it must also be selected such that SGD will successfully amplify the Fourier gap. A third consideration, which we do not attempt to study in this work, is that multiple randomly-initialized neurons will tend to learn the correct features at different times (see the weight trajectory visualizations in Figure 3 and Figure 13, as well the staircase-like training curves seen for MLPs in Figure 1 (left)). We expect this symmetry breaking phenomenon to be present in more complex feature learning settings. Finally, as shown in the training curves from setting (vi) in Figure 6, and in more detail in Appendix C.8, the choice of activation function influences the qualitative behavior of the training curves: namely, whether the plateaus disappear at large widths and batch sizes.
C.3 Hidden progress measures
In this section, we provide an expanded discussion and plots for the investigations outlined in Figure 3 and the “hidden progress measures” section in Section 5.
Note that many trivial progress measures exist: an example to keep in mind is that for the algorithm which exhaustively enumerates over a deterministic list of hypotheses (say, the possible -element subsets in lexicographical order) and terminates when it finds the correct one, the current iteration is a progress measure. Thus, the purpose of demonstrating hidden progress measures is not to provide further evidence that SGD finds the features using Fourier gaps. Rather, it is to (1) further refute the hypothesis of SGD performing a memoryless Langevin-like random search, and (2) provide a preliminary exploration of how progress can be quantified even when the natural metrics of loss and accuracy appear to be flat.
The Fourier gap visualizations in Section C.1 already provides an example of a quantity which varies continuously as the model trains, despite no apparent progress in the loss and accuracy curves. However, none of our theoretical analyses capture the empirical observation that this quantity tends to amplify over time. Below, we consider other quantities which reveal hidden progress in parity learning, which are more straightforward and closer to our analyses.
This progress measure is shown alongside the loss curves in Figure 13, in red. We do not attempt to characterize the dynamics of ; we only note that they are clearly distinguishable from the maximum of unbiased random walks, even when SGD appears to make no progress in terms of loss and accuracy. Studying hidden progress measures in deep learning more quantitatively, as well as in more general settings, presents a fruitful direction for future work.
C.4 Convergence time vs. width
We provide supplementary plots for the experiment outlined in Figure 4 (left), which probes whether extremely large widths () afford factor- parallel speedups of the parity learning mechanism (as one would expect from random search). On 3 parity instances , we varied the width , keeping all other parameters the same .
We did not find evidence of such parallel speedups over runs in each setting; see Figure 14. This serves as further evidence that the mechanism by which standard training solves parity is best understood as deterministic and sequential, rather than behaving like random search over size- subsets. A benefit of width appears to be variance reduction: the upper tail of long convergence times is mitigated by a large number of randomly-initialized neurons.
C.5 Learning and grokking in the finite-sample case
We provide some supplementary plots for the experiments outlined in Figure 4 (right). In these settings, a fixed architecture (width- MLP with ReLU activations) is trained with minibatch SGD in an otherwise fixed configuration (hinge loss, learning rate , batches of size ) on a finite training sample of size . We also vary a weight decay parameter .
As shown in Figure 15, the weight decay parameter modulates a delicate computational-statistical tradeoff: it improves generalization (expanding the range of for which training eventually finds the correct solution), but the model fails to train at large values of . For small and appropriately tuned , we observe grokking: the model initially overfits the training data, but finds a classifier that generalizes after a large number of iterations.
C.6 Learning noisy parities
The other empirical results in this work focus on noiseless parity distributions , to reduce the number of sources of variance and degrees of freedom. However, the setting of random classification noise is important for several reasons. In this section, we briefly demonstrate that our results extend to this case. Let denote the -noisy parity distribution, defined by flipping the labels in the -parity distribution independently with probability . Note that when , the labels are completely random (thus, cannot be learned). By a standard PAC-learning argument, when , the statistical limit for identifying from i.i.d. samples from scales as .
First, learning parities from noisy samples is the true “emblematic computationally-hard distribution”. Without noise, there is a non-SQ algorithm which avoids the exponential-in- computational barrier: Gaussian elimination can identify in time and samples. Second, viewing parities as an idealized setting in which to understand training dynamics, resource scaling, and emergence in deep learning, it is important to see that this phenomenon is robust to label noise.
In particular, when architecture’s population gradient has a Fourier gap with parameter in the noiseless case implies a Fourier gap with parameter .
We find that the experimental findings are robust to label noise, in the sense that models are able to obtain nontrivial (and sometimes ) accuracy; see Figure 16 for some training curves under various settings of . This provides concrete evidence against the (already extremely dubious) hypothesis that neural networks, with standard initialization and training, learn noiseless parities by implicitly simulating an efficient algorithm such as Gaussian elimination. Note that with a constant learning rate (here, ) and label noise, the iterates of SGD do not always converge to accurate solutions.
C.7 Counterexample for layer-by-layer learning
Consider an -layer MLP with activation , parameterized by weights and biases
where denotes the function for , and denotes . The shapes of the parameters are selected such that each function composition is well-defined. Let the intermediate activations at layer be denoted by
Finally, (the width at layer ) refers to the dimensionality of as defined above.
Notice that when is a degree- polynomial (say, ), an -layer MLP can represent parities up to degree – for example, a -layer MLP (which composes quadratic activations twice) can represent a -sparse parity as a -sparse parity of -sparse parities. However, Equation (2) implies the following:
An individual layer cannot represent a parity of inputs.
The population gradient (as in Equation (5) is zero (since every coordinate of the gradient is the correlation between a -wise parity and a polynomial of degree ).
Thus, this setting serves as an idealized counterexample for layer-by-layer learning: if SGD succeeds on parities with higher degree than the architecture’s polynomial activations, it must do so by an end-to-end mechanism. Intuitively, earlier layers can only make progress by knowing how their outputs will be used downstream. Concretely, consider the population gradient of the correlation loss, with respect to a first-layer neuron’s weights . With layer-by-layer training, this gradient contains no information:
However, in end-to-end training, the presence of downstream layers removes this barrier:
giving the gradient greater representation capacity (in terms of polynomial degree). The question remains of whether end-to-end training works in this setting, which we resolve positively in small experiments.
We empirically observed successful training (to accuracy) in a few settings (with SGD, learning rate , batch size , and default uniform initialization as described in Appendix D.1):
. Small widths suffice: . Over 10 random seeds, all models converged within iterations.
. Widths were chosen to be slightly larger for stability: . Over 10 random seeds, all models converged within iterations. Additionally, models trained on converged within iterations.
As a sanity check, the models failed to converge in experimental setups where : and .
This construction serves as a simple counterexample to the “deep only works if shallow is good” principle of Malach and Shalev-Shwartz (2019), demonstrating a case where a deep network can get near-perfect accuracy even when greedy layerwise training (e.g. (Belilovsky et al., 2019)) cannot beat trivial performance. It remains to characterize these positive empirical results theoretically, as well as to investigate whether there are pertinent analogues in real data distributions.
C.8 Lack of plateaus for wide polynomial-activation MLPs
An interesting qualitative observation from the training curves in Figure 6 is that the validation accuracy curves in setting (vi) (width-1000 polynomial-activation MLPs) do not follow the same “plateau” or ”staircase” pattern as the others. Figure 17 shows a few additional examples of training curves for polynomial-activation MLPs, varying the width and batch size . We find that the rate of descent of the validation error increases with both of these parameters; note that this does not occur with ReLU activations (where there are sharp phase transitions between plateaus at all batch sizes).
This constitutes an exception to this paper’s theme of “hidden progress” behind flat loss (or error) curves: with enough overparameterization and “over-sampling”, the continuous progress of SGD in this setting is no longer hidden, and manifests in the training curves. This phenomenon seems to be specific to certain activation functions (i.e. but not ReLU); we leave it for future work to understand why and when it occurs, as well as potential practical implications.
Appendix D Details for all experiments
Our “robust space” of empirical results use the following loss functions:
In the configurations corresponding to all of the figures and convergence time experiments, we used the hinge loss. This was a relatively arbitrary choice (i.e. they appeared to be interchangeable upon running small experiments); an advantage of the hinge and square losses over cross entropy is that for architectures that can realize the parity function, there is a zero-loss solution with finite weights.
Our empirical results use the following i.i.d. weight initializations:
Uniform on the interval , where the scale is chosen for all affine transformation parameters using the “Xavier initialization” convention (Glorot and Bengio, 2010). The experiments are quite tolerant to the particular choice of (as these are not deep networks); this choice, which is the default in deep learning packages, emphasizes that our positive empirical results hold under a standard initialization scheme.
Gaussian with mean and variance , selected using the “Kaiming initialization” convention (He et al., 2015).
We consider -layer MLPs for two choices of activations:
Degree- polynomial: .
In both cases, whenever (and, in the case of polynomial activations, choosing the degree to be ), there exists a width- MLP which can represent -sparse parities: for all and , there is a setting of such that .
Note that if the output is a degree- polynomial in (e.g. an MLP with activations), the architecture is incapable of representing a parity of inputs. In fact, it is incapable of representing any function that has a nonzero correlation with parity; this follows from orthogonality (Equation (2)).
To explore the limits of concise parameterization for architectures capable of learning parities, we propose a variety of non-standard activation functions which allow a single neuron to learn sparse parities. These constructions leverage the fact that the parity is a nonlinear function of the sum of its inputs , where .
is the piecewise linear function which interpolates the points with linear regions . Then, .
is the degree- polynomial which interpolates the same points as above.
. The sinusoidal neuron can also express parities of arbitrary degree, since it can interpolate the same set of points as the -zigzag activation. In the experiments, we pick a shift and use the activation , such that interpolates the same points as the sign convention selected for the -zigzag activation. In the experiments in Section 3, the sinusoidal activation is additionally scaled by a factor of 2 (); this is interchangeable with scaling the learning rate and initialization, and is done to obtain more robust convergence in the particular setting of .
Figure 18 visualizes all families of activations considered in this paper.
The Transformer experiments use a slightly simplified version of the architecture introduced in (Vaswani et al., 2017). In particular, it omits dropout, layer normalization, tied input/output embedding weights, and a positional embedding on the special [CLS] token. Including these does not change the results significantly; they are all present in the preliminary findings in (Edelman et al., 2021) (in which an “off-the-shelf” Transformer implementation successfully learns sparse parities). We specify the architecture below.
where . Note that we have specialized this architecture to a single output, at the [CLS] position.
where (the Gaussian error linear unit) is the standard choice in Transformers.
Each matrix-shaped parameter was initialized using PyTorch’s default “Xavier uniform” convention. Unlike the other settings considered in this paper, we were unable to observe successful convergence beyond a few small using standard SGD. As is common practice when training Transformers, we used Adam (Kingma and Ba, 2014) with default adaptive parameters in our experiments. While there are more fine-grained accounts of why Adam outperforms vanilla SGD (Zhang et al., 2020, Agarwal et al., 2020), finding the optimal optimizer configuration and investigating ablations of this optimizer are outside the scope of this work. In this work, we only tune Adam’s learning rate .
Even with all biases set to , this architecture can realize a -wise parity, by setting in any permutation.
Figure 1 (left) shows training curves from 8 representative configurations, with online i.i.d. samples from the same distribution, corresponding to the -sparse parity problem. The first row encompasses various MLP settings with standard activations:
Setting (i): width- MLP with ReLU activation ().
Setting (i): width- MLP with ReLU activation, with large batches ().
Setting (ii): width- MLP with ReLU activation, with tiny batches ().
Setting (iv): width- MLP with polynomial activation ().
Setting (vii): width- MLP with a piecewise linear -zigzag activation ().
Setting (x): width- MLP with a sinusoidal activation (scaled and shifted for ; see the discussion in Section D.1.2) ().
Setting (xv): degree- PolyNet ().
The first row uses the width- ReLU MLP configuration , holding and while varying the task difficulty across 6 settings: . The remaining plots are all for the setting.
The second row uses the PolyNet configuration , varying .
The third row uses the minimally-wide configurations (*i), (*ii), (vii), (viii), (xi), (xii) (thus presenting an example for each non-standard activation), holding batch size . in each of the cases except (*ii), where .
The fourth row uses three large architectures: settings (iii), (vi), and (*iii), with (*iii) uses the Adam optimizer instead of SGD.
Figure 10 contains scaling plots in various settings for the median convergence time . Below, we give comprehensive details about these settings. For each of these runs, we chose (settings with smaller batch sizes exhibited additional variance; with larger batch sizes, the models were slower to converge), as well as the hinge loss. We used SGD with constant learning rate (enumerated below), except in setting (*iii).
The top row shows MLP settings (i) through (vi). From left to right:
Setting (i): width- MLP with ReLU activation ().
Setting (ii): width- MLP with ReLU activation ().
Setting (iii): width- MLP with ReLU activation ().
Setting (iv): width- MLP with activation ().
Setting (v): width- MLP with activation ().
Setting (vi): width- MLP with activation ().
The bottom row shows miscellaneous settings. From left to right:
Setting (vii): width- MLP with degree- oscillating polynomial activation interpolating the parity function ().
Setting (xiv): single sinusoidal neuron with no second layer ().
Setting (iv): degree- PolyNet ().
Setting (ii): width- MLP with ReLU activation (), showing an expanded range of for smaller .
Setting (xv): width- MLP with activation (), showing an expanded range of for smaller .
D.2 Training curves and convergence time plots
For all example training curves in all figures (in Sections 3 and 5, as well as the appendix), population losses and accuracies are approximated using a batch of size , sampled once at the beginning of training from the same distribution . All plots of single representative training runs use a fixed random seed (torch.manual_seed(0)); when training runs are shown, seeds are used.
In Figures 7 and 8, validation accuracies were recorded every iterations, and a run was recorded as converged if it reached accuracy within iterations; we report the percentile over 25 random seeds, to reduce variance arising from the more initialization-sensitive settings. In Figure 9, coarse-grained scaling estimates for the ( percentile) convergence time are computed as follows: for , the smallest is chosen such that , choosing , so that at . These estimates are calculated to give quantitative order-of-magnitude upper bounds for the convergence time. Indeed, the power-law convergence times do not extrapolate at a constant learning rate; see Figure 2 (right), the “larger ” plots in Figure 10, and the discussion on batch sizes and learning rates in Appendix C.2.
To reduce computational load, for the larger-scale probes of convergence times , validation accuracies were instead checked on a sample of size . For the underparameterized networks (i.e. unable to represent parity, but can still get a meaningful gradient signal), this threshold was changed to consecutive batches with accuracy at least . Note that for parity learning in particular, a weak learner can be converted into a strong learner: there is an efficient algorithm (Goldreich and Levin, 1989, Kushilevitz and Mansour, 1993) which, given a classifier which achieves accuracy on for a constant , outputs with high probability.
In the median convergence time plots in Figure 1 (right), Figure 2 (right), and Figure 10, error bars for median convergence times in all plots are confidence intervals, computed from bootstrap samples. Each point on the each curve corresponds to random trials. Halted curves signify more than of runs failing to converge within iterations (hence, infinite medians).
D.3 Implementation, hardware, and compute time
All training experiments were implemented using PyTorch (Paszke et al., 2019).
Although most of the networks considered in the main empirical results are relatively small, a large () total number of models were trained to certify the “robust space” of results and obtain precise scaling curves. These individual experiments were not large enough to benefit from GPU acceleration; on an internal cluster, the CPU compute expenditure totaled approximately CPU hours.
A subset of these experiments stood to benefit from GPU acceleration: width MLPs; scaling behaviors for ; all experiments involving Transformers. These were performed with NVIDIA Tesla P100, Tesla P40, and RTX A6000 GPUs on an internal cluster, consuming a total of approximately GPU hours.