Stable and low-precision training for large-scale vision-language models

Mitchell Wortsman, Tim Dettmers, Luke Zettlemoyer, Ari Morcos, Ali Farhadi, Ludwig Schmidt

Introduction

Large models trained on large datasets have recently led to multiple breakthroughs in machine learning such as GPT-3 and PaLM . While many components are necessary for successful large-scale training, two critical elements are training speed and stability. To enable further progress, we must ensure that 1) training is fast—the model should be able to see a lot of data even if it is large, and 2) training is stable—large models should not suffer from loss spikes which degrade performance. We study these two directions in the context of contrastive language-image pre-training (CLIP) . We examine CLIP-style models because of their importance in computer vision: CLIP-style models reach state-of-the-art performance on a wide range of image classification tasks and underlie image generation methods such as DALLE\cdot2 and Stable Diffusion . Our contributions towards fast training and stable training are as follows.

Towards fast training, we introduce SwitchBack, a linear layer for quantized training with int8 precision which matches the performance of the bfloat16 baseline within 0.1 percentage points for CLIP ViT-Huge—a larger model than considered in the original CLIP paper. Linear layers account for the majority of the compute in standard transformer models, usually more than 90%, comprising the key, query, value, and out projection of the attention blocks as well as the multilayer perceptron. We perform all linear layers in low-precision (int8) while retaining other layers, such as layer norms, in higher precision. With this setup, we observe end-to-end speedups between 13 and 25% for CLIP ViT-Huge training: 25% compared to a standard linear layer implemented using the PyTorch autograd python module and 13% compared to the standard PyTorch layer which includes CUDA and C++ optimizations that happen in the background and which are difficult to replicate for custom layers.

SwitchBack starts from the observation that quantization noise grows with the inner dimension in a matrix multiplication. For CLIP training, the weight gradient computation involves a large inner dimension because CLIP training requires a large batch size . Hence SwitchBack uses 16 bit precision matrix multiplication for the weight gradient computation while using int8 multiplications for the forward pass and layer input gradient computations. This approach leads to large accuracy improvements compared to LLM.int8() (Figure 1). We provide open-source Triton kernels for Switchback to enable future work on efficient quantization schemes.

Besides int8 training, we also study large-scale 8-bit float (fp8) training. We do not have access to hardware that supports fp8 data types, which is currently more rare than int8, so we use an accurate simulation of fp8 computation. SwitchBack also outperforms straightforward 8-bit float (fp8) baselines because tensor-wise quantized baselines diverge at >>420M scale (Figure 1). However, we demonstrate that these methods can achieve high accuracy if the network is trained while keeping feature magnitudes small, which we accomplish via layer-scale initialized with zeros.

Towards stable training, we find that loss spikes occur in CLIP training when the AdamW second moment estimator becomes out-of-date in the patch embedding layer. In particular, the learning signal changes so that the moving averages of squared gradients underestimates their true magnitude. Indeed, in the absence of stability interventions, we show that loss spikes can be predicted by examining this ratio of the squared gradients to their moving average. We therefore recommend a AdamW-AdaFactor hybrid, which we refer to as StableAdamW as it removes instabilities at the scales we consider and outperforms gradient clipping. Concretely, StableAdamW is AdamW with the update clipping technique introduced in AdaFactor. Update clipping tracks the average ratio of the gradient square to the second moment estimator and lowers the learning rate when the ratio is large.

The remainder of this paper is organized as follows: Section 2 focuses on low-precision training while Section 3 stabilizes training by reducing loss spikes.

8-bit training

This section develops and compares methods for eight-bit training of languge-vision transformer models. First, Section 2.1 discusses preliminaries and related work. Next, Section 2.2 introduces and tests SwitchBack, a linear layer for int8 and float8 training. Finally, Section 2.3 develops alternatives to SwitchBack which can be used for float8.

Neural networks today typically use 16-bit operations for training in either the float16 or bfloat16 format . Floating point formats use a subset of bits to represent the exponent while the remainder specifies the fraction (often referred to as the mantissa). The float16 format uses 5 bits for the exponent while bfloat16 uses 8 and therefore covers a larger range—float16 has a range of (5.96108,65504)(5.96\cdot 10^{-8},65504) while bfloat16 has a range of (1038,31038)(10^{-38},3\cdot 10^{38}). Most floating point formats also have denormalized numbers which allow for a “soft underflow” which gets exponentially closer to 0.0f for each additional bit in the mantissa. To prevent underflows float16 mixed precision training has been developed which works as follows. The loss of a mini-batch is multiplied by a loss scalar to scale the loss and following backpropagation gradients into the representable range of fp16. This loss scaling is undone by rescaling the weight gradients before the optimizer updates fp32 main weights with the fp16 gradients. In PyTorch , the loss scalar is initialized to 65536. Everytime an Inf/NaN is encountered, the update is skipped and the loss scalar is halved. If no Inf/NaN are encountered for 2k iterations, the scalar is doubled.

When the loss scalar becomes too low in float16 training the loss slowly diverges. This was observed by Cherti et al. when training ViT-Huge CLIP models and remedied by switching to bfloat16. Another instance of float16 creating issues at scale was the training of OPT and BLOOM models . Indeed, many obstacles faced during the OPT project could have been alleviated by using bfloat16 . Similarly, all float16 training runs for BLOOM ended in divergence, only after using bfloat16 was the training stable. However, fast bfloat16 support is only available on TPUs, or GPUs developed with or after the NVIDIA Ampere series (2021 or later).

