Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models

Soham De, Samuel L. Smith, Anushan Fernando, Aleksandar Botev, George Cristian-Muraru, Albert Gu, Ruba Haroun, Leonard Berrada, Yutian Chen, Srivatsan Srinivasan, Guillaume Desjardins, Arnaud Doucet, David Budden, Yee Whye Teh, Razvan Pascanu, Nando De Freitas, Caglar Gulcehre

Introduction

Recurrent neural networks (RNNs) played a central role in the early days of deep learning and NLP research (Elman, 1990; Siegelmann and Sontag, 1991; Hochreiter and Schmidhuber, 1997; Mikolov et al., 2010; Bahdanau et al., 2014; Sutskever et al., 2014), and achieved practical success in many applications, including Google’s first end to end machine translation system (Wu et al., 2016). However in recent years, both deep learning and NLP have been dominated by the Transformer architecture (Vaswani et al., 2017), which interleaves multi-layer perceptrons (MLPs) and multi-head attention (MHA). Transformers achieve better performance than RNNs in practice and are also very efficient at utilizing modern hardware (Kaplan et al., 2020). Transformer-based large language models trained on massive datasets collected from the web have achieved remarkable success (Brown et al., 2020; Rae et al., 2021; Hoffmann et al., 2022; Touvron et al., 2023; Achiam et al., 2023; Gemini Team Google, 2023).

Despite their successes, Transformers are difficult to scale efficiently to long sequences due to the quadratic complexity of global attention. Additionally, the linear growth of the Key-Value (KV) cache with the sequence length makes Transformers slow during inference. Although Multi-Query Attention (MQA) (Shazeer, 2019) partially mitigates this issue by reducing the cache size by a constant factor, the cache still grows linearly in sequence length. Recurrent language models present a compelling alternative as they compress the entire sequence into a fixed-sized hidden state which is updated iteratively. However to replace Transformers, new RNN models must demonstrate not only comparable performance at scale but also achieve similar hardware efficiency (Gu et al., 2021a; Mehta et al., 2022; Smith et al., 2022; Orvieto et al., 2023b; Dao et al., 2022b; Poli et al., 2023; Gu and Dao, 2023).

In this work, we propose the RG-LRU layer, a novel gated linear recurrent layer, around which we design a new recurrent block to replace MQA. We build two new models using this recurrent block: Hawk, a model which interleaves MLPs with recurrent blocks, and Griffin, a hybrid model which interleaves MLPs with a mixture of recurrent blocks and local attention (Beltagy et al., 2020). We show that:

Hawk and Griffin exhibit power law scaling between held-out loss and training FLOPs, up to and beyond 7B parameters (Figure 1(a)), as previously observed for Transformers (Kaplan et al., 2020).

Griffin achieves slightly lower held-out loss than strong Transformer baselines at all model scales.

We overtrain Hawk and Griffin on 300B tokens at a range of model scales. Hawk-3B exceeds the reported performance of Mamba-3B (Gu and Dao, 2023) on downstream tasks, despite being trained on half as many tokens. Griffin-7B and Griffin-14B match the performance of Llama-2 (Touvron et al., 2023) despite being trained on roughly 7 times fewer tokens (Section 3.2).

Both Hawk and Griffin achieve comparable training efficiency to Transformers on TPU-v3. Since diagonal RNN layers are memory bound, we achieve this with a kernel for the RG-LRU layer, implemented in Pallas (Bradbury et al., 2018), that minimizes memory transfers (Section 4).

During inference, both Hawk and Griffin achieve significantly higher throughput than MQA Transformers (Figure 1(b)), and they achieve lower latency when sampling long sequences (Section 5).

Griffin performs better than Transformers when evaluated on sequences longer than those seen during training, and can also efficiently learn copying and retrieval tasks from training data (Section 6). However, Hawk and Griffin perform less well than Transformers when we evaluate pre-trained models on copying and exact-retrieval tasks without fine-tuning.

Model Architecture

All our models contain the following components: (i) a residual block, (ii) an MLP block, and (iii) a temporal-mixing block. While (i) and (ii) are the same across all models, we consider three temporal mixing blocks: global Multi-Query Attention (MQA), local (sliding-window) MQA and our proposed recurrent block. As part of the recurrent block we use the Real-Gated Linear Recurrent Unit (RG-LRU) – a novel recurrent layer inspired by the Linear Recurrent Unit (Orvieto et al., 2023b).

The residual block, as shown in Figure 2(a), defines the global structure of our models and is inspired by pre-norm Transformers (Xiong et al., 2020). After embedding the input sequence we pass it through NN such blocks (NN denoting the model depth), and then we apply RMSNorm (Zhang and Sennrich, 2019) to produce the final activations. To compute the token probabilities we apply a final linear layer followed by a softmax. The weights of this layer are shared with the input embedding layer.

The residual block contains two components, applied in order. The first component takes the hidden state xx and applies an RMSNorm (Zhang and Sennrich, 2019), followed by the temporal-mixing block. We then merge the output with a skip connection from xx through addition. Similarly, the second component applies RMSNorm, followed by the MLP block and then merges its output with a skip connection from the input of the RMSNorm. This block is illustrated in Figure 2 (a).

2 MLP block

We use a gated MLP block (Dauphin et al., 2017) (illustrated in Figure 2(b)), which creates two branches from its input of dimension DD. We apply a linear layer with output dimension MDMD on each branch, where MM denotes the expansion factor. For simplicity, we use M=3M=3 throughout this work. We apply a GeLU non-linearity (Hendrycks and Gimpel, 2016) on one of the branches before merging them by element-wise multiplication, similar to GeGeLU (Shazeer, 2020). However, in our MLP block, we apply a final linear layer with output dimension DD on the outputs of the GeGeLU layer.

3 Temporal-mixing blocks

The temporal-mixing block is the component of our model that aggregates hidden layer activations at different temporal locations in the sequence. We consider three temporal-mixing blocks: global MQA (Shazeer, 2019), local MQA (Beltagy et al., 2020) and our proposed Recurrent block.

