The Power of Interpolation: Understanding the Effectiveness of SGD in Modern Over-parametrized Learning

Siyuan Ma, Raef Bassily, Mikhail Belkin

Introduction

In recent years, Stochastic Gradient Descent (SGD) with a small mini-batch size has become the backbone of machine learning, used in nearly all large-scale applications of machine learning methods, notably in conjunction with deep neural networks. Mini-batch SGD is a first order method which, instead of computing the full gradient of L(w)\mathcal{L}({\boldsymbol{w}}), computes the gradient with respect to a certain subset of the data points, often chosen sequentially. In practice small mini-batch SGD consistently outperforms full gradient descent (GD) by a large factor in terms of the computations required to achieve certain accuracy. However, the theoretical evidence has been mixed. While SGD needs less computations per iteration, most analyses suggest that it requires adaptive step sizes and has the rate of convergence that is far slower than that of GD, making computational efficiency comparisons difficult.

In this paper, we explain the reasons for the effectiveness of SGD by taking a different perspective. We note that most of modern machine learning, especially deep learning, relies on classifiers which are trained to achieve near zero classification and regression losses on the training data. Indeed, the goal of achieving near-perfect fit on the training set is stated explicitly by the practitioners as a best practice in supervised learningPotentially using regularization at a later stage, after a near-perfect fit is achieved., see, e.g., the tutorial [Sal17]. The ability to achieve near-zero loss is provided by over-parametrization. The number of parameters for most deep architectures is very large and often exceeds by far the size of the datasets used for training (see, e.g., [CPC16] for a summary of different architectures). There is significant theoretical and empirical evidence that in such over-parametrized systems most or all local minima are also global and hence correspond to the regime where the output of the learning algorithm matches the labels exactly [GAGN15, CCSL16, ZBH+16, HLWvdM16, SEG+17, BFT17]. Since continuous loss functions are typically used for training, the resulting function interpolates the dataMost of these architectures should be able to achieve perfect interpolation, fw(xi)=yif_{{\boldsymbol{w}}^{*}}({\boldsymbol{x}}_{i})=y_{i}. In practice, of course, it is not possible even for linear systems due to the computational and numerical limitations., i.e., fw(xi)yif_{{\boldsymbol{w}}^{*}}({\boldsymbol{x}}_{i})\approx y_{i}.

While we do not yet understand why these interpolated classifiers generalize so well to unseen data, there is ample empirical evidence for their excellent generalization performance in deep neural networks [GAGN15, CCSL16, ZBH+16, HLWvdM16, SEG+17], kernel machines [belkin2018understand] and boosting [SFBL98]. In this paper we look at the significant computational implications of this startling phenomenon for stochastic gradient descent.

Our first key observation is that in the interpolated regime SGD with fixed step size converges exponentially fast for convex loss functions. The results showing exponential convergence of SGD when the optimal solution minimizes the loss function at each point go back to the Kaczmarz method [Kac37] for quadratic functions, more recently analyzed in [SV09]. For the general convex case, it was first proved in [MB11]. The rate was later improved in [NWS14]. However, to the best of our knowledge, exponential convergence in that regime has not been connected to over-parametrization and interpolation in modern machine learning. Still, exponential convergence by itself does not allow us to make any comparisons between the computational efficiency of SGD with different mini-batch sizes and full gradient descent, as the existing results do not depend on the mini-batch size mm. This dependence is crucial for understanding SGD, as small mini-batch SGD seems to dramatically outperform full gradient descent in nearly all applications. Motivated by this, in this paper we provide an explanation for the empirically observed efficiency of small mini-batch SGD. We provide a detailed analysis for the rates of convergence and computational efficiency for different mini-batch sizes and a discussion of its implications in the context of modern machine learning.

We first analyze convergence of mini-batch SGD for convex loss functions as a function of the batch size mm. We show that there is a critical mini-batch size mm^{*} that is nearly independent on nn, such that the following holds:

(linear scaling) One SGD iteration with mini-batch of size mmm\leq m^{*} is equivalent to mm iterations of mini-batch of size one up to a multiplicative constant close to 11.

(saturation) One SGD iterations with a mini-batch of size m>mm>m^{*} is nearly (up to a small constant) as effective as one iteration of full gradient descent.

We see that the critical mini-batch size mm^{*} can be viewed as the limit for the effective parallelization of mini-batch computations. If an iteration with mini-batch of size mmm\leq m^{*} can be computed in parallel, it is nearly equivalent to mm sequential steps with mini-batch of size 11. For m>mm>m^{*} parallel computation has limited added value.

Next, for the quadratic loss function, we obtain a sharp characterization of these regimes based on an explicit derivation of optimal step size as a function of mm. In particular, in this case we show that the critical mini-batch size is given by m=maxi=1n{xi2}λ1(H)m^{*}=\frac{\max_{i=1}^{n}\{\left\lVert{\boldsymbol{x}}_{i}\right\rVert^{2}\}}{\lambda_{1}(H)} , where HH is the Hessian at the minimizer and λ1\lambda_{1} is its spectral norm.

Our result shows that mm^{*} is nearly independent of the data size nn (depending only on the properties of the Hessian). Thus SGD with mini-batch size mm^{*} (typically a small constant) gives essentially the same convergence per iteration as full gradient descent, implying acceleration by a factor of O(n)O(n) over GD per unit of computation.