While 16 bit training is the standard today, hardware support for 8 bit operations are becoming more common. Hopper GPUs support float8 (fp8) and Ampere GPUs support int8. However, it is currently (2023) very difficult to attain Hopper GPUs. Moreover, while int8 and int4 are used for inference , and there is earlier work exploring 8 bit training for convnets , these formats are not commonly used for training transformer models at scale. The CLIP ViT-Huge models we train have 1B parameters including the image and text towers which is 40x larger than a standard ResNet-50 (23M) , and quantization is more challenging for large tensors . Additional related work on quantization of large scale models (larger than BERT-large) and low-precision training and be found in Appendix A.

2 SwitchBack

Overview. A linear layer consists of three matrix multiplications—one in the forward pass to compute outputs and two in the backwards pass to compute gradients for the input and weights. Our SwitchBack layer uses 8 bit precision for the first two matrix multiplies but switches back to higher precision for the weight gradient.

We compute the weight gradient in higher precision because this matrix multiplication involves dot products between vectors which have a length of batch size times sequence length. As CLIP training requires large batch sizes , this inner dimension of batch size times sequence length is much larger than for the other matrix multiplies. As we show in Appendix C, variance due to quantization increases with the inner dimension of the matrix multiply. This modification is what differentiates SwitchBack from LLM.int8(), allowing SwitchBack to match the bfloat16 baseline (Figure 1).

Quantization. For the matrix multiplies in 8 bit precision we use quantization. There are a multiple quantization techniques to choose from and we release code for all these alternatives. However, we find the best trade-off of simplicity and performance is from using i) row-wise quantization for the inputs and gradients and ii) tensor-wise quantization for the weights. Additional information on quantization methods is provided by Dettmers et al. but we summarize below. Using int8 as an example, which can represent integers from 127-127 to 127, we now define row-wise and tensor wise quantization. For a matrix XX with rows x1,...,xbx_{1},...,x_{b}, row-wise quantization QrowQ_{\text{row}} is given by

while tensor-wise quantization QtensorQ_{\text{tensor}} is given by

where absmax\mathsf{absmax} is the maximum of the absolute value.

Since only the matrix multiply occurs in int8 precision we need to dequantize the outputs back to the original floating point precision. Consequently, the forward pass with quantization and dequantization becomes

where * denotes elementwise-multiplication, which in this case is broadcasted so that row ii of the matrix Qrow(X)Qtensor(W)Q_{\text{row}}\left(X\right)Q_{\text{tensor}}\left(W\right)^{\top} is multiplied by element ii of staterow(X)\mathsf{state}_{\text{row}}(X).

As mentioned previously, we use row-wise quantization for the inputs and gradients and tensor-wise quantization for the weights. We find that using row-wise quantization for both matrices increases complexity at a negligible or no performance increase. As such, we use this simpler approach.

The last detail in our algorithm is hardware specific. NVIDIA GPUs, which we use in this work, do not implement the int8/float8 operation ABAB for matrices AA and BB and only ABTAB^{T} is implemented. As such, it is necessary to transpose the weight matrix in the backward pass. To reduce the overhead of transposition and quantization we fuse both operations, meaning we load the required data once from slow DRAM into fast SRAM/shared memory and then perform both operation in this cached memory – this is critical for achieving speedups. We call this operation tensor-wise_quantize_transpose, which is a fused tensor-wise quantize and transpose operation.

Putting the pieces together, the result is Algorithm 1.

Variants. While Algorithm 1 is the most straightforward version of SwitchBack, we also present two alternative versions—SwitchBackM and SwitchBackQ—and release triton implementations for all three via the bitsandbytes library . Appendix B contains pseudocode. SwitchBackM (Algorithm 3) is a memory efficient version of SwitchBack which only saves 8 bit tensors for the backwards pass—we recommend its use when memory is limited. The small downside of SwitchBackM is that it requires an additional dequantize operation during the backwards pass which increases the runtime overhead. For CLIP ViT-Huge we observed only a negligible accuracy differences between SwitchBack and SwitchBackM. In addition, we present SwitchBackQ (Algorithm 4) which uses row-wise and column-wise quantization for the weights instead of tensor-wise. While we did not observe this to improve accuracy at the scales we consider, it’s possible that it will perform better than SwitchBack at larger scale. For SwitchBackQ, the forward pass is given by

where * is an elementwise product. Again, we append _transpose to a function in Algorithm 4 to mean that the operation is fused with a transpose.

float8. While the explanation so far has used int8 as an example, the code for SwitchBack and float8 (fp8) is nearly identical. The only modification is that operations such as round(127x/absmax(x))\mathsf{round}(127x/\mathsf{absmax}(x)) are replaced by float8cast(x/absmax(x))\mathsf{float8cast}(x/\mathsf{absmax}(x)) where we simulate float8cast\mathsf{float8cast} through bitsandbytes by rounding to the exact values of the float8 data type. This simulation improves on the simulation of which only clips the input tensors into the representable range of the float8 data type, but not the exact values of the float8 data type. This simulation theoretically matches float8 training, but we are unable to perform real float8 training because we lack the hardware that supports float8 arithmetic. As such, we perform arithmetic in 16-bit with exact float8 values. For our int8 experiments we conduct the multiplications in int8 using A100 GPUs—we perform real int8 training without any simulation.

