Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients

Lukas Balles, Philipp Hennig

Introduction

Many prominent machine learning models pose empirical risk minimization problems with objectives of the form

which is a random variable with E[g(θ)]=L(θ)\mathbf{E}[g(\theta)]=\nabla\mathcal{L}(\theta). An important quantity for this paper will be the (element-wise) variances of the stochastic gradient, which we denote by σi2(θ):=var[g(θ)i]\sigma_{i}^{2}(\theta):=\mathbf{var}[g(\theta)_{i}].

Widely-used stochastic optimization algorithms are stochastic gradient descent (sgd, Robbins & Monro, 1951) and its momentum variants (Polyak, 1964; Nesterov, 1983). A number of methods popular in deep learning choose per-element update magnitudes based on past gradient observations. Among these are adagrad (Duchi et al., 2011), rmsprop (Tieleman & Hinton, 2012), adadelta (Zeiler, 2012), and adam (Kingma & Ba, 2015).

Notation: In the following, we occasionally drop θ\theta, writing gg instead of g(θ)g(\theta), et cetera. We use shorthands like Lt\nabla\mathcal{L}_{t}, gtg_{t} for sequences θt\theta_{t} and double indices where needed, e.g., gt,i=g(θt)ig_{t,i}=g(\theta_{t})_{i}, σt,i2=σi2(θt)\sigma^{2}_{t,i}=\sigma_{i}^{2}(\theta_{t}). Divisions, squares and square-roots on vectors are to be understood element-wise. To avoid confusion with inner products, we explicitly denote element-wise multiplication of vectors by \odot.

We start out from a reinterpretation of the widely-used adam optimizer,Some of our considerations naturally extend to adam’s relatives rmsprop and adadelta, but we restrict our attention to adam to keep the presentation concise. which maintains moving averages of stochastic gradients and their element-wise square,

with β1,β2(0,1)\beta_{1},\beta_{2}\in(0,1) and updates

with a small constant ε>0\varepsilon>0 preventing division by zero. Ignoring ε\varepsilon and assuming mt,i>0|m_{t,i}|>0 for the moment, we can rewrite the update direction as

where the sign is to be understood element-wise. Assuming that mtm_{t} and vtv_{t} approximate the first and second moment of the stochastic gradient—a notion that we will discuss further in §4.1—(vtmt2)(v_{t}-m_{t}^{2}) can be seen as an estimate of the stochastic gradient variances. The use of the non-central second moment effectively cancels out the magnitude of mtm_{t}; it only appears in the ratio (vtmt2)/mt2(v_{t}-m_{t}^{2})/m_{t}^{2}. Hence, adam can be interpreted as a combination of two aspects:

The update direction for the ii-th coordinate is given by the sign of mt,im_{t,i}.

The update magnitude for the ii-th coordinate is solely determined by the global step size α\alpha and the factor

where η^t,i\hat{\eta}_{t,i} is an estimate of the relative variance,

We will refer to the second aspect as variance adaptation. The variance adaptation factors shorten the update in directions of high relative variance, adapting for varying reliability of the stochastic gradient in different coordinates.

The above interpretation of adam’s update rule has to be viewed in contrast to existing ones. A motivation given by Kingma & Ba (2015) is that vtv_{t} is a diagonal approximation to the empirical Fisher information matrix (FIM), making adam an approximation to natural gradient descent (Amari, 1998). Apart from fundamental reservations towards the empirical Fisher and the quality of diagonal approximations (Martens, 2014, §11), this view is problematic because the FIM, if anything, is approximated by vtv_{t}, whereas adam adapts with the square-root vt\sqrt{v_{t}}.

Another possible motivation (which is not found in peer-reviewed publications but circulates the community as “conventional wisdom”) is that adam performs an approximate whitening of stochastic gradients. However, this view hinges on the fact that adam divides by the square-root of the non-central second moment, not by the standard deviation.

2 Overview

Both aspects of adam—taking the sign and variance adaptation—are briefly mentioned in Kingma & Ba (2015), who note that “[t]he effective stepsize […] is also invariant to the scale of the gradients” and refer to mt/vtm_{t}/\sqrt{v_{t}} as a “signal-to-noise ratio”. The purpose of this work is to disentangle these two aspects in order to discuss and analyze them in isolation.

This perspective naturally suggests two alternative methods by incorporating one of the aspects while excluding the other. Taking the sign of a stochastic gradient without any further modification gives rise to Stochastic Sign Descent (ssd). On the other hand, Stochastic Variance-Adapted Gradient (svag), to be derived in §3.2, applies variance adaptation directly to the stochastic gradient instead of its sign. Together with adam, the momentum variants of sgd, ssd, and svag constitute the four possible recombinations of the sign aspect and the variance adaptation, see Fig. 1.

We proceed as follows: Section 2 discusses the sign aspect. In a simplified setting we investigate under which circumstances the sign of a stochastic gradient is a better update direction than the stochastic gradient itself. Section 3 presents a principled derivation of element-wise variance adaptation factors. Subsequently, we discuss the practical implementation of variance-adapted methods (Section 4). Section 5 draws a connection to recent work on adam’s effect on generalization. Finally, Section 6 presents experimental results.

3 Related Work

Sign-based optimization algorithms have received some attention in the past. rprop (Riedmiller & Braun, 1993) is based on gradient signs and adapts per-element update magnitudes based on observed sign changes. Seide et al. (2014) empirically investigate the use of stochastic gradient signs in a distributed setting with the goal of reducing communication cost. Karimi et al. (2016) prove convergence results for sign-based methods in the non-stochastic case.

Variance-based update directions have been proposed before, e.g., by Schaul et al. (2013), where the variance appears together with curvature estimates in a diagonal preconditioner for sgd. Their variance-dependent terms resemble the variance adaptation factors we will derive in Section 3. The corresponding parts of our work complement that of Schaul et al. (2013) in various ways. Most notably, we provide a principled motivation for variance adaptation that is independent of the update direction and use that to extend the variance adaptation to the momentum case.

