HGRN2: Gated Linear RNNs with State Expansion

Zhen Qin, Songlin Yang, Weixuan Sun, Xuyang Shen, Dong Li, Weigao Sun, Yiran Zhong

Introduction

Large language models (LLMs) have achieved great empirical success in recent years. However, serving Transformer-based LLMs is expensive due to the expensive KV cache management. Recurrent neural networks (RNNs), on the other hand, offer linear inference complexity with constant state size, which is ideal for serving. Therefore, there is great interest in studying parallelizable linear recurrent models such as linear RNNs (Peng et al., 2023; Orvieto et al., 2023; Qin et al., 2023c; De et al., 2024), linear attention (Sun et al., 2023b; Qin et al., 2023b; Yang et al., 2023; Arora et al., 2024; Peng et al., 2024), and state space models (Gu et al., 2022a; Smith et al., 2023; Gu & Dao, 2023).

RNNs have a fixed recurrent state size to encode all historical information. Therefore, it is important for RNNs (i) to better exploit the fix-sized states and (ii) to increase the recurrent state size. Recent improvements in linear RNN are exactly on this track, including the use of data-dependent decays and state expansion techniques.

Data-dependent decays (also known as forget gates) are of great importance for RNNs (van der Westhuizen & Lasenby, 2018), which allows them to selectively retain useful information while erasing irrelevant information, so that the fixed-size recurrent state could be better exploited. HGRN (Qin et al., 2023c) first highlighted the importance of data-dependent decays for linear RNNs. Since then, many recent linear recurrent models have also employed data-dependent decays, such as Mamba (Gu & Dao, 2023), Gated Linear Attention (GLA, Yang et al. 2023), Griffin (De et al., 2024), and RWKV-6 (Peng et al., 2024).

However, HGRN did not increase the recurrent state size size, greatly restricted by the limited memory capacity, which prevents it from achieving LLaMa-like (Touvron et al., 2023a; b) language modeling performance, as noted in Qin et al. (2024). Recent state-of-the-art linear recurrent models, such as Mamba, GLA, and RWKV-6, additionally use state expansion techniques, which greatly increase the recurrent state size and thereby enhance the memory capacity.

In this work, we propose HGRN2, with the aim of increasing the size of the recurrent state for HGRN while retaining both parameter efficiency and training efficiency. We first explore structured matrices to directly expand the state size in a parameter-efficient manner. Empirically, we found that they improved the language modeling performance, but still faced with the training inefficiency issue, thus they cannot scale the recurrent state size well. Inspired by linear attention, we then explore using a nonparameteric outer product-based state expansion mechanism, which allows efficient scaling of the size of the recurrent state during training without introducing any additional parameters. Thanks to the gated linear attention form, we could borrow the hardware-efficient GLA training algorithm described in Yang et al. (2023) for large-scale experiments.

We extensively evaluate HGRN2 in language modeling, image classification, and the Long Range Arena benchmark. HGRN2 outperforms HGRN1 in all these benchmarks, showing the significant benefit of using state expansion. Our largest 3B HGRN2 model slightly outperforms Mamba and LLaMa Architecture Transformer for language modeling in a controlled experiment setting; and performs competitively with many open-source 3B models in downstream evaluation but using much fewer total training tokens.

Background

2 HGRN (Qin et al., 2023c)

Compared to Eq. 1, HGRN makes two adjustments: (i) complex-valued recurrence. (ii) forget gates with monotonically increased lower bound values from bottom layers to upper layers.

For (i), similarly to the findings of Gu & Dao (2023) and De et al. (2024), we empirically found that the complex-valued recurrence is not necessary, as shown in Table 1. We speculate that the reason Qin et al. (2023c) found it useful is due to state expansion: The complex-valued recurrent state is twice the size of that in the real-valued recurrent state. If we directly expand the real-valued recurrent state size from dd to 2d2d, the language modeling performance on Wikitext-103 corpus is even better. Therefore, we only consider the real-valued recurrence thereafter.

To avoid the lower bound being one in the highest layer, HGRN subtracts all β\beta by β0\beta^{0}, so that the lower bound for the first layer is zero. After obtaining the lower bound values, the old forget gate gt\mathbf{g}_{t} learns the residuals instead, resulting the new forget gate ft\mathbf{f}_{t}

Method

The result is shown in Table 3. We found that state expansion generally improves the performance and that the low-rank matrix performs the best among these candidates.

However, these PESE methods face the training inefficiency issue, as they need to conduct element-wise linear recurrence in high dimension (i.e. ndnd). Since these element-wise operations cannot leverage tensor cores (a fast matrix multiply unit on GPUs), the dramatically increasing FLOPs and I/O costs would significantly slow down the training when nn is large. We notice that this is similar to the case in MambaThough Mamba has an attention view (Ali et al., 2024) similar to that in linear attention, the attention computation cannot be written as matrix multiply like linear attention, and thus does not facilitate tensor core-based GPU acceleration., so Mamba needs a relatively small expansion ratio (i.e., n=16n=16) and also a custom I/O efficient CUDA implementation to achieve a reasonably fast running speed.