We also show that a mini-batch of size one is optimal in terms of computations required to achieve a given error. Our theoretical results are based on upper bounds which we show to be tight in the quadratic case and nearly tight in the general convex case.

There have been work on understanding the interplay between the mini-batch size and computational efficiency, including [TBRS13, LZCS14, YPL+18] in the standard non-interpolated regime. However, in that setting the issue of bridging the exponential convergence of full GD and the much slower convergence rates of mini-batch SGD is harder to resolve, requiring extra components, such as tail averaging [JKK+16] (for quadratic loss).

We provide experimental evidence corroborating this on real data. In particular, we demonstrate the regimes of linear scaling and saturation and also show that on real data mm^{*} is in line with our estimate. It is typically several orders of magnitude smaller than the data size nn implying a computational advantage of at least 10310^{3} factor over full gradient descent in realistic scenarios in the over-parametrized (or fully parametrized) setting. We believe this sheds light on the impressive effectiveness of SGD observed in many real-world situation and is the reason why full gradient descent is rarely, if ever, used. In particular, the “linear scaling rule” recently used in deep convolutional networks [Kri14, GDG+17, YGG17, SKL17] is consistent with our theoretical analyses.

The rest of the paper is organized as follows:

In Section 3, we analyze the fast convergence of mini-batch SGD and discuss some implications for the variance reduction techniques. It turns out that in the interpolated regime, simple SGD with constant step size is equally or more effective than the more complex variance reduction methods.

Section 4 contains the analysis of the special case of quadratic losses, where we obtain optimal convergence rates of mini-batch SGD, and derive the optimal step size as a function of the mini-batch size. We also analyze the computational efficiency as a function of the mini-batch size.

In Section 5 we provide experimental evidence using several datasets. We show that the experimental results correspond closely to the behavior predicted by our bounds. We also briefly discuss the connection to the linear scaling rule in neural networks.

Preliminaries

Before we start our technical discussion, we briefly overview some standard notions in convex analysis. Here, we will focus on differentiable convex functions, however, the definitions below extend to general functions simply by replacing the gradient of the function at a given point by the set of all sub-gradients at that point. In fact, since in this paper we only consider smooth functions, differentiability is directly implied.

Interpolation and Fast SGD: Convex Loss

Next, we state our key assumption in this work. This assumption describes the interpolation setting, which is aligned with what we usually observe in over-parametrized settings in modern machine learning.

where mm is the size of a mini-batch of data points whose indices {it(1),,it(m)}\{i_{t}^{(1)},\ldots,i_{t}^{(m)}\} are drawn uniformly with replacement at each iteration tt from {1,,n}\{1,\ldots,n\}.

The theorem below shows exponential convergence for mini-batch SGD in the interpolated regime.

By the λ\lambda-smoothness of L\mathcal{L}, we have

Then we prove inequality (2) by showing that

Applying the expectation to the inner product and using the α\alpha-strong convexity of L\mathcal{L}, we have

From (7) and (8), it is easy to see that for any pp\in when choosing η(p)min{pmβ,1pλmm1}\eta(p)\triangleq\min\{\frac{p\cdot m}{\beta},\frac{1-p}{\lambda}\cdot\frac{m}{m-1}\}, we have

We see that for pp\in, 1η(p)α1-\eta(p)\cdot\alpha reaches its minimum (for fastest convergence) when p=ββ+λ(m1)p=\frac{\beta}{\beta+\lambda(m-1)}. Thus we choose η(m)=mβ+λ(m1)\eta^{*}(m)=\frac{m}{\beta+\lambda(m-1)} corresponding to the best pp and obtain,

Incorporating this result with inequality (3), we have

For m=1m=1, this theorem is a special case of Theorem 2.1 in [NWS14], which is a sharper version of Theorem 1 in [MB11].

Speedup factor. Let t(m)t(m) be the number of iterations needed to reach a desired accuracy with batch size mm. Assuming λα\lambda\gg\alpha, the speed up factor t(1)t(m)\frac{t(1)}{t(m)}, which measures the number of iterations saved by using larger batch, is

Critical batch size mβλ+1m^{*}\triangleq\frac{\beta}{\lambda}+1 . By estimating the speedup factor for each batch size mm, we directly obtain

Linear scaling regime: one iteration of batch size mmm\leq m^{*} is nearly equivalent to mm iterations of batch size 11.

Saturation regime: one iteration with batch size m>mm>m^{*} is nearly equivalent to one full gradient iteration.

We give a sharper analysis for the case of quadratic loss in Section 4.

For general convex optimization, a set of important stochastic methods [RSB12, JZ13, DBLJ14, XZ14, AZ16] have been proposed to achieve exponential (linear) convergence rate with constant step size. The effectiveness of these methods derives from their ability to reduce the stochastic variance caused by sampling. In a general convex setting, this variance prevents SGD from both adopting a constant step size and achieving an exponential convergence rate.

Remarkably, in the interpolated regime, Theorem 1 implies that SGD obtains the benefits of variance reduction “for free" without the need for any modification or extra information (e.g., full gradient computations for variance reduction). The table on the right compares the convergence of SGD in the interpolation setting with several popular variance reduction methods. Overall, SGD has the largest step size and achieves the fastest convergence rate without the need for any further assumptions. The only comparable or faster rate is given by Katyusha, which is an accelerated SGD method combining momentum and variance reduction for faster convergence.

