Small-scale proxies for large-scale Transformer training instabilities

Mitchell Wortsman, Peter J. Liu, Lechao Xiao, Katie Everett, Alex Alemi, Ben Adlam, John D. Co-Reyes, Izzeddin Gur, Abhishek Kumar, Roman Novak, Jeffrey Pennington, Jascha Sohl-dickstein, Kelvin Xu, Jaehoon Lee, Justin Gilmer, Simon Kornblith

Introduction

Scaling up transformers has led to remarkable progress from chat models to image generation. However, not every training run is successful. When training large Transformers, researchers have reported instabilities which slow or destabilize learning . As the resources required for large runs continue to grow, it is important to examine the ways that Transformer training can fail.

In this report we reproduce, study, and predict training instability in Transformer models. We find that measuring the relationship between learning rate and loss across scales is a useful tool to identify instability (e.g., Figure 1). Therefore, we introduce learning rate (LR) sensitivity, which serves as a useful summary statistic for learning rate vs. loss curves. LR sensitivity measures the deviation from optimal performance when varying LR across orders of magnitude.

We show that two sources of instability, which have previously been described at scale, can be reproduced in small Transformers.We focus on instabilities which lead to slow divergence, not loss spikes (see Section 4). This enables their study without access to large resource pools. In particular, we examine the growth of logits in attention layers and divergence of the output logits from the log probabilities . As evident from the learning rate vs. loss curves and by inspecting model characteristics, both instabilities appear at high learning rates in small models. Moreover, interventions which have previously been employed at scale are also successful in this regime (e.g., Figure 1). These interventions—qk-layernorm Based off currently unpublished investigations of Gilmer et al. . and z-loss regularization —reduce LR sensitivity and enable successful training across three orders of magnitude of LR variation.

These observations raise the question of how other known optimizer and model interventions affect the shape of the learning rate vs. loss curves across scales. Therefore, we study the effect of techniques such as warm-up, weight decay, and μ\muParam in this context. When employing qk-layernorm and z-loss regularization, these other techniques usually have little impact on the range of learning rates at which models can be stably trained, but do affect the sensitivity to learning rate within this range. In line with previous work, we find that longer warm-up reduces learning rate sensitivity, as does the independent scaling of learning rate and weight decay recommended by Loshchilov and Hutter . One interesting finding is that scaling depth increases LR sensitivity at a faster rate than scaling width.

The remainder of our investigation centers on the scaling behavior for model characteristics such as activation and gradient norms. Using the attention logit growth instability as an example, we show that it is possible to predict an instability before it emerges. This is in contrast to prior works on scaling which primarily focus on scaling trends related to loss .

We conclude by using the scaling behavior of model characteristics to search for instabilities that are currently not well documented. Our investigation shows that gradient norms decrease with both scale and learning rate, such that the default AdamW epsilon hyperparameter is too large. This causes updates that are too small. We connect this phenomenon and the attention logit growth instability to parameter norm growth .

Overall, we believe our work presents new scientific opportunities for studying training stability without access to large resource pools.

Experimental methodology

This section details our experimental set-up (Section 2.1) and useful tools employed by our analysis: (i) measuring the relationship between learning rate and loss across scales (Section 2.2) and (ii) examining scaling trends for model characteristics (Section 2.3).

We train small Transformer models with a similar experimental set-up as GPT-2 implemented in Flax : the models are decoder-only and trained with an auto-regressive loss (refer to Section A for more infrastructure details). While we experimentally manipulate many of the following hyperparameters, this section provides their default values, which we use unless otherwise specified.

