Implicit Bias of AdamW: $\ell_\infty$ Norm Constrained Optimization

Shuo Xie, Zhiyuan Li

Introduction

Which solution does AdamW\mathtt{AdamW} converge to, if it converges?

Our following main result Theorem 1.1 characterizes the implicit bias of AdamW\mathtt{AdamW} in the deterministic case, where a full-batch loss is used:

If LL is additionally convex, then AdamW\mathtt{AdamW} converges to the constrained minimizer, i.e., xarg minx1λL(x){\bm{x}}_{\infty}\in\operatorname*{arg\,min}_{\left\|{\bm{x}}\right\|_{\infty}\leq\frac{1}{\lambda}}L({\bm{x}}).

Despite being simplistic, the full-batch setting is still a very interesting and highly non-trivial regime, because the two main hypotheses of why Adam\mathtt{Adam} outperforms SGD\mathtt{SGD} got challenged recently in the deterministic regime (Kunstner et al., 2022). The first hypothesis is that Adam\mathtt{Adam} outperforms SGD\mathtt{SGD} by better handling heavy-tailed noise (Zhang et al., 2020). However, Kunstner et al. (2022) finds that Adam\mathtt{Adam} still outperforms GD\mathtt{GD} for optimizing language tasks even in the full-batch setting. The second hypothesis is the smoothness of the training loss landscape can linearly increase as the gradient norm increases and thus clipping or normalization is necessary for gradient descent. Intriguingly, Kunstner et al. (2022) finds that normalizing each update of GD cannot close the gap towards Adam\mathtt{Adam} in the full-batch setting, but normalizing each coordinate to its sign (i.e., SignGD\mathtt{SignGD}) closes the gap.

In Section 3.1, we prove normalized steepest descent with weight decay optimizes convex functions under norm constraints (Theorem 3.5). In Section 3.2, we prove it must converge to KKT points of the norm-constrained optimization problem for general loss functions if it converges with a learning rate schedule whose partial sum diverges (Theorem 3.7).

In Section 4, we prove AdamW\mathtt{AdamW} must converge to KKT points of the norm-constrained optimization problem for general loss functions if it converges with a non-increasing learning rate schedule whose partial sum diverges (Theorem 1.1).

Towards generalizing the proof of Theorem 3.7 to Theorem 1.1, we prove a novel and tight upper bound on average update size of Adam\mathtt{Adam} (Lemma 4.2), which holds even for non-deterministic settings as well and might be of independent interest to the community. We test various predictions made by our bound in experiments.

Preliminaries and Notations

Steepest Descent:

We say v{\bm{v}} is a steepest descent direction for objective function LL at current iterate x{\bm{x}} w.r.t. norm \left\|\cdot\right\| iff v=1\left\|{\bm{v}}\right\|=1 and v,L(x)=minv1v,L(x)\left\langle{\bm{v}},\nabla L({\bm{x}})\right\rangle=\min_{\left\|{\bm{v}}^{\prime}\right\|\leq 1}\left\langle{\bm{v}}^{\prime},\nabla L({\bm{x}})\right\rangle. Thus for all steepest descent direction v{\bm{v}}, we have that v,L(x)=L(x)\left\langle{\bm{v}},\nabla L({\bm{x}})\right\rangle=-\left\|\nabla L({\bm{x}})\right\|_{*}.

Given initialization x0{\bm{x}}_{0}, learning rate schedule {ηt}t=0\{\eta_{t}\}_{t=0}^{\infty} and weight decay factor λ\lambda, the ttth iterate of normalized steepest descent w.r.t. \left\|\cdot\right\| with decoupled weight decay is defined as

Because the dual norm of the dual norm is always equal to the original norm, by Lemma 2.1, we can also characterize the steepest descent directions as the subgradient of its dual norm.

arg maxΔ1L(x)Δ=yy=L(x)\operatorname*{arg\,max}\limits_{\left\|{\bm{\Delta}}\right\|\leq 1}\nabla L({\bm{x}})^{\top}{\bm{\Delta}}=\left.\partial\left\|{\bm{y}}\right\|_{*}\right|_{{\bm{y}}=\nabla L({\bm{x}})}.

Warm Up: Implicit Bias of Normalized Steepest Descent w. Weight Decay

Our analysis in this section holds for all norms, including the non-differentiable ones, like \left\|\cdot\right\|_{\infty}.

In this subsection, we give a simple non-asymptotic convergence analysis for normalized Steepest descent w. weight decay (NSD\mathtt{NSD}-WD\mathtt{WD}) w.r.t. to general norms over smooth convex loss functions. If the norm of initialization is no larger than 1λ\frac{1}{\lambda} where λ\lambda is the weight decay factor then surprisingly NSD\mathtt{NSD}-WD\mathtt{WD} is exactly equivalent to a well-known optimization algorithm in literature, Frank\mathtt{Frank}-Wolfe\mathtt{Wolfe} (Frank et al., 1956), where the constraint set here is the norm ball with radius 1λ\frac{1}{\lambda}. If the norm of initialization is larger than 1λ\frac{1}{\lambda}, then the analysis contains an additional phase where the norm of iterates linearly converges to 1λ\frac{1}{\lambda}. In this case, the iterate of NSD\mathtt{NSD}-WD\mathtt{WD} may always be outside the 1λ\frac{1}{\lambda} norm ball, but still, the convergence analysis of Frank\mathtt{Frank}-Wolfe\mathtt{Wolfe} can be adopted (e.g., Jaggi (2013)). First, we show that the norm of the iterates will shrink to 1λ\frac{1}{\lambda} as long as the norm of each update is bounded by 11, i.e., Δt1\left\|{\bm{\Delta}}_{t}\right\|\leq 1. Note this conclusion doesn’t use the convexity of the function L(x)L({\bm{x}}) nor the update Δt{\bm{\Delta}}_{t} being the steepest descent direction. It can hold under non-deterministic settings.

