Which Algorithmic Choices Matter at Which Batch Sizes? Insights From a Noisy Quadratic Model
Guodong Zhang, Lala Li, Zachary Nado, James Martens, Sushant Sachdeva, George E. Dahl, Christopher J. Shallue, Roger Grosse
Introduction
Increasing the batch size is one of the most appealing ways to accelerate neural network training on data parallel hardware. Larger batch sizes yield better gradient estimates and, up to a point, reduce the number of steps required for training, which reduces the training time. The importance of understanding the benefits of modern parallel hardware has motivated a lot of recent work on training neural networks with larger batch sizes (Goyal et al., 2017; Osawa et al., 2018; McCandlish et al., 2018; Shallue et al., 2018). To date, the most comprehensive empirical study of the effects of batch size on neural network training is Shallue et al. (2018), who confirmed that increasing the batch size initially achieves perfect scaling (i.e. doubling the batch size halves the number of steps needed) up to a problem-dependent critical batch size, beyond which it yields diminishing returns (Balles et al., 2017; Goyal et al., 2017; Jastrzębski et al., 2018; McCandlish et al., 2018). Shallue et al. (2018) also provided experimental evidence that the critical batch size depends on the optimization algorithm, the network architecture, and the data set. However, their experiments only covered plain SGD, SGD with (heavy-ball) momentum, and SGD with Nesterov momentum, leaving open the enticing possibility that other optimizers might extend perfect scaling to even larger batch sizes.
Empirical scaling curves like those in Shallue et al. (2018) are essential for understanding the effects of batch size, but generating such curves, even for a single optimizer on a single task, can be very expensive. On the other hand, existing theoretical analyses that attempt to analytically derive critical batch sizes (e.g. Ma et al. (2018); Yin et al. (2018); Jain et al. (2018)) do not answer our questions about which optimizers scale the best with batch size. They tend to make strong assumptions, produce parameter-dependent results that are difficult to apply, or are restricted to plain SGD. It would be ideal to find a middle ground between a purely empirical investigation and theoretical analysis by building a model of neural network optimization problems that captures the essential behavior we see in real neural networks, while still being easy to understand. Additionally, we need to study optimizers beyond momentum SGD since they might provide us an approach to exploit speedups from the very largest batch sizes. In this work, we make the following contributions:
We show that a simple noisy quadratic model (NQM) is remarkably consistent with the batch size effects observed in real neural networks, while allowing us to run experiments in seconds, making it a great tool to generate testable predictions about neural network optimization.
We show that the NQM successfully predicts that momentum should speed up training relative to plain SGD at larger batch sizes, but have no benefit at small batch sizes.
Through large scale experiments with Adam (Kingma and Ba, 2014) and K-FAC (Martens and Grosse, 2015), we confirm that, as predicted by the NQM, preconditioning extends perfect batch size scaling to larger batch sizes than are possible with momentum SGD alone. Furthermore, unlike momentum, preconditioning can help at small batch sizes as well.
Lastly, we show that, as predicted by the NQM, exponential moving averages reduce the number of steps required for a specific batch size and can achieve the same acceleration with smaller batch sizes, thereby saving computation.
Related Work
In a classic paper, Bottou and Bousquet (2008) studied the asymptotics of stochastic optimization algorithms and found SGD to be competitive with fancier approaches. They showed that stochastic optimization involves fundamentally different tradeoffs from full-batch optimization. More recently, several studies have investigated the relationship between batch size and training time for neural networks. Chen et al. (2018) studied the effect of network width on the critical batch size, and showed experimentally that it depends on both the data set and network architecture. Golmant et al. (2018) studied how various heuristics for adjusting the learning rate as a function of batch size affect the relationship between batch size and training time. Shallue et al. (2018) conducted a comprehensive empirical study on the relationship between batch size and training time with different neural network architectures and data sets using plain SGD, heavy-ball momentum, and Nesterov momentum. Finally, McCandlish et al. (2018) used the average gradient noise over training to predict the critical batch size. All of these studies described a basic relationship between batch size and training steps to a fixed error goal, which is comprised of three regions: perfect scaling initially, then diminishing returns, and finally no benefit for all batch sizes greater than the critical batch size.
Other studies have attempted to characterize the critical batch size analytically in stochastic optimization. Under varying assumptions, Ma et al. (2018); Yin et al. (2018); Jain et al. (2018) all derived analytical notions of critical batch size, but to our knowledge, all for SGD.
Additionally, previous studies have shown that SGD and momentum SGD are equivalent for small learning rates (after appropriate rescaling), both for the continuous limit (Leen and Orr, 1994) and discrete settings Yuan et al. (2016). However, they do not explain why momentum SGD (including heavy-ball and Nesterov momentum) sometimes outperforms plain SGD in mini-batch training (as observed by Kidambi et al. (2018) and Shallue et al. (2018)). Concurrently, Smith et al. (2019) showed that momentum outperforms plain SGD at large batch sizes.
Finally, there are a few works studying average of the iterates, rather than working with the last iterate. This is a classical idea in optimization, where it is known to provide improved convergence (Polyak and Juditsky, 1992; Bach and Moulines, 2013; Dieuleveut and Bach, 2016). However, most of them focused on tail averaging, which you have to decide ahead of time the iteration to start accumulating the running averaging. More commonly (especially in deep learning), exponential moving average (Martens, 2014) is preferred for its simplicity and ability to handle non-convex landscape. However, no analysis was done especially when mini-batch is used.
Analysis of the Noisy Quadratic Model (NQM)
In this section, we work with a noisy quadratic model (NQM), a stochastic optimization problem whose dynamics can be simulated analytically, in order to reason about various phenomena encountered in training neural networks. In this highly simplified model, we first assume the loss function being optimized is a convex quadratic, with noisy observations of the gradient. For analytic tractability, we further assume the noise covariance is codiagonalizable with the Hessian. Because we are not interested in modeling overfitting effects, we focus on the online training setting, where the observations are drawn i.i.d. in every training iteration. Under these assumptions, we derive an analytic expression for the risk after any number of steps of SGD with a fixed step size, as well as a dynamic programming method to compute the risk following a given step size schedule.
Convex quadratics may appear an odd model for a complicated nonconvex optimization landscape. However, one obtains a convex quadratic objective by linearizing the network’s function around a given weight vector and taking the second-order Taylor approximation to the loss function (assuming it is smooth and convex). Indeed, recent theoretical works (Jacot et al., 2018; Du et al., 2019; Zhang et al., 2019a) show that for wide enough networks, the weights stay close enough to the initialization for the linearized approximation to remain accurate. Empirically, linearized approximations closely match a variety of training phenomena for large but realistic networks (Lee et al., 2019).
We now introduce the noisy quadratic model (Schaul et al., 2013; Martens, 2014; Wu et al., 2018), where the true function being optimized is a convex quadratic. Because we analyze rotation-invariant and translation-invariant optimizers such as SGD and heavy-ball momentum, we assume without loss of generality that the quadratic form is diagonal, and that the optimum is at the origin. Hence, our exact cost function decomposes as a sum of scalar quadratic functions for each coordinate:
where is the learning rate and is zero-mean unit variance iid noise. By treating as a random variable, we immediately obtain the dynamics of its mean and variance.
Based on eqn. (3), the expected risk after steps in a given dimension is
where we have assumed that . (Note that this can be seen as a special case of the convergence result derived for convex quadratics in Martens (2014).)
Remarkably, each dimension converges exponentially to a steady state risk. Unfortunately, there is a trade-off in the sense that higher learning rates (up to ) give faster convergence to the steady state risk, but also produce higher values of the steady-state risk. The steady state risk also decreases proportionally to increases in batch size; this is important to note because in the following subsections, we will show that traditional acceleration techniques (e.g., momentum and preconditioning) help improve the convergence rate at the expense of increasing the steady state risk. Therefore, the NQM implies that momentum and preconditioning would benefit more from large-batch training compared to plain SGD, as shown in later sections.
2 Momentum Accelerates Training at Large Batch Sizes
Applied to the same noisy quadratic model as before, the update equations for momentum SGD are:
We show in the following theorem (see Appendix C for proof) that momentum SGD performs similarly to plain SGD in the regime of small batch sizes but helps in the large-batch regime, which can be viewed as a near-deterministic optimization problem.
Given a dimension index , and with , the expected risk at time associated with that dimension satisfies the upper bound
where and (with ) are the two roots of the quadratic equation .
As with plain SGD (c.f. eqn. (4)), the loss associated with each dimension can be expressed as the sum of two terms, where the first one decays exponentially and corresponds to the behavior of the deterministic version of the algorithm, and the second remains constant.
Following the existing treatment of the deterministic version of the algorithm (Chiang, 1974; Qian, 1999; Yang et al., 2018; Goh, 2017), we divide our analysis two cases: overdamping and underdamping. In the case of overdamping, where , both roots and are real and therefore the convergence rate is determined by the larger one (i.e. ), which has the value
With a fixed learning rate, the steady state risk will be constant, and the best achievable expected risk will be lower bounded by it. Thus, to achieve a certain target loss we must either drive the learning rate down, or the batch size up. Assuming a small batch size and a low target risk, we are forced to pick a small learning rate, in which case one can showTo see this, note that the term in the square root of eqn. (7) for can be written as . Dropping the term and simplifying gives the claimed expression for . that . In Figure 2 we plot the convergence rate as a function of , and we indeed observe that the convergence rate closely matches , assuming a relative small learning rate. We further note that the convergence rate and steady state risk of eqn. (6) are the same as the ones in plain SGD (eqn. (4)), except that they use an "effective learning rate" of . To help validate these predictions, in Appendix E.3 we provide a comparison of momentum SGD with plain SGD using the effective learning rate.
In the case of underdamping where , both and will be complex and have norm . We note that the optimal should be equal to or smaller than , since otherwise all dimensions are under-damped, and we can easily improve the convergence rate and steady state risk by reducing .
Next we observe that the convergence of the total loss will eventually be dominated by the slowest converging dimension (which corresponds to the smallest curvature ), and this will be in the overdamping regime as argued above. By our analysis of the overdamping case, we can achieve the same convergence rate for this dimension by simply replacing the learning rate in the bound for plain SGD (eqn. (4)) with the effective learning rate .
So while momentum gives no long-term training acceleration for very low fixed learning rates (which we are forced to use when the batch size is small), we note that it can help in large-batch training. With , the steady state risk roughly amplifies by a factor of , and we note that steady state risk also decreases proportionally to increases in batch size. Therefore, we expect momentum SGD to exhibit perfect scaling up to larger batch sizes than plain SGD.
3 Preconditioning Further Extends Perfect Scaling to Larger Batch Sizes
Many optimizers, such as Adam and K-FAC, can be viewed as preconditioned gradient descent methods. In each update, the gradient is rescaled by a PSD matrix , called the preconditioner.
In lieu of trying to construct noisy quadratic analogues of particular optimizers, we analyze preconditioners of the form with . Note that remains fixed throughout training since the Hessian is constant in the NQM. We can recover standard SGD by setting .
which is a monotonically increasing function with respect to . Even without this amplification effect, the steady state risk will eventually become the limiting factor in the minimization of the expected risk. One way to reduce the steady state risk, apart from using Polyak averaging (Polyak and Juditsky, 1992) or decreasing the learning rate (which will harm the rate of convergence), is to increase the batch size. This suggests that the benefits of using stronger preconditioners will be more clearly observed for larger batch sizes, which is an an effect that we empirically demonstrate in later sections.
4 Exponential Moving Average Reduces Steady State Risk
Following the same procedure as previous two sections, we analyze exponential moving averages (EMA) on our NQM. The update rule of EMA can be written as
Given a dimension index , and , the expected risk at time t associated with that dimension satisfies the upper bound
where and .
By properly choosing an averaging coefficient such that , one can show that EMA reduces the steady state risk without scarificing the convergence rate. To see this, we note that the red part of eqn. (12) is strictly less than given the fact while the other part is exactly the same as the steady state risk of plain SGD.
5 Choice of 𝐇𝐇\mathbf{H} and 𝐂𝐂\mathbf{C}
We’ve found that the qualitative behavior of optimizers in our NQM depends on the choices of and . Therefore, we choose matrices motivated by theoretical and empirical considerations about neural net training. First, we set the diagonal entries of to be for some integer , giving a condition number of . This closely matches the estimated eigenspectrum of the Hessian of a convolutional network (see Figure 9 and Appendix E.4), and is also consistent with recent work finding heavy tailed eigenspectra of neural network Hessians (Ubaru et al., 2017; Ghorbani et al., 2019). We choose , which approximately matches the condition number of the K-FAC Hessian approximation for ResNet8. (Qualitative behaviors were consistent for a wide range of .)
We also set (a nontrivial assumption). This was motivated by theoretical arguments that, under the assumption that the implicit conditional distribution over the network’s output is close to the conditional distribution of targets from the training distribution, the Hessian closely matches the gradient covariance in neural network training (Martens, 2014). Empirically, this relationship appears to hold tightly for a convolutional network and moderately well for a transformer (see Appendix E.2).
6 Information Theoretic Lower Bound
Since our NQM assumes the infinite data (online optimization) setting, it’s instructive to compare the performance of optimizers against an information theoretic lower bound. Specifically, under the assumption that , the NQM is equivalent to maximum likelihood estimation of the mean vector for a multivariate Gaussian distribution with covariance . Hence, the risk obtained by any optimizer can be bounded below by the risk of the maximum likelihood estimator for the Gaussian, which is , where is the dimension and is the total number of training examples visited. We indicate this bound with a dashed black line in our plots.
7 Noisy Quadratic Experiments
In this section, we simulate noisy quadratic optimization using the closed-form dynamics. Our aim is to formulate hypotheses for how different optimizers would behave for neural network optimization. Our main metric is the number of steps required to achieve a target risk. For efficiency, rather than explicitly representing all the eigenvalues of , we quantize them into 100 bins and count the number of eigenvalues in each bin. Unless otherwise specified, we initialize as and use a target risk of 0.01. (The results don’t seem to be sensitive to either the initial variance or the target risk; some results with varying target risk thresholds are shown in Appendix E.5).
We first experiment with momentum and varying preconditioner powers on our NQM. We treat both the (fixed) learning rate and momentum decay parameter as hyperparameters, which we tune using a fine-grained grid search.
Consistent with the empirical results of Shallue et al. (2018), each optimizer shows two distinct regimes: a small-batch (stochastic) regime with perfect linear scaling, and a large-batch (deterministic) regime insensitive to batch size. We call the phase transition between these regimes the critical batch size. Consistent with the analysis of Section 3.2 and the observations of Smith et al. (2018); Shallue et al. (2018); Kidambi et al. (2018), the performance of momentum-based optimizers matches that of the plain SGD methods in the small-batch regime, but momentum increases the critical batch size and gives substantial speedups in the large batch regime. Preconditioning also increases the critical batch size and gives substantial speedups in the large batch regime, but interestingly, also improves performance by a small constant factor even for very small batches. Combining momentum with preconditioning extends both of these trends.
We next experiment with EMA and varying preconditioning powers on our NQM. Following the same procedure as before, we tune both learning rate and averaging coefficient using grid search. As expected, EMA reduces the number of steps required especially for plain SGD with preconditioning power . Another interesting observation is that EMA becomes redundant in the large batch (near-deterministic) regime since the main effect of EMA is reducing the steady-state risk, which can also be done by increasing the batch size. This implies that EMA would reduce the critical batch size and therefore achieve the same amount of acceleration with less computation.
7.2 Optimal Learning Rate and Decay Scheme
In the NQM, we can calculate the optimal constant learning rate given a specific batch size. Figure 14 shows the optimal learning rate as a function of batch size for a target risk of . Notably, the optimal learning rate of plain (preconditioned) SGD (Figure 14(a)) scales linearly with batch size before it hits the critical batch size, matching the scheme used in Goyal et al. (2017). The linear scaling also holds for the effective learning rate of momentum SGD. In the small batch regime, the optimal effective learning rate for momentum SGD matches the optimal plain SGD learning rate, suggesting that the momentum and learning rate are interchangeable in the small batch regime.
While a fixed learning rate often works well for simple problems, good performance on the ImageNet benchmark (Russakovsky et al., 2015) requires a carefully tuned schedule. Here we explicitly optimize a piecewise constant learning rate schedule for SGD (with 50 pieces), in terms of the number of steps to reach the loss threshold.For a given schedule and number of time steps, we obtain the exact risk using dynamic programming with eqn. (3). For stability, the learning rates are constrained to be at most . For a fixed number of time steps, we minimize this risk using BFGS. We determine the optimal number of time steps using binary search. In Figure 3(b), we show that optimized learning rate schedules help significantly in the small batch regime, consistent with the analysis in Wu et al. (2018). We observe the same linear scaling as with fixed-learning-rate SGD, but with a better constant factor. In fact, optimized schedules nearly achieve the information theoretic optimum. However, learning rate schedules do not improve at all over fixed learning rates in the large batch regime. Figure 3(c) shows optimized schedules for different batch sizes; interestingly, they maintain a large learning rate throughout training followed by a roughly exponential decay, consistent with commonly used neural network training schedules. Additionally, even though the different batch sizes start with the same learning rate, their final learning rates at the end of training scale linearly with batch size (see Figure 15 in Appendix E.7).
Neural Network Experiments
We investigated whether the predictions made by the NQM hold in practice by running experiments with five neural network architectures across three image classification tasks and one language modeling task (see Table 1). For each model and task, we compared a range of optimizers: SGD, momentum SGD, Adam (with and without momentum), and K-FAC (with and without momentum). For K-FAC, preconditioning is applied before momentum. See Appendix F for more details.
The primary quantity we measured is the number of steps required to reach a target accuracy (for image classification tasks) or cross entropy (for language modeling). Unless otherwise specified, we measured steps to target on the validation set. We chose the target metric values based on an initial set of experiments with practical computational budgets. For each model, task, optimizer, and batch size, we independently tuned the learning rate , the parameters governing the learning rate schedule (where applicable), and optimizer-specific metaparameters (see Appendix F.4). We manually chose the search spaces based on our initial experiments, and we verified after each experiment that the optimal metaparameter values were far from the search space boundaries. We used quasi-random search (Bousquet et al., 2017) to tune the metaparameters with fixed budgets of non-divergentWe discarded trials with a divergent training loss, which occurred when the learning rate was too high. trials (100 for Simple CNN, ResNet8, and Transformer, and 200 for ResNet32 and VGG11). We chose the trial that reached the target metric value using the fewest number of steps.
Figure 5 shows the relationship between batch size and steps to target for each model, task, and optimizer. In each case, as the batch size grows, there is an initial period of perfect scaling where doubling the batch size halves the steps to target, but once the batch size exceeds a problem-dependent critical batch size, there are rapidly diminishing returns, matching the results of (Goyal et al., 2017; McCandlish et al., 2018; Shallue et al., 2018). K-FAC has the largest critical batch size in all cases, highlighting the usefulness of preconditioning. Momentum SGD extends perfect scaling to larger batch sizes than plain SGD, but for batch sizes smaller than the plain SGD critical batch size, momentum SGD requires as many steps as plain SGD to reach the target. This is consistent with both the empirical results of Shallue et al. (2018) and our NQM simulations. By contrast, Adam and K-FAC can reduce the number of steps needed to reach the target compared to plain SGD even for the smallest batch sizes, although neither optimizer does so in all cases. Finally, we see some evidence that the benefit of momentum diminishes with preconditioning (Figures 5(a) and 5(b)), as predicted by our NQM simulations, although we do not see this in all cases (e.g. Figure 5(c) and 5(f)).
2 Exponential Moving Average Improves Convergence with Minimal Computation Cost
To verify the predictions of NQM on exponential moving average (EMA), we conducted some experiments on comparing EMA with plain SGD. We follow the same protocol of Figure 5 and report the results in Figure 6. As expected, the results on real neural networks closely match our predictions based on NQM analysis. In particular, SGD with EMA appears to reach the same target with fewer steps than plain SGD at small batch sizes, though the benefit of EMA diminishes with large batch sizes. Besides, we note that EMA leads to smaller critical batch sizes and achieves the same acceleration with less computation.
3 Optimal Learning Rate
The NQM predicts that the optimal constant learning rate for plain SGD (or effective learning rate for momentum SGD) scales linearly with batch size initially, and then levels off after a certain batch size. Figure 7 shows the empirical optimal (effective) learning rate as a function of batch size for simple CNN on MNIST and ResNet8 on CIFAR10. For small batch sizes, the optimal learning rate of plain SGD appears to match the optimal effective learning rate of momentum SGD. However, after a certain batch size, the optimal learning rate for plain SGD saturates while the optimal effective learning rate of momentum SGD keeps increasing. Interestingly, plain SGD and momentum SGD appear to deviate at the same batch size in the optimal effective learning rate and steps to target plots (Figures 5 and 7).
4 Steps to Target on the Training Set
Figure 8 shows the empirical relationship between batch size and steps to target, measured on the training set, for ResNet8 and ResNet32 on CIFAR10. For ResNet8, the curves are almost identical to those using validation accuracy (Figure 5(c)), but for ResNet32, the gaps between different optimizers become much smaller than in Figure 5(e) and the effects of momentum and preconditioning appear to become less significant. Nevertheless, the qualitative differences between optimizers are consistent with the validation set measurements.
Conclusion
In this work, we analyzed the interactions between the batch size and the optimization algorithm from two perspectives: experiments with real neural networks, and a noisy quadratic model with parameters chosen based on empirical observations about neural networks. Despite its simplicity, the noisy quadratic model agrees remarkably well with a variety of neural network training phenomena, including learning rate scaling, critical batch sizes, and the effects of momentum, preconditioning and averaging. More importantly, the noisy quadratic model allows us to run experiments in seconds, while it can take weeks, or even months, to conduct careful large-scale experiments with real neural networks. Therefore, the noisy quadratic model is a convenient and powerful way to quickly formulate testable predictions about neural network optimization.
References
Appendix A Kronecker-factored Approximate Curvature (K-FAC)
As shown by eqn. (14), computing natural gradient using K-FAC only consists of matrix transformations comparable to size of , making it very efficient.
K-FAC has been implemented on the autoencoder [Martens and Grosse, 2015] and various convolutional networks [Grosse and Martens, 2016, Ba et al., 2017] before. To our knowledge, this is the first time K-FAC is implemented on the Transformer model. What is different from the previous models is the shared weight matrix between the embedding layer and the pre-softmax linear transformation [Vaswani et al., 2017]. In particular, the weight matrix is transposed at the pre-softmax layer: and . With the same assumptions as the non-transposed case, we get
i.e. the positions of the two Kronecker factors are swapped. If we name the two Kronecker factors "input factor" and "output factor" respectively, i.e. , then for the weight matrix that is shared between the embedding layer and the pre-softmax layer, the input_factor has contributions from both the embedding inputs and the gradients of pre-softmax layer outputs; and the output_factor has contributions from both the pre-softmax layer inputs and the gradients of the embedding outputs. In practice, when computing a Kronecker factor, we treat contribution from multiple sources as an equivalent situation as contribution from multiple training examples from a mini-batch. Also note that because of the high dimensionality of the embedding weight matrix (with a vocabulary size of 32,768), the dense input factor would have size $$. In order to save memory, we use a diagonal matrix to estimate the input_factor. The output_factor is still estimated with a dense matrix.
Appendix B Dynamics of momentum SGD on noisy quadratic model
The convergence rate is determined by the transition matrix which has the characteristic polynomial
With the momentum value , all eigenvalues of the transition matrix are equal to each other with the value , giving the fastest convergence.
Appendix C Proof of Theorem 1
For a linear dynamical system like eqn. (21), we can get in the following form:
We first analyze the stochastic term . For notational convenience, we define
In eqn. (24), we append zero vector for convenience. To compute the infinite sum, we first focus on a single term. We have the following update:
Since we only care which totally decide the loss, so we get rid of by merging two updates, which yields a second-order difference equation:
with initial conditions and . To solve the second-order difference equation, we leverage the Z-transform to get the analytical form. Based on basic manipulation of the Z-transform, we have the Z-domain function
where and are two roots of equation . Then, we use the inverse Z-transform to get :
Now, we are ready to compute the infinite sum :
Because and are two roots with , , we have
Now, we analyze the deterministic term. Similar to the analysis of stochastic term, we have the same second-order difference equation
except the initial conditions become . According to Z-transform, we have
Appendix D Proof of Theorem 2
For such a linear dynamical system, we can easily get the in the following form:
Now, to get the closed-form of , we first analyze the second term which involves the infinite sum. For notational convenience, we introduce the following notations:
In eqn. (40), we append zero vector for convenience. To compute the infinite sum, we first focus on a single term. We have the following update:
Since we only care which totally decide the loss, so we get rid of by merging two updates, which yields a second-order difference equation:
with initial conditions and . To solve the second-order difference equation, we leverage the Z-transform to get the analytical form. Based on basic manipulation of the Z-transform, we have the Z-domain function
where and are two roots of equation . Then, we use the inverse Z-transform to get :
Now, we are ready to compute the infinite sum :
It is easy to see that and , we then plug them back into eqn. (46) and get
For the other term , we can reuse the same second-order difference equation (42) except with initial conditions . According to Z-transform, we have
Therefore, we have the following upper bound:
Appendix E More results on the NQM
The main objective of this section is to examine the loss surface of modern neural networks in different stages of training in order to justify the assumptions made in NQM. Nevertheless, it is hard to visualize such a high dimensional space. Following recent work [Sagun et al., 2016, Ghorbani et al., 2019], we instead focus on analyzing the eigenspectrum of the Hessian/Fisher matrices. The Hessian/Fisher of the training loss (with respect to the parameters) is crucial in determining many behaviors of neural networks. The eigenvalues of the Hessian/Fisher characterize the local curvature of the loss surface which determines many training behaviors, including first-order methods optimization rates (at least for convex problems.)
It has been noted that the true Fisher matrix is equivalent to the generalized Gauss-Newton Hessian matrix [Martens, 2014], so we take it as a proxy of the Hessian. To construct the eigenspectrum of the true Fisher matrix, we first leverage the Kronecker-factored approximation of the Fisher to get an estimation of the eigenspectrum, which may shed light upon the true eigenspectrum. Specifically, we train the network with K-FAC and then perform eigen-decomposition on saved Kronecker factors of the Fisher to calculate the eigenvalues.
The eigenspectra are plotted in Figure 9. One interesting observation is that there are only a few large eigenvalues and a few small eigenvalues in the approximate Fisher matrices; the bulk of eigenvalues are in the middle of the spectrum. We also note that after 200 iterations of training the eigenspectrum remains mostly unchanged.
E.2 Gradient Covariance in the Kronecker-Factored Eigenbasis
E.3 Plots for the Evolution of the First Term in Eqn. (6)
In Section 3.2, we claim that the convergence of momentum SGD for a single dimension is very close to that of plain SGD with an adjusted learning rate (note that we already verified that the steady state risk of momentum SGD matches plain SGD using effective learning rate in Figure 2). Here we verify this argument by comparing them in the NQM. The total risk consists of two terms (eqn. (6)): the first term determines convergence, while the second term (steady state risk) stays constant throughout training. Given that the second stays unchanged, we only plot the first term of eqn. (6) in Figure 11. Note that the values are normalized in the figures. We observe that the convergence dynamics of the two update rules closely match each other. For this experiment we set , but the results are not sensitive to this value.
E.4 Verification of Eigenspectrum
In Section 3.7, we assume the diagonal entries of are . To justify this choice, we compare the K-FAC eigenspectra of ResNet8 to this distribution in Figure 12. The distribution of eigenvalues we chose for in the NQM very closely matches the eigenspectra of the real neural network, validating the assumption that the diagonal entries of are in Section 3.5.
E.5 Effect of Loss Threshold
Recall that a main objective of this work is to characterize the effects of increasing the batch size on training time, as measured in the number of steps necessary to reach a goal target error/loss. Here we experiment with different loss thresholds to study the relationship between batch size and number of training steps. To obtain the minimal training steps for a given batch size, we do grid search over constant learning rates. Figure 13 shows that increasing the batch size initially decreases the required number of training steps proportionally, but eventually there are diminishing returns, which matches the empirical findings [Golmant et al., 2018, Shallue et al., 2018]. The shape of the curves is characteristically the same for different loss thresholds, though the critical batch size seems to increase for more difficult thresholds.
E.6 Results of Optimal Learning Rate on NQM
E.7 Final Learning Rate of Different Batch Sizes for PWC Learning Rate Scheme
In Section 3.7.2, we study the piecewise constant learning rate scheme. The optimal scheme starts with a high learning rate which drops later in training (Figure 3(c)). Recall that for fixed learning rates, we observed that the optimal learning rate scaled linearly with the batch size for small batch sizes, but it is unclear whether there is a similar phenomenon for learning rate decay. In Figure 15, we plot the final learning rate as a function of batch size and show that it also scales linearly with batch size.
Appendix F More Details for Experiments
The data sets in Table 1 (MNIST, Fashion MNIST, CIFAR10, ImageNet and LM1B) are identical to those of Shallue et al. (described in their Appendix A.1). For CIFAR10 we used data augmentation (including horizontal flip and random crop), but they did not.
F.2 Model Details
This section provides details of models in Table 1. The models are very similar (and some identical) to those used in Shallue et al. (described in their Appendix B). Any modifications from them are highlighted in this section.
Simple CNN consists of 2 convolutional layers with max-pooling followed by 1 fully connected hidden layer. The convolutional layers use 5×5 filters with stride length 1, “same” padding [Goodfellow et al., 2016], and ReLU activation function. Max pooling uses 2×2 windows with stride length 2. Unlike in Shallue et al. , we did not use any dropout regularization (while they used dropout with probability 0.4 in the fully connected layer). We used 32 and 64 filters in the convolutional layers and 1,024 units in the fully connected layer. This corresponds to the “base” configuration in Shallue et al. .
ResNet8 [He et al., 2016] consists of 7 convolutional layers with residual connections followed by 1 fully connected hidden layer. We used the identical architecture as Shallue et al. . In particular, we did not use batch normalization. The only difference is that we used data augmentation in our experiments.
ResNet32 [He et al., 2016] consists of 31 convolutional layers with residual connections followed by 1 fully connected hidden layer (see Section 4.2 of He et al. ). We replaced batch normalization [Ioffe and Szegedy, 2015] with ghost batch normalization to keep the training objective fixed between batch sizes and to avoid possible negative effects from computing batch normalization statistics over a large number of examples [Hoffer et al., 2017]. We used a ghost batch size of 32 for all experiments. We also applied label smoothing [Szegedy et al., 2016] to regularize the model at training time, which was helpful for larger batch sizes. We set the label smoothing parameter to 0.1 in all experiments. Instead of using weight decay, we applied channel-wise weight normalization by constraining the Frobenius norm of each convolutional channel to be exactly 1, which controls the effective learning rate [Zhang et al., 2019b, van Laarhoven, 2017].
VGG11 [Simonyan and Zisserman, 2015] consists of 8 convolutional layers followed by 1 fully connected hidden layers. as in ResNet32, we used Ghost batch normalization, label smoothing, and channel-wise weight normalization.
Transformer Vaswani et al. is a self-attention model. We chose the Transformer model identical to the “base” model described in Vaswani et al. , except with only two hidden layers instead of six. This is identical to the “Transformer Shallow” model in Shallue et al. .
F.3 Learning Rate Schedules
This section describes two learning rate schedules mentioned in Table 1: constant schedule and linear decay schedule. Constant schedule simply keeps a fixed learning rate throughout training:
where is the training step index. Linear decay schedule is
where is the initial learning rate, is the rate of decay, and is the number of steps taken to reach the final learning rate. Shallue et al. experimented with various learning rate schedules and found that linear decay matched performance of the other schedules with fewer hyperparameters to tune. Therefore, we also chose the linear decay schedule, for which we tuned , and .
F.4 Optimizer-Specific Hyperparamters
For momentum SGD, we tuned the momentum . For Adam, we tuned , , and (see Kingma and Ba ). For K-FAC, we tuned damping and the trust region constraint (also known as the KL clipping term) for Transformer, keeping momentum and the moving average parameter for damping ; for all other models, we tuned all four parameters (see Martens and Grosse ).