How Fast is Fast SGD: Analysis of Step, Mini-batch Sizes and Computational Efficiency for Quadratic Loss

In this section, we analyze the convergence of mini-batch SGD for quadratic losses. We will consider the following key questions:

What is the optimal convergence rate of mini-batch SGD and the corresponding step size as a function of mm (size of mini-batch)?

What is the computational efficiency of different batch sizes and how do they compare to full GD?

The case of quadratic losses covers over-parametrized linear or kernel regression with a positive definite kernel. The quadratic case also captures general smooth convex functions in the neighborhood of a minimum where higher order terms can be ignored.

Consider the problem of minimizing the sum of squares

For any vH{\boldsymbol{v}}\in{\mathcal{H}}, let Pv\mathbf{P}_{\boldsymbol{v}} denote the projection of v{\boldsymbol{v}} unto the subspace spanned by {e1,,ek}\{{\boldsymbol{e}}_{1},\ldots,{\boldsymbol{e}}_{k}\} and Qv\mathbf{Q}_{{\boldsymbol{v}}} denote the projection of v{\boldsymbol{v}} unto the subspace spanned by {ek+1,,ed}\{{\boldsymbol{e}}_{k+1},\ldots,{\boldsymbol{e}}_{d}\}. That is, v=Pv+Qv{\boldsymbol{v}}=\mathbf{P}_{{\boldsymbol{v}}}+\mathbf{Q}_{{\boldsymbol{v}}} is the decomposition of v{\boldsymbol{v}} into two orthogonal components: its projection onto Range(H)\mathsf{Range}(H) (i.e., the range space of HH, which is the subspace spanned by {e1,,ek}\{{\boldsymbol{e}}_{1},\ldots,{\boldsymbol{e}}_{k}\}) and its projection onto Null(H)\mathsf{Null}(H) (i.e., the null space of HH, which is the subspace spanned by {ek+1,,ed}\{{\boldsymbol{e}}_{k+1},\ldots,{\boldsymbol{e}}_{d}\}). Hence, the above quadratic loss can be written as

To minimize the loss in this setting, consider the following SGD update with mini-batch of size mm and step size η\eta:

Let δtwtw{\boldsymbol{\delta}}_{t}\triangleq{\boldsymbol{w}}_{t}-{\boldsymbol{w}}^{*}. Observe that we can write (11) as

Now, we make the following simple claim (whose proof is given in the appendix).

This also implies that for any vNull(H)=Span{ek+1,,ed},{\boldsymbol{v}}\in\mathsf{Null}(H)=\mathsf{Span}\{{\boldsymbol{e}}_{k+1},\ldots,{\boldsymbol{e}}_{d}\}, we must have Hmv=0H_{m}{\boldsymbol{v}}=0.

By the above claim, the update equation (12) can be decomposed into two components:

From (10), it follows that for any iteration tt, the target loss function Lwt\mathcal{L}_{{\boldsymbol{w}}_{t}} is not affected at all by Qδt\mathbf{Q}_{{\boldsymbol{\delta}}_{t}}, that is, Pδt\mathbf{P}_{{\boldsymbol{\delta}}_{t}} is the only component that matters. Hence, by (13-14), we only need to consider the effective SGD update (13), i.e., the update in the span of {e1,,ek}\{{\boldsymbol{e}}_{1},\ldots,{\boldsymbol{e}}_{k}\}.

1 Upper bound on the expected empirical loss

The following theorem provides an upper bound on the expected empirical loss after tt iterations of mini-batch SGD whose update step is given by (11).

Let g(m,η)maxλ[λk,λ1]g(λ;m,η).g(m,\eta)\triangleq\max_{\lambda\in[\lambda_{k},\lambda_{1}]}g(\lambda;m,\eta). In the interpolation setting, for any t1,t\geq 1, the mini-batch SGD with update step (11) yields the following guarantee

By reordering terms in the update equation 13 and using the independence of HmH_{m} and wt1{\boldsymbol{w}}_{t-1}, the variance in the effective component of the parameter update can be written as

Let Gm,ηI2ηH+η2(βmH+m1mH2)G_{m,\eta}\triangleq I-2\eta H+\eta^{2}(\frac{\beta}{m}H+\frac{m-1}{m}H^{2}). Then the variance is bounded as

Furthermore, the convergence rate relies on the eigenvalues of Gm,ηG_{m,\eta}. Let λ\lambda be a non-zero eigenvalue of HH, then the corresponding eigenvalue of Gm,ηG_{m,\eta} is given by

When the step size η\eta and mini-batch size mm are chosen (satisfying constraint 16), we have

2 Tightness of the bound on the expected empirical loss

We now show that our upper bound given above is indeed tight in the interpolation setting for the class of quadratic loss functions defined in (9). Namely, we give a specific instance of (9) where the upper bound in Theorem 2 is tight.

which shows that inequality (17) is also achieved with equality in that setting. This completes the proof. ∎

Remark. From the experimental results, it appears that our upper bound can be close to tight even in some settings when the eigenvalues are far apart. We plan to investigate this phenomenon further.

3 Optimal step size for a given batch size