2.2 Experimental setup

To evaluate SwitchBack we train CLIP visual transformer models on LAION-2B . Typically CLIP training, especially at ViT-Huge scale, is prohibitively expensive. Our goal is not high final accuracy but rather to contrast different methods for low-precision training. To enable running multiple experiments, we therefore only train for a small number of samples seen—380 million images—and use patch-dropout 0.5 . We note that the experiment is still very expensive, corresponding to roughly 300 epochs of ImageNet training in terms of samples seen, or approximately 2.9e20 FLOPs per training run.After training on LAION-2B we evaluate the models zero-shot on ImageNet using the 80 prompt templates from CLIP .

We use batch size 16384 (per-gpu batch size of 256) and train for a total of 20k iterations. The first 5k iterations are linear warmup while the remaining 15k are cosine decay. Training and evaluation are conducted with the OpenCLIP library with learning rate 2e-3, weight decay 0.2, and batch-size 16384 using the optimizer described in Section 3.5.

2.3 Results

We test two main questions: (1) can we replicate 16-bit performance with SwitchBack and (2) can we get speedups. To test (1) we train CLIP models with SwitchBack across multiple scales with both int8 and float8 precision (Figures 1 and 2). To test (2) we profile operations in an individual linear layer and also measure end-to-end training speed.

Accuracy. We find that SwitchBack can match standard 16-bit training performance and outperform baselines for both a) int8 precision and b) float8 precision.

For our int8 experiments (Figures 1 and 2 left), we contrast the performance of i) the standard baseline which uses mixed-precision bfloat16, ii) the matrix multiplication kernels from LLM.int8() , which is equivalent to SwitchBackQ (Algorithm 4) if the weight gradient multiplication was also performed in int8 using row- and column-wise quantization, and iii) SwitchBack. SwitchBack has a negligible accuracy drop of 0.1 percentage points compared to the bfloat16 baseline for CLIP ViT-Huge. In contrast, there is a drop of 5.9 percentage points when training with LLM.int8(). Section C details our hypothesis for why LLM.int8() fails to replicate 16-bit performance for CLIP training.

For our simulated float8 training experiments (Figures 1 and 2 right), we contrast the performance of i) the standard baseline which uses mixed-precision bfloat16, ii) a baseline which uses tensor-wise quantization for all matrices, that is the weights, inputs, and gradients, and iii) SwitchBack. SwitchBack has a negligible accuracy drop of 0.1 percentage points from the bfloat16 baseline for CLIP ViT-Huge. In contrast, training diverges for the baseline that uses tensor-wise quantization for all matrices.

Speed. We now test the speedups offered by SwitchBack by first examining individual operations and then end-to-end training.

We profile all of the operations which constitute a forward and backward pass for a single linear layer in Figure 3 (left) for both SwitchBack and the baseline. For SwitchBack we profile our custom triton kernels and for the baseline we profile torch.matmul\mathsf{torch.matmul}. Overall, we observe that int8 multiplies occupy just over half the time as standard fp16 matmuls, and that quantize operations are roughly an order of magnitude less time than a matmul. Note that our int8 matmuls are fused with the dequantize operation.

Figure 3 (right) displays the % speedup of SwitchBack over a standard fp16 layer when all operations in Figure 3 (left) are summed. Overall, the advantage of SwitchBack is greater for larger dim\mathsf{dim} and batch_sizesequence_length\mathsf{batch\_size}*\mathsf{sequence\_length}. Overall, the speedup ranges from 5% to 35%. We see a bump at dim=1280\mathsf{dim}=1280 because standard PyTorch matmuls do not have optimized kernels for matrices of this size while we use triton’s autotune feature which provides fine-grained optimized kernels for matrices of any size. Our kernels are easy to modify as they are written in Triton , and the code to run the benchmarks and produce Figure 3 is open sourced. In doing so, we invite the community to further improve the kernels and provide a benchmark for measuring this progress. Due to computational constraints we have not tested dim>4096\mathsf{dim}>4096 and it’s possible the kernels require additional tuning to perform well at that scale.

One downside of SwitchBack is that it requires quantize operations. However, it is already evident from Figure 3 that quantize operations occupy a small amount of time compared to matmuls. This is highlighted by Figure 4 (left) which displays the fraction of time occupied by quantize operations relative to matmuls for SwitchBack linear layers. Quantize operations occupy at most 25% of the time, this fraction decreases to around 10% or below for large dim\mathsf{dim}.

We now conduct end-to-end speed tests for CLIP training on a single node with 4x A100 GPUs (Figure 4, right). This is in contrast with the speedup measurements so far in this which have measured individual layers independently. We benchmark speedups relative to using i) a baseline linear layer which we implement in PyTorch with torch.autograd.linear\mathsf{torch.autograd.linear} (Algorithm 5) and ii) the PyTorch optimized linear layer nn.Linear\mathsf{nn.Linear}. In both cases the speedups increase when going from CLIP ViT-Base to CLIP ViT-Huge. However, there is an additional \sim12.5% speedup when comparing SwitchBack to the baseline linear layer which uses torch.autograd\mathsf{torch.autograd}. We believe this comparison is fair because SwitchBack is also implemented using torch.autograd\mathsf{torch.autograd}, while the standard PyTorch nn.Linear\mathsf{nn.Linear} layer has additional C++ and CUDA optimizations that we do not implement. We hope to collaborate with the PyTorch team to realize the additional \sim12.5% speedup. Finally, we note that the kernels from LLM.int8() do not provide speedups over fp16 at the scale we consider.

