Self-Stabilization: The Implicit Bias of Gradient Descent at the Edge of Stability
Alex Damian, Eshaan Nichani, Jason D. Lee
Introduction
Almost all neural networks are trained using a variant of gradient descent, most commonly stochastic gradient descent (SGD) or ADAM . When deciding on an initial learning rate, many practitioners rely on intuition drawn from classical optimization. In particular, the following classical lemma, known as the "descent lemma," provides a common heuristic for choosing a learning rate in terms of the sharpness of the loss function:
Given a loss function , the sharpness is defined to be . When this eigenvalue is unique, the associated eigenvector is denoted by .
However, it has recently been observed that in neural networks, the descent lemma is not predictive of the optimization dynamics. Recently, Cohen et al. observed two important phenomena for gradient descent, which made more precise similar observations in Jastrzębski et al. for SGD:
Throughout most of the optimization trajectory, the gradient of the loss is negatively aligned with the gradient of sharpness, i.e. As a result, for any reasonable learning rate , the sharpness increases throughout training until it reaches .
Edge of Stability
Once the sharpness reaches (the “break-even” point in Jastrzębski et al. ), it ceases to increase and remains around for the rest of training. The descent lemma no longer guarantees the loss decreases but the loss still continues decreasing, albeit non-monotonically.
2 Self-stabilization: The Implicit Bias of Instability
In this work we explain the second stage, "edge of stability." We identify a new implicit bias of gradient descent which we call self-stabilization. Self-stabilization is the mechanism by which the sharpness remains bounded around , despite the continued force of progressive sharpening, and by which the gradient descent dynamics do not diverge, despite instability. Unlike progressive sharpening, which is only true for specific loss functions (eg. those resulting from neural network optimization ), self stabilization is a general property of gradient descent.
Traditional non-convex optimization analyses involve Taylor expanding the loss function to second order around to prove loss decrease when . When this is violated, the iterates diverge exponentially in the top eigenvector direction, , thus leaving the region in which the loss function is locally quadratic. Understanding the dynamics thus necessitates a cubic Taylor expansion.
Our key insight is that the missing term in the Taylor expansion of the gradient after diverging in the direction is , which is conveniently equal to the gradient of the sharpness at :
If the top eigenvalue of is unique, then the sharpness is differentiable at and .
As the iterates move in the negative gradient direction, this term has the effect of decreasing the sharpness. The story of self-stabilization is thus that as the iterates diverge in the direction, the strength of this movement in the direction grows until it forces the sharpness below , at which point the iterates in the direction shrink and the dynamics re-enter the quadratic regime.
This negative feedback loop prevents both the sharpness and the movement in the top eigenvector direction, , from growing out of control. As a consequence, we show that gradient descent implicitly solves the constrained minimization problem:
Specifically, if the stable set is defined by The condition that is added to ensure the constrained trajectory is not unstable. This condition does not affect the stationary points of eq. 2. then the gradient descent trajectory tracks the following projected gradient descent (PGD) trajectory which solves the constrained problem :
Our main contributions are as follows. First, we explain self-stabilization as a generic property of gradient descent for a large class of loss functions, and provide precise predictions for the loss, sharpness, and deviation from the constrained trajectory throughout training (Section 4). Next, we prove that under mild conditions on the loss function (which we verify empirically for standard architectures and datasets), our predictions track the true gradient descent dynamics up to higher order error terms (Section 5). Finally, we verify our predictions by replicating the experiments in Cohen et al. and show that they model the true gradient descent dynamics (Section 6).
Related Work
Xing et al. observed that for some neural networks trained by full-batch gradient descent, the loss is not monotonically decreasing. Wu et al. remarked that gradient descent cannot converge to minima where the sharpness exceeds but did not give a mechanism for avoiding such minima. Lewkowycz et al. observed that when the initial sharpness is larger than , gradient descent "catapults" into a stable region and eventually converges. Jastrzębski et al. studied the sharpness along stochastic gradient descent trajectories and observed an initial increase (i.e. progressive sharpening) followed by a peak and eventual decrease. They also observed interesting relationships between the dynamics in the top eigenvector direction and the sharpness. Jastrzębski et al. conjectured a general characterization of stochastic gradient descent dynamics asserting that the sharpness tends to grow but cannot exceed a stability criterion given by their eq (1), which reduces to in the case of full batch training. Cohen et al. demonstrated that for the special case of (full batch) gradient descent training, the optimization dynamics exhibit a simple characterization. First, the sharpness rises until it reaches at which point the dynamics transition into an “edge of stability” (EOS) regime where the sharpness oscillates around and the loss continues to decrease, albeit non-monotonically.
Recent works have sought to provide theoretical analyses for the EOS phenomenon. Ma et al. analyzes EOS when the loss satisfies a "subquadratic growth" assumption. Ahn et al. argues that unstable convergence is possible when there exists a "forward invariant subset" near the set of minimizers. Arora et al. analyzes progressive sharpening and the EOS phenomenon for normalized gradient descent close to the manifold of global minimizers. Lyu et al. uses the EOS phenomenon to analyze the effect of normalization layers on sharpness for scale-invariant loss functions. Chen and Bruna show global convergence despite instability for certain 2D toy problems and in a 1-neuron student-teacher setting. The concurrent work Li et al. proves progressive sharpening for a two-layer network and analyzes the EOS dynamics through four stages similar to ours using the norm of the output layer as a proxy for sharpness.
Beyond the EOS phenomenon itself, prior work has also shown that SGD with large step size or small batch size will lead to a decrease in sharpness . Gilmer et al. also describes connections between EOS, learning rate warm-up, and gradient clipping.
At a high level, our proof relies on the idea that oscillations in an unstable direction prescribed by the quadratic approximation of the loss cause a longer term effect arising from the third-order Taylor expansion of the dynamics. This overall idea has also been used to analyze the implicit regularization of SGD . In those settings, oscillations come from the stochasticity, while in our setting the oscillations stem from instability.
Setup
As in eq. 2, the constrained trajectory is defined by
Our key assumption is the existence of progressive sharpening along the constrained trajectory, which is captured by the progressive sharpening coefficient :
We define .
We focus on the regime in which there is a single unstable eigenvalue, and we leave understanding multiple unstable eigenvalues to future work. We thus make the following assumption on :
For some absolute constant we have .
The Self-stabilization Property of Gradient Descent
In this section, we derive a set of equations that predict the displacement between the gradient descent trajectory and the constrained trajectory . Viewed as a dynamical system, these equations give rise to a negative feedback loop, which prevents both the sharpness and the displacement in the unstable direction from diverging. These equations also allow us to predict the values of the sharpness and the loss throughout the gradient descent trajectory.
The analysis in this section proceeds by a cubic Taylor expansion around a fixed reference point .Beginning in Section 5, the reference points for our Taylor expansions change at every step to minimize errors. However, fixing the reference point in this section simplifies the analysis, better illustrates the negative feedback loop, and motivates the definition of the constrained trajectory. For notational simplicity, we will define the following quantities at :
We want to track the movement in the unstable direction and the direction of changing sharpness , and thus define
Note that is approximately equal to the change in sharpness from to , since Taylor expanding the sharpness yields
The mechanism for edge of stability can be described in 4 stages (see Figure 2):
While are small, . In addition, because , gradient descent naturally increases the sharpness at every step. In particular,
The sharpness therefore increases linearly with rate .
Stage 2: Blowup
As measures the deviation from in the direction, the dynamics of can be modeled by gradient descent on a quadratic with sharpness . In particular, the rule for gradient descent on a quadratic givesA rigorous derivation of this update in terms of instad of requires a third-order Taylor expansion around ; see Appendix I for more details.
When the sharpness exceeds , i.e. when , begins to grow exponentially indicating divergence in the top eigenvector direction.
Stage 3: Self-Stabilization
Once the movement in the direction is sufficiently large, the loss is no longer locally quadratic. Understanding the dynamics necessitates a third order Taylor expansion. The missing cubic term in the Taylor expansion of is
by Lemma 2. This biases the optimization trajectory in the direction, which decreases sharpness. Recalling , the new update for becomes:
Therefore once , the sharpness begins to decrease and continues to do so until the sharpness goes below and the dynamics return to stability.
Stage 4: Return to Stability
At this point is still large from stages 1 and 2. However, the self-stabilization of stage 3 eventually drives the sharpness below so that . Because the rule for gradient descent on a quadratic with sharpness is still
begins to shrink exponentially and the process returns to stage 1.
Combining the update for in all four stages, we obtain the following simplified dynamics:
where we recall is the progressive sharpening coefficient and .
2 Analyzing the simplified dynamics
We now analyze the dynamics in eq. 3. First, note that changes sign at every iteration, and that, due to the instability in the direction. While eq. 3 cannot be directly modeled by an ODE due to these rapid oscillations, we can instead model , whose update is controlled by . As a consequence, we can couple the dynamics of to the following ODE :
This system has the unique fixed point where . We also note that this ODE can be written as a Lotka-Volterra predator-prey model after a change of variables, which is a classical example of a negative feedback loop. In particular, the following quantity is conserved:
Let . Then the quantity
As a result we can use the conservation of to explicitly bound the size of the trajectory:
For all , and .
The Predicted Dynamics and Theoretical Results
We now present the equations governing edge of stability for general loss functions.
Our general approach Taylor expands the gradient of each iterate around the corresponding iterate of the constrained trajectory. We define the following Taylor expansion quantities at :
Furthermore, for any vector-valued function , we define where is the projection onto the orthogonal complement of .
We also define the following quantities which govern the dynamics near .
Recall that is the progressive sharpening force, is the strength of the stabilization force, and controls the size of the deviations from and was the fixed point in the direction in Section 4.2. In addition, admits a simple interpretation: it measures the change in sharpness between times and after a displacement of at time . Unlike in Section 4 where is always perpendicular to the Hessian, the Hessian causes this displacement to change at every step. In particular, at time , it multiplies the displacement by . The displacement after steps is therefore
As the change in sharpness is approximately equal to , the change in sharpness between times and is approximately captured by .
When is constant and orthogonal to the Hessian, as in Section 4, this change is because the change in sharpness is approximately and the displacement is . Therefore in the setting of Section 4, is constant.
2 The equations governing edge of stability
We now introduce the equations governing edge of stability. We track the following quantities:
Define , , .
Our predicted dynamics directly predict the displacement and the full definition is deferred to Appendix C. However, they have a relatively simple form in the directions that only depend on the remaining directions through the scalar quantities .
Let denote our predicted dynamics (defined in Appendix C). Letting and , we have
As we will see in Theorem 1, the directions alone suffice to determine the loss and sharpness values and, in fact, the EOS dynamics can be fully captured by this 2d dynamical system with time-dependent coefficients.
Note that when are constant, our update reduces to the simple case discussed in Section 4, which we analyze fully. When is large, eq. 5 demonstrates that there is a self-stabilization force which acts to decrease ; however, unlike in Section 4, the strength of this force changes with .
3 Coupling Theorem
We now show that, under a mild set of assumptions which we verify to hold empirically in Appendix E, the true dynamics are accurately governed by the predicted dynamics. This lets us use the predicted dynamics to predict the loss, sharpness, and the distance to the constrained trajectory .
Our errors depend on the unitless quantity , which we verify is small in Appendix E.
Let and .
To control Taylor expansion errors, we require upper bounds on and its Lipschitz constant:For simplicity of exposition, we make these bounds on globally, however our proof only requires them in a small neighborhood of the constrained trajectory .
Let , to be the minimum constants such that for all , and is -Lipschitz with respect to . Then we assume that .
Next, we require the following generalization of 1:
Finally, we require a set of “non-worst-case" assumptions, which are that the quantities and are nicely behaved in the directions orthogonal to , which generalizes the eigengap assumption. We verify the assumptions on and empirically in Appendix E.
With these assumptions in place, we can state our main theorem which guarantees predict the loss, sharpness, and deviation from the constrained trajectory up to higher order terms:
Let and assume that Then for any , we have
Theorem 1 says that up to higher order terms, the predicted dynamics capture the deviation from the constrained trajectory and allow us to predict the sharpness and loss of the current iterate. In particular, the GD trajectory should be though of as following the constrained PGD trajectory plus a rapidly oscillating process whose dynamics are given by the predicted dynamics eq. 6.
The sharpness is controlled by the slowly evolving quantity and the period-2 oscillations of . This combination of gradual and rapid periodic behavior was observed by Cohen et al. and appears in our experiments. Theorem 1 also shows that the loss at spikes whenever is large. On the other hand, when is small, approaches the loss of the constrained trajectory.
Experiments
We verify that the predicted dynamics defined in eq. 5 accurately capture the dynamics of gradient descent at the edge of stability by replicating the experiments in and tracking the deviation of gradient descent from the constrained trajectory. In Figure 4, we evaluate our theory on a 3-layer MLP and a 3-layer CNN trained with mean squared error (MSE) on a 5k subset of CIFAR10 and a 2-layer Transformer trained with MSE on SST2 Socher et al. . We provide additional experiments varying the learning rate and loss function in Appendix G, which use the generalized predicted dynamics described in Section 7.2. For additional details, see Appendix D.
Figure 4 confirms that the predicted dynamics eq. 5 accurately predict the loss, sharpness, and distance from the constrained trajectory. In addition, while the gradient flow trajectory diverges from the gradient descent trajectory at a linear rate, the gradient descent trajectory and the constrained trajectories remain close throughout training. In particular, the dynamics converge to the fixed point described in Section 4.2 and . This confirms our claim that gradient descent implicitly follows the constrained trajectory eq. 2.
In Section 5, various assumptions on the model were made to obtain the edge of stability behavior. In Appendix E, we numerically verify these assumptions to ensure the validity of our theory.
Discussion
Recall that the predicted dynamics describe the deviation of the GD trajectory from the PGD constrained trajectory . These dynamics enable many interesting observations about the EOS dynamics. First, the loss and sharpness only depend on the quantities , which are governed by the 2D dynamical system with time-dependent coefficients eq. 5. When are constant, we showed that this system cycles and has a conserved potential. In general, understanding the edge of stability dynamics only requires analyzing the 2D system eq. 5, which is generally well behaved (Figure 4).
In the limit, we expect to approach , the fixed point of the system eq. 5. In fact, Figure 4 shows that after a few cycles, indeed converges to this fixed point. We are able to accurately predict its location, as well as the loss increase from the constrained trajectory due to .
2 Generalized Predicted Dynamics
In order for our cubic Taylor expansions to track the true gradients, we need a bound on the fourth derivative of the loss (3). This is usually sufficient to capture the dynamics at the edge of stability as demonstrated by Figure 4 and Appendix E. However, this condition was violated in some of our experiments, especially when using logistic loss. To overcome this challenge, we developed a generalized form of the predicted dynamics whose definition we defer to Appendix F. These generalized predictions are qualitatively similar to those given by the predicted dynamics in Section 5; however, they precisely track the dynamics of gradient descent in a wider range of settings. See Appendix G for empirical verification of the generalized predicted dynamics.
3 Implications for Neural Network Training
A central phenomenon at edge of stability is that despite non-monotonic fluctuations of the loss, the loss still decreases over long time scales. Our theory provides a clear explanation for this decrease. We show that the gradient descent trajectory remains close to the constrained trajectory (Sections 4 and 5). Since the constrained trajectory is stable, it satisfies a descent lemma (Lemma 14), and has monotonically decreasing loss. Over short time periods, the loss is dominated by the rapid fluctuations of described in Section 4. Over longer time periods, the loss decrease of the constrained trajectory due to the descent lemma overpowers the bounded fluctuations of , leading to an overall loss decrease. This is reflected in our experiments in Section 6.
Generalization & the Role of Large Learning Rates
Prior work has shown that in neural networks, both decreasing sharpness of the learned solution and increasing the learning rate are correlated with better generalization. Our analysis shows that gradient descent implicitly constrains the sharpness to stay near , which suggests larger learning may improve generalization by reducing the sharpness. In Figure 6 we confirm that in a standard setting, full-batch gradient descent generalizes better with large learning rates.
Training Speed
Additional experiments in [8, Appendix F] show that, despite the instability in the training process, larger learning rates lead to faster convergence. This phenomenon is explained by our analysis. Gradient descent is coupled to the constrained trajectory which minimizes the loss while constraining movement in the directions. Since only two directions are “off limits,” the constrained trajectory can still move quickly in the orthogonal directions, using the large learning rate to accelerate convergence. We demonstrate this empirically in Figure 6.
Connection to Sharpness Aware Minimization (SAM)
Foret et al. introduced the sharpness-aware minimization (SAM) algorithm, which aims to control sharpness by solving the optimization problem . This is roughly equivalent to minimizing over all global minimizers, and thus SAM tries to explicitly minimize the sharpness. Our analysis shows that gradient descent implicitly minimizes the sharpness, and for a fixed looks to minimize subject to .
Connection to Warmup
Gilmer et al. demonstrated that learning rate warmup, which consists of gradually increasing the learning rate, empirically leads to being able to train with a larger learning rate. The self-stabilization property of gradient descent provides a plausible explanation for this phenomenon. If too large of an initial learning rate is chosen (so that is much greater than ), then the iterates may diverge before self stabilization can decrease the sharpness to . On the other hand, if the learning rate is chosen that is only slightly greater than , self-stabilization will decrease the sharpness to . Repeatedly increasing the learning rate slightly could then lead to small decreases in sharpness without the iterates diverging, thus allowing training to proceed with a large learning rate.
Connection to Weight Decay and Sharpness Reduction.
Lyu et al. proved that when the loss function is scale-invariant, gradient descent with weight decay and sufficiently small learning rate converges leads to reduction of the normalized sharpness . In fact, the mechanism behind the sharpness reduction is exactly the self-stabilization force described in this paper restricted to the setting in . In section Appendix H we present a heuristic derivation of this equivalence.
Future Work
Our main result Theorem 1 gives sufficient and verifiable conditions under which the GD trajectory and the PGD constrained trajectory can be coupled for steps. However, these conditions are not strictly necessary and our local coupling result is not sufficient to prove global convergence to a stationary point of eq. 2. These suggest two important directions for future work: Can we precisely characterize the prerequisites on the loss function and learning rate for self-stabilization to occur? Can we couple for longer periods of time or repeat this local coupling result to prove convergence to KKT points of eq. 2?
2 Multiple Unstable Eigenvalues
Our work focuses on explaining edge of stability in the presence of a single unstable eigenvalue (2). However, Cohen et al. observed that progressive sharpening appears to apply to all eigenvalues, even after the largest eigenvalue has become unstable. As a result, all of the top eigenvalues will successively enter edge of stability (see Figure 5). In particular, Figure 5 shows that the dynamics are fairly well behaved in the period when only a single eigenvalue is unstable, yet appear to be significantly more chaotic when multiple eigenvalues are unstable.
3 The Mystery of Progressive Sharpening
Our analysis directly assumed the existence of progressive sharpening (1), and focused on explaining the edge of stability dynamics using this assumption. However, this leaves open the question of why neural networks exhibit progressive sharpening, which is an important question for future work. Partial progress towards understanding the mechanism behind progressive sharpening in neural networks has been made in the concurrent works Li et al. , Zhu et al. , Agarwala et al. .
4 Connections to Stochasticity
Our analysis focused on understanding the edge of stability dynamics for gradient descent. However, phenomena similar to the edge of stability have also been observed for SGD . While these phenomena do not exhibit as simple of a characterization as the full batch gradient descent dynamics do, understanding the optimization dynamics of neural networks used in practice requires understanding the connections between edge of stabity and SGD. Important questions include what the correct notion of “stability” is for SGD and what form the self-stabilization force takes.
One possible hypothesis for how self-stabilization occurs in SGD is as a byproduct of the implicit regularization of SGD described in Blanc et al. , Damian et al. , Li et al. . These works show that the stochasticity in SGD has the effect of decreasing a quantity related to the trace of the Hessian, rather than directly constraining the operator norm of the Hessian (i.e the sharpness).
Furthermore, Damian et al. showed that the implicit bias of label noise SGD is proportional to which blows up as . This regularizer therefore heuristically acts as a log-barrier which enforces , rather than as a hard constraint. The precise “break-even” point could then be approximated by the point at which this regularization force balances with progressive sharpening. It is an interesting direction to better understand the interactions between the self-stabilization mechanism described in this paper and the implicit regularization effects of SGD described in in Blanc et al. , Damian et al. , Li et al. .
Acknowledgements
AD acknowledges support from a NSF Graduate Research Fellowship. EN acknowledges support from a National Defense Science & Engineering Graduate Fellowship, and NSF grants CIF-1907661 and DMS-2014279. JDL, AD, EN acknowledge support of the Sloan Research Fellowship, NSF CCF 2002272, NSF IIS 2107304, NSF CIF 2212262, and NSF-CAREER under award #2144994.
The authors would like to thank Jeremy Cohen, Kaifeng Lyu, and Lei Chen for helpful discussions throughout the course of this project. We would especially like to thank Jeremy Cohen for suggesting the term “self-stabilization” to describe the negative feedback loop derived in this paper.
References
Appendix A Notation
For matrices , we define
Appendix B A Toy Model for Self-Stabilization
For , consider the function
initialized at the point . Note that the constrained trajectory will follow , , as it cannot decrease without increasing the sharpness past . We therefore have:
Note that this satisfies all of the assumptions in Section 4 and it satisfies and . This process will then follow eq. 4 in the directions while it tracks the constrained trajectory moving linearly in the direction.
Appendix C Definition of the Predicted Dynamics
Below, we present the full definition of the predicted dynamics:
Define , and let . Then
Given a vector and a timestep , define by
where and .
It is easy to see that .
Defining , we can unfold the recursion in eq. 6 to obtain the following formula for .
It is then immediate to see that have the following simple update:
Appendix D Experimental Details
We evaluated our theory on four different architectures. The 3-layer MLP and CNN are exact copies of the MLP and CNN used in . The MLP has width , the CNN has width , and both are using the swish activation . We also evaluate on a ResNet18 with progressive widths and on a 2-layer Transformer with hidden dimension and two attention heads.
D.2 Data
We evaluated our theory on three primary tasks: CIFAR10 multi-class classification with both categorical MSE loss and cross-entropy loss, CIFAR10 binary classification (cats vs dogs) with binary MSE loss and logistic loss, and SST2 with binary MSE loss and logistic loss.
D.3 Experimental Setup
For every experiment, we tracked the gradient descent dynamics until they reached instability and then began tracking the constrained trajectory, gradient descent, gradient flow, and both our predicted dynamics (Section 5) and our generalized predicted dynamics (Appendix F). In addition, we tracked the various quantities on which we made assumptions for Section 5 in order to validate these assumptions. We also tracked the second eigenvalue of the Hessian at the constrained trajectory throughout training and stopped training once it reached , to ensure the existence of a single unstable eigenvalue. Finally, as the edge of stability dynamics are very sensitive to small perturbation when is small (see Figure 3), we switched to computing gradients with 64-bit precision after first reaching instability to avoid propagating floating point errors.
Our experiments were conducted in JAX , using https://github.com/locuslab/edge-of-stability as a reference for replicating the experimental setup used in . All experiments were conducted on two servers, each with 10 NVIDIA GPUs. Code is available at https://github.com/adamian98/EOS.
Appendix E Empirical Verification of the Assumptions
For each of the experimental settings considered (MLP+MSE, CNN+MSE, CNN+Logistic, ResNet18+MSE, Transformer+MSE, Transformer+Logistic), we plot a number of quantities along the constrained trajectory to verify that the assumptions made in the main text hold. For each learning rate we have 8 plots tracking various quantities, which verify the assumptions as follows: 1 is verified by the 1st plot, being small is verified by the 2nd plot, 4 is verified by the 3rd and 4th plots, 3 is verified by the 5th plot, and 5 is verified by the last 3 plots. As described in the experimental setup, training is stopped once the second eigenvalue is , so 2 always holds with as well.
Appendix F The Generalized Predicted Dynamics
Our analysis relies on a cubic Taylor expansion of the gradient. However, in order for this Taylor expansion to accurately track the gradients we need a bound on the fourth derivative of the loss (3). Section 6 and Appendix E show that this approximation is sufficient to capture the dynamics of gradient descent at the edge of stability for many standard models when the loss criterion is the mean squared error. However, for certain architectures and loss functions, including ResNet18 and models trained with the logistic loss, this condition is often violated.
In these situations, the loss function in the top eigenvector direction is either sub-quadratic, meaning that the quadratic Taylor expansion overestimates the loss and sharpnessThis sub-quadratic phenomenon was also observed in ., or super-quadratic, meaning that the quadratic Taylor expansion underestimates the loss and sharpness. To capture this phenomenon, we derive a more general form of the predicted dynamics which reduces to the standard predicted dynamics in Section 5 when the loss in the top eigenvector direction is approximately quadratic. In addition, Appendix G shows that the generalized predicted dynamics capture the dynamics of gradient descent at the edge of stability for both mean squared error and cross-entropy in all settings we tested.
To derive the generalized predicted dynamics, we will abstract away the dynamics in the top eigenvector direction. Specifically, for every we define
We say that is sub-quadratic at if and super-quadratic if .
Note that knowing is not sufficient to capture the dynamics in the direction. Specifically,
It is still critically important to track the effect that the movement in the direction has on the dynamics of . As in Section 4.1, the effect of the movement in the direction on the dynamics of is changing the sharpness by . This gives us the generalized predicted dynamics update:
where
Note that when is exactly quadratic, this reduces to the standard predicted dynamics update in eq. 6. Note that the update for is completely unchanged:
Restricted to the directions, the generalized predicted dynamics imply:
The proof is identical to the proof of Lemma 4.
F.2 Properties of the Generalized Predicted Dynamics
Note that due to the sign flipping argument in Appendix I, we can assume that is an even function as the odd part will only influence the dynamics through additional oscillations of period 2, so throughout the remainder of this section we will assume that . Otherwise, we can simply redefine by its even part.
Next, note that the fixed point of eq. 11 is still when , regardless of the shape of , due to the need to stabilize the direction. This contradicts previous 1-dimensional analyses of edge of stability in which the fixed point in the top eigenvector direction strongly depends on the shape of , the loss in the direction.
The limiting value of can therefore be read from the update for . If is an orbit of period 2 of eq. 11, then
In addition, note that the sharpness can no longer be approximated as as the sharpness now changes along the direction. In particular, it changes by so that
Therefore, the limiting sharpness of eq. 11 is
When and the loss is exactly quadratic in the direction, this update reduces to fixed point predictions in Section 4.1.
One interesting phenomenon observed by Cohen et al. is that with cross-entropy loss, the sharpness was never exactly , but usually hovered above it. This contradicts the predictions of the standard predicted dynamics which predict that the fixed point has sharpness . However, using the generalized predicted dynamics eq. 11, we can give a clear explanation.
When the loss is sub-quadratic, e.g. when , we have
so the sharpness will converge to a value below . On the other hand if the loss is super-quadratic, the sharpness converges to a value above . More generally, whether the loss converges to a value above or below depends on the sign of .
In our experiments in Appendix G, we observed both sub-quadratic and super-quadratic loss functions. In particular, the loss was usually sub-quadratic when it first reached instability but gradually became super-quadratic as training progressed at the edge of stability.
Appendix G Additional Experiments
We trained ResNet18 with full batch gradient descent on the full 50k training set of CIFAR10 with various learning rates, in addition to the commonly proposed learning rate schedule . We show that despite entering the edge of stability, large learning rates converge much faster. In addition, due to the self-stabilization effect of gradient descent, the final sharpness is bounded by which is smaller for larger learning rates and leads to better generalization (see Figure 6).
Appendix H Scale Invariant Losses
Let be a scale invariant loss function, i.e. . Let . Then for any local minimizer of such that ,
Our primary result is that gradient descent solves the constrained problem such that . Let denote the largest eigenvalue of . To prove equivalence to the sharpness reduction, we will need the following lemma from which follows from the scale invariance of the :
Let . Then we have the following equality between minimization problems:
where the last line follows from the scale-invariance of the loss function. In particular if is sufficiently small and the dynamics are initialized near a global minimizer of the loss, this will converge to the solution of the constrained problem:
Let denote an arbitrary parameter and let . Throughout this section, let be a scale invariant function with non-vanishing Hessian.
Differentiating with respect to and using that gives,
Note that from Corollary 2, the top eigenvector of is perpendicular to . In addition,
so this is also the top eigenvector of , i.e. . Finally,
which is parallel to and concludes the proof. ∎
Let denote the largest eigenvalue of . Then by scale invariance, . This implies that
Finally, we have from that so
In particular, .
Note that implies that
by compactness of . In addition, note that so . ∎
At any second order stationary point of ,
At any second order stationary point of ,
where the last step follows from compactness of and the fact that is non-vanishing. ∎
The lemma is simply a restatement of Corollary 3, Lemma 10, Lemma 9, and Lemma 11. ∎
Appendix I Proofs
We next prove several nice properties of the constrained trajectory. Before, we require the following auxiliary lemma, which shows that several quantities are Lipschitz in a neighborhood around the constrained trajectory:
where is the constant in 2 and denotes the ball of radius centered at .
is -Lipschitz in each set .
is -Lipschitz with respect to .
is -Lipschitz.
is -Lipschitz in each set .
is -Lipschitz in each set .
The Lipschitzness of follows immediately from the bound . Weil’s inequality then immediately implies the desired bound on the Lipschitz constant of the eigenvalues of . Therefore for any , we have for all :
Next, from the derivative of eigenvector formula:
which implies the bound on the Lipschitz constant of restricted to . Finally, because ,
where the second to last inequality follows from the bound on restricted to and the last inequality follows from 3. ∎
Next, note that we can decompose :
Let . We can now perform the change of variables
for sufficiently small so because we have
so . Therefore,
Then Taylor expanding around gives
so . We can also Taylor expand around and use that to get
Now note that for sufficiently small we have
Therefore if , we would have which contradicts . Therefore and therefore , which by complementary slackness implies . This then implies that
where the last line follows from . ∎
Taylor expanding around and using Lemma 13 gives
Next, note that because we have .Therefore for sufficiently small,
Let . Then there exists such that
Inductively applying Lemma 14 we have that there exists an absolute constant such that
I.2 Proof of Theorem 1
We first require the following three lemmas, whose proofs are deferred to Section I.3.
Assume that . Then
Assume that there exists constants such that for all , , . Then, for all , we have
For , .
With these lemmas in hand, we can prove Theorem 1.
First, by Lemma 17, we have .
Next, we Taylor expand to calculate :
I.3 Proof of Auxiliary Lemmas
Taylor expanding the update for about , we get
Additionally, recall that the update for is
Subtracting the previous 2 equations and expanding out via the non-worst-case bounds, we obtain
We would first like to compute the magnitude of .
Observe that by definition of and , and since
Therefore, the one-step update in the direction is:
where we have defined the error term as
The update in the direction is
Therefore we can control the first of the error terms as
As for the second error term, we can decompose
By 5, we have . Additionally, . Therefore
where we used . Altogether, we have
We next compute the two-step update for :
We previously obtained . Furthermore,
The first of these two error terms can be bounded as
Altogether, the two-step update for is
Additionally, the two-step update for is
Define , and for . By the definition of , one sees that
The update for after applying is
Additionally, the update for is
and define the auxiliary trajectory by and . I first claim that for all even , which we will prove by induction on . The base case is given by assumption so assume the result for some even . Then,
Next, we will prove by induction that for ,
By definition, , so the claim is clearly true for . Next, assume the claim holds for . If is even then ; otherwise , and thus
Let . Then
so for we have . Therefore
so the induction is proven. Altogether, we get for all such , as desired. ∎
Since , we have that , and thus
Since , grows by at most a constant factor, and thus . Finally, recall that
Therefore . ∎