By default, we use AdamW with β1=0.9\beta_{1}=0.9, β2=0.95\beta_{2}=0.95, ϵ=\epsilon= 1e-8, and gradient clipping at global norm 1. The default warmup is 5e3 steps, and the default number of total steps is 1e5. We use a linear schedule for warmup and and a cosine-decay schedule for the remainder, with minimum learning rate 1e-5. We use an independent weight decay of 1e-4 and auxiliary z-loss with coefficient 1e-4. Sections 3.2.2 and 3.1.2 respectively provide additional information and ablations on decoupled weight decay and z-loss. We use pre-normalization Transformers with qk-layernorm (see Section 3.1.1 for information). We do not use any biases following Chowdhery et al. , and the layernorm ϵ\epsilon remains at the default value in Flax of 1e-6. We jointly scale up the embedding size, depth, and number of heads when scaling parameters. We do not use weight tying of the first and last layer , and when reporting the number of parameters we exclude the embedding and head (as in Kaplan et al. ). We use rotary positional embeddings , and for training data we use C4 . Letting dd refer to the model dimension (i.e., the embedding size), the feed-forward component of the Transformer is an MLP with hidden dimension of 4dd and gelu activations. As in Vaswani et al. we use factor 1/d1/\sqrt{d} scaling in the self-attention. The embedding initialization is the default in Flax, which is normally distributed with standard deviation 1/d1/\sqrt{d}. The remainder of the weights are initialized with a truncated normal distribution with inverse root fan-in standard deviation . The default batch size is 256, where each batch element has a sequence length of 512 tokens. Sequences are packed so that no padding is required. Finally, we use the vocabulary from Raffel et al. which has size 32101 and uses a SentencePiece tokenizer. We train on TPUs in bfloat16 precision using Flax and JAX .

2 LR vs. loss curves and learning rate sensitivity

To investigate how model instability emerges with scale, it is useful to plot the relationship between learning rate (LR) and loss for models of different sizes. For instance, an instability is often characterized by an explosion in the loss at high learning rates. LR vs. loss curves can reveal how the lowest unstable learning rate changes as a function of model size.

To summarize LR vs. loss curves, we use LR sensitivity. LR sensitivity measures the deviation in final validation loss from optimal when sweeping LR across three orders of magnitude. If a model fails to train at high learning rates, then LR sensitivity will be high. There are cases where LR vs. loss curves and LR sensitivity are no longer meaningful, for instance if an intervention changes the meaning of learning rate—see Appendix B for a detailed discussion.

Unless otherwise mentioned, we use the learning rate range 3e-4 to 3e-1 with AdamW to measure LR sensitivity, where LR refers to the maximum value in a cosine decay schedule with warm-up . We consider LRs in {3e-4, 1e-3, 3e-3, 1e-2, 3e-2, 1e-1, 3e-1} when computing the minimum and expectation.

3 Scaling trends for model characteristics

To study instability, we also find it useful to examine scaling trends for model characteristics such as gradient or activation norms. This method is helpful for predicting instabilities and contrasts with previous work on scaling, which primarily focuses on trends relating model scale and loss .

Results

This section presents our results on training stability for small Transformers. Equipped with LR sensitivity (Section 2.2), we study two known instabilities and their corresponding mitigation at small scale (Section 3.1). This raises the question of how other model and optimizer interventions effect sensitivity of final loss to learning rate, which we investigate in Section 3.2. Finally, we examine whether instabilities can be reliably predicted before they emerge: Section 3.3 predicts when the logit growth instability may cause divergence in a larger model, while Section 3.4 aims to find other issues that may occur when scaling up with our default hyperparameters.

Here, we examine two instabilities that have previously been described at scale: the growth of logits in attention layers and divergence of the output logits from the log probabilities . By examining LR vs. loss curves, we show that these instabilities can be reproduced in small models by using high learning rates and that mitigations employed at scale are effective in this regime.

Researchers have previously documented that Transformer training fails when the attention logits become large . In Dehghani et al. , this issue emerged when training a ViT model with 22 billion parameters.

In the self-attention layer of a Transformer , queries qiq_{i} and keys kik_{i} are combined to compute the attention logits zij=qi,kj/dhz_{ij}=\langle q_{i},k_{j}\rangle/\sqrt{d_{h}}, where dhd_{h} is the head dimension. Next, the attention logits are passed through a softmax to produce attention weights, which are used to combine values viv_{i}. Dehghani et al. observed that the attention logits zz became large, which they refered to as attention logit growth. As a result, the attention weights collapse to one-hot vectors, which was named attention entropy collapse by Zhai et al. . To resolve this issue, Dehghani et al. proposed qk-layernorm, which applies LayerNorm to the queries and keys before computing the attention logits.

