How Does Adaptive Optimization Impact Local Neural Network Geometry?

Kaiqi Jiang, Dhruv Malik, Yuanzhi Li

Introduction

The efficient minimization of a parameterized loss function is a core primitive in statistics, optimization and machine learning. Gradient descent (GD), which iteratively updates a parameter vector with a step along the gradient of the loss function evaluated at that vector, is a simple yet canonical algorithm which has been applied to efficiently solve such minimization problems with enormous success. However, in modern machine learning, and especially deep learning, one frequently encounters problems where the loss functions are high dimensional, non-convex and non-smooth. The optimization landscape of such problems is thus extremely challenging, and in these settings gradient descent often suffers from prohibitively high iteration complexity.

To deal with these difficulties and improve optimization efficiency, practitioners in recent years have developed many variants of GD. One prominent class of these GD variants is the family of adaptive algorithms . At a high level, adaptive methods scale the gradient with an adpatively selected preconditioning matrix, which is constructed via a moving average of past gradients. These methods are reminiscent of second order gradient descent, since they construct approximations to the Hessian of the loss functions, while remaining computationally feasible since they eschew full computation of the Hessian. A vast line of empirical work has demonstrated the superiority of adaptive methods over GD to optimize deep neural networks, especially on Natural Language Processing (NLP) tasks with transformers .

From a theoretical perspective, adaptive methods are well understood in the traditional context of convex optimization. For instance, Duchi et al. show that when the loss function is convex, then the Adagrad algorithm yields regret guarantees that are provably as good as those obtained by using the best (diagonal) preconditioner in hindsight. The key mechanism that underlies this improved performance, is that the loss function has some global geometric property (such as sparsity or a coordinate wise bounded Lipschitz constant), and the algorithm adapts to this global geometry by adaptively selecting learning rates for features that are more informative.

However, in non-convex optimization, and deep learning in particular, it is highly unclear whether this simple characterization is sufficient to explain the superiority of adaptive methods over GD. Indeed, for large scale neural networks, global guarantees on the geometric properties of the loss are typically vacuous. For instance, for a 20-layer feedforward neural network, if we scale up the weights in each layer by a factor of 1.51.5, then the global Lipschitz constant of the network is scaled up by a factor of at least e10e^{10}. Hence it only makes sense to study convergence by looking at the local geometry of the loss along the trajectory of the optimization algorithm .

Moreover, the interaction between an optimization algorithm and neural network geometry is highly complex — recent work has shown that geometric characteristics of iterates encountered during optimization is highly dependent on the choice of optimization algorithm and associated hyperparameters . For instance, Cohen et al. demonstrate that while training neural networks with GD, the maximum eigenvalue of the Hessian evaluated at the GD iterates first increases and then plateaus at a level 2/(step size). The viewpoint from convex optimization, where a loss function has some (potentially) non-uniform but fixed underlying geometry that we must adapt to, is thus insufficient for neural networks, since the choice of optimization algorithm can actually interact with and influence the observed geometry significantly.

To provide another example of this interactive phenomenon, we consider the following experiment. On the same network training loss function ff, we run stochastic gradient descent with momentum (SGD+M) and Adam to obtain two different trajectories. We select an iterate xAdamx_{\text{Adam}} from the Adam trajectory and an iterate xSGDx_{\text{SGD}} from the SGD trajectory, such that f(xAdam)=f(xSGD)f(x_{\text{Adam}})=f(x_{\text{SGD}}). We then run SGD+M twice, once from xAdamx_{\text{Adam}} and once from xSGDx_{\text{SGD}}. If the underlying geometry of the loss function ff was truly fixed, then we would not expect a significant difference in the performance of running SGD+M from either of the two iterates. However, as shown in Figure 1(a), there is a noticeable difference in performance, and running SGD+M from xAdamx_{\text{Adam}} achieves lower loss than running SGD+M from xSGDx_{\text{SGD}}. This suggests that Adam may bias the optimization trajectory towards a region which is more favorable for rapid training. This motivates the following question.

How does adaptive optimization impact the observed geometry of a neural network loss function, relative to SGD (with momentum)?

This statistic thus measures the uniformity of the diagonal of the Hessian, where a smaller value of RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) implies that the Hessian has a more uniform diagonal. It can also be viewed as a stableConsider the case where one parameter has little impact on the loss, then the second derivative w.r.t. this parameter is almost zero, making max{Hii(t)}i=1dmin{Hii(t)}i=1d\frac{\max\{|H^{(t)}_{ii}|\}_{i=1}^{d}}{\min\{|H^{(t)}_{ii}|\}_{i=1}^{d}} infinity. So we consider median which is more stable. variant of the condition number. Instead of eigenvalues, we choose diagonal entries because adaptive methods used in practice are coordinate-wise, which can be viewed as the diagonal scaling approaches.Recall that the main theoretical bound in the original Adagrad paper is in terms of the diagonal scaling. Hence we believe the diagonal of Hessian is more relevant than the spectrum. As a supplementary result, in Appendix E, we demonstrate that the loss Hessian approaches diagonal during training for Adam and SGD+M. There has been prior theoretical work on overparameterized neural networks showing that a smaller condition number of Hessian, Neural Tangent Kernel etc. could yield to faster convergence rate for (S)GD . As for (diagonal) adaptive methods (e.g. Adagrad), they were original designed to adapt to the nonuniform diagonal geometry. Intuitively, a smaller RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t), which implies more uniform diagonal geometry, could lead to faster convergence.

Armed with this statistic, we make the following contributions:

On a wide variety of neural network transformer architectures and language modeling datasets, we conduct experiments to compare how RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t) and RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t) evolve over time, when Adam and SGD+M are run from the same initialization and with their optimal (initial) learning rates respectively. In each case, we demonstrate that the Adam trajectory attains RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t) values that are significantly smaller than the RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t) values found by SGD+M. We show a simple example of this phenomenon in Figure 1(b). This suggests that relative to SGD+M, Adam biases the optimization trajectory to a region where the Hessian diagonal is more uniform. We call this phenomenon the uniformity of diagonal geometry for adaptive methods. As an aside, we observe that larger improvements in performance of Adam over SGD+M are correlated with larger gaps between RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t) and RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t). This suggests that a region where the Hessian diagonal is more uniform is also a region that is more amenable to rapid optimization.

We complement our empirical results with a theoretical analysis of this phenomenon in the simplified setting of large batch Adam and SGD+M, on a two-layer linear network with dd-dimensional input and hidden layer, and one dimensional output. We show that for a wide range of tt, RmedAdam(t)=1±o(1)R_{\text{med}}^{\text{Adam}}(t)=1\pm o(1) but RmedSGDM(t)=Ω(logd)R_{\text{med}}^{\text{SGDM}}(t)=\Omega(\log d). Our proof reveals that Adam induces the weight matrices to have low rank whose leading singular vectors have certain type of uniformity (see Section 6 for discussion), a fact that we also observe empirically in large scale neural networks, suggesting that this may be a mechanism by which adaptive methods bias trajectories to have uniformity of diagonal geometry.

Related work

Existing analyses of adaptive methods. The vast majority of prior theoretical work on adaptive methods has focused on the blackbox setting . These works make minimal assumptions about the structure of the loss function, beyond (possibly) some global properties such as convexity or smoothness. These global properties (governed by parameters such as the smoothness parameter) are assumed to hold over the entire domain. Hence this style of analysis is worst case, since the resulting convergence bounds depend on polynomially on these global parameters. However, as we show in Section 3.1, in neural networks these parameters are prohibitively large. This worst case analysis is hence unlikely to explain the success of adaptive methods on neural networks. By contrast, our focus is on analyzing the local trajectory that is induced by running the optimization method.

Existing analyses of (S)GD on neural networks. There is an extensive literature on the analysis of GD/SGD in the non-blackbox setting, e.g. overparameterized neural networks, . However, it is unclear how to translate these analyses of GD/SGD, to an analysis that explains the gap between GD/SGD and adaptive methods.

Influence of algorithms on the loss geometry. In many simple convex settings, e.g. linear or logistic regression and the Neural Tangent Kernel , the loss geometry is usually fixed and not influenced by learning algorithms. However, in neural networks the interaction between algorithms and loss landscapes is more complicated. Lewkowycz et al. find a so-called catapult effect of initial learning rate on the training trajectory of SGD and related loss curvature. Cohen et al. demonstrate that while training neural networks with GD, the maximum eigenvalue of the Hessian evaluated at the GD iterates first increases and then plateaus at a level that is inversely proportional to the step size. However, Cohen et al. leave open the problem of whether similar interactive phenomena occur in algorithms that are not GD, including adaptive methods.

Overview of results and setup

As is mentioned in Section 2, existing work on adaptive algorithms has mainly focused on black-box analysis assuming some global worst-case parameters. However, these global bounds can be extremely bad in complicated deep learning models, as is discussed in Section 1. To see this, we initialized a transformer modelhttps://pytorch.org/tutorials/beginner/transformer_tutorial.html with default initialization in Pytorch but chose a large gainThis refers to the gain parameter in some commonly used initialization functions of Pytorch, e.g. torch.nn.init.xavier_uniform_()., and computed the smoothness parameter (denoted as ll) and the condition number (denoted as κ\kappa) of loss Hessian on one layer. We observed that setting the gain as a large constant (e.g. 800) results in extremely large ll and κ\kappa (l107l\geq 10^{7} and κ1010\kappa\geq 10^{10}), which makes the convergence rates in prior black-box analysis vacuous.

The failure of global worst-case analysis implies that we need to focus on the local trajectory of algorithms. However, it is unclear that when two optimization algorithms are used, they will have the same geometry in local trajectory. In particular, although in theory, adaptive algorithms can yield to a convergence rate with better dependency on certain local geometry of the function comparing to SGD (with momentum), it could still be the case that the local geometry along the trajectory of adaptive algorithm can be much worse than that of SGD (with momentum).

That motivates us to study the local geometry, especially that obtained by adaptive methods comparing to SGD (with momentum) in the paper. Motivated by the diagonal scaling of Adagrad and Adam for neural network training, we ask the follow main question in our paper:

How does the local diagonal geometry (diagonal of the loss Hessian) along the local trajectory of adaptive algorithms compare to that of SGD (with momentum)?

2 Overview of the experiments

As is discussed in Section 1, we consider RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) defined in eq. (1) as a measurement of the uniformity of the diagonal of the loss Hessian. We conduct experiments on different NLP tasks to examine RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t), as in language models, adaptive methods have shown significantly faster convergence than SGD (with momentum). The details of these experiments will be shown in Section 4. To explore potential different patterns of different layers, we do the computation layer by layer. On a wide variety of transformer architectures and language modeling datasets from the same initialization, we observe that:

When we train the neural network using Adam, the uniformity of diagonal geometry, measured by RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) is smaller than that when we train using SGD+M from the same initialization, except for first several layers.

Table 1 shows a typical example of RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t) compared to RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t) on a sentence classification task using BERT-small (see Section 4.1 for details). We repeated the experiments for 12 times starting from the same initialization. Table 1 shows the averaged RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t) and RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t) in some randomly selected layers (except for the first several). We also report the averaged RmedSGDM(t)RmedAdam(t)\frac{R_{\text{med}}^{\text{SGDM}}(t)}{R_{\text{med}}^{\text{Adam}}(t)} and their standard deviations in the brackets.RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t) values in Table 1 for most layers are roughly 1.4 to 2 times RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t) in corresponding layers. In practice, it can be considered significant because it might imply 1.4 to 2 times faster convergence. Figure 2 shows the corresponding training losses of one in these 12 experiments.

To understand this phenomenon in a more principled point of view, we also provide a formal proof of the statement in a simplified setting: large batch Adam and SGD+M on a two-layer linear network. Although simple, the choice of two-layer linear network to understand learning dynamics is common in prior works (e.g. ). Section 3.3 below describes the theoretical setup.

3 Setup of the theoretical analysis

Let [d]={1,2,...,d}[d]=\{1,2,...,d\}. We use 2\|\cdot\|_{2} to denote the l2l_{2} norm of a vector, and F\|\cdot\|_{F} to denote the Frobenius norm of a matrix. Let ,\langle\cdot,\cdot\rangle be the Euclidean inner product between vectors or matrices. Let N(μ,σ2)\mathcal{N}(\mu,\sigma^{2}) be the one-dimensional Gaussian distribution with mean μ\mu and variance σ2\sigma^{2}. For a scalar (vector, matrix) AA which evolves over time, we use A(t)A^{(t)} to denote its value at time tt.

where cc does not depend on WW. We consider the following model with small Gaussian initialization.

i,j:w2i(0)N(0,1d2α),W1(0)[i,j]N(0,1d4α)\forall i,j:w_{2i}^{(0)}\sim\mathcal{N}(0,\frac{1}{d^{2\alpha}}),W_{1}^{(0)}[i,j]\sim\mathcal{N}(0,\frac{1}{d^{4\alpha}}) are independently initialized with sufficiently large α>0\alpha>0.

where η\eta is the learning rate, β,β1,β2\beta,\beta_{1},\beta_{2} are momentum parameters, and ξ\xi is for numerical stability. All operations on vectors are element-wise.

The uniformity of diagonal geometry

As is mentioned in Section 3.2, we computed RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) defined in eq. (1) on different language models. In this section, we present the results of SGD+M and Adam on different architectures and datasets. In Appendix A, we present the results of other adaptive algorithms.

During training we started from the same initial weights and used the same learning rate schedule (constant or decreasing) for SGD+M and Adam. We tuned and chose the best (initial) learning rate of SGD+M. The (initial) learning rate of Adam was set as a value under which Adam converged faster than SGD+M with its best learning rate. The concrete values will be stated in later parts of this section. We used large batch sizes to make the training procedure stable. When computing Hessian, we also used large batch sizes. Due to the extremely large dimension, we did the computation on some uniformly selected coordinates, more precisely, 200 coordinates per layer.

We fine-tuned BERT-small on the IMDB dataset : the task is to classify whether movie reviews are positive or negative.https://huggingface.co/docs/transformers/v4.16.2/en/training The momentum parameter β\beta in SGD was set as 0.9. The two momentum parameters (β1,β2)(\beta_{1},\beta_{2}) of Adam were set as (0.9, 0.999). We trained the model using linearly decreasing learning rates for 10 epochs (2500 iterations). The initial learning rates of SGD+M and Adam were 0.001 and 5e-5, respectively. As mentioned in Section 3.2, Figure 2 and Table 1 show the training losses and the comparison between RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t) and RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t).