To fully answer the first question we posed at the beginning of this section, we will derive an optimal rule for choosing the step size as a function of the batch size. Specifically, we want to find step size η(m)\eta^{*}(m) to achieve fastest convergence. Given Theorem 2, our task reduces to finding the minimizer

Let g(m)g^{*}(m) denote the resulting minimum, that is, g(m)=g(m,η(m))g^{*}(m)=g\left(m,\eta^{*}(m)\right). The resulting expression for the minimizer η(m)\eta^{*}(m) generally depends on the least non-zero eigenvalue λk\lambda_{k} of the Hessian matrix. In situations where we don’t have a good estimate for this eigenvalue (which can be close to zero in practice), one would rather have a step size that is independent of λk\lambda_{k}. In Theorem 5, we give a near-optimal approximation for step size with no dependence on λk\lambda_{k} under the assumption that β/λk=Ω(n)\beta/\lambda_{k}=\Omega(n), which is valid in many practical settings such as in kernel learning with positive definite kernels.

We first characterize exactly the optimal step size and the resulting g(m)g^{*}(m).

For every batch size mm, the optimal step size function η(m)\eta^{*}(m) and convergence rate function g(m)g^{*}(m) are given by:

Note that if λ1=λk\lambda_{1}=\lambda_{k}, then the first case in each expression will be valid for all m1m\geq 1.

The proof of the above theorem follows from the following two lemmas.

Let η0(m)2mβ+(m1)(λ1+λk)\eta_{0}(m)\triangleq\frac{2m}{\beta+(m-1)(\lambda_{1}+\lambda_{k})}, and let η1(m)2mβ+(m1)λ1\eta_{1}(m)\triangleq\frac{2m}{\beta+(m-1)\lambda_{1}}. Then,

For any fixed m1m\geq 1 and η<η1(m),\eta<\eta_{1}(m), observe that g(λ;m,η)g(\lambda;m,\eta) is a quadratic function of λ\lambda. Hence, the maximum must occur at either λ=λk\lambda=\lambda_{k} or λ=λ1\lambda=\lambda_{1}. Define gI(m,η)g(λk;m,η)g^{\textup{I}}(m,\eta)\triangleq g(\lambda_{k};m,\eta) and gII(m,η)g(λ1;m,η)g^{\textup{II}}(m,\eta)\triangleq g(\lambda_{1};m,\eta). Now, depending on the value of mm and η\eta, we would either have gI(m,η)gII(m,η)g^{\textup{I}}(m,\eta)\geq g^{\textup{II}}(m,\eta) or gI(m,η)<gII(m,η)g^{\textup{I}}(m,\eta)<g^{\textup{II}}(m,\eta). In particular, it is not hard to show that

where η0(m)2mβ+(m1)(λ1+λk)\eta_{0}(m)\triangleq\frac{2m}{\beta+(m-1)(\lambda_{1}+\lambda_{k})}. This completes the proof. ∎

For all m1m\geq 1, gI(m,ηI(m))gII(m,ηII(m))g^{\textup{I}}\left(m,\eta^{\textup{I}}(m)\right)\leq g^{\textup{II}}\left(m,\eta^{\textup{II}}(m)\right).

For all m1m\geq 1, ηI(m)=η(m)\eta^{\textup{I}}(m)=\eta^{*}(m) and gI(m,ηI(m))=g(m)g^{\textup{I}}\left(m,\eta^{\textup{I}}(m)\right)=g^{*}(m), where η(m)\eta^{*}(m) and g(m)g^{*}(m) are as given by (19) and (20), respectively, (in Theorem 4).

First, consider gI(m,η)g^{\textup{I}}(m,\eta). For any fixed mm, it is not hard to show that the minimizer of gI(m,η)g^{\textup{I}}(m,\eta) as a function of η\eta, constrained to ηη0(m)\eta\leq\eta_{0}(m), is given by min(η0(m),mβ+(m1)λk)ηI(m)\min\left(\eta_{0}(m),\frac{m}{\beta+(m-1)\lambda_{k}}\right)\triangleq\eta^{\textup{I}}(m). That is,

Substituting η=ηI(m)\eta=\eta^{\textup{I}}(m) in gI(m,η)g^{\textup{I}}(m,\eta), we get

Note that ηI(m)\eta^{\textup{I}}(m) and gI(m,ηI(m))g^{\textup{I}}\left(m,\eta^{\textup{I}}(m)\right) are equal to η(m)\eta^{*}(m) and g(m)g^{*}(m) given in Theorem 4, respectively. This proves item 2 of the lemma.

Next, consider gII(m,η)g^{\textup{II}}(m,\eta). Again, for any fixed mm, one can easily show that the minimum of gII(m,η)g^{\textup{II}}(m,\eta) as a function of η\eta, constrained to η0(m)<ηη1(m)\eta_{0}(m)<\eta\leq\eta_{1}(m), is actually achieved at the boundary η=η0(m)\eta=\eta_{0}(m). Hence, ηII(m)=η0(m)\eta^{\textup{II}}(m)=\eta_{0}(m). Substituting this in gII(m,η)g^{\textup{II}}(m,\eta), we get