Unless otherwise stated, we use MQA rather than MHA to improve the inference speeds of our Transformer baselines (Shazeer, 2019). We use a fixed head dimension Dhead=128D_{head}=128, and we fix the number of attention heads HH such that HDhead=DHD_{head}=D. This requires the model dimension DD to be a multiple of 128. We do not use any absolute positional embeddings, but we use Rotary Position Embedding (RoPE) (Su et al., 2021) as a relative positional embedding.

One of the key disadvantages of using global attention is that its computational complexity grows quadratically in the sequence length. To address this, several works have started to adopt local attention (Beltagy et al., 2020), also known as sliding window attention. It allows each position to attend only to a fixed number of tokens in the past. This not only reduces the computational FLOPs but also bounds the size of the KV cache to the size of window, making it no longer quadratic in the sequence length. All other details are the same as the global MQA.

Our recurrent block (Figure 2(c)) is similar to the GSS block (Mehta et al., 2022) and the block used by Mamba (Gu and Dao, 2023). We take the input of dimension DD and apply two linear layers with output dimension DRNND_{RNN} in parallel, creating two branches. On the first branch, we apply a small separable Conv1D layer, inspired by the Shift-SSM in H3 (Dao et al., 2022b), with a temporal filter dimension of 4. Note that this Conv1D layer is very small, with just 4DRNN4D_{RNN} parameters. We follow the Conv1D layer with our proposed RG-LRU layer (defined below.) On the second branch we apply a GeLU nonlinearity and then merge the branches by element-wise multiplication. We then apply a final linear layer with output dimension DD.

4 Real-Gated Linear Recurrent Unit (RG-LRU)

Our proposed RG-LRU layer has a simple recurrence inspired by the Linear Recurrent Unit (LRU) (Orvieto et al., 2023b), but incorporates a gating mechanism motivated by the literature on non-linear RNNs, in particular LSTMs (Hochreiter and Schmidhuber, 1997) and GRUs (Chung et al., 2014). The equations describing the layer are as follows:

The output of the layer is yt=hty_{t}=h_{t}, and the non-linearity σ\sigma in the equations is the sigmoid\operatorname{sigmoid} function. The recurrent weight aa in Equation (4) is diagonal. Hence all operations are element-wise. We parameterize aa in Equation (3) as a=σ(Λ)a=\sigma(\Lambda), where Λ\Lambda is a learnable parameter. This guarantees that 0a10\leq a\leq 1, ensuring that the recurrence is stable. The variable cc is a scalar-valued constant set to 88. For numerical stability, in practice we compute acrta^{cr_{t}} in log-space (see Appendix A). The layer has gates on both the input xx and the recurrent weight aa. However, neither gate depends on the recurrent state ht1h_{t-1}, which ensures that the computation can be executed efficiently on device. We initialize both WaW_{a} and WxW_{x} using LeCun init (LeCun et al., 2002). We initialize Λ\Lambda such that aca^{c} is uniformly distributed between 0.90.9 and 0.9990.999 at the start of training, similar to Orvieto et al. (2023b). Unlike many recent works in the SSM literature, the RG-LRU does not use initialization inspired by the theory of orthogonal polynomials (Gu et al., 2020), and it also is not defined as the discretization of an underlying continuous system (Gu et al., 2021a). Unlike the original LRU layer, we do not use complex algebra in the recurrence. While using complex recurrences would lead to a more expressive layer (Orvieto et al., 2023a) we found that complex recurrences were not beneficial for language modelling in practice, as also observed by Gu and Dao (2023).We suggest ablating the use of complex numbers for other modalities and provide more information about the complex-valued version of the RG-LRU layer in Appendix B.

The input gate iti_{t} is similar to the one in LSTM, which can filter (or scale down) the input xtx_{t}. However, to our knowledge, our recurrence gate rtr_{t} is different from other gating mechanisms in the literature. For example, the selection mechanism proposed in Mamba (Gu and Dao, 2023) is comparable to the update gate of GRUs which interpolates between the previous state and and the current observation xtx_{t}. Its effect on the hidden state allows it to reset its state and forget any information it holds from the past, similar to the forget gate in the LSTM. In contrast, our recurrence gate can approximately interpolate between the standard LRU update from Orvieto et al. (2023a) and the previous hidden state, which allows it to effectively discard the input and preserve all information from the previous history (see Appendix A for further details). We believe the key role of this gate is to enable the model to achieve super-exponential memory by reducing the influence of uninformative inputs.

Recurrent Models Scale as Efficiently as Transformers

Scaling studies provide important insights into how to tune the hyperparameters of a model and its behaviour at scale. Here, we define the models evaluated in our studies, and provide scaling curves up to and beyond 7B parameters. Finally, we assess the performance of our models on downstream tasks. We consider 3 model families in this work; (1) a MQA-Transformer baseline, (2) Hawk; our pure RNN model, and (3) Griffin; our hybrid model which mixes recurrent blocks with local attention. We define the key model hyper-parameters for models across a range of scales in Appendix C.

Our Transformer baseline uses the residual pattern and the gated MLP block described in Section 2, in combination with MQA (Shazeer, 2019) and RoPE (Su et al., 2021).

The Hawk architecture uses the same residual pattern and MLP block as our Transformer baseline, but we use the recurrent block introduced in Section 2.3 with a RG-LRU layer (see Section 2.4) as our temporal mixing block, instead of MQA. We expand the width of the recurrent block by a factor of approximately 43\frac{4}{3} (i.e. DRNN4D/3D_{RNN}\approx 4D/3) in order to roughly match the parameter count of a MHA block when both use the same model dimension DD.Note that we match parameters with MHA attention block, though our Transformer baseline and Griffin ended up relying on MQA attention in order to improve inference efficiency. This means that our recurrent blocks have slightly more parameters than the corresponding MQA blocks. See Appendix C for precise hyper-parameters.