In our experiments, we find that models need not be large to exhibit instability related to attention logit growth. As shown in Figure 1, the maximum learning rate at which small models can be trained increases when using qk-layernorm. Without qk-layernorm, the learning rate at which models diverge becomes smaller with increasing model size. By contrast, models with qk-layernorm exhibit considerably lower LR sensitivity and train to low loss at high learning rates. As a highlight, qk-layernorm allows training a model with 1.2B parameters at learning rate 0.3. Both with and without qk-layernorm, LR sensitivity increases with scale.

Figure 2 displays the loss and max attention logit for two model scales that differ by three orders of magnitude. In both cases, the loss diverges without qk-layernorm. Our results in Appendix Figure E.1 suggest that attention logit growth is due to growth in the queries and keys, not due to an increase in their alignment. Instead, we hypothesize this instability could result from the quadratic dependence of attention logits on parameter norms.

1.2 Output logit divergence

Another instability reported by researchers training large models is divergence in the output logits from the log probabilities . Just as before, we reproduce this instability with small models at large learning rates, and the proposed mitigation ameliorates the issue. Overall, Figure 3 summarizes the effect.

Let yy denote the model’s output logits, which are used to compute class probabilities pip_{i} via a softmax pi=eyi/Zp_{i}=e^{y_{i}}/Z where Z=jeyjZ=\sum_{j}e^{y_{j}}. This instability occurs when the logits diverge and become very negative, as illustrated in Figure 4 for a 2.4M parameter model at learning rate 0.1. In contrast to the attention logit growth instability, this divergence occurs towards the end of training. The mitigation proposed by Chowdhery et al. is to encourage logZ\log Z to remain close to zero. They add an auxiliary loss log2Z\log^{2}Z, referred to as z-loss, with coefficient 1e-4.

As illustrated in Figures 3 and 4, we find that instability related to output logit divergence occurs in models with no weight decay regardless of scale, and z-loss resolves this instability. Weight decay also mitigates this instability for the larger models we test.

2 Measuring the effect of other known interventions

The previous section used the relationship between learning rate and loss as a useful tool for examining two known instabilities and their mitigation. This raises the question of how other known model and optimizer interventions affect the shape of LR vs. loss curves across scales. In particular, can LR sensitivity help identify additional issues or resolutions when scaling? This section aims to answer this question for common techniques such as warm-up, weight decay, and μ\muParam .

As illustrated by Figure 5, a longer warm-up period reduces LR sensitivity. This is most clear for the larger models, which are not stable at LR 3e-1 without long warm-up. The number of total steps is fixed to 1e5 in this experiment, and all models use qk-layernorm. The importance of warm-up for stability has previously been highlighted , although these works do not measure scaling behavior.

2.2 Independent weight decay

Parameterizing weight decay independently of learning rate reduces LR sensitivity, as illustrated in Figure 6. While this was recommended by Loshchilov and Hutter , it is not common practice in the default AdamW implementations of PyTorch or Optax . We explain the differences below.

For parameters θ\theta, let Δ=v/(u+ϵ)\Delta=v/\left(\sqrt{u}+\epsilon\right) denote the AdamW update without learning rate or weight decay. For weight decay coefficient λ\lambda, max learning rate η\eta, and schedule sts_{t}\in, Loshchilov and Hutter recommend the update θθst(ηΔλθ)\theta\leftarrow\theta-s_{t}(\eta\Delta-\lambda\theta), which we refer to as independent decay. On the other hand, the default implementation in PyTorch or Optax applies the update θθstη(Δλθ)\theta\leftarrow\theta-s_{t}\eta(\Delta-\lambda\theta), i.e., η\eta now scales both terms.

When reporting LR sensitivity without independent decay in Figure 6, we report the minimum LR sensitivity over ranges [1e-4, 1e-1] and [3e-4, 3e-1] because the former is sometimes better centered on the minimum. The default setting in this paper is to use independent decay. When using independent decay we set λ\lambda=1e-4, and without independent decay we set λ\lambda=0.1. A sweep on weight decay values is conducted in Figure E.10.

2.3 Scaling width vs. depth

We have so far consistently observed that increasing the number of parameters increases LR sensitivity. We now examine which part of scaling is most responsible.

Our results, illustrated by Figure 7, indicate that scaling depth increases LR sensitivity at a faster rate than scaling width. However, at the largest scale we test, independently scaling depth produces a model with lower validation loss. A validation loss comparison between width scaling, depth scaling, and joint scaling is in Appendix Figure E.3. The standard practice of joint scaling performs best at the largest scale and also has a more reliable scaling prediction when extrapolating.