We trained a Seq2Seq network that uses Transformer to solve a machine translation task on Multi30k (CC BY-NC-SA 4.0): this task is to train a German to English translation model.https://pytorch.org/tutorials/beginner/translation_transformer.html The momentum parameter β\beta in SGD was set as 0.9. The two momentum parameters (β1,β2)(\beta_{1},\beta_{2}) of Adam were set as (0.9, 0.98). We trained the model using constant learning rates (0.03 for SGD+M and 1e-4 for Adam) for 60 epochs (1800 iterations). The experiments were repeated for 8 times starting from the same initialization. Figure 3(a) shows the training losses for one among them. Table 2(a) shows the averaged RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t), RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t) and RmedSGDM(t)RmedAdam(t)\frac{R_{\text{med}}^{\text{SGDM}}(t)}{R_{\text{med}}^{\text{Adam}}(t)} (with standard deviation in the brackets) in some randomly selected layers.

2 Experiments on random datasets

We used the same model and momentum parameters as in the translation task described in Section 4.1 but generated random integers as targets. Similar to the setting on real targets, the model was trained using constant learning rates (0.015 for SGD+M and 5e-5 for Adam) for 60 epochs (1800 iterations), and we repeated the experiments for 8 times starting from the same initialization. Figure 3(b) shows the training losses for one among them. Table 2(b) shows the averaged RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t), RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t) and RmedSGDM(t)RmedAdam(t)\frac{R_{\text{med}}^{\text{SGDM}}(t)}{R_{\text{med}}^{\text{Adam}}(t)} (with standard deviation in the brackets) of the same 10 layers as in Table 2(a).To prevent RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) from getting too large due to tiny median, we added an additional term 0.001max{Hii(t)}i=1d0.001\max\{|H^{(t)}_{ii}|\}_{i=1}^{d} to the denominator of eq. (1) when computing.

3 How the (adaptive) gradient aligns with diagonal of loss Hessian

In this section we present the uniformity of diagonal geometry of adaptive methods from another perspective. Denote HiiH_{ii} as the (i,i)(i,i)-th element of the loss Hessian HH and gig_{i} as the ii-th element of the gradient. It is conjectured that when Hii|H_{ii}| is large, the corresponding gi|g_{i}| is usually large as well. For adaptive methods, we can regard the update per step as the learning rate times the “adaptive gradient”. Let’s use gadapt,ig_{\text{adapt},i} to represent the ii-th component of the adaptive gradient. Through experiments on language models, we find that gadapt,i|g_{\text{adapt},i}| for different ii are quite uniform and do not align with Hii|H_{ii}| as the true gradient gi|g_{i}| does.

4 Summarization of the empirical results and discussion

Overall, through extensive experiments on language models, we demonstrate that starting from the same initialization, the RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) values found by Adam are smaller than those found by SGD+M, except for the first several layers. This suggests that Adam is biased towards a region with more uniform diagonal Hessian than SGD+M.

We observe that on random dataset, SGD+M plateaus after about 400 steps and thus converges much slower when compared to Adam than on real dataset (see Figure 3(a) and Figure 3(b)). On the other hand, the gaps of RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t) and RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t) are more significant on random data than on real data (see Table 2(a) and Table 2(b)) as well. In Appendix A.4, we conduct another experiment where we switch from SGD to Adam in the middle and compare it with the model trained by Adam from the beginning. The observation is that both the loss gap and the gap of RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) are gradually closed after switching (see Figure 8 and Table 8). Hence we find a positive correlation between fast convergence and the uniformity of diagonal of loss Hessian, suggesting that a region with more uniform diagonal of Hessian is also a region that is more amenable to fast optimization. In Appendix A we study other adaptive algorithms (Adagrad, RMSprop and AMSGrad) and get similar observation: all these adaptive methods converge faster than SGD or SGD+M and also bias the trajectory to a region with smaller RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t), suggesting that the uniformity of diagonal Hessian might be a universal mechanism (partially) explaining the faster optimization of adaptive algorithms than SGD (with momentum).

Considering the fact that our comparison between RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t) and RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t) is conditioned on the same iteration when SGD+M has larger training loss than Adam, there is a potential alternative explanation of the Hessian diagonal uniformity. That is, the global minimum has uniform Hessian, and Adam simply converges faster to it than SGD+M, thus giving the appearance that it induces better geometry. To rule out this possibility, in Appendix A.3 we add a comparison of our measurements RmedAdam(t)R^{\text{Adam}}_{\text{med}}(t) and RmedSGDM(t)R^{\text{SGDM}}_{\text{med}}(t^{\prime}), where t,tt,t^{\prime} are picked such that ttth Adam iterate and tt^{\prime}th SGD+M iterate have the same training loss. The results (in Table 7) show that RmedAdam(t)<RmedSGDM(t)R^{\text{Adam}}_{\text{med}}(t)<R^{\text{SGDM}}_{\text{med}}(t^{\prime}) for most layers, thus demonstrating that the trajectories of Adam and SGD+M are truly different and that the difference is because Adam biases the local geometry (as opposed to faster convergence).

People in practice usually add weight decay (equivalent to l2l_{2} regularization) to encourage better generalization ability. In Appendix A.7 we compare SGD+M and Adam when both using small weight decay values (0.001). The results in Figure 13(a) and Table 9 suggest that in this case, the relationship between RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) and convergence speed still holds: Adam converges faster than SGD+M and in most of the layers except for the first several, RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t) values are smaller than RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t). This reveals the robustness of our observation under weak regularization. However, under large weight decay parameters, we observed cases where Adam still converged faster but RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t) values were larger rather than smaller. In the case of strong regularization, the adaptivity of Adam requires further exploration and we hope to find new mechanisms in the future.

Although in this paper we focus on language models where Adam shows significant fast convergence, we also add supplementary results in Appendix A.8 on image tasks where SGD+M performs better. On a residual network trained on CIFAR-10, we observed that Adam did not converge faster than SGD+M (see Figure 13(b)) and in the meantime, RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t) values were no longer smaller than RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t) during training (see Table 10). This reveals the connection between the local diagonal geometry and the convergence speed from another perspective. That is, when the diagonal of Hessian of Adam is not more uniform than SGD+M, its convergence speed is not better, either. In summary, all the observations on language and image tasks together suggest a positive correlation between the uniformity of diagonal Hessian and fast optimization.

Theoretical analysis

In Section 4, we empirically demonstrate the uniformity of diagonal geometry. In this section, we theoretically analyze this property for large batch Adam and SGD+M on a two-layer linear network with 1-dimensional output.

Since the weights and Hessians in different layers may have different magnitudes, we compute the RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) layer by layer. We denote Rmed,kSGDM(t)R_{\text{med},k}^{\text{SGDM}}(t) (resp. Rmed,kAdam(t)R_{\text{med},k}^{\text{Adam}}(t)) as the RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) found by SGD+M (resp. Adam) w.r.t. WkW_{k} at time tt where k=1,2k=1,2.

Under Assumption 1, 2 and 3, consider the weights {WSGD(t)}t0\left\{W^{(t)}_{\text{SGD}}\right\}_{t\geq 0} (resp. {WAdam(t)}t0\left\{W^{(t)}_{\text{Adam}}\right\}_{t\geq 0}) obtained by SGD+M (resp. Adam) defined in (3).

An immediate corollary of this theorem below gives the difference between iterates of Adam and SGD+M that have the same loss.

The low rank structure of weight matrices and uniformity of leading singular vectors

The proof sketch in Appendix B highlights one crucial intuition of Theorem 1: After TSGD,1T_{\text{SGD},1} (resp. TAdam,1T_{\text{Adam},1}) steps, W1W_{1} of SGD+M (resp. Adam) becomes an approximately rank-1 matrix. Consider the left singular vector u:=[u1,u2,...,ud]T\boldsymbol{u}:=[u_{1},u_{2},...,u_{d}]^{T} which corresponds to the leading singular value σ1\sigma_{1}. We can show that the distribution of u12,u22,...,ud2u_{1}^{2},u_{2}^{2},...,u_{d}^{2} for Adam is more uniform than that of SGD+M. This property, we call the uniformity of the leading singular vector, is related to the uniformity of the diagonal of loss Hessian, see Appendix F for more details.

After reviewing the weight matrices we got in different settings, we observed that (A) and (B) hold for many layers in those models. For example, on the translation task mentioned in Section 4.1, we found 12 layers which have approximately low rank structures and for 10 of them, RuR_{u} values (defined in (B)) obtained by Adam are smaller than those found by SGD+M. Figure 5 shows the result on one typical layer. Results of more layers can be found in Appendix A.5.

Although in multi-layer nonlinear neural networks, the connection between diagonal of loss Hessian and the weight matrices is more complicated and Rmed,2OPT(t)R_{\text{med},2}^{\text{OPT}}(t) may depend on the product of many weight matrices rather than one single matrix, we still believe that this definition of RuR_{u} is a reasonable ratio to consider.

Conclusion and future work

We demonstrate that adaptive optimization methods bias the training trajectory towards a region where the diagonal of loss Hessian is more uniform, through extensive experiments on language models and theoretical analysis in a simplified setting of two-layer linear networks. Although our findings may not directly lead to an improved algorithm for practical use, they provide a new way of thinking when designing new algorithms: in contrast with the traditional view which tries to design a method that performs better in the bad loss geometry, our findings suggest that we can design algorithms which implicitly avoid regions with bad geometry. There are a lot of future directions along this line. For example, our theoretical results on the two-layer linear networks may be able to generalize to multi-layer networks. In fact, people conjecture that the key-value-query structure in language models can be approximated by a three-layer linear network. Hence the generalization to multi-layer networks might provide more connection to real deep models and could be an interesting and challenging future direction. Moreover, it is also possible to relax our large-batch assumption (Assumption 3) and prove similar results in the general stochastic setting.

References

Appendix A More experiments of the uniformity of diagonal geometry

In this section, we present the RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) values defined in eq. (1) obtained by SGD and Adagrad on a language modeling taskhttps://pytorch.org/tutorials/beginner/transformer_tutorial.html. The task is to assign a probability for the likelihood of a given word (or a sequence of words) to follow a sequence of words. We trained a transformer model to solve this problem on both Wikitext-2 (CC BY-SA 3.0) and random dataset (generating random integers as targets). This model has roughly 8 layers (not counting normalization and dropout layers)

The setup is the same as in Section 3.2. We used the same learning rate schedule (constant or decreasing) for SGD and Adagrad. We tuned and chose the best (initial) learning rate of SGD. The (initial) learning rate of Adagrad was set as a value under which Adagrad converged faster than SGD with its best (initial) learning rate. We used large batch sizes to make the training procedure more stable. When computing Hessian, we also used large batch sizes. Due to the extremely large dimension, we did the computation on some uniformly selected coordinates, more precisely, 200 coordinates per layer.

We tried different initialization (normal and uniform) by using different gains of the Pytorch initialization schedule.

Figure 6(a) shows the training losses on real dataset (wikitext-2). Table 3 (resp. Table 4) shows the RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) for Adagrad and SGD under uniform (resp. normal) initialization with different gains.

A.1.2 Experiments on random dataset

Figure 6(b) shows the training losses on random dataset and Table 5 shows the RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) in different layers.

A.2 RMSprop and AMSGrad

In this section, we present the results of RMSprop and AMSGrad and compare them with SGD+M. The experiments are conducted on the translation task described in Section 4.1. The learning rates we used were 0.000025 for RMSprop, 0.0005 for AMSGrad and 0.03 for SGD+M. Both RMSprop and SGD+M used momentum parameter 0.9. The two momentum parameters (β1,β2)(\beta_{1},\beta_{2}) of AMSGrad are (0.9,0.98)(0.9,0.98). Figure 7 shows the training losses and Table 6 shows the corresponding RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t).

A.3 Comparison conditioned on the same loss

In this section, we compare RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t) and RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t) conditioned on the same training loss. More precisely, we make comparison of RmedAdam(t)R^{\text{Adam}}_{\text{med}}(t) and RmedSGDM(t)R^{\text{SGDM}}_{\text{med}}(t^{\prime}), where t,tt,t^{\prime} are picked such that ttth Adam iterate and tt^{\prime}th SGD+M iterate have the same training loss. The details of the tasks are described in in Section 4.1. Table 7 shows the results of RmedAdam(t)R^{\text{Adam}}_{\text{med}}(t) and RmedSGDM(t)R^{\text{SGDM}}_{\text{med}}(t^{\prime}) in some layers.

A.4 Experiments of switching from SGD to Adam

In this section we describe another learning schedule: the “Adam after SGD” schedule, where we switched from SGD to Adam in the middle to see whether the loss and RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) can catch up with the model trained by Adam from the very beginning. Again, we used the same model as in the translation task in Section 4.1. In this section, we did not add momentum term to SGD in order to get a larger gap between SGD and Adam than the case using momentum. We want to see whether this larger gap can be closed after switching to Adam in the middle.

As is shown in Figure 8 and Table 8, both the loss gap and the gap of RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) were closed after a period of training after switching algorithms, which provides evidence of the connection between convergence speed and uniformity of diagonal of loss Hessian.

A.5 The low rank structure

In this section, we present more results for the experiments in Section 6.

We examined the weights of the model trained for the translation task in Section 4.1. Among roughly 30 layers, we observed that for 12 layers, at least the weight matrices obtained by Adam after training have approximately low rank structures.

Figure 9 shows the examples of layers with or without the low rank structure.

We then studied the uniformity of leading singular vectors of these 12 layers, i.e. computed RuR_{u} and RvR_{v} defined in (B) and the second remark of Section 6. The observation is that for 10 out of these 12 layers, RuR_{u} values of Adam are smaller those of SGD, which implies the uniformity of leading left singular vectors of Adam. However, we did not observe significant uniformity for Adam in terms of leading right singular vectors (RvR_{v}). The second remark of Section 6 discusses possible reasons.

Figure 10 shows how RuR_{u} and RvR_{v} changed over time in some layers.

A.6 How the (adaptive) gradient aligns with diagonal of loss Hessian

In this section, we present more empirical results on how the (adaptive) gradient aligns with diagonal of loss Hessian. The detailed setup is described in Section 4.3.

Here we compare SGD and Adagrad on the language modeling task on wikitext-2 described in Section A.1. We observed that the figures of all layers are quite similar so we select one layer as an example, as is shown in Figure 11.

A.6.2 SGD with momentum vs. Adam

Figure 4 presents the comparison between Adam and SGD with momentum on the sentence classification task using BERT-small. Here we add more results of the comparison of these two algorithms on the translation task described in Section 4.1. Again, we select one layer as an example, as is shown in Figure 12.

A.7 Adding regularization and other tricks

In this section, we add weight decay to both Adam and SGD+M on the translation task described in Section 4. The momentum parameter β\beta in SGD was set as 0.9. The two momentum parameters (β1,β2)(\beta_{1},\beta_{2}) of Adam were set as (0.9, 0.98). For both algorithms, we set the weight decay parameter as 0.001. We trained the model using constant learning rates for 60 epochs (1800 iterations). We tuned and chose the best learning rate 0.03 for SGD+M. The learning rate of Adam was set as 0.0001, under which Adam converged faster than SGD+M with its best learning rate 0.03. Figure 13(a) shows the training losses and Table 9 shows the values of RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t), RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t) and RmedSGDM(t)RmedAdam(t)\frac{R_{\text{med}}^{\text{SGDM}}(t)}{R_{\text{med}}^{\text{Adam}}(t)} in some randomly selected layers.