We conclude the proof by showing that for all m1m\geq 1, gI(m,ηI(m))gII(m,ηII(m)).g^{\textup{I}}\left(m,\eta^{\textup{I}}(m)\right)\leq g^{\textup{II}}\left(m,\eta^{\textup{II}}(m)\right). Note that for m>βλ1λk+1,m>\frac{\beta}{\lambda_{1}-\lambda_{k}}+1, gI(m,ηI(m))g^{\textup{I}}\left(m,\eta^{\textup{I}}(m)\right) and gII(m,ηII(m))g^{\textup{II}}\left(m,\eta^{\textup{II}}(m)\right) are identical. For mβλ1λk+1,m\leq\frac{\beta}{\lambda_{1}-\lambda_{k}}+1, given the expressions above, one can verify that gI(m,ηI(m))gII(m,ηI(m))g^{\textup{I}}\left(m,\eta^{\textup{I}}(m)\right)\leq g^{\textup{II}}\left(m,\eta^{\textup{I}}(m)\right).

Given Lemma 1 and item 1 of Lemma 2, it follows that ηI(m)\eta^{\textup{I}}(m) is the minimizer η(m)\eta^{*}(m) given by (18). Item 2 of Lemma 2 concludes the proof of the theorem.

Nearly optimal step size with no dependence on λk\lambda_{k}: In practice, it is usually easy to obtain a good estimate for λ1\lambda_{1}, but it is hard to reliably estimate λk\lambda_{k} which is typically much smaller than λ1\lambda_{1} (e.g., [CCSL16]). That is why one would want to avoid dependence on λk\lambda_{k} in practical SGD algorithms. Under a mild assumption which is typically valid in practice, we can easily find an accurate approximation η^(m)\hat{\eta}(m) of optimal η(m)\eta^{*}(m) that depends only on λ1\lambda_{1} and β\beta. Namely, we assume that λk/β1/n\lambda_{k}/\beta\leq 1/n. In particular, this is always true in kernel learning with positive definite kernels, when the data points are distinct.

The following theorem provides such approximation resulting in a nearly optimal convergence rate g^(m)\hat{g}(m).

Suppose that λk/β1/n\lambda_{k}/\beta\leq 1/n. Let η^(m)\hat{\eta}(m) be defined as:

Then, the step size η^(m)\hat{\eta}(m) yields the following upper bound on g(m,η^(m))g\left(m,\hat{\eta}(m)\right), denoted as g^(m)\hat{g}(m):

The proof easily follows by observing that if λk/β1/n\lambda_{k}/\beta\leq 1/n, then η^(m)\hat{\eta}(m) lies in the feasible region for the minimization problem in (18). In particular, η^(m)η0(m)\hat{\eta}(m)\leq\eta_{0}(m), where η0(m)\eta_{0}(m) is as defined in Lemma 1. The upper bound g^(m)\hat{g}\left(m\right) follows from substituting η^(m)\hat{\eta}(m) in gI(m,η)g^{\textup{I}}(m,\eta) defined in Lemma 1, then upper-bounding the resulting expression. ∎

It is easy to see that the convergence rate g^(m)\hat{g}(m) resulting from the step size η^\hat{\eta} is at most factor 1+O(m/n)1+O(m/n) slower than the optimal rate g(m)g^{*}(m). This factor is negligible when mnm\ll n. Since we expect nβn\gg\beta, we can further approximate η^(m)m/β\hat{\eta}(m)\approx m/\beta when mβ/λ1m\lessapprox\beta/\lambda_{1} and η^2mβ+(m1)λ1\hat{\eta}\approx\frac{2m}{\beta+(m-1)\lambda_{1}} when mβ/λ1m\gtrapprox\beta/\lambda_{1}.

4 Batch size selection

In this section, we will derive the optimal batch size given a fixed computational budget in terms of the computational efficiency defined as the number of gradient computations to obtain a fixed desired accuracy. We will show that single-point batch is in fact optimal in that setting. Moreover, we will show that any mini-batch size in the range from 11 to a certain constant mm^{*} independent of nn, is nearly optimal in terms of gradient computations. Interestingly, for values beyond mm^{*} the computational efficiency drops sharply. This result has direct implications for the batch size selection in parallel computation.

Suppose we are limited by a fixed number of gradient computations. Then, what would be the batch size that yields the least approximation error? Equivalently, suppose we are required to achieve a certain target accuracy ϵ\epsilon (i.e., want to reach parameter w^\hat{{\boldsymbol{w}}} such that L(w^)L(w)ϵ\mathcal{L}(\hat{{\boldsymbol{w}}})-\mathcal{L}({\boldsymbol{w}}^{*})\leq\epsilon). Then, again, what would be the optimal batch size that yields the least amount of computation.

Suppose we are being charged a unit cost for each gradient computation, then it is not hard to see that the cost function we seek to minimize is g(m)1mg^{*}(m)^{\frac{1}{m}}, where g(m)g^{*}(m) is as given by Theorem 4. To see this, note that for a batch size mm, the number of iterations to reach a fixed desired accuracy is t(m)=constantlog(1/g(m))t(m)=\frac{\mathsf{constant}}{\log(1/g^{*}(m))}. Hence, the computation cost is mt(m)=constantlog(1/g(m))1/mm\cdot t(m)=\frac{\mathsf{constant}}{\log(1/g^{*}(m))^{1/m}}. Hence, minimizing the computation cost is tantamount to minimizing g(m)1/m.g^{*}(m)^{1/m}. The following theorem shows that the exact minimizer is m=1m=1. Later, we will see that any value for mm from 22 to β/λ1\approx\beta/\lambda_{1} is actually not far from optimal. So, if we have cheap or free computation available (e.g., parallel computation), then it would make sense to choose mβ/λ1m\approx\beta/\lambda_{1}. We will provide more details in the following subsection.