3 Float8 training by reducing feature magnitude

We find that SwitchBack is necessary for high accuracy int8 training. However, this section develops other interventions which enable float8 training without SwitchBack. We show that high accuracy can be achieved via float8 training with tensor-wise quantization for the inputs, weights, and gradients, so long as the network is initialized and trained in a way which discourages large feature magnitudes. We accomplish via layer-scale initialized to zero.

We use the bitsandbytes library to simulate float8 training using the fp8 types from Micikevicius et al. . We use tensor-wise quantization for the inputs, weights, and gradients, so that all operations occur in simulated float8. In our simulation, we represent each value only with the exact values representable by float8, but we perform computations in float16 precision. We believe that tensor-wise quantization approximates the removal of quantize operations entirely. This is because, as we show in Appendix B.2 (Figure 14), the maximum of these tensors tends to evolve smoothly. Consequently, using a moving average for a maximum which is divided directly in the matmul is similar to tensor-wise quantization.

Layer-scale, introduced by Touvron et al. , scales each self-attention and MLP block output hidden state by a learnable vector of shape embed_dim\mathsf{embed\_dim}. A pre-norm transformer block with layer-scale tensors γ1\gamma_{1} and γ2\gamma_{2} is defined as

where * is broadcasted elementwise multiplication.

Typically, layers are initialized so that they approximately preserve the variance of their inputs, and inputs have approximately unit variance . However, when combined with residual connections this can lead to higher norms in deeper networks.

Consequently, researchers have proposed initialization and scaling schemes which remedy this issue . Layer-scale with initialization 0 is an example of one such scheme—at initialization the transformer is an identity function. While γ1,γ2\gamma_{1},\gamma_{2} are typically initialized as vectors of 10410^{-4} or 10610^{-6}, we use 0 for simplicity.

Figure 5 (right) demonstrates that the layer-scale intervention is successful at controlling the average magnitude output. Without the intervention, the average feature magnitude \mathdsE[abs(xk)]\mathds{E}[\mathsf{abs}(x_{k})] becomes high for later blocks. Previous work has shown that large feature magnitudes result in issues for low precision training.

Results for simulated fp8 training are shown in Figure 5 (left) for ViT-Large. We find that all fp8 runs diverge except for when we use layer-scale initialized to zero. Concretely, Figure 5 compares i) the baseline which uses bfloat16 training, ii) using fp8 with tensor-wise quantization and no further modifications, which slowly diverges, iii) adding gradient clipping to ii), which also diverges, iv) adding KQ layernorm to ii), which also diverges, and v) using zero-init layerscale, which trains without diverging. While there is a difference still between fp8 and bfloat16 training, this is primarily because of layerscale. Moreover, we believe that with hyperparameter tuning layerscale would match standard training in terms of accuracy.

Stability

We now switch focus from accelerating learning by reducing precision to addressing instabilities which can arise during training. Section 3.1 reviews preliminaries and related work while Section 3.2 details the experimental setup. Next, Section 3.3 examines trends for training instability, finding loss spikes to increase with model scale but decrease with lower AdamW β2\beta_{2}. Then, Section 3.4 finds that loss spikes arise in our setting due to an out-of-date AdamW second moment estimator leading Section 3.5 to adopt and tests a fix developed in the context of AdaFactor . Finally, Section 3.6 connects loss spikes to low precision training.

Loss spikes can emerge when scaling up models . These instabilities may slow learning, or even destabilize training completely. Various solutions have been proposed, including freezing the embedding layer , adding additional layer normalization , or reparametrizing the weights .

In our work we investigate instabilities which arise during CLIP training. Unlike the instabilities observed in which lead to a slow divergence, we study fast loss spikes. Our results indicate that these spikes arise when the second moment estimator is out of date for early layers.

While our analysis and methods build directly on Shazeer and Stern (AdaFactor), there are important differences. In contrast with Shazeer and Stern , who only observe instabilities without warmup, we observe instabilities despite a long warmup period. Moreover, in contrast with Shazeer and Stern we find that an out-of-date second moment estimator is primarily an issue for the (patch) embedding layer, and measure how well loss spikes are predicted by this event. Finally, we note that researchers have moved away from AdaFactor in its original formulation for large-scale training , finding AdaFactor to under-perform AdamW . We believe this is due to the factored second moment or absence of first moment. This is why our focus is AdamW which is the de facto standard optimizer for transformers.

After the initial version of this paper we became aware of Cohen et al. which offers a general and principled treatment of fast loss spikes, and which we recommend to readers. Moreover, we direct the reader’s attention to the concurrent work of .

2 Experimental setup