A.8 Results on image tasks

We trained a ResNetWe borrowed the implementation here https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/ and replace the “layers” array with . on CIFAR-10 dataset and compared the convergence speed and RmedOPT(t)R_{\text{med}}^{\text{OPT}}(t) of SGD+M and Adam. The momentum parameter β\beta in SGD was set as 0.9. The two momentum parameters (β1,β2)(\beta_{1},\beta_{2}) of Adam were set as (0.9, 0.98). The model was trained using constant learning rates for 41 epochs (2050 iterations). We tuned and chose the best learning rates for both algorithms: 0.5 for SGD+M and 0.005 for Adam. Figure 13(b) shows the training losses and Table 10 shows the values of RmedAdam(t)R_{\text{med}}^{\text{Adam}}(t), RmedSGDM(t)R_{\text{med}}^{\text{SGDM}}(t) and RmedSGDM(t)RmedAdam(t)\frac{R_{\text{med}}^{\text{SGDM}}(t)}{R_{\text{med}}^{\text{Adam}}(t)}.

Appendix B Proof sketch of Theorem 1

Now we give a proof sketch of Theorem 1, which contains three major steps. The detailed proof can be found in Appendix F, C and D.

First we relate the diagonal of Hessian to weight matrices W1,W2W_{1},W_{2}. Under Assumption 1, denote W1[i,:]W_{1}[i,:] as the ii-th row of W1W_{1} and W2:=[w2i,w22,...,w2d]W_{2}:=[w_{2i},w_{22},...,w_{2d}]. Since the input dataset is whitened, we can show that

Next, due to the one-dimensional output, we can prove that W1W_{1} converges to an approximately rank-1 matrix. More precisely, we have

Using the rank 1 structure, we can further simplify Rmed,1OPT(t)R_{\text{med},1}^{\text{OPT}}(t) and Rmed,2OPT(t)R_{\text{med},2}^{\text{OPT}}(t) by

The final step is the detailed analysis of u(t)\boldsymbol{u}^{(t)}.

Appendix C Analysis of SGD+M

Note that A=1mYXTA=\frac{1}{m}YX^{T}, Λxx:=1mXXT\Lambda_{xx}:=\frac{1}{m}XX^{T}. Denote gk(t):=WkL(W(t)),k=1,2g_{k}^{(t)}:=\nabla_{W_{k}}L(W^{(t)}),k=1,2. We have that

Based on the magnitude of W2W_{2} and W1W_{1}, we can intuitively divide the training procedure into 2 phases.

First phase: the first several iterations when W1W_{1} and W2W_{2} are “small” so that W2W1AAW_{2}W_{1}-A\approx-A.

Second phase: later iterations when W2W1W_{2}W_{1} cannot be ignored.

More formally, the boundary between the first and second phase is defined below.

The end of the first phase (denoted as T1T_{1}) is defined as T1:=inf{t0:i,j[d]:w2i(t)1dα2or W1(t)[i,j]1dα2}T_{1}:=\inf\left\{t\geq 0:\exists i,j\in[d]:\left|w_{2i}^{(t)}\right|\geq\frac{1}{d^{\frac{\alpha}{2}}}\text{or }\left|W_{1}^{(t)}[i,j]\right|\geq\frac{1}{d^{\frac{\alpha}{2}}}\right\}.

By Assumption 2 and the assumption that j[d]:Aj>0,Aj=Θ(1)\forall j\in[d]:A_{j}>0,A_{j}=\Theta(1), at the beginning, w.h.p., j[d]:(W2W1)jAj<0\forall j\in[d]:(W_{2}W_{1})_{j}-A_{j}<0. During the training, each (W2W1)j(W_{2}W_{1})_{j} increases and approaches AjA_{j}. We hope that by choosing a small learning rate, when (W2W1)j(W_{2}W_{1})_{j} overshoots for some coordinate jj, i.e. (W2W1)j>Aj(W_{2}W_{1})_{j}>A_{j}, it will be close to convergence. To analyze this overshooting issue more carefully, let’s first define the following “almost overshooting time”.

For ϵ>0\epsilon>0, denote ϵ0:=1d14α1+ϵlogdϵ\epsilon_{0}:=\frac{1}{d^{\frac{1}{4}\alpha-1}}+\epsilon\log\sqrt{\frac{d}{\epsilon}}. Define T2:=inf{t0:j[d]:(W2(t)W1(t))jAjϵ0}T_{2}:=\inf\left\{t\geq 0:\exists j\in[d]:\left(W_{2}^{(t)}W_{1}^{(t)}\right)_{j}-A_{j}\geq-\sqrt{\epsilon_{0}}\right\}.

For ϵ>0\epsilon>0, we define the “convergence time” T3:=inf{t0:E(t)22ϵ}T_{3}:=\inf\left\{t\geq 0:\left\|E^{(t)}\right\|^{2}_{2}\leq\epsilon\right\}.

We can first show that after the first phase, i.e. when t=T1t=T_{1}, W1W_{1} will become an approximately rank-1 matrix, as described in the following lemma.

Under Assumption 1, 2 and 3, suppose ση3/2dα/2+1\sigma\leq\frac{\eta^{3/2}}{d^{\alpha/2+1}}. By picking ηO(1dα)\eta\leq\mathcal{O}\left(\frac{1}{d^{\alpha}}\right), we have that when t=T1t=T_{1}, Lˉ(W(T1))=Θ(d)\bar{L}\left(W^{(T_{1})}\right)=\Theta(d), and that

The following lemma tells us that this approximate rank-1 structure is preserved when T1tmin{T2,T3}T_{1}\leq t\leq\min\{T_{2},T_{3}\}.

Under Assumption 1, 2 and 3, suppose ση3/2dα/2+1\sigma\leq\frac{\eta^{3/2}}{d^{\alpha/2+1}}. By picking ηO(ϵd7α4+4)\eta\leq\mathcal{O}\left(\frac{\epsilon}{d^{\frac{7\alpha}{4}+4}}\right), we have that w.h.p. for T1tmin{T2,T3}T_{1}\leq t\leq\min\{T_{2},T_{3}\},

and ϵ0\epsilon_{0} is defined in Definition 2. Moreover, when t=min{T2,T3}t=\min\{T_{2},T_{3}\}, Lˉ(W(t))=O(ϵ0d)\bar{L}\left(W^{(t)}\right)=\mathcal{O}(\epsilon_{0}d).

The following lemma gives us a more detailed description of u(T1)\boldsymbol{u}^{(T_{1})}.

Now we are ready to prove the SGD+M part of Theorem 1.

Here Xi,i[d]X_{i},i\in[d] are i.i.d Gaussian random variables by Lemma 3. To prove the concentration of median Xi2\text{median }X_{i}^{2}, we borrow the Proposition 12 in Chapter 2.3 of . By setting K=N=dK=N=d in this proposition, we have

That means with high probability, median Xi2Cσ2\text{median }X_{i}^{2}\leq C\sigma^{2} for some C>0C>0. By Lemma 34 in Appendix G, we know that w.h.p.

Hence we have proved that Rmed,1SGDM(t),Rmed,2SGDM(t)Ω(logd)R_{\text{med},1}^{\text{SGDM}}(t),R_{\text{med},2}^{\text{SGDM}}(t)\geq\Omega(\log d).

C.2 Proof of Lemma 1

In the first phase, W2W1W_{2}W_{1} is “small”, and we write the update equations in the following way

The following lemma gives us an explicit formula of W2(t)W_{2}^{(t)}.

Let λ1<λ2\lambda_{1}<\lambda_{2} be the two roots of the quadratic equation x22x+1η2(1β)2A22=0x^{2}-2x+1-\frac{\eta^{2}}{(1-\beta)^{2}}\|A\|_{2}^{2}=0. Pick η<1βA2\eta<\frac{1-\beta}{\|A\|_{2}}, then we have that

where C1=W2(1)λ2W2(0)λ2λ1C_{1}=-\frac{W_{2}^{(1)}-\lambda_{2}W_{2}^{(0)}}{\lambda_{2}-\lambda_{1}}, C2=W2(1)λ1W2(0)λ2λ1C_{2}=\frac{W_{2}^{(1)}-\lambda_{1}W_{2}^{(0)}}{\lambda_{2}-\lambda_{1}}. r5(t)r_{5}^{(t)} will be specified in the proof.

We can prove that in the first phase, r5(t)r_{5}^{(t)} is “small”. More specifically, denote its ii-th coordinate as r5i(t)r_{5i}^{(t)}, and the ii-th coordinate of C2C_{2} as C2iC_{2i}. Then the following lemmas tell us that i[d],r5i(t)O(1dp(α))\forall i\in[d],\left|r_{5i}^{(t)}\right|\leq\mathcal{O}\left(\frac{1}{d^{p(\alpha)}}\right), where w.h.p. O(1dp(α))mini[d]C2i\mathcal{O}\left(\frac{1}{d^{p(\alpha)}}\right)\ll\min_{i\in[d]}|C_{2i}|.

We first have the following bounds of r1i(t),r2i(t)\left|r_{1i}^{(t)}\right|,\left|r_{2i}^{(t)}\right| and r5i(t)\left|r_{5i}^{(t)}\right| for i[d]i\in[d].

Next we prove upper and lower bounds of C1i|C_{1i}| and C2i|C_{2i}| for i[d]i\in[d].

Now we are ready to prove Lemma 1. Lemma 4 tells us that

where λ1=1η1βA2\lambda_{1}=1-\frac{\eta}{1-\beta}\|A\|_{2} and λ2=1+η1βA2\lambda_{2}=1+\frac{\eta}{1-\beta}\|A\|_{2}.

Under the conditions of Theorem 1 and pick ηO(1dα)\eta\leq\mathcal{O}\left(\frac{1}{d^{\alpha}}\right), by Lemma 6 and 7, we know that w.h.p. tT1,1id\forall t\leq T_{1},\forall 1\leq i\leq d,

We first prove that w2i(t)\left|w_{2i}^{(t)}\right| reaches 1dα/2\frac{1}{d^{\alpha/2}} for some coordinate ii before W1(t)[k,j]\left|W_{1}^{(t)}[k,j]\right| for k,j[d]\forall k,j\in[d]. To see this, first note that

where v(t)T=η1βA\boldsymbol{v}^{(t)T}=\frac{\eta}{1-\beta}A and

For tT1t\leq T_{1}, by eq. (7), we get that w.h.p.,

Here we used i[d]:Ai=Θ(1)\forall i\in[d]:A_{i}=\Theta(1) by Assumption 1.

Further we notice that for tT1t\leq T_{1}, we have j[d]\forall j\in[d],

which yields that ui(t)vj(t)O(1d)c(t)ui(t)\left|u_{i}^{(t)}v_{j}^{(t)}\right|\leq\mathcal{O}\left(\frac{1}{\sqrt{d}}\right)\left|c^{(t)}u_{i}^{(t)}\right|. Together with eq. (9) gives us that w2i(t)\left|w_{2i}^{(t)}\right| reaches 1dα/2\frac{1}{d^{\alpha/2}} for some i[d]i\in[d] before W1(t)[k,j]\left|W_{1}^{(t)}[k,j]\right| for k,j[d]\forall k,j\in[d], i.e. T1=inf{t0:i[d]:w2i(t)1dα2}T_{1}=\inf\left\{t\geq 0:\exists i\in[d]:\left|w_{2i}^{(t)}\right|\geq\frac{1}{d^{\frac{\alpha}{2}}}\right\}.

Further, we know that at time T1T_{1}, c(T1)ui0(T1)=C2i0λ2T1=Θ(1dα/2)\left|c^{(T_{1})}u_{i_{0}}^{(T_{1})}\right|=|C_{2i_{0}}|\lambda_{2}^{T_{1}}=\Theta\left(\frac{1}{d^{\alpha/2}}\right) for some i0[d]i_{0}\in[d], which means w.h.p.

This is the length of the first phase. As for c(T1)ui(T1)c^{(T_{1})}u_{i}^{(T_{1})} and ui(T1)vj(T1)u_{i}^{(T_{1})}v_{j}^{(T_{1})} for other coordinates, we have that w.h.p. 1i,jd\forall 1\leq i,j\leq d,

Finally, we consider the loss. Since j[d]:(W2(T1)W1(T1))jAj=Θ(1)\forall j\in[d]:\left(W_{2}^{(T_{1})}W^{(T_{1})}_{1}\right)_{j}-A_{j}=-\Theta(1), we know that Lˉ(W(T1))=Θ(d)\bar{L}\left(W^{(T_{1})}\right)=\Theta(d).

C.3 Proof of Lemma 3

C.4 Proof of Lemma 4

Replacing tt by t1t-1 in eq. (6), we get

Eq. (6)-(11) and substituting eq. (5) yield

where r3(t):=η2(1β)2Ar1(t1)T+η1β(r2(t)r2(t1))r_{3}(t):=\frac{\eta^{2}}{(1-\beta)^{2}}Ar_{1}^{(t-1)T}+\frac{\eta}{1-\beta}\left(r_{2}^{(t)}-r_{2}^{(t-1)}\right).

For the equation x22x+1η2(1β)2A22=0x^{2}-2x+1-\frac{\eta^{2}}{(1-\beta)^{2}}\|A\|_{2}^{2}=0, the roots are λ1=1η1βA2\lambda_{1}=1-\frac{\eta}{1-\beta}\|A\|_{2} and λ2=1+η1βA2\lambda_{2}=1+\frac{\eta}{1-\beta}\|A\|_{2}. We have that

where r5(t)=τ=1tλ2τr4(τ)r_{5}^{(t)}=\sum_{\tau=1}^{t}\lambda_{2}^{-\tau}r_{4}^{(\tau)}, C1=W2(1)λ2W2(0)λ2λ1C_{1}=-\frac{W_{2}^{(1)}-\lambda_{2}W_{2}^{(0)}}{\lambda_{2}-\lambda_{1}} and C2=W2(1)λ1W2(0)λ2λ1C_{2}=\frac{W_{2}^{(1)}-\lambda_{1}W_{2}^{(0)}}{\lambda_{2}-\lambda_{1}}.

C.5 Proof of Lemma 5