For any learning rate schedule {ηt}t=1\{\eta_{t}\}_{t=1}^{\infty} and update {Δt}t=1\{{\bm{\Delta}}_{t}\}_{t=1}^{\infty} such that ληt<1\lambda\eta_{t}<1 and Δt1\left\|{\bm{\Delta}}_{t}\right\|\leq 1, xt1λmax(eλi=1tηi(x01λ),0)\left\|{\bm{x}}_{t}\right\|-\frac{1}{\lambda}\leq\max\left(e^{-\lambda\sum_{i=1}^{t}\eta_{i}}\left(\left\|{\bm{x}}_{0}\right\|-\frac{1}{\lambda}\right),0\right).

The proof is deferred to Section A.1. Lemma 3.1 shows that xt{\bm{x}}_{t} is either always inside the norm ball with radius 1λ\frac{1}{\lambda}, or their distance shrinks exponentially as the sum of learning rates increases. Whenever xt{\bm{x}}_{t} gets into the norm ball with radius 1λ\frac{1}{\lambda}, xt{\bm{x}}_{t} will not leave it and the remaining trajectory of NSD\mathtt{NSD}-WD\mathtt{WD} is exactly the same as Frank\mathtt{Frank}-Wolfe\mathtt{Wolfe}, as shown in the following theorem. We note the relationship between Frank\mathtt{Frank}-Wolfe\mathtt{Wolfe} and steepest descent algorithms is also observed very recently in the continuous case (Chen et al., 2023).

For any norm \left\|\cdot\right\|, weight decay λ\lambda, and xt11λ\left\|{\bm{x}}_{t-1}\right\|\leq\frac{1}{\lambda}, NSD\mathtt{NSD}-WD\mathtt{WD} with learning rate ηt<1λ\eta_{t}<\frac{1}{\lambda} and Frank\mathtt{Frank}-Wolfe\mathtt{Wolfe} (Algorithm 2) with step size γt=ηtλ\gamma_{t}=\eta_{t}\lambda and convex set X{yy1λ}\mathcal{X}\triangleq\{{\bm{y}}\mid\left\|{\bm{y}}\right\|\leq\frac{1}{\lambda}\} generate the same next iterate xt{\bm{x}}_{t}.

Define x=arg minx1λL(x){\bm{x}}^{*}=\operatorname*{arg\,min}_{\left\|{\bm{x}}\right\|\leq\frac{1}{\lambda}}L({\bm{x}}) to be the constrained minimizer of convex function L(x)L({\bm{x}}). We first compute how much the gap between L(xt)L({\bm{x}}_{t}) and L(x)L({\bm{x}}^{*}) can decrease in one normalized steepest descent step when the iterate xt{\bm{x}}_{t} is bounded.

Suppose loss function LL is convex and has HH-lipschitz gradient w.r.t. norm \left\|\cdot\right\|. For iterates {xt}\{{\bm{x}}_{t}\} in NSD\mathtt{NSD}-WD\mathtt{WD} (Equation 1), we have that

The proof of Lemma 3.3 is deferred to Section A.2. With Lemma 3.3, we can prove the convergence of L(xt)L({\bm{x}}_{t}) for learning rate schedules with certain conditions. The proof is also deferred to Section A.2.

Assume that ηt0\eta_{t}\geq 0, limtηt=0\lim_{t\rightarrow\infty}\eta_{t}=0 and t=1ηt=\sum_{t=1}^{\infty}\eta_{t}=\infty. For any convex loss LL with HH-lipschitz gradient, limtL(xt)=L(x)\lim_{t\rightarrow\infty}L({\bm{x}}_{t})=L({\bm{x}}^{*}).

We also provide a specific example of learning rates {ηt}t=1\{\eta_{t}\}_{t=1}^{\infty} that can achieve O(1t)O(\frac{1}{t}) convergence of f(xt)f({\bm{x}}_{t}), which is the same as Frank\mathtt{Frank}-Wolfe\mathtt{Wolfe} over convex objectives (Jaggi, 2013) and the proof is standard. For completeness, we provide a proof of Theorem 3.5 in Section A.2.

Define B=max{x0,1λ}B=\max{\{\left\|{\bm{x}}_{0}\right\|,\frac{1}{\lambda}\}}. For NSD\mathtt{NSD}-WD\mathtt{WD} with learning rate schedule ηt=2λ(t+1)\eta_{t}=\frac{2}{\lambda(t+1)}, we have L(xt)L(x)2H(1+λB)2(t+2)λ2L({\bm{x}}_{t})-L({\bm{x}}^{*})\leq\frac{2H(1+\lambda B)^{2}}{(t+2)\lambda^{2}} for t1t\geq 1.

2 Non-convex setting: convergence to KKT points

