FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
Introduction
Transformer models have emerged as the most widely used architecture in applications such as natural language processing and image classification. Transformers have grown larger and deeper , but equipping them with longer context remains difficult , since the self-attention module at their heart has time and memory complexity quadratic in sequence length. An important question is whether making attention faster and more memory-efficient can help Transformer models address their runtime and memory challenges for long sequences.
Many approximate attention methods have aimed to reduce the compute and memory requirements of attention. These methods range from sparse-approximation to low-rank approximation , and their combinations . Although these methods reduce the compute requirements to linear or near-linear in sequence length, many of them do not display wall-clock speedup against standard attention and have not gained wide adoption. One main reason is that they focus on FLOP reduction (which may not correlate with wall-clock speed) and tend to ignore overheads from memory access (IO).
In this paper, we argue that a missing principle is making attention algorithms IO-aware —that is, carefully accounting for reads and writes to different levels of fast and slow memory (e.g., between fast GPU on-chip SRAM and relatively slow GPU high bandwidth memory, or HBM , Figure 1 left). On modern GPUs, compute speed has out-paced memory speed , and most operations in Transformers are bottlenecked by memory accesses . IO-aware algorithms have been critical for similar memory-bound operations, when reading and writing data can account for a large portion of the runtime—such as database joins , image processing , numerical linear algebra , and more . However, common Python interfaces to deep learning such as PyTorch and Tensorflow do not allow fine-grained control of memory access.
We propose FlashAttention, a new attention algorithm that computes exact attention with far fewer memory accesses. Our main goal is to avoid reading and writing the attention matrix to and from HBM. This requires (i) computing the softmax reduction without access to the whole input (ii) not storing the large intermediate attention matrix for the backward pass. We apply two well-established techniques to address these challenges. (i) We restructure the attention computation to split the input into blocks and make several passes over input blocks, thus incrementally performing the softmax reduction (also known as tiling). (ii) We store the softmax normalization factor from the forward pass to quickly recompute attention on-chip in the backward pass, which is faster than the standard approach of reading the intermediate attention matrix from HBM. We implement FlashAttention in CUDA to achieve fine-grained control over memory access and fuse all the attention operations into one GPU kernel. Even with the increased FLOPs due to recomputation, our algorithm both runs faster (up to 7.6x on GPT-2 , Figure 1 right) and uses less memory—linear in sequence length—than standard attention, thanks to the massively reduced amount of HBM access.
We analyze the IO complexity of FlashAttention, proving that it requires HBM accesses where is the head dimension and is the size of SRAM, as compared to of standard attention. For typical values of and , FlashAttention requires many times fewer HBM accesses compared to standard attention (up to 9 fewer, as shown in Fig. 2). Moreover, we provide a lower bound, showing that no exact attention algorithm can asymptotically improve on the number of HBM accesses over all SRAM sizes.
We also show that FlashAttention can serve as a useful primitive for realizing the potential of approximate attention algorithms by overcoming their issues with memory access overhead. As a proof of concept, we implement block-sparse FlashAttention, a sparse attention algorithm that is 2-4 faster than even FlashAttention, scaling up to sequence length of 64k. We prove that block-sparse FlashAttention has better IO complexity than FlashAttention by a factor proportional to the sparsity ratio. We discuss further extensions to other operations (attention on multi-GPU, kernel regression, block-sparse matrix multiply) in Section 5. We open-source FlashAttention to make it easier to build on this primitive.111FlashAttention code is available at https://github.com/HazyResearch/flash-attention
We empirically validate that FlashAttention speeds up model training and improves model quality by modeling longer context. We also benchmark the runtime and memory footprint of FlashAttention and block-sparse FlashAttention compared to prior attention implementations.
Faster Model Training. FlashAttention trains Transformer models faster in wall-clock time. We train BERT-large (seq. length 512) 15% faster than the training speed record in MLPerf 1.1 , GPT2 (seq. length 1K) 3 faster than baseline implementations from HuggingFace and Megatron-LM , and long-range arena (seq. length 1K-4K) 2.4 faster than baselines.
Higher Quality Models. FlashAttention scales Transformers to longer sequences, which improves their quality and enables new capabilities. We observe a 0.7 improvement in perplexity on GPT-2 and 6.4 points of lift from modeling longer sequences on long-document classification . FlashAttention enables the first Transformer that can achieve better-than-chance performance on the Path-X challenge, solely from using a longer sequence length (16K). Block-sparse FlashAttention enables a Transformer to scale to even longer sequences (64K), resulting in the first model that can achieve better-than-chance performance on Path-256.
Benchmarking Attention. FlashAttention is up to 3 faster than the standard attention implementation across common sequence lengths from 128 to 2K and scales up to 64K. Up to sequence length of 512, FlashAttention is both faster and more memory-efficient than any existing attention method, whereas for sequence length beyond 1K, some approximate attention methods (e.g., Linformer) start to become faster. On the other hand, block-sparse FlashAttention is faster than all existing approximate attention methods that we know of.
Background
We provide some background on the performance characteristics of common deep learning operations on modern hardware (GPUs). We also describe the standard implementation of attention.
We focus here on GPUs. Performance on other hardware accelerators are similar .
GPU Memory Hierarchy. The GPU memory hierarchy (Fig. 1 left) comprises multiple forms of memory of different sizes and speeds, with smaller memory being faster. As an example, the A100 GPU has 40-80GB of high bandwidth memory (HBM) with bandwidth 1.5-2.0TB/s and 192KB of on-chip SRAM per each of 108 streaming multiprocessors with bandwidth estimated around 19TB/s . The on-chip SRAM is an order of magnitude faster than HBM but many orders of magnitude smaller in size. As compute has gotten faster relative to memory speed , operations are increasingly bottlenecked by memory (HBM) accesses. Thus exploiting fast SRAM becomes more important.
Execution Model. GPUs have a massive number of threads to execute an operation (called a kernel). Each kernel loads inputs from HBM to registers and SRAM, computes, then writes outputs to HBM.
Performance characteristics. Depending on the balance of computation and memory accesses, operations can be classified as either compute-bound or memory-bound. This is commonly measured by the arithmetic intensity , which is the number of arithmetic operations per byte of memory access.
Compute-bound: the time taken by the operation is determined by how many arithmetic operations there are, while time accessing HBM is much smaller. Typical examples are matrix multiply with large inner dimension, and convolution with large number of channels.
Memory-bound: the time taken by the operation is determined by the number of memory accesses, while time spent in computation is much smaller. Examples include most other operations: elementwise (e.g., activation, dropout), and reduction (e.g., sum, softmax, batch norm, layer norm).
Kernel fusion. The most common approach to accelerate memory-bound operations is kernel fusion: if there are multiple operations applied to the same input, the input can be loaded once from HBM, instead of multiple times for each operation. Compilers can automatically fuse many elementwise operations . However, in the context of model training, the intermediate values still need to be written to HBM to save for the backward pass, reducing the effectiveness of naive kernel fusion.
2 Standard Attention Implementation
Given input sequences where is the sequence length and is the head dimension, we want to compute the attention output :
where is applied row-wise.
Standard attention implementations materialize the matrices and to HBM, which takes memory. Often (e.g., for GPT2, and ). We describe the standard attention implementation in Algorithm . As some or most of the operations are memory-bound (e.g., softmax), the large number of memory accesses translates to slow wall-clock time.
This problem is exacerbated by other elementwise operations applied to the attention matrix, such as masking applied to or dropout applied to . As a result, there have been many attempts to fuse several elementwise operations, such as fusing masking with softmax .
In Section 3.2, we will show that the standard attention implementation performs HBM accesses quadratic in the sequence length . We also compare the number of FLOPs and number of HBM accesses of standard attention and of our method (FlashAttention).
FlashAttention: Algorithm, Analysis, and Extensions
We show how to compute exact attention with fewer HBM reads/writes and without storing large intermediate matrices for the backward pass. This yields an attention algorithm that is both memory efficient and faster in wall-clock time. We analyze its IO complexity, showing that our method requires much fewer HBM accesses compared to standard attention. We further show that FlashAttention can serve as a useful primitive by extending it to handle block-sparse attention.
We focus here on the forward pass for ease of exposition; Appendix B contains details for the backward.
Given the inputs in HBM, we aim to compute the attention output and write it to HBM. Our goal is to reduce the amount of HBM accesses (to sub-quadratic in ).
We apply two established techniques (tiling, recomputation) to overcome the technical challenge of computing exact attention in sub-quadratic HBM accesses. We describe this in Algorithm 1. The main idea is that we split the inputs into blocks, load them from slow HBM to fast SRAM, then compute the attention output with respect to those blocks. By scaling the output of each block by the right normalization factor before adding them up, we get the correct result at the end.
Tiling. We compute attention by blocks. Softmax couples columns of , so we decompose the large softmax with scaling . For numerical stability, the softmax of vector is computed as:
For vectors , we can decompose the softmax of the concatenated as:
superscript𝑒𝑚superscript𝑥1𝑚𝑥ℓsuperscript𝑥1superscript𝑒𝑚superscript𝑥2𝑚𝑥ℓsuperscript𝑥2softmax𝑥𝑓𝑥ℓ𝑥\displaystyle\ell(x)=\ell(\begin{bmatrix}x^{(1)}\ x^{(2)}\end{bmatrix})=e^{m(x^{(1)})-m(x)}\ell(x^{(1)})+e^{m(x^{(2)})-m(x)}\ell(x^{(2)}),\quad\mathrm{softmax}(x)=\frac{f(x)}{\ell(x)}. Therefore if we keep track of some extra statistics (), we can compute softmax one block at a time.222This style of aggregation is called algebraic aggregation . We thus split the inputs into blocks (Algorithm 1 line 3), compute the softmax values along with extra statistics (Algorithm 1 line 10), and combine the results (Algorithm 1 line 12).
Recomputation. One of our goals is to not store intermediate values for the backward pass. The backward pass typically requires the matrices to compute the gradients with respect to . However, by storing the output and the softmax normalization statistics , we can recompute the attention matrix and easily in the backward pass from blocks of in SRAM. This can be seen as a form of selective gradient checkpointing . While gradient checkpointing has been suggested to reduce the maximum amount of memory required , all implementations (that we know off) have to trade speed for memory. In contrast, even with more FLOPs, our recomputation speeds up the backward pass due to reduced HBM accesses (Fig. 2). The full backward pass description is in Appendix B.
Implementation details: Kernel fusion. Tiling enables us to implement our algorithm in one CUDA kernel, loading input from HBM, performing all the computation steps (matrix multiply, softmax, optionally masking and dropout, matrix multiply), then write the result back to HBM (masking and dropout in Appendix B). This avoids repeatedly reading and writing of inputs and outputs from and to HBM.
superscript𝑒subscript𝑚𝑖superscriptsubscript𝑚𝑖newsubscriptℓ𝑖superscript𝑒subscript~𝑚𝑖𝑗superscriptsubscript𝑚𝑖newsubscript~ℓ𝑖𝑗superscriptℝsubscript𝐵𝑟\ell_{i}^{\mathrm{new}}=e^{m_{i}-m_{i}^{\mathrm{new}}}\ell_{i}+e^{\tilde{m}_{ij}-m_{i}^{\mathrm{new}}}\tilde{\ell}_{ij}\in\mathbb{R}^{B_{r}}. 12: Write to HBM. 13: Write , to HBM. 14: end for 15: end for 16: Return . We show FlashAttention’s correctness, runtime, and memory requirement (proof in Appendix C).
Algorithm 1 returns with FLOPs and requires additional memory beyond inputs and output.
2 Analysis: IO Complexity of FlashAttention
We analyze the IO complexity of FlashAttention, showing significant reduction in HBM accesses compared to standard attention. We also provide a lower bound, proving that no exact attention algorithm can asymptotically improve on HBM accesses over all SRAM sizes. Proofs are in Appendix C.
Let be the sequence length, be the head dimension, and be size of SRAM with . Standard attention (Algorithm ) requires HBM accesses, while FlashAttention (Algorithm 1) requires HBM accesses.
For typical values of (64-128) and (around 100KB), is many times smaller than , and thus FlashAttention requires many times fewer HBM accesses than standard implementation. This leads to both faster execution and lower memory footprint, which we validate in Section 4.3.
The main idea of the proof is that given the SRAM size of , we can load blocks of of size each (Algorithm 1 line 6). For each block of and , we iterate over all blocks of (Algorithm 1 line 8) to compute the intermediate values, resulting in passes over . Each pass loads elements, which amounts to HBM accesses. We similarly prove that the backward pass of standard attention requires HBM accesses while the backward pass of FlashAttention requires HBM accesses (Appendix B).
We prove a lower-bound: one cannot asymptotically improve on the number of HBM accesses for all values of (the SRAM size) when computing exact attention.
Let be the sequence length, be the head dimension, and be size of SRAM with . There does not exist an algorithm to compute exact attention with HBM accesses for all in the range .
The proof relies on the fact that for any algorithm must perform HBM accesses. This type of lower bound over a subrange of is common in the streaming algorithms literature . We leave proving parameterized complexity lower bounds in terms of as exciting future work.
We validate that the number of HBM accesses is the main determining factor of attention run-time. In Fig. 2 (left), we see that even though FlashAttention has higher FLOP count compared to standard attention (due to recomputation in the backward pass), it has much fewer HBM accesses, resulting in much faster runtime. In Fig. 2 (middle), we vary the block size of FlashAttention, which results in different amounts of HBM accesses, and measure the runtime of the forward pass. As block size increases, the number of HBM accesses decreases (as we make fewer passes over the input), and runtime decreases. For large enough block size (beyond 256), the runtime is then bottlenecked by other factors (e.g., arithmetic operations). Moreover, larger block size will not fit into the small SRAM size.
3 Extension: Block-Sparse FlashAttention
We extend FlashAttention to approximate attention: we propose block-sparse FlashAttention, whose IO complexity is smaller than FlashAttention by a factor proportional to the sparsity.
Given inputs and a mask matrix , we want to compute:
where if and if . We require to have block form: for some block sizes , for all , with for some .
Given a predefined block sparsity mask we can easily adapt Algorithm 1 to only compute the nonzero blocks of the attention matrix. The algorithm is identical to Algorithm 1, except we skip zero blocks. We reproduce the algorithm description in Algorithm 5 in Appendix B.
We also analyze the IO complexity of block-sparse FlashAttention.
Let be the sequence length, be the head dimension, and be size of SRAM with . Block-sparse FlashAttention (Algorithm 5) requires HBM accesses where is the fraction of nonzero blocks in the block-sparsity mask.
We see that applying block-sparsity yields a direct improvement by the sparsity to the larger term in the IO complexity. For large sequence lengths , is often set to or , resulting in or IO complexity. For downstream experiments, we use the fixed butterfly sparsity pattern , which has been shown to be able to approximate arbitrary sparsity .
In Fig. 2 (right), we validate that as the sparsity increases, the runtime of block-sparse FlashAttention improves proportionally. On the LRA benchmark, block-sparse FlashAttention achieves 2.8 speedup, while performing on par with standard attention (Section 4).
Experiments
We evaluate the impact of using FlashAttention to train Transformer models. We validate two claims about training time and model accuracy, and report attention runtime and memory benchmarks.
Training Speed. FlashAttention outperforms the MLPerf 1.1 speed record for BERT by 15%, and speeds up GPT-2 up to 3 over HuggingFace and over Megatron over standard Transformers. FlashAttention speeds up the long-range arena (LRA) benchmark 2.4.
Quality. FlashAttention scales Transformers to longer sequences, yielding higher quality. FlashAttention trains GPT-2 with context length 4K faster than Megatron trains GPT-2 with context length 1K, while achieving 0.7 better perplexity. Modeling longer sequences yields 6.4 points of lift on two long-document classification tasks. Finally, FlashAttention yields the first Transformer that can achieve better-than-random performance on the challenging Path-X task (sequence length 16K), and block-sparse FlashAttention yields the first sequence model that we know of that can achieve better-than-random performance on Path-256 (sequence length 64K).
Benchmarking Attention. We measure the runtime and memory performance of FlashAttention and block-sparse FlashAttention based on sequence length. We confirm that the memory footprint of FlashAttention scales linearly with seq. length and is up to 3 faster than standard attention for common seq. lengths (up to 2K). We confirm that runtime of block-sparse FlashAttention scales linearly in seq. length and is faster than all existing approximate attention baselines.
Additional experiment details are in Appendix E.
FlashAttention yields the fastest single-node BERT training speed that we know of. We train a BERT-large model with FlashAttention on Wikipedia. Table 1 compares our training time to the implementation from Nvidia that set the training speed record for MLPerf 1.1 . Our implementation is 15% faster.
FlashAttention yields faster training times for GPT-2 on the large OpenWebtext dataset than the widely used HuggingFace and Megatron-LM implementations. Table 2 shows up to 3 end-to-end speedup compared to Huggingface and 1.7 speedup compared to Megatron-LM. FlashAttention achieves the same perplexity as the other two implementations, as we do not change the model definition. Appendix E includes plots of the validation perplexity throughout training, confirming that FlashAttention is as numerically stable as the baselines and produces the same training / validation curves.
We compare vanilla Transformer (with either standard implementation or FlashAttention) on the long-range arena (LRA ) benchmark. We measure accuracy, throughput, and training time of all models. Each task has a different sequence length varying between 1024 and 4096. We follow the implementation and experimental setting in Tay et al. and Xiong et al. .333LRA accuracy results are known to be highly dependent on the tuning procedure . Our reproduced baselines perform better than as reported in the original comparison . Table 3 shows that FlashAttention achieves up 2.4 speed-up compared to standard attention. Block-sparse FlashAttention is faster than all of the approximate attention methods that we have tested.
2 Better Models with Longer Sequences
The runtime and memory-efficiency of FlashAttention allow us to increase the context length of GPT-2 by 4 while still running faster than the optimized implementation from Megatron-LM. Table 4 shows that that GPT-2 with FlashAttention and context length 4K is still 30% faster than GPT-2 from Megatron with context length 1K, while achieving 0.7 better perplexity.
Training Transformers with longer sequences with FlashAttention improves performance on the MIMIC-III and ECtHR datasets. MIMIC-III contains intensive care unit patient discharge summaries, each annotated with multiple labels. ECtHR contains legal cases from the European Court of Human Rights, each of which is mapped to articles of the Convention of Human Rights that were allegedly violaged. Both of these datasets contain very long text documents; the average number of tokens in MIMIC is 2,395 tokens, and the longest document contains 14,562 tokens, while the average and longest numbers in ECtHR are 2,197 and 49,392, respectively. We evaluate lift from increasing the sequence length of a pretrained RoBERTa model (we repeat the positional embeddings, as in Beltagy et al. ).
Table 6 shows that sequence length 16K outperforms length 512 by 4.3 points on MIMIC, and that length 8K outperforms length 512 by 8.5 points on ECtHR. The discrepancies may be due to subtle distribution shifts: MIMIC-III contains specialized medical text and thus may be more susceptible to a distribution shift in the document length, whereas ECtHR contains general language.
The Path-X and Path-256 benchmarks are challenging tasks from the long-range arena benchmark designed to test long context. The task is to classify whether two points in a black and white 128128 (or 256256) image have a path connecting them, and the images are fed to the transformer one pixel at a time. In prior work, all transformer models have either run out of memory, or only achieved random performance . There has been a search for alternative architectures that can model such long context . We present here the first result of Transformer models being able to solve Path-X and Path-256 (Table 6). We pretrain a transformer on Path-64, and then transfer to Path-X by spatially interpolating the positional embeddings. FlashAttention achieves 61.4 accuracy on Path-X. Additionally, block-sparse FlashAttention enables the Transformers to scale to sequence length 64K, achieving 63.1 accuracy444Path-256 requires longer sequences but has relatively shorter paths than Path-X, so it is easier to obtain a higher accuracy. on Path-256.
3 Benchmarking Attention
We vary sequence length and measure runtime and memory usage of FlashAttention and block-sparse FlashAttention against various attention baselines on one A100 GPU with 40 GB HBM, with dropout and a padding mask. We compare against reference implementations for exact attention, approximate attention, and sparse attention. We report a subset of baselines in the main body; Appendix E contains more baselines and full details.
Figure 3 (left) reports the runtime in milliseconds of the forward + backward pass of FlashAttention and block-sparse FlashAttention compared to the baselines in exact, approximate, and sparse attention (exact numbers in Appendix E). Runtime grows quadratically with sequence length, but FlashAttention runs significantly faster than exact attention baselines, up to 3 faster than the PyTorch implementation. The runtimes of many approximate/sparse attention mechanisms grow linearly with sequence length, but FlashAttention still runs faster than approximate and sparse attention for short sequences due to fewer memory accesses. The approximate attention runtimes begin to cross over with FlashAttention at sequences between 512 and 1024. On the other hand, block-sparse FlashAttention is faster than all implementations of exact, sparse, and approximate attention that we know of, across all sequence lengths.
Figure 3 (right) shows the memory footprint of FlashAttention and block-sparse FlashAttention compared to various exact, approximate, and sparse attention baselines. FlashAttention and block-sparse FlashAttention have the same memory footprint, which grows linearly with sequence length. FlashAttention is up to 20 more memory efficient than exact attention baselines, and is more memory-efficient than the approximate attention baselines. All other algorithms except for Linformer run out of memory on an A100 GPU before 64K, and FlashAttention is still 2 more efficient than Linformer.
Limitations and Future Directions
We discuss limitations of our approach and future directions. Related work is given in Appendix A.
Compiling to CUDA. Our current approach to building IO-aware implementations of attention requires writing a new CUDA kernel for each new attention implementation. This requires writing the attention algorithm in a considerably lower-level language than PyTorch, and requires significant engineering effort. Implementations may also not be transferrable across GPU architectures. These limitations suggest the need for a method that supports writing attention algorithms in a high-level language (e.g., PyTorch), and compiling to IO-aware implementations in CUDA—similar to efforts such as Halide in image processing .
IO-Aware Deep Learning. We believe that the IO-aware approach can extend beyond attention. Attention is the most memory-intensive computation in Transformers, but every layer in a deep network touches GPU HBM. We hope our work inspires IO-aware implementations of additional modules. We discuss these potential extensions in Appendix D.
Multi-GPU IO-Aware Methods. Our IO-aware implementation of attention is optimal within constants for computing attention on a single GPU. However, the attention computation may be parallelizable across multiple GPUs . Using multiple GPUs adds an additional layer to IO analysis—accounting for data transfer between GPUs. We hope our work inspires future work in this direction.
Our implementation uses Apex’s FMHA code (https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha) as a starting point. We thank Young-Jun Ko for the in-depth explanation of his FMHA implementation and for his thoughtful answers to our questions about CUDA. We thank Sabri Eyuboglu, Megan Leszczynski, Laurel Orr, Yuhuai Wu, Beidi Chen, and Xun Huang for their constructive feedback and suggestions on early drafts of the paper. We thank Markus Rabe and Charles Staats for helpful discussion of their attention algorithm.
We gratefully acknowledge the support of NIH under No. U54EB020405 (Mobilize), NSF under Nos. CCF1763315 (Beyond Sparsity), CCF1563078 (Volume to Velocity), and 1937301 (RTML); ARL under No. W911NF-21-2-0251 (Interactive Human-AI Teaming); ONR under No. N000141712266 (Unifying Weak Supervision); ONR N00014-20-1-2480: Understanding and Applying Non-Euclidean Geometry in Machine Learning; N000142012275 (NEPTUNE); NXP, Xilinx, LETI-CEA, Intel, IBM, Microsoft, NEC, Toshiba, TSMC, ARM, Hitachi, BASF, Accenture, Ericsson, Qualcomm, Analog Devices, Google Cloud, Salesforce, Total, the HAI-GCP & HAI-Azure Cloud Credits for Research program, the Stanford Data Science Initiative (SDSI), Department of Defense (DoD) through the National Defense Science and Engineering Graduate Fellowship (NDSEG) Program, and members of the Stanford DAWN project: Facebook, Google, and VMWare. The U.S. Government is authorized to reproduce and distribute reprints for Governmental purposes notwithstanding any copyright notation thereon. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the authors and do not necessarily reflect the views, policies, or endorsements, either expressed or implied, of NIH, ONR, or the U.S. Government. Atri Rudra’s research is supported by NSF grant CCF-1763481.
References
Appendix A Related Work
IO-Aware Runtime Optimization. The broad concept of optimizing for reading and writing to fast/slow memory has a long history in computer science and has been known by many names. We draw the most direct connection to the literature of analyzing I/O complexity in this work , but concepts of memory hierarchies are fundamental and has appeared in many forms, from the working set model , to data locality , to the Roofline model of arithmetic intensity , to analyses of scalability , to standard textbook treatments of computer architecture . We hope that this work encourages the community to adopt these ideas in more parts of the deep learning stack.
Efficient ML Models with Structured Matrices. Matrix multiply is the core computational bottleneck of most machine learning models. To reduce the computational complexity, there have been numerous approaches to learn over a more efficient set of matrices. These matrices are called structured matrices, which have subquadratic ( for dimension ) number of parameters and runtime. Most common examples of structured matrices are sparse and low-rank matrices, along with fast transforms commonly encountered in signal processing (Fourier, Chebyshev, sine/cosine, orthogonal polynomials). There have been several more general classes of structured matrices proposed in machine learning: Toeplitz-like , low-displacement rank , quasi-separable ). The butterfly pattern we use for our block-sparse attention is motivated by the fact that butterfly matrices and their products have been shown to be able to express any structured matrices with almost optimal runtime and number of parameters . However, even though structured matrices are efficient in theory, they have not seen wide adoption since it is hard to translate their efficiency to wall-clock speedup since dense unconstrained matrix multiply has very optimize implementation, a phenomenon known as the hardware lottery . Extensions of butterfly matrices aimed to make butterfly matrices more hardware-friendly.
Sparse Training. Our block-sparse FlashAttention can be seen as a step towards making sparse model training more efficient. Sparse models have seen success in compressing models for inference (pruning) by sparsifying the weight matrices . For model training, the lottery tickets hypothesis suggests that there are a set of small sub-networks derived from a larger dense network that performs as well as the original dense network. Out block-sparse FlashAttention can also be seen as a fixed lottery ticket in the context of attention: we fix the sparsity pattern to be the butterfly pattern through training, and observe that it performs almost as well as the (dense) FlashAttention on the Long-range Arena tasks.
Efficient Transformer. Transformer-based models have become the most widely-used architecture in natural language processing and computer vision . However, one of their computational bottlenecks is that their time and memory scales quadratic in the sequence length. There are numerous approaches to overcome this bottleneck, including approximation with hashing (i.e., sparse) such as Reformer and Smyrf and with low-rank approximation such as Performer . One can even combine sparse and low-rank approximation for better accuracy (e.g., Longformer , BigBird , Scatterbrain , Long-short transformer , Combiner ). Other approaches include compressing along the sequence dimension to attend to multiple tokens at once . One can also attend over the states from previous sequences to help lengthen the context (e.g., Transformer-XL and Compressive Transformer ). We recommend the survey for more details.
There are several lines of work on developing other modules instead of attention to model longer context. HiPPO and its extensions, most notably S4 projects the history on a polynomial basis, allowing accurate reconstruction of the history through state-space models. They combine the strengths of CNNs (efficient training), RNNs (efficient inference), and continuous models (robust to change in sampling rates). LambdaNetworks , AFT and FLASH are other attempts at replacing attention in the context of image classification and language modeling.
Appendix B Algorithm Details
We first derive the forward and backward passes of attention and show that they can be computed in a memory-efficient manner (requiring extra memory linear instead of quadratic in the sequence length). Though they reduce the amount of extra memory required, naively they still incur quadratic HBM accesses, resulting in slower execution speed. We describe the FlashAttention algorithm to implement both the forward and the backward passes on GPUs that reduces HBM accesses, leading to both faster runtime and smaller memory footprint.
The main challenge in making attention memory-efficient is the softmax that couples the columns of (and columns of ). Our approach is to compute the softmax normalization constant separately to decouple the columns. This technique has been used in the literature to show that attention computation does not need quadratic extra memory (though the number of HBM accesses is still quadratic, resulting in slow run-time).
For simplicity, we omit here the max-shifting step during softmax. The full algorithm in Section B.3 contains all the steps.
Recall that given input sequences , we want to compute the attention output :
We have that where and are the -th and -th columns of and respectively. Define the normalization constants of softmax:
Let be the -th column of , then the -th columns of the output is
We see that once is computed, we can compute without extra memory by repeatedly summing . Therefore the forward pass can be computed with extra memory:
Compute for all according to Eq. 1, which takes extra memory.
Compute for all according to Eq. 2, which takes extra memory.
B.2 Memory-efficient backward pass
We derive the backward pass of attention and show that it can also be computed with linear memory. Rabe and Staats suggests that the backward pass can be done without quadratic extra memory by applying gradient checkpointing to the memory-efficient forward pass. We instead derive the backward pass explicitly and show how it can be computed in a memory-efficient manner.
Suppose that there is a scalar loss function , and let the output gradient be (where denotes ). We want to compute the input gradients (where denote respectively).
The gradient is easy to see. Applying reverse-mode autodiff by hand (aka the chain rule), we obtain (in matrix notation) . Thus:
Since we already computed , can be computed without extra memory by repeated summing.
The gradients and are a little more complicated. We go through the gradients and first. From Eq. 2, we have that , and so:
Recall that . Using the fact that the Jacobian of is , we have that
where denotes pointwise multiplication.
Now we can get the gradients and . Recall that , so
Therefore the backward pass can also be computed with extra memory:
Compute for all according to Eq. 3, which takes extra memory.
Compute for all according to Eq. 4, which takes extra memory.
Compute for all according to Eq. 5, which takes extra memory.
Compute for all according to Eq. 6, which takes extra memory.
B.3 FlashAttention: Forward Pass
We describe the full details of FlashAttention forward pass. Given input sequences , we want to compute the attention output :
where is some softmax scaling (typically ), mask is some masking function that sets some entries of the input to and keep other entries the same (e.g., key padding mask when sequences in the batch don’t have the same lengths and are padded), and applies dropout to elementwise (i.e., output with probability and output 0 with probability for each element ).
The full algorithm is in Algorithm 2. We save the output , the softmax statistics and , and the pseudo-random number generator state for the backward pass.
B.4 FlashAttention: Backward Pass
We describe the full details of FlashAttention backward pass. Given input sequences , the output , and the output gradient , we want to compute the input gradients .
We first describe the standard attention backward pass in Algorithm 3 for completeness.
We now make two observations about FlashAttention backward pass:
We do not need to store the dropout mask of size from the forward pass. Instead, we can save the pseudo-random number generator states from the forward pass and re-generate the dropout mask in the backward pass. This allows us to only use extra memory.
When computing the softmax gradient, we use Eq. 4 to compute without reducing over and of size (they might not fit into SRAM). Instead we can rewrite and compute the dot product between vectors of size .
The full FlashAttention backward pass algorithm is in Algorithm 4. Conceptually it is just a block version of the derivation in Section B.2.
~subscript𝐝𝐕𝑗superscriptsuperscriptsubscript𝐏𝑖𝑗droppedtopsubscript𝐝𝐎𝑖superscriptℝsubscript𝐵𝑐𝑑\tilde{\mathbf{dV}_{j}}\leftarrow\tilde{\mathbf{dV}_{j}}+(\mathbf{P}_{ij}^{\mathrm{dropped}})^{\top}\mathbf{dO}_{i}\in\mathbb{R}^{B_{c}\times d}. 17: On chip, compute . 18: On chip, compute (pointwise multiply). 19: On chip, compute . 20: On chip, compute . 21: Write to HBM. 22: On chip, compute . 23: end for 24: Write to HBM. 25: end for 26: Return . We see that similar to the forward pass, the backward pass performs FLOPs and only requires extra memory beyond inputs, output, output gradient, and input gradients.
We analyze the IO-complexity of the backward pass, similar to the forward pass (Theorem 2).
Let be the sequence length, be the head dimension, and be size of SRAM with . Standard attention (Algorithm ) backward pass requires HBM accesses, while FlashAttention backward pass (Algorithm 4) requires HBM accesses.
B.5 Comparison with Rabe and Staats [66]
We describe here some similarities and differences between our FlashAttention algorithm and the algorithm of Rabe and Staats .
Conceptually, both FlashAttention and Rabe and Staats operate on blocks of the attention matrix using the well-established technique of tiling (or softmax scaling) . To reduce the memory footprint, both methods avoid storing the large attention matrix in the forward pass and recompute it in the backward pass.
The first major difference is that Rabe and Staats focuses on the reducing the total memory footprint (maximum amount of GPU memory required) while FlashAttention focuses on reducing memory accesses (the number of memory reads/writes). As mentioned in Section 2, the amount of memory access is the primary determining factor of runtime. Reducing memory accesses also necessarily reduces the total amount of memory required (e.g., if an operation incurs memory accesses, then its total memory requirement is at most ). As a result, FlashAttention is faster than standard attention (2-4) while Rabe and Staats is around the same speed or slightly slower than standard attention. In terms of total memory required, both methods offer substantial memory saving.
The second difference between the two methods is the way information is summarized from each block to pass to the next block. Rabe and Staats summarizes each block with its temporary output along with the softmax normalization statistics. At the end of the forward pass, the temporary outputs of all the blocks are combined using the statistics to produce the final output. FlashAttention instead incrementally updates the output (Algorithm 1 line 12) after processing each block, so only one copy of the output is needed (instead of copies for blocks). This means that FlashAttention has smaller total memory requirement compared to Rabe and Staats .
The final major difference is the way the backward pass is computed. Rabe and Staats uses gradient checkpointing to recompute the attention matrix and the temporary output of each block. FlashAttention instead simplifies the backward pass analytically (Sections B.2 and B.4). It only recomputes the attention matrix and does not recompute the temporary output of each block. This reduces the memory requirement for the backward pass and yields speedup.
Appendix C Proofs
We first count the number of FLOPs and extra memory required.
The dominating FLOPs are from matrix multiplication. In the inner loop, (Algorithm 1 line 9), we compute for and , which takes FLOPs. We also compute (Algorithm 1 line 12) for and , which takes FLOPs. We execute the inner loops times. Therefore the total number of FLOPs is
In terms of extra memory required, we see that we need memory to store the statistics .
We now prove the algorithm’s correctness by induction on for . Let be the first rows of , and similarly the the first rows of . Let , and (softmax applied row-wise). Let be the values of in HBM after the -th iteration of the outer loop (Algorithm 1 line 5). (Note that these values of are updated after each iteration of the outer loop.) We want to show that after the -th iteration of the outer loop, we have computed in HBM:
Based on our initialization (Algorithm 1 line 2), this claim is true for (i.e., before the any iteration of the outer loop is executed). Suppose that the claim holds for some . We want to show that the claim also holds for . Indeed, when we update the statistics in the inner loop (Algorithm 1 line 10) on the -th iteration of the outer loop, we update where is the row-max of , the slice of from column to column . This implies that
𝑗1rowmaxsubscript𝐒::absent𝑗1superscriptℝ𝑁m^{(j+1)}=\mathrm{rowmax}(\mathbf{S}_{:,:j+1})\in\mathbb{R}^{N}. Similarly, we update
𝑗1superscript𝑒superscript𝑚𝑗superscript𝑚𝑗1superscriptℓ𝑗superscript𝑒~𝑚superscript𝑚𝑗1~ℓ\ell^{(j+1)}=e^{m^{(j)}-m^{(j+1)}}\ell^{(j)}+e^{\tilde{m}-m^{(j+1)}}\tilde{\ell}, where . By the same algebraic manipulation in Section 3.1, we obtain:
𝑗1rowsumsubscript𝐒::absent𝑗1superscript𝑚𝑗1superscriptℝ𝑁\ell^{(j+1)}=\mathrm{rowsum}(\exp(\mathbf{S}_{:,:j+1}-m^{(j+1)}))\in\mathbb{R}^{N}. Let be the slice of from column to column , we also update:
𝑗1\displaystyle\mathbf{O}^{(j+1)} We then see that the claim is also true for . By induction, the claim is true for all .
When , we conclude that the final value of in HBM is .
We first analyze the IO complexity of standard attention implementation. The inputs reside in HBM, and the at the end of the algorithm the output is written to HBM.
In the first step of computing the matrix multiply , the inputs are read from HBM and the output is written to HBM (Algorithm line 1). This incurs HBM accesses.
In the second step of computing , the input is read from HBM and the output is written to HBM (Algorithm line 2). This incurs HBM accesses.
In the last step of computing , the inputs are read from global memory and the output is written to HBM (Algorithm line 3). This incurs HBM accesses.
Overall, standard attention implementation requires global memory accesses.
We now analyze the IO complexity of streaming attention.
Following Algorithm 1, we see that each element of and is loaded from HBM once (Algorithm 1 line 6). We make passes over and , each pass loading all of and all of to HBM (Algorithm 1 line 8). Therefore the number of HBM accesses is .
We derive the conditions on the block sizes and . We need the blocks and of size to fit into on-chip memory, which translates to:
Similarly, we need the blocks of size to fit into on-chip memory, which translates to:
Finally, we need the block of size to fit into on-chip memory, which translates to:
As a result, the number of HBM accesses is:
For contradiction, suppose that there exists an algorithm that computes exact attention where the number for HBM access for all is
In the regime of , this results in the number of HBM accesses:
However, the input to attention (matrices ) and the output have size and they start out being in HBM, so if the algorithm computes exact attention it must incur at least HBM accesses. This is a contradiction. ∎
The IO complexity of the attention backward is very similar to the IO complexity of the attention forward (Theorem 2). Here we provide a sketch of the proof.
We first analyze the IO complexity of standard attention backward pass. The inputs reside in HBM, and the at the end of the algorithm the outputs are written to HBM.
At each step of the standard attention backward pass, one needs to load inputs of size or from HBM, and needs to write the outputs of size or to HBM. This incurs HBM accesses.
We now analyze the IO complexity of FlashAttention backward pass.
Similar to Theorem 2, we see that each element of and is loaded from HBM once. Each element of and is only written to HBM once. We make passes over , each pass loading all of to HBM. We also make passes over , each pass reading/writing all of from/to HBM. Therefore the number of HBM accesses is .
As in the proof of Theorem 2, the constraints on the block sizes are that:
As a result, the number of HBM accesses is:
Appendix D Extension Details
We describe the full block-sparse FlashAttention algorithm in Algorithm 5. The algorithm is identical to Algorithm 2, except that we skip zero blocks.
superscript𝑒subscript𝑚𝑖superscriptsubscript𝑚𝑖newsubscriptℓ𝑖superscript𝑒subscript~𝑚𝑖𝑗superscriptsubscript𝑚𝑖newsubscript~ℓ𝑖𝑗superscriptℝsubscript𝐵𝑟\ell_{i}^{\mathrm{new}}=e^{m_{i}-m_{i}^{\mathrm{new}}}\ell_{i}+e^{\tilde{m}_{ij}-m_{i}^{\mathrm{new}}}\tilde{\ell}_{ij}\in\mathbb{R}^{B_{r}}. 14: On chip, compute . 15: Write to HBM. 16: Write , to HBM. 17: end if 18: end for 19: end for 20: Return . We prove the IO-complexity of block-sparse FlashAttention.
The proof is very similar to the proof of Theorem 2. For the block-sparse case, notice that we only need to load blocks corresponding to nonzero blocks. As a result, the number of HBM accesses are scaled by , the fraction of nonzero blocks in the block-sparsity mask. However, for small values of , we would still need to write the result . Therefore the number of HBM accesses is
𝑁𝑑superscript𝑁2superscript𝑑2𝑀𝑠\Theta\left(Nd+\frac{N^{2}d^{2}}{M}s\right). ∎
D.2 Potential Extensions
We discuss here a few potential extensions of the IO-aware approach to speed up deep learning training.
Multi-GPU Attention. Large language models are trained on hundreds or thousands of GPUs, and one typically splits the attention computation between 4-8 GPUs on the same node . This introduces another level of memory hierarchy: beside GPU SRAM and GPU HBM, we also have the HBM of other GPUs. For very long sequences, the different GPUs on the same node can cooperate to compute attention by taking into account the asymmetry of different levels of memory hierarchy.
Sparse MLP layers. Typical dense MLP layers are compute-bound and not memory-bound. To improve their efficiency, MLP layers with sparse weight matrices can be used . However, many sparse MLP layers are instead memory-bound, and their speedup is often not proportional to the sparsity. We believe that an IO-aware implementation can alleviate this issue and realize the benefits of sparsity. We are excited about future work in this direction, to reduce the computational requirement of large models and improve their wall-block runtime.
Kernel machine learning. Our approach in FlashAttention relies on the fact that the attention matrix is a function of a low-rank matrix (of rank ). As a result, we can repeatedly load the inputs and recompute the block of the attention matrix that we need, significantly reducing HBM access. As similar scenario happens in kernel machine learning: each element of the kernel matrix is a function of two vectors of size , as it measures the similarity between two datapoints and . The KeOps library is a successful example of how reducing memory reads/writes can speed up kernel operations. We hope that this will motivate kernel methods that focus more on reducing IOs instead of just FLOPs.
Appendix E Full Experimental Results
We train BERT-large following the training procedure and hyperparameters of the reference MLPerf 1.1 implementation. In particular, we use the LAMB optimizer with learning rate 3.75e-3, with batch size 448, trained for at most 7100 steps. The training is stopped once the validation accuracy (for masked language modeling) reaches the target 72.0%, and the wall-clock run-time is measured. We train with FP16 precision using Apex AMP (with O2 optimization level).
We compare our results with the reported training speed from Nvidia that was submitted to MLPerf 1.1 (Table 1).
We use the same train / validation data split provided by MLPerf 1.1 reference implementation. In particular, we evaluate on the same 10000 validation examples as the baseline from Nvidia.
We train the model on 8A100-80GB GPUs. Each training run takes between 16 and 19 minutes, and we average the results of 10 runs.
E.2 GPT-2
We use the standard implementations of GPT-2 from Huggingface transformers library and from Nvidia’s Megatron-LM repo. We follow the training recipe of the Megatron-LM repo.
We use an effective batch size of 512, and use gradient accumulation to fit into available GPU memory. We use the AdamW optimizer, with learning rate 6e-4 for GPT-2 small and 1.5e-4 for GPT-2 medium, and weight decay of 0.1. All models are trained with the same hyperparameters for 400K steps. We run all implementations with mixed-precision training (PyTorch AMP).
We use the Openwebtext dataset, with the GPT-2 BPE tokenizer. We randomly select 0.5% of the dataset as the validation set, with the rest being used as training set. This random selection of validation set is done once, and all models are evaluated on the same validation set.
We train the model on 8A100-40GB GPUs, and we measure the wall-clock training time. Training GPT-2 small takes between 2.7-9.5 days, and training GPT-2 medium takes between 6.9-21.0 days (Table 2).
In Fig. 4, we plot of the validation perplexity throughout training of GPT-2 small/medium, using either HuggingFace implementation or our FlashAttention implementation. We see that FlashAttention behaves the same as the baseline implementation and the validation perplexity curves of the two implementations almost lie on top of each other.
For MIMIC-III and ECtHR, we follow the hyperparameters of Dai et al. .
E.3 LRA details
We follow the hyperparameters from the Long-range arena paper , the Long-range arena repo (https://github.com/google-research/long-range-arena), and the Nyströmformer reproduction . To be generous to the baseline methods, if we are unable to reproduce the performance of any baseline for any of the five tasks, we report the better performance from Tay et al. or Xiong et al. for that baseline on that task.
After hyperparameter tuning, almost all of the attention methods achieve similar accuracy on all of the five LRA tasks.
We run all methods with mixed-precision training, except for Performer (not stable with mixed precision) and Local Attention (implementation does not support FP16).
To calculate the overall wallclock-time speedup, we take the geometric mean of the wallclock-time speedup of each of the five tasks.
For Path-X and Path-256, we follow the hyperparameters from the PathFinder-32 experiments from the long-range arena paper. For both, we first pretrain a model on Path-64. We take the checkpoint after 200 epochs, upsample its positional embedding (we duplicate the positional embeddings gridwise in space), and fine-tune it on the downstream task for 200 epochs with one epoch of linear warmup, and cosine decay of the learning rate. For Path-X, we take the best performing checkpoint (according to val accuracy), and additionally fine-tune it for 200 epochs with the same warmup and learning rate (this adds roughly 4 points of accuracy to FlashAttention for Path-X, but the model starts overfitting afterwards).
E.4 Comparison with Apex FMHA
We compare our method/implementation with Apex FMHA (https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha).
When we started this project, Apex FMHA was the fastest implementation of attention (that we knew of), tailored for short sequences of length at most 512. In fact, almost all MLPerf submissions for BERT training benchmark running on Nvidia GPUs use FMHA for their model code, as of MLPerf 1.1 . Since FMHA targets BERT models, it only supports head dimension 64, and only runs on A100 GPUs. FMHA fuses the attention computation into one CUDA kernel. In the forward pass, it stores the attention matrix to HBM to be used in gradient computation. As a result, it does not offer substantial memory saving (though for shorter sequences memory footprint is often not a primary concern).
We use FMHA code as a starting point, and apply two well-established techniques (tiling and recomputation) to deal with long sequences and to save memory as mentioned in Section 3. As a result, we can support much longer sequences (e.g., up to length 64K). We also support more head dimensions (16, 32, 64, 128) and broader GPU types (all Turing and Ampere GPUs at the time of writing).
In Table 7, we compare the performance of FlashAttention and Apex FMHA for short sequences (as FMHA only supports sequence length at most 512). Generally FlashAttention is slightly faster than FMHA in the forward pass and slightly slower than FMHA in the backward pass. This is because we do not store the attention matrix in the forward pass and recompute it in the backward pass. Compared to FMHA, the overall runtime of FlashAttention is about 4% slower for sequence length 128, 8% faster for sequence length 256, and 5% faster for sequence length 512.
E.5 Speedup On Different Hardware and Configurations
Speedup varies between different types of GPU types and generations depending on HBM bandwidth and SRAM size. In this section, we profile FlashAttention speedup on different GPUs and configurations.
Figure 5 shows speedup on an A100 GPU with batch size 8, head dimension 64, and 12 attention heads, across different sequence lengths. We generally see 2-4 speedup, and we see more speedup when using dropout and masking due to kernel fusion.
Speedup also changes when we increase the head dimension. Each block requires more memory, so we need to use smaller block sizes to fit into SRAM. Figure 6 shows speedup with head dimension 128 on an A100 (batch size 16, 12 heads). We see less speedup overall—but we can still see significant speedup (up to 3) with a causal mask, where half the blocks are masked out.
Figure 7 shows speedup on an RTX 3090 GPU. Here, we use batch size 12 with 12 attention heads. We observe slightly higher speedups on the RTX 3090 (between 2.5-4.5), since the memory bandwidth on an RTX 3090 is lower than on an A100 (roughly 900 GB/s vs. 1.5 TB/s).
Figure 8 shows speedup on a T4 GPU. T4 SRAM is smaller than A100, so we need to make the block sizes smaller in FlashAttention. As a result, we observe less speedup on T4, which matches the IO complexity analysis in Section 3.2. T4 GPUs are commonly used for inference, so we also report speedup on the forward pass only.
E.6 Full Benchmarking Results
We report the full benchmarking results and experimental details on A100.
We compare against reference implementations for exact attention from PyTorch/HuggingFace and Megatron, approximate attention, and sparse attention. For approximate attention, we compare against reference implementations of Reformer , Local Attention , Linformer Attention , Smyrf , and LongShortFormer (LSFormer) . For sparse attention, we compare against reference implementations of Block-Sparse Attention form OpenAI , Longformer, and BigBird Attention . For the approximate and sparse attention, we use a compression ratio of 1/8, or a compressed sequence length of 256, whichever is smaller.
We measure runtime and memory usage of the attention computation with 8 heads of dimension 64, and batch size 16 on a machine with one A100 GPU with 40 GB of GPU HBM. We vary sequence length in our experiments. We compute attention on random vectors for , , and (we do not measure the projection from the hidden layer). For dropout, we use dropout 0.1; for masking, we use a padding mask with uniformly-random mask lengths between the total sequence length and the total sequence length minus 20. To measure runtime, we take the average of 100 measurements of the attention call. We only measure memory footprint once, since it does not vary between runs.
We report timing results on the forward pass, backward pass, and combined forward + backward pass. We measure each method with and without dropout, masking, or both—except for Block Sparse, Longformer, and BigBird. These methods did not successfully run the backward pass with masking due to a bug in external libraries, so we measured them without masking to be generous. We use FP16 for all measurements, except for Local Attention, whose implementation only supports FP32.
For each baseline, we increase sequence length until it runs out of memory on the GPU, except for the following exceptions: The Megatron implementation does not support sequence lengths longer than 2048. Block-Sparse (OpenAI) does not support sequence lengths longer than 4096. Longformer and BigBird do not support sequence lengths longer than 8092.
We measure memory usage on the combined forward + backward pass, without dropout or masking.
Table 8 summarizes all the experimental configurations and contains pointers to the results tables.