Self-attention Does Not Need $O(n^2)$ Memory

Markus N. Rabe, Charles Staats

Introduction

Attention (Bahdanau et al., 2015) is widely used in modern neural architectures. In particular, it is the heart of the Transformer architecture (Vaswani et al., 2017), which has revolutionized Natural Language Processing (Devlin et al., 2019), and found wide-spread adoption across several research areas since then.

Given a query qRdq\in\mathbb{R}^{d} and lists of keys and values k1,,knk_{1},\dots,k_{n} and v1,,vnRdv_{1},\dots,v_{n}\in\mathbb{R}^{d} of length nn, attention is defined as follows:

The result of the attention operation for a single query, is hence a weighted sum of the value vectors, where the weights are the softmax of the dot products of the query and the keys.

The straight-forward implementation of the attention operation above requires us to first compute and remember sis_{i} for all ii, leading to a O(n)O(n) time and memory complexity for each query. Transformers use self-attention, which issues a separate query for each position in the sequence, so the overall time and space complexity is O(n2)O(n^{2}).

In many works the quadratic time and space complexity of self-attention has been used as the motivation for the investigation of variants of the original attention mechanism and architectures with more favorable complexity classes (Kitaev et al., 2020; Roy et al., 2021; Zaheer et al., 2020; Choromanski et al., 2020; Wang et al., 2020; Ren et al., 2021; Child et al., 2019; Tay et al., 2021; Wang et al., 2020; Ma et al., 2021; Shen et al., 2021; Qiu et al., 2020). Modern accelerator hardware, such as GPUs and TPUs, are often memory constrained for applications in deep learning, while compute is relatively cheap. So the space complexity of transformers is a particular concern, c.f. Kitaev et al. (2020); Roy et al. (2021); Zaheer et al. (2020).

In this work, we present new algorithms for attention and self-attention that require only constant memory and logarithmic memory, respectively. The basic algorithm is very simple; but it requires a trick to make it numerically feasible (see Section 3). We also present an implementation in JAX (Bradbury et al., 2018), which runs efficiently on TPUs, and requires O(n)O(\sqrt{n}) memory for self-attention (see Section 4).

Unlike other works that aim to reduce the memory complexity of attention, the memory-efficient algorithm for attention that we suggest is not an approximation, but computes the same function. We can hence use the memory-efficient algorithm as a drop-in replacement for other attention implementations to save memory. This may allow us to reconsider architecture choices, or scale to new datasets that require longer, dense attention. However, our algorithm still requires O(n2)O(n^{2}) time complexity for self-attention and O(n)O(n) time complexity for single-query attention, and the various efficient, long-context attention mechanisms remain an interesting alternative to (dense) attention.

Algorithm

First, we present the algorithm for the attention operation with a single query and extend the algorithm to self-attention at the end of this Section. We observe that the division by jesj\sum_{j}e^{s_{j}} can be moved to the very end of the attention operation using the distributive law:

After publishing our initial draft, we were made aware that (1) is a rediscovery of the “lazy softmax" method of Jang et al. (2019, equation 4). Unfortunately their paper went in a different direction and did not discuss the memory complexity implications and other innovations we present in the remainder of this paper. For more details see Section 6.

This can be computed with constant memory: The memory overhead of this algorithm consists of a vector vRdv^{*}\in\mathbb{R}^{d} and a scalar sRs^{*}\in\mathbb{R}, both initialized with 0. Given the query qq, keys k1,,knk_{1},\dots,k_{n} and values v1,,vnv_{1},\dots,v_{n}, we process the keys and values in sequence. Given a key value pair kik_{i}, viv_{i}, we compute si=dot(q,ki)s_{i}=\mathrm{dot}(q,k_{i}) and update vv+viesiv^{*}\leftarrow v^{*}+v_{i}e^{s_{i}} and ss+esis^{*}\leftarrow s^{*}+e^{s_{i}}. After processing all keys and values, we divide vs\frac{v^{*}}{s^{*}} to get the final result.

