The Power of Interpolation: Understanding the Effectiveness of SGD in Modern Over-parametrized Learning
Siyuan Ma, Raef Bassily, Mikhail Belkin
Introduction
In recent years, Stochastic Gradient Descent (SGD) with a small mini-batch size has become the backbone of machine learning, used in nearly all large-scale applications of machine learning methods, notably in conjunction with deep neural networks. Mini-batch SGD is a first order method which, instead of computing the full gradient of , computes the gradient with respect to a certain subset of the data points, often chosen sequentially. In practice small mini-batch SGD consistently outperforms full gradient descent (GD) by a large factor in terms of the computations required to achieve certain accuracy. However, the theoretical evidence has been mixed. While SGD needs less computations per iteration, most analyses suggest that it requires adaptive step sizes and has the rate of convergence that is far slower than that of GD, making computational efficiency comparisons difficult.
In this paper, we explain the reasons for the effectiveness of SGD by taking a different perspective. We note that most of modern machine learning, especially deep learning, relies on classifiers which are trained to achieve near zero classification and regression losses on the training data. Indeed, the goal of achieving near-perfect fit on the training set is stated explicitly by the practitioners as a best practice in supervised learningPotentially using regularization at a later stage, after a near-perfect fit is achieved., see, e.g., the tutorial [Sal17]. The ability to achieve near-zero loss is provided by over-parametrization. The number of parameters for most deep architectures is very large and often exceeds by far the size of the datasets used for training (see, e.g., [CPC16] for a summary of different architectures). There is significant theoretical and empirical evidence that in such over-parametrized systems most or all local minima are also global and hence correspond to the regime where the output of the learning algorithm matches the labels exactly [GAGN15, CCSL16, ZBH+16, HLWvdM16, SEG+17, BFT17]. Since continuous loss functions are typically used for training, the resulting function interpolates the dataMost of these architectures should be able to achieve perfect interpolation, . In practice, of course, it is not possible even for linear systems due to the computational and numerical limitations., i.e., .
While we do not yet understand why these interpolated classifiers generalize so well to unseen data, there is ample empirical evidence for their excellent generalization performance in deep neural networks [GAGN15, CCSL16, ZBH+16, HLWvdM16, SEG+17], kernel machines [belkin2018understand] and boosting [SFBL98]. In this paper we look at the significant computational implications of this startling phenomenon for stochastic gradient descent.
Our first key observation is that in the interpolated regime SGD with fixed step size converges exponentially fast for convex loss functions. The results showing exponential convergence of SGD when the optimal solution minimizes the loss function at each point go back to the Kaczmarz method [Kac37] for quadratic functions, more recently analyzed in [SV09]. For the general convex case, it was first proved in [MB11]. The rate was later improved in [NWS14]. However, to the best of our knowledge, exponential convergence in that regime has not been connected to over-parametrization and interpolation in modern machine learning. Still, exponential convergence by itself does not allow us to make any comparisons between the computational efficiency of SGD with different mini-batch sizes and full gradient descent, as the existing results do not depend on the mini-batch size . This dependence is crucial for understanding SGD, as small mini-batch SGD seems to dramatically outperform full gradient descent in nearly all applications. Motivated by this, in this paper we provide an explanation for the empirically observed efficiency of small mini-batch SGD. We provide a detailed analysis for the rates of convergence and computational efficiency for different mini-batch sizes and a discussion of its implications in the context of modern machine learning.
We first analyze convergence of mini-batch SGD for convex loss functions as a function of the batch size . We show that there is a critical mini-batch size that is nearly independent on , such that the following holds:
(linear scaling) One SGD iteration with mini-batch of size is equivalent to iterations of mini-batch of size one up to a multiplicative constant close to .
(saturation) One SGD iterations with a mini-batch of size is nearly (up to a small constant) as effective as one iteration of full gradient descent.
We see that the critical mini-batch size can be viewed as the limit for the effective parallelization of mini-batch computations. If an iteration with mini-batch of size can be computed in parallel, it is nearly equivalent to sequential steps with mini-batch of size . For parallel computation has limited added value.
Next, for the quadratic loss function, we obtain a sharp characterization of these regimes based on an explicit derivation of optimal step size as a function of . In particular, in this case we show that the critical mini-batch size is given by , where is the Hessian at the minimizer and is its spectral norm.
Our result shows that is nearly independent of the data size (depending only on the properties of the Hessian). Thus SGD with mini-batch size (typically a small constant) gives essentially the same convergence per iteration as full gradient descent, implying acceleration by a factor of over GD per unit of computation.
We also show that a mini-batch of size one is optimal in terms of computations required to achieve a given error. Our theoretical results are based on upper bounds which we show to be tight in the quadratic case and nearly tight in the general convex case.
There have been work on understanding the interplay between the mini-batch size and computational efficiency, including [TBRS13, LZCS14, YPL+18] in the standard non-interpolated regime. However, in that setting the issue of bridging the exponential convergence of full GD and the much slower convergence rates of mini-batch SGD is harder to resolve, requiring extra components, such as tail averaging [JKK+16] (for quadratic loss).
We provide experimental evidence corroborating this on real data. In particular, we demonstrate the regimes of linear scaling and saturation and also show that on real data is in line with our estimate. It is typically several orders of magnitude smaller than the data size implying a computational advantage of at least factor over full gradient descent in realistic scenarios in the over-parametrized (or fully parametrized) setting. We believe this sheds light on the impressive effectiveness of SGD observed in many real-world situation and is the reason why full gradient descent is rarely, if ever, used. In particular, the “linear scaling rule” recently used in deep convolutional networks [Kri14, GDG+17, YGG17, SKL17] is consistent with our theoretical analyses.
The rest of the paper is organized as follows:
In Section 3, we analyze the fast convergence of mini-batch SGD and discuss some implications for the variance reduction techniques. It turns out that in the interpolated regime, simple SGD with constant step size is equally or more effective than the more complex variance reduction methods.
Section 4 contains the analysis of the special case of quadratic losses, where we obtain optimal convergence rates of mini-batch SGD, and derive the optimal step size as a function of the mini-batch size. We also analyze the computational efficiency as a function of the mini-batch size.
In Section 5 we provide experimental evidence using several datasets. We show that the experimental results correspond closely to the behavior predicted by our bounds. We also briefly discuss the connection to the linear scaling rule in neural networks.
Preliminaries
Before we start our technical discussion, we briefly overview some standard notions in convex analysis. Here, we will focus on differentiable convex functions, however, the definitions below extend to general functions simply by replacing the gradient of the function at a given point by the set of all sub-gradients at that point. In fact, since in this paper we only consider smooth functions, differentiability is directly implied.
Interpolation and Fast SGD: Convex Loss
Next, we state our key assumption in this work. This assumption describes the interpolation setting, which is aligned with what we usually observe in over-parametrized settings in modern machine learning.
where is the size of a mini-batch of data points whose indices are drawn uniformly with replacement at each iteration from .
The theorem below shows exponential convergence for mini-batch SGD in the interpolated regime.
By the -smoothness of , we have
Then we prove inequality (2) by showing that
Applying the expectation to the inner product and using the -strong convexity of , we have
From (7) and (8), it is easy to see that for any when choosing , we have
We see that for , reaches its minimum (for fastest convergence) when . Thus we choose corresponding to the best and obtain,
Incorporating this result with inequality (3), we have
For , this theorem is a special case of Theorem 2.1 in [NWS14], which is a sharper version of Theorem 1 in [MB11].
Speedup factor. Let be the number of iterations needed to reach a desired accuracy with batch size . Assuming , the speed up factor , which measures the number of iterations saved by using larger batch, is
Critical batch size . By estimating the speedup factor for each batch size , we directly obtain
Linear scaling regime: one iteration of batch size is nearly equivalent to iterations of batch size .
Saturation regime: one iteration with batch size is nearly equivalent to one full gradient iteration.
We give a sharper analysis for the case of quadratic loss in Section 4.
For general convex optimization, a set of important stochastic methods [RSB12, JZ13, DBLJ14, XZ14, AZ16] have been proposed to achieve exponential (linear) convergence rate with constant step size. The effectiveness of these methods derives from their ability to reduce the stochastic variance caused by sampling. In a general convex setting, this variance prevents SGD from both adopting a constant step size and achieving an exponential convergence rate.
Remarkably, in the interpolated regime, Theorem 1 implies that SGD obtains the benefits of variance reduction “for free" without the need for any modification or extra information (e.g., full gradient computations for variance reduction). The table on the right compares the convergence of SGD in the interpolation setting with several popular variance reduction methods. Overall, SGD has the largest step size and achieves the fastest convergence rate without the need for any further assumptions. The only comparable or faster rate is given by Katyusha, which is an accelerated SGD method combining momentum and variance reduction for faster convergence.
How Fast is Fast SGD: Analysis of Step, Mini-batch Sizes and Computational Efficiency for Quadratic Loss
In this section, we analyze the convergence of mini-batch SGD for quadratic losses. We will consider the following key questions:
What is the optimal convergence rate of mini-batch SGD and the corresponding step size as a function of (size of mini-batch)?
What is the computational efficiency of different batch sizes and how do they compare to full GD?
The case of quadratic losses covers over-parametrized linear or kernel regression with a positive definite kernel. The quadratic case also captures general smooth convex functions in the neighborhood of a minimum where higher order terms can be ignored.
Consider the problem of minimizing the sum of squares
For any , let denote the projection of unto the subspace spanned by and denote the projection of unto the subspace spanned by . That is, is the decomposition of into two orthogonal components: its projection onto (i.e., the range space of , which is the subspace spanned by ) and its projection onto (i.e., the null space of , which is the subspace spanned by ). Hence, the above quadratic loss can be written as
To minimize the loss in this setting, consider the following SGD update with mini-batch of size and step size :
Let . Observe that we can write (11) as
Now, we make the following simple claim (whose proof is given in the appendix).
This also implies that for any we must have .
By the above claim, the update equation (12) can be decomposed into two components:
From (10), it follows that for any iteration , the target loss function is not affected at all by , that is, is the only component that matters. Hence, by (13-14), we only need to consider the effective SGD update (13), i.e., the update in the span of .
1 Upper bound on the expected empirical loss
The following theorem provides an upper bound on the expected empirical loss after iterations of mini-batch SGD whose update step is given by (11).
Let In the interpolation setting, for any the mini-batch SGD with update step (11) yields the following guarantee
By reordering terms in the update equation 13 and using the independence of and , the variance in the effective component of the parameter update can be written as
Let . Then the variance is bounded as
Furthermore, the convergence rate relies on the eigenvalues of . Let be a non-zero eigenvalue of , then the corresponding eigenvalue of is given by
When the step size and mini-batch size are chosen (satisfying constraint 16), we have
2 Tightness of the bound on the expected empirical loss
We now show that our upper bound given above is indeed tight in the interpolation setting for the class of quadratic loss functions defined in (9). Namely, we give a specific instance of (9) where the upper bound in Theorem 2 is tight.
which shows that inequality (17) is also achieved with equality in that setting. This completes the proof. ∎
Remark. From the experimental results, it appears that our upper bound can be close to tight even in some settings when the eigenvalues are far apart. We plan to investigate this phenomenon further.
3 Optimal step size for a given batch size
To fully answer the first question we posed at the beginning of this section, we will derive an optimal rule for choosing the step size as a function of the batch size. Specifically, we want to find step size to achieve fastest convergence. Given Theorem 2, our task reduces to finding the minimizer
Let denote the resulting minimum, that is, . The resulting expression for the minimizer generally depends on the least non-zero eigenvalue of the Hessian matrix. In situations where we don’t have a good estimate for this eigenvalue (which can be close to zero in practice), one would rather have a step size that is independent of . In Theorem 5, we give a near-optimal approximation for step size with no dependence on under the assumption that , which is valid in many practical settings such as in kernel learning with positive definite kernels.
We first characterize exactly the optimal step size and the resulting .
For every batch size , the optimal step size function and convergence rate function are given by:
Note that if , then the first case in each expression will be valid for all .
The proof of the above theorem follows from the following two lemmas.
Let , and let . Then,
For any fixed and observe that is a quadratic function of . Hence, the maximum must occur at either or . Define and . Now, depending on the value of and , we would either have or . In particular, it is not hard to show that
where . This completes the proof. ∎
For all , .
For all , and , where and are as given by (19) and (20), respectively, (in Theorem 4).
First, consider . For any fixed , it is not hard to show that the minimizer of as a function of , constrained to , is given by . That is,
Substituting in , we get
Note that and are equal to and given in Theorem 4, respectively. This proves item 2 of the lemma.
Next, consider . Again, for any fixed , one can easily show that the minimum of as a function of , constrained to , is actually achieved at the boundary . Hence, . Substituting this in , we get
We conclude the proof by showing that for all , Note that for and are identical. For given the expressions above, one can verify that .
Given Lemma 1 and item 1 of Lemma 2, it follows that is the minimizer given by (18). Item 2 of Lemma 2 concludes the proof of the theorem.
Nearly optimal step size with no dependence on : In practice, it is usually easy to obtain a good estimate for , but it is hard to reliably estimate which is typically much smaller than (e.g., [CCSL16]). That is why one would want to avoid dependence on in practical SGD algorithms. Under a mild assumption which is typically valid in practice, we can easily find an accurate approximation of optimal that depends only on and . Namely, we assume that . In particular, this is always true in kernel learning with positive definite kernels, when the data points are distinct.
The following theorem provides such approximation resulting in a nearly optimal convergence rate .
Suppose that . Let be defined as:
Then, the step size yields the following upper bound on , denoted as :
The proof easily follows by observing that if , then lies in the feasible region for the minimization problem in (18). In particular, , where is as defined in Lemma 1. The upper bound follows from substituting in defined in Lemma 1, then upper-bounding the resulting expression. ∎
It is easy to see that the convergence rate resulting from the step size is at most factor slower than the optimal rate . This factor is negligible when . Since we expect , we can further approximate when and when .
4 Batch size selection
In this section, we will derive the optimal batch size given a fixed computational budget in terms of the computational efficiency defined as the number of gradient computations to obtain a fixed desired accuracy. We will show that single-point batch is in fact optimal in that setting. Moreover, we will show that any mini-batch size in the range from to a certain constant independent of , is nearly optimal in terms of gradient computations. Interestingly, for values beyond the computational efficiency drops sharply. This result has direct implications for the batch size selection in parallel computation.
Suppose we are limited by a fixed number of gradient computations. Then, what would be the batch size that yields the least approximation error? Equivalently, suppose we are required to achieve a certain target accuracy (i.e., want to reach parameter such that ). Then, again, what would be the optimal batch size that yields the least amount of computation.
Suppose we are being charged a unit cost for each gradient computation, then it is not hard to see that the cost function we seek to minimize is , where is as given by Theorem 4. To see this, note that for a batch size , the number of iterations to reach a fixed desired accuracy is . Hence, the computation cost is . Hence, minimizing the computation cost is tantamount to minimizing The following theorem shows that the exact minimizer is . Later, we will see that any value for from to is actually not far from optimal. So, if we have cheap or free computation available (e.g., parallel computation), then it would make sense to choose . We will provide more details in the following subsection.
When we are charged a unit cost per gradient computation, the batch size that minimizes the overall computational cost required to achieve a fixed accuracy (i.e., maximizes the computational efficiency) is . Namely,
The detailed and precise proof is deferred to the appendix. Here, we give a less formal but more intuitive argument based on a reasonable approximation for . Such approximation in fact is valid in most of the practical settings. In the full version of this paper, we give an exact and detailed analysis. Note that can be written as , where is given by
Note that defined in (23) indeed captures the speed-up factor we gain in convergence relative to standard SGD (with ) where the convergence is dictated by . Now, note that . This approximation becomes very accurate when , which is typically the case for most of the practical settings where and is very large. Assuming that this is the case (for the sake of this intuitive argument), minimizing becomes equivalent to maximizing . Now, note that when then , which is decreasing in . Hence, for we have . On the other hand, when we have
which is also decreasing in , and hence, it’s upper bounded by its value at . By direct substitution and simple cancellations, we can show that . Thus, is optimal.
One may wonder whether the above result is valid if the near-optimal step size (that does not depend on ) is used. That is, one may ask whether the same optimality result is valid if the near optimal error rate function is used instead of in Theorem 6. Indeed, we show that the same optimality remains true even if computational efficiency is measured with respect to . This is formally stated in the following theorem.
When the near-optimal step size is used (and assuming that , the batch size that minimizes the overall computational cost required to achieve a fixed accuracy is . Namely,
The proof of the above theorem follows similar lines of the proof of Theorem 6.
4.2 Near optimal larger batch sizes
Suppose that several gradient computations can be performed in parallel. Sometimes doubling the number of machines used in parallel can halve the number of iterations needed to reach a fixed desired accuracy. Such observation has motivated many works to use large batch size with distributed synchronized SGD [CMBJ16, GDG+17, YGG17, SKL17]. One critical problem in this large batch setting is how to choose the step size. To keep the same covariance, [BCN16, Li17, HHS17] choose the step size for batch size . While [Kri14, GDG+17, YGG17, SKL17] have observed that rescaling the step size works well in practice for not too large . To explain these observations, we directly connect the parallelism, or the batch size , to the required number of iterations defined previously. It turns out that (a) when the batch size is small, doubling the size will almost halve the required iterations; (b) after the batch size surpasses certain value, increasing the size to any amount would only reduce the required iterations by at most a constant factor.
Linear scaling regime (): This is the regime where increasing the batch size will quickly drive down needed to reach certain accuracy. When , , which suggests . In other words, doubling the batch size in this regime will roughly halve the number of iterations needed. Note that we choose step size . When , , which is consistent with the linear scaling heuristic used in [Kri14, GDG+17, SKL17]. In this case, the largest batch size in the linear scaling regime can be practically calculated through
Saturation regime (): Increasing batch size in this regime becomes much less beneficial. Although is monotonically increasing, it is upper bounded by . In fact, since for small , no batch size in this regime can reduce the needed iterations by a factor of more than 4.
Experimental Results
This section will provide empirical evidence for our theoretical results on the effectiveness of mini-batch SGD in the interpolated setting. We first consider a kernel learning problem, where the parameters , , and can be computed efficiently (see [MB17] for details). In all experiments we set the step size to be defined in (21).
We observe empirically that increasing the step size from to consistently leads to divergence, indicating that differs from the optimal step size by at most a factor of 2. This is consistent with our Theorem 5 on near-optimal step size.
Theorem 4 suggests that SGD using batch size defined in (24) can reach the same error as GD using at most times the number of iterations. This is consistent with our experimental results for MNIST, HINT-S [HYWW13], and TIMIT, shown in Figure 3. Moreover, in line with our analysis, SGD with batch size larger than but still much smaller than the data size, converges nearly identically to full gradient descent.
Since our analysis is concerned with the training error, only the training error is reported here. For completeness, we report the test error in Appendix C. As consistently observed in such over-parametrized settings, test error decreases with the training error.
2 Optimality of batch size m=1𝑚1m=1
Our theoretical results, Theorem 6 and Theorem 7 show that achieves the optimal computational efficiency. Note for a given batch size, the corresponding optimal step size is chosen according to equation (21). The experiments in Figure 4 show that indeed achieves the lowest error for any fixed number of epochs.
3 Linear scaling and saturation regimes
In the interpolation regime, Theorem 6 shows linear scaling for mini-batch sizes up to a (typically small) “critical” batch size defined in (24) followed by the saturation regime. In Figure 4 we plot the training error for different batch sizes as a function of the number of epochs. Note that the number of epochs is proportional to the amount of computation measured in terms of gradient evaluations. The linear scaling regime () is reflected in the small difference in the training error for and in Figure 4 (the bottom three curves. As expected from our theoretical results, they have similar computational efficiency. On the other hand, we see that large mini-batch sizes () require drastically more computations, which is the saturation phenomenon reflected in the top two curves.
Relation to the “linear scaling rule” in neural networks. A number of recent large scale neural network methods including [Kri14, CMBJ16, GDG+17] use the “linear scaling rule” to accelerate training using parallel computation. After the initial “warmup” stage to find a good region of parameters, this rule suggest increasing the step size to a level proportional to the mini-batch size . In spite of the wide adoption and effectiveness of this technique, there has been no satisfactory explanation [Kri14] as the usual variance-based analysis suggests increasing the step size by a factor of instead of [BCN16]. We note that this “linear scaling” can be explained by our analysis, assuming that the warmup stage ends up in a neighborhood of an interpolating minimum.
4 Interpolation in kernel methods
To give additional evidence of the interpolation in the over-parametrized settings, we provide empirical results showing that this is indeed the case in kernel learning. We give two examples: Laplace kernel trained using EigenPro [MB17] on MNIST [LBBH98] and on a subset (of size ) of TIMIT [GLF+93]. The histograms in Figure 5 show the number of points with a given loss calculated as (on feature vector and corresponding binary label vector ). As evident from the histograms, the test loss keeps decreasing as we converge to an interpolated solution.
It is interesting to observe that even when SGD is slow to converge to the interpolated
solution, our theoretical bounds still accurately describe the relative efficiency of different mini-batch sizes. We examine two different settings: interpolation with Laplace kernel trained using EigenPro [MB17] on HINT-S (Figure 6) and with Gaussian kernel in Figure 6. As clear from the figures Laplace kernel converges to the interpolated solution much faster than the Gaussian. However, as our experiments depicted in Figure 7 show, relative computational efficiency of different batch sizes for these two settings is very similar. As before, we plot the training error against the number of epochs (which is proportional to computation) for different batch sizes. Note that while the scale of the error is very different for these two settings, the profiles of the curves are remarkably similar.
References
Appendix A Proof of Claim 1
Let denote the eigen-basis of corresponding to eigenvalues . For every let be the expansion of w.r.t. the eigen-basis of .
where the last equality follows from expanding each w.r.t. the eigen-basis of . Thus,
The proof immediately follows from (26) since for any , we can write where and denote the projections of onto and , respectively. Hence, (26) implies that and , which proves the claim.
Appendix B Proof of Theorem 6
Here, we will provide an exact analysis for the optimality of batch size for the cost function , which, as discussed in Section 4.4.1, captures the total computational cost required to achieve any fixed target accuracy (in a model with no parallel computation).
We prove this theorem by showing that is strictly increasing for . We do this via the following two simple lemmas. First, we introduce the following notation.
is strictly increasing for .
Define . We will show that is strictly decreasing for , which is tantamount to showing that is strictly increasing over . For more compact notation, let’s define , and . First note that, after straightforward simplification, . Hence, where . Now, it is not hard to see that is strictly decreasing since the function is strictly decreasing in as long as . ∎
, for all
Proving the lemma is equivalent to proving . After direct manipulation, this is equivalent to showing that
which is true for all since the left-hand side is a complete square: . ∎
Given these two simple lemmas, observe that
where the first and last equalities follow from the fact that for , and the second inequality follows from Lemma 3. Also, observe that
where the third inequality follows from Lemma 4, and the last equality follows from the fact that for Putting (27) and (28) together, we have for all , which completes the proof.