LongNet: Scaling Transformers to 1,000,000,000 Tokens
Jiayu Ding, Shuming Ma, Li Dong, Xingxing Zhang, Shaohan Huang, Wenhui Wang, Nanning Zheng, Furu Wei
Introduction
Recent years have witnessed a trend toward scaling neural networks . The depth is primarily scaled up for exponential expressivity, producing many powerful deep networks . Then, the sparse MoE models and model parallelism approaches efficiently enlarge the hidden dimension. Sequence length, as the last atomic dimension of the neural network, is desirable to be unlimited. Breaking the limitation of sequence length introduces significant advantages. First, it provides large memory and receptive field for models, which is practical for them to interact with human and the world. Second, a longer context contains more complex causality and reasoning paths that models can exploit in training data. In contrast, short dependency has more spurious correlations, which is harmful to generalization. Third, it enables to explore the limits of in-context learning, which has the potential to be a paradigm shift for many-shot learning, as an extremely long context may help the models alleviate catastrophic forgetting.
The major challenge of scaling up sequence length is striking the right balance between the computational complexity and the model expressivity. RNN-style models are primarily implemented to increase the length. However, its sequential nature limits the parallelization during training, which is essential in long-sequence modeling. More recently, state space models are appealing to sequence modeling. It can operate as a CNN during training, and transform to an efficient RNN at test time. While they perform well at long-range benchmarks , their performance on regular lengths is not as good as Transformers, limited mainly by the model expressivity .
Another strand of scaling the sequence length is to decrease the complexity of Transformers, i.e., the quadratic complexity of self-attention. Implementing sliding windows or convolution modules over the attention is a straightforward way to make the complexity nearly linear. Nevertheless, this sacrifices the ability to recall the early tokens, forgetting the prompts at the very beginning of the sequence. Sparse attention reduces the computation by sparsifying the attention matrix, preserving the possibility of recalling long-distant information. For example, obtains time complexity with a fixed sparse pattern. Besides the heuristic patterns , the learnable patterns prove to be useful for sparse attention . There are also some other efficient Transformer-based variants, including low-rank attention , kernel-based methods , downsampling approaches , recurrent models , and retrieval-based methods . Yet, none has been scaled to 1 billion tokens (see Figure 1).
In this work, we successfully scale the sequence length to 1 billion tokens. Our solution is LongNet, which replaces the attention of vanilla Transformers with a novel component named dilated attention. The general design principle is - attention allocation decreases exponentially as the distance between tokens grows. We prove that it obtains a linear computation complexity and a logarithm dependency between tokens. This deals with the contradiction between limited attention resources and the accessibility to every token. In the implementation, LongNet can be transformed into a dense Transformer, which seamlessly supports the off-the-shelf optimization for Transformers (e.g., kernel fusion, quantization, and distributed training). Taking advantage of the linear complexity, LongNet can parallelize the training across nodes, breaking the constraint of both computation and memory with a distributed algorithm. This allows us to efficiently scale up the sequence length to 1B tokens with nearly constant runtime (see Figure 5), while vanilla Transformer suffers from quadratic complexity.
LongNet
Self-attention struggles with long sequences, due to its quadratic dependency on the sequence length. One query would attend to all keys and values, leading to computational inefficiencies.
Sparse attention alleviates this issue by restricting the query’s access to a subset of keys and values. The key of sparse attention is the sparse attention pattern , which determines specific keys and values that the query can attend to.
For example, the fixed pattern of sparse Transformer is composed of a local pattern and a strided pattern. The sequence is divided into blocks of length . The local pattern allows one query to attend to tokens within the same block, while strided pattern allows one query to attend to the last tokens of each block. Formally, the local pattern , and the strided pattern .
2 Dilated Attention
Figure 2 illustrates the overview of dilated attention. Dilated attention splits the input (, , ) into segments equally with a segment length . Each segment is then sparsified along the sequence dimension by selecting the rows with an interval . The computation can be written as:
The sparsified segments are fed into the attention in parallel, after which are scattered and concatenated as the output :
In the implementation, the dilated attention can be transformed into dense attention between a gathering operation over the input and a scattering operation over the output , so it can directly reuse any optimization for vanilla attention (e.g., flash attention ). Dilated attention can significantly reduce the computation cost by a factor of over the vanilla attention.
In practice, the segment size trades the globality of attention for efficiency, while the dilation with a size reduces the computation cost by approximating the attention matrix. To capture both long-range and short-range information efficiently, we implement a mixture of dilated attentions with different segment sizes and dilation rates :
where denotes the denominator of the attention softmax for . Note that the computations for are in parallel because there is no computation dependency among them. Experiments show that dynamic weights calculated by the denominator of the attention softmax are better than learnable fixed weights. For a query attends to keys in different dilated attentions, our method to mix dilated attentions is equivalent to gather keys in different parts and calculate softmax together.
Intuitively, the local attention should be precisely computed, while the global attention can be approximate. Therefore, we set a larger with a bigger . Moreover, we gradually increase the for each attention until it reaches the maximum length or the number of attention patterns :
In practice, we set and to geometric sequences for an exponential attentive field.
3 Multi-Head Dilated Attention
As shown in Figure 3, we differ in the computation among different heads by sparsifying different parts of the query-key-value pairs. Specifically, for the -th head, we have an offset when selecting the :
Following the vanilla multi-head attention, the outputs of different heads are concatenated into a final output. The rest of the computation remains the same as the single-head counterpart in Section 2.2.
4 Computational Complexity and Token Dependency
We further extend it to dilated attention with multiple segment sizes and dilation rates. The flops can be written as:
With the segment sizes and dilation rates in Equation 11 and Equation 12, the flops are given by
where is a predefined constant and is the common ratio for geometric sequences and . Therefore, the computation complexity of dilated attention is approximate to .
Moreover, the information of each tokens can be propagated to a maximum distance of :
where is the length of the propagated path. Therefore, the maximum path length of a sequence with tokens can be estimated as:
This proves that the token dependency is approximate to .
LongNet as a Distributed Trainer: Scaling up to 1B Tokens
Although the computation complexity of dilated attention has been greatly reduced to , it is infeasible to scale the sequence length to the million level on a single GPU device due to the computation and memory constraints. There are some distributed training algorithms for large-scale model training, such as model parallelism , sequence parallelism , and pipeline parallelism . However, they are insufficient for LongNet especially when the sequence dimension is extremely large.
We take advantage of the linear computation complexity of LongNet for the distributed training of sequence dimension. Without loss of generality, Figure 4 presents our distributed algorithm on two GPUs, which can be further scaled to an arbitrary number of devices. We start by splitting the input sequence along the sequence dimension. Each sequence is put on one device separately:
Then, they are projected into queries, keys, and values on the two devices:
For the segment length (where is the sequence length on the local device), we compute the attention locally with Equation 3 to Equation 8. For the segment length , the keys and values are distributed across different devices. Therefore, we collect the key-value pairs before computing the attention. We use Equation 3 to Equation 5 to sparsify the into . An all-gather operation is implemented to collect the key-value pairs:
Note that the all-gather operation in the backward becomes a reduce-scatter operation. Different from vanilla attention, both sizes of and are independent of the sequence length , making the communication cost constant.
Finally, we compute the cross-attention with the local queries and the global key-value pairs . The formulation is written as:
The concatenation of the outputs across different devices becomes the final attention output:
The distributed algorithm described above is orthogonal to other parallelisms, including data parallelism which partitions the batch dimension, model parallelism which partitions the hidden dimension, and pipeline parallelism which partitions the layers.
2 Scaling up to 1B Tokens
We verify the feasibility of scaling to 1B tokens with the modern distributed systems. Starting from 8K, we gradually scale the sequence length until the limit of GPU memory. We reduce the batch size accordingly to keep the number of tokens per batch at 1 billion. Each model of different sequence lengths has up to 3 segment lengths, which are 2,048, the number of tokens per device, and the sequence length. We compute the average speed in the forward propagation for 10 different runs.
Figure 5 reports the runtime of vanilla attention and our dilated attention. Both of them are implemented with FlashAttention Kernel for saving memory and improving speed. It shows that dilated attention can successfully scale up the sequence length with almost constant latency. By partitioning the sequence dimension, it can leverage the distributed systems to scale the sequence length to 1 billion tokens. In contrast, vanilla attention suffers from the quadratic dependency on the sequence length. Its latency dramatically increases as the length grows. Moreover, there is no distributed algorithm for vanilla attention to break sequence length limitation. This proves the advantage of the linear complexity as well as the distributed algorithm for LongNet.
Experiments on Language Modeling
We implement LongNet on language modeling. The backbone architecture is Magneto with xPos relative position encoding, except that we replace the standard attention with our dilated attention. We use the base-size configuration of Magneto, which has a hidden dimension of 768, 12 attention heads, and 12 decoder layers. We pre-train the model with The Stack dataset , a source code collection in over 300 programming languages. The data is preprocessed with the tiktoken tokenizerhttps://github.com/openai/tiktoken with cl100k_base encoding. The models are trained with a batch size of 0.5M tokens for 300K steps. More details regarding the hyperparameters can be found in the appendix. All experiments are conducted based on the torchscale codebase.
2 Results
We compare LongNet with both vanilla Transformer and sparse Transformers. The differences among the architectures are the attention layers, while the others remain the same. We scale the sequence length of these models from 2K to 32K, while reducing the batch size to keep the number of tokens per batch constant. For LongNet, we use segment lengths of , and the dilated ratios are . We implement the fixed pattern for sparse attention as in with multiple heads attending to distinct subblocks. The block size is set to 2048. We adjust their sparse ratios to match the computation flops with LongNet so that the comparison is fair. The attention layers in vanilla Transformers are dense and fully connected, so the computation cost is much higher. Due to the computation constraints, we only scale it up to 32K sequence length. All of our implementations of attention variants are based on FlashAttentionhttps://github.com/HazyResearch/flash-attention/tree/main for training efficiency. We customize the flash attention kernels for both sparse attention and dilated attention.
Table 2 summarizes the results of these models on the Stack dataset. We use perplexity as the evaluation metric. The models are tested with different sequence lengths, ranging from 2K to 32K. When the input is longer than the maximum length that the models support, we implement blockwise causal attention (BCA) , a state-of-the-art extrapolation method for language model inference. Besides, we remove the absolute position encoding. Primarily, the results demonstrate that increasing the sequence length during training generally leads to a better language model. Secondly, the extrapolation of sequence length in inference does not apply to the case when the length is much larger than the model supports. Finally, LongNet consistently outperforms the baseline models, proving its effectiveness in language modeling.
3 Scaling Curves of Sequence Length
Previous work has shown that language models follow some scaling laws by increasing parameters or training tokens. We are interested in the performance of language models when the context length is scaled up during training. We test the losses with inputs of a mixture of different lengths, from 1K to 32K. We use blockwise causal attention during inference to improve the generalization of sequence lengths.
Figure 6 plots the scaling curves of sequence length for both vanilla Transformers and LongNet. We estimate the amount of compute by calculating the total flops of matrix multiplication. The results show that both vanilla Transformers and LongNet benefit from a larger context length during training. However, LongNet can scale up the context length more efficiently, achieving a lower test loss with a smaller amount of computing. This demonstrates the advantage of longer training input over extrapolation. In conclusion, our experiments show that LongNet is a more efficient way to scale up the context length in language models. This is because LongNet can learn longer-range dependencies more effectively.
4 Scaling up Model Size
An important property of large language models is that the loss scales as a power law with compute. To verify whether LongNet still follows the similar scaling law, we train a series of models with different model sizes, from 125 million to 2.7 billion parameters. The 2.7B model is trained with 300B tokens, while the rest digest about 40B tokens. Figure 7 plots the scaling curve of LongNet regarding the compute. We compute the perplexity on the same test set. The amount of compute is estimated by calculating the total flops of matrix multiplication during training. It proves that LongNet can still follow the power law. This implies that the dense Transformer is not a prerequisite for scaling the language models. Additionally, the scalability and the efficiency are both obtained by LongNet.
5 Long Context Prompting
Prompting is an essential method to guide and provide additional information to the language models. We conduct experiments to verify whether LongNet can benefit from a longer context window for prompting. Specifically, we reserve a piece of prefixes as the prompt and test the perplexity of its suffixes. We gradually scale the length of the prompt from 2K to 32K. For a fair comparison, we keep the suffixes the same, while increasing the length of the prefixes to the maximum lengths of the models. The results on the test set are reported in Figure 7. It shows that the test loss of LongNet gradually decreases as the context window grows. This demonstrates the superiority of LongNet in fully leveraging the long context to improve the language model.
Conclusion and Future Work
We present LongNet, a Transformer variant that can scale the sequence length to 1 billion tokens and beyond, with no loss in shorter sequences. The core of LongNet is dilated attention, which reduces the computation complexity from quadratic to linear. LongNet can be served as a distributed trainer that parallelizes the training of a sequence across multiple GPU devices. Experiments show that LongNet has superior performance over the strong baselines on modeling both long and short sequences. In the future, we will extend LongNet to support more tasks, e.g., multimodal large language modeling , BEiT pretraining , and genomic data modeling.
We would like to acknowledge Yuqing Xia and Jilong Xue for the early exploration of the flash attention kernel.