The analysis of space complexity assumes that inputs are given in a particular order: we first read the query, and then a list of pairs of keys and values. If the inputs are provided in a different order, we have to additionally store an index into the sequence, requiring O(logn)O(\log n) memory instead.

To extend this algorithm to self-attention, we compute the results to all queries sequentially. This requires just one additional index into the list of queries, giving rise to the O(logn)O(\log n) memory complexity. Note that the operation produces outputs that are linear in the size of the number of queries, i.e., O(n)O(n), which is not counted towards the space complexity.

Numerical Stability

The formulation of standard attention that we presented in the Introduction, as well as our memory-efficient algorithm, are not numerically stable when using floating point arithmetic, because the softmax exponentiates the scores. For scores 89\geq 89 the exponentiation results in inf (for bfloat16 and float32), which will be carried through to the final result of the attention operation. In practice, the softmax is implemented by subtracting the maximum score from all scores. This does not change the result of the softmax, but avoids this numerical problem.

Our incremental computation of the sum of exponentiated scores (and the values times the scores) does not immediately allow for the same trick, as the maximum may depend on the last score in the sequence. But the subtraction cannot be delayed either, since the scores must be exponentiated before they can be added to the cumulative sum.

To resolve this problem, we introduce an additional scalar, which keeps track of the maximum score that the incremental algorithm has seen so far, and we renormalize the sums of exponentiated values as needed: We initialize the vector vRdv^{*}\in\mathbb{R}^{d} and scalar sRs^{*}\in\mathbb{R} with 0, and mm^{*} with inf-\mathrm{inf}. As before, given a key value pair kik_{i}, viv_{i}, we compute si=dot(q,ki)s_{i}=\mathrm{dot}(q,k_{i}), but then the algorithm differs slightly from Section 2. We first compute mi=max(m,si)m_{i}=\max(m^{*},s_{i}) and update vvemmi+viesimiv^{*}\leftarrow v^{*}e^{m^{*}-m_{i}}+v_{i}e^{s_{i}-m_{i}} and ssemmi+esimis^{*}\leftarrow s^{*}e^{m^{*}-m_{i}}+e^{s_{i}-m_{i}} and mmim^{*}\leftarrow m_{i}. After processing all keys and queries, we divide vs\frac{v^{*}}{s^{*}} to get the final result.

An Implementation For TPUs

In this section, we provide a version of the algorithm above that exploits the massive parallelism of modern hardware, such as GPUs or TPUs. The naive algorithm above is is not trivial to parallelize for a compiler, as the incremental sum introduces a dependency across all keys and values.

We present the entire implementation, including the support for multiple attention heads and memory-efficient differentiation in Figure 1. The implementation does not optimize strictly for memory efficiency, but instead aims to strike a balance between simplicity, computational efficiency, and memory requirements.

To exploit the parallelism available in modern hardware, we split the computation into chunks at the cost of some additional memory. In the outer loop (lines 54-55), we split the queries in to chunks of constant size, resulting in a linear number of iterations. In each iteration of the outer loop, we call _query_chunk_attention, which itself processes the keys and values in chunks (lines 30-31). The chunks are processed sequentially and each chunk is summarized independently (lines 12 to 19). Assuming a chunk size of n\sqrt{n} for the keys and values, we hence obtain n\sqrt{n} summaries, giving rise to the O(n)O(\sqrt{n}) memory complexity.

After the summaries are computed, they need to be rescaled (lines 33 to 36) along the lines of Section 3, before we return the values divided by the weights (line 40). The result of each iteration of the outer loop is directly written to the output tensor res (line 54), so that no additional memory is consumed across iterations. (A multi-stage summarization approach could achieve O(logn)O(\log n) but would complicate the implementation.)