A somewhat related line of research aims to obtain reduced-variance gradient estimates (e.g., Johnson & Zhang, 2013; Defazio et al., 2014). This is largely orthogonal to our notion of variance adaptation, which alters the search direction to mitigate adverse effects of the (remaining) variance.

4 The Sign of a Stochastic Gradient

For later use, we briefly establish some facts about the signTo avoid a separate zero-case, we define sign(0)=1\operatorname{sign}(0)=1 for all theoretical considerations. Note that gi0g_{i}\neq 0 a.s. if var[gi]>0\mathbf{var}[g_{i}]>0. of a stochastic gradient, s=sign(g)s=\operatorname{sign}(g). The distribution of the binary random variable sis_{i} is fully characterized by the success probability ρi:=P[si=sign(Li)]\rho_{i}:=\mathbf{P}\left[s_{i}=\operatorname{sign}(\nabla\mathcal{L}_{i})\right], which generally depends on the distribution of gig_{i}. If we assume gig_{i} to be normally distributed, which is supported by the Central Limit Theorem applied to Eq. (3), we have

see §B.1 of the supplementary material. Note that ρi\rho_{i} is uniquely determined by the relative variance of gig_{i}.

Why the Sign?

Can it make sense to use the sign of a stochastic gradient as the update direction instead of the stochastic gradient itself? This question is difficult to tackle in a general setting, but we can get an intuition using the simple, yet insightful, case of stochastic quadratic problems, where we can investigate the effects of curvature properties and noise.

with L(θ)=Q(θx)\nabla\mathcal{L}(\theta)=Q(\theta-x^{\ast}). Stochastic gradients are given by g(θ)=Q(θx)N(L(θ),ν2QQ)g(\theta)=Q(\theta-x)\sim\mathcal{N}(\nabla\mathcal{L}(\theta),\nu^{2}QQ).

We compare update directions on sQPs in terms of their local expected decrease in function value from a single step. For any stochastic direction zz, updating from θ\theta to θ+αz\theta+\alpha z results in E[L(θ+αz)]=L(θ)+αL(θ)TE[z]+α22E[zTQz]\mathbf{E}[\mathcal{L}(\theta+\alpha z)]=\mathcal{L}(\theta)+\alpha\nabla\mathcal{L}(\theta)^{T}\mathbf{E}[z]+\frac{\alpha^{2}}{2}\mathbf{E}[z^{T}Qz]. For this comparison of update directions we use the optimal step size minimizing E[L(θ+αz)]\mathbf{E}[\mathcal{L}(\theta+\alpha z)], which is easily found to be α=L(θ)TE[z]/E[zTQz]\alpha_{\ast}=-\nabla\mathcal{L}(\theta)^{T}\mathbf{E}[z]/\mathbf{E}[z^{T}Qz] and yields an expected improvement of

Locally, a larger expected improvement implies a better update direction. We compute this quantity for sgd (z=g(θ)z=-g(\theta)) and ssd (z=sign(g(θ))z=-\operatorname{sign}(g(\theta))) in §B.2 of the supplementary material and find

Firstly, the term pdiag(Q)p_{\text{diag}}(Q), which features only in I\textscssd\mathcal{I}_{\textsc{ssd}}, relates to the orientation of the eigenbasis of QQ. If QQ is diagonal, the problem is perfectly axis-aligned and we have pdiag(Q)=1p_{\text{diag}}(Q)=1. This is the obvious best case for the intrinsically axis-aligned sign update. However, pdiag(Q)p_{\text{diag}}(Q) can become as small as 1/d1/d in the worst case and will on average (over random orientations) be pdiag(Q)1.57/dp_{\text{diag}}(Q)\approx 1.57/d. (We show these properties in §B.2 of the supplementary material.) This suggests that the sign update will have difficulties with arbitrarily-rotated eigenbases and crucially relies on the problem being “close to axis-aligned”.

Secondly, I\textscsgd\mathcal{I}_{\textsc{sgd}} contains the term ν2i=1dλi3\nu^{2}\sum_{i=1}^{d}\lambda_{i}^{3} in which stochastic noise and the eigenspectrum of the problem interact. I\textscssd\mathcal{I}_{\textsc{ssd}}, on the other hand, has a milder dependence on the eigenvalues of QQ and there is no such interaction between noise and eigenspectrum. The noise only manifests in the element-wise success probabilities ρi\rho_{i}.

In summary, we can expect the sign direction to be beneficial for noisy, ill-conditioned problems with diagonally dominant Hessians. It is unclear to what extent these properties hold for real problems, on which sign-based methods like adam are usually applied. Becker & LeCun (1988) empirically investigated the first property for Hessians of simple neural network training problems and found comparably high values of pdiag(Q)=0.1p_{\text{diag}}(Q)=0.1 up to pdiag(Q)=0.6p_{\text{diag}}(Q)=0.6. Chaudhari et al. (2017) empirically investigated the eigenspectrum in deep learning problems and found it to be very ill-conditioned with the majority of eigenvalues close to zero and a few very large ones. However, this empirical evidence is far from conclusive.

2 Experimental Evaluation

Variance Adaptation

We now proceed to the second component of adam: variance-based element-wise step sizes. Considering this variance adaptation in isolation from the sign aspect naturally suggests to employ it on arbitrary update directions, for example directly on the stochastic gradient instead of its sign. A principled motivation arises from the following consideration:

and E[γsign(p^)sign(p)22]\mathbf{E}[\|\gamma\odot\operatorname{sign}(\hat{p})-\operatorname{sign}(p)\|_{2}^{2}] is minimized by

where ρi:=P[sign(p^i)=sign(pi)]\rho_{i}:=\mathbf{P}[\operatorname{sign}(\hat{p}_{i})=\operatorname{sign}(p_{i})]. (Proof in §B.3)

According to Lemma 1, the optimal variance adaptation factors for the sign of a stochastic gradient are γi=2ρi1\gamma_{i}=2\rho_{i}-1, where ρi=P[sign(gi)=sign(Li)]\rho_{i}=\mathbf{P}[\operatorname{sign}(g_{i})=\operatorname{sign}(\nabla\mathcal{L}_{i})]. Appealing to intuition, this means that γi\gamma_{i} is proportional to the success probability with a maximum of 11 when we are certain about the sign of the gradient (ρi=1\rho_{i}=1) and a minimum of in the absence of information (ρi=0.5\rho_{i}=0.5).