Write r1(t)=βt+1W2(t)TA+q12(t)+q13(t)+q14(t)r_{1}^{(t)}=-\beta^{t+1}W_{2}^{(t)T}A+q_{12}^{(t)}+q_{13}^{(t)}+q_{14}^{(t)} where q12(t)=(1β)τ=0tβtτ(W2(τ)TW2(t)T)Aq_{12}^{(t)}=(1-\beta)\sum_{\tau=0}^{t}\beta^{t-\tau}\left(W_{2}^{(\tau)T}-W_{2}^{(t)T}\right)A, q13(t)=(1β)τ=0tβtτW2(τ)TW2(τ)W1(τ)q_{13}^{(t)}=-(1-\beta)\sum_{\tau=0}^{t}\beta^{t-\tau}W_{2}^{(\tau)T}W_{2}^{(\tau)}W_{1}^{(\tau)} and q14(t)=(1β)τ=0tβtτDg1(τ)q_{14}^{(t)}=-(1-\beta)\sum_{\tau=0}^{t}\beta^{t-\tau}Dg_{1}^{(\tau)}. And write r2(t)=βt+1AW1(t)T+q22(t)+q23(t)+q24(t)r_{2}^{(t)}=-\beta^{t+1}AW_{1}^{(t)T}+q_{22}^{(t)}+q_{23}^{(t)}+q_{24}^{(t)}, where q22(t)=(1β)τ=0tβtτA(W1(τ)TW1(t)T)q_{22}^{(t)}=(1-\beta)\sum_{\tau=0}^{t}\beta^{t-\tau}A\left(W_{1}^{(\tau)T}-W_{1}^{(t)T}\right), q23(t)=(1β)τ=0tβtτW2(τ)W1(τ)W1(τ)Tq_{23}^{(t)}=-(1-\beta)\sum_{\tau=0}^{t}\beta^{t-\tau}W_{2}^{(\tau)}W_{1}^{(\tau)}W_{1}^{(\tau)T} and q24(t)=(1β)τ=0tβtτDg2(τ)q_{24}^{(t)}=-(1-\beta)\sum_{\tau=0}^{t}\beta^{t-\tau}Dg_{2}^{(\tau)}.

Let’s first try to bound q12(t)[i,j]\left|q_{12}^{(t)}[i,j]\right| and q22,i(t)\left|q_{22,i}^{(t)}\right|. For any τT1\tau\leq T_{1}, we have that

and thus i[d]:Ei(τ)=O(1)\forall i\in[d]:\left|E_{i}^{(\tau)}\right|=\mathcal{O}(1). Then we have for all i,j[d]i,j\in[d],

Then we bound q13(t)[i,j]\left|q_{13}^{(t)}[i,j]\right| and q23,i(t)\left|q_{23,i}^{(t)}\right|. We have for i,j[d]\forall i,j\in[d],

Finally we use Lemma 31 to bound q14(t)[i,j]\left|q_{14}^{(t)}[i,j]\right| and q24,i(t)\left|q_{24,i}^{(t)}\right|. For tT1t\leq T_{1}, the M1(t),M2(t)M_{1}^{(t)},M_{2}^{(t)} in Lemma 31 are upper bounded by 1dα2\frac{1}{d^{\frac{\alpha}{2}}}. In the theorem we consider the training period before TSGD,2T_{\text{SGD},2} so the time TT in Lemma 31 is set as TSGD,2T_{\text{SGD},2}. In the following sections, we will prove that TSGD,2O(dαlog(d/ϵ)η)T_{\text{SGD},2}\leq\mathcal{O}\left(\frac{d^{\alpha}\log(\sqrt{d/\epsilon})}{\eta}\right). Then by Lemma 31, we have with probability at least 11d1-\frac{1}{d}, for tT1\forall t\leq T_{1} and i,j[d]\forall i,j\in[d],

Combining all the above bounds and substituting ηO(1dα)\eta\leq\mathcal{O}\left(\frac{1}{d^{\alpha}}\right) gives us for tT1\forall t\leq T_{1} and i,j[d]\forall i,j\in[d],

For tT1t\leq T_{1}, we have i,j[d]\forall i,j\in[d], w2i(t)AjO(1dα/2)\left|w_{2i}^{(t)}A_{j}\right|\leq\mathcal{O}\left(\frac{1}{d^{\alpha/2}}\right) and j=1dAjW1(t)[i,j]O(1dα/21)\left|\sum_{j=1}^{d}A_{j}W_{1}^{(t)}[i,j]\right|\leq\mathcal{O}\left(\frac{1}{d^{\alpha/2-1}}\right), which gives us r1(t)[i,j]O(1dα/2)\left|r_{1}^{(t)}[i,j]\right|\leq\mathcal{O}\left(\frac{1}{d^{\alpha/2}}\right) and r2i(t)O(1dα/21)\left|r_{2i}^{(t)}\right|\leq\mathcal{O}\left(\frac{1}{d^{\alpha/2-1}}\right). Substituting into eq. (5) and eq. (6) yields that for tT1t\leq T_{1} and i,j[d]\forall i,j\in[d],

Hence for tmin{αlogdlog(1/β),T1}t\leq\min\left\{\frac{\alpha\log d}{\log(1/\beta)},T_{1}\right\}, we have i,j[d]\forall i,j\in[d],

Then we know that T1>αlogdlog(1/β)T_{1}>\frac{\alpha\log d}{\log(1/\beta)} and also get tighter bounds of W1(t)[i,j],w2i(t)\left|W_{1}^{(t)}[i,j]\right|,\left|w_{2i}^{(t)}\right| for tαlogdlog(1/β)t\leq\frac{\alpha\log d}{\log(1/\beta)}. Now we use these new bounds to analyze r1(t)[i,j]\left|r_{1}^{(t)}[i,j]\right| and r2i(t)\left|r_{2i}^{(t)}\right| again.

C.6 Proof of Lemma 6

Since λ1=1η1βA2,λ2=1+η1βA2\lambda_{1}=1-\frac{\eta}{1-\beta}\|A\|_{2},\lambda_{2}=1+\frac{\eta}{1-\beta}\|A\|_{2}, and note that A2=Θ(d)\|A\|_{2}=\Theta\left(\sqrt{d}\right), we have that

C.7 Proof of Lemma 7

For the equation x22x+1η2(1β)2A22=0x^{2}-2x+1-\frac{\eta^{2}}{(1-\beta)^{2}}\|A\|_{2}^{2}=0, the roots are λ1=1η1βA2\lambda_{1}=1-\frac{\eta}{1-\beta}\|A\|_{2} and λ2=1+η1βA2\lambda_{2}=1+\frac{\eta}{1-\beta}\|A\|_{2}, which gives us

Then we have that w.p. at least 1δ1-\delta, 1i,jd:\forall 1\leq i,j\leq d:,

Next, we bound the ii-th coordinate of W2(0)+1βA2AW1(0)TW_{2}^{(0)}+\frac{1-\beta}{\|A\|_{2}}AW_{1}^{(0)T}, i.e. w2i(0)+1βA2A(W1(0)[i,:])Tw_{2i}^{(0)}+\frac{1-\beta}{\|A\|_{2}}A\left(W_{1}^{(0)}[i,:]\right)^{T}.

By independence under Assumption 2, we have that

Using the Gaussian tail bound and union bound, w.p. at least 1δ1-\delta, for ever 1id1\leq i\leq d, we have that

Since for XN(0,σ2)X\sim\mathcal{N}(0,\sigma^{2}), we have that P(Xt)2t2πσP(|X|\leq t)\leq\frac{2t}{\sqrt{2\pi}\sigma}, then for a fixed ii,

Then by union bound, we have that w.p. at least 11dα411-\frac{1}{d^{\frac{\alpha}{4}-1}}, for every 1id1\leq i\leq d,

where (i)(i) follows from eq. (14) and the fact that A2=d\|A\|_{2}=\sqrt{d}. Then we get that w.h.p.

Substituting eq. (15) into eq. (13), we get that w.h.p.,

C.8 Proof of Lemma 2

The proof in Section C.2 tells us that at the end of the first phase (when t=T1t=T_{1}),

Denote the ii-th coordinate of u(t),v(t),R2(t)\boldsymbol{u}^{(t)},\boldsymbol{v}^{(t)},R_{2}^{(t)} as ui(t),vi(t),R2i(t)u_{i}^{(t)},v_{i}^{(t)},R_{2i}^{(t)}, respectively. Denote the (i,j)(i,j)-th element of R1(t)R_{1}^{(t)} as R1(t)[i,j]R_{1}^{(t)}[i,j]. For tT1t\geq T_{1}, we prove by induction that,

with E(t):=W2(t)W1(t)AE^{(t)}:=W_{2}^{(t)}W_{1}^{(t)}-A, ηt=ητ=0tβtτ\eta_{t}=\eta\sum_{\tau=0}^{t}\beta^{t-\tau}, r1(t):=ητ=0tβtτ(W2(t)TE(t)W2(τ)TE(τ))ητ=0tβtτDg1(τ)r_{1}^{(t)}:=\eta\sum_{\tau=0}^{t}\beta^{t-\tau}\left(W_{2}^{(t)T}E^{(t)}-W_{2}^{(\tau)T}E^{(\tau)}\right)-\eta\sum_{\tau=0}^{t}\beta^{t-\tau}Dg_{1}^{(\tau)} and r2(t)=ητ=0tβtτ(E(t)W1(t)TE(τ)W1(τ)T)ητ=0tβtτDg2(τ)r_{2}^{(t)}=\eta\sum_{\tau=0}^{t}\beta^{t-\tau}\left(E^{(t)}W_{1}^{(t)T}-E^{(\tau)}W_{1}^{(\tau)T}\right)-\eta\sum_{\tau=0}^{t}\beta^{t-\tau}Dg_{2}^{(\tau)}. Note that the r1(t)r_{1}^{(t)} and r2(t)r_{2}^{(t)} here are different from those defined in Section C.2, but we abuse the notation and still use r1(t)r_{1}^{(t)} and r2(t)r_{2}^{(t)} to represent the error terms.

The base case is already given by eq. (16).

Suppose our lemma holds for tt, then for t+1t+1, using the same techniques as in eq. (5) and eq. (6), we have that

Plugging in the inductive hypothesis yields

It implies that our lemma holds for t+1t+1, which completes the proof.

Now we analyze the error terms R1(t)[i,j]\left|R_{1}^{(t)}[i,j]\right| and R2i(t)\left|R_{2i}^{(t)}\right|. Eq. (16) tells us that c(T1)c^{(T_{1})} and i[d],vi(T1)\forall i\in[d],v_{i}^{(T_{1})} are all positive. We first prove by induction that for all T1tT2T_{1}\leq t\leq T_{2}, c(t)>0,i[d],vi(t)>0c^{(t)}>0,\forall i\in[d],v_{i}^{(t)}>0.

The above discussion already proves the base case. Suppose at time tt, we have c(t)>0,i[d],vi(t)>0c^{(t)}>0,\forall i\in[d],v_{i}^{(t)}>0. Note that when T1t<T2T_{1}\leq t<T_{2}, i[d]:Ei(t)0\forall i\in[d]:E_{i}^{(t)}\leq 0, then for t+1t+1,

Therefore by induction, we have proved that for all T1tT2T_{1}\leq t\leq T_{2}, c(t)>0,i[d],vi(t)>0c^{(t)}>0,\forall i\in[d],v_{i}^{(t)}>0.

Now we prove that for all T1tT2T_{1}\leq t\leq T_{2},

The left hand sides of the inequalities are trivial since we have proved that c(t)>0,i[d],vi(t)>0c^{(t)}>0,\forall i\in[d],v_{i}^{(t)}>0 for all T1tT2T_{1}\leq t\leq T_{2}. Now we prove the right hand sides by induction.

The base case is already verified by the definition of δi\delta_{i}. Suppose eq.(18) holds for T1t<T2T_{1}\leq t<T_{2}. Then for t+1t+1, using i[d]:Ei(t)0\forall i\in[d]:E_{i}^{(t)}\leq 0 and v(t+1)v(t),c(t+1)c(t)v^{(t+1)}\geq v^{(t)},c^{(t+1)}\geq c^{(t)}, we can get that 1i,jd\forall 1\leq i,j\leq d

Similarly, we have that 1id\forall 1\leq i\leq d

Therefore by induction, eq. (18) holds for all tt in the second phase.

So far we have proved the rank 1 structure stated in Lemma 2. The remaining part of the proof is given by the following lemma, whose proof is deferred to Section C.9.

Under Assumption 1, 2 and 3, suppose ση3/2dα/2+1\sigma\leq\frac{\eta^{3/2}}{d^{\alpha/2+1}}. By picking ηO(ϵd7α4+4)\eta\leq\mathcal{O}\left(\frac{\epsilon}{d^{\frac{7\alpha}{4}+4}}\right), we have that w.h.p. for T1tmin{T2,T3}T_{1}\leq t\leq\min\{T_{2},T_{3}\},

and that when t=min{T2,T3}t=\min\{T_{2},T_{3}\}, we have E(t)22=O(ϵ0d)\left\|E^{(t)}\right\|^{2}_{2}=\mathcal{O}(\epsilon_{0}d).

C.9 Proof of Lemma 8

We first have the following lemma which describes the structure of v(t)\boldsymbol{v}^{(t)} for tT1t\geq T_{1}.

Under Assumption 1, 2 and 3, for tT1t\geq T_{1}, we can write v(t)T\boldsymbol{v}^{(t)T} as v(t)T=a(t)A+Rv(t)T\boldsymbol{v}^{(t)T}=a^{(t)}A+R_{v}^{(t)T}, with a(T1)=η1β,Rv(T1)T=[0,0,...,0]a^{(T_{1})}=\frac{\eta}{1-\beta},R_{v}^{(T_{1})T}=[0,0,...,0], and

We prove Lemma 8 by induction. Denote the ii-th coordinate of R3(t)R_{3}^{(t)} and Rv(t)R_{v}^{(t)} as R3i(t)R_{3i}^{(t)} and Rvi(t)R_{vi}^{(t)}, respectively. The following lemmas constitute the inductive part.

Under the conditions of Lemma 10 and pick ηO(ϵd7α4+4)\eta\leq\mathcal{O}\left(\frac{\epsilon}{d^{\frac{7\alpha}{4}+4}}\right), we have that at time t+1t+1,

where ϵ0\epsilon_{0} is defined in Definition 2.

Under the conditions of Lemma 10 and pick ηO(ϵd7α4+4)\eta\leq\mathcal{O}\left(\frac{\epsilon}{d^{\frac{7\alpha}{4}+4}}\right), we have that at time t+1t+1,

i,j[d]:Ej(t)Ei(t)=Θ(1)\forall i,j\in[d]:\frac{E_{j}^{(t)}}{E_{i}^{(t)}}=\Theta(1),

j[d]:vj(t+1)c(t+1)=Θ(1d)\forall j\in[d]:\frac{v_{j}^{(t+1)}}{c^{(t+1)}}=\Theta\left(\frac{1}{\sqrt{d}}\right),