In this subsection, we study the implicit bias of SignGD\mathtt{SignGD} (or more generally, NSD\mathtt{NSD}-WD\mathtt{WD}) when the loss is non-convex. In such case, last-iterate parameter convergence is, in general, difficult to showIndeed, even for convex case, Frank\mathtt{Frank}-Wolfe\mathtt{Wolfe} may not converge in parameter.(Bolte et al., 2023), and thus we turn to study what parameters SignGD\mathtt{SignGD} and NSD\mathtt{NSD}-WD\mathtt{WD} can converge to. Our main results Theorem 3.7 show that such parameters must be the KKT points (see Definition 3.6) of the constrained optimization problems. In particular, if the objective is convex, since the norm ball constraint is always convex for all norm, all KKT points are constrained minimizers.

For convex LL, all KKT points x{\bm{x}}^{*} are optimal and the dual variable s0s^{*}\geq 0 is the certificate for the optimality. To see that, for any other y1λ\left\|{\bm{y}}\right\|\leq\frac{1}{\lambda}, it holds that L(y)L(y)+s(y1λ)L(x)+s(x1λ)L({\bm{y}})\geq L({\bm{y}})+s^{*}(\left\|{\bm{y}}\right\|-\frac{1}{\lambda})\geq L({\bm{x}}^{*})+s^{*}(\left\|{\bm{x}}^{*}\right\|-\frac{1}{\lambda}), where the second inequality is because L(x)+sxL({\bm{x}})+s^{*}\left\|{\bm{x}}\right\| is also convex and is its subgradient at x{\bm{x}}^{*}. Thus we conclude L(y)L(x)+s(x1λ)=L(x)L({\bm{y}})\geq L({\bm{x}}^{*})+s^{*}(\left\|{\bm{x}}^{*}\right\|-\frac{1}{\lambda})=L({\bm{x}}^{*}).

Now we state the main result for this subsection.

To prove Theorem 3.7, we use the following alternative characterization for KKT points of minx1λL(x)\min_{\left\|{\bm{x}}\right\|\leq\frac{1}{\lambda}}L({\bm{x}}) below based on Lemma 2.1.

x{\bm{x}} is a KKT point of minx1λL(x)\min_{\left\|{\bm{x}}\right\|\leq\frac{1}{\lambda}}L({\bm{x}}) iff x1λ\left\|{\bm{x}}\right\|\leq\frac{1}{\lambda} and λx,L(x)=L(x)\left\langle-\lambda{\bm{x}},\nabla L({\bm{x}})\right\rangle=\left\|\nabla L({\bm{x}})\right\|_{*}.

The following lemma (Lemma 3.9) circumvents the above issue by considering the weighted average of past steepest descent directions, which provably converges, given the iterates {xt}t=1\{{\bm{x}}_{t}\}_{t=1}^{\infty} converge. Theorem 3.7 is a direct combination of Lemma 3.9 and Lemma 3.8 and we omit its proof. The proof of Lemma 3.9 is deferred into Section A.3.

For any learning rate schedule {ηt}t=1\{\eta_{t}\}_{t=1}^{\infty} satisfying t=1ηt=\sum_{t=1}^{\infty}\eta_{t}=\infty, if the iterates of NSD\mathtt{NSD}-WD\mathtt{WD} {xt}t=0\{{\bm{x}}_{t}\}_{t=0}^{\infty} converges to some x{\bm{x}}_{\infty}, we have that

Δ:=limTt=1TηtΔtt=1Tηt{\bm{\Delta}}_{\infty}:=\lim\limits_{T\rightarrow\infty}\frac{\sum_{t=1}^{T}\eta_{t}{\bm{\Delta}}_{t}}{\sum_{t=1}^{T}\eta_{t}} exists and Δ=λx{\bm{\Delta}}_{\infty}=-\lambda{\bm{x}}_{\infty}.

L(x),Δ=L(x)\left\langle\nabla L({\bm{x}}_{\infty}),{\bm{\Delta}}_{\infty}\right\rangle=\left\|\nabla L({\bm{x}}_{\infty})\right\|_{*}.

Δ1\left\|{\bm{\Delta}}_{\infty}\right\|\leq 1.

Implicit Bias of AdamW

In this section, we extend the analysis on NSD\mathtt{NSD}-WD\mathtt{WD} in Section 3 to AdamW\mathtt{AdamW} to prove that the converged parameters of AdamW\mathtt{AdamW} is the KKT point of the constrained optimization problem. The proof relies on an upper bound of average update size of AdamW\mathtt{AdamW} and we find that the bound can also be used to guide hyperparameter tuning in empirical study.

For non-increasing learning rate schedule {ηt}t=0\{\eta_{t}\}_{t=0}^{\infty} satisfying t=1ηt=\sum_{t=1}^{\infty}\eta_{t}=\infty and β2β1\beta_{2}\geq\beta_{1}, we get {xt}t=1\{{\bm{x}}_{t}\}_{t=1}^{\infty} by running AdamW with weight decay factor λ\lambda. If {xt}t=0\{{\bm{x}}_{t}\}_{t=0}^{\infty} converges to some x{\bm{x}}_{\infty}, then it holds that

Δ:=limTt=1TηtΔtt=1Tηt{\bm{\Delta}}_{\infty}:=\lim\limits_{T\rightarrow\infty}\frac{\sum_{t=1}^{T}\eta_{t}{\bm{\Delta}}_{t}}{\sum_{t=1}^{T}\eta_{t}} exists and Δ=λx{\bm{\Delta}}_{\infty}=-\lambda{\bm{x}}_{\infty}.