While a constant chunk size for the queries and a chunk size of n\sqrt{n} for the keys and values is optimal for memory consumption, the runtime is also affected by the choice of chunk size in practice, which is heavily affected by the choice of hardware. Ultimately, we have to leave this trade-off to the programmer, and expose the chunk sizes as arguments query_chunk_size and key_chunk_size. In Figure 1 we provide default values for the chunk sizes that lead to minimal runtime impact on TPU, while still providing significant memory savings.

Empirical Analysis

In this section, we experimentally compare the memory requirements and runtime performance of the suggested algorithm compared to the implementation of attention currently provided by Flax (Heek et al. (2020), see flax/linen/attention.py). We open-sourced the code of our implementation and most of the evaluation as a colab to help others reproduce the results: https://github.com/google-research/google-research/tree/master/memory_efficient_attention.

In Table 2 we compare the memory requirements and the compute time of the memory-efficient attention implementation and the Flax implementation of attention. The size of inputs and outputs includes the query, key, and value tensors of dtype bfloat16, and the output tensor of dtype float32. We measure the memory overhead as the TPUs peak memory in excess of the input and output tensors. All computations were done on a single TPUv3 chip. For this experiment, we only use one attention head.

Our memory-efficient implementation of attention removes the memory bottleneck of self-attention, scaling at least to a sequence length of 1M. At this sequence length the algorithm is multiplying over 1 trillion combinations of queries and keys. The time complexity is still quadratic.

The “relative compute speed” of the implementations was computed as the median over 100 runs—but the numbers still fluctuated across multiple runs of the evaluation and we only provide them to demonstrate that the runtime performance is roughly similar. Please note that this experiment analyzes the attention operation in isolation; the measured relative performance is not necessarily the same when the operations are embedded in larger architectures. In fact, we observed a slight increase in steps/sec of about 4% when training a small Transformer.

For all cases where the standard attention would not OOM (i.e. require >16>16GB device memory), we checked that the results of the two implementations are within 1.8×1071.8\times 10^{-7} for inputs drawn from a normal distribution with standard deviation 11 (measured as the maximal absolute difference of any dimension in a self-attention over sequence length 2142^{14}).

2 Differentiation

During the forward pass our algorithm saves memory by summarizing parts of the attention matrix sequentially, allowing it to forget the parts of the attention matrix it has summarized already. A naive application of differentiation would have to store all those intermediate results and our algorithm would loose its memory advantage entirely. So we apply checkpointing (Chen et al., 2016) in line 11 to the function that summarizes the individual chunks. The intermediate results can thus be forgotten during the forward pass and recomputed during backpropagation.

In Table 3 we compare runtime and peak memory during differentiation of our implementation to standard attention. We used the same setting as for the forward pass, but applied jax.grad to an arbitrarily chosen loss function (the sum of the results). The relative compute speed was reduced significantly compared to standard attention. This is expected when using checkpointing since some values must be recomputed during backpropagation.

Note that applying checkpointing to the standard attention algorithm would not achieve these results. The standard algorithm with checkpointing would forget the attention matrix after it is formed; our algorithm never forms the full attention matrix at all.

3 Training