The key advantage of recurrent blocks over global attention is that they use a fixed state size to summarize the sequence, whereas the size of MQA’s KV cache grows proportional to sequence length. Since local attention (Section 2.3) has the same property, mixing recurrent blocks with local attention preserves this benefit. We have found this combination extremely effective, since local attention accurately models the recent past, while the recurrent layers can transmit information across long sequences.

Griffin uses the same residual pattern and MLP block as our Transformer baseline. However unlike both our MQA Transformer baseline and the Hawk model, Griffin uses a mixture of recurrent blocks and MQA blocks. Specifically, we employ a layered structure by alternating two residual blocks with a recurrent block followed by one residual block which uses the local (MQA) attention block described in Section 2.3. Unless otherwise stated, the local attention window size is fixed to 1024 tokens.

1 Scaling curves

We present our main scaling results in Figure 1(a). All three model families are trained at a range of model scales from 100M to 7B parameters, with an additional Griffin model with 14 billion parameters. We increase the number of training tokens to be roughly proportional to the number of parameters of the model, as prescribed by the Chinchilla scaling laws (Hoffmann et al., 2022). Models are trained on the MassiveText dataset (Hoffmann et al., 2022), previously used to train Gopher (Rae et al., 2021) and Chinchilla (Hoffmann et al., 2022), although we use a slightly different data subset distribution. A sequence length of 2048 tokens was used (see Section 6 for results with longer sequences.) All experiments use the AdamW optimizer (Loshchilov and Hutter, 2017). We tune the learning rate, weight decay and β2\beta_{2} parameters for small models, and use these runs to identify scaling rules for these hyper-parameters which predict their optimal values for the 7B and 14B models.

All three model families demonstrate a linear scaling relationship between the validation loss and training FLOPs (see Figure 1(a); note both axes are in log scale), as previously observed for Transformers by Brown et al. (2020). Notably, Griffin achieves lower validation loss than the Transformer baseline across all FLOPs budgets despite not using any global attention layers. Hawk on the other hand achieves slightly higher validation loss, but this gap appears to close as the training budget increases.

2 Evaluation on downstream tasks

In order to compare to other models in the literature, we train all our models for 300B tokens before evaluating on downstream tasks. The two external baselines that we compare to are Mamba-3B (Gu and Dao, 2023), the strongest small recurrent model reported in the literature to date, and Llama-2 (Touvron et al., 2023), a widely used open Transformer model. Both external baselines have been trained on significantly more than 300B tokens – Mamba has been trained on 600B tokens, twice more, and Llama-2 has been trained on 2T tokens, nearly seven times more. We note however that both Mamba and Llama-2 were trained on different datasets and with different hyper-parameter tuning strategies, which may partially explain our strong performance. We therefore also include our own MQA transformer baseline, trained on the same data and with the same hyper-parameter tuning budget as Hawk and Griffin.

We provide an evaluation on downstream tasks in Table 1. We find that both Hawk and Griffin achieve very strong performance. In line with other works, we report character normalized accuracy on MMLU, HellaSwag, PIQA, ARC-E and ARC-C, while we report absolute accuracy on WinoGrande with partial scoring. The performance of Hawk improves significantly as we increase the model size, and Hawk-3B achieves stronger performance on downstream tasks than Mamba-3B, despite being trained on half as many tokens. Griffin-3B significantly outperforms Mamba-3B, and Griffin-7B and Griffin-14B achieve performance competitive with Llama-2, despite being trained on nearly 7 times fewer tokens. Hawk is also competitive with our MQA Transformer baseline, while Griffin outperforms this baseline.

Training Recurrent Models Efficiently on Device

We encountered two main engineering challenges when developing and scaling our models. First, how to efficiently shard our models across multiple devices. Second, how to efficiently implement linear recurrences to maximize training efficiency on TPUs. We address both of these challenges in this section, before providing an empirical comparison of the training speed of Griffin and our MQA baseline.

As our model increases in size, we cannot fit the model on a single device during training, even with a batch size of 1 per-device. We therefore use model parallelism to shard our large models across devices during training. Since communication costs across different training devices are expensive, efficiently sharding the model is critical for fast training at scale.

For our gated-MLP block we use Megatron-style sharding (Shoeybi et al., 2019), which requires a single all-reduce operation in both the forward and the backward pass. Similarly, we apply the same strategy to the linear layers in the attention block, and additionally shard the attention mechanism over its heads (Narayanan et al., 2021).

The recurrent block contains two linear layers per branch. This allows us to apply Megatron sharding to these layers in an equivalent fashion. The Conv1D layer operates independently across channels, enabling us to split its parameters across devices without incurring any communication overhead. To avoid additional cross-device communication, we use block-diagonal weights for the gates in the RG-LRU (see equations 1 and 2), instead of dense matrices. For all experiments in this paper, we use 16 blocks for both the recurrence gate and the input gate (such that WxW_{x} and WaW_{a} each have DRNN2/16D_{RNN}^{2}/16 parameters). The diagonal structure of the recurrence offers the same advantage as the Conv1D, allowing parameter sharding and computation without any communication. With this strategy, the recurrent block’s communication requirements are equivalent to those of the MLP block.

Optimizer states can consume significant memory, exceeding the size of the model parameters themselves. To address this, we employ ZeRO parallelism (Rajbhandari et al., 2020), distributing both optimizer states and model parameters across the batch shards. We also use bfloat16 representation for model parameters and activations, minimizing any data transfer overhead.

2 Efficient linear recurrences on device

Current deep learning accelerators are optimized for classical architectures which are composed largely of matrix multiplications and convolutions. These operations have a high FLOPs-to-byte ratio, motivating the development of specialized hardware units like Nvidia GPUs’ TensorCores (Markidis et al., 2018) and Google TPUs’ MXUs (Norrie et al., 2021; Jouppi et al., 2021, 2023). Classical RNNs also benefit from this due to their dense recurrence matrices. In contrast, our proposed RG-LRU layer, like other diagonal RNN models, has a low FLOPs-to-byte ratio. This fundamental difference poses a computational challenge, as existing accelerators lack optimization for such workloads. Since we run all our experiments on TPU-v3, we focus on developing an efficient implementation tailored to this deviceThe conclusions drawn here do not necessarily apply to other accelerators..