L(x),Δ=L(x)1\left\langle\nabla L({\bm{x}}_{\infty}),{\bm{\Delta}}_{\infty}\right\rangle=\left\|\nabla L({\bm{x}}_{\infty})\right\|_{1}.

Δ1\left\|{\bm{\Delta}}_{\infty}\right\|_{\infty}\leq 1.

The first two properties in Lemma 4.1 follow from a similar argument for Lemma 4.1, and the main technical difficulty here lies in the proof of the third property. This is because for any single tt, Δt\left\|{\bm{\Delta}}_{t}\right\| could be larger than 11, which is different from the case of NSD\mathtt{NSD}-WD\mathtt{WD}. To prove the third property, we need a tight upper bound for the average update size of Adam\mathtt{Adam}-like update rule, which is Lemma 4.2. The proof of Lemma 4.1 is deferred to Appendix B.

As mentioned earlier, Adam\mathtt{Adam} updates mtvt\left\|\frac{{\bm{m}}_{t}}{\sqrt{{\bm{v}}_{t}}}\right\| can easily go beyond 11 and thus we prove the following upper bound for the average update size of Adam\mathtt{Adam} (Lemma 4.2). The proof of Lemma 4.2 is deferred to Section B.1.

In particular, when β1=β2\beta_{1}=\beta_{2}, it even holds that Δt1|\Delta_{t}|\leq 1.

Note {vt}t=0\{v_{t}\}_{t=0}^{\infty} here only needs to satisfy a more general condition rather than to be the exact moving average of gt2g_{t}^{2}. It can be applied to the practical scenario where a small positive constant ϵ\epsilon is added to vt\sqrt{{\bm{v}}_{t}} in the denominator to improve the numerical stability of Adam\mathtt{Adam}. It is easy to verify that for vt{\bm{v}}_{t} in Algorithm 1, we have that

Therefore, for Adam\mathtt{Adam} with ϵ\epsilon, vtv_{t} in Lemma 4.2 is always lower bounded, and if we further have an upper bound for gradients, then we can easily control the average update size of Adam\mathtt{Adam}. One nice property is that the upper bound only scales up logarithmically to 1/ϵ1/\epsilon, instead of linearly, as the naive upper bound scales.

Another application of Lemma 4.2 is to provide a tight upper bound for the norm of iterates for any setting, e.g., before convergence or even when the gradient is stochastic. In particular, when the learning rate does not change over steps, we have the following upper bound whose proof is in Section B.4.

For any coordinate j[d]j\in[d], for AdamW\mathtt{AdamW} with constant learning rate η\eta and weight decay factor λ\lambda, with Cmax1tTlnvt,jv1,jC\triangleq\max\limits_{1\leq t\leq T}\left|\ln{\frac{{\bm{v}}_{t,j}}{{\bm{v}}_{1,j}}}\right|, it holds that

When β1=β2\beta_{1}=\beta_{2}, we only need T=Ω(logx0λη)T=\Omega\left(\frac{\log\left\|{\bm{x}}_{0}\right\|_{\infty}}{\lambda\eta}\right) to guarantee that xT,j\left|{\bm{x}}_{T},j\right| is no larger than 1λ\frac{1}{\lambda} for any λη1\lambda\eta\leq 1. However, when β1<β2\beta_{1}<\beta_{2} and β1<1λη\beta_{1}<1-\lambda\eta, the dominating term on the right-hand side is Cηλ(β2β1)(1β2)(1ηλβ1)C\cdot\frac{\eta\lambda(\beta_{2}-\beta_{1})}{(1-\beta_{2})(1-\eta\lambda-\beta_{1})}. Assuming C=O(1)C=O(1), it also requires λη1β2<1β1\lambda\eta\ll 1-\beta_{2}<1-\beta_{1} or λη<1β21β1\lambda\eta<1-\beta_{2}\approx 1-\beta_{1} to ensure the remaining term is small.

Experiments

2 A synthetic problem

Related Work

While Stochastic Gradient Descent (Robbins & Monro, 1951) remains popular for optimizing deep learning models like ResNet (He et al., 2016), only adaptive methods can efficiently train recently-emerged large language models (Zhang et al., 2020). There has been a fruitful amount of research on adaptive gradient method, including AdaGrad (Duchi et al., 2011), RMSProp (Tieleman & Hinton, 2012), AdaDelta (Zeiler, 2012), Adam (Kingma & Ba, 2014), AdaFactor (Shazeer & Stern, 2018), AMSGrad (Reddi et al., 2018), AdaBound (Luo et al., 2018), Lion (Chen et al., 2024), etc. Recently there have been also adaptive methods attempting to accelerate by leveraging the second-order information, e.g., AdaHessian (Yao et al., 2021) and Sophia (Liu et al., 2023). However, most algorithms that are able to train large language models adopt coordinate-wise adaptivity. In contrast, stochastic gradient descent, even equipped with global gradient norm clipping, cannot match the performance of coordinate-wise adaptive algorithms on language tasks (Li et al., 2022a). Previous work has given convergence rate for RMSProp and Adam under different assumptions (Chen et al., 2018; Zou et al., 2019; Shi & Li, 2021; Guo et al., 2021; Défossez et al., 2022; Zhang et al., 2022).