When we are charged a unit cost per gradient computation, the batch size that minimizes the overall computational cost required to achieve a fixed accuracy (i.e., maximizes the computational efficiency) is m=1m=1. Namely,

The detailed and precise proof is deferred to the appendix. Here, we give a less formal but more intuitive argument based on a reasonable approximation for g(m)g^{*}(m). Such approximation in fact is valid in most of the practical settings. In the full version of this paper, we give an exact and detailed analysis. Note that g(m)g^{*}(m) can be written as 1λkβs(m)1-\frac{\lambda_{k}}{\beta}s(m), where s(m)s(m) is given by

Note that s(m)s(m) defined in (23) indeed captures the speed-up factor we gain in convergence relative to standard SGD (with m=1m=1) where the convergence is dictated by λk/β\lambda_{k}/\beta. Now, note that g(m)1/meλk/βs(m)/mg^{*}(m)^{1/m}\approx e^{-\lambda_{k}/\beta\cdot s(m)/m}. This approximation becomes very accurate when λkλ1\lambda_{k}\ll\lambda_{1}, which is typically the case for most of the practical settings where λ1/λkn\lambda_{1}/\lambda_{k}\approx n and nn is very large. Assuming that this is the case (for the sake of this intuitive argument), minimizing g(m)1/mg^{*}(m)^{1/m} becomes equivalent to maximizing s(m)/ms(m)/m. Now, note that when mβλ1λk+1,m\leq\frac{\beta}{\lambda_{1}-\lambda_{k}}+1, then s(m)/m=11+(m1)λkβs(m)/m=\frac{1}{1+(m-1)\frac{\lambda_{k}}{\beta}}, which is decreasing in mm. Hence, for mβλ1λk+1,m\leq\frac{\beta}{\lambda_{1}-\lambda_{k}}+1, we have s(m)/ms(1)=1s(m)/m\leq s(1)=1. On the other hand, when m>βλ1λk+1,m>\frac{\beta}{\lambda_{1}-\lambda_{k}}+1, we have

which is also decreasing in mm, and hence, it’s upper bounded by its value at m=mβλ1λk+1m=m^{*}\triangleq\frac{\beta}{\lambda_{1}-\lambda_{k}}+1. By direct substitution and simple cancellations, we can show that s(m)/mλ1λkλ1<1s(m^{*})/{m^{*}}\leq\frac{\lambda_{1}-\lambda_{k}}{\lambda_{1}}<1. Thus, m=1m=1 is optimal.

One may wonder whether the above result is valid if the near-optimal step size η^(m)\hat{\eta}(m) (that does not depend on λk\lambda_{k}) is used. That is, one may ask whether the same optimality result is valid if the near optimal error rate function g^(m)\hat{g}(m) is used instead of g(m)g^{*}(m) in Theorem 6. Indeed, we show that the same optimality remains true even if computational efficiency is measured with respect to g^(m)\hat{g}(m). This is formally stated in the following theorem.

When the near-optimal step size η^(m)\hat{\eta}(m) is used (and assuming that λk/β1/n)\lambda_{k}/\beta\leq 1/n), the batch size that minimizes the overall computational cost required to achieve a fixed accuracy is m=1m=1. Namely,

The proof of the above theorem follows similar lines of the proof of Theorem 6.

4.2 Near optimal larger batch sizes

Suppose that several gradient computations can be performed in parallel. Sometimes doubling the number of machines used in parallel can halve the number of iterations needed to reach a fixed desired accuracy. Such observation has motivated many works to use large batch size with distributed synchronized SGD [CMBJ16, GDG+17, YGG17, SKL17]. One critical problem in this large batch setting is how to choose the step size. To keep the same covariance, [BCN16, Li17, HHS17] choose the step size ηm\eta\sim\sqrt{m} for batch size mm. While [Kri14, GDG+17, YGG17, SKL17] have observed that rescaling the step size ηm\eta\sim m works well in practice for not too large mm. To explain these observations, we directly connect the parallelism, or the batch size mm, to the required number of iterations t(m)t(m) defined previously. It turns out that (a) when the batch size is small, doubling the size will almost halve the required iterations; (b) after the batch size surpasses certain value, increasing the size to any amount would only reduce the required iterations by at most a constant factor.

Linear scaling regime (mβλ1λk+1m\leq\frac{\beta}{\lambda_{1}-\lambda_{k}}+1): This is the regime where increasing the batch size mm will quickly drive down t(m)t(m) needed to reach certain accuracy. When λkλ1\lambda_{k}\ll\lambda_{1}, s(m)ms(m)\approx m, which suggests t(m/2)2t(m)t(m/2)\approx 2\cdot t(m). In other words, doubling the batch size in this regime will roughly halve the number of iterations needed. Note that we choose step size ηmβ+(m1)λk\eta\leftarrow\frac{m}{\beta+(m-1)\lambda_{k}}. When λkβnλ1\lambda_{k}\leq\frac{\beta}{n}\ll\lambda_{1}, ηm\eta\sim m, which is consistent with the linear scaling heuristic used in [Kri14, GDG+17, SKL17]. In this case, the largest batch size in the linear scaling regime can be practically calculated through