When scaling depth, we use d=512d=512, and when scaling width, we use 6 layers. The number of heads is scaled proportionally with width, so that the head dimension remains the same.

Figure E.2 repeats this experiment without qk-layernorm, finding that the attention logit growth instability occurs more frequently at scale regardless of whether width or depth are scaled.

2.4 μ𝜇\muParam

Yang and Hu introduced the μ\muParam method for parameterizing a neural network. As a product, the optimal LR remains consistent when scaling model width . This section tests the effect of μ\muParam on LR sensitivity, and examines whether μ\muParam alleviates the need for qk-layernorm .

As illustrated by Figure 8, μ\muParam does succeed in stabilizing the optimal LR at the scale we test. However, μ\muParam does not improve loss or reduce LR sensitivity in our experiments. Appendix Figure E.4 repeats this experiment without qk-layernorm. Our results indicate that μ\muParam does not alleviate the need for this intervention at high learning rates. We note that from a practical perspective, reducing LR sensitivity is not important if the optimal LR does not change.

We refer to the variant of μ\muParam that we use in these experiments as μ\muParam (simple) because it maintains only the core feature of μ\muParam. We add additional features from Yang et al. in Appendix Figure E.5 without measurable improvement at the largest scale we test. For μ\muParam (simple) we make the following changes from our standard baseline: scale the LR for linear layers by base-fan-in/fan-in\text{base-fan-in}/\text{fan-in}. For μ\muParam (full) there are three additional changes: (i) initialize the head with standard deviation base-fan-in/fan-in\sqrt{\text{base-fan-in}}/\text{fan-in}; (ii) change the 1/dh1/\sqrt{d_{h}} scaling factor in attention layers to 1/dh1/d_{h} where dhd_{h} is the head dimension; and (iii) initialize the query projection weights with zeros. For base-fan-in we use the fan-in values for the smallest model we test, which has width 256.

We comment briefly on the aforementioned changes (ii) and (iii). First, we ablate on change (ii) in isolation in Appendix Figure E.6. While this intervention reduces loss slightly at the smallest scale we test, the reverse is true for the largest scale we test. Also, removing the square root from the scaling factor in attention layers does not alleviate the need for qk-layernorm. Finally, with regards to change (iii), we note that in preliminary experiments this change had no noticeable effect.

2.5 Additional interventions

This section recreates the previous plots with additional interventions or hyperparameter changes. Corresponding figures are displayed in the appendix.

Changing the number of training steps from 1e5 to 5e4 or 2e5 does not meaningfully change LR sensitivity (Appendix Figure E.7).

We try applying qk-layernorm across the whole model dimension instead of individually per-head with shared parameters. As illustrated in Appendix Figure E.8, the latter performs better. We use per-head qk-layernorm as the default in all other experiments.

Increasing the batch size from 256 to 512 or 1024 does not meaningfully change LR sensitivity (Appendix Figure E.9, each batch element contains 512 tokens). When increasing batch size we decrease the number of training steps so that the amount of data seen is constant. We believe a similar effect would be observed if instead we held the number of steps constant because changing the number of steps has no impact on LR sensitivity at batch size 256 (Appendix Figure E.7).

The effect of changing the weight decay from 1e-4 is illustrated in Figure E.10. Increasing decay appears to slightly shift the optimal LR right.

We find that the logit growth instability is not due to the softmax in the self-attention layer, as it still occurs with a pointwise variant of attention (Appendix Figure E.11).

3 Predicting attention logit growth instability from scaling behavior of model characteristics

A central question when studying instabilities is whether they can be predicted. We now examine whether it is possible to predict the logit growth instability before it occurs. We track the attention logit maximums across model scales and fit a curve to the data. We use this to predict that a 4.8B parameter model will be unstable at LR 1e-2 without qk-layernorm and run an experiment to confirm this prediction.

Figure 9 plots the number of parameters vs. max attention logit at different learning rate values.We use block 0, which typically has the largest logits, and consider the value at step 2e3. Much earlier than 2e3 was uninformative, and much later the unstable points had long past diverged. At each learning rate, we fit a quadratic to predict how the max attention logit will change with model scale.

