WNGrad: Learn the Learning Rate in Gradient Descent
Xiaoxia Wu, Rachel Ward, Léon Bottou
Introduction
In the stochastic setting, the issue of how to choose the learning rate is less resolved. There are different guidelines for setting the learning “schedule” , each guideline having its own justification in the form of a convergence result given a set of structural assumptions on the loss function . The classical Robbins/Monro theory [Robbins and Monro, 1951] says that if the learning rate is chosen such that
Thus, in the stochastic setting, there is no clear “best choice” for the learning rate. In many deep learning problems, where the underlying loss function is highly non-convex, one often tests several different learning rate schedules of the form
where is the maximal diameter of the feasible set, and is the norm of the current gradient or an average of recent gradients; the schedule which works best on the problem at hand is then chosen. Another popular and effective choice is to start with a constant learning rate which gives good empirical convergence results or start with a small one followed by a warmup scheme [Goyal et al., 2017], maintain this constant learning rate for a fixed number of epochs over the training data, then decrease the learning rate , and repeat this process until convergence .
Another line of work on adaptive learning rates [Needell et al., 2014, Zhao and Zhang, 2015] consider importance sampling in stochastic gradient descent in the setting where the loss function can be expressed as a sum of component functions, and provide precise ways for setting different constant learning rates for different component functions based on their Lipschitz constants; if the sampling distribution over the parameters is weighted so that parameters with smaller Lipschitz constants are sampled less frequently, then this reparametrization affords a faster convergence rate, depending on the average Lipschitz constant between all parameters, rather than the largest Lipschitz constant between them. Of course, in practice, the Lipschitz constants are not known in advance, and must be learned along the way.
This begs the question: if we take a step back to the batch/non-stochastic gradient descent setting, is it possible to learn even a single Lipschitz constant, corresponding to the gradient function , so that we can match the convergence rate of gradient descent with optimized constant learning rate which requires knowledge of the Lipschitz constant beforehand? To our knowledge, this question has not been addressed until now.
2 Weight Normalization
To answer this question, we turn to simple reparametrizations of weight vectors in neural networks which have been proposed in recent years and have already gained widespread adaptations in practice due to their effectiveness in accelerating training times without compromising generalization performance, while simultaneously being robust to the tuning of learning rates. The celebrated batch normalization [Ioffe and Szegedy, 2015] accomplishes these objectives by normalizing the means and variances of minibatches in a particular way which reduces the dependence of gradients on the scale of the parameters or their initial values, allowing the use of much higher learning rates without the risk of divergence. Inspired by batch normalization, the weight normalization algorithm [Salimans and Kingma, 2016] was introduced as an even simpler reparametrization, also effective in making the resulting stochastic gradient descent more robust to specified learning rates and initialization. The weight normalization algorithm, roughly speaking, reparametrizes the loss function in polar coordinates, and runs (stochastic) gradient descent with respect to polar coordinates: If the loss function is where is a -dimensional vector, then weight normalization considers instead where is a -dimensional vector, is a scalar, and is the Euclidean norm of . The analog of the weight normalization algorithm in the batch gradient setting would simply be gradient descent in polar coordinates as follows:
where denotes the orthogonal projection of onto the subspace of co-dimension orthogonal to . One important feature of note is that, since the gradient of with respect to is orthogonal to the current direction , the norm grows monotonically with the update, thus effectively producing a dynamically-updated decay in the effective learning rate . More precisely, considering weight normalization in the batch setting, restricted to the unit sphere (fixing ), the gradient update reduces to
3 Our contributions
As a nod to its inspiration, weight normalization, we call this algorithm WNGrad, but note that the update can also be interpreted as a close variant of AdaGrad with the dynamic update applied to a single learning rate; indeed, WNGrad -update satisfies
which matches the coordinate-wise update rule in AdaGrad if is one dimension. Nevertheless, WNGrad update (1.3) offers some insight and advantages over the family or modifications/improvements of AdaGrad update – first, it gives a precise correspondence between the accumulated gradient and current gradient in the update of the . Additionally, it does not require any square root computations, thus making the update more efficient.
In this paper, we provide some basic theoretical guarantees about WNGrad update. Surprisingly, we are able to provide guarantees for the same learning rate update rule in both the batch and stochastic settings.
In the batch gradient descent setting, we show that WNGrad will converge to a weight vector satisfying in at most iterations, if has -Lipschitz smooth gradient. The proof involves showing that if grows up to the critical level it automatically stabilizes, satisfying for all timeC is a constant factor. This should be compared to the standard gradient descent convergence rate using constant learning rate which in the ideal case achieves convergence rate, but which is not guaranteed to converge at all if the learning rate is even slightly too big, Thus, WNGrad is a provably robust variant to gradient descent which is provably robust to the scale of Lipschitz constant, when parameters like the Lipschitz smoothness are not known in advance.
On the other hand, in the stochastic setting, the update in WNGrad has dramatically different behavior, growing like where is a bound on the variance of the stochastic gradients. As a result, in the stochastic setting, we also show that WNGrad, achieves the optimal rate of convergence for convex loss functions, and moreover settles in expectation on the “correct” constant, . Thus, WNGrad also works robustly in the stochastic setting, and finds a good learning rate.
We supplement all of our theorems with numerical experiments, which show that WNGrad competes favorably to plain stochastic gradient descent in terms of robustness to the Lipschitz constant of the loss function, speed of convergence, and generalization error, in training neural networks on two standard data sets.
WNGrad for Batch Gradient Descent
The following convergence result is classical ([Nesterov, 1998], ).
Suppose that and that . Consider gradient descent with constant learning rate .
If and , then
On the other hand, gradient descent can oscillate or diverge once .
Note that this result requires the knowledge of Lipschitz constant or an upper bound estimate. Even if such a bound is known, the algorithm is quite conservative; the Lipschitz constant represents the worst case oscillation of the function over all points in the domain; the local behavior of gradient might be much more regular, indicating that a larger learning rate (and hence, faster convergence rate) might be permissible. In any case, it is beneficial to consider a modified gradient descent algorithm which, starting from a large initial learning rate, decreases the learning rate according to gradient information received so far, and stabilizes at at a rate depends on the local smoothness behavior and so no smaller than .
We consider the following modified gradient descent scheme:
Initializing and scale invariance. Ideally, one could initialize in WNGrad by sampling points close to the initialization , and , and take
If this is not possible, it is also reasonable to consider an initialization with a constant . With either choice, one observes that the resulting WNGrad algorithm is invariant to the scale of : if is replaced by , then the sequence of iterates remains unchanged.
We show that the WNGrad algorithm has the following properties:
After a reasonable number of initial iterations, either or
If at some point , then the learning rate stabilizes: for all .
As a consequence, we have the following convergence result.
Consider the WNGrad algorithm. Set . Suppose that , is the point satisfying and that .
Case 1 steps if and
Case 2 steps if .
Comparing the convergence rate of batch gradient descent in Theorem 2.3 and the classical convergence result in Lemma 2.1, we see that WNGrad adjusts the learning rate automatically with decreasing learning rate based on the gradient information received so far, and without knowledge of the constant L, and still achieves linear convergence at nearly the same rate as gradient descent in Lemma 2.1 with constant learning rate .
We will use the following lemmas to prove Theorem 2.3. For more details, see Appendix A.1.
Fix and . Consider the sequence
after iterations, either , or .
Suppose that , and . Denote by the first index such that . Then for all ,
Lemma 2.5 guarantees that the learning rate stabilizes once it reaches the (unknown) Lipschitz constant, up to an additive term. To be complete, we can also bound as a function of , then arrive at the main result of this section.
Suppose that and that . Denote by the first index such that . Then
WNGrad for Stochastic Gradient Descent
We now shift from the setting of batch gradient descent to stochastic gradient descent. The update rule to the learning rate in WNGrad extends without modification to this setting, but now that the gradient norms do not converge to zero but rather remain noisy, the WNGrad learning rate does not converge to a fixed size, but rather settles eventually on the rate of , where is a bound on the variance of the stochastic gradients. In order to tackle this issue and derive a convergence rate, we assume for the analysis that the loss function is convex but not necessarily smooth.
Consider the general optimization problem
Consider WNGrad algorithm. Suppose is convex. Suppose, that, independent of ,
Under the same assumptions, excluding the assumption that , one obtains the same convergence rate using decreasing learning rate for some constant .
We will use the following lemma, which is easily proved by induction.
Consider a positive constant and a sequence of positive numbers and for each ,
Proof of Theorem 3.1: First, note that under the stated assumptions, satisfies Lemma 3.3 for . Thus, with probability 1,
Numerical Experiments
With guaranteed convergence of WNGrad in both batch and stochastic settings under appropriate conditionsWe assume non-convex smooth loss function in batch setting and convex not necessarily smooth in stochastic setting, we perform experiments in this section to show that WNGrad exhibits the same robustness for highly non-convex loss functions associated to deep learning problems.
Consider a loss function whose gradient has Lipschitz constant . Then, the gradient of the rescaled loss function has Lipschitz constant . If we were to also rescale to , then the dynamics would remain unchanged due to scale invariance. If instead we fix while letting vary, we can test the robustness of WNGrad to different Lipschitz constants, and compare its robustness to stochastic gradient descent (SGD, Algorithm 5 in Appendix). To be precise, we consider the following variant of WNGrad, Algorithm 3, and explore its performance as we vary . Note that in this algorithm is analogous to the constant learning rate in weight normalization and batch normalization as discussed in (1.2).
WNGrad is mainly tested on two data sets: MNIST [LeCun et al., 1998a] and CIFAR-10 [Krizhevsky, 2009]. Table 2 is the summary. We use batch size 100 for both MNIST and CIFAR-10. The experiments are done in PyTorch and parameters are by default if no specification is provided. The data sets are preprocessed with normalization using mean and standard deviation of the entire train samples. Details in implementing WNGrad in a neural network are explained in Appendix A.3.
We first test a wide range of the scale of the loss function with two fully connected layers (without bias term) on MNIST (input dimension is ) in a very simple setting excluding other factors that come into effect, such as regularization (weight decay), dropout, momentum, batch normalization, etc. In addition, we repeat 5 times for each experiment in order to avoid the initialization effect since random initialization of weight vectors is used in our experiments.
The outcome of the experiments shown in Figure 2 verifies that WNGrad is very robust to the Lipschitz constant, while SGD is much more sensitive. This shows that the learning rate can be initialized at a high value if we consider to be the learning rate. When picking and , we have the train/test loss with respect to epoch shown in blue and dark-red curves respectively. With larger scale of Lipschitz constant (), WNGrad does much better than SGD in both training and test loss. It is interesting to note that even with smaller scale of the Lipschitz constant , even thought SGD obtains the smaller training loss but does worse in generalization. On the contrary, WNGrad gives better generalization (smaller test loss) despite of the larger train loss. Thus, WNGrad to some extend is not only robust to the scale of Lipschitz constant but also generalizes well – we aim to study this property of WNGrad in future work.
Now we continue to compare the methods on a larger dataset, CIFAR10, with a wide range of scale . from to . We apply a simple convolution neural network (see Table 2 for details) with weight decay , of which the result shown in Figure 2. In comparison with SGD, WNGrad is very robust to the scale – it performs better at and does as well as SGD when smaller than . When at the best for each algorithm ( WNGrad and SGD), WNGrad outperforms in training and testing along the way.
A common practice to train deep and recurrent neural neural networks is to add momentum to stochastic gradient descent [Sutskever et al., 2013]. Recent adaptive moment estimation (Adam) [Kingma and Ba, 2014] seems to improve performance of models on a number of datasets. However, these methods are considerably sensitive to the scale of Lipschitz constant and require careful tuning in order to obtain the best result. Here we incorporate our algorithms with momentum (WNGrad-Momentum) and adapt “Adam” way (WN-Adam, Algorithm 4) in the hope to improve the robustness to the relationship between the learning rate and the Lipschitz constant. We use ResNet-18 training on CIFAR10 in Figure 3. Because of the batch normalization designed in ResNet-18, we widen the range of up to . As we can see, WN-Adam (green curve) and WNGrad-Momentum (black curve) do seem to be more robust compared to Adam (red) and SGD-Momentum (orange). Particularly, WN-Adam is very robust even at and still does fairly well in generalization.
Conclusion
We propose WNGrad, an method for dynamically updating the learning rate according to gradients received so far, which works in both batch and stochastic gradient methods and converges.
In the batch gradient descent setting, we show that WNGrad converges to a weight vector satisfying in at most iterations, if has -Lipschitz smooth gradient. This nearly matches the convergence rate for standard gradient descent with fixed learning rate , but WNGrad does not need to know in advance.
In the stochastic setting, the update in WNGrad has different behavior, growing like where is a bound on the variance of the stochastic gradients. As a result, in the stochastic setting, we also show that WNGrad achieves the optimal rate of convergence for convex loss functions, and moreover settles in expectation on the “correct” rate, . Thus, WNGrad works robustly in the stochastic setting, and finds a good learning rate.
In numerical experiments, WNGrad competes favorably to plain stochastic gradient descent in terms of robustness to the relationship between the learning rate and the Lipschitz constant and generalization error in training neural networks on two standard data sets. And such robustness extends further to the algorithm that incorporates momentum (WN-Adam and WNGrad-Momentum).
Acknowledgments
We thank Arthur Szlam and Mark Tygert for constructive suggestions. Also, we appreciate the help (with the experiments) from Sam Gross, Shubho Sengupta, Teng Li, Ailing Zhang, Zeming Lin, and Timothee Lacroix.
References
Appendix A Appendix
Proof: If , we are done. So suppose . Thus,
So and hence
A.1.2 Proof of Lemma 2.5
Suppose is the first index such that . Then for all , and by Lemma A.1, for
since . Finally, since for , we can bound . By Lemma A.1,
since and .
A.1.3 Proof of Lemma 2.6
We use shorthand . Let be the first index such that . Then,
A.2 Proof of Theorem 2.3
By Lemma 2.4, if is not satisfied after steps, then there is a first index such that . By Lemma 2.5, for all ,
and thus the stated result holds straightforwardly.
Otherwise, if then, by Lemma A.1, for any ,
By Lemma 2.6, since , we have
where .
A.3 Implementing the Algorithm in A Neural Network
In this section, we give the details for implementing our algorithm in a neural network. In the standard neural network architecture, the computation of each neuron consists of an elementwise nonlinearity of a linear transform of input features or output of previous layer:
where is the -dimensional weight vector, is a scalar bias term, , are respectively a -dimensional vector of input features (or output of previous layer) and the output of current neuron, denotes an elementwise nonlinearity. When using backpropogration [LeCun et al., 1998b] the stochastic gradient of in Algorithms 2, 3 and 4 represent the gradient of the current neuron (see Figure 4). Thus, when implementing our algorithm in PyTorch, WNGrad is one learning rate associated to one neuron, while SGD has one learning rate for all neurons.