Saturation regime (m>βλ1λk+1m>\frac{\beta}{\lambda_{1}-\lambda_{k}}+1): Increasing batch size in this regime becomes much less beneficial. Although s(m)s(m) is monotonically increasing, it is upper bounded by limms(m)=4βλ1\lim_{m\rightarrow\infty}{s(m)}=\frac{4\beta}{\lambda_{1}}. In fact, since t(βλ1λk+1)/limmt(m)<4t(\frac{\beta}{\lambda_{1}-\lambda_{k}}+1)/\lim_{m\rightarrow\infty}{t(m)}<4 for small λk\lambda_{k}, no batch size in this regime can reduce the needed iterations by a factor of more than 4.

Experimental Results

This section will provide empirical evidence for our theoretical results on the effectiveness of mini-batch SGD in the interpolated setting. We first consider a kernel learning problem, where the parameters β\beta, λ1\lambda_{1}, and mm^{*} can be computed efficiently (see [MB17] for details). In all experiments we set the step size to be η^\hat{\eta} defined in (21).

We observe empirically that increasing the step size from η^\hat{\eta} to 2η^2\,\hat{\eta} consistently leads to divergence, indicating that η^\hat{\eta} differs from the optimal step size by at most a factor of 2. This is consistent with our Theorem 5 on near-optimal step size.

Theorem 4 suggests that SGD using batch size mm^{*} defined in (24) can reach the same error as GD using at most 44 times the number of iterations. This is consistent with our experimental results for MNIST, HINT-S [HYWW13], and TIMIT, shown in Figure 3. Moreover, in line with our analysis, SGD with batch size larger than mm^{*} but still much smaller than the data size, converges nearly identically to full gradient descent.

Since our analysis is concerned with the training error, only the training error is reported here. For completeness, we report the test error in Appendix C. As consistently observed in such over-parametrized settings, test error decreases with the training error.

2 Optimality of batch size m=1𝑚1m=1

Our theoretical results, Theorem 6 and Theorem 7 show that m=1m=1 achieves the optimal computational efficiency. Note for a given batch size, the corresponding optimal step size is chosen according to equation (21). The experiments in Figure 4 show that m=1m=1 indeed achieves the lowest error for any fixed number of epochs.

3 Linear scaling and saturation regimes