We first noticed that all points with attention logits above 1e4 diverged. Moreover, the quadratic fit predicted that for LR 1e-2 the next model scale would also cross that value. Based on this prediction, we trained a new 4.8B parameter model at LR 1e-2. This model diverged as predicted. Not only do we predict the divergence, but our fit closely extrapolates to predict the value of the max attention logit.

4 Searching for new instabilities via scaling trends of model characteristics

This section examines whether the scaling behavior of model characteristics can be used to predict new issues with the default model and hyperparameter settings.

As models get larger, the value that grad RMS approaches is cause for concern. At the largest scale and learning rate we test, grad RMS is around the default AdamW ϵ\epsilon hyperparameter. Recall that the unscaled AdamW update is Δ=v/(u+ϵ)\Delta=v/\left(\sqrt{u}+\epsilon\right), where vv and uu are the first and second gradient moment EMA, respectively. If the grad RMS is on the same order as ϵ\epsilon, then Δ\Delta will decrease in magnitude as illustrated by Figure 13, and parameters will not receive learning signals as intended.

An obvious mitigation for this issue is to simply lower the AdamW ϵ\epsilon hyperparameter from its default of 1e-8. We conduct this experiment for a 4.8B parameter model at LR 0.3 and present the results in Figure 12. Decreasing ϵ\epsilon to 1e-15 improves loss and mitigates a collapse in grad RMS. We believe this improvement will only increase at scale. On the other hand, increasing ϵ\epsilon to 1e-6 results in an instability (shown in Figure E.15).

Figure 13 expands on this result by illustrating the grad and update RMS throughout training at the largest scale and learning rate we test. When the grad RMS reaches ϵ\epsilon, the update RMS becomes small. Figure E.13 presents data from an analogous experiment at many different scales and LRs, demonstrating that this issue is most apparent for the larger models and LRs we test.

Although we identified the instability above by empirically measuring the scaling behavior of the gradients, a mechanistic explanation exists. For larger networks and learning rates, the Transformer output RMS entering the final layernorm may grow. Since the layernorm gradients are scaled by the inverse of their input RMS, the gradient received by the Transformer will shrink. Refer to Appendix C for a more detailed discussion.

Related work

This paper mainly focuses on the effect of known interventions and instabilities, and so related work has been primarily discussed when relevant. This includes the attention growth instability observed by Dehghani et al. , Zhai et al. , and the final logit divergence issue encountered by Chowdhery et al. , Thilak et al. . However, we highlight similar experimental methods in previous work. For instance, Yang et al. also measure the relationship between LR and loss across scales, but their focus is on centering the optimum (see Section 3.2.4). In addition, Zhai et al. elicit instability in base models by doubling learning rate, and Dettmers et al. measure the presence of outlier features as a function of scale.

There are also important instabilities and related topics we have not directly discussed so far. For instance, we have primarily focused on instabilities that lead to a slow divergence, and we now summarize research on fast loss spikes. This instability is characterized by a quick increase in the loss that often eventually recovers.

The conventional understanding of gradient descent predicts that loss instability only occurs when the learning rate exceeds 2/λmax(H)2/\lambda_{\text{max}}(H), where HH is the Hessian. However recent investigations into large batch neural network training dynamics have revealed a more complicated picture via edge of stability (EoS) . When training neural networks with large batch SGD, the loss curvature constantly evolves via the interaction of two processes: progressive sharpening and self stabilization. Progressive sharpening is the empirical observation that when LR<2/λmax(H)\text{LR}<2/\lambda_{\text{max}}(H), the curvature gradually increases until the stability threshold is violated. When the learning rate becomes too large relative to the curvature, fast loss spikes occur and the parameters oscillate into a region with smaller λmax(H)\lambda_{\text{max}}(H) where stable training and progressive sharpening resumes. The latter process where instability results in smaller λmax(H)\lambda_{\text{max}}(H) is self-stabilization, a theoretical model of which is given in Damian et al. . Gradually shrinking λmax(H)\lambda_{\text{max}}(H) via self stabilization was shown to be a primary mechanism behind the success of learning rate warmup in Gilmer et al. , who closely studied the connections between curvature, initialization, architecture and max trainable learning rates.