Recall from Eq. (10) that, under the Gaussian assumption, the success probabilities are 2ρi1=erf[(2ηi)1]2\rho_{i}-1=\operatorname{erf}[(\sqrt{2}\eta_{i})^{-1}]. Figure 3 shows that this term is closely approximated by (1+ηi2)1/2(1+\eta_{i}^{2})^{-1/2}, the variance adaptation terms of adam. Hence, adam can be regarded as an approximate realization of this optimal variance adaptation scheme. This comes with the caveat that adam applies these factors to sign(mt)\operatorname{sign}(m_{t}) instead of sign(gt)\operatorname{sign}(g_{t}). Variance adaptation for mtm_{t} will be discussed further in §4.3 and in the supplements §C.2.

2 Stochastic Variance-Adapted Gradient (SVAG)

Applying Eq. (15) to p^=g\hat{p}=g, the optimal variance adaptation factors for a stochastic gradient are found to be

A term of this form also appears, together with diagonal curvature estimates, in Schaul et al. (2013). We refer to the method updating along γgg\gamma^{g}\odot g as Stochastic Variance-Adapted Gradient (svag). To support intuition, Fig. 4 shows a conceptual sketch of this variance adaptation scheme.

Variance adaptation of this form guarantees convergence without manually decreasing the global step size. We recover the O(1/t)\mathcal{O}(1/t) rate of sgd for smooth, strongly convex functions. We emphasize that this result considers an idealized version of svag with exact γig\gamma^{g}_{i}. It should be considered as a motivation for this variance adaptation strategy, not a statement about its performance with estimated variance adaptation factors.

where ff_{\ast} is the minimum value of ff. (Proof in §B.4)

The assumption i=1dσt,i2cvft2+Mv\sum_{i=1}^{d}\sigma_{t,i}^{2}\leq c_{v}\|\nabla f_{t}\|^{2}+M_{v} is a mild restriction on the variances, allowing them to be non-zero everywhere and to grow quadratically in the gradient norm.

Practical Implementation of M-SVAG

Section 3 has introduced the general idea of variance adaptation; we now discuss its practical implementation. For the sake of a concise presentation, we focus on one particular variance-adapted method, m-svag, which applies variance adaptation to the update direction mtm_{t}. This method is of particular interest due to its relationship to adam outlined in Figure 1. Many of the following considerations correspondingly apply to other variance-adapted methods, e.g., svag and variants of adam, some of which are discussed and evaluated in the supplementary material (§C).

In practice, the optimal variance adaptation factors are unknown and have to be estimated. A key ingredient is an estimate of the stochastic gradient variance. We have argued in the introduction that adam obtains such an estimate from moving averages, σt,i2vt,imt,i2\sigma_{t,i}^{2}\approx v_{t,i}-m_{t,i}^{2}. The underlying assumption is that the distribution of stochastic gradients is approximately constant over the effective time horizon of the exponential moving average, making mtm_{t} and vtv_{t} estimates of the first and second moment of gtg_{t}, respectively:

While this can only ever hold approximately, Assumption 1 is the tool we need to obtain gradient variance estimates from past gradient observations. It will be more realistic in the case of high noise and small step size, where the variation between successive stochastic gradients is dominated by stochasticity rather than change in the true gradient.

We make two modifications to adam’s variance estimate. First, we will use the same moving average constant β1=β2=β\beta_{1}=\beta_{2}=\beta for mtm_{t} and vtv_{t}. This constant should define the effective range for which we implicitly assume the stochastic gradients to come from the same distribution, making different constants for the first and second moment implausible.

Secondly, we adapt for a systematic bias in the variance estimate. As we show in §B.5, under Assumption 1,

and consequently E[vt,imt,i2](1ρ(β,t))σt,i2\mathbf{E}[v_{t,i}-m_{t,i}^{2}]\approx(1-\rho(\beta,t))\,\sigma_{t,i}^{2}. We correct for this bias and use the variance estimate

Mini-Batch Gradient Variance Estimates: An alternative variance estimate can be computed locally “within” a single mini-batch, see §D of the supplements. We have experimented with both estimators and found the resulting methods to have similar performance. For the main paper, we stick to the moving average variant for its ease of implementation and direct correspondence with adam. We present experiments with the mini-batch variant in the supplementary material. These demonstrate the merit of variance adaptation irrespective of how the variance is estimated.

2 Estimating the Variance Adaptation Factors

The gradient variance itself is not of primary interest; we have to estimate the variance adaptation factors, given by Eq. (17) in the case of svag. We propose to use the estimate

While γ^tg\hat{\gamma}^{g}_{t} is an intuitive quantity, it is not an unbiased estimate of the exact variance adaptation factors as defined in Eq. (17). To our knowledge, unbiased estimation of the exact factors is intractable. We have experimented with several partial bias correction terms but found them to have destabilizing effects.

3 Incorporating Momentum

So far, we have considered variance adaptation for the update direction gtg_{t}. In practice, we may want to update in the direction of mtm_{t} to incorporate momentum. Our use of the term momentum is somewhat colloquial. To highlight the relationship with adam (Fig. 1), we have defined m-sgd as the method using the update direction mtm_{t}, which is a rescaled version of sgd with momentum. m-svag applies variance adaptation to mtm_{t}. This is not to be confused with the application of momentum acceleration (Polyak, 1964; Nesterov, 1983) on top of a svag update. According to Lemma 1, the variance adaptation factors should then be determined by the relative of variance of mtm_{t}.

Once more adopting Assumption 1, we have E[mt]Lt\mathbf{E}[m_{t}]\approx\nabla\mathcal{L}_{t} and var[mt,i]ρ(β,t)σt,i2\mathbf{var}[m_{t,i}]\approx\rho(\beta,t)\sigma_{t,i}^{2}, the latter being due to Eq. (20). Hence, the relative variance of mtm_{t} is ρ(β,t)\rho(\beta,t) times that of gtg_{t}, such that the optimal variance adaptation factors for the update direction mtm_{t} according to Lemma 1 are