As in Section 2, we train ViT CLIP models on LAION using OpenCLIP and evaluate them zero-shot on ImageNet. Since we are not interested in final performance and instead interested in studying instability—even for very large models—we use a short run which allows us to conduct multiple experiments. Concretly, we use patch-dropout 0.5 and 20k iterations. The first 5k iterations are linear warmup while the remainder are cosine decay . We follow the CLIP paper in that i) we do not use gradient clipping unless otherwise mentionedIt is possible that CLIP is trained with gradient clipping despite not mentioning it in the paper. However, this baseline follows the OpenCLIP library , which does not use gradient clipping by default since it follows what is mentioned in Radford et al. ., though we do clip the logit_scale\mathsf{logit\_scale} parameter, and ii) we add a layer-norm after the patch embedding and before the main transformer. Unless otherwise mentioned, experiments use batch size 16384 (per-gpu batch size of 256), learning rate 2e-3 and weight decay 0.2. We initially tried adding a layer-norm before the patch embedding as in , but removed this as we found it to hurt performance at CLIP ViT-Huge scale.

3 Loss spikes increase with model size, batch size, and learning rate

We begin our studying of loss spikes by observing how their presence varies when changing model size, batch size, and learning rate. The following sections build on these observations—in particular the finding that lowering the AdamW β2\beta_{2} hyperparameter removes spikes entirely.

We find that loss spikes increase when increasing model size (Figure 6), batch size (Figure 7), or learning rate (Figure 3). However, we also find that loss spikes can be avoided by reducing the β2\beta_{2} hyperparameter for in AdamW. On the other hand, if β2\beta_{2} is reduced too much then learning is slowed which results in worse performance .

Based on the observation in the previous section that lowering β2\beta_{2} reduces spikes, this section traces the cause of loss spikes to an out-of-date second moment estimator in the patch embedding layer.

Overview. Adaptive optimizers such as AdaGrad , Adam , or AdaFactor scale the update differently for each individual parameter. This is often conceptualized a per-parameter learning rate. For instance, in Adam/AdamW, per-parameter updates are scaled by the inverse root of the exponential moving average of squared gradients (see the code for AdamW in Algorithm 2, ignoring for now the modifications in pink which we discuss in Section 3.5).

This adaptivity can be a very useful tool for accelerating training, but can also cause issues when the learning signal changes. Concretely, exponential moving averages can become out of date causing updates to be scaled by a value that is too large. This issue is discussed in Section 5 of Shazeer and Stern , and we summarize below.

However, this method can break down when the learning signal changes and utu_{t} ceases to be a good estimator for the running average of gt2g_{t}^{2}. Consider the case where the gradient magnitudes have been historically very small for some parameters so 1/(ut+ϵ)1/\left(\sqrt{u_{t}}+\epsilon\right) is large for those parameters. If, then, at iteration tt those parameters suddenly receive a larger gradient signal the update can be catastrophically big. We refer to the scenario as the stuck-in-the-past scenario.

Overall, if β2\beta_{2} is too small then convergence may be slowed . If β2\beta_{2} is too large then utu_{t} can become out-of-date and no longer a good estimator for gt2g_{t}^{2}, resulting in per-parameter scaling that is too large.

Measurement. We now discuss measurement of the aforementioned stuck-in-the-past scenario and search for a predictive relationship between this event and a loss spike. We follow Shazeer and Stern and measure the following root-mean-square quantity, RMSt=\mathdsE[gt2/ut]\mathsf{RMS}_{t}=\sqrt{\mathds{E}\left[g_{t}^{2}/u_{t}\right]}. If utu_{t} is a good estimator for gt2g_{t}^{2} then the aggregate quantity RMSt\mathsf{RMS}_{t} will be around 1. The stuck-in-the-past scenario described above corresponds to an RMSt1\mathsf{RMS}_{t}\gg 1.

As illustrated in Figures 6-8, we observe instability for high β2\beta_{2} in our experiments even though we have 5k iterations of warm-up. While Shazeer and Stern first recognize the out-of-date second moment estimator issue, in their experimental setting they only observe instability without warm-up.

We now aim to establish a predictive relationship between the stuck-in-the-past scenario and loss spikes. We present initial results in Figure 9, where we examine RMSt\mathsf{RMS}_{t} for the the visual transformer patch embedding layer, visual.conv1.weight\mathsf{visual.conv1.weight}. This means that the expectation is computed over parameters in visual.conv1.weight\mathsf{visual.conv1.weight} only. This figure illustrates a few important findings: i) loss spikes tend to follow 1-8 iterations after an RMS spike, ii) loss spikes slow learning as recovery time is required, and iii), RMSt\mathsf{RMS}_{t} stays around 1 for lower β2\beta_{2}.

As this is just one example, we further elaborate on the predictive relationship between an RMS spike in the embedding layer in Section D through Figures 16, 17, 18, 19, 20, and 21. For analysis purposes, we define a heuristic to characterize loss and RMS spikes in visual.conv1.weight\mathsf{visual.conv1.weight}. We then show that 28 out of 30 detected loss spikes follow an RMS spike by 1-8 iterations, while the probability that a loss spike follows an RMS spike by chance is only 1%. Moreover, we find that the same predictive relationship does not exist for the RMS in other transformer layers.

5 StableAdamW: AdamW with update clipping from AdaFactor

This Section develops and tests StableAdamW (Algorithm 2), an AdamW-Adafactor hybrid.