0.2<spanclass="katexdisplay"><spanclass="katex"><spanclass="katexmathml"><mathxmlns="http://www.w3.org/1998/Math/MathML"display="block"><semantics><mrow><mn>0.4</mn></mrow><annotationencoding="application/xtex">0.4</annotation></semantics></math></span><spanclass="katexhtml"ariahidden="true"><spanclass="base"><spanclass="strut"style="height:0.6444em;"></span><spanclass="mord">0.4</span></span></span></span></span>0.6<spanclass="katexdisplay"><spanclass="katex"><spanclass="katexmathml"><mathxmlns="http://www.w3.org/1998/Math/MathML"display="block"><semantics><mrow><mn>0.8</mn></mrow><annotationencoding="application/xtex">0.8</annotation></semantics></math></span><spanclass="katexhtml"ariahidden="true"><spanclass="base"><spanclass="strut"style="height:0.6444em;"></span><spanclass="mord">0.8</span></span></span></span></span>1<spanclass="katexdisplay"><spanclass="katex"><spanclass="katexmathml"><mathxmlns="http://www.w3.org/1998/Math/MathML"display="block"><semantics><mrow><mo></mo><msup><mn>10</mn><mn>5</mn></msup></mrow><annotationencoding="application/xtex">105</annotation></semantics></math></span><spanclass="katexhtml"ariahidden="true"><spanclass="base"><spanclass="strut"style="height:0.8641em;"></span><spanclass="mord"></span><spanclass="mord">1</span><spanclass="mord"><spanclass="mord">0</span><spanclass="msupsub"><spanclass="vlistt"><spanclass="vlistr"><spanclass="vlist"style="height:0.8641em;"><spanstyle="top:3.113em;marginright:0.05em;"><spanclass="pstrut"style="height:2.7em;"></span><spanclass="sizingresetsize6size3mtight"><spanclass="mordmtight"><spanclass="mordmtight">5</span></span></span></span></span></span></span></span></span></span></span></span></span>100.2<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mn>0.4</mn></mrow><annotation encoding="application/x-tex">0.4</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">0.4</span></span></span></span></span>0.6<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mn>0.8</mn></mrow><annotation encoding="application/x-tex">0.8</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6444em;"></span><span class="mord">0.8</span></span></span></span></span>1<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mo>⋅</mo><msup><mn>10</mn><mn>5</mn></msup></mrow><annotation encoding="application/x-tex">\cdot 10^{5}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8641em;"></span><span class="mord">⋅</span><span class="mord">1</span><span class="mord"><span class="mord">0</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8641em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">5</span></span></span></span></span></span></span></span></span></span></span></span></span>102020training stepBLEU scoreStandard attnMemory-efficient attn Figure 4: BLEU scores of a two Transformer models trained with standard attention and memory-efficient attention. 100<spanclass="katexdisplay"><spanclass="katex"><spanclass="katexmathml"><mathxmlns="http://www.w3.org/1998/Math/MathML"display="block"><semantics><mrow><msup><mn>10</mn><mn>1</mn></msup></mrow><annotationencoding="application/xtex">101</annotation></semantics></math></span><spanclass="katexhtml"ariahidden="true"><spanclass="base"><spanclass="strut"style="height:0.8641em;"></span><spanclass="mord">1</span><spanclass="mord"><spanclass="mord">0</span><spanclass="msupsub"><spanclass="vlistt"><spanclass="vlistr"><spanclass="vlist"style="height:0.8641em;"><spanstyle="top:3.113em;marginright:0.05em;"><spanclass="pstrut"style="height:2.7em;"></span><spanclass="sizingresetsize6size3mtight"><spanclass="mordmtight"><spanclass="mordmtight">1</span></span></span></span></span></span></span></span></span></span></span></span></span>102<spanclass="katexdisplay"><spanclass="katex"><spanclass="katexmathml"><mathxmlns="http://www.w3.org/1998/Math/MathML"display="block"><semantics><mrow><msup><mn>10</mn><mn>3</mn></msup></mrow><annotationencoding="application/xtex">103</annotation></semantics></math></span><spanclass="katexhtml"ariahidden="true"><spanclass="base"><spanclass="strut"style="height:0.8641em;"></span><spanclass="mord">1</span><spanclass="mord"><spanclass="mord">0</span><spanclass="msupsub"><spanclass="vlistt"><spanclass="vlistr"><spanclass="vlist"style="height:0.8641em;"><spanstyle="top:3.113em;marginright:0.05em;"><spanclass="pstrut"style="height:2.7em;"></span><spanclass="sizingresetsize6size3mtight"><spanclass="mordmtight"><spanclass="mordmtight">3</span></span></span></span></span></span></span></span></span></span></span></span></span>104<spanclass="katexdisplay"><spanclass="katex"><spanclass="katexmathml"><mathxmlns="http://www.w3.org/1998/Math/MathML"display="block"><semantics><mrow><mo></mo><mn>100</mn></mrow><annotationencoding="application/xtex">100</annotation></semantics></math></span><spanclass="katexhtml"ariahidden="true"><spanclass="base"><spanclass="strut"style="height:0.7278em;verticalalign:0.0833em;"></span><spanclass="mord"></span><spanclass="mord">100</span></span></span></span></span>80<spanclass="katexdisplay"><spanclass="katex"><spanclass="katexmathml"><mathxmlns="http://www.w3.org/1998/Math/MathML"display="block"><semantics><mrow><mo></mo><mn>60</mn></mrow><annotationencoding="application/xtex">60</annotation></semantics></math></span><spanclass="katexhtml"ariahidden="true"><spanclass="base"><spanclass="strut"style="height:0.7278em;verticalalign:0.0833em;"></span><spanclass="mord"></span><spanclass="mord">60</span></span></span></span></span>4010^{0}<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msup><mn>10</mn><mn>1</mn></msup></mrow><annotation encoding="application/x-tex">10^{1}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8641em;"></span><span class="mord">1</span><span class="mord"><span class="mord">0</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8641em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span></span></span></span></span></span></span></span></span>10^{2}<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msup><mn>10</mn><mn>3</mn></msup></mrow><annotation encoding="application/x-tex">10^{3}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8641em;"></span><span class="mord">1</span><span class="mord"><span class="mord">0</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8641em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">3</span></span></span></span></span></span></span></span></span></span></span></span></span>10^{4}<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mo>−</mo><mn>100</mn></mrow><annotation encoding="application/x-tex">-100</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">−</span><span class="mord">100</span></span></span></span></span>-80<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mo>−</mo><mn>60</mn></mrow><annotation encoding="application/x-tex">-60</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">−</span><span class="mord">60</span></span></span></span></span>-4020-20query chunk sizerelative runtime of query chunking in % 104<spanclass="katexdisplay"><spanclass="katex"><spanclass="katexmathml"><mathxmlns="http://www.w3.org/1998/Math/MathML"display="block"><semantics><mrow><msup><mn>10</mn><mn>5</mn></msup></mrow><annotationencoding="application/xtex">105</annotation></semantics></math></span><spanclass="katexhtml"ariahidden="true"><spanclass="base"><spanclass="strut"style="height:0.8641em;"></span><spanclass="mord">1</span><spanclass="mord"><spanclass="mord">0</span><spanclass="msupsub"><spanclass="vlistt"><spanclass="vlistr"><spanclass="vlist"style="height:0.8641em;"><spanstyle="top:3.113em;marginright:0.05em;"><spanclass="pstrut"style="height:2.7em;"></span><spanclass="sizingresetsize6size3mtight"><spanclass="mordmtight"><spanclass="mordmtight">5</span></span></span></span></span></span></span></span></span></span></span></span></span>106<spanclass="katexdisplay"><spanclass="katex"><spanclass="katexmathml"><mathxmlns="http://www.w3.org/1998/Math/MathML"display="block"><semantics><mrow><mo></mo><mn>60</mn></mrow><annotationencoding="application/xtex">60</annotation></semantics></math></span><spanclass="katexhtml"ariahidden="true"><spanclass="base"><spanclass="strut"style="height:0.7278em;verticalalign:0.0833em;"></span><spanclass="mord"></span><spanclass="mord">60</span></span></span></span></span>4010^{4}<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><msup><mn>10</mn><mn>5</mn></msup></mrow><annotation encoding="application/x-tex">10^{5}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8641em;"></span><span class="mord">1</span><span class="mord"><span class="mord">0</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8641em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">5</span></span></span></span></span></span></span></span></span></span></span></span></span>10^{6}<span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mo>−</mo><mn>60</mn></mrow><annotation encoding="application/x-tex">-60</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.7278em;vertical-align:-0.0833em;"></span><span class="mord">−</span><span class="mord">60</span></span></span></span></span>-4020-20sequence length for self-attentionrelative runtime in % Figure 5: Left: Relative runtime of self-attention on sequence length 2152^{15} using query chunking compared to standard attention. Right: Relative runtime of self-attention using query chunking compared to our memory-efficient algorithm, where both are restricted to the same amount of memory. We integrated our memory-efficient implementation into a simple Transformer architecture provided in the Flax library, and ran the WMT en-de translation experiment with the standard attention module and with the memory-efficient attention module. Throughout the training, the two implementations behaved almost identically. After 100K training steps, the evaluation accuracy reached 62.69 for the memory-efficient implementation and 62.59 for the standard implementation. This demonstrates that our memory-efficient implementation of self-attention can be used to replace existing implementations. Figure 4 illustrates that both models resulted in very similar BLEU scores. We used the default settings for the WMT en-de experiment as given in the Flax implementation, except that we had to deactivate example packing to simplify the masking code. This also required us to lower the learning rate to 0.005.