One of the main challenges of utilizing a device like the TPU-v3 for the RG-LRU is that the update equation of the hidden state in eq. (4) is a pure elementwise operation. For each element update it requires to load 6 bytes (assuming bfloat16 we need 2 bytes for each of the variables ht1,at,xth_{t-1},a_{t},x_{t}) and write 2 bytes (the hidden state hth_{t}) while the computation only executes 6 FLOPs (number of arithmetic operations in eq. 4) per element. This translates to a low FLOPs-to-byte ratio of 0.750.75 – significantly below the device’s capacity for elementwise operations of 4.24.2 (see Appendix 3). Execution time is therefore dominated by memory transfers between HBM and VMEM, making the computation memory bound.

To address this we have written a custom Pallas kernel for the computation of eq. (4) using a linear scan. This allows us to minimize memory transfers, by keeping the hidden state in VMEM all the time, and also to perform memory transfers in larger chunks rather than one at a time. In practice, this translates to almost 3x speed up over the native Jax implementation of the linear scan. Additionally, we observe 10-20% lower training times per step of the full Hawk model, relative to the same model using the native Jax implementation (see Appendix D.2 for more details.)

The initial appeal of linear recurrence models stemmed from their high parallelizability, enabled by the associativity of their computations. This permitted efficient execution on device via convolutions (Gu et al., 2021b) or prefix-sum algorithms (the associative scan) (Smith et al., 2022). However, the RG-LRU’s gating mechanism on ata_{t} is not compatible with the convolutional view. Although we can still use the associative scan in principle, the associative scan reduces the number of FLOPs required but does not reduce memory overheads, which is our primary bottleneck in practice. Empirically we observe that on a TPU-v3 the associative scan is significantly slower that the native Jax linear scan (see Appendix D.2 for more details.) We speculate that the random access nature of the tree recombination of the parallel prefix-sum algorithm makes is poorly suited for the TPU architecture, leading to even slower memory transfers – the main bottleneck of this operation.

3 Training speed on longer sequences

We compare the training speeds across different model sizes and sequence lengths to investigate the computational advantages of our models during training. For each model size, we keep the total number of tokens per batch fixed, meaning that as we increase the sequence length, we proportionally decrease the number of sequences. In Figure 3, we plot the relative runtimes of our Griffin model compared to that of the MQA baseline at 2048 sequence length. At the lowest sequence length, the two models have similar training time, but as we increase the sequence length the Transformer becomes slower, while Griffin’s runtime remains the same. The drop in speed for the baseline is more pronounced at smaller model sizes and decreases at larger model sizes. This can be explained by the fact that all models contain a large number of linear layers. Their computation scales O(TD2)O(TD^{2}), while the RG-LRU is O(TD)O(TD) vs O(T2D)O(T^{2}D) of global attention. This means that as we increase the model width DD compared to the sequence length TT, the linear layers become the primary computational bottleneck, minimizing the efficiency gains from the RNN block. Therefore, replacing Transformers with Hawk or Griffin offers the most significant wall-clock time improvement when sequence length is sufficiently large relative to model width to ensure the attention computation constitutes a major portion of the total computation time. We also note that in practice, our MQA baseline has slightly fewer parameters than Griffin at the same model scale (and performs fewer FLOPs). This explains why Griffin trains slightly slower than our MQA baseline at 7B for short sequences.

Inference Speed

Inference in LLMs is composed of two stages. In the “prefill” stage, we receive and process the prompt. This step is effectively performing a forward pass of the model. Since the prompt can be processed in parallel across the sequence, most model operations are compute bound during this stage. We therefore expect the relative speeds of Transformers and recurrent models during the prefill stage to be similar to the relative speeds of the same models during training, which we discussed in Section 4.

Prefill is followed by a “decode” stage, in which we sample tokens auto-regressively from the model. As we show below, recurrent models have lower latency and higher throughput during the decoding stage, especially for longer sequence lengths where the key-value (KV) cache used in attention can get large.

There are two main metrics to consider when evaluating inference speed. The first is latency, which measures the time taken to generate a specified number of tokens at a certain batch size. The second is throughput, which measures the largest number of tokens per second that can be generated on a single device when sampling a specified number of tokens. Since throughput is given by tokens sampled times batch size divided by latency, one can improve throughput either by reducing the latency or by reducing memory usage to enable the use of larger batch sizes on device. Latency can be useful to consider for real-time applications that require a quick response time. Throughput is also useful to consider as it can tell us the maximum number of tokens we could sample from a particular model in a given time. This property is useful when considering other language applications such as Reinforcement Learning from Human Feedback (RLHF) or scoring language model outputs such as done in AlphaCode (Li et al., 2022) where being able to output a large number of tokens in a given time is an appealing feature.

All components of language models are memory bound during decoding as long as batch size isn’t too big (i.e. B128B\lesssim 128- see Appendix F.1 for details) and we will assume this for the remainder of this section. The largest memory overheads of Transformers typically come from the parameters themselves and the KV cache. Therefore we can approximate the time required to generate a single token for each sequence in the batch BB during decoding as the time needed to load these two quantities from memory:

Here, cache size refers to either the size of the KV cache at batch size 1 (for Transformers), or to the size of the recurrent state at batch size 1 (for RNNs).

The difference in cache size relative to model parameters has important implications for sampling efficiency. In recurrent and local attention blocks, parameter loading is the primary bottleneck, (because the cache size is substantially smaller). In contrast, global attention’s KV cache scales with the sequence length TT and can be comparable to, or even exceed, the size of the model parameters. This introduces considerable overhead when the sequence length TT is large enough (as shown in F.4). Consequently, an equally sized recurrent model can exhibit substantially lower latency than a Transformer when TT is large. Note however that as the model size grows the sequence length at which we see latency benefits (where the KV cache size is comparable to parameter size) also increases. It is important to note that, as well as improving latency, having a small recurrent state can also increase the largest batch size that fits in memory on a single device, leading to higher throughput.