To stabilize training, the AdaFactor optimizer divides the learning rate for iteration tt by 1/max(RMSt,1)1/\max(\mathsf{RMS}_{t},1).They actually introduce a hyperparameter dd and use 1/max(RMSt/d,1)1/\max(\mathsf{RMS}_{t}/d,1), but recommend setting d=1d=1 which we follow. They refer to this as update clipping. The effect is to slow training when utu_{t} is no longer a good estimator for gt2g_{t}^{2}.

As discussed in Section 3.4, our stability issues can be traced to an out-of-date utu_{t} which is what led Shazeer and Stern to update clipping, even though their stability issues are also solved with warm-up. Therefore, we port update clipping to the standard AdamW optimizer with d=1d=1 and refer to the resulting AdamW-Adafactor hybrid as StableAdamW (Algorithm 2). A modification we make is to compute and divide learning rate by max(RMSt,1)\max(\mathsf{RMS}_{t},1) independently for each tensor, which is for implementation convenience. This means that the expectation will be computed independently for each layer to produce a different RMSt\mathsf{RMS}_{t}.

We now test how StableAdamW compares with other stability interventions such as gradient clippingWe clip at global norm 1. We observed instability when trying 2 instead of 1. We did not tune this further, but note that 1.0 is standard in, e.g., PaLM , and Scaling Vision Transformers . or lowering β2\beta_{2}. These results, presented in Figure 10 find that StableAdamW (i.e., AdamW + update clipping) outperforms these aforementioned interventions for CLIP ViT-Huge. While gradient clipping and update clipping both remove instability, update clipping performs better in terms of zero-shot ImageNet accuracy. With update or gradient clipping, higher β2\beta_{2} such as 0.99 tends to perform better.

Appendix E provides further commentary and implementation considerations for StableAdamW.

6 Loss spikes and the loss scalar

This final Section ties the low precision training results 2 with our investigation into stability. Overall we find that loss spikes can co-occur with large activations and gradients. Large activations and gradients may cause issues during low precision training due to a more limited representible range. Therefore, reducing loss spikes is an important step for successful low precision training.

Supporting data is illustrated by Figure 11, in which an RMS spike precedes a loss spikes which coincides with spikes in the activations (i.e., features) and gradients. As we’ve previously seen (Figure 5), high feature magnitudes can pose challenges for low-precision training. Moreover, the spikes in the gradient are so large that Inf/NaN values occur, which results in the loss scalar dropping many times. There are a few takeaways from this observation.First, reducing loss spikes is an important step to enabling low-precision training. Second, spikes in gradient magnitude can be transient and therefore we may be adjusting the loss scalar too often—if using the PyTorch default loss scalar, thousands of iterations would be required before the loss scalar recovered to its value before this event. Finally, the layers highlighted in this figure are the main layers where Inf/NaN are encountered. Concretely, while we only track every tenth block, we never observe any Inf/NaN for any transformer block greater than 0. However, with the PyTorch default loss scalar an Inf/NaN in a single layer will skip the update for the whole network.

This motivates the loss scalar that we use in our experiments when one is required (except for in Figure 11). We use a loss scalar which i) checks for Inf/NaN at the individual tensor level and skips the update at the tensor level—not globally, and ii) remains fixed at its initial value.

This scalar allows fp16 mixed precision training for CLIP models at ViT-Huge scale where previously the scalar became too low and training diverged . We also believe an adaptive block-wise scalar as in Ramesh et al. would remedy this issue. One interesting remark is that often when we observe an Inf/NaN, it is in the patch embedding layer. Therefore, in the case where Inf/NaN’s happen frequently it recovers the stability solution of Chen et al. which is to freeze the embedding layer. As a final remark, we note that loss spikes do not always cause the loss scalar to drop, and emphasize the loss scalar can drop for various other reasons than spikes. Figure 11 is just an existence example that loss spikes can result in activation spikes and Inf/NaN gradients.

Conclusion

In summary, we have shared experiments in accelerating and stabilizing large multi-modal model training which we believe will be useful to the community. Moreover, we have shared resources such as triton kernels to enable building and improving on our work. We believe the main limitation of our work is that it is non-exhaustive. For instance, we only simulate float8 training. Second, we do not examine the impact on stability of width scalings for initialization and training hyperparameters such as those examined by . Finally, a limitation is that the checkpoints we produce have low accuracy. This is due to the limited compute budget that we use – our aim was to study loss spikes across multiple trials at scale and not attain competitive model performance. A redeeming aspect is that our early exploration informed the training hyperparameters which produced the highest accuracy open-source CLIP model so far .

For insightful discussions we thank Romain Beaumont, Yair Carmon, Mehdi Cherti, Brian Cheung, Alex Fang, Gabriel Ilharco, Jenia Jitsev, LAION, Sarah Pratt, Christoph Schuhmann, Ross Whightman, and Sho Yaida. We thank Emad Mostaque and stability.ai for compute resources.

This work is in part supported by NSF IIS 1652052, IIS 17303166, DARPA N66001-19-2-4031, DARPA W911NF-15-1-0543 and gifts from Allen Institute for Artificial Intelligence.

References

Appendix A Additional Related Work on Quantization

The literature on training neural networks in low-bit precision is vast. The main differentiating factor of our work is that we train relatively large models – in fact, we train the largest 8-bit vision transformers to date.