Note that mtm_{t} now serves a double purpose: It determines the base update direction and, at the same time, is used to obtain an estimate of the gradient variance.

4 Details

Note that Eq. (22) is ill-defined for t=0t=0, since ρ(β,0)=0\rho(\beta,0)=0. We use s^0=0\hat{s}_{0}=0 for the first iteration, making the initial step of m-svag coincide with an sgd-step. One final detail concerns a possible division by zero in Eq. (25). Unlike adam, we do not add a constant offset ε\varepsilon in the denominator. A division by zero only occurs when mt,i=vt,i=0m_{t,i}=v_{t,i}=0; we check for this case and perform no update, since mt,i=0m_{t,i}=0.

This completes the description of our implementation of m-svag. Alg. 1 provides pseudo-code (ignoring the details discussed in §4.4 for readability).

Connection to Generalization

Of late, the question of the effect of the optimization algorithm on generalization has received increased attention. Especially in deep learning, different optimizers might find solutions with varying generalization performance. Recently, Wilson et al. (2017) have argued that “adaptive methods” (referring to adagrad, rmsprop, and adam) have adverse effects on generalization compared to “non-adaptive methods” (gradient descent, sgd, and their momentum variants). In addition to an extensive empirical validation of that claim, the authors make a theoretical argument using a binary least-squares classification problem,

Intriguingly, as we show in §B.6 of the supplementary material, this statement easily extends to sign descent, i.e., the method updating θt+1=θtαsign(R(θt))\theta_{t+1}=\theta_{t}-\alpha\operatorname{sign}(\nabla R(\theta_{t})).

Under the assumptions of Lemma 2, the iterates generated by sign descent satisfy θtsign(XTy)\theta_{t}\propto\operatorname{sign}(X^{T}y).

On the other hand, this does not extend to m-svag, an adaptive method by any standard. As noted before, the first step of m-svag coincides with a gradient descent step. The iterates generated by m-svag will, thus, not generally be proportional to sign(XTy)\operatorname{sign}(X^{T}y). While this does by no means imply that it converges to the max-margin solution or has otherwise favorable generalization properties, the construction of Wilson et al. (2017) does not apply to m-svag.

This suggests that it is the sign that impedes generalization in the examples constructed by Wilson et al. (2017), rather than the element-wise adaptivity as such. Our experiments substantiate this suspicion. The fact that all currently popular adaptive methods are also sign-based has led to a conflation of these two aspects. The main motivation for this work was to disentangle them.

Experiments

We experimentally compare m-svag and adam to their non-variance-adapted counterparts m-sgd and m-ssd (Alg. 2). Since these are the four possible recombinations of the sign and the variance adaptation (Fig. 1), this comparison allows us to separate the effects of the two aspects.

We evaluated the four methods on the following problems:

A vanilla convolutional neural network (CNN) with two convolutional and two fully-connected layers on the Fashion-mnist data set (Xiao et al., 2017).

A vanilla CNN with three convolutional and three fully-connected layers on cifar-10 (Krizhevsky, 2009).

The wide residual network WRN-40-4 architecture of Zagoruyko & Komodakis (2016) on cifar-100.

A two-layer LSTM (Hochreiter & Schmidhuber, 1997) for character-level language modelling on Tolstoy’s War and Peace.

A detailed description of all network architectures has been moved to §A of the supplementary material.

For all experiments, we used β=0.9\beta=0.9 for m-sgd, m-ssd and m-svag and default parameters (β1=0.9,β2=0.999,ε=108\beta_{1}=0.9,\beta_{2}=0.999,\varepsilon=10^{-8}) for adam. The global step size α\alpha was tuned for each method individually by first finding the maximal stable step size by trial and error, then searching downwards. We selected the one that yielded maximal test accuracy within a fixed number of training steps; a scenario close to an actual application of the methods by a practitioner. (Loss and accuracy have been evaluated at a fixed interval on the full test set as well as on an equally-sized portion of the training set). Experiments with the best step size have been replicated ten times with different random seeds. While (P1) and (P2) were trained with constant α\alpha, we used a decrease schedule for (P3) and (P4), which was fixed in advance for all methods. Full details can be found in §A of the supplements.

2 Results

Fig. 5 shows results. We make four main observations.

With the exception of (P4), the performance of the four methods distinctly clusters into sign-based and non-sign-based methods. Of the two components of adam identified in §1.1, the sign aspect seems to be by far the dominant one, accounting for most of the difference between adam and m-sgd. adam and m-ssd display surprisingly similar performance; an observation that might inform practitioners’ choice of algorithm, especially for very high-dimensional problems, where adam’s additional memory requirements are an issue.

Considering only training loss, the two sign-based methods clearly outperform the two non-sign-based methods on problems (P1) and (P3). On (P2), adam and m-ssd make rapid initial progress, but later plateau and are undercut by m-sgd and m-svag. On the language modelling task (P4) the non-sign-based methods show superior performance. Relating to our analysis in Section 2, this shows that the usefulness of sign-based methods depends on the particular problem at hand.

In all experiments, the variance-adapted variants perform at least as good as, and often better than, their “base algorithms”. The magnitude of the effect varies. For example, adam and m-ssd have identical performance on (P3), but m-svag significantly outperforms m-sgd on (P3) as well as (P4).

The cifar-100 example (P3) displays similar effects as reported by Wilson et al. (2017): adam vastly outperforms m-sgd in training loss, but has significantly worse test performance. Observe that m-ssd behaves almost identical to adam in both train and test and, thus, displays the same generalization-harming effects. m-svag, on the other hand, improves upon m-sgd and, in particular, does not display any adverse effects on generalization. This corroborates the suspicion raised in §5 that the generalization-harming effects of adam are caused by the sign aspect rather than the element-wise adaptive step sizes.

Conclusion