2 Results

Here, we look at inference results for models of size 1B parameters. For our baseline, we compare against a MQA Transformer, which is significantly faster during inference than the standard MHA Transformer often used in the literature. The models that we compare are: i) MQA Transformer, ii) Hawk, and iii) Griffin. For comparing different models we report both latency and throughput.

We compare the latency for models with a batch size of 16 with an empty prefill as well as a prefill of 4096 tokens as seen in Figure 4. Hawk and Griffin achieve faster sampling latency than MQA Transformers for long sequences. This is particularly noticeable as the sequence length and the prefill length (which affect the size of the KV cache) are increased. Griffin achieves similar latency to Hawk, demonstrating the excellent compatibility of linear recurrences and local attention.

We compare the maximum throughput (tokens/s) for the same models when sampling 512, 1024, 2048 and 4196 tokens following an empty prompt in Figure 1(b). We see that both Griffin and Hawk achieve significantly higher throughput than the MQA Transformer baseline. This is partially due to recurrent models having lower latency but also primarily occurs because Griffin and Hawk can fit larger batch sizes than the MQA Transformer on a single device, since their cache size is smaller. Hawk achieves higher throughputs than Griffin, since the size of the local attention cache eventually becomes comparable to the size of the parameters when the batch size is large.

Long Context Modeling

In this section, we explore the effectiveness of Hawk and Griffin to use longer contexts to improve their next token prediction, and investigate their extrapolation capabilities during inference. Additionally, we explore our models’ performance on tasks that require copying and retrieval capabilities, both for models that are trained on such tasks, as well as when testing for these capabilities with our pre-trained language models.

We investigate the ability of Hawk and Griffin to improve their predictions with longer contexts. In particular, we evaluate our trained models by measuring the loss on a held-out books dataset across a range of sequence lengths. Using these long documents allows us to evaluate the ability of the models to extrapolate, i.e. the ability to accurately predict the next token given contexts that are longer than those seen during training.

In Transformers, this ability to extrapolate is largely determined by the positional encoding used for the attention layers (Kazemnejad et al., 2024). For recurrent models, it is instead dictated by the capacity of the model to keep refining the representation stored in the recurrence state as the context becomes longer. From the left plot of Figure 5, we observe that, up to some maximal length, both Hawk and Griffin improve next token prediction given longer contexts, and they are overall able to extrapolate to significantly longer sequences (at least 4x longer) than they were trained on. In particular, Griffin extrapolates remarkably well even when using RoPE (Su et al., 2021) for the local attention layers.

The results so far evaluate models that have been trained on sequences of 2048 tokens. In order to assess whether our models can also effectively learn from longer contexts, we train 1B parameter models on sequences of 8192 (8k) tokens on MassiveText, and compare them to models trained on the same dataset but on sequences of length 2048 (2k) tokens. We keep the total number of training tokens the same across the models by reducing the batch size by a factor of 4 for the models trained on the sequence length of 8192 (while keeping the number of training steps fixed). As illustrated in the right plot of Figure 5, we find that Hawk-8k and Griffin-8k do achieve lower evaluation loss for sequences of length 8192 or larger, compared to Hawk-2k and Griffin-2k. This indicates that Hawk and Griffin are able to learn to use longer contexts during training. Interestingly, when evaluating at short sequence lengths, we find that Hawk-2k and Griffin-2k perform slightly better than Hawk-8k and Griffin-8k. This suggests that the training sequence length should be carefully chosen according to the intended downstream use of the model.

2 Copy and retrieval capabilities

Recent work (Jelassi et al., 2024) has shown that Transformers can be significantly more efficient than state space models (SSMs), a popular new family of RNNs, at learning synthetic tasks such as copying the context or retrieving relevant tokens from the context. Additionally, Jelassi et al. (2024) showed that pre-trained Transformers such as Pythia (Biderman et al., 2023) are much better at copying and retrieval tasks at evaluation time compared to pre-trained SSM models such as Mamba (Gu and Dao, 2023). In this section, we investigate the efficiency of Griffin and Hawk in learning how to copy and retrieve tokens from the context. Additionally, we evaluate pre-trained Hawk and Griffin models on a phone number lookup task designed to test both copying and retrieval capabilities.

To investigate the efficiency of learning how to copy and retrieve relevant tokens from the context, we train on two synthetic tasks: Selective Copying and Induction Heads. To be able to compare Transformers with Hawk and Griffin, we consider 5-block deep networks with model dimension 64, totalling roughly 250K parameters, where Griffin uses a single local attention in the middle of the network, in the third block.

Selective copying task: In this task, the model needs to learn to copy data tokens from a sequence while ignoring noise tokens from the context. See Appendix H for more details on the setup for this task. This task is inspired by Gu and Dao (2023), where the authors showed that Mamba was able to solve this task better than previously proposed SSMs. We use a vocabulary size of 16, and train on sequences of length 1024, containing 16 data tokens (randomly sampled from the vocabulary and at random locations), with the rest of the tokens set to the noise token. Griffin uses a local attention window size of 512.

Induction heads: In this task, the model needs to learn to recall the token immediately following a special token. This requires the model to learn the special token, and retrieve the token immediately following it in the context. If the model is able to learn the task, it should be able to extrapolate to significantly longer sequences than it was trained for. We use a vocabulary size of 16 and train on sequences of length 256 where the tokens are sampled randomly, and we randomly sample the location of the special token in the sequence. Griffin uses a local attention window of size 128.

We show our results in Figure 6. On the Selective Copying task, we find that all 3 models are able to solve the task perfectly. When comparing speed of learning on this task, we find Hawk to be significantly slower than Transformers, similar to the observation made by Jelassi et al. (2024), where the authors showed that Mamba was significantly slower to learn on similar tasks. Interestingly though, Griffin shows almost no slowdown, effectively matching the speed of learning of Transformers, despite using only a single local attention layer.