4 Comparison to Query Chunking

The algorithms introduced in this work chunk both the keys and the queries. Chunking the only queries has been explored already by Kitaev et al. (2020), but it is folklore that it slows down the computation significantly. In Figure 5 (left), we plot the runtime of self-attention using query-chunking for different query chunk sizes compared to dense self-attention: we see that for small chunk sizes (e.g. 64\leq 64) the performance suffers indeed, but for large chunk sizes, the loss of performance is less significant. So, while lower memory consumption can be achieved by query chunking alone, small values for query chunking are impractical.

In comparison to query chunking, memory-efficient attention can save additional memory by chunking also the keys. This can help to keep the query chunk size at a desirable point given a fixed memory limit. In Figure 5, we constrained query chunking to the amount of memory that is used by memory-efficient attention with the default settings for key and query chunk size (see Table 2, “Memory overhead of memory-efficient att.”, we rounded the query chunk size towards the benefit of query chunking). We see that as the sequence length increases, query chunking eventually slows down significantly as the query chunk size has to be lowered to 64\leq 64, while memory-efficient attention does not suffer a major slowdown (see Table 2, “Relative compute speed”). So, in memory-constrained scenarios, memory-efficient attention can outperform query chunking.

Related Work

After publishing our initial draft, we were made aware that Jang et al. (2019) already observed that the division of the softmax operation can be delayed until the end of the attention operation (“lazy softmax”), similar to our Equation (1). But their paper does not discuss memory complexity at all. They also do not address numerical stability or backpropagation, and, as far as we know, there is no publicly available implementation of their work. Instead they use this algorithm to reduce the memory bandwidth for inference when sharding key-value pairs across multiple chips.

Dao et al. (2022) provide a CUDA implementation of memory-efficient attention and demonstrate that the reduced memory requirements can translate to significant speedups on GPUs. One reason why we do not observe the same performance gains in this paper is that standard self-attention already balances the available FLOPs and memory bandwidth of TPUs.

Conclusion

This paper presents a simple trick to reduce the memory requirement of (self-)attention dramatically, which appears to have been simply overlooked by the community. We hope that this short paper raises awareness of the fact that attention is not intrinsically memory-hungry, which may allow us to revisit some of the design choices in popular neural architectures and hardware architectures.

Acknowledgements

We want to thank Andrew Jaegle for discussions on this paper, and for experimenting with memory-efficient attention in the context of Perceiver (Hawthorne et al., 2022). We are glad to see that the algorithm proposed here has already found interest and would like to thank Rezaei (2021) and Wang (2022) for reimplementations in JAX and PyTorch with additional features like masking. We also want to thank DeLesley Hutchins for detailed feedback on our draft.

References