Implicit Bias of AdamW: $\ell_\infty$ Norm Constrained Optimization
Shuo Xie, Zhiyuan Li
Introduction
Which solution does converge to, if it converges?
Our following main result Theorem 1.1 characterizes the implicit bias of in the deterministic case, where a full-batch loss is used:
If is additionally convex, then converges to the constrained minimizer, i.e., .
Despite being simplistic, the full-batch setting is still a very interesting and highly non-trivial regime, because the two main hypotheses of why outperforms got challenged recently in the deterministic regime (Kunstner et al., 2022). The first hypothesis is that outperforms by better handling heavy-tailed noise (Zhang et al., 2020). However, Kunstner et al. (2022) finds that still outperforms for optimizing language tasks even in the full-batch setting. The second hypothesis is the smoothness of the training loss landscape can linearly increase as the gradient norm increases and thus clipping or normalization is necessary for gradient descent. Intriguingly, Kunstner et al. (2022) finds that normalizing each update of GD cannot close the gap towards in the full-batch setting, but normalizing each coordinate to its sign (i.e., ) closes the gap.
In Section 3.1, we prove normalized steepest descent with weight decay optimizes convex functions under norm constraints (Theorem 3.5). In Section 3.2, we prove it must converge to KKT points of the norm-constrained optimization problem for general loss functions if it converges with a learning rate schedule whose partial sum diverges (Theorem 3.7).
In Section 4, we prove must converge to KKT points of the norm-constrained optimization problem for general loss functions if it converges with a non-increasing learning rate schedule whose partial sum diverges (Theorem 1.1).
Towards generalizing the proof of Theorem 3.7 to Theorem 1.1, we prove a novel and tight upper bound on average update size of (Lemma 4.2), which holds even for non-deterministic settings as well and might be of independent interest to the community. We test various predictions made by our bound in experiments.
Preliminaries and Notations
Steepest Descent:
We say is a steepest descent direction for objective function at current iterate w.r.t. norm iff and . Thus for all steepest descent direction , we have that .
Given initialization , learning rate schedule and weight decay factor , the th iterate of normalized steepest descent w.r.t. with decoupled weight decay is defined as
Because the dual norm of the dual norm is always equal to the original norm, by Lemma 2.1, we can also characterize the steepest descent directions as the subgradient of its dual norm.
.
Warm Up: Implicit Bias of Normalized Steepest Descent w. Weight Decay
Our analysis in this section holds for all norms, including the non-differentiable ones, like .
In this subsection, we give a simple non-asymptotic convergence analysis for normalized Steepest descent w. weight decay (-) w.r.t. to general norms over smooth convex loss functions. If the norm of initialization is no larger than where is the weight decay factor then surprisingly - is exactly equivalent to a well-known optimization algorithm in literature, - (Frank et al., 1956), where the constraint set here is the norm ball with radius . If the norm of initialization is larger than , then the analysis contains an additional phase where the norm of iterates linearly converges to . In this case, the iterate of - may always be outside the norm ball, but still, the convergence analysis of - can be adopted (e.g., Jaggi (2013)). First, we show that the norm of the iterates will shrink to as long as the norm of each update is bounded by , i.e., . Note this conclusion doesn’t use the convexity of the function nor the update being the steepest descent direction. It can hold under non-deterministic settings.
For any learning rate schedule and update such that and , .
The proof is deferred to Section A.1. Lemma 3.1 shows that is either always inside the norm ball with radius , or their distance shrinks exponentially as the sum of learning rates increases. Whenever gets into the norm ball with radius , will not leave it and the remaining trajectory of - is exactly the same as -, as shown in the following theorem. We note the relationship between - and steepest descent algorithms is also observed very recently in the continuous case (Chen et al., 2023).
For any norm , weight decay , and , - with learning rate and - (Algorithm 2) with step size and convex set generate the same next iterate .
Define to be the constrained minimizer of convex function . We first compute how much the gap between and can decrease in one normalized steepest descent step when the iterate is bounded.
Suppose loss function is convex and has -lipschitz gradient w.r.t. norm . For iterates in - (Equation 1), we have that
The proof of Lemma 3.3 is deferred to Section A.2. With Lemma 3.3, we can prove the convergence of for learning rate schedules with certain conditions. The proof is also deferred to Section A.2.
Assume that , and . For any convex loss with -lipschitz gradient, .
We also provide a specific example of learning rates that can achieve convergence of , which is the same as - over convex objectives (Jaggi, 2013) and the proof is standard. For completeness, we provide a proof of Theorem 3.5 in Section A.2.
Define . For - with learning rate schedule , we have for .
2 Non-convex setting: convergence to KKT points
In this subsection, we study the implicit bias of (or more generally, -) when the loss is non-convex. In such case, last-iterate parameter convergence is, in general, difficult to showIndeed, even for convex case, - may not converge in parameter.(Bolte et al., 2023), and thus we turn to study what parameters and - can converge to. Our main results Theorem 3.7 show that such parameters must be the KKT points (see Definition 3.6) of the constrained optimization problems. In particular, if the objective is convex, since the norm ball constraint is always convex for all norm, all KKT points are constrained minimizers.
For convex , all KKT points are optimal and the dual variable is the certificate for the optimality. To see that, for any other , it holds that , where the second inequality is because is also convex and is its subgradient at . Thus we conclude .
Now we state the main result for this subsection.
To prove Theorem 3.7, we use the following alternative characterization for KKT points of below based on Lemma 2.1.
is a KKT point of iff and .
The following lemma (Lemma 3.9) circumvents the above issue by considering the weighted average of past steepest descent directions, which provably converges, given the iterates converge. Theorem 3.7 is a direct combination of Lemma 3.9 and Lemma 3.8 and we omit its proof. The proof of Lemma 3.9 is deferred into Section A.3.
For any learning rate schedule satisfying , if the iterates of - converges to some , we have that
exists and .
.
.
Implicit Bias of AdamW
In this section, we extend the analysis on - in Section 3 to to prove that the converged parameters of is the KKT point of the constrained optimization problem. The proof relies on an upper bound of average update size of and we find that the bound can also be used to guide hyperparameter tuning in empirical study.
For non-increasing learning rate schedule satisfying and , we get by running AdamW with weight decay factor . If converges to some , then it holds that
exists and .
.
.
The first two properties in Lemma 4.1 follow from a similar argument for Lemma 4.1, and the main technical difficulty here lies in the proof of the third property. This is because for any single , could be larger than , which is different from the case of -. To prove the third property, we need a tight upper bound for the average update size of -like update rule, which is Lemma 4.2. The proof of Lemma 4.1 is deferred to Appendix B.
As mentioned earlier, updates can easily go beyond and thus we prove the following upper bound for the average update size of (Lemma 4.2). The proof of Lemma 4.2 is deferred to Section B.1.
In particular, when , it even holds that .
Note here only needs to satisfy a more general condition rather than to be the exact moving average of . It can be applied to the practical scenario where a small positive constant is added to in the denominator to improve the numerical stability of . It is easy to verify that for in Algorithm 1, we have that
Therefore, for with , in Lemma 4.2 is always lower bounded, and if we further have an upper bound for gradients, then we can easily control the average update size of . One nice property is that the upper bound only scales up logarithmically to , instead of linearly, as the naive upper bound scales.
Another application of Lemma 4.2 is to provide a tight upper bound for the norm of iterates for any setting, e.g., before convergence or even when the gradient is stochastic. In particular, when the learning rate does not change over steps, we have the following upper bound whose proof is in Section B.4.
For any coordinate , for with constant learning rate and weight decay factor , with , it holds that
When , we only need to guarantee that is no larger than for any . However, when and , the dominating term on the right-hand side is . Assuming , it also requires or to ensure the remaining term is small.
Experiments
2 A synthetic problem
Related Work
While Stochastic Gradient Descent (Robbins & Monro, 1951) remains popular for optimizing deep learning models like ResNet (He et al., 2016), only adaptive methods can efficiently train recently-emerged large language models (Zhang et al., 2020). There has been a fruitful amount of research on adaptive gradient method, including AdaGrad (Duchi et al., 2011), RMSProp (Tieleman & Hinton, 2012), AdaDelta (Zeiler, 2012), Adam (Kingma & Ba, 2014), AdaFactor (Shazeer & Stern, 2018), AMSGrad (Reddi et al., 2018), AdaBound (Luo et al., 2018), Lion (Chen et al., 2024), etc. Recently there have been also adaptive methods attempting to accelerate by leveraging the second-order information, e.g., AdaHessian (Yao et al., 2021) and Sophia (Liu et al., 2023). However, most algorithms that are able to train large language models adopt coordinate-wise adaptivity. In contrast, stochastic gradient descent, even equipped with global gradient norm clipping, cannot match the performance of coordinate-wise adaptive algorithms on language tasks (Li et al., 2022a). Previous work has given convergence rate for RMSProp and Adam under different assumptions (Chen et al., 2018; Zou et al., 2019; Shi & Li, 2021; Guo et al., 2021; Défossez et al., 2022; Zhang et al., 2022).
Our work shows that and with weight decay converge to the same point assuming convergence. Balles & Hennig (2018); Kunstner et al. (2022) point out that the similarity with largely accounts for the advantage of over . Moreover, when is equipped with momentum which is one key component of , it can achieve comparable empirical results with for various tasks (Balles & Hennig, 2018; Kunstner et al., 2022; Bernstein et al., 2018; Crawshaw et al., 2022).
Role of Weight Decay:
The usage of weight decay, which refers to shrinking the parameter by a small constant fraction, can be dated back to the 1980s (Rumelhart et al., 1986; Hinton, 1987). It has been recognized as a standard trick to improve the generalization performance of neural networks (Krogh & Hertz, 1991; Bos & Chug, 1996) for a long time. Krizhevsky et al. (2012) first noticed that weight decay can sometimes accelerate optimization in deep learning. For modern architectures equipped with normalization layers, e.g., BatchNorm (Ioffe & Szegedy, 2015) and LayerNorm (Ba et al., 2016), only the direction of the parameters before normalization layers matters, rather than their norms. Turning on weight decay in such settings changes the effective learning rate of the parameters (Hoffer et al., 2018; Arora et al., 2018; Zhang et al., 2018; Li & Arora, 2019; Li et al., 2020).
Implicit Regularization:
The concurrent work by Chen et al. (2023) is arguably the most related work to us, where the recently discovered optimization algorithm by auto-search, Lion (Chen et al., 2024), is elegantly generalized to a family of algorithms, Lion-, where is some convex function. When is chosen to be the dual norm and momentum in Lion- is turned off, Lion- becomes the normalized steepest descent. Their analysis shows that even with momentum, the steepest normalized descent with weight decay can be viewed as optimization under the original norm constraint. However, in any Lion- algorithm, the update at one step only depends on past iterates through first-order momentum . Their analysis cannot be applied to because cannot be written in the form of Lion- for any convex function . To see this, simply note that the update of Lion- for a fixed is completely determined by and while the update of can still be different if the second order momentum is different. In terms of proof technique, Chen et al. (2023) constructs the Lyapunov function while we directly characterize the KKT point and connect the converged point to KKT point through the weighted average update.
Discussion and Future Works
This work focuses on the implicit bias of in the deterministic (or full-batch) case. Though our upper bound on the average update size of holds unconditionally on the input gradients, regardless of stochasticity or not, it is unlikely that the upper bound can be reached when there is large gradient noise, especially when is very close to . In that case, the denominator of the update of is roughly the square root of the square of the expected gradient plus some additional gradient variance term, which strictly dominates the expected gradient in the numerator. Malladi et al. (2022) uses Stochastic Differential Equation (SDE) approximation to model the trajectories of in such regime and empirically tests the implication of SDE approximation, namely the square root scaling rule.
Another important future direction is to provide non-asymptotic convergence rates for in both convex and non-convex settings.
Conclusions
References
Appendix A Omitted Proofs in Section 3
In this section, we provide the omitted proofs in Section 3, which shows the iterates and the converged solution by normalized steepest descent with decoupled weight decay before diving into the analysis on . In Section A.1, we prove that the iterates will enter or stay in the norm ball with radius for any normalized update. In Section A.2, we prove that the iterates of normalized steepest descent with weight decay will converge to the constrained minimizer of in the same ball with proper learning rates.
Lemma 3.1 We prove by induction that .
When , we have that
When , . This completes the proof. ∎
A.2 Omitted proofs for convergence to constrained minimizer with proper learning rates
For normalized steepest descent update from Equation 1,
where the first inequality we use convexity of and the second inequality uses .
Since the gradient of is -lipschitz, by Taylor expansion, we have that
Because the update is normalized and thus have unit norm by definition, it holds that
The proof of Theorem 3.4 is a direct application of Lemma A.1 on the one-step descent lemma Lemma 3.3. ∎
Assume that , and . is any positive number and . If the sequence satisfies that , then .
First we show by induction that .
Because , . In order to show , it’s sufficient to show .
From Lemma 3.1, for . Define .
Suppose , we have that
A.3 Omitted Proofs for Lemma 3.9
For any , there exists such that for any . Because , we have that
There exists such that for . Then we have
So exists and .
Because is a continuous function and , . For any , there exists such that
for any . It also holds that
because . Because , there exists such that
Therefore, we prove that . On the other hand, we have that
For any , we know . By the continuity of , .
Appendix B Omitted Proofs in Section 4
We first represent and as a weighted sum of and .
By Cauchy–Schwarz inequality, we have that
We further analyze the first time in RHS and have that
B.2 Proof for Lemma 4.1
The proof for this part is the same as the proof of Lemma 3.9 in Section A.3.
If , .
If , we consider each coordinate such that . Since we have that , we can get the convergence for and .
Then we have that . For any , there exists such that for . And there exists such that for any . Then for any , we have that
So for . Then we have that
For nonzero coordinate of , from above we have .
For such that , we know . We employ the upper bound for average update in Lemma 4.2 since and in Algorithm 1 satisfy the condition that and . By Lemma 4.2 we have
The denominator goes to when . So it suffices to bound the last two terms in the numerator by constants in order to show . Because is non-increasing in , it holds that
For the last term, we first analyze the coefficient between each . Define . We claim that . This is because
and again by monotonicity of learning rates , we have that
We can also have because
And there exists such that for any because . Then
Define , we now have
Then because . This completes the proof.
B.4 Proof for upper bound for norm of iterates in 𝙰𝚍𝚊𝚖𝚆𝙰𝚍𝚊𝚖𝚆\mathtt{AdamW}
For with constant learning rate and each coordinate , can be written as weighted average of past update
Define for . We apply Lemma 4.2 on and to bound .
We first compute . For the second term in Lemma 4.2, we have that
For the last term, we define and we can compute the exact form of as following
Then we can bound the last term by showing that
Appendix C Experimental Details and More Results
The architecture of the two-layer transformer is the same as in Kunstner et al. (2022), which is also used as a tutorial example in PyTorch. It consists of a 200-dimensional embedding layer, transformer layers and a linear layer. Each transformer layer consists of a -head self-attention and an MLP with a hidden dimension . The experiments are run on a single A4000 or A6000.