i,j[d]:w2i(t+1)O(d1/4),W1(t+1)[i,j]O(1d1/4)\forall i,j\in[d]:\left|w_{2i}^{(t+1)}\right|\leq\mathcal{O}\left(d^{1/4}\right),\left|W_{1}^{(t+1)}[i,j]\right|\leq\mathcal{O}\left(\frac{1}{d^{1/4}}\right),

By combining Lemma 10, 11 and 13, we can prove by induction that for all T1tmin{T2,T3}T_{1}\leq t\leq\min\{T_{2},T_{3}\}, eq. (19) holds (which follows from Lemma 11), and

So far we have proved eq. (19) in Lemma 8. Now let’s prove when t=min{T2,T3}t=\min\{T_{2},T_{3}\}, we have that E(t)22=O(ϵ0d)\left\|E^{(t)}\right\|^{2}_{2}=\mathcal{O}(\epsilon_{0}d).

If min{T2,T3}=T3\min\{T_{2},T_{3}\}=T_{3}, by Definition 3, we have E(t)22ϵ\left\|E^{(t)}\right\|_{2}^{2}\leq\epsilon. If min{T2,T3}=T2\min\{T_{2},T_{3}\}=T_{2}, by Definition 2, there exists j[d]j\in[d] such that Ej(t)=Θ(ϵ0)E_{j}^{(t)}=-\Theta\left(\sqrt{\epsilon_{0}}\right). Combining with eq. (22) gives us i[d]:Ei(t)=Θ(ϵ0)\forall i\in[d]:E_{i}^{(t)}=-\Theta\left(\sqrt{\epsilon_{0}}\right). Combining these two cases, we get that when t=min{T2,T3}t=\min\{T_{2},T_{3}\}, E(t)22max{ϵ,Θ(ϵ0d)}=O(ϵ0d)\left\|E^{(t)}\right\|^{2}_{2}\leq\max\{\epsilon,\Theta\left(\epsilon_{0}d\right)\}=\mathcal{O}\left(\epsilon_{0}d\right).

C.10 Proof of Lemma 9

We prove this lemma by induction. The base case (t=T1t=T_{1}) of v(t)\boldsymbol{v}^{(t)} is verified by eq. (16).

Suppose at time tt, v(t)T=a(t)A+Rv(t)T\boldsymbol{v}^{(t)T}=a^{(t)}A+R_{v}^{(t)T}, then by eq. 17, we have that

where d(t):=c(t)u(T1)2+R2(t)Tu(T1),R3(t)T:=c(t)u(T1)TR1(t)+R2(t)TR1(t)d^{(t)}:=c^{(t)}\left\|\boldsymbol{u}^{(T_{1})}\right\|^{2}+R_{2}^{(t)T}\boldsymbol{u}^{(T_{1})},\quad R_{3}^{(t)T}:=c^{(t)}\boldsymbol{u}^{(T_{1})T}R_{1}^{(t)}+R_{2}^{(t)T}R_{1}^{(t)}. That gives us

Therefore we have proved by induction that for tt in the second phase, v(t)=a(t)A+Rv(t)T\boldsymbol{v}^{(t)}=a^{(t)}A+R_{v}^{(t)T}. The above steps also proved eq. (20).

C.11 Proof of Lemma 10

Write r1(t)=q11(t)+q12(t)r_{1}^{(t)}=q_{11}^{(t)}+q_{12}^{(t)} where we have q11(t)=ητ=0tβtτ(W2(t)TE(t)W2(τ)TE(τ))q_{11}^{(t)}=\eta\sum_{\tau=0}^{t}\beta^{t-\tau}\left(W_{2}^{(t)T}E^{(t)}-W_{2}^{(\tau)T}E^{(\tau)}\right), q12(t)=ητ=0tβtτDg1(τ)q_{12}^{(t)}=-\eta\sum_{\tau=0}^{t}\beta^{t-\tau}Dg_{1}^{(\tau)}. Write r2(t)=q21(t)+q22(t)r_{2}^{(t)}=q_{21}^{(t)}+q_{22}^{(t)} where q22(t)=ητ=0tβtτDg2(τ)q_{22}^{(t)}=-\eta\sum_{\tau=0}^{t}\beta^{t-\tau}Dg_{2}^{(\tau)}, q21(t)=ητ=0tβtτ(E(t)W1(t)TE(τ)W1(τ)T)q_{21}^{(t)}=\eta\sum_{\tau=0}^{t}\beta^{t-\tau}\left(E^{(t)}W_{1}^{(t)T}-E^{(\tau)}W_{1}^{(\tau)T}\right).

Let’s first bound q11(t)[i,j]\left|q_{11}^{(t)}[i,j]\right| and q21,i(t)\left|q_{21,i}^{(t)}\right|. By definition of T2T_{2}, we know that for T1τtT_{1}\leq\tau\leq t, i[d]:Ei(τ)=O(1)\forall i\in[d]:\left|E_{i}^{(\tau)}\right|=\mathcal{O}(1). Then we have for all i,j[d]i,j\in[d],

We can further get that for j[d]\forall j\in[d],

Combining the above inequalities gives us i,j[d]\forall i,j\in[d],

Next let’s bound q12(t)[i,j]\left|q_{12}^{(t)}[i,j]\right| and q22,i(t)\left|q_{22,i}^{(t)}\right|. By the assumption of this lemma and the analysis before T1T_{1}, we know that for all τt\tau\leq t, the M1(τ),M2(τ)M_{1}^{(\tau)},M_{2}^{(\tau)} in Lemma 31 are upper bounded by O(1d1/4)\mathcal{O}\left(\frac{1}{d^{1/4}}\right) and O(d1/4)\mathcal{O}\left(d^{1/4}\right), respectively. In the theorem we consider the training period before TSGD,2T_{\text{SGD},2} so the time TT in Lemma 31 is set as TSGD,2T_{\text{SGD},2}. In the following sections, we will prove that TSGD,2O(dαlog(d/ϵ)η)T_{\text{SGD},2}\leq\mathcal{O}\left(\frac{d^{\alpha}\log(\sqrt{d/\epsilon})}{\eta}\right). Then by Lemma 31, we have with probability at least 11d1-\frac{1}{d}, for τt\forall\tau\leq t and i,j[d]\forall i,j\in[d],

Combining the above bounds, we get that i,j[d]\forall i,j\in[d],

C.12 Proof of Lemma 11

Let’s first try to bound the length of min{T2,T3}\min\{T_{2},T_{3}\}. More formally, we prove that under the conditions of Lemma 10 and pick ηO(ϵd7α4+4)\eta\leq\mathcal{O}\left(\frac{\epsilon}{d^{\frac{7\alpha}{4}+4}}\right), we have that min{T2,T3}O(dαlog(d/ϵ)η)\min\{T_{2},T_{3}\}\leq\mathcal{O}\left(\frac{d^{\alpha}\log(\sqrt{d/\epsilon})}{\eta}\right).

Under the conditions of Lemma 10, we know that

When T1t<T2T_{1}\leq t<T_{2}, we have proved that c(t)c^{(t)} is increasing over time in Section C.8, which implies that W2(t)22CW2(T1)22\left\|W_{2}^{(t)}\right\|_{2}^{2}\geq C\left\|W_{2}^{(T_{1})}\right\|_{2}^{2} since c(t)u(T1)Tc^{(t)}\boldsymbol{u}^{(T_{1})T} is the leading term of W2(t)W_{2}^{(t)}. Combining with ηtη\eta_{t}\geq\eta gives us

where (i)(i) uses E(T1)2=O(d)\left\|E^{(T_{1})}\right\|_{2}=\mathcal{O}(\sqrt{d}). By picking ηO(ϵd7α4+4)\eta\leq\mathcal{O}\left(\frac{\epsilon}{d^{\frac{7\alpha}{4}+4}}\right) and noticing that W2(T1)22Ω(1dα)\left\|W_{2}^{(T_{1})}\right\|_{2}^{2}\geq\Omega\left(\frac{1}{d^{\alpha}}\right), we have ηd4W2(T1)22<ϵ2\frac{\eta d^{4}}{\left\|W_{2}^{(T_{1})}\right\|_{2}^{2}}<\frac{\sqrt{\epsilon}}{2}. Hence when tT1Θ(log(d/ϵ)ηW2(T1)22)t-T_{1}\geq\Theta\left(\frac{\log\left(\sqrt{d/\epsilon}\right)}{\eta\left\|W_{2}^{(T_{1})}\right\|_{2}^{2}}\right), we have that E(t)2ϵ\left\|E^{(t)}\right\|_{2}\leq\sqrt{\epsilon}, i.e. E(t)22ϵ\left\|E^{(t)}\right\|^{2}_{2}\leq\epsilon.

That means after at most O(log(d/ϵ)ηW2(T1)22)\mathcal{O}\left(\frac{\log\left(\sqrt{d/\epsilon}\right)}{\eta\left\|W_{2}^{(T_{1})}\right\|_{2}^{2}}\right) steps from T1T_{1}, either tT2t\geq T_{2}, or we have E(t)22ϵ\left\|E^{(t)}\right\|^{2}_{2}\leq\epsilon. In other words, min{T2,T3}T1+O(log(d/ϵ)ηW2(T1)22)O(dαlog(d/ϵ)η)\min\{T_{2},T_{3}\}\leq T_{1}+\mathcal{O}\left(\frac{\log\left(\sqrt{d/\epsilon}\right)}{\eta\left\|W_{2}^{(T_{1})}\right\|_{2}^{2}}\right)\leq\mathcal{O}\left(\frac{d^{\alpha}\log\left(\sqrt{d/\epsilon}\right)}{\eta}\right).

Combining min{T2,T3}O(dαlog(d/ϵ)η)\min\{T_{2},T_{3}\}\leq\mathcal{O}\left(\frac{d^{\alpha}\log(\sqrt{d/\epsilon})}{\eta}\right) and Lemma 10 yields that for t+1min{T2,T3}t+1\leq\min\{T_{2},T_{3}\},

C.13 Proof of Lemma 12

The proof in Section C.8 tells us that for T1τT2T_{1}\leq\tau\leq T_{2}, c(τ)>0,j[d]:vj(τ)>0c^{(\tau)}>0,\forall j\in[d]:v_{j}^{(\tau)}>0, which gives us 0R2(t+1)Tu(T1)c(t+1)u(T1)20\leq\frac{\left|R_{2}^{(t+1)T}\boldsymbol{u}^{(T_{1})}\right|}{c^{(t+1)}\left\|\boldsymbol{u}^{(T_{1})}\right\|^{2}} and 0R3j(t+1)c(t+1)u(T1)2vj(t+1)0\leq\frac{\left|R^{(t+1)}_{3j}\right|}{c^{(t+1)}\left\|\boldsymbol{u}^{(T_{1})}\right\|^{2}v_{j}^{(t+1)}}. By Lemma 11, we have that

Since t<T2t<T_{2}, we have j[d]:(W2(t+1)W1(t+1))jAj=O(1)\forall j\in[d]:\frac{\left(W_{2}^{(t+1)}W_{1}^{(t+1)}\right)_{j}}{A_{j}}=\mathcal{O}(1), which yields

C.14 Proof of Lemma 13

Substituting into the time tt version of eq.(24) yields

Since t<T2t<T_{2}, we have Ej(t)<ϵ0E_{j}^{(t)}<-\sqrt{\epsilon_{0}}. Combining with Aj=Θ(1)A_{j}=\Theta(1), gives us a(t)c(t)u(T1)21=Ω(ϵ0)a^{(t)}c^{(t)}\left\|\boldsymbol{u}^{(T_{1})}\right\|^{2}-1=-\Omega\left(\sqrt{\epsilon_{0}}\right). Then we can rewrite Ej(t)E_{j}^{(t)} as j[d]\forall j\in[d],

(B) Note that we assume j[d]:vj(t)c(t)=Θ(1d)\forall j\in[d]:\quad\frac{v_{j}^{(t)}}{c^{(t)}}=\Theta\left(\frac{1}{\sqrt{d}}\right), then we have for j[d]j\in[d],

Then for t+1t+1, we have that for j[d]j\in[d],

(C) Combining eq. (25) and j[d]:Aj=Θ(1)\forall j\in[d]:A_{j}=\Theta(1), we know that

Under the conditions of Lemma 10, and pick ηO(ϵd7α4+4)\eta\leq\mathcal{O}\left(\frac{\epsilon}{d^{\frac{7\alpha}{4}+4}}\right), we have that 1ηtc(t)d(t)1ηc(t)d(t)>01-\eta_{t}c^{(t)}d^{(t)}\geq 1-\eta c^{(t)}d^{(t)}>0.

Appendix D Analysis of Adam

Note that A=1mYXTA=\frac{1}{m}YX^{T}, Λxx:=1mXXT\Lambda_{xx}:=\frac{1}{m}XX^{T}. Denote gk(t):=WkL(W(t)),k=1,2g_{k}^{(t)}:=\nabla_{W_{k}}L(W^{(t)}),k=1,2. We have that

Denote the ii-th coordinate of W2W1W_{2}W_{1} and AA as (W2W1)i(W_{2}W_{1})_{i} and AiA_{i}, respectively. By Assumption 2 and the assumption that i[d]:Ai>0,Ai=Ω(1)\forall i\in[d]:A_{i}>0,A_{i}=\Omega(1), at the beginning, w.h.p., i[d]:(W2W1)iAi<0\forall i\in[d]:(W_{2}W_{1})_{i}-A_{i}<0. Based on this, we divide the training procedure into two phases (note that these two phases are different from those of GD).

First phase: when the error (W2W1)iAi(W_{2}W_{1})_{i}-A_{i} is negative and its absolute value is big for all i[d]i\in[d].

Second phase: when (W2W1)iAi(W_{2}W_{1})_{i}-A_{i} is close to zero for some coordinate i[d]i\in[d].

More formally, we define the boundary between the two phases below.

The end of the first phase (denoted as T1T_{1}) is defined as T1=inf{t>0:i[d]:Ei(t)ηd}T_{1}=\inf\left\{t>0:\exists i\in[d]:E^{(t)}_{i}\geq-\sqrt{\eta d}\right\}.

In the second phase, we define some time points.

Define Tg:=inf{t>T1:i[d]:g2i(t)dη}T_{g}:=\inf\left\{t>T_{1}:\exists i\in[d]:\left|g_{2i}^{(t)}\right|\leq d\sqrt{\eta}\right\}.

For t<T1t<T_{1}, we have i[d]:Ei(t)<0\forall i\in[d]:E_{i}^{(t)}<0 by Definition 4. For t>T1t>T_{1}, some Ei(t)E_{i}^{(t)} may flip the sign and become positive. For certain coordinate ii, we define the following “flip time”.

We can first show that after a few steps in the first phase, W1W_{1} will become an approximately rank-1 matrix, as described in the following lemma.