We have argued that adam combines two components: taking signs and variance adaptation. Our experiments show that the sign aspect is by far the dominant one, but its usefulness is problem-dependent. Our theoretical analysis suggests that it depends on the interplay of stochasticity, the conditioning of the problem, and its axis-alignment. Sign-based methods also seem to have an adverse effect on the generalization performance of the obtained solution; a possible starting point for further research into the generalization effects of optimization algorithms.

The second aspect, variance adaptation, is not restricted to adam but can be applied to any update direction. We have provided a general motivation for variance adaptation factors that is independent of the update direction. In particular, we introduced m-svag, a variance-adapted variant of momentum sgd, which is a useful addition to the practitioner’s toolbox for problems where sign-based methods like adam fail. A TensorFlow (Abadi et al., 2015) implementation can be found at https://github.com/lballes/msvag.

Acknowledgements

The authors thank Maren Mahsereci for helpful discussions. The authors acknowledge financial support by the European Research Council through ERC StG Action 757275 / PANAMA during a part of the project. Lukas Balles kindly acknowledges the support of the International Max Planck Research School for Intelligent Systems (IMPRS-IS).

References

—Supplementary Material—

Appendix A Experiments

We trained a simple convolutional neural network with two convolutional layers (size 5×\times5, 32 and 64 filters, respectively), each followed by max-pooling over 3×\times3 areas with stride 2, and a fully-connected layer with 1024 units. ReLU activation was used for all layers. The output layer has 10 units with softmax activation. We used cross-entropy loss, without any additional regularization, and a mini-batch size of 64. We trained for a total of 6000 steps with a constant global step size α\alpha.

We trained a CNN with three convolutional layers (64 filters of size 5×\times5, 96 filters of size 3×\times3, and 128 filters of size 3×\times3) interspersed with max-pooling over 3×\times3 areas with stride 2 and followed by two fully-connected layers with 512 and 256 units. ReLU activation was used for all layers. The output layer has 10 units with softmax activation. We used cross-entropy loss function and applied L2L_{2}-regularization on all weights, but not the biases. During training we performed some standard data augmentation operations (random cropping of sub-images, left-right mirroring, color distortion) on the input images. We used a batch size of 128 and trained for a total of 40k steps with a constant global step size α\alpha.

We use the WRN-40-4 architecture of Zagoruyko & Komodakis (2016); details can be found in the original paper. We used cross-entropy loss and applied L2L_{2}-regularization on all weights, but not the biases. We used the same data augmentation operations as for cifar-10, a batch size of 128, and trained for 80k steps. For the global step size α\alpha, we used the decrease schedule suggested by Zagoruyko & Komodakis (2016), which amounts to multiplying with a factor of 0.2 after 24k, 48k, and 64k steps. TensorFlow code was adapted from https://github.com/dalgu90/wrn-tensorflow.

We preprocessed War and Peace, extracting a vocabulary of 83 characters. The language model is a two-layer LSTM with 128 hidden units each. We used a sequence length of 50 characters and a batch size of 50. Drop-out regularization was applied during training. We trained for 200k steps; the global step size α\alpha was multiplied with a factor of 0.1 after 125k steps. TensorFlow code was adapted from https://github.com/sherjilozair/char-rnn-tensorflow.

A.2 Step Size Tuning

Step sizes α\alpha (initial step sizes for the experiments with a step size decrease schedule) for each optimizer have been tuned by first finding the maximal stable step size by trial and error and then searching downwards over multiple orders of magnitude, testing 610m6\cdot 10^{m}, 310m3\cdot 10^{m}, and 110m1\cdot 10^{m} for order of magnitude mm. We evaluated loss and accuracy on the full test set (as well as on an equally-sized portion of the training set) at a constant interval and selected the best-performing step size for each method in terms of maximally reached test accuracy. Using the best choice, we replicated the experiment ten times with different random seeds, randomizing the parameter initialization, data set shuffling, drop-out, et cetera. In some rare cases where the accuracies for two different step sizes were very close, we replicated both and then chose the one with the higher maximum mean accuracy.

The following list shows all explored step sizes, with the “winner” in bold face.

Problem 1: Fashion-mnist m-sgd: 3,1,6101,3101,1101,6102,3102,1102,6103,31033,1,6\cdot 10^{-1},3\cdot 10^{-1},\mathbf{1\cdot 10^{-1}},6\cdot 10^{-2},3\cdot 10^{-2},1\cdot 10^{-2},6\cdot 10^{-3},3\cdot 10^{-3} adam: 3102,102,6103,3103,1103,6104,3104,11043\cdot 10^{-2},10^{-2},6\cdot 10^{-3},3\cdot 10^{-3},\mathbf{1\cdot 10^{-3}},6\cdot 10^{-4},3\cdot 10^{-4},1\cdot 10^{-4} m-ssd: 102,6103,3103,1103,6104,3104,110410^{-2},6\cdot 10^{-3},3\cdot 10^{-3},1\cdot 10^{-3},6\cdot 10^{-4},\mathbf{3\cdot 10^{-4}},1\cdot 10^{-4} m-svag: 3,1,6101,3101,1101,6102,3102,1102,6103,31033,1,6\cdot 10^{-1},\mathbf{3\cdot 10^{-1}},1\cdot 10^{-1},6\cdot 10^{-2},3\cdot 10^{-2},1\cdot 10^{-2},6\cdot 10^{-3},3\cdot 10^{-3}

Problem 2: cifar-10 m-sgd: 6101,3101,1101,6102,3102,1102,6103,31036\cdot 10^{-1},3\cdot 10^{-1},1\cdot 10^{-1},6\cdot 10^{-2},\mathbf{3\cdot 10^{-2}},1\cdot 10^{-2},6\cdot 10^{-3},3\cdot 10^{-3} adam: 6103,3103,1103,6104,3104,1104,61056\cdot 10^{-3},3\cdot 10^{-3},1\cdot 10^{-3},\mathbf{6\cdot 10^{-4}},3\cdot 10^{-4},1\cdot 10^{-4},6\cdot 10^{-5} m-ssd: 6103,3103,1103,6104,3104,1104,6105,31056\cdot 10^{-3},3\cdot 10^{-3},1\cdot 10^{-3},6\cdot 10^{-4},3\cdot 10^{-4},\mathbf{1\cdot 10^{-4}},6\cdot 10^{-5},3\cdot 10^{-5} m-svag: 1,6101,3101,1101,6102,3102,1102,61031,6\cdot 10^{-1},3\cdot 10^{-1},1\cdot 10^{-1},\mathbf{6\cdot 10^{-2}},3\cdot 10^{-2},1\cdot 10^{-2},6\cdot 10^{-3}