Cohen et al. further analyze edge of stability of dynamics with adaptive optimizers, showing that progressive sharpening interacts with both the self-stabilization process and the adaptive optimizer state. This interaction results in the preconditioned sharpness λmax(P1H)\lambda_{\text{max}}(P^{-1}H) oscillating around an optimizer specific threshold (38/LR in the case of Adam with β1\beta_{1}=0.9). Adaptive EoS (AEoS) can also result in periodic loss spikes when progressive sharpening pushes the preconditioned sharpness above the stability threshold, however the optimizer hyperparameters play a role. In particular, when LR>>38/λmax(P1H)\lambda_{\text{max}}(P^{-1}H), two mechanisms are now in play to resolve the step size being too big—either HH can shrink or P1P^{-1} can shrink (or both). Cohen et al. found that when β2\beta_{2} is large, HH tends to shrink and fast loss spikes result during the process, resembling the self stabilization process observed with gradient descent. However when β2\beta_{2} is small, P1P^{-1} tends to shrink, no loss spikes are observed, and λmax(H)\lambda_{\text{max}}(H) tends to gradually increase throughout training.

It is noteworthy that the adaptive edge of stability process (and the role of β2\beta_{2}) studied in Cohen et al. offers a more complete understanding for loss spikes studied in a body of literature . For example, Shazeer and Stern argue that during training of Transformers with adaptive optimizers the optimizer update can become too big resulting in a loss spike followed by recovery. This is sometimes attributed to the adaptive optimizer state becoming “stale”, which is consistent with the observation the reducing β2\beta_{2} resolves the loss spikes . This is perhaps the same observation as Cohen et al. that reducing β2\beta_{2} allows P1P^{-1} to change quicker to adjust to the process of progressive sharpening. AEoS also offers an explanation for the periodic loss spikes observed when training large transformer models .

While our work has studied sensitivity to learning rate, there is also research that aims to eliminate the need to specify a learning rate . Based on their analysis, Ivgi et al. set the step size for iteration tt to the maximum distance from the initialization divided by the root sum of historical gradient squares. Moreover, while our work investigated μ\muParam, there are additional parameterizations for which it would be interesting to explore LR vs. loss .

Conclusion

As the compute required to train the largest models continues to increase, it becomes increasingly important to understand if training will be stable. This paper has shown that useful insights on stability can be found when studying small Transformers. We hope that this opens new opportunities for impactful research which benefits large runs without access to large resource pools.

We thank George Dahl for thorough comments and suggestions, and Hugo Larochelle and Rif A. Saurous for helpful discussion. Also, we thank the members of the Google DeepMind PAGI team for their support of this effort, Noah Fiedel, Noah Constant, Aaron Parisi, Alex Rizkowsky, Avi Singh, Azade Nova, Bernd Bohnet, Daniel Freeman, Gamaleldin Elsayed, Hanie Sedghi, Isabelle Simpson, James Harrison, Jiri Hron, Kathleen Kenealy, Kevin Swersky, Kshiteej Mahajan, Laura Culp, Max Bileschi, Merrie Morris, Rosanne Liu, Yundi Qian, Sharad Vikram, Tris Warkentin.

References

Appendix A Additional infrastructure details

This Section provides more details on the training infrastructure, which is built on Flax , Jax , and TPUs , and we call NanoDO. To enable larger model training, we shard the model and optimizer states as in FSDP , then specify these shadings when compiling with JIT. We use Orbax for checkpointing, and Grain for deterministic data loading. When loading data, sequences are packed so that no padding is required—if a sequence is less tokens than the context length hyperparameter, then an end of sequence token is appended, followed by the beginning of a new sequence.

Appendix B When is learning rate sensitivity a useful metric

There are cases where LR sensitivity (defined in Section 2.2) is no longer a useful metric. This section details these scenarios and justifies the use of LR sensitivity for the interventions in this paper.

Interventions which change the meaning of learning rate

When an intervention changes the meaning of learning rate then comparing LR sensitivity is not useful. A clear example of this would be taking the square root of the LR before passing it to the optimizer, but there are more subtle cases to be cautious of when using LR sensitivity.

In general, we avoid manipulations where the meaning of LR meaningfully changes. In some cases, we have good empirical evidence that the meaning of the learning rate has not changed when intervening. For instance, the LR vs. loss curves are indistinguishable up to some critical learning rate when using qk-layernorm (Figure 1), adding z-loss (Figure 3), or changing warm-up.