Under Assumption 1, 2 and 3, suppose ση3/2ξ2d13/4\sigma\leq\frac{\eta^{3/2}\xi^{2}}{d^{13/4}}. By picking ηO(1d3α),ξηd3α1\eta\leq\mathcal{O}\left(\frac{1}{d^{3\alpha}}\right),\xi\leq\sqrt{\frac{\eta}{d^{3\alpha-1}}}, and β2=β12\beta_{2}=\beta_{1}^{2}, there exists tinc>0t_{\text{inc}}>0 such that w.h.p. for tinct<T1t_{\text{inc}}\leq t<T_{1},

Now we are ready to prove the Adam part of Theorem 1.

D.2 Proof of Lemma 14

For some time tt, we introduce two conditions.

Next prove that, under Assumption 1 and 2, by picking ηO(1d3α),ξηd3α1\eta\leq\mathcal{O}\left(\frac{1}{d^{3\alpha}}\right),\xi\leq\sqrt{\frac{\eta}{d^{3\alpha-1}}}, and β2=β12\beta_{2}=\beta_{1}^{2}, there exists tinc>0t_{\text{inc}}>0 such that for tinct<T1t_{\text{inc}}\leq t<T_{1}, the weights can be approximated in the following way.

Before we dive into the proof, let’s introduce some useful lemmas.

The following lemma reflects our key idea: converting the exponential average in Adam to a finite-step average, and trying to bound the stochastic error terms in eq. (28).

where H11β1logmax{G1(t),G2(t),(G1(t))2,(G2(t))2}ηξ2H\geq\frac{1}{1-\beta_{1}}\log\frac{\max\left\{G_{1}^{(t)},G_{2}^{(t)},\left(G_{1}^{(t)}\right)^{2},\left(G_{2}^{(t)}\right)^{2}\right\}}{\eta\xi^{2}} and

The following lemma analyzes the magnitude of weights during a short period at the beginning.

Under Assumption 1, 2 and 3, suppose ση3/2ξ2d13/4\sigma\leq\frac{\eta^{3/2}\xi^{2}}{d^{13/4}}. Pick ξ1d32α\xi\leq\frac{1}{d^{\frac{3}{2}\alpha}}, then there exists some time point tinc(H,T1)t_{\text{inc}}\in(H,T_{1}), such that w.h.p., for ttinct\leq t_{\text{inc}}, for every i,j[d]i,j\in[d],

Specifically, when t=tinct=t_{\text{inc}}, we have sign(w2i(tinc))=sign(W1(tinc)[i,j])=sign(w2i(0))\text{sign}\left(w_{2i}^{(t_{\text{inc}})}\right)=\text{sign}\left(W_{1}^{(t_{\text{inc}})}[i,j]\right)=\text{sign}\left(w_{2i}^{(0)}\right), W1(tinc)[i,j]=Θ(1d32α+1)\left|W_{1}^{(t_{\text{inc}})}[i,j]\right|=\Theta\left(\frac{1}{d^{\frac{3}{2}\alpha+1}}\right) and g1(tinc)[i,j]Ω(ξ),g2i(tinc)Ω(ξ)\left|g_{1}^{\left(t_{\text{inc}}\right)}[i,j]\right|\geq\Omega(\xi),\left|g_{2i}^{\left(t_{\text{inc}}\right)}\right|\geq\Omega(\xi). Moreover, Condition 1 and 2 are satisfied for t=tinct=t_{\text{inc}}. The s1(t)[i,j]s_{1}^{(t)}[i,j] and s2i(t)s_{2i}^{(t)} in the conditions are both sign(w2i(0))-\text{sign}\left(w_{2i}^{(0)}\right).

The following lemma gives us lower bounds of g1(t)[i,j]\left|g_{1}^{(t)}[i,j]\right| and g2i(t)\left|g_{2i}^{(t)}\right|.

The following lemma shows that when tinct<T1t_{\text{inc}}\leq t<T_{1}, we have i,j[d]:g2i(t)g2i(t)g2i(t1)\forall i,j\in[d]:\left|g^{(t)}_{2i}\right|\gg\left|g_{2i}^{(t)}-g_{2i}^{(t-1)}\right| and that g1(t)[i,j]g1(t)[i,j]g1(t1)[i,j]\left|g^{(t)}_{1}[i,j]\right|\gg\left|g_{1}^{(t)}[i,j]-g_{1}^{(t-1)}[i,j]\right|.

Under Assumption 1, 2 and 3, suppose ση3/2ξ2d13/4\sigma\leq\frac{\eta^{3/2}\xi^{2}}{d^{13/4}}. Pick ξηd3α1\xi\leq\sqrt{\frac{\eta}{d^{3\alpha-1}}}, ηO(1d3α)\eta\leq\mathcal{O}\left(\frac{1}{d^{3\alpha}}\right). For tinct_{\text{inc}} in Lemma 17, we have that w.h.p. for tinct<T1t_{\text{inc}}\leq t<T_{1} and τt\tau\leq t, i,j[d]\forall i,j\in[d],

Equipped with these lemmas, now let’s prove eq. (29).

Let’s first look at the update of W1(t)[i,j]W_{1}^{(t)}[i,j]. For tt in the first phase, we write the RHS of eq. (32) as

By Lemma 18, we know that g1(t)[i,j]=Ω(η)\left|g_{1}^{(t)}[i,j]\right|=\Omega\left(\sqrt{\eta}\right). Then we have that

Therefore by Lemma 33 in Appendix G, we have

Since β(0,1)\beta\in(0,1), we know that logββ1<0\log\beta\leq\beta-1<0. Then our choice of HH gives us H=11β1logdηξ2logηξ2dlogβ1H=\frac{1}{1-\beta_{1}}\log\frac{d}{\eta\xi^{2}}\geq\frac{\log\frac{\eta\xi^{2}}{d}}{\log\beta_{1}} and H>11β2logdηξ2logηξ2dlogβ2H>\frac{1}{1-\beta_{2}}\log\frac{d}{\eta\xi^{2}}\geq\frac{\log\frac{\eta\xi^{2}}{d}}{\log\beta_{2}}, which implies that β1H,β2Hηξ2/d\beta_{1}^{H},\beta_{2}^{H}\leq\eta\xi^{2}/d. Hence for ttinc>Ht\geq t_{\text{inc}}>H, ηt1β1H+11β2H+1=η1β2t+11β2H+11β1H+11β1t+1=η(1±O(η))\eta_{t}\frac{1-\beta_{1}^{H+1}}{\sqrt{1-\beta_{2}^{H+1}}}=\eta\frac{\sqrt{1-\beta_{2}^{t+1}}}{\sqrt{1-\beta_{2}^{H+1}}}\frac{1-\beta_{1}^{H+1}}{1-\beta_{1}^{t+1}}=\eta(1\pm\mathcal{O}(\eta)).

So far we have successfully proved eq. (29). By sign(ΔW1(t)[i,j])=sign(Δw2i(t))=sign(w2i(0))\text{sign}\left(\Delta W_{1}^{(t)}[i,j]\right)=\text{sign}\left(\Delta w_{2i}^{(t)}\right)=\text{sign}\left(w_{2i}^{(0)}\right) in Lemma 18, we know that sign(g1(t)[i,j])=sign(g2i(t))=sign(w2i(0))\text{sign}\left(-g_{1}^{(t)}[i,j]\right)=\text{sign}\left(-g_{2i}^{(t)}\right)=\text{sign}\left(w_{2i}^{(0)}\right), which gives us

Finally to complete the proof, we show that T1=Θ(1dη)T_{1}=\Theta\left(\frac{1}{\sqrt{d}\eta}\right). When t=T1t=T_{1}, we have j[d]:i=1dw2i(T1)W1(T1)[i,j]=Θ(1)\forall j\in[d]:\sum_{i=1}^{d}w_{2i}^{(T_{1})}W_{1}^{(T_{1})}[i,j]=\Theta(1). Combining with the above results, we know that dη2(T1tinc)2=Θ(1)d\eta^{2}(T_{1}-t_{\text{inc}})^{2}=\Theta(1), i.e. η(T1tinc)=Θ(1d)\eta(T_{1}-t_{\text{inc}})=\Theta\left(\frac{1}{\sqrt{d}}\right). In Section D.5, we will prove tinc=Θ(1ηd32α+1)t_{\text{inc}}=\Theta\left(\frac{1}{\eta d^{\frac{3}{2}\alpha+1}}\right). Then we have T1=Θ(1dη)T_{1}=\Theta\left(\frac{1}{\sqrt{d}\eta}\right).

D.3 Proof of Lemma 16

For certain tt and HH, we write eq. (27) as

and r1n(t)[i,j],r1d(t)[i,j],r2n,i(t),r2d,i(t)r_{1n}^{(t)}[i,j],r_{1d}^{(t)}[i,j],r_{2n,i}^{(t)},r_{2d,i}^{(t)} are defined in eq. (28).

Since β2=β12<β1\beta_{2}=\beta_{1}^{2}<\beta_{1}, then if we pick H11β1logmax{G1(t),G2(t),(G1(t))2,(G2(t))2}ηξ2H\geq\frac{1}{1-\beta_{1}}\log\frac{\max\left\{G_{1}^{(t)},G_{2}^{(t)},\left(G_{1}^{(t)}\right)^{2},\left(G_{2}^{(t)}\right)^{2}\right\}}{\eta\xi^{2}}, we can get that H11β1logG1(t)ηξ2,H11β2log(G1(t))2ηξ2H\geq\frac{1}{1-\beta_{1}}\log\frac{G_{1}^{(t)}}{\eta\xi^{2}},H\geq\frac{1}{1-\beta_{2}}\log\frac{\left(G_{1}^{(t)}\right)^{2}}{\eta\xi^{2}}, H11β1logG2(t)ηξ2,H11β2log(G2(t))2ηξ2H\geq\frac{1}{1-\beta_{1}}\log\frac{G_{2}^{(t)}}{\eta\xi^{2}},H\geq\frac{1}{1-\beta_{2}}\log\frac{\left(G_{2}^{(t)}\right)^{2}}{\eta\xi^{2}}. Hence we can apply Lemma 32 in Appendix G to get that q1n(t)[i,j],q1d(t)[i,j],q2n,i(t),q2d,i(t)ηξ2\left|q_{1n}^{(t)}[i,j]\right|,\left|q_{1d}^{(t)}[i,j]\right|,\left|q_{2n,i}^{(t)}\right|,\left|q_{2d,i}^{(t)}\right|\leq\eta\xi^{2}.

D.4 Proof of Corollary 2

D.5 Proof of Lemma 17

The proof is based on the following two lemmas.

Under Assumption 1 and 2, we have that w.p. at least 11dα211-\frac{1}{d^{\frac{\alpha}{2}-1}}, for every 1id1\leq i\leq d, πd32αw2i(0)2d2αlog2dδ\frac{\sqrt{\pi}}{d^{\frac{3}{2}\alpha}}\leq\left|w_{2i}^{(0)}\right|\leq\sqrt{\frac{2}{d^{2\alpha}}\log\frac{2d}{\delta}}, and that w.p. at least 1δ1-\delta for any given δ>0\delta>0, W1(0)[i,j]2d4αlog2d2δ\left|W_{1}^{(0)}[i,j]\right|\leq\sqrt{\frac{2}{d^{4\alpha}}\log\frac{2d^{2}}{\delta}}.

Furthermore, if for certain i,j[d]i,j\in[d], Condition 1 (resp. Condition 2) is satisfied, we will have

Hence we have shown that at some time point t0t_{0}, we have i,j[d]:sign(W1(t)[i,j])=sign(w2i(t))=sign(w2i(0))\forall i,j\in[d]:\text{sign}\left(W_{1}^{(t)}[i,j]\right)=\text{sign}\left(w_{2i}^{(t)}\right)=\text{sign}\left(w_{2i}^{(0)}\right). Now we analyze the period tt0t\geq t_{0}.

Moreover, tincHτtinc,i[d]:sign(g2i(τ))=sign(w2i(0))\forall t_{\text{inc}}-H\leq\tau\leq t_{\text{inc}},\forall i\in[d]:\text{sign}\left(g_{2i}^{(\tau)}\right)=-\text{sign}\left(w_{2i}^{(0)}\right). Then Condition 2 is satisfied with s2i(t)=sign(w2i(0))s_{2i}^{(t)}=-\text{sign}\left(w_{2i}^{(0)}\right) for t=tinct=t_{\text{inc}}. In the analysis of g1(t)[i,j]g_{1}^{(t)}[i,j], we have already shown that for all ttsignt\leq t_{\text{sign}} (and thus for t=tinct=t_{\text{inc}}), Condition 1 is satisfied, which completes the proof.

D.6 Proof of Lemma 20

Since for XN(0,σ2)X\sim\mathcal{N}\left(0,\sigma^{2}\right), we have that P(Xt)2t2πσP(|X|\leq t)\leq\frac{2t}{\sqrt{2\pi}\sigma}, then for a fixed ii,

Then by union bound, we have that w.p. at least 11dα211-\frac{1}{d^{\frac{\alpha}{2}-1}}, for every 1id1\leq i\leq d, w2i(0)πd32α\left|w_{2i}^{(0)}\right|\geq\frac{\sqrt{\pi}}{d^{\frac{3}{2}\alpha}}.

As for the upper bounds, using the Gaussian tail bound and union bound, we have w.p. at least 1δ1-\delta,

D.7 Proof of Lemma 21

Now we analyze the magnitude order of ΔW1(t)[i,j]\Delta W_{1}^{(t)}[i,j]. The analysis of Δw2i(t)\Delta w_{2i}^{(t)} is similar.

where (i)(i) uses Cauchy-Schwarz inequality for the numerator.

On the other hand, when sign(g1(tH)[i,j])=sign(g1(tH+1)[i,j])=...=sign(g1(t)[i,j])=s1(t)[i,j]\text{sign}\left(g_{1}^{(t-H)}[i,j]\right)=\text{sign}\left(g_{1}^{(t-H+1)}[i,j]\right)=...=\text{sign}\left(g_{1}^{(t)}[i,j]\right)=s_{1}^{(t)}[i,j], we have

Using x+yx+y\sqrt{x+y}\leq\sqrt{|x|}+\sqrt{|y|}, we obtain that

Together with the upper bound completes the proof.

D.8 Proof of Lemma 18

The proof is based on the following lemma, which gives a coarse analysis on the magnitude of weights and their increments per step during the first phase.

Under Assumption 1, 2 and 3, suppose ση3/2ξ2d13/4\sigma\leq\frac{\eta^{3/2}\xi^{2}}{d^{13/4}}. Pick ξmin{ηd3α1,1d32α}\xi\leq\min\left\{\sqrt{\frac{\eta}{d^{3\alpha-1}}},\frac{1}{d^{\frac{3}{2}\alpha}}\right\}, for tinct_{\text{inc}} in Lemma 17, we have that w.h.p. for all tinctT1t_{\text{inc}}\leq t\leq T_{1}, i,j[d]\forall i,j\in[d].