Our work shows that AdamW\mathtt{AdamW} and SignGD\mathtt{SignGD} with weight decay converge to the same point assuming convergence. Balles & Hennig (2018); Kunstner et al. (2022) point out that the similarity with SignGD\mathtt{SignGD} largely accounts for the advantage of Adam\mathtt{Adam} over SGD\mathtt{SGD}. Moreover, when SignGD\mathtt{SignGD} is equipped with momentum which is one key component of Adam\mathtt{Adam}, it can achieve comparable empirical results with Adam\mathtt{Adam} for various tasks (Balles & Hennig, 2018; Kunstner et al., 2022; Bernstein et al., 2018; Crawshaw et al., 2022).

Role of Weight Decay:

The usage of weight decay, which refers to shrinking the parameter by a small constant fraction, can be dated back to the 1980s (Rumelhart et al., 1986; Hinton, 1987). It has been recognized as a standard trick to improve the generalization performance of neural networks (Krogh & Hertz, 1991; Bos & Chug, 1996) for a long time. Krizhevsky et al. (2012) first noticed that weight decay can sometimes accelerate optimization in deep learning. For modern architectures equipped with normalization layers, e.g., BatchNorm (Ioffe & Szegedy, 2015) and LayerNorm (Ba et al., 2016), only the direction of the parameters before normalization layers matters, rather than their norms. Turning on weight decay in such settings changes the effective learning rate of the parameters (Hoffer et al., 2018; Arora et al., 2018; Zhang et al., 2018; Li & Arora, 2019; Li et al., 2020).

Implicit Regularization:

The concurrent work by Chen et al. (2023) is arguably the most related work to us, where the recently discovered optimization algorithm by auto-search, Lion (Chen et al., 2024), is elegantly generalized to a family of algorithms, Lion-K\mathcal{K}, where K\mathcal{K} is some convex function. When K\mathcal{K} is chosen to be the dual norm and momentum in Lion-K\mathcal{K} is turned off, Lion-K\mathcal{K} becomes the normalized steepest descent. Their analysis shows that even with momentum, the steepest normalized descent with weight decay can be viewed as optimization under the original norm constraint. However, in any Lion-K\mathcal{K} algorithm, the update at one step tt only depends on past iterates through first-order momentum mt{\bm{m}}_{t}. Their analysis cannot be applied to AdamW\mathtt{AdamW} because AdamW\mathtt{AdamW} cannot be written in the form of Lion-K\mathcal{K} for any convex function K\mathcal{K}. To see this, simply note that the update of Lion-K\mathcal{K} for a fixed K\mathcal{K} is completely determined by gt,mt{\bm{g}}_{t},{\bm{m}}_{t} and xt{\bm{x}}_{t} while the update of AdamW\mathtt{AdamW} can still be different if the second order momentum vt{\bm{v}}_{t} is different. In terms of proof technique, Chen et al. (2023) constructs the Lyapunov function while we directly characterize the KKT point and connect the converged point to KKT point through the weighted average update.

Discussion and Future Works

This work focuses on the implicit bias of AdamW\mathtt{AdamW} in the deterministic (or full-batch) case. Though our upper bound on the average update size of Adam\mathtt{Adam} holds unconditionally on the input gradients, regardless of stochasticity or not, it is unlikely that the 1λ\frac{1}{\lambda} upper bound can be reached when there is large gradient noise, especially when β2\beta_{2} is very close to 11. In that case, the denominator of the update of AdamW\mathtt{AdamW} is roughly the square root of the square of the expected gradient plus some additional gradient variance term, which strictly dominates the expected gradient in the numerator. Malladi et al. (2022) uses Stochastic Differential Equation (SDE) approximation to model the trajectories of Adam\mathtt{Adam} in such regime and empirically tests the implication of SDE approximation, namely the square root scaling rule.

Another important future direction is to provide non-asymptotic convergence rates for AdamW\mathtt{AdamW} in both convex and non-convex settings.

Conclusions

References

Appendix A Omitted Proofs in Section 3

In this section, we provide the omitted proofs in Section 3, which shows the iterates and the converged solution by normalized steepest descent with decoupled weight decay before diving into the analysis on AdamW\mathtt{AdamW}. In Section A.1, we prove that the iterates will enter or stay in the norm ball with radius 1λ\frac{1}{\lambda} for any normalized update. In Section A.2, we prove that the iterates of normalized steepest descent with weight decay will converge to the constrained minimizer of L(x)L({\bm{x}}) in the same ball with proper learning rates.

Lemma 3.1 We prove by induction that xt1λ+i=1t(1ληi)(x01λ)\left\|{\bm{x}}_{t}\right\|\leq\frac{1}{\lambda}+\prod_{i=1}^{t}(1-\lambda\eta_{i})\left(\left\|{\bm{x}}_{0}\right\|-\frac{1}{\lambda}\right).

When x0>1λ\left\|{\bm{x}}_{0}\right\|>\frac{1}{\lambda}, we have that

When x01λ\left\|{\bm{x}}_{0}\right\|\leq\frac{1}{\lambda}, xt1λ0\left\|{\bm{x}}_{t}\right\|-\frac{1}{\lambda}\leq 0. This completes the proof. ∎

A.2 Omitted proofs for convergence to constrained minimizer with proper learning rates