In other cases, such as when testing μ\muParam (Section 3.2.4), we believe that LR sensitivity is useful despite a per-layer modification of LR. This is because the per-layer LR is manipulated linearly, and this modification does not change for different points on the LR vs loss curve.

The one experiment in this paper where we believe LR sensitivity is likely not a useful metric is when scaling learning rate by the root mean square of the parameters (Figure E.14). Therefore, we do not measure LR sensitivity in that case.

The definition of LR sensitivity in Section 2.2 does not account for the optimal LR shifting when specifying the LR range [a,b][a,b]. In practice we recommend shifting the three order of magnitude range [a,b][a,b] to correspond with this shift. For instance, we shift the range in Section 3.2.2, as discussed in more detail in the section. However, our main experiments (e.g., Figure 1) do not test at a large enough scale to necessitate this shift.

Another limitation of the LR sensitivity metric is that it is invariant to the scale of the loss. If the network consistently achieves random performance across learning rates, then LR sensitivity will be zero. We do not offer a solution to this, and instead recommend that LR sensitivity should always be examined in combination with the LR vs. loss curves as we do in this paper. It is meant as a useful summary of the LR vs. loss curves, not as a metric to optimize in isolation.

Appendix C Output norm growth

This section discusses the growth of the output norms during Transformer training as previously studied by Merrill et al. , Lee , and relates this phenomenon to the attention logit growth and AdamW epsilon instabilities (Sections 3.1.1 and 3.4, respectively). As empirical evidence, Figure C.1 shows that the RMS of the Transformer block output is mainly determined by learning rate.

We have two hypothesis which relate parameter norm growth and subsequent output norm growth to instability. First, we believe that the attention output logits are the first to become large because they are the only feature in the network we test whose magnitude depends quadratically on parameter RMS. For inputs XX with unit RMS, a typical matrix multiply XWXW with parameters WW will result in features YY where RMS(Y)\text{RMS}(Y) is a linear function of RMS(W)\text{RMS}(W). On the other hand, the attention logit entries are computed via XW1,XW2\langle XW_{1},XW_{2}\rangle so depend quadratically on RMS(W)\text{RMS}(W). Next, this helps to explain the decreasing trend in gradient scale observed in Section 3.4 (Figure 11). In a pre-normalization Transformer there is an output layernorm layer after the last Transformer block and before the final linear layer. The gradient from this output layernorm layer is scaled by the reciprocal of the input RMS. This RMS is growing with depth because of the residual connections (Figure C.1). As the RMS leaving the last Transformer block grows, the gradient received shrinks.

For completeness we now compute the layernorm gradient to input xx. We assume the input as mean zero and the layernorm has no bias for simplicity. Let

Appendix D Author contributions

Mitchell Wortsman led the project, ran the experiments and produced the figures, contributed substantially to the infrastructure for experimentation, the framing and direction, and the writing.

Peter J. Liu led the infrastructure and creation of NanoDO for experimentation, provided key insights and advice on multiple technical areas, and contributed to the writing.

Lechao Xiao and Katie Everett contributed to the infrastructure used for experimentation, provided key insight related to parameterization, and contributed to the writing.

Alex Alemi, Ben Adlam, John D. Co-Reyes, Izzeddin Gur, Abhishek Kumar, Roman Novak, Jeffrey Pennington, Jascha Sohl-dickstein, and Kelvin Xu were active participants in weekly brainstorming meetings which motivated, influenced, and elucidated technical concepts pertaining to this work.

Jaehoon Lee and Justin Gilmer were senior authors advising on the project, contributed substantially to the framing and direction, provided key insight and advice on multiple technical areas, and contributed to the writing. Jaehoon led the connection with output norm growth. Justin proposed to plot loss as a function of learning rate for different model sizes, and performed initial experiments demonstrating that attention logit growth could be reproduced at high learning rates in small models.

Simon Kornblith was the lead advisor on the project, contributing substantially to the framing, direction, infrastructure, and writing. Simon initially brainstormed the project with Mitchell, and was Mitchell’s host for the summer internship during which this research was conducted, providing substantial technical support.

Appendix E Additional figures

This Section contains the additional Figures referenced in the main text.