Retentive Network: A Successor to Transformer for Large Language Models
Yutao Sun, Li Dong, Shaohan Huang, Shuming Ma, Yuqing Xia, Jilong Xue, Jianyong Wang, Furu Wei
Introduction
Transformer has become the de facto architecture for large language models , which was initially proposed to overcome the sequential training issue of recurrent models . However, training parallelism of Transformers is at the cost of inefficient inference, because of the complexity per step and memory-bound key-value cache , which renders Transformers unfriendly to deployment. The growing sequence length increases GPU memory consumption as well as latency and reduces inference speed.
Numerous efforts have continued to develop the next-generation architecture, aiming at retaining training parallelism and competitive performance as Transformers while having efficient inference. It is challenging to achieve the above goals simultaneously, i.e., the so-called “impossible triangle” as shown in Figure 2.
There have been three main strands of research. First, linearized attention approximates standard attention scores with kernels , so that autoregressive inference can be rewritten in a recurrent form. However, the modeling capability and performance are worse than Transformers, which hinders the method’s popularity. The second strand returns to recurrent models for efficient inference while sacrificing training parallelism. As a remedy, element-wise operators are used for acceleration, however, representation capacity and performance are harmed. The third line of research explores replacing attention with other mechanisms, such as S4 , and its variants . None of the previous work can break through the impossible triangle, resulting in no clear winner compared with Transformers.
In this work, we propose retentive networks (RetNet), achieving low-cost inference, efficient long-sequence modeling, Transformer-comparable performance, and parallel model training simultaneously. Specifically, we introduce a multi-scale retention mechanism to substitute multi-head attention, which has three computation paradigms, i.e., parallel, recurrent, and chunkwise recurrent representations. First, the parallel representation empowers training parallelism to utilize GPU devices fully. Second, the recurrent representation enables efficient inference in terms of memory and computation. The deployment cost and latency can be significantly reduced. Moreover, the implementation is greatly simplified without key-value cache tricks. Third, the chunkwise recurrent representation can perform efficient long-sequence modeling. We parallelly encode each local block for computation speed while recurrently encoding the global blocks to save GPU memory.
We conduct extensive experiments to compare RetNet with Transformer and its variants. Experimental results on language modeling show that RetNet is consistently competitive in terms of both scaling curves and in-context learning. Moreover, the inference cost of RetNet is length-invariant. For a 7B model and 8k sequence length, RetNet decodes 8.4 faster and saves 70% of memory than Transformers with key-value caches. During training, RetNet also achieves 25-50% memory saving and 7 acceleration than standard Transformer and an advantage towards highly-optimized FlashAttention . Besides, RetNet’s inference latency is insensitive to batch size, allowing enormous throughput. The intriguing properties make RetNet a strong successor to Transformer for large language models.
Retentive Networks
In this section, we introduce the retention mechanism that has a dual form of recurrence and parallelism. So we can train the models in a parallel way while recurrently conducting inference.
where we map to the state vector , and then implement a linear transform to encode sequence information recurrently.
Next, we make the projection content-aware:
where is known as xPos , i.e., a relative position embedding proposed for Transformer. We further simplify as a scalar, Equation (3) becomes:
where † is the conjugate transpose. The formulation is easily parallelizable within training instances.
In summary, we start with recurrent modeling as shown in Equation (LABEL:eq:rnn), and then derive its parallel formulation in Equation (4). We consider the original mapping as vectors and obtain the retention mechanism as follows.
As shown in Figure 3(a), the retention layer is defined as:
The Recurrent Representation of Retention
As shown in Figure 3(b), the proposed mechanism can also be written as recurrent neural networks (RNNs), which is favorable for inference. For the -th timestep, we recurrently obtain the output as:
where are the same as in Equation (5).
The Chunkwise Recurrent Representation of Retention
A hybrid form of parallel representation and recurrent representation is available to accelerate training, especially for long sequences. We divide the input sequences into chunks. Within each chunk, we follow the parallel representation (Equation (5)) to conduct computation. In contrast, cross-chunk information is passed following the recurrent representation (Equation (LABEL:eq:ret:recurrent)). Specifically, let denote the chunk length. We compute the retention output of the -th chunk via:
where indicates the -th chunk, i.e., .
2 Gated Multi-Scale Retention
The pseudocode of retention is summarized in Figure 4.
3 Overall Architecture of Retention Networks
We use the parallel (Equation (5)) and chunkwise recurrent (Equation (7)) representations during the training process. The parallelization within sequences or chunks efficiently utilizes GPUs to accelerate computation. More favorably, chunkwise recurrence is especially useful for long-sequence training, which is efficient in terms of both FLOPs and memory consumption.
Inference
The recurrent representation (Equation (LABEL:eq:ret:recurrent)) is employed during the inference, which nicely fits autoregressive decoding. The complexity reduces memory and inference latency while achieving equivalent results.
4 Relation to and Differences from Previous Methods
Table 1 compares RetNet with previous methods from various perspectives. The comparison results echo the “impossible triangle” presented in Figure 2. Moreover, RetNet has linear memory complexity for long sequences due to the chunkwise recurrent representation. We also summarize the comparisons with specific methods as follows.
S4
Unlike Equation (2), if and are content-unaware, the formulation can be degenerated to S4 , where .
Linear Attention
AFT/RWKV
xPos/RoPE
Compared with relative position embedding methods proposed for Transformers, Equation 3 presents a similar formulation as xPos and RoPE .
Sub-LayerNorm
As shown in Equation (8), the retention layer uses Sub-LayerNorm to normalize outputs. Because the multi-scale modeling leads to different variances for the heads, we replace the original LayerNorm with GroupNorm.
Experiments
We conduct experiments on language modeling to evaluate RetNet. We evaluate the proposed architecture with various benchmarks, i.e., language modeling performance, and zero-/few-shot learning on downstream tasks. Moreover, for training and inference, we compare speed, memory consumption, and latency.
Language Model Training
As shown in Table 2, we train language models with various sizes (i.e., 1.3B, 2.7B, and 6.7B) from scratch. The training corpus is a curated compilation of The Pile , C4 , and The Stack . We append the
2 Comparisons with Transformer
As shown in Figure 5, we report perplexity on the validation set for the language models based on Transformer and RetNet. We present the scaling curves with three model sizes, i.e., 1.3B, 2.7B, and 6.7B. RetNet achieves comparable results with Transformers. More importantly, the results indicate that RetNet is favorable regarding size scaling. Besides performance, the RetNet training is quite stable in our experiments. Experimental results show that RetNet is a strong competitor to Transformer for large language models. Empirically, we find that RetNet starts to outperform Transformer when the model size is larger than 2B. We also summarize the language modeling results with different context lengths in Appendix B.
Zero-Shot and Few-Shot Evaluation on Downstream Tasks
We also compare the language models on a wide range of downstream tasks. We evaluate zero-shot and 4-shot learning with the 6.7B models. As shown in Table 3, the datasets include HellaSwag (HS) , BoolQ , COPA , PIQA , Winograd, Winogrande , and StoryCloze (SC) . The accuracy numbers are consistent with language modeling perplexity presented in Figure 5. RetNet achieves comparable performance with Transformer on zero-shot and in-context learning settings.
3 Training Cost
As shown in Table 4, we compare the training speed and memory consumption of Transformer and RetNet, where the training sequence length is 8192. We also compare with FlashAttention , which improves speed and reduces GPU memory IO by recomputation and kernel fusion. In comparison, we implement RetNet using vanilla PyTorch code, and leave kernel fusion or FlashAttention-like acceleration for future work. We use chunkwise recurrent representation of retention as described in Equation (7). The chunk size is set to . We evaluate the results with eight Nvidia A100-80GB GPUs, because FlashAttention is highly optimized for A100. Tensor parallelism is enabled for 6.7B and 13B models.
Experimental results show that RetNet is more memory-efficient and has higher throughput than Transformers during training. Even compared with FlashAttention, RetNet is still competitive in terms of speed and memory cost. Moreover, without relying on specific kernels, it is easy to train RetNet on other platforms efficiently. For example, we train the RetNet models on an AMD MI200 cluster with decent throughput. It is notable that RetNet has the potential to further reduce cost via advanced implementation, such as kernel fusion.
4 Inference Cost
As shown in Figure 6, we compare memory cost, throughput, and latency of Transformer and RetNet during inference. Transformers reuse KV caches of previously decoded tokens. RetNet uses the recurrent representation as described in Equation (LABEL:eq:ret:recurrent). We evaluate the 6.7B model on the A100-80GB GPU in our experiments. Figure 6 shows that RetNet outperforms Transformer in terms of inference cost.
As shown in Figure 6(a), the memory cost of Transformer increases linearly due to KV caches. In contrast, the memory consumption of RetNet remains consistent even for long sequences, requiring much less GPU memory to host RetNet. The additional memory consumption of RetNet is almost negligible (i.e., about 3%) while the model weights occupy 97%.
Throughput
As presented in Figure 6(b), the throughput of Transformer drops along with the decoding length increases. In comparison, RetNet has higher and length-invariant throughput during decoding, by utilizing the recurrent representation of retention.
Latency
Latency is an important metric in deployment, which greatly affects user experience. We report decoding latency in Figure 6(c). Experimental results show that increasing batch size renders Transformer’s latency larger. Moreover, the latency of Transformers grows faster with longer input. In order to make latency acceptable, we have to restrict the batch size, which harms the overall inference throughput of Transformers. By contrast, RetNet’s decoding latency outperforms Transformers and keeps almost the same across different batch sizes and input lengths.
5 Comparison with Transformer Variants
Apart from Transformer, we compare RetNet with various efficient Transformer variants, including Linear Transformer , RWKV , H3 , and Hyena . All models have 200M parameters with 16 layers and a hidden dimension of 1024. For H3, we set the head dimension as 8. For RWKV, we use the TimeMix module to substitute self-attention layers while keeping FFN layers consistent with other models for fair comparisons. We train the models with 10k steps with a batch size of 0.5M tokens. Most hyperparameters and training corpora are kept the same as in Section 3.1.
Table 5 reports the perplexity numbers on the in-domain validation set and other out-of-domain corpora, e.g., Project Gutenberg 2019-2022 (PG22) , QMSum , GovReport , SummScreen . Overall, RetNet outperforms previous methods across different datasets. RetNet not only achieves better evaluation results on the in-domain corpus but also obtains lower perplexity on several out-of-domain datasets. The favorable performance makes RetNet a strong successor to Transformer, besides the benefits of significant cost reduction (Sections 3.3 and 3.4).
In addition, we discuss the training and inference efficiency of the compared methods. Let denote the hidden dimension, and the sequence length. For training, RWKV’s token-mixing complexity is while Hyena’s is with Fast Fourier Transform acceleration. The above two methods reduce training FLOPS via employing element-wise operators to trade-off modeling capacity. In comparison with retention, the chunk-wise recurrent representation is , where is the chunk size, is the head dimension, and we usually set . For either large model size (i.e., larger ) or sequence length, the additional has negligible effects. So the RetNet training is quite efficient without sacrificing the modeling performance. For inference, among the compared efficient architectures, Hyena has the same complexity (i.e., per step) as Transformer while the others can perform decoding.
6 Ablation Studies
We ablate various design choices of RetNet and report the language modeling results in Table 6. The evaluation settings and metrics are the same as in Section 3.5.
Multi-Scale Decay
Equation 8 shows that we use different as the decay rates for the retention heads. In the ablation studies, we examine removing decay (i.e., “ decay”) and applying the same decay rate across heads (i.e., “ multi-scale decay”). Specifically, ablating decay is equivalent to . In the second setting, we set for all heads. Table 6 indicates that both the decay mechanism and using multiple decay rates can improve the language modeling performance.
Head Dimension
From the recurrent perspective of LABEL:eq:rnn, the head dimension implies the memory capacity of hidden states. In the ablation study, we reduce the default head dimension from to , i.e., for queries and keys, and for values. We keep the hidden dimension the same so the number of heads increases. Experimental results in Table 6 show that the larger head dimension achieves better performance.
Conclusion
In this work, we propose retentive networks (RetNet) for sequence modeling, which enables various representations, i.e., parallel, recurrent, and chunkwise recurrent. RetNet achieves significantly better inference efficiency (in terms of memory, speed, and latency), favorable training parallelization, and competitive performance compared with Transformers. The above advantages make RetNet an ideal successor to Transformers for large language models, especially considering the deployment benefits brought by the inference complexity. In the future, we would like to scale up RetNet in terms of model size and training steps. Moreover, retention can efficiently work with structured prompting by compressing long-term memory. We will also use RetNet as the backbone architecture to train multimodal large language models . In addition, we are interested in deploying RetNet models on various edge devices, such as mobile phones.
Acknowledgement
We would like to acknowledge Jiayu Ding, Songlin Yang, and colleagues from MSRA System Group for the helpful discussions.
References
Appendix A Hyperparameters
Appendix B Grouped Results of Different Context Lengths
As shown in Table 8, we report language modeling results with different context lengths. In order to make the numbers comparable, we use 2048 text chunks as evaluation data and only compute perplexity for the last 128 tokens. Experimental results show that RetNet outperforms Transformer across different context lengths. Besides, RetNet can utilize longer context for better results.