Problem 3: cifar-100 m-sgd: 6,3,1,6101,3101,1101,6102,3102,11026,\mathbf{3},1,6\cdot 10^{-1},3\cdot 10^{-1},1\cdot 10^{-1},6\cdot 10^{-2},\mathbf{3\cdot 10^{-2}},1\cdot 10^{-2} adam: 1102,6103,3103,1103,6104,3104,1104,6105,31051\cdot 10^{-2},6\cdot 10^{-3},3\cdot 10^{-3},1\cdot 10^{-3},6\cdot 10^{-4},\mathbf{3\cdot 10^{-4}},1\cdot 10^{-4},6\cdot 10^{-5},3\cdot 10^{-5} m-ssd: 1102,6103,3103,1103,6104,3104,1104,6105,31051\cdot 10^{-2},6\cdot 10^{-3},3\cdot 10^{-3},1\cdot 10^{-3},6\cdot 10^{-4},3\cdot 10^{-4},\mathbf{1\cdot 10^{-4}},6\cdot 10^{-5},3\cdot 10^{-5} m-svag: 6,3,1,6101,3101,1101,6102,3102,11026,\mathbf{3},1,6\cdot 10^{-1},3\cdot 10^{-1},1\cdot 10^{-1},6\cdot 10^{-2},\mathbf{3\cdot 10^{-2}},1\cdot 10^{-2}

Problem 4: War and Peace m-sgd: 10,6,3,1,6101,3101,1101,610210,6,\mathbf{3},1,6\cdot 10^{-1},3\cdot 10^{-1},1\cdot 10^{-1},6\cdot 10^{-2} adam: 1102,6103,3103,1103,6104,3104,1104,61051\cdot 10^{-2},6\cdot 10^{-3},\mathbf{3\cdot 10^{-3}},1\cdot 10^{-3},6\cdot 10^{-4},3\cdot 10^{-4},1\cdot 10^{-4},6\cdot 10^{-5} m-ssd: 1102,6103,3103,1103,6104,3104,1104,61051\cdot 10^{-2},6\cdot 10^{-3},3\cdot 10^{-3},\mathbf{1\cdot 10^{-3}},6\cdot 10^{-4},3\cdot 10^{-4},1\cdot 10^{-4},6\cdot 10^{-5} m-svag: 30,10,6,3,1,6101,3101,110130,\mathbf{10},6,3,1,6\cdot 10^{-1},3\cdot 10^{-1},1\cdot 10^{-1}

Appendix B Mathematical Details

We have stated in the main text that the sign of a stochastic gradient, s(θ)=sign(g(θ))s(\theta)=\operatorname{sign}(g(\theta)), has success probabilities

under the assumption that gN(L,Σ)g\sim\mathcal{N}(\nabla\mathcal{L},\Sigma). The following Lemma formally proves this statement and Figure 6 provides a pictorial illustration.

If XN(μ,σ2)X\sim\mathcal{N}(\mu,\sigma^{2}) then

Define ρ:=P[sign(X)=sign(μ)]\rho:=\mathbf{P}[\operatorname{sign}(X)=\operatorname{sign}(\mu)]. The cumulative density function (cdf) of XN(μ,σ2)X\sim\mathcal{N}(\mu,\sigma^{2}) is P[Xx]=Φ((xμ)/σ)\mathbf{P}[X\leq x]=\Phi((x-\mu)/\sigma), where Φ(z)=0.5(1+erf(z/2))\Phi(z)=0.5(1+\operatorname{erf}(z/\sqrt{2})) is the cdf of the standard normal distribution. If μ<0\mu<0, then

where the last step used the anti-symmetry of the error function. ∎

B.2 Analysis on Stochastic QPs

We derive the expressions in Eq. (13), dropping the fixed θ\theta from the notation for readability.

For sgd, we have E[g]=L\mathbf{E}[g]=\nabla\mathcal{L} and E[gTQg]=LTQL+tr(Qcov[g])\mathbf{E}[g^{T}Qg]=\nabla\mathcal{L}^{T}Q\nabla\mathcal{L}+\operatorname{tr}(Q\mathbf{cov}[g]), which is a general fact for quadratic forms of random variables. For the stochastic QP the gradient covariance is cov[g]=ν2QQ\mathbf{cov}[g]=\nu^{2}QQ, thus tr(Qcov[g])=ν2tr(QQQ)=ν2iλi3\operatorname{tr}(Q\mathbf{cov}[g])=\nu^{2}\operatorname{tr}(QQQ)=\nu^{2}\sum_{i}\lambda_{i}^{3}. Plugging everything into Eq. (12) yields

For stochastic sign descent, s=sign(g)s=\operatorname{sign}(g), we have E[si]=(2ρi1)sign(Li)\mathbf{E}[s_{i}]=(2\rho_{i}-1)\operatorname{sign}(\nabla\mathcal{L}_{i}) and thus LTE[s]=i=1dLiE[si]=i(2ρi1)Li\nabla\mathcal{L}^{T}\mathbf{E}[s]=\sum_{i=1}^{d}\nabla\mathcal{L}_{i}\mathbf{E}[s_{i}]=\sum_{i}(2\rho_{i}-1)|\nabla\mathcal{L}_{i}|. Regarding the denominator, it is

since si=1|s_{i}|=1. Further, by definition of pdiag(Q)p_{\text{diag}}(Q), we have i=1dqij=pdiag(Q)1i=1dqii\sum_{i=1}^{d}|q_{ij}|=p_{\text{diag}}(Q)^{-1}\sum_{i=1}^{d}|q_{ii}|. Since QQ is positive definite, its diagonal elements are positive, such that i=1dqii=i=1dqii=i=1dλi\sum_{i=1}^{d}|q_{ii}|=\sum_{i=1}^{d}q_{ii}=\sum_{i=1}^{d}\lambda_{i}. Plugging everything into Eq. (12) yields