On the Induction Heads task, while all 3 models can solve the task perfectly up to the training sequence length, our Transformer baseline is not able to extrapolate to longer sequences during evaluation. While our MQA baseline uses RoPE, Gu and Dao (2023) had similar observation for Transformers with a range of positional encodings. We find that Hawk is able to perfectly extrapolate on this task to evaluation sequences several orders of magnitude longer than the training sequence length. Notably, Griffin, with its local attention, also demonstrated exceptional ability to extrapolate on this task.

We now evaluate whether copying and retrieval capabilities naturally emerge in our pre-trained models. We consider our 7B Hawk and Griffin models and our 6B MQA Transformer baseline, all trained on 300B tokens on the MassiveText dataset. We consider the same phonebook lookup task introduced in Jelassi et al. (2024), where we provide to the model a synthetic phonebook containing names and numbers, and the model is asked to retrieve the correct phone number given a name. The prompt to the model is a phonebook consisting of randomly sampled list of names and numbers of a certain length, followed by two randomly sampled examples of the task, followed by a randomly sampled name from the phonebook for which the model needs to retrieve the correct phone number.

From Figure 6(c), we see that while Hawk can do reasonably well on the task for very short phonebook lengths, it fails to memorize and retrieve the correct phone number when the phonebook length grows, similar to the observation made by Jelassi et al. (2024) on the Mamba model’s performance on this task. This is not particularly surprising since Hawk uses a small fixed-size state. Our Transformer baseline can almost perfectly solve this task up to the training sequence length, but fails to retrieve the correct phone number for context lengths longer than the training sequence length. Interestingly, Griffin can perfectly solve this task up to a context length that matches its local attention window size of 1024, in spite of using only a single local attention layer. Once the context length is long enough such that the local attention window does not cover the whole phonebook, performance starts to degrade. Griffin is also able to extrapolate better to longer sequence lengths compared to Transformers. While the performance of Griffin is promising for the ability of models with fixed-size state to solve copying and retrieval tasks, our results suggest more work is needed to improve these capabilities for such models.

Related Works

The Transformer architecture has become a more scalable alternative to RNNs. Transformers achieve superior scalability through fully parallelized training, contrasting with the inherent limitations of RNNs. Due to their sequential processing structure, classical RNNs suffer from slow training speeds during both forward and backward propagation (Werbos, 1990). To mitigate this issue, researchers have explored alternative RNN-based methods. Notable examples include Quasi-RNNs (Bradbury et al., 2016), which combine convolutions and linear RNNs for greater parallelization, and the use of input-based gating mechanisms to parallelize linear RNN training (Martin and Cundy, 2017).

State-space Models (SSMs) have recently emerged as a powerful tool for modeling long input sequences. They demonstrated strong performance on tasks from the long-range arena benchmark (Tay et al., 2020), and audio generation (Goel et al., 2022). SSMs successfully integrate concepts from classical state-space models (Kalman, 1960) with those of RNNs. Their reliance on linear recurrences allows for efficient hidden state computation, either through parallel scan operations or convolutions, resulting in training speeds comparable to Transformer models. The S4 (Gu et al., 2021a) model proposed a sophisticated parameterization called normal plus low-rank to diagonalize the recurrence computation. The S4D parametrized the SSM directly with a diagonal state matrix and showed that it performed just as well while being much simpler (Gu et al., 2022). S5 also diagonalized the recurrence, and showed that the recurrence can be computed using the associative scan (Smith et al., 2022). The H3 model (Dao et al., 2022b) generalizes the recurrent interpretation of linear attention (Katharopoulos et al., 2020). Hyena (Poli et al., 2023) uses a similar architecture, but replaces the S4D layer with a global convolution kernel parametrized by an MLP. RetNet (Sun et al., 2023) uses a simpler SSM design with a gating mechanism which allows them to parallelize the computation using a variant of multi-head attention. Orvieto et al. (2023b) systematically analyzed and ablated multiple modifications to standard RNNs. Their finding showed that through better parameterization and initialization simplified linear RNNs (the LRU), perform just as well as other SSMs variants on various long-range tasks. RWKV (Peng et al., 2023) is a recent RNN, shown to be competitive on language modeling tasks, based on another linear attention approximation inspired by the attention-free Transformer (Zhai et al., 2021). Concurrent to our work Gu and Dao (2023) developed an SSM architecture called Mamba with an input dependant selection mechanism and showed that it achieves performance comparable to Transformers with efficient inference. Several extensions of Mamba have been proposed (Wang et al., 2024; Zhu et al., 2024) for different applications. An input-dependent gating similar to Mamba was also proposed by Gateloop (Katsch, 2023).

Linear attention (Katharopoulos et al., 2020) offers a computationally efficient approximation of the self-attention mechanism by linearizing the attention, which can be computed recurrently as a linear RNN. While this approach significantly reduces computational cost compared to full attention, it often comes with a trade-off in model performance. Flash Attention (Dao et al., 2022a) improves the training speed of attention on GPUs by making efficient use of the memory hierarchy. Another approach to reducing the computational cost of global attention, which is becoming increasingly more popular, is using sparse-local attention (Child et al., 2019) or sliding window attention (Jiang et al., 2023).

Conclusion

This work introduces Hawk; a recurrent model incorporating a novel gated linear recurrent layer, the RG-LRU. We also introduce Griffin; a hybrid model which mixes the RG-LRU layer with local attention. These models demonstrate exceptional language modeling performance across varying scales, with held-out loss exhibiting power-law scaling as compute resources increase. Hawk exceeds the reported performance of Mamba on downstream tasks when trained on half as many tokens, while Griffin slightly exceeds the performance of Llama-2 when trained on over 6 times fewer tokens. Furthermore, we empirically validate the inference-time advantages of Hawk and Griffin and observe reduced latency and significantly increased throughput compared to our Transformer baselines. Lastly, Hawk and Griffin exhibit the ability to extrapolate on longer sequences than they have been trained on and are capable of efficiently learning to copy and retrieve data over long horizons. These findings strongly suggest that our proposed models offer a powerful and efficient alternative to Transformers with global attention.