The literature agrees that quantization of very large networks is more difficult than for smaller networks . As such, we divide our related work into three parts: (1) large-scale low-precision neural network (larger than BERT-large), and (2) low-precision training of smaller networks.

Our work is currently the only work that does low-precision (8-bit and below) training of very large networks with more than 230M parameters. Other related work studies inference at scale. SmoothQuant , ZeroQuant , NuQmm , and LLM.int8() study inference with Int8 matrix multiplication. Another line of work studies large models inference with more than 250M parameters by considering 16-bit inputs and k-bit weights .

Training of small-scale low-precision neural networks can take many shapes and forms, such as quantization for integer only devices, quantization for mobile device, or quantization to accelerate training. One way to break up these directions is through the data type used and the neural network trained. One major direction is to quantize convolutional neural networks often for fast and memory efficient usage on edge devices . Further work in this area is discussed in the survey by . Another line of work is centered around 8-bit float data types which can be used to accelerate training of neural networks . Lastly, a common application is to finetune (similar to training) BERT models to particular datasets. This not only decreases the model footprint and increases inference speed but adjusts the model to new data .

Appendix B Additional code and figures

This Section provides additional pseudocode:

Algorithm 3 is the memory effecient variant of SwitchBack.

Algorithm 4 is the variant of SwitchBack which uses row- and column-wise quantization for the weights.

Algorithm 5 is a standard linear layer implemented with torch.autograd\mathsf{torch.autograd}.