In the interpolation regime, Theorem 6 shows linear scaling for mini-batch sizes up to a (typically small) “critical” batch size mm^{*} defined in (24) followed by the saturation regime. In Figure 4 we plot the training error for different batch sizes as a function of the number of epochs. Note that the number of epochs is proportional to the amount of computation measured in terms of gradient evaluations. The linear scaling regime (1mm1\leq m\leq m^{*}) is reflected in the small difference in the training error for m=1m=1 and m=mm=m^{*} in Figure 4 (the bottom three curves. As expected from our theoretical results, they have similar computational efficiency. On the other hand, we see that large mini-batch sizes (mmm\gg m^{*}) require drastically more computations, which is the saturation phenomenon reflected in the top two curves.

Relation to the “linear scaling rule” in neural networks. A number of recent large scale neural network methods including [Kri14, CMBJ16, GDG+17] use the “linear scaling rule” to accelerate training using parallel computation. After the initial “warmup” stage to find a good region of parameters, this rule suggest increasing the step size to a level proportional to the mini-batch size mm. In spite of the wide adoption and effectiveness of this technique, there has been no satisfactory explanation [Kri14] as the usual variance-based analysis suggests increasing the step size by a factor of m\sqrt{m} instead of mm [BCN16]. We note that this “linear scaling” can be explained by our analysis, assuming that the warmup stage ends up in a neighborhood of an interpolating minimum.

4 Interpolation in kernel methods

To give additional evidence of the interpolation in the over-parametrized settings, we provide empirical results showing that this is indeed the case in kernel learning. We give two examples: Laplace kernel trained using EigenPro [MB17] on MNIST [LBBH98] and on a subset (of size 51045\cdot 10^{4}) of TIMIT [GLF+93]. The histograms in Figure 5 show the number of points with a given loss calculated as yif(xi)2\left\lVert{\boldsymbol{y}}_{i}-f({\boldsymbol{x}}_{i})\right\rVert^{2} (on feature vector xi{\boldsymbol{x}}_{i} and corresponding binary label vector yi{\boldsymbol{y}}_{i}). As evident from the histograms, the test loss keeps decreasing as we converge to an interpolated solution.

It is interesting to observe that even when SGD is slow to converge to the interpolated

solution, our theoretical bounds still accurately describe the relative efficiency of different mini-batch sizes. We examine two different settings: interpolation with Laplace kernel trained using EigenPro [MB17] on HINT-S (Figure 6) and with Gaussian kernel in Figure 6. As clear from the figures Laplace kernel converges to the interpolated solution much faster than the Gaussian. However, as our experiments depicted in Figure 7 show, relative computational efficiency of different batch sizes for these two settings is very similar. As before, we plot the training error against the number of epochs (which is proportional to computation) for different batch sizes. Note that while the scale of the error is very different for these two settings, the profiles of the curves are remarkably similar.

References

Appendix A Proof of Claim 1

Let {e1,,ed}\{{\boldsymbol{e}}_{1},\ldots,{\boldsymbol{e}}_{d}\} denote the eigen-basis of HH corresponding to eigenvalues λ1λk>0=λk+1==λd\lambda_{1}\geq\cdots\geq\lambda_{k}>0=\lambda_{k+1}=\cdots=\lambda_{d}. For every i{1,,n},i\in\{1,\ldots,n\}, let xi=j=1dαi,jej{\boldsymbol{x}}_{i}=\sum_{j=1}^{d}\alpha_{i,j}{\boldsymbol{e}}_{j} be the expansion of xi{\boldsymbol{x}}_{i} w.r.t. the eigen-basis of HH.

where the last equality follows from expanding each xi{\boldsymbol{x}}_{i} w.r.t. the eigen-basis of HH. Thus,

The proof immediately follows from (26) since for any uH{\boldsymbol{u}}\in{\mathcal{H}}, we can write u=Pu+Qu{\boldsymbol{u}}=\mathbf{P}_{{\boldsymbol{u}}}+\mathbf{Q}_{{\boldsymbol{u}}} where Pu\mathbf{P}_{{\boldsymbol{u}}} and Qu\mathbf{Q}_{{\boldsymbol{u}}} denote the projections of u{\boldsymbol{u}} onto Span{e1,,ek}\mathsf{Span}\{{\boldsymbol{e}}_{1},\ldots,{\boldsymbol{e}}_{k}\} and Span{ek+1,,ed}\mathsf{Span}\{{\boldsymbol{e}}_{k+1},\ldots,{\boldsymbol{e}}_{d}\}, respectively. Hence, (26) implies that HmPuSpan{e1,,ek}H_{m}\mathbf{P}_{{\boldsymbol{u}}}\in\mathsf{Span}\{{\boldsymbol{e}}_{1},\ldots,{\boldsymbol{e}}_{k}\} and HmQu=0H_{m}\mathbf{Q}_{{\boldsymbol{u}}}=0, which proves the claim.

Appendix B Proof of Theorem 6

Here, we will provide an exact analysis for the optimality of batch size m=1m=1 for the cost function g(m)1/mg^{*}(m)^{1/m}, which, as discussed in Section 4.4.1, captures the total computational cost required to achieve any fixed target accuracy (in a model with no parallel computation).

We prove this theorem by showing that g(m)1mg^{*}(m)^{\frac{1}{m}} is strictly increasing for m1m\geq 1. We do this via the following two simple lemmas. First, we introduce the following notation.

g1(m)1mg_{1}(m)^{\frac{1}{m}} is strictly increasing for m1m\geq 1.

Define T(m)1mln(1/g1(m))T(m)\triangleq\frac{1}{m}\ln(1/g_{1}(m)). We will show that T(m)T(m) is strictly decreasing for m1m\geq 1, which is tantamount to showing that g1(m)1mg_{1}(m)^{\frac{1}{m}} is strictly increasing over m1m\geq 1. For more compact notation, let’s define τβλkβ\tau\triangleq\frac{\beta-\lambda_{k}}{\beta}, and τˉ=1τ\bar{\tau}=1-\tau. First note that, after straightforward simplification, g(m)=ττ+τˉmg^{*}(m)=\frac{\tau}{\tau+\bar{\tau}m}. Hence, T(m)=1mln(1+um),T(m)=\frac{1}{m}\ln(1+um), where uτˉτ=λkβλk>0u\triangleq\frac{\bar{\tau}}{\tau}=\frac{\lambda_{k}}{\beta-\lambda_{k}}>0. Now, it is not hard to see that T(m)T(m) is strictly decreasing since the function 1xln(1+ux)\frac{1}{x}\ln(1+ux) is strictly decreasing in xx as long as u>0u>0. ∎

g1(m)g2(m)g_{1}(m)\leq g_{2}(m), for all m1.m\geq 1.

Proving the lemma is equivalent to proving 4m(m1)λ1λk(β+(m1)(λ1+λk)2)<mλkβ+(m1)λk\frac{4m(m-1)\lambda_{1}\lambda_{k}}{(\beta+(m-1)(\lambda_{1}+\lambda_{k})^{2})}<\frac{m\lambda_{k}}{\beta+(m-1)\lambda_{k}}. After direct manipulation, this is equivalent to showing that

which is true for all mm since the left-hand side is a complete square: ((m1)(λ1λk)β)2\left((m-1)(\lambda_{1}-\lambda_{k})-\beta\right)^{2}. ∎

Given these two simple lemmas, observe that

where the first and last equalities follow from the fact that g1(m)=g(m)g_{1}(m)=g^{*}(m) for m[1,βλ1λk+1]m\in[1,\frac{\beta}{\lambda_{1}-\lambda_{k}}+1], and the second inequality follows from Lemma 3. Also, observe that

where the third inequality follows from Lemma 4, and the last equality follows from the fact that g2(m)=g(m)g_{2}(m)=g^{*}(m) for m>βλ1λk+1.m>\frac{\beta}{\lambda_{1}-\lambda_{k}}+1. Putting (27) and (28) together, we have g(1)g(m),g^{*}(1)\leq g^{*}(m), for all m1m\geq 1, which completes the proof.

Appendix C Experiments: Comparison of Train and Test losses