As mentioned before, iqii=iλi\sum_{i}|q_{ii}|=\sum_{i}\lambda_{i}. Hence,

As we have already seen, the best case arises if the eigenvectors are axis-aligned (diagonal QQ), resulting in vi1=vi2=1\|v_{i}\|_{1}=\|v_{i}\|_{2}=1.

We can get a rough intuition for the average case from the following consideration: For a dd-dimensional random vector wN(0,I)w\sim\mathcal{N}(0,I), which corresponds to a random orientation, we have

As a rough approximation, we can thus assume that a randomly-oriented vector will satisfy w12d/πw2\|w\|_{1}\approx\sqrt{2d/\pi}\|w\|_{2}. Plugging that in for the eigenvectors of QQ in Eq. (35) yields an approximate average case value of

B.3 Variance Adaptation Factors

Using E[p^i]=pi\mathbf{E}[\hat{p}_{i}]=p_{i} and E[p^i2]=pi2+σi2\mathbf{E}[\hat{p}_{i}^{2}]=p_{i}^{2}+\sigma_{i}^{2}, we get

Setting the derivative w.r.t. γi\gamma_{i} to zero, we find the optimal choice

For the second part, using E[sign(p^i)]=(2ρi1)sign(pi)\mathbf{E}[\operatorname{sign}(\hat{p}_{i})]=(2\rho_{i}-1)\operatorname{sign}(p_{i}) and sign()2=1\operatorname{sign}(\cdot)^{2}=1, we get

B.4 Convergence of Idealized SVAG

Regarding the first inequality, we use f(θ)=0\nabla f(\theta_{\ast})=0 and the Lipschitz continuity of f()\nabla f(\cdot) to get f(θ)2=f(θ)f(θ)2L2θθ2\|\nabla f(\theta)\|^{2}=\|\nabla f(\theta)-\nabla f(\theta_{\ast})\|^{2}\leq L^{2}\|\theta-\theta_{\ast}\|^{2}. Using strong convexity, we have f(θ)f+f(θ)T(θθ)+(μ/2)θθ2=f+(μ/2)θθ2f(\theta)\geq f_{\ast}+\nabla f(\theta_{\ast})^{T}(\theta-\theta_{\ast})+(\mu/2)\|\theta-\theta_{\ast}\|^{2}=f_{\ast}+(\mu/2)\|\theta-\theta_{\ast}\|^{2}. Plugging the two inequalities together yields the desired inequality.

The second inequality arises from strong convexity, by minimizing both sides of

w.r.t. θ\theta^{\prime}. The left-hand side obviously has minimal value ff_{\ast}. For the right-hand side, we set its derivative, f(θ)+μ(θθ)\nabla f(\theta)+\mu(\theta^{\prime}-\theta), to zero to find the minimizer θ=θf(θ)/μ\theta^{\prime}=\theta-\nabla f(\theta)/\mu. Plugging that back in yields the minimal value f(θ)f(θ)/(2μ)f(\theta)-\|\nabla f(\theta)\|/(2\mu). ∎

Using the Lipschitz continuity of f\nabla f, we can bound f(θ+Δθ)f(θ)+f(θ)TΔθ+L2Δθ2f(\theta+\Delta\theta)\leq f(\theta)+\nabla f(\theta)^{T}\Delta\theta+\frac{L}{2}\|\Delta\theta\|^{2}. Hence,

Plugging in the definition γt,i=ft,i2/(ft,i2+σt,i2)\gamma_{t,i}=\nabla f_{t,i}^{2}/(\nabla f_{t,i}^{2}+\sigma_{t,i}^{2}) and simplifying, we get

This shows that Et[ft+1]ft\mathbf{E}_{t}[f_{t+1}]\leq f_{t}. Defining et:=ftfe_{t}:=f_{t}-f_{\ast}, this implies

Using the assumption i=1dσt,i2cvft2+Mv\sum_{i=1}^{d}\sigma_{t,i}^{2}\leq c_{v}|\nabla f_{t}\|^{2}+M_{v} in the denominator, we obtain

where the last equality defines the (positive) constants c1,c2c_{1},c_{2} and c3c_{3}. Combining Eqs. (48), (49) and (51), inserting in (46), and subtracting ff_{\ast} from both sides, we obtain

and, consequently, by taking expectations on both sides,

where the last step is due to Jensen’s inequality applied to the convex function ϕ(x)=c1x2c2x+c3\phi(x)=\frac{c_{1}x^{2}}{c_{2}x+c_{3}}. Using E[et]e0\mathbf{E}[e_{t}]\leq e_{0} in the denominator and introducing the shorthand eˉt:=E[et]\bar{e}_{t}:=\mathbf{E}[e_{t}], we get

with c:=c1/(2L(c2e0+c3))>0c:=c_{1}/(2L(c_{2}e_{0}+c_{3}))>0. To conclude the proof, we will show that this implies eˉtO(1t)\bar{e}_{t}\in\mathcal{O}(\frac{1}{t}). Without loss of generality, we assume eˉt+1>0\bar{e}_{t+1}>0 and obtain

where the second step is due to the simple fact that (1x)1(1+x)(1-x)^{-1}\geq(1+x) for any x[0,1)x\in[0,1). Summing this inequality over t=0,,T1t=0,\dotsc,T-1 yields eˉT1e01+Tc\bar{e}_{T}^{-1}\geq e_{0}^{-1}+Tc and, thus,

which shows that eˉtO(1t)\bar{e}_{t}\in\mathcal{O}(\frac{1}{t}). ∎

B.5 Gradient Variance Estimates via Moving Averages

with coefficients c(β1,t,s)c(\beta_{1},t,s) summing to one by the geometric sum formula, making mtm_{t} a convex combination of stochastic gradients. Likewise, vt=s=0tc(β2,t,s)gs2v_{t}=\sum_{s=0}^{t}c(\beta_{2},t,s)g_{s}^{2} is a convex combination of squared stochastic gradients. Hence,