In the next subsection, we explore another strategy that does not replace the dense projection matrices with structured ones, but changes the element-wise gating operations in Eq.1 to some other matrix/vector operations similar to linear attention, which allows for efficient training.

2 HGRN2

Note that the complexity of recurrence increases dramatically from O(BNd)O(BNd) to O(BNd2)O(BNd^{2}) due to state expansion. Therefore, we introduce multi-head HGRN (similar to that in linear attention) such that the complexity is reduced to O(BNd2/H)O(BNd^{2}/H) for the number of heads HH and the state size effectively becomes d2/Hd^{2}/H, i.e., the expansion ratio n=dh=d/Hn=d_{h}=d/H.See Bolya et al. (2022) for more detailed complexity analysis. We conduct an ablation study on the expand ratio(head dimension) n=dHn=\frac{d}{H} as shown in Figure 2. We can see that state expansion is very effective in improving language modeling performance. However, when the head dimension (i.e., state expansion ratio) is larger than 128, the performance gain is not too high. To balance computational cost and performance, we chose to use dh=128d_{h}=128 for the main experiments.

Also note that this form of recurrence is similar to that of Gated Linear Attention (Yang et al., 2023) except for concrete parameterization. We list the correspondence between two parameterizations in Table 4. We can see that the output gate in HGRN2 amounts to the query in GLA, while the output gate in GLA is removed in HGRN2. The key vector in GLA corresponds to the input gate in HGRN2, which is tied to the forget gate, thus saving parameters.

Thanks to the similar computation structure to GLA, we can directly leverage their chunkwise algorithm and their highly optimized kernels https://github.com/sustcsonglin/flash-linear-attention for hardware-efficient large-scale training. We refer the reader to their paper for more details.

Although HGRN2 shares many similarities with GLA, we believe that HGRN2 offers a unique perspective that is different from linear attention, starting from the approach of gated linear RNN. For example, it is not immediately apparent from the perspective of linear attention why key vectors should be constrained within the range of (0, 1) such as in Schlag et al. (2021) ; and why key vector and forget gate value should sum up to one. It becomes quite intuitive when one starts with the gated linear RNN framework and aims to explore state expansion.

Experiments

Multi-Query Associative Recall (MQAR) (Arora et al., 2023) is an enhanced version of the synthetic induction head dataset (Fu et al., 2023) to test the in-context associative recall ability for subquadractic models. Arora et al. (2023) found strong correlations between the accuracy of MQAR and the performance of language modeling. Our experimental setting strictly follows the original paper https://github.com/HazyResearch/zoology, other technical details can also be found in the paper. Our hyperparameter sweeping range is: expand ratio {64,128}\in\{64,128\}, learning rate {1e5,5e5,1e4,5e4,1e3,5e3,1e2}\in\{1e-5,5e-5,1e-4,5e-4,1e-3,5e-3,1e-2\}.

We can see from Fig.3 that HGRN2 significantly outperforms HGRN1 across various model dimensions, showing the benefit of using state expansion.

2 Language modeling

For the Wikitext-103 experiment, we followed the configuration of HGRN1 to validate the performance of 44M models against a wide range of subquadractic models: FLASH (Hua et al., 2022), 1+elu (Katharopoulos et al., 2020), Performer (Choromanski et al., 2021), cosFormer (Qin et al., 2022b), Syn(D), Syn(R) (Tay et al., 2021a), gMLP (Liu et al., 2021), S4 (Gu et al., 2022a), DSS (Gupta et al., 2022b), RWKV-v4 (Peng et al., 2023), LRU (Orvieto et al., 2023), HGRN1 (Qin et al., 2023c), TNN (Qin et al., 2023a).

Table 5 shows the result. We can see that HGRN2 achieves the lowest perplexity among all compared subquadractic models. In particular, HGRN2 clearly outperforms HGRN1 using even fewer parameters.

2.2 The Pile

We conduct a controlled experiment in a relatively large-scale setting on the Pile (Gao et al., 2020) to compare the LLaMa architecture (i.e., Transformer++), Mamba (Gu & Dao, 2023), HGRN1 and HGRN2. Due to the limiation of computational resources, we only train the 1B/3B models for 30B tokens.

The training curve is shown in Fig. 4. We can see that, for 1B models, HGRN2 slightly underperforms Mamba. For 3B models, HGRN2 outperforms both LLaMa and Mamba, indicating the potential to further scale up HGRN2.

2.3 Downstream evaluation

We train 150M/350M/1B/3B HGRN2 for 100B tokens sampled from subsets within the Pile, C4 and Wikipedia datasets and evaluate them in Commonsense reasoning tasks. Table 6 shows the result. HGRN2 outperforms HGRN1 in almost all benchmarks and competes strongly against other models.

3 LRA

Long Range Arena (Tay et al., 2021b) is a benchmark for accessing the model’s long-dependency modeling ability. We use HGRN1’s setting, and compare wth existing methods shown below.