Acknowledgements

We thank Adam Paszke, Sharad Vikram, Trevor Gale, Sebastian Borgeaud, George Scrivener, Raia Hadsell, Oriol Vinyals, Toby Boyd, Zhifeng Chen, Chris Dyer, Kelvin Xu, Andriy Mnih for their guidance and advice. We make use of the DeepMind Jax ecosystem (Bradbury et al., 2018) and especially thank Andy Brock for building the internal framework we used for training and evaluating our models.

References

Appendix A RG-LRU Recurrence Gate

In Figure 7, we demonstrate the behavior of different gating mechanisms applied on the recurrent weight aa.

We implement our recurrence gate, as defined in Section 2.4, in a slightly different, but mathematically equivalent form, for numerical stability. In particular, we compute the logarithm of ata_{t} and then we exponentiate it, instead of computing a sigmoid and then taking a power:

Our gate is quite different than other standard gates in the literature. In particular, most gating mechanisms, like the one used in Mamba and GRU, allow through the gate to interpolate fully between the hidden state and the new observation. Ours on the other hand is biased towards retaining information, and does not allow to fully discard the contribution of ht1h_{t-1} (this depends, however, on the value of Λ\Lambda). To demonstrate this, we analyze the relative weight of xtx_{t} compare to ht1h_{t-1} in the output yty_{t}. For a general recurrence we define this as:

For our model we have α(rt)=at=acrt\alpha(r_{t})=a_{t}=a^{cr_{t}} and β(rt)=1α(rt)2\beta(r_{t})=\sqrt{1-\alpha(r_{t})^{2}}. For a standard GRU style gating we have α(rt)=1rt\alpha(r_{t})=1-r_{t} and β(rt)=rt\beta(r_{t})=r_{t}. For Mamba, assuming in their notation B=1,C=1B=1,C=1, then α(rt)=(1rt)A\alpha(r_{t})=(1-r_{t})^{-A} and β(rt)=(1α)/A\beta(r_{t})=(1-\alpha)/A. The behaviour of the different gating mechanisms is depicted in Figure 7, where for clarity we have also included the update value of the LRU (Orvieto et al., 2023b), which has no gating. As can be seen, the Mamba gating is almost identical to the GRU for values of AA close to 11, with minor deviations at smaller values. On the other hand, our gating mechanism performs a very different non-linear interpolation between fully discarding the input xtx_{t} and the update of the LRU.

Appendix B Complex-Gated Linear Recurrent Unit (CG-LRU)

With this we rewrite the equations for the LRU (see eq. 4) as:

Appendix C Model Scale Hyper-Parameters

In Table 2, we present the hyper-parameters of the models at different scales. These hyperparameters are shared for all the model families that we explored in this paper.

Appendix D Efficient Linear Recurrences on Device

The initial step in computational optimization lies in identifying the primary performance bottleneck on the target hardware. For most accelerators, the key limiting factors are computational throughput (FLOPs/s) and memory bandwidth between the high-bandwidth memory (HBM) and the fast vector memory (VMEM). While factors like HBM capacity and host-device communication are relevant, techniques such as ZeRO sharding and pipelined data transfer offer practical mitigations. Modern accelerator designs often prioritize a high FLOPs-to-byte ratio to accommodate workloads where computations significantly outnumber memory transfers. We show the key specification of the TPU-v3 pod (two chips per pod) in Table 3, which we use for all our experiments.

A typical matrix multiplication of a D×DD\times D matrix with a D×ND\times N matrix has 2ND22ND^{2} FLOPs and 2(D2+2ND)2(D^{2}+2ND) bytes to transfer (both read and write) which translates to NDD+N\frac{ND}{D+N} FLOPs/byte ratio. When D>>ND>>N and running on a TPU-v3 this implies that the dimension NN must be at least 136 to saturate the device, in which case the operation is “compute bound”, or otherwise most of the time will be spent on waiting for memory transfers, in which case the operation is “memory bound”.

D.2 Scan runtimes

In Figure 8(a) we demonstrate that on a TPU-v3 our Pallas kernel achieves nearly x3 speed up compared to the naive Jax implementation. In addition, the associative scan is significantly slower, even if fully run in bfloat16 precision. Figure 8(b) demonstrates that these gains also translate to significant improvements of the overall training time per step of the full Hawk model even at the 7b scale. For completeness we have also added the runtime of the associative scan, which can be up to 50% slower.

Appendix E The Local Attention Window Size of Griffin

Griffin uses both recurrent blocks as well as local attention layers in its temporal mixing blocks. For all experiments previously shown using a training sequence length of 2048, we use a local attention window size of 1024. We now investigate how the performance of different window sizes for the local attention layer varies with the training sequence length.

We consider 400M parameter models trained on sequence lengths of 2048, 4096 and 8192 tokens, where we keep the total number of training tokens fixed. For each sequence length, we train Griffin models using different local attention window sizes. As baselines, we train MQA Transformers using global attention layers, as well MQA Transformers using local attention layers with different window sizes. The results are shown in Figure 9, where the window sizes used are shown on top of each bar (MQA Transformer bars with window size equal to the training sequence length are the global attention MQA Transformer baseline).

From Figure 9, we see that remarkably, even when using a fixed window size of 1024 for the local attention layers in Griffin, it outperforms the global attention MQA Transformer baseline across all sequence lengths tested. However, it is worth noting that the performance gap between Griffin with local attention window 1024 and the global attention MQA Transformer reduces as the sequence length grows. Therefore, if the sequence length grows further, it is likely important to slowly also grow the local attention window size. In practice, the hardware used will also heavily determine the optimal local attention window size in terms of training and inference speed. Finally, we note that MQA Transformers purely using local attention (window sizes less than the training sequence length) perform significantly worse than both global attention MQA Transformers, as well as Griffin.