These implementations can be found at https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/nn/triton_based_modules.py. To use in OpenCLIP training (https://github.com/mlfoundations/open_clip), add the argument:

--use-bnb-linear SwitchBackLinearGlobal for Algorithm 1.

--use-bnb-linear SwitchBackLinearGlobalMemEfficient for Algorithm 3.

--use-bnb-linear SwitchBackLinearVectorwise for Algorithm 4.

--use-bnb-linear StandardLinear for Algorithm 5.

B.2 Additional Figures

This section presents additional figures.

Figure 12 presents a more fine-grained version of Figure 3.

Figure 13 compares the speed-up of SwitchBack compared to LLM.int8().

Figure 14 shows the mean and max for the gradient and activation (i.e., feature) throughout training.

Figure 15 shows that using a schedule for β2\beta_{2} of the form 1iterationλ1-\mathsf{iteration}^{-\lambda} does not improve accuracy.

Appendix C Analysis

This section shows that error due to quantization increases with kk. This suggests why SwitchBack may achieve high accuracy, as we avoid quantizing matmuls for which kk is very large. For the weight gradient computation, which we leave in high precision, kk is batch size times sequence length, which is often 32000\approx 32000 in our experiments. For the other operations which comprise a matmul, kk is less than 4embed_dim4\cdot\mathsf{embed\_dim} which is 8000\leq 8000 in our experiments. These dimensions are standard for CLIP training experiments .

This section measures the variance due to quantization for the inner product between uu and vv. Let uu, vv be vectors of length kk vectors with each element drawn i.i.d. from a distribution with mean 0. Let uiu_{i} have variance σu2\sigma_{u}^{2} and viv_{i} have variance σv2\sigma_{v}^{2}.

Next, let u^\hat{u} and v^\hat{v} be the quantized versions of uu and vv, respectively. We model quantization error as u^i=ui+ϵi\hat{u}_{i}=u_{i}+\epsilon_{i} and v^i=vi+ξi\hat{v}_{i}=v_{i}+\xi_{i} where ϵi,ξi\epsilon_{i},\xi_{i} are i.i.d. mean centered random variables with variance σq2\sigma_{q}^{2}.

The aim of this section is to show that variance due to quantization grows with kk. Our analysis is conservative because we do not assume the variance of ϵi,ξi\epsilon_{i},\xi_{i} increase with kk, though in practice we believe they would as the absmax of uu and vv increases with kk.

We first examine the variance of u^iv^i\hat{u}_{i}\hat{v}_{i}. By using that all random variable are mean centered, this variance is given by,

Next, we use linearity of variance for independent random variables to calculate Var(u^,v^)\text{Var}\left(\langle\hat{u},\hat{v}\rangle\right). This is given by,

C.2 Takeaways

We have shown that for inner products with length kk vectors, variance due to quantization increases with kk. This means the variance of output units/features due to quantization increases with kk which can thought of making the outputs more noisy. Noise compounds throughout the network and will eventually drown out useful signal—for large kk the network features or gradient will no longer lead to effective learning.

C.3 Why LLM.int8() fails: LLMs vs CLIP models

This Section details our hypothesis for why SwitchBack outperforms LLM.int8() for CLIP training, which is conditioned on the analysis in Section C.1 being a good model for training.

From our analysis we have shown that the variance in the output features increases with the size of the inner products of a quantized matrix multiplication compared to the full precision matrix multiplication. As such, we may have different failure modes for transformers pretrained on text, such as GPT-3 or LLaMA , compared to CLIP models .

Pretrained large language models (LLMs) tend to have larger weight matrices relative to their batch sizes when compared to CLIP models. CLIP models perform best when the batch size is large . As a consequence, LLMs and CLIP models have their most noisy operations for different matrix multiplications. LLMs are most noisy in the forward pass XWTXW^{T} and during layer-to-layer back propagation Y˙kWk=X˙k1\dot{Y}_{k}W_{k}=\dot{X}_{k-1} where inner product dimension are large, for example, they are 32768 and 8192 for the output projection of LLaMA 65B, 32768 and 8192. While the weight gradient inner product size is determined by the per-GPU batch size, which is 2048 for LLaMA (4M tokens per full batch distributed across 2048 GPUs). As such, if the quantization produces the same variance in quantization errors, then the weight gradient in LLM int8 training is between 4x and 16x less noisy if the analysis in Section C.1 is a good model for training.

For CLIP training with ViT-Huge, we have a batch size of 65536 per GPU (256x images of size 224x224 inputs with patch size 14x14, leading to 16x16 patches for each images, resulting in 65536 patches per GPU). The dimensions for the weight matrices are 1280×51201280\times 5120. As such, analogous to above for the LLaMA LLM, the weight gradient in CLIP models is between 51.2x to 12.8x more noisy compared to the forward and layer-to-layer backpropagation operations if the analysis in Section C.1 is a good model for training. Notice that the CLIP weight gradient is twice as noisy compared to the most noisy LLaMA 65B operations if we assume that all quantization operations have the same error variance.

As such, low-precision LLM training and CLIP requires high-precision quantization routines for different parts of the training.

This also gives the reason why we believe LLM.int8() fails despite replicating inference performance – the weight gradient in CLIP training is a highly noisy operation which might not give enough signal to SGD to converge to a local minimum.

Appendix D RMS Spikes precede Loss Spikes

This section further elaborate on the predictive relationship between an RMS spike in the embedding layer and a loss spike as in Figure 9.

We define a heuristic to characterize loss and RMS spikes which we use for analysis. We determined these heuristics by checking if they qualitatively coincided with what appeared to be a loss spike. We display results in this Section so that the reader can also evaluate if these heuristics appear reasonable.

We define RMS spikes events as {t:RMSt2.3}\{t:\mathsf{RMS}_{t}\geq 2.3\} while loss spike events are defined as the set of tt where loss at time tt exceeds the running mean by 3.23.2 times the running standard deviation. Finally, we ignore the first 1000 iterations when learning rate is low.

We also deduplicate the RMS and loss spikes iterations as follows: multiple spikes over a short time interval of 10 iterations are only counted as one spike and start at the earliest time. Moreover, we only count a loss spike if there are multiple deviations in an interval of 10, which indicates that loss has meaningfully spiked.

Figure 16 observes that out of 15 total loss spikes for ViT-Huge across different β2\beta_{2}, 14 out of 15 come 1-8 iterations after an RMS spike in the patch embedding layer (module.conv1.weight\mathsf{module.conv1.weight}). With only 76 total RMS spike events, the probability that a loss spike follows 1-8 iterations after an RMS spike by chance is <1%<1\%.

Figure 17 repeats this analysis for ViT-Large, wherein 13 out of 15 loss spikes follow an RMS spike by 1-8 iterations. The probability that a loss spike follows an RMS spike by chance is 1.0%1.0\%.

Figure 18 zooms in on Figure 16 to show additional detail.

Figures 19 and 20 examine the cases where loss spikes fail to be detected in Figures 16 and 17, finding them to mainly be issues with the heuristic identifying loss spikes, i.e., false positive loss spikes.

Finally, Figure 21 repeats Figure 16 but examines the RMS\mathsf{RMS} of a random layer in the middle of the transformer—not the patch embedding layer. In this case, none of the loss spikes follow RMS spikes.

Appendix E StableAdamW continued

This Section asks and answers a series of questions the reader may have concerning Section 3.5.

First, why not just use AdaFactor? The answer is that the community has moved away from AdaFactor as they find that AdaFactor under-performs AdamW at scale . We believe this is likely due to the factored moments, and not other features such as update-clipping. The goal of this work is to advocate using a hybrid. We tried porting other features from AdaFactor to AdamW such as the β2\beta_{2} schedule but did not find them to help (Figure 15). Moreover, while PaLM uses an AdaFactor-AdamW hybrid, we believe they don’t use update clipping.

Another question is, why not use an optimizer such as Lion which does not divide updates by any value, and is therefore immune to the stuck-in-the-past scenario. We believe this may be a promising path forward. However, while we observe that Lion outperforms AdamW at small scale, Lion still slightly under-performs AdamW for CLIP ViT-Huge scale in our experiments.This may be out of date, track the latest https://github.com/mlfoundations/open_clip/pull/432.

A final question is, why consider gt2g_{t}^{2} in the numerator for computing RMSt\mathsf{RMS}_{t} and not vt2v_{t}^{2}? We also tried vt2v_{t}^{2} and found the performance worse.

E.2 Implementation considerations

To prevent divide by 0 issues when computing RMSt\mathsf{RMS}_{t} we compute RMSt=\mathdsE[gt2/maximum(ut,ϵ2)]\mathsf{RMS}_{t}=\sqrt{\mathds{E}\left[g_{t}^{2}/\mathsf{maximum}(u_{t},\epsilon^{2})\right]} where ϵ\epsilon is the AdamW hyperparamer for which we use 1e-6 and maximum\mathsf{maximum} is an elementwise maximum. This is instead of RMSt=\mathdsE[gt2/ut]\mathsf{RMS}_{t}=\sqrt{\mathds{E}\left[g_{t}^{2}/u_{t}\right]}.