Now we go back to the proof of Lemma 18. For tinct<T1t_{\text{inc}}\leq t<T_{1}, since Ej(t)=(W2(t)W1(t))jAj=i=1dw2i(t)W1(t)[i,j]AjE_{j}^{(t)}=\left(W_{2}^{(t)}W_{1}^{(t)}\right)_{j}-A_{j}=\sum_{i=1}^{d}w_{2i}^{(t)}W_{1}^{(t)}[i,j]-A_{j}, we have,

Combining Lemma 22 and eq. (34) gives us j[d]\forall j\in[d],

Let’s first analyze g1(t)[i,j]g_{1}^{(t)}[i,j]. Note that

where sign(w2i(t+1)ΔEj(t))=sign(w2i(0))\text{sign}\left(w_{2i}^{(t+1)}\Delta E_{j}^{(t)}\right)=\text{sign}\left(w_{2i}^{(0)}\right) while sign(Δw2i(t)Ej(t))=sign(w2i(0))\text{sign}\left(\Delta w_{2i}^{(t)}E_{j}^{(t)}\right)=-\text{sign}\left(w_{2i}^{(0)}\right).

As for g2i(t)g_{2i}^{(t)}, since for i[d]\forall i\in[d], W1(t)[i,j]W_{1}^{(t)}[i,j] for different jj have the same sign. Combining with j[d]:Ej(t)<0\forall j\in[d]:E_{j}^{(t)}<0 gives us

D.9 Proof of Lemma 22

For any i,j[d]i,j\in[d], and any tt in the interval [tinc,T1][t_{\text{inc}},T_{1}], we prove by induction that

τ[tH,t]:sign(W1(τ)[i,j])=sign(w2i(τ))=sign(w2i(0))\forall\tau\in[t-H,t]:\text{sign}\left(W_{1}^{(\tau)}[i,j]\right)=\text{sign}\left(w_{2i}^{(\tau)}\right)=\text{sign}\left(w_{2i}^{(0)}\right).

g1(t)[i,j]Ω(ξ),g2i(t)Ω(ξ)\left|g_{1}^{(t)}[i,j]\right|\geq\Omega(\xi),\left|g_{2i}^{(t)}\right|\geq\Omega(\xi).

The base case t=tinct=t_{\text{inc}} was already proven by Lemma 17.

In Section D.2 we have shown that T1=Θ(1dη)T_{1}=\Theta\left(\frac{1}{\sqrt{d}\eta}\right). Then for t[tinc,T1)t\in[t_{\text{inc}},T_{1}), we can use Lemma 21 to get that tincτt\forall t_{\text{inc}}\leq\tau\leq t, i,j[d]\forall i,j\in[d],

Since when t=tinct=t_{\text{inc}}, sign(W1(tinc)[i,j])=sign(w2i(tinc))=sign(w2i(0))\text{sign}\left(W_{1}^{(t_{\text{inc}})}[i,j]\right)=\text{sign}\left(w_{2i}^{(t_{\text{inc}})}\right)=\text{sign}\left(w_{2i}^{(0)}\right). We get that for tincτtt_{\text{inc}}\leq\tau\leq t,

That means τ[t+1H,t+1]:sign(W1(τ)[i,j])=sign(w2i(τ))=sign(w2i(0))\forall\tau\in[t+1-H,t+1]:\text{sign}\left(W_{1}^{(\tau)}[i,j]\right)=\text{sign}\left(w_{2i}^{(\tau)}\right)=\text{sign}\left(w_{2i}^{(0)}\right). This proves (B) for time t+1t+1.

On the other hand, we get that W1(t+1)[i,j]W1(tinc)[i,j]=Θ(1d32α+1)\left|W_{1}^{(t+1)}[i,j]\right|\geq\left|W_{1}^{(t_{\text{inc}})}[i,j]\right|=\Theta\left(\frac{1}{d^{\frac{3}{2}\alpha+1}}\right) and w2i(t+1)w2i(tinc)=Ω(1d32α)\left|w_{2i}^{(t+1)}\right|\geq\left|w_{2i}^{(t_{\text{inc}})}\right|=\Omega\left(\frac{1}{d^{\frac{3}{2}\alpha}}\right). Since t+1T1t+1\leq T_{1} which means j[d]:Ej(t+1)ηd\forall j\in[d]:\left|E_{j}^{(t+1)}\right|\geq\sqrt{\eta d}. Then

Since t+1T1t+1\leq T_{1} which means j[d]:(W2(t+1)W1(t+1))jO(1)\forall j\in[d]:\left(W_{2}^{(t+1)}W_{1}^{(t+1)}\right)_{j}\leq\mathcal{O}(1), we obtain that

D.10 Proof of Lemma 19

Then eq. (31) immediately follows from eq. (30).

D.11 Proof of Lemma 15

We divide Lemma 15 into the following three lemmas. Combining them together immediately gives us the whole proof.

The first lemma below gives us the structure of W2W_{2} in the second phase and that of W1W_{1} under some conditions.

The second lemma below also analyzes the structure of W1W_{1} but removes the conditions in Lemma 23.

D.12 Proof of Lemma 23

The proof is based on the following lemma, which gives a coarse analysis on the magnitude of weights and their increments per step during the second phase.

Equipped with Lemma 26, we are ready to prove Lemma 23. We will only prove the results of w2i(t)w_{2i}^{(t)}. The proof for W1(t)[i,j]W_{1}^{(t)}[i,j] uses the same techniques.

By Lemma 14, we have that at the end of the first phase (t=T1t=T_{1}),

D.13 Proof of Lemma 26

Now we analyze the magnitude order of w2i(t+1),W1(t+1)[i,j]\left|w_{2i}^{(t+1)}\right|,\left|W_{1}^{(t+1)}[i,j]\right|. Let’s first analyze w2i(t+1)\left|w_{2i}^{(t+1)}\right|.

D.14 Proof of Lemma 24

Since (c(tτ)Ej(tτ))2>0\left(c^{(t-\tau)}E_{j}^{(t-\tau)}\right)^{2}>0, in eq. (38) we have that

However in eq. (37), we cannot similarly prove that τ=0Hβ1τRg,1(tτ)[i,j]τ=0Hβ1τc(tτ)Ej(tτ)\left|\sum_{\tau=0}^{H}\beta_{1}^{\tau}R_{g,1}^{(t-\tau)}[i,j]\right|\ll\left|\sum_{\tau=0}^{H}\beta_{1}^{\tau}c^{(t-\tau)}E_{j}^{(t-\tau)}\right| because c(tτ)Ej(tτ)c^{(t-\tau)}E_{j}^{(t-\tau)} may not have the same sign for τ=0,1,...,H\tau=0,1,...,H. To deal with eq.(37), we need to consider the two cases where τ=0Hβ1τRg,1(tτ)[i,j]τ=0Hβ1τc(tτ)Ej(tτ)\left|\sum_{\tau=0}^{H}\beta_{1}^{\tau}R_{g,1}^{(t-\tau)}[i,j]\right|\ll\left|\sum_{\tau=0}^{H}\beta_{1}^{\tau}c^{(t-\tau)}E_{j}^{(t-\tau)}\right| or τ=0Hβ1τRg,1(tτ)[i,j]≪̸τ=0Hβ1τc(tτ)Ej(tτ)\left|\sum_{\tau=0}^{H}\beta_{1}^{\tau}R_{g,1}^{(t-\tau)}[i,j]\right|\not\ll\left|\sum_{\tau=0}^{H}\beta_{1}^{\tau}c^{(t-\tau)}E_{j}^{(t-\tau)}\right|.

(1β1)τ=0Hβ1τRg,1(tτ)[i,j]+ϵ1n(t)[i,j](η14+1dα214)(1β1)τ=0Hβ1τc(tτ)Ej(tτ)\left|(1-\beta_{1})\sum_{\tau=0}^{H}\beta_{1}^{\tau}R_{g,1}^{(t-\tau)}[i,j]+\epsilon_{1n}^{(t)}[i,j]\right|\leq\left(\eta^{\frac{1}{4}}+\frac{1}{d^{\frac{\alpha}{2}-\frac{1}{4}}}\right)\left|(1-\beta_{1})\sum_{\tau=0}^{H}\beta_{1}^{\tau}c^{(t-\tau)}E_{j}^{(t-\tau)}\right|.

(1β1)τ=0Hβ1τRg,1(tτ)[i,j]+ϵ1n(t)[i,j]>(η14+1dα214)(1β1)τ=0Hβ1τc(tτ)Ej(tτ)\left|(1-\beta_{1})\sum_{\tau=0}^{H}\beta_{1}^{\tau}R_{g,1}^{(t-\tau)}[i,j]+\epsilon_{1n}^{(t)}[i,j]\right|>\left(\eta^{\frac{1}{4}}+\frac{1}{d^{\frac{\alpha}{2}-\frac{1}{4}}}\right)\left|(1-\beta_{1})\sum_{\tau=0}^{H}\beta_{1}^{\tau}c^{(t-\tau)}E_{j}^{(t-\tau)}\right|.

where (i)(i) uses Cauchy-Schwarz inequality and β2=β12\beta_{2}=\beta_{1}^{2}, (ii)(ii) uses eq. (38) and (39).

By the first phase analysis, we have that

where RT[i,j]O(η14+1dα214)(Vj(T1)+τTvj(τ))\left|R_{\mathcal{T}}[i,j]\right|\leq\mathcal{O}\left(\eta^{\frac{1}{4}}+\frac{1}{d^{\frac{\alpha}{2}-\frac{1}{4}}}\right)\left(\left|V_{j}^{(T_{1})}\right|+\sum_{\tau\in\mathcal{T}}\left|v_{j}^{(\tau)}\right|\right).

Combining the above results together yields

Therefore, we have that for any j[d]j\in[d],

D.15 Proof of Lemma 25

D.16 Proof of Lemma 27

where (i)(i) is because sign(w2i(t))=sign(W1(t)[i,j])=sign(w2i(0))\text{sign}\left(w_{2i}^{(t)}\right)=\text{sign}\left(W_{1}^{(t)}[i,j]\right)=\text{sign}\left(w_{2i}^{(0)}\right). Ej(t+1)E_{j}^{(t+1)} may not be smaller than Ej(t)E_{j}^{(t)}, but we will show that after at most tst_{s} steps for some tst_{s}, we will have Ej(t+ts+1)<Ej(t+ts)E_{j}^{(t+t_{s}+1)}<E_{j}^{(t+t_{s})}.

Appendix E Hessian tends to become more and more diagonal during training

In this section, we empirically demonstrate that the trend of loss Hessian in practice is to become more and more diagonal during training. We also give a rigorous theoretical analysis on a two-layer network under Assumption 1 and 2.

Let’s first define the diagonal domination of the ii-th coordinate at time tt.