Assumption 1 thus necessarily implies E[gs,i]Lt,i\mathbf{E}[g_{s,i}]\approx\nabla\mathcal{L}_{t,i} and E[gs,i2]Lt,i2+σt,i2\mathbf{E}[g_{s,i}^{2}]\approx\nabla\mathcal{L}_{t,i}^{2}+\sigma_{t,i}^{2}. (This will of course be utterly wrong for gradient observations that are far in the past, but these won’t contribute significantly to the moving average.) It follows that

where the second step is due to the fact that gsg_{s} and gsg_{s^{\prime}} are stochastically independent for sss\neq s^{\prime}. The last term evaluates to

where the fourth step is another application of the geometric sum formula, and the fifth step uses 1x2=(1x)(1+x)1-x^{2}=(1-x)(1+x). Note that

such that ρ(β,t)\rho(\beta,t) is uniquely defined by β\beta in the long term.

As an interesting side note, the division by 1ρ(β,t)1-\rho(\beta,t) in Eq. (22) is the analogon to Bessel’s correction (the use of n1n-1 instead of nn in the classical sample variance) for the case where we use moving averages instead of arithmetic means.

B.6 Connection to Generalization

Like in the proof of Lemma 3.1 in Wilson et al. (2017), we inductively show that θt=λtsign(XTy)\theta_{t}=\lambda_{t}\operatorname{sign}(X^{T}y) with a scalar λt\lambda_{t}. This trivially holds for θ0=0\theta_{0}=0. Assume that the assertion holds for all sts\leq t. Then

where the first step is the gradient of the objective (Eq. 26), the second step uses the inductive assumption, and the third step uses the assumption Xsign(XTy)=cyX\operatorname{sign}(X^{T}y)=cy. Now, plugging Eq. (62) into the update rule, we find

Appendix C Alternative Methods

m-svag applies variance adaptation to the update direction mtm_{t}, resulting in the variance adaptation factors Eq. 25. We can also update in direction gtg_{t} and choose the appropriate estimated variance adaptation factors, resulting in an implementation of svag without momentum. We have already derived the necessary variance adaptation factors en route to those for the momentum variant, see Eq. (23) in §4.2. Pseudo-code is provided in Alg. 3. It differs from m-svag only in the last two lines.

C.2 Variants of ADAM

This paper interpreted adam as variance-adapted m-ssd. The experiments in the main paper used a standard implementation of adam as described by Kingma & Ba (2015). However, in the derivation of our implementation of m-svag, we have made multiple adjustments regarding the estimation of variance adaptation factors which correspondingly apply to the sign case. Specifically, this concerns:

The use of the same moving average constant for the first and second moment (β1=β2=β\beta_{1}=\beta_{2}=\beta).

The bias correction in the gradient variance estimate, see Eq. (22).

The adjustment of the variance adaptation factors for the momentum case, see §4.3.

The omission of a constant offset ε\varepsilon in the denominator.

Applying these adjustment to the sign case gives rise to a variant of the original adam algorithm, which we will refer to as adam*. Pseudo-code is provided in Alg. 4. Note that we use the variance adaptation factors (1+η)1/2(1+\eta)^{-1/2} and not the optimal ones derived in §3.1, which would under the Gaussian assumption be erf[(2η)1]\operatorname{erf}[(\sqrt{2}\eta)^{-1}]. We initially experimented with both variants and found them to perform almost identically, which is not surprising given how similar the two are (see Fig. 3). We thus stuck with the first option for direct correspondence with the original adam and to avoid the cumbersome error function.

In analogy to svag versus m-svag, we could also define a variance-adapted version stochastic sign descent without momentum, i.e., using the base update direction sign(gt)\operatorname{sign}(g_{t}). We did not explore this further in this work.

C.3 Experiments

We tested svag as well as adam* with and without momentum on the problems (P2) and (P3) from the main paper. Results are shown in Figure 7.

We observe that svag performs better than m-svag on (P2). On (P3), it makes faster initial progress but later plateaus, leading to slightly worse outcomes in both training loss and test accuracy. svag is a viable alternative. In future work, it will be interesting to apply svag to problems where sgd outperforms m-sgd.

Next, we compare adam* to the original adam algorithm. In the cifar-100 example (P3) the two methods are on par. On (P2), adam is marginally faster in the early stages of the the optimization process. adam* quickly catches up and reaches lower minimal training loss values. We conclude that the adjustments to the variance adaptation factors derived in §4 do have a positive effect.

Appendix D Mini-Batch Gradient Variance Estimates

Several recent papers (Mahsereci & Hennig, 2015; Balles et al., 2017b) have used this variance estimate for other aspects of stochastic optimization. In contrast to the moving average-based estimators, this is an unbiased estimate of the local gradient variance. The (non-trivial) implementation of this estimator for neural networks is described in Balles et al. (2017a).

We explored a variant of m-svag which use mini-batch gradient variance estimates. The local variance estimation allows for a theoretically more pleasing treatment of the variance of the update direction mtm_{t}. Starting from the formulation of mtm_{t} in Eq. (57) and considering that gsg_{s} and gsg_{s^{\prime}} are stochastically independent for sss\neq s^{\prime}, we have

Given that we now have access to a true, local, unbiased estimate of var[gs]\mathbf{var}[g_{s}], we can estimate var[mt]\mathbf{var}[m_{t}] by

It turns out that we can track this quantity with another exponential moving average: It is sˉt=ρ(β,t)rt\bar{s}_{t}=\rho(\beta,t)r_{t} with

This can be shown by iterating Eq. (67) backwards and comparing coefficients with Eq. (66). The resulting mini-batch variant of m-svag is presented in Algorithm 5.

Note that mini-batch gradient variance estimates could likewise be used for the alternative methods discussed in §C. We do not explore this further in this paper.

D.2 Experiments

We tested the mini-batch variant of m-svag on the problems (P1) and (P2) from the main text and compared it to the moving average version. Results are shown in Figure 8. The two algorithms have almost identical performance.