For normalized steepest descent update Δt{\bm{\Delta}}_{t} from Equation 1,

where the first inequality we use convexity of LL and the second inequality uses x1\left\|{\bm{x}}^{*}\right\|\leq 1.

Since the gradient of LL is HH-lipschitz, by Taylor expansion, we have that

Because the update Δt{\bm{\Delta}}_{t} is normalized and thus have unit norm by definition, it holds that

The proof of Theorem 3.4 is a direct application of Lemma A.1 on the one-step descent lemma Lemma 3.3. ∎

Assume that ηt0\eta_{t}\geq 0, limtηt=0\lim_{t\rightarrow\infty}\eta_{t}=0 and t=1ηt=\sum_{t=1}^{\infty}\eta_{t}=\infty. CC is any positive number and a00a_{0}\geq 0. If the sequence {at}t=0\{a_{t}\}_{t=0}^{\infty} satisfies that at(1ηt)at1+Cηt2a_{t}\leq(1-\eta_{t})a_{t-1}+C\eta_{t}^{2}, then limtat=0\lim_{t\rightarrow\infty}a_{t}=0.

First we show by induction that ata0exp(i=1tηi)+Ci=1tηi2exp(j=i+1tηj)a_{t}\leq a_{0}\exp\left(-\sum_{i=1}^{t}\eta_{i}\right)+C\sum_{i=1}^{t}\eta_{i}^{2}\exp\left(-\sum_{j=i+1}^{t}\eta_{j}\right).

Because t=1ηt=\sum_{t=1}^{\infty}\eta_{t}=\infty, limta0exp(i=1tηi)=0\lim_{t\rightarrow\infty}a_{0}\exp\left(-\sum_{i=1}^{t}\eta_{i}\right)=0. In order to show limtat=0\lim_{t\rightarrow\infty}a_{t}=0, it’s sufficient to show limti=1tηi2exp(j=i+1tηj)=0\lim_{t\rightarrow\infty}\sum_{i=1}^{t}\eta_{i}^{2}\exp\left(-\sum_{j=i+1}^{t}\eta_{j}\right)=0.

From Lemma 3.1, xtmax{x0,1λ}=B\left\|{\bm{x}}_{t}\right\|\leq\max{\{\left\|{\bm{x}}_{0}\right\|,\frac{1}{\lambda}\}}=B for t0t\geq 0. Define CH(1+λB)22λ24(t+1)2C\triangleq\frac{H(1+\lambda B)^{2}}{2\lambda^{2}}\frac{4}{(t+1)^{2}}.

Suppose L(xt1)L(x)4Ct+1L({\bm{x}}_{t-1})-L({\bm{x}}^{*})\leq\frac{4C}{t+1}, we have that

A.3 Omitted Proofs for Lemma 3.9

For any ϵ>0\epsilon>0, there exists tt^{\prime} such that xtxϵ2λ\left\|{\bm{x}}_{t}-{\bm{x}}_{\infty}\right\|\leq\frac{\epsilon}{2\lambda} for any t>tt>t^{\prime}. Because ηtΔt=xt1xtληtxt1\eta_{t}{\bm{\Delta}}_{t}={\bm{x}}_{t-1}-{\bm{x}}_{t}-\lambda\eta_{t}{\bm{x}}_{t-1}, we have that

There exists TtT^{\prime}\geq t^{\prime} such that t=1Tηt2ϵ(x0xλ(t=1tηtxt1t=1tηtx)+ϵ2)\sum_{t=1}^{T}\eta_{t}\geq\frac{2}{\epsilon}\left(\left\|{\bm{x}}_{0}-{\bm{x}}_{\infty}-\lambda\left(\sum_{t=1}^{t^{\prime}}\eta_{t}{\bm{x}}_{t-1}-\sum_{t=1}^{t^{\prime}}\eta_{t}{\bm{x}}_{\infty}\right)\right\|+\frac{\epsilon}{2}\right) for TTT\geq T^{\prime}. Then we have

So Δ:=t=1TηtΔtt=1Tηt{\bm{\Delta}}_{\infty}:=\frac{\sum_{t=1}^{T}\eta_{t}{\bm{\Delta}}_{t}}{\sum_{t=1}^{T}\eta_{t}} exists and Δ=λx{\bm{\Delta}}_{\infty}=-\lambda{\bm{x}}_{\infty}.

Because L(x)\nabla L({\bm{x}}) is a continuous function and limtxt=x\lim_{t\rightarrow\infty}{\bm{x}}_{t}={\bm{x}}_{\infty}, limtL(xt)=L(x)\lim_{t\rightarrow\infty}\nabla L({\bm{x}}_{t})=\nabla L({\bm{x}}_{\infty}). For any ϵ>0\epsilon>0, there exists T1T_{1} such that

for any tT1t\geq T_{1}. It also holds that

because Δt1\left\|{\bm{\Delta}}_{t}\right\|\leq 1. Because t=1ηt=\sum_{t=1}^{\infty}\eta_{t}=\infty, there exists T2T1T_{2}\geq T_{1} such that

Therefore, we prove that L(x)=limTt=1TηtL(x),Δtt=1Tηt\left\|\nabla L({\bm{x}}_{\infty})\right\|_{*}=\lim_{T\to\infty}\frac{\sum_{t=1}^{T}\eta_{t}\left\langle\nabla L({\bm{x}}_{\infty}),{\bm{\Delta}}_{t}\right\rangle}{\sum_{t=1}^{T}\eta_{t}}. On the other hand, we have that