To measure the diagonal domination of the whole Hessian, we need to consider the distribution of rdiag,iOPT(t)r_{\text{diag},i}^{\text{OPT}}(t) for different ii. Figure 14 shows the mean and median of rdiag,iSGDM(t)r_{\text{diag},i}^{\text{SGDM}}(t) and rdiag,iAdam(t)r_{\text{diag},i}^{\text{Adam}}(t) on the sentence classification task (See Section 4.1). Here we chose 4 layers (Layer #6, 12, 17 and 22) and computed the Hessians across these 4 layers. Since the number of parameters is very large, we did the computation by random sampling. As we can see, for both rdiag,iSGDM(t)r_{\text{diag},i}^{\text{SGDM}}(t) and rdiag,iAdam(t)r_{\text{diag},i}^{\text{Adam}}(t), the trend of their mean or median is to decrease over time, although there might be some oscillation.

E.2 Theoretical Analysis

To simplify the theoretical analysis, we consider the mean of rdiag,iOPT(t)r_{\text{diag},i}^{\text{OPT}}(t) over all coordinate and define

We consider a 2-layer network under Assumption 1 and 2, and have two goals in our proof:

To show that RdiagOPT(t)R_{\text{diag}}^{\text{OPT}}(t) after training is smaller than that before training (t=0t=0).

Note that in our setting (see in Assumption 1), the Hessian is a (d2+d)×(d2+d)(d^{2}+d)\times(d^{2}+d) matrix. For a completely “uniform” matrix with the same size, we have that RdiagOPT(t)=Θ(d2+d)=Θ(d)R_{\text{diag}}^{\text{OPT}}(t)=\Theta\left(\sqrt{d^{2}+d}\right)=\Theta(d). Hence our second goal is to show that the RdiagOPT(t)R_{\text{diag}}^{\text{OPT}}(t) after training is on lower order than Θ(d)\Theta(d).

Consider the ratio RdiagOPT(t)R_{\text{diag}}^{\text{OPT}}(t) defined in eq. (40). Under Assumption 1 and 2, we have that before training (t=0t=0), with high probability,

For SGD+M defined in eq. (3). For any p>0p>0, by picking the same hyperparameters as in Theorem 1, for TSGD,1,TSGD,2T_{\text{SGD},1},T_{\text{SGD},2} mentioned in Theorem 1, we have with constant probability, for any t[TSGD,1,TSGD,2]t\in[T_{\text{SGD},1},T_{\text{SGD},2}],

For Adam defined in eq. (3). For any p>0p>0, by picking the same hyperparameters as in Theorem 1, for TAdam,1,TAdam,2T_{\text{Adam},1},T_{\text{Adam},2} mentioned in Theorem 1, we have with high probability, for any t[TAdam,1,TAdam,2]t\in[T_{\text{Adam},1},T_{\text{Adam},2}],

E.3 Proof of Theorem 2

Lemma 4.3 of gives us the following forms of Hessian.

For any k{1,2,...,H+1}k\in\{1,2,...,H+1\}, we know that

where r=(WH+1W1A)T,C=WH+1WHW2r=(W_{H+1}\dots W_{1}-A)^{T},C=W_{H+1}W_{H}\cdots W_{2}.

For the 2-layer linear network, write the Hessian as

Intuitively, before training the elements of W1W_{1} and W2W_{2} are very close to zero, and W2W1AAW_{2}W_{1}-A\approx-A. Since the elements of AA are Θ(1)\Theta(1), we know that the magnitudes of elements of H21H_{21} are much bigger than those of H11H_{11} and H22H_{22}.

After training, for both SGD+M and Adam, W2W1A0W_{2}W_{1}-A\approx 0. Then H21(W2)T(W1)TH_{21}\approx(W_{2})^{T}\otimes(W_{1})^{T} and the magnitudes of its elements are no longer much larger than those of H11H_{11} and H22H_{22}. From the formula of H11H_{11}, we know that all the diagonal entries are nonzero, and among the d4d2d^{4}-d^{2} off-diagonal entries, there are only d3d2d^{3}-d^{2} nonzero entries, which helps us to bound RdiagOPT(t)R_{\text{diag}}^{\text{OPT}}(t).

For the ii-th row where 1id1\leq i\leq d, i.e. the ii-th row of the submatrix [H22H21T][H_{22}\quad H_{21}^{T}], we have

On the other hand, for the diagonal elements, we have w.h.p.

For the (id+k)(id+k)-th row where 1id,1kd1\leq i\leq d,1\leq k\leq d, i.e. the ((i1)d+k)((i-1)d+k)-th row of the submatrix [H21H11][H_{21}\quad H_{11}], we have

On the other hand, for the diagonal elements, we have w.h.p.

Then we have that for 1id,1kd1\leq i\leq d,1\leq k\leq d,

Taking the average, we obtain that before training, i.e. when t=0t=0,

E.3.2 Proof of eq. (42)

Suppose the weight matrices have the following structure:

where 1i,jd:R1[i,j]uivjδ,R2icuiδ,δ(0,1)\forall 1\leq i,j\leq d:\quad\frac{|R_{1}[i,j]|}{|u_{i}v_{j}|}\leq\delta,\quad\frac{|R_{2i}|}{|cu_{i}|}\leq\delta,\quad\delta\in(0,1).

By the analyses in Section C.1, we know that for t[TSGD,1,TSGD,2]t\in[T_{\text{SGD},1},T_{\text{SGD},2}], the weights obtained by GD with momentum satisfy

Here ϵ0\epsilon_{0} is defined in Definition 2. Since u(T1)\boldsymbol{u}^{(T_{1})} doesn’t depend on time tt in the period (TSGD,1,TSGD,2](T_{\text{SGD},1},T_{\text{SGD},2}], we write u(T1)\boldsymbol{u}^{(T_{1})} as u\boldsymbol{u} for ease of notation.

Hence by Lemma 28, when t[TSGD,1,TSGD,2]t\in[T_{\text{SGD},1},T_{\text{SGD},2}], we have for 1id1\leq i\leq d,

By Lemma 3, we have u=X+Y\boldsymbol{u}=X+Y where Xi,i[d]X_{i},i\in[d] are i.i.d Gaussian random variables and w.h.p.,

By the proof in Section C.8, we know that for t[TSGD,1,TSGD,2]t\in[T_{\text{SGD},1},T_{\text{SGD},2}], i[d]:vi(t),c(t)\forall i\in[d]:v_{i}^{(t)},c^{(t)} are positive. The induction in Section C.9 further gives us that for t[TSGD,1,TSGD,2]t\in[T_{\text{SGD},1},T_{\text{SGD},2}], w.h.p. k[d]:vk(t)c(t)=Θ(1d)\forall k\in[d]:\frac{v^{(t)}_{k}}{c^{(t)}}=\Theta\left(\frac{1}{\sqrt{d}}\right), which yields c(t)v(t)2=Θ(1)\frac{c^{(t)}}{\left\|\boldsymbol{v}^{(t)}\right\|_{2}}=\Theta(1). Combining with eq. (47), we obtain

By the proof in Section C.8, we know that for t[TSGD,1,TSGD,2]t\in[T_{\text{SGD},1},T_{\text{SGD},2}], i[d]:vi(t),c(t)\forall i\in[d]:v_{i}^{(t)},c^{(t)} are positive and monotonically increasing. On the other hand, the proof in Section C.2 and C.9 tells us that w.h.p. E(t)2\left\|E^{(t)}\right\|_{2} (resp. k[d],Ek(t)\forall k\in[d],\left|E^{(t)}_{k}\right|) decreases from Θ(d)\Theta(\sqrt{d}) (resp. Θ(1)\Theta(1)) when t=TSGD,1t=T_{\text{SGD},1} to O(ϵ0d)\mathcal{O}(\sqrt{\epsilon_{0}d}) (resp. O(ϵ0)\mathcal{O}(\sqrt{\epsilon_{0}})) when t=TSGD,2t=T_{\text{SGD},2}. Therefore, the trend of E(t)2ui2v(t)22\frac{\left\|E^{(t)}\right\|_{2}}{u_{i}^{2}\left\|\boldsymbol{v}^{(t)}\right\|_{2}^{2}} and Ek(t)(c(t))2ui2\frac{\left|E^{(t)}_{k}\right|}{\left(c^{(t)}\right)^{2}u_{i}^{2}} is to decrease over time, and when t=TSGD,2t=T_{\text{SGD},2}, we have w.h.p.

Moreover, when t=TSGD,2t=T_{\text{SGD},2}, the inequality in eq. (26) becomes equality, i.e. c2u22=Θ(d)c^{2}\|\boldsymbol{u}\|_{2}^{2}=\Theta\left(\sqrt{d}\right)and j[d]:u22vj2=Θ(1d)\forall j\in[d]:\|\boldsymbol{u}\|_{2}^{2}v_{j}^{2}=\Theta\left(\frac{1}{\sqrt{d}}\right).

Using u=X+Y\boldsymbol{u}=X+Y and eq. (46), we have

which together with the second inequality in eq. (47) yields

Substituting eq. (48) and (50) into eq. (44) and (45) gives us

where the trend of q1i(t)q_{1i}^{(t)} is to decrease over time and q1i(TSGD,2)O(j=1dXj2Xi2ϵ0)q_{1i}^{(T_{\text{SGD},2})}\leq\mathcal{O}\left(\frac{\sum_{j=1}^{d}X_{j}^{2}}{X_{i}^{2}}\cdot\sqrt{\epsilon_{0}}\right).

where the trend of q2i(t)q_{2i}^{(t)} is to decrease over time and q2i(TSGD,2)O(j=1dXj2Xi2ϵ0d)q_{2i}^{(T_{\text{SGD},2})}\leq\mathcal{O}\left(\frac{\sum_{j=1}^{d}X_{j}^{2}}{X_{i}^{2}}\cdot\sqrt{\frac{\epsilon_{0}}{d}}\right).

where the trend of q(t)q^{(t)} is to decrease over time and

Denote σ2\sigma^{2} as the variance of XiX_{i} for i[d]i\in[d]. By concentration of chi-squared distribution, we know that with probability at least 1δ1-\delta for δ>0\delta>0,

By Lemma 35 in Appendix G, we know that with constant probability 1di=1d1Xi=O(1σlogd)\frac{1}{d}\sum_{i=1}^{d}\frac{1}{|X_{i}|}=\mathcal{O}\left(\frac{1}{\sigma}\log d\right). Then with constant probability, 1di=1d1Xi21d(i=1d1Xi)2=O(dσ2log2d)\frac{1}{d}\sum_{i=1}^{d}\frac{1}{X_{i}^{2}}\leq\frac{1}{d}\left(\sum_{i=1}^{d}\frac{1}{|X_{i}|}\right)^{2}=\mathcal{O}\left(\frac{d}{\sigma^{2}}\log^{2}d\right). Hence

E.3.3 Proof of eq. (43)

By the analyses in Section D.1, we know that for t[TAdam,1,TAdam,2]t\in[T_{\text{Adam},1},T_{\text{Adam},2}], the weights obtained by Adam satisfy

where i[d]:ui=sign(w2i(0)){±1}\forall i\in[d]:u_{i}=\text{sign}(w_{2i}^{(0)})\in\{\pm 1\} and

Hence by Lemma 28, when t[TAdam,1,TAdam,2]t\in[T_{\text{Adam},1},T_{\text{Adam},2}], we have for 1id1\leq i\leq d,

Combining (A) and (B), we get that the trend of E(t)2v(t)22\frac{\left\|E^{(t)}\right\|_{2}}{\left\|\boldsymbol{v}^{(t)}\right\|_{2}^{2}} and Ek(t)(c(t))2\frac{\left|E^{(t)}_{k}\right|}{\left(c^{(t)}\right)^{2}} is to decrease over time, and when t=TAdam,2t=T_{\text{Adam},2}, we have w.h.p.

Substituting (A) and eq. (53) into eq. (51) and (52) gives us w.h.p.,

E.4 Proof of Lemma 28

By the assumed weight structure, we get that

For the ii-th row where 1id1\leq i\leq d, i.e. the ii-th row of the submatrix [H22H21T][H_{22}\quad H_{21}^{T}], by triangle inequality, we have

For the (id+k)(id+k)-th row where 1id,1kd1\leq i\leq d,1\leq k\leq d, i.e. the ((i1)d+k)((i-1)d+k)-th row of the submatrix [H21H11][H_{21}\quad H_{11}], by triangle inequality again, we have

Then we have that for 1id,1kd1\leq i\leq d,1\leq k\leq d,

Appendix F Connection between diagonal of loss Hessian and weights

The partial derivative at WiW_{i} of the cost function for each ii is given by:

In our experiments, we were interested in the diagonal elements of the hessian. These are given by:

for each possible i,a,bi,a,b. For ease in notation, define for each ii, the quantities Mi:=Wi+1TWH+1TM_{i}:=W_{i+1}^{T}\dots W_{H+1}^{T} and Ni:=W1TWi1TN_{i}:=W_{1}^{T}\dots W_{i-1}^{T}. Then we have the following lemma.

The diagonal elements of the hessian of the cost function are given by:

where the last step follows since MiM_{i} and NiN_{i} are not functions of WiW_{i}.

Note that Mi:=Wi+1TWH+1TM_{i}:=W_{i+1}^{T}\dots W_{H+1}^{T} and Ni:=W1TWi1TN_{i}:=W_{1}^{T}\dots W_{i-1}^{T}. Now define Ci:=MiWH+1Wi+1=MiMiTC_{i}:=M_{i}W_{H+1}\dots W_{i+1}=M_{i}M_{i}^{T} and Di:=Wi1W2W1Ni=NiTNiD_{i}:=W_{i-1}\dots W_{2}W_{1}N_{i}=N_{i}^{T}N_{i} so that:

where CiC_{i} and DiD_{i} are not functions of WiW_{i}. Now, Equation 7474 in the Matrix Cookbookhttps://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf shows us that for any matrices AA and XX we have:

Similarly, consider the Hessian w.r.t. W2W_{2}, we have that M1M1TM_{1}M_{1}^{T} is an identity matrix and N1TN1=W1W1TN_{1}^{T}N_{1}=W_{1}W_{1}^{T} . Therefore,

Hence we have related the uniformity of diagonal Hessian to that of weight matrices. In the detailed analysis, for both GD and Adam, we can prove that W1W_{1} converges to an approximately rank 1 matrix. The following lemma allows us to use this rank 1 structure to compute Rmed,1R_{\text{med},1} and Rmed,2R_{\text{med},2}.

Let’s first consider Rmed,1R_{\text{med},1}. we have

Similarly, for Rmed,2R_{\text{med},2}. We have that

Appendix G Auxiliary lemmas

By Assumption 3 and Chebyshev’s inequality, we have for fixed i,j[d]i,j\in[d] and tTt\leq T,

which gives us with probability at least 11d1-\frac{1}{d}, for tT,i,j[d]\forall t\leq T,\forall i,j\in[d],

Note that for all tTt\leq T and i[d]\forall i\in[d],

Then we have with probability at least 11d1-\frac{1}{d}, for all tTt\leq T and i[d]\forall i\in[d],

Consider two sequences {a(t)}t0,{b(t)}t0\{a^{(t)}\}_{t\geq 0},\{b^{(t)}\}_{t\geq 0}, which satisfy

Suppose τt:b(t)B\forall\tau\leq t:\left|b^{(t)}\right|\leq B, then for any ϵ>0\epsilon>0, the following truncated version

To make it less than ϵ\epsilon, it suffices to choose Hlog(ϵB)/logβH\geq\log(\frac{\epsilon}{B})/\log\beta.

Since β(0,1)\beta\in(0,1), we know that logββ1<0\log\beta\leq\beta-1<0. We also have logϵB<0\log\frac{\epsilon}{B}<0. Then it suffices to choose

Define R:=q1+q2R:=q_{1}+q_{2}. The term q1|q_{1}| can be bounded by

Here the denominator of (i)(i) uses b+ebb(1δ)>0b+e_{b}\geq b(1-\delta)>0 and x+yxy\sqrt{x+y}\geq\sqrt{x}-\sqrt{|y|} when x0,x+y0x\geq 0,x+y\geq 0.

Now let’s bound q4|q_{4}|. If ec>0e_{c}>0, we have ec=ece_{c}=|e_{c}| and q4ecec=1|q_{4}|\leq\frac{\sqrt{e_{c}}}{\sqrt{e_{c}}}=1 since b+ebb(1δ)>0b+e_{b}\geq b(1-\delta)>0.

If ec0e_{c}\leq 0, note that b+eb+ec>0b+e_{b}+e_{c}>0, we have ec<b+beb(1+δ)|e_{c}|<b+b_{e}\leq b(1+\delta), which yields q4ecb=O(1)|q_{4}|\leq\frac{\sqrt{|e_{c}|}}{\sqrt{b}}=\mathcal{O}(1). Combining the above bounds give us q1q3+δq4=O(δ)|q_{1}|\leq|q_{3}|+\delta|q_{4}|=\mathcal{O}(\delta).

On the other hand, q2|q_{2}| can be bounded by

Then Rq1+q2=O(δ)|R|\leq|q_{1}|+|q_{2}|=\mathcal{O}(\delta)

Suppose X1,X2,...,XdX_{1},X_{2},...,X_{d} are i.i.d Gaussian with mean 0 and variance σ2\sigma^{2}, then for 0<δ<1e0<\delta<\frac{1}{e}, we have with probability at least 1δ1-\delta,

It suffices to assume that σ2=1\sigma^{2}=1 and prove that w.p. at least 1δ1-\delta, max1idXi2C1logdC2loglog1δ\max_{1\leq i\leq d}X_{i}^{2}\geq C_{1}\log d-C_{2}\log\log\frac{1}{\delta}.

where the last inequality uses 1xex1-x\leq e^{-x} for xx\in. Let exp(dαeβx2)=δ\exp(-d\alpha e^{-\beta x^{2}})=\delta, we get that w.p. at least 1δ1-\delta,

Suppose X1,X2,...,XdX_{1},X_{2},...,X_{d} are i.i.d Gaussian with mean 0 and variance σ2\sigma^{2}, then we have with constant probability,

It suffices to assume that σ2=1\sigma^{2}=1 and prove that with constant probability, 1di=1d1XiO(logd)\frac{1}{d}\sum_{i=1}^{d}\frac{1}{|X_{i}|}\leq\mathcal{O}\left(\log d\right).

which means with constant probability, 1di=1d1Xi=O(logd)\frac{1}{d}\sum_{i=1}^{d}\frac{1}{|X_{i}|}=\mathcal{O}\left(\log d\right). ∎