Transformer (Tay et al., 2021b), Cosformer (Qin et al., 2022c), FLASH (Hua et al., 2022), S4 (Gu et al., 2022b), DSS (Gupta et al., 2022a), TNN (Qin et al., 2023a), S5 (Smith et al., 2023), Mega (Ma et al., 2022), SGConv (Li et al., 2022), LRU (Orvieto et al., 2023), Mamba (Gu & Dao, 2023), Griffin (De et al., 2024).

Table 7 shows the result. We can see that HGRN2 outperforms HGRN1 and is competitive with other state-of-the-art models, while Mamba and Griffin failed to have a high accuracy, indicating the importance of lower bound as ablated in HGRN1.

4 Image modeling.

For the image classification task, we referred to the configuration of HGRN1 and train it on ImageNet-1k, also compared to TNN and the vanilla transformer

Table 8 shows the result. We can see that HGRN2 outperforms HGRN1 with a similar parameter size, while having advantage over previous TNN (Qin et al., 2023a) and Deit models (Touvron et al., 2021).

Related work

Linear recurrent models mainly include linear RNNs, state-space models, and linear attention. State-space models (SSMs) are gaining great attention since the seminal work S4 (Gu et al., 2022a) and its more efficient diagonalized version (Gu et al., 2022c). Despite excellent performance in the LRA benchmark, it has been shown to have inferior performance in language modeling. Gating mechanisms have been shown to be crucial in improving SSMs’ language modeling performance (Mehta et al., 2023; Wang et al., 2022; Gu & Dao, 2023). Gupta et al. (2022c) build the connection between SSM and linear RNN. Orvieto et al. (2023) proposes a linear RNN layer (i.e., LRU) inspired by SSMs. Peng et al. (2023) successfully scale linear RNN models to billions of parameters for the first time.

For linear attention models, their language modeling performance has been underperforming softmax attention for a long time. Several improvements have been proposed to bridge the performance gap: (i) incorporating the forgetting mechanism (Peng et al., 2021; Schlag et al., 2021; Sun et al., 2023a; Qin et al., 2023b; Yang et al., 2023), (ii) using local attention (Qin et al., 2022a; Zhang et al., 2023; Arora et al., 2024), (iii) using higher-order polynomial feature map (Arora et al., 2024; Kacham et al., 2023) to make the resulting attention distribution more sharp (Zhang et al., 2024).

Martin & Cundy (2018) first proposed a minimal gated linear recurrent layer and showed how to use the parallel scan algorithm to train linear RNNs in sequence-level parallel. Qin et al. (2023c) is largely based on this work with several adaptations and highlights the importance of data-dependent decay. De et al. (2024) build their model on LRU (Orvieto et al., 2023) and replace data-independent decays with data-dependent ones. They further use sliding-window attention to boost the performance. These models are limited in recurrent state size.

Gated recurrent models with matrix-valued recurrent state have been investigated in the literature of Neural Turing Machine (NTM, Graves et al. 2014) and linear Transformer (Katharopoulos et al., 2020). In NTM, the number of memory slots can be regarded as the state expansion ratio discussed in this work. NTM also included data-dependent decays in the form of erase vectors. However, NTM is hard to parallelize and thus slow to train in practice. ABC (Peng et al., 2022) could be considered a simplified and parallelizable version of NTM. The linear transformer is known to have the recurrent form (Katharopoulos et al., 2020) and is known to be closely related to fast weight programming (FWP Schlag et al. 2021). Gated FWPs have been investigated since Schlag & Schmidhuber (2017); Zhang & Zhou (2017), and have recently been revisited in Peng et al. (2021); Mao (2022); Yang et al. (2023); Katsch (2023); Pramanik et al. (2023). In particular, Yang et al. (2023) proposed a hardware-efficient training algorithm for these types of models.

Conclusion

In this work, we proposed HGRN2, which improves HGRN (Qin et al., 2023c) using an outer product-based state expansion mechanism inspired by linear attention, which allows for hardware-efficient training. Experiments on multiple tasks validate the advantages of HGRN2 over HGRN1. Large-scale language modeling experiments show that HGRN2 is competitive with other state-of-the-art models.

References

Appendix A Appendix

In Table 9, the experiment configurations provided detail setups for both Auto-regressive Language Modeling (ALM) and ImageNet (IM) experiments, focusing on the WikiText-103 and ImageNet-1k datasets, respectively. ALM experiments utilize Byte Pair Encoding (BPE) with a vocabulary size of 50,26550,265 and sequence length of 512512, featuring a total batch size of 128128 and 50,00050,000 updates. ImageNet experiments differentiate between 6 million and 23 million parameter models, with total batch sizes of 10241024 and 20482048, both running for 300300 epochs but with differing warm-up periods. Optimization strategies vary between Adam for ALM and AdamW for IM, with specific learning rate schedulers and hyper-parameters tailored to each model’s scale. Additional configurations outline variations in model complexity, from 0.150.15 to 2.92.9 million parameters, adjusting layers, hidden dimensions, and GPUs used, aiming to comprehensively explore model performance across scales and setups.

A.2 Loss curve of HGRN2

The training loss curves for the HGRN2 models of different sizes—150M, 385M, and 1B, as shown in Fig. 5, which as the number of parameters increases, the model’s performance improves, with the 1B model consistently outperforming the others.