For any TT, we know t=1TηtΔtt=1Tηtt=1TηtΔtt=1Tηtt=1Tηtt=1Tηt=1\left\|\frac{\sum_{t=1}^{T}\eta_{t}{\bm{\Delta}}_{t}}{\sum_{t=1}^{T}\eta_{t}}\right\|\leq\frac{\sum_{t=1}^{T}\eta_{t}\left\|{\bm{\Delta}}_{t}\right\|}{\sum_{t=1}^{T}\eta_{t}}\leq\frac{\sum_{t=1}^{T}\eta_{t}}{\sum_{t=1}^{T}\eta_{t}}=1. By the continuity of \left\|\cdot\right\|, Δ=limTt=1TηtΔtt=1Tηt1\left\|\Delta_{\infty}\right\|=\lim_{T\rightarrow\infty}\left\|\frac{\sum_{t=1}^{T}\eta_{t}{\bm{\Delta}}_{t}}{\sum_{t=1}^{T}\eta_{t}}\right\|\leq 1.

Appendix B Omitted Proofs in Section 4

We first represent mtm_{t} and vtv_{t} as a weighted sum of gtg_{t} and gt2g_{t}^{2}.

By Cauchy–Schwarz inequality, we have that

We further analyze the first time in RHS and have that

B.2 Proof for Lemma 4.1

The proof for this part is the same as the proof of Lemma 3.9 in Section A.3.

If L(x)=0\nabla L({\bm{x}}_{\infty})={\bm{\mathbf{0}}}, L(x),Δ=0=L(x)1\left\langle\nabla L({\bm{x}}_{\infty}),{\bm{\Delta}}_{\infty}\right\rangle=0=\left\|\nabla L({\bm{x}}_{\infty})\right\|_{1}.

If L(x)0\nabla L({\bm{x}}_{\infty})\neq{\bm{\mathbf{0}}}, we consider each coordinate jj such that L(x)j0\nabla L({\bm{x}}_{\infty})_{j}\neq 0. Since we have that limtL(xt)j=L(x)j\lim_{t\rightarrow\infty}\nabla L({\bm{x}}_{t})_{j}=\nabla L({\bm{x}}_{\infty})_{j}, we can get the convergence for mt,j{\bm{m}}_{t,j} and vt,j{\bm{v}}_{t,j}.

Then we have that limtΔt,j=limtmt,jvt,j=sign(L(x)j)\lim_{t\rightarrow\infty}{\bm{\Delta}}_{t,j}=\lim_{t\rightarrow\infty}\frac{{\bm{m}}_{t,j}}{\sqrt{{\bm{v}}_{t,j}}}=\text{sign}(\nabla L({\bm{x}}_{\infty})_{j}). For any ϵ>0\epsilon>0, there exists tt^{\prime} such that Δt,jsign(L(x)j)ϵ2\left\|{\bm{\Delta}}_{t,j}-\text{sign}\left(\nabla L({\bm{x}}_{\infty})_{j}\right)\right\|\leq\frac{\epsilon}{2} for ttt\geq t^{\prime}. And there exists TtT^{\prime}\geq t^{\prime} such that t=1Tηt2ϵt=1tηt(Δt,jsign(L(x)j))\sum_{t=1}^{T}\eta_{t}\geq\frac{2}{\epsilon}\sum_{t=1}^{t^{\prime}}\eta_{t}\left({\bm{\Delta}}_{t,j}-\text{sign}\left(\nabla L({\bm{x}}_{\infty})_{j}\right)\right) for any TTT\geq T^{\prime}. Then for any TTT\geq T^{\prime}, we have that

So Δ,j=limTt=1TηtΔt,jt=1Tηt=sign(L(x)j){\bm{\Delta}}_{\infty,j}=\lim_{T\rightarrow\infty}\frac{\sum_{t=1}^{T}\eta_{t}{\bm{\Delta}}_{t,j}}{\sum_{t=1}^{T}\eta_{t}}=\text{sign}(\nabla L({\bm{x}}_{\infty})_{j}) for L(x)j0\nabla L({\bm{x}}_{\infty})_{j}\neq 0. Then we have that

For nonzero coordinate jj of L(x)\nabla L({\bm{x}}_{\infty}), from above we have Δ,j=sign(L(x)j)=1|{\bm{\Delta}}_{\infty,j}|=|\text{sign}(\nabla L({\bm{x}}_{\infty})_{j})|=1.

For jj such that L(x)j=0\nabla L({\bm{x}}_{\infty})_{j}=0, we know limtgt,j=limtmt,j=limtvt,j=0\lim_{t\rightarrow\infty}{\bm{g}}_{t,j}=\lim_{t\rightarrow\infty}{\bm{m}}_{t,j}=\lim_{t\rightarrow\infty}{\bm{v}}_{t,j}=0. We employ the upper bound for average update in Lemma 4.2 since {gt,j}t=1\{{\bm{g}}_{t,j}\}_{t=1}^{\infty} and {vt,j}t=0\{{\bm{v}}_{t,j}\}_{t=0}^{\infty} in Algorithm 1 satisfy the condition that vt,jβ2vt1,j(1β2)gt,j2{\bm{v}}_{t,j}-\beta_{2}{\bm{v}}_{t-1,j}\geq(1-\beta_{2}){\bm{g}}_{t,j}^{2} and m0,j=0v0,j{\bm{m}}_{0,j}=0\leq\sqrt{{\bm{v}}_{0,j}}. By Lemma 4.2 we have