Appendix F Inference Speeds

The inference speed of language models at decode time is bounded by memory loading. As described already in 4.2 the linear RNN is memory bound. In the following we will show this is true for the other components (which are linear layers and self-attention) in our recurrent models and Transformer models.

F.2 Estimating the memory boundedness of linear layers

As shown in D.1 the outer dimension (usually consisting of batch BB and sequence length TT dimensions) must be at least 136 in order to be compute bound. At decode time T=1T=1 and if we assume B128B\lesssim 128 then any linear layers will be memory bound at decode time.

F.3 Estimating the memory boundedness of self-attention

In the following, we calculate the ratio of memory accesses to arithmetic operations for the attention computation for the LL-th decode step, to show it is also memory-bound.

To simplify the following analysis, we assume that we start from an empty prompt (or equivalently assume that the prefill contains 0 tokens).

When sampling auto-regressively from MHA or MQA, standard practice is to save the key and value vectors in a Key-Value (KV) cache. For LL tokens already sampled, the KV cache would therefore be of size 2×L×Hk×dhead2\times L\times H_{k}\times d_{head} for each sequence in the batch, where HkH_{k} denotes the number of heads used for the keys and values, and dheadd_{head} denotes the dimension of the key and value vectors in each head.

For sampling the LL-th token, once we calculate the query, key and value vectors corresponding to the LL-th token. The attention weights and the output of the attention layer are then computed using the LL-th key and value vectors in the KV cache. This requires O(LD)O(LD) operations overall and it requires loading the O(L×Hk×dhead)O(L\times H_{k}\times d_{head}) sized KV cache from HBM, for each sequence in the minibatch. The size of the KV cache, as well as the number of FLOPs, scales linearly with the batch size BB.

For MHA, the number of heads for the key and values HkH_{k} is typically equal to the number of heads used for the queries HH. For MQA, a single head is used for keys and values, i.e., Hk=1H_{k}=1. Therefore for MQA, the size of the KV cache is a factor of HkH_{k} smaller (i.e., of size 2×L×dhead2\times L\times d_{head}).

For a batch size of BB, the memory access to FLOPs ratio for the attention computation goes as O(B×L×Hk×dheadB×L×D)O(\frac{B\times L\times H_{k}\times d_{head}}{B\times L\times D}). For typical Transformer architectures, D=H×dheadD=H\times d_{head} and further Hk=HH_{k}=H for MHA and Hk=1H_{k}=1 for MQA. Therefore the memory access to flops ratio is O(1)O(1) for MHA and O(1/H)O(1/H) for MQA. As explained in 3, in order to be compute bound on TPU-v3 a FLOPs-to-byte ratio of 136 is required, and therefore both MHA and MQA would typically be memory bound. Nevertheless, MQA significantly speeds up Transformer inference (when compared to MHA), since it lowers the memory boundedness by a factor of HH.

F.4 Cache sizes

In the following we do an analysis of the relative sizes of caches used in our recurrent and Transformers. All caches sizes scale linearly with batch size and in the following we assume B=1B=1.

For attention, the KV cache has size 2NThkdhead2NTh_{k}d_{head}, where NN denotes the number of attention layers (the depth), TT denotes the length of the sequence, hkh_{k} denotes the number of KV heads and dheadd_{head} denotes the head dimension. Throughout this work, dhead=128d_{head}=128. For MHA, hkdhead=Dh_{k}d_{head}=D, while for MQA, hk=1h_{k}=1. (We therefore expect MQA to be faster when decoding long sequences than MHA, since the size of the KV cache is significantly smaller and less memory needs to be moved.)

For either MHA or MQA the size of the KV cache can exceed the number of model parameters when the sequence length TT is large. We therefore expect to observe a transition from a ‘parameter bound’ regime when the sequence length is short, during which the decoding speed is dominated by the time taken to load the model parameters on device, to a ‘cache bound’ regime for large sequences, where the decoding speed is dominated by the time taken to load the KV cache.

F.4.2 The size of the recurrent state

The recurrent state of a single RG-LRU layer has size DRNND_{RNN}, and the total state size for the entire Hawk model is NDRNN4BND/3ND_{RNN}\approx 4BND/3. Unlike the KV cache, this state does not grow with sequence length and is very small in comparison to parameter size. We therefore expect the decoding speed of RG-LRU to be dominated by the time taken to load the model parameters on device at all sequence lengths.

A similar consideration applies to the size of the 1D convolution state size. For a temporal filter width of size 4, the state has size (41)DRNN=3DRNN=4D(4-1)D_{RNN}=3D_{RNN}=4D for each recurrent block which is also substantially smaller than parameter sizes.

F.4.3 The local attention cache

A single local MQA layer has cache size upper bounded by 2TWSdhead2T_{WS}d_{head}, where TWST_{WS} denotes the local attention window size. So long as TWSD2/(Bdhead)T_{WS}\lesssim D^{2}/(Bd_{head}), the size of the local attention cache is also small relative to the parameter count. We therefore expect the decoding speed of Griffin to be similar to the decoding speed of the Hawk model.

Appendix G Improving Next Token Prediction with Longer Contexts: Additional Results

Figure 10 shows an additional result demonstrating next token prediction performance at different context lengths on a held out dataset of arXiv articles. We find that the results on this dataset are qualitatively similar to the results shown in Figure 5.

Appendix H Additional Details of the Copy and Retrieval Tasks

Figure 11 is an illustration of the Selective Copying and Induction Heads tasks.

In the Selective Copying task, the model needs to learn to copy data tokens (coloured tokens in Figure 11) from a sequence while ignoring noise tokens (white tokens in Figure 11). Crossed out tokens in the output in Figure 6 denote tokens that are masked out in the loss.

In the Induction Heads task, the model needs to learn to recall the token immediately following a special token (black token in Figure 11). As before, crossed out tokens in the output denote tokens that are masked out in the loss.