The denominator goes to \infty when TT\rightarrow\infty. So it suffices to bound the last two terms in the numerator by constants in order to show Δ1\left\|{\bm{\Delta}}_{\infty}\right\|\leq 1. Because ηt\eta_{t} is non-increasing in tt, it holds that

For the last term, we first analyze the coefficient between each lnvt,j\ln{{\bm{v}}_{t,j}}. Define αt=ηt1β1t11β1i=1Ttηt+iβ1i1\alpha_{t}=\eta_{t}\frac{1-\beta_{1}^{t-1}}{1-\beta_{1}}-\sum_{i=1}^{T-t}\eta_{t+i}\beta_{1}^{i-1}. We claim that αtmax{β1t11β1ηt+1,ηt1β1}=ηt1β1\left|\alpha_{t}\right|\leq\max{\{\frac{\beta_{1}^{t-1}}{1-\beta_{1}}\eta_{t+1},\frac{\eta_{t}}{1-\beta_{1}}\}}=\frac{\eta_{t}}{1-\beta_{1}}. This is because

and again by monotonicity of learning rates ηt\eta_{t}, we have that

We can also have lnvt,jv1,j(t1)lnβ2\ln{\frac{{\bm{v}}_{t,j}}{{\bm{v}}_{1,j}}}\geq(t-1)\ln{\beta_{2}} because

And there exists tt^{\prime} such that lnvt,jv1,j0\ln{\frac{{\bm{v}}_{t,j}}{{\bm{v}}_{1,j}}}\leq 0 for any ttt\geq t^{\prime} because limtvt,j=0\lim_{t\rightarrow\infty}{\bm{v}}_{t,j}=0. Then

Define C:=(β2β1)η1β1(1β2)(1β1)+β2β11β2(t=2tηtlnvt,jη1β12lnβ2(1β1)2)C:=\frac{(\beta_{2}-\beta_{1})\eta_{1}\beta_{1}}{(1-\beta_{2})(1-\beta_{1})}+\frac{\beta_{2}-\beta_{1}}{1-\beta_{2}}\left(\sum_{t=2}^{t^{\prime}}\eta_{t}\left|\ln{{\bm{v}}_{t,j}}\right|-\frac{\eta_{1}\beta_{1}^{2}\ln{\beta_{2}}}{(1-\beta_{1})^{2}}\right), we now have

Then Δ,j=limTt=1TηtΔt,jt=1TηtlimTt=1TηtΔt,jt=1Tηt1\left|{\bm{\Delta}}_{\infty,j}\right|=\left|\lim\limits_{T\to\infty}\frac{\sum_{t=1}^{T}\eta_{t}{\bm{\Delta}}_{t,j}}{\sum_{t=1}^{T}\eta_{t}}\right|\leq\lim\limits_{T\to\infty}\left|\frac{\sum_{t=1}^{T}\eta_{t}{\bm{\Delta}}_{t,j}}{\sum_{t=1}^{T}\eta_{t}}\right|\leq 1 because t=1ηT=\sum_{t=1}^{\infty}\eta_{T}=\infty. This completes the proof.

B.4 Proof for upper bound for norm of iterates in 𝙰𝚍𝚊𝚖𝚆𝙰𝚍𝚊𝚖𝚆\mathtt{AdamW}

For AdamW\mathtt{AdamW} with constant learning rate η\eta and each coordinate jj, xT,j{\bm{x}}_{T,j} can be written as weighted average of past update

Define ηt=η(1λη)Tt\eta_{t}=\eta(1-\lambda\eta)^{T-t} for 1tT1\leq t\leq T. We apply Lemma 4.2 on {vt,j}t=1T\{{\bm{v}}_{t,j}\}_{t=1}^{T} and {gt,j}t=1T\{{\bm{g}}_{t,j}\}_{t=1}^{T} to bound t=1Tη(1λη)Ttmt,jvt,j\left|\sum_{t=1}^{T}\eta(1-\lambda\eta)^{T-t}\frac{{\bm{m}}_{t,j}}{\sqrt{{\bm{v}}_{t,j}}}\right|.

We first compute t=1Tηt=1(1λη)Tλ1λ\sum_{t=1}^{T}\eta_{t}=\frac{1-(1-\lambda\eta)^{T}}{\lambda}\leq\frac{1}{\lambda}. For the second term in Lemma 4.2, we have that

For the last term, we define αt=ηt1β1t11β1i=1Ttηt+iβ1i1\alpha_{t}=\eta_{t}\frac{1-\beta_{1}^{t-1}}{1-\beta_{1}}-\sum_{i=1}^{T-t}\eta_{t+i}\beta_{1}^{i-1} and we can compute the exact form of αt\alpha_{t} as following

Then we can bound the last term by showing that

Appendix C Experimental Details and More Results

The architecture of the two-layer transformer is the same as in Kunstner et al. (2022), which is also used as a tutorial example in PyTorch. It consists of a 200-dimensional embedding layer, 22 transformer layers and a linear layer. Each transformer layer consists of a 22-head self-attention and an MLP with a hidden dimension 200200. The experiments are run on a single A4000 or A6000.