Hungry Hungry Hippos: Towards Language Modeling with State Space Models

Daniel Y. Fu, Tri Dao, Khaled K. Saab, Armin W. Thomas, Atri Rudra, Christopher Ré

Introduction

State space models (SSMs) have achieved state-of-the-art sequence modeling performance in domains ranging from time series analysis to audio generation . However, they have yet to match the performance of Transformers on language modeling, often underperforming Transformers by multiple points in perplexity . An natural question is whether this gap in performance is due to inherent inductive biases and capabilities in attention , or whether it is a function of the significant organizational resources that have been spent training and tuning large attention-based language models , as well as specialized hardware support for attention, ranging from tensor cores to transformer chips .

We take first steps towards answering these questions in this paper. First, we use synthetic language modeling tasks to show that there is an expressivity gap between SSMs and attention. Using our insights, we design a new SSM layer that nearly matches attention in language modeling. Second, we propose better hardware-aware algorithms for SSMs that allow them to take advantage of modern accelerators—and run faster than attention.

Understanding the Expressivity Gap. To understand the gap between SSMs and attention, we draw on synthetic language modeling tasks that have been proposed as a mechanistic basis for in-context learning in Transformers These synthetic languages focus on the ability to manipulate text—recalling tokens from earlier time steps, or comparing tokens from different points in a sequence. We find that existing SSMs struggle to model these synthetic languages. To probe how important these skills are for language modeling, we propose H3 (Hungry Hungry Hippo), a new SSM-based layer designed to solve these language modeling tasks. H3 stacks two SSMs, with multiplicative interactions between their outputs and input projections. The SSMs allow H3 to keep a log of tokens (to recall them later), while the multiplicative interactions allow for comparisons across the sequence.

H3 matches attention on the synthetic languages and almost closes the gap with Transformers on language modeling—coming within 0.4 perplexity of Transformers on OpenWebText (compared to 3.4 ppl for existing SSMs—even those explicitly designed for language modeling ). Furthermore, a simple hybrid H3-attention model that retains two attention layers surprisingly outperforms Transformers on OpenWebText by 1.0 perplexity. To further evaluate H3 on language modeling, we train 125M-, 355M-, 1.3B-, and 2.7B-parameter hybrid H3-attention language models on the Pile , using hyperparameters from GPT-3 . These hybrid models outperform Transformer-based language models of the same size in perplexity, and match or outperform them on a majority of tasks in the SuperGLUE benchmark in zero- and few-shot learning. Since the SSM layers in these hybrid models admit a recurrent view, they can also perform 2.4×\times faster inference than Transformers.

Scaling SSMs. Next, we improve the efficiency of SSMs on modern hardware, to reduce the hardware barrier between attention and SSMs. SSMs scale nearly linearly in sequence length instead of quadratically like attention, but still run slower on modern hardware due to poor hardware utilization. To close this gap, we propose FlashConv, a hierarchical algorithm for computing SSMs, inspired by IO-Aware attention . The technical challenge is that SSMs require a FFT-based convolution over the input sequence, which requires an FFT, pointwise multiply, and inverse FFT. When implemented in cuFFT , this operation incurs expensive GPU memory reads/writes, and cannot utilize the specialized matrix multiply units available on modern hardware111An A100 GPU has a maximum of 312 TFLOPs/s of FP16 with tensor cores, but only 20 TFLOPs/s of FP32 (and 40 TFLOPs/s of FP16) without tensor cores . This trend started with the V100 GPUs and has continued with the H100 GPUs .. To use specialized matrix multiply units, we appeal to classical techniques that split the FFT into blocks and compute it using a series of matrix multiplications. Combined with kernel fusion, this “block” FFT solution increases hardware efficiency, but only as long as the sequence length can fit into GPU SRAM (on-chip memory, analogous to L1 cache on the CPU)—up to sequence length 8K on modern A100.

To scale to sequences longer than 8K, we propose a state passing algorithm (Figure 1 right), specialized to SSMs. The key insight is that we can use the recurrent properties of SSMs to process the input in chunks—as long as we keep track of an additional state vector. The state passing algorithm splits the input into the largest chunks that can fit into GPU SRAM, efficiently computes the FFT-based convolution using block FFT, and updates an intermediate state to start the next chunk. Using this state-passing algorithm, FlashConv can scale SSMs to any sequence length—even longer than can fit on GPU SRAM at once—while maintaining a near linear compute complexity. FlashConv sets state-of-the-art speed on long range arena using S4 , outperforming Transformers by 5.8×\times and previous S4 models by 2×\times. FlashConv trains H3 4-8×\times times faster than attention for long sequences, and is a critical component for scaling to billion-parameter models222Code for H3 is available at https://github.com/HazyResearch/H3 .

Background

We present some background on state space models and linear attention, which inspired our H3 layer.

A continuous-time state-space representation defines a linear mapping from an input signal u(t)Ru(t)\in\mathbb{R} (as a function of time tt) to an output signal y(t)Ry(t)\in\mathbb{R} through a state-variable x(t)Rmx(t)\in\mathbb{R}^{m}, with the following differential equation, for some matrices ARm×m\mathbf{A}\in\mathbb{R}^{m\times m}, BRm×1\mathbf{B}\in\mathbb{R}^{m\times 1}, CR1×m\mathbf{C}\in\mathbb{R}^{1\times m}, DR1×1\mathbf{D}\in\mathbb{R}^{1\times 1}: x˙(t)=Ax(t)+Bu(t)\dot{x}(t)=\mathbf{A}x(t)+\mathbf{B}u(t), y(t)=Cx(t)+Du(t)y(t)=\mathbf{C}x(t)+\mathbf{D}u(t).

Similarly, a discrete-time state-space representation defines a linear mapping from a discrete input signal uiu_{i} (for i=1,2,i=1,2,\dots) to a discrete output signal yiy_{i} though a state-variable xiRmx_{i}\in\mathbb{R}^{m}:

𝐀subscript𝑥𝑖1𝐁subscript𝑢𝑖\displaystyle=\mathbf{A}x_{i-1}+\mathbf{B}u_{i} yi\displaystyle y_{i} =Cxi+Dui.\displaystyle=\mathbf{C}x_{i}+\mathbf{D}u_{i}. A state-space model (SSM) uses these representations as a layer in a deep learning pipeline, where the matrices A,B,C,D\mathbf{A},\mathbf{B},\mathbf{C},\mathbf{D} are learned from data (e.g., with gradient-based optimization). One often has dd of these SSMs in parallel, each corresponding to one hidden dimension. To preserve the sequence history, HiPPO projects the history on a basis of orthogonal polynomials, which translates to having SSMs whose A,B\mathbf{A},\mathbf{B} matrices are initialized to some special matrices.

This recurrent form of SSMs allows efficient inference (i.e., generation): to generate the output of the next time-step, one only needs the state of the current time-step, not the entire input history. Furthermore, SSMs can freely extrapolate to sequences longer than seen during training.

SSMs as Convolution. For efficient training, given the entire sequence of the input u1,,uNu_{1},\dots,u_{N}, the output sequence y1,,yNy_{1},\dots,y_{N} can also be written as the convolution of the input with the filter :

That is, from an initial condition x0x_{0}, we have yi=CAiBx0+(fu)i+Duiy_{i}=\mathbf{C}\mathbf{A}^{i}\mathbf{B}x_{0}+(f\ast u)_{i}+\mathbf{D}u_{i}, where (fu)(f\ast u) denotes a linear convolution between ff and uu. If we set the initial condition x0x_{0} to be zero, then yy is exactly a linear convolution of uu, with a residual connection Du\mathbf{D}u. More generally, any linear time-invariant system (of which SSMs are a special case) can be written as a convolution.

Given a 1D input sequence uRNu\in\mathbb{R}^{N} of length NN, we denote the 1D output sequence yRNy\in\mathbb{R}^{N} of an SSM parameterized by matrices A,B,C,D\mathbf{A},\mathbf{B},\mathbf{C},\mathbf{D} as

To simplify notation, we omit the reference to A,B,C,D\mathbf{A},\mathbf{B},\mathbf{C},\mathbf{D} and write y=SSM(u)y=\mathrm{SSM}(u) if they are clear from context. When uu is multidimensional of dimension dd, we stack dd of these SSMs together that defines a mapping from uRN×du\in\mathbb{R}^{N\times d} to yRN×dy\in\mathbb{R}^{N\times d}, using the same notation y=SSM(u)y=\mathrm{SSM}(u).

To construct the filter ff from A,B,C\mathbf{A},\mathbf{B},\mathbf{C} efficiently, A\mathbf{A} is often constrained to be diagonal , or diagonal plus low-rank .

SSM through FFTs. Computing the convolution naively through conventional matrix operations is expensive for long kernels, scaling as O(N2)O(N^{2}). Instead, we can use FFTs: take the FFT of ff and uu, multiply them together pointwise, and then take the inverse FFT. This yields an O(NlogN)O(N\log N) algorithm.

2 Linear attention

We describe linear attention and its connection to RNNs, which inspired our model design (Section 3).

In standard attention , we have NN query/key/value tokens Qi,Ki,ViRdQ_{i},K_{i},V_{i}\in\mathbb{R}^{d} for i=1,,Ni=1,\dots,N, where NN is the sequence length and dd is the head dimension. For some similarity metric Sim ⁣:Rd×RdR\mathrm{Sim}\colon\mathbb{R}^{d}\times\mathbb{R}^{d}\to\mathbb{R}, we want to compute the output:

For standard softmax attention, Sim(q,k)=eqk\mathrm{Sim}(q,k)=e^{q^{\top}k} (often the dot product is scaled by 1/d1/\sqrt{d}). Linear attention makes the assumption that Sim\mathrm{Sim} has the form Sim(q,k)=ϕ(q)ϕ(k),\mathrm{Sim}(q,k)=\phi(q)^{\top}\phi(k), for some (nonlinear) function ϕ\phi. The output is then Oi=ϕ(Qi)j=1iϕ(Kj)Vjϕ(Qi)j=1iϕ(Kj)O_{i}=\frac{\phi(Q_{i})^{\top}\sum_{j=1}^{i}\phi(K_{j})V_{j}^{\top}}{\phi(Q_{i})^{\top}\sum_{j=1}^{i}\phi(K_{j})}. Let Si=j=1iϕ(Kj)VjRd×dS_{i}=\sum_{j=1}^{i}\phi(K_{j})V_{j}^{\top}\in\mathbb{R}^{d\times d}, zi=j=1iϕ(Kj)Rdz_{i}=\sum_{j=1}^{i}\phi(K_{j})\in\mathbb{R}^{d}, di=ϕ(Qi)ziRd_{i}=\phi(Q_{i})^{\top}z_{i}\in\mathbb{R}. Then Oi=ϕ(Qi)SidiO_{i}=\frac{\phi(Q_{i})^{\top}S_{i}}{d_{i}}. This connects linear attention to RNNs: the output OiO_{i} is a function of SiS_{i} and ziz_{i}, both of which are incrementally updated (as cumulative sums).

Hungry Hungry Hippos Layer to Model Discrete Sequences

To understand the gap between SSMs and attention on language modeling, we examine two synthetic language modeling tasks. These tasks motivate our H3 layer to add a discrete SSM (based on shift matrix) and multiplicative interaction to effectively model discrete sequences. We then show that the H3 layer is expressive enough to solve these synthetic tasks, and that this understanding leads to better performance on a real language modeling benchmark.

We describe two closely-related synthetic tasks, summarized in Table 1. Olsson et al. argue that the ability to solve (variants of) these tasks accounts for the majority of the in-context learning capability of Transformers, and more intuition is given in Appendix E.

The Induction Head task tests how well a model can recall content after a special token (e.g., \vdash in Table 1). At the end of the sequence, the model must recall the token that appeared immediately after the special token earlier in the sequence. Associative Recall is similar to the induction head task, but requires the model to remember multiple key-value pairs. At the end of the sequence, the model must recall a specific value belonging to a specific key.

Table 2 (for two-layer models) shows that S4D and Gated State Spaces both fail to model these synthetic languages, which suggests they may not have the expressivity for general language. We argue that these failures suggest two missing capabilities: (i) to remember tokens that appear after a particular event (e.g., the special token in the induction head task), and (ii) to compare tokens across the sequence (e.g., comparing keys to decide which value to recall). Attention has both these capabilities: it can compare tokens by constructing the quadratic attention matrix QK\mathbf{Q}\mathbf{K}^{\top}, and it can recall tokens by direct copying (multiplying softmax(QK)\mathrm{softmax}(\mathbf{Q}\mathbf{K}^{\top}) with V\mathbf{V}). In Section 3.2, we design our new layer H3 to enable these capabilities in SSMs, narrowing the expressivity gap between SSMs and attention.

2 H3 Layer

H3 uses SSMs with shift and diagonal matrices, along with multiplicative operations against projections of the input to capture the missing capabilities identified by the synthetics.

High-level Intuition. (i) To remember tokens from the past, we want the state xix_{i} to copy from the input uiu_{i}, and then pass that information to the next state xi+1x_{i+1}. As xi+1x_{i+1} relates to xix_{i}by Axi\mathbf{A}x_{i}, we use a discrete SSM with a shift matrix A\mathbf{A} (described formally below) that shifts the elements of a state vector (e.g., mapping [a,b,c][0,a,b][a,b,c]\to[0,a,b]). (ii) To compare tokens across the sequence, we use multiplicative interaction: the output of an SSM, containing information from previous time steps, is multiplied with the input at the current time steps, thus measuring similarity between tokens.

H3 is loosely inspired by linear attention (Section 2): we project the input uu to get three signals Q,K,V\mathbf{Q},\mathbf{K},\mathbf{V}. Then we replace the non-linearity ϕ(K)\phi(\mathbf{K}) with an SSM where A\mathbf{A} is a shift matrix (SSMshift\mathrm{SSM}_{\mathrm{shift}}), and we replace the summation SiS_{i} with a SSM with diagonal A\mathbf{A} (SSMdiag\mathrm{SSM}_{\mathrm{diag}}). The output, for the case of head dimension dh=1d_{h}=1, is:

where \odot denotes pointwise multiplication. We can view this form as stacking two SSMs with multiplicative interaction (each is a “hungry hippo”, hence the name of our layer). A more formal connection between linear attention, time-varying systems, and H3 can be found in Appendix B.

Remembering Key Tokens: Shift and Diagonal SSMs. The shift and diagonal SSMs are designed to address the capability to log tokens after particular events. In the shift SSM, we constrain ARm×m\mathbf{A}\in\mathbb{R}^{m\times m} to be a shift matrix Ai,j={1for i1=j0otherwise\mathbf{A}_{i,j}=\begin{cases}1&\text{for }i-1=j\\ 0&\text{otherwise}\end{cases}. The action of this matrix on the hidden state xix_{i} is to shift each coordinate down by one—thereby creating a “memory” of the previous states. For example, if B=e1\mathbf{B}=e_{1}, the first basis vector, then xi=[ui,ui1,,uim+1]x_{i}=[u_{i},u_{i-1},\dots,u_{i-m+1}] contains the inputs from the previous mm time steps. We learn B\mathbf{B} and C\mathbf{C} (B\mathbf{B} can also be fixed to e1e_{1} for simplicity, in which case the output is a 1D conv. with kernel size mm).

The diagonal SSM constrains A\mathbf{A} to be diagonal and initializes it from the diagonal version of HiPPO (S4D ). This parameterization allows the model to remember state over the entire sequence. The shift SSM can detect when a particular event occurs, and the diagonal SSM can remember a token afterwards for the rest of the sequence.

Multiplicative Interaction for Comparison. We take the multiplicative interactions from linear attention, but they provide another missing capability when combined with a shift matrix: comparing tokens across the sequence. The multiplicative interactions between the output of the shift SSM and the V\mathbf{V} projection mimics local multiplicative interactions in linear attention (depending on the size of the hidden state). Similarly, multiplicative interactions with the Q\mathbf{Q} projection and the output of the diagonal SSM allows comparisons between tokens over the entire sequence.

H3 Layer. The overall layer is given in Algorithm 1 and shown schematically in Figure 1 (left). We use the H3 layer to construct a model in the same style as Transformers by interleaving it with MLPs, connected by residual connection and layer norm (i.e., pre-norm architecture ). We will also consider a hybrid H3-attention model (two attention layers while the rest are H3, Sections 3.3 and 5).

We show that H3 scales as O(NlogN)O(N\log N) with sequence length NN—asymptotically more efficient than attention, which typically requires O(N2d)O(N^{2}d) time and O(N2)O(N^{2}) space333There are several memory-efficient algorithms for attention , though their time complexity is still quadratic in NN, which is a lower-bound for attention . (proof in Section D.3).

Let NN be the sequence length, dd be the hidden dimension, and assume that the head dimension dhd_{h} is of order O(1)O(1). Then the H3 layer takes O(d2N+dNlogN)O(d^{2}N+dN\log N) time and O(dN)O(dN) space to compute.

3 Expressivity

We show that H3 can model our synthetic languages, as well as natural language on OpenWebText . We also present a hybrid H3-attention extension that outperforms Transformers on OpenWebText.

Mechanism for Solving Associative Recall with H3. H3 is expressive enough to solve our synthetic language modeling tasks, as shown in Table 2. Figure 1 (middle) shows a mechanism for a single H3 layer to solve the associative recall task for a particular key-value pair (a,3)(a,3). The shift SSM and following multiplicative interaction act as a gate on whether to let a value through to the diagonal SSM, based on whether the previous token was key aa. The diagonal SSM stores the value 33 in memory, and continually outputs it. The final multiplicative interaction gates whether to let the diagonal SSM’s output through—based on whether the current input token is the key aa. We formally construct the weights of an H3 layer to solve this task in Appendix D.1.

Better Synthetic Language Modeling Translates to Better Natural Language Modeling. We validate that when H3 can solve these synthetic tasks, it also improves the modeling capability on natural language (e.g., on the OpenWebText dataset). As shown in Table 3, H3 comes within 0.4 perplexity points of Transformers when trained for 50B tokens on OpenWebText, and performs much better than existing SSM variants (S4D, GSS), by 33.93-3.9 points.

Extension: H3-attention Hybrid Model. A simple hybrid H3-attention language model surprisingly outperforms Transformers on OpenWebText by 1.0 point. Our hybrid model simply retains two self-attention layers: one in the second layer, and one in the middle (layer 2+N/22+N/2 for an NN-layer model, NN even). The H3-attention hybrid also outperforms the GSS-attention hybrid .

FlashConv: Efficiently Training SSMs

To improve the efficiency of SSMs on modern hardware, we propose FlashConv. FlashConv fuses the FFT, pointwise multiply, and inverse FFT to reduce memory reads/writes. It also uses a block FFT algorithm to make use of specialized matrix multiply units (e.g., tensor cores on A100) for sequence lengths up to 8K. For sequences longer than 8K, the computation no longer fits in GPU SRAM444SRAM, or on-chip memory, is much faster than off-chip GPU memory, but usually much smaller, on the order of around 100KB for each streaming processor., so we propose a novel state-passing algorithm that splits the sequence into chunks to compute the FFT convolution one chunk at a time. FlashConv can speed up any SSMs (not just H3).

We deploy two techniques to speed up the FFT-based convolution for sequences shorter than 8K: kernel fusion and block FFT. Kernel fusion addresses IO bottlenecks due to reading and writing of intermediate results, while block FFT allows the FFT-based convolution to utilize specialized matrix multiplication units. These techniques allow us to speed up FFTConv by 2×\times (Section 6) for sequences shorter than 8k.

Kernel Fusion. Naive implementations of FFTConv using standard libraries such as cuFFT are IO-bound due to repeated reading and writing of intermediate results. The FFT convolution in an SSM with input uu and filter ff has the form iFFT(FFT(u)FFT(f))iFFT(FFT(u)\odot FFT(f)) (where \odot denotes pointwise multiplication). It requires reading and writing intermediate results to GPU memory—which can dominate the runtime. Following FlashAttention , we first fuse the entire FFTConv into a single kernel and compute it in SRAM to avoid this overhead.

Block FFT. To further speed up the computation of FFT-based convolution, we exploit specialized matrix multiplication hardware on modern GPUs (e.g., Tensor Cores on Nvidia GPUs perform fast 16×1616\times 16 matrix multiplication). We appeal to classical results that show that the FFT can be written as a series of block-diagonal matrix multiplications interleaved with permutation. We note that such algorithms are not new, but our setting (fused FFTConv on GPU) introduces new bottlenecks—by removing the IO bottlenecks, compute becomes the bottleneck (note that a single FFT on GPU is usually IO bound).

Suppose that we want to perform an NN-point FFT, which is equivalent to multiply by the DFT matrix FN\mathbf{F}_{N}. Suppose that N=N1N2N=N_{1}N_{2} for some integers N1,N2N_{1},N_{2}. By the Cooley-Tukey decomposition of the DFT (also known as the four-step FFT algorithm), we can write FN=P(IN2FN1)PD(IN1FN2)P\mathbf{F}_{N}=\mathbf{P}(\mathbf{I}_{N_{2}}\otimes\mathbf{F}_{N_{1}})\mathbf{P}^{\top}\mathbf{D}(\mathbf{I}_{N_{1}}\otimes\mathbf{F}_{N_{2}})\mathbf{P}, where P\mathbf{P} denotes a fixed permutation that reshapes the input as a N1×N2N_{1}\times N_{2} array and then transpose it, \otimes denotes Kroneker product, D\mathbf{D} is a N×NN\times N diagonal matrix (called the twiddle factors) , and INi\mathbf{I}_{N_{i}} and FNi\mathbf{F}_{N_{i}} are the identity and DFT matrix of size Ni×NiN_{i}\times N_{i}. As IN2FN1\mathbf{I}_{N_{2}}\otimes\mathbf{F}_{N_{1}} and IN1FN2\mathbf{I}_{N_{1}}\otimes\mathbf{F}_{N_{2}} are just block-diagonal matrices, we can make use of specialized matmul units to perform these multiplications. Similarly, if N=N1N2N3N=N_{1}N_{2}N_{3} then we can decompose the NN-point FFT into a series of (block) FFT of size N1N_{1}, N2N_{2}, and N3N_{3}, interleaved by permutation.

The block FFT algorithm incurs O(NrlogN/logr)O(Nr\log N/\log r) FLOPs for a sequence length NN, if NN can be written as rpr^{p} for two integers r,pr,p. This incurs more FLOPs than standard FFT (O(NlogN))(O(N\log N)), but can run faster when we using specialized matrix multiplication hardware.

2 State-Passing

However, the fused kernel cannot run if the sequence is too long to fit into GPU SRAM (longer than 8K on A100). We show how to exploit the particular form of the FFT in SSM to speed it up for long sequences.

The recurrent nature of SSMs allows us to split the FFTConv of a length-NN sequence into chunks of size NN^{\prime} each (NN^{\prime} is the longest FFT we can fit into SRAM), assuming NN is a multiple of NN^{\prime}). We use FFTConv to compute each chunk and use a recurrence to connect the chunks. In particular, we split the inputs uu into C=N/NC=N/N^{\prime} chunks u(c)RNu^{(c)}\in\mathbb{R}^{N^{\prime}} for c=1,,Cc=1,\dots,C. Similarly, split the states xx into x(c)RN×mx^{(c)}\in\mathbb{R}^{N^{\prime}\times m} and the output yy into y(c)RNy^{(c)}\in\mathbb{R}^{N^{\prime}} for i=1,,Ci=1,\dots,C. We will only need the end-state xN(c)x_{N^{\prime}}^{(c)} of each chunk cc.

Let f=[CB,CAB,CA2B,,CAN1B]f=[\mathbf{C}\mathbf{B},\mathbf{C}\mathbf{A}\mathbf{B},\mathbf{C}\mathbf{A}^{2}\mathbf{B},\dots,\mathbf{C}\mathbf{A}^{N^{\prime}-1}\mathbf{B}] be the SSM filter. Recall from Section 2 that for each chunk cc, yi(c)=CAiBxN(c1)+(fu(c))i+Dui(c)y_{i}^{(c)}=\mathbf{C}\mathbf{A}^{i}\mathbf{B}x_{N^{\prime}}^{(c-1)}+(f\ast u^{(c)})_{i}+\mathbf{D}u_{i}^{(c)}, since xN(c1)x_{N^{\prime}}^{(c-1)}, the end-state of the previous chunk (c1)(c-1) is the initial condition for the current chunk cc. In vector notation, y(c)=MxyxN(c1)+fu(c)+Du(c)y^{(c)}=\mathbf{M}_{xy}x_{N^{\prime}}^{(c-1)}+f\ast u^{(c)}+\mathbf{D}u^{(c)} for some matrix MxyRN×m\mathbf{M}_{xy}\in\mathbb{R}^{N^{\prime}\times m}. Additionally we need to update the end-state of each chunk with xNc=ANxN(c1)+Muxu(c)x_{N^{\prime}}^{c}=\mathbf{A}^{N^{\prime}}x_{N^{\prime}}^{(c-1)}+\mathbf{M}_{ux}u^{(c)} for some matrix Muxm×N\mathbf{M}_{ux}^{m\times N^{\prime}} (derivation in Appendix C.2). In essence, we can compute the output for each chunk with FFT-based convolution as long as we remember the end-state of the previous chunk, and the end-state of each chunk can be updated recurrently. This yields a state-passing algorithm for long sequences, where we only compute FFT of length NN^{\prime}, and update some hidden state each iteration.

Let BlockFFTConv refer to our fused block FFTConv kernel. Then, the state-passing algorithm for 1D input is given by Algorithm 2. For inputs of dimension dd where we stack dd SSMs, we simply batch Algorithm 2 along the dd-dimension.

y^{(c)}=\mathbf{M}_{xy}x_{N^{\prime}}^{(c-1)}+ BlockFFTConv(ff, uju_{j}) +Du(c)RN+\mathbf{D}u^{(c)}\in\mathbb{R}^{N^{\prime}}. 6: Update state: xN(c)=ANxN(c1)+Muxu(c)x_{N^{\prime}}^{(c)}=\mathbf{A}^{N^{\prime}}x_{N^{\prime}}^{(c-1)}+\mathbf{M}_{ux}u^{(c)}. 7: end for 8: Return y=[y(1)y(C)]y=[y^{(1)}\dots y^{(C)}]. We prove that Algorithm 2 yields the same output as if one has computed the SSM using a large FFT of size NN (proof in Section D.4):

For input uRNu\in\mathbb{R}^{N} and matrices A,B,C,D\mathbf{A},\mathbf{B},\mathbf{C},\mathbf{D}, the output yRNy\in\mathbb{R}^{N} returned by Algorithm 2 is the same as the output defined by the SSM parameterized by A,B,C,D\mathbf{A},\mathbf{B},\mathbf{C},\mathbf{D}.

H3 Evaluation

To understand how well capturing the synthetics in Section 3.1 translates to language modeling, we train two hybrid hybrid H3-attention language models at sizes 125M, 355M, 1.3B, and 2.7B, and we evaluate their performance against Transformers. The hybrid models match or exceed the quality of Transformers in perplexity and zero/few-shot learning. We also validate that H3 models retain strong performance on non-text sequence modeling. Appendix F contains additional experiments on more datasets, length extrapolation, and scaling with data.

We compare hybrid H3-attention language models against Transformer-based language models. We evaluate language modeling performance using perplexity, zero-shot learning, and few-shot learning performance. Hybrid H3 models outperform Transformers, which suggests that closing the gap between SSMs and attention on the synthetic languages translates to real language modeling capabilities. We also report the generation speed of hybrid H3 models compared to Transformers; since SSMs are recurrent models, they can generate tokens 2.4×\times faster than Transformers. Appendix F shows performance of pure H3 language models on these same evaluation metrics.

We train hybrid models at sizes 125M, 355M, 1.3B, and 2.7B on the Pile for 400B tokens. We compare against checkpoints of equivalent sizes from Open-AI and GPT-Neo555There is no pretrained GPT-Neo at the 350M size. , from HuggingFace .

Table 4 shows perplexity on the Pile , OpenWebText , and WikiText-103 . On the Pile, our 125M hybrid model outperforms GPT-Neo, which was also trained on the Pile. Our hybrid models also outperform GPT-Neo models and GPT-2 models on zero-shot transfer to OpenWebText and WikiText103. We report the PPL of GPT-2 models for context, though the performance is not directly comparable since they were trained on different data.

We compare the zero- and few-shot performance of hybrid H3 language models against OPT , GPT-Neo, and GPT-2 models, where public checkpoints are available. We report performance with rank classification on the logits of the possible choices (see Appendix F.7 for raw generation). Table 5 reports zero-shot performance on the SuperGLUE benchmark, and Table 6 reports the 3-shot performance. Following the perplexity results, the hybrid language models outperform or match the best the Transformer baseline on more than half the tasks.

Finally, since SSMs are recurrent models, they admit faster text generation than Transformers. Table 7 shows inference throughput of a 1.3B-parameter hybrid model compared to a Transformer. The hybrid model has up to 2.4×\times higher throughput.

FlashConv Evaluation

We evaluate how well FlashConv speeds up SSMs. FlashConv sets state-of-the-art performance on the long range arena benchmark using S4 . We report performance of training H3 module with FlashConv compared to attention at various sequence lengths, from 256 to 32K and demonstrate nearly linear scaling.

The Long Range Arena (LRA) benchmark is a benchmark for long-range sequence modeling. The state-of-the-art approach, S4 , is an SSM. Table 8 shows that FlashConv accelerates S4 by 2×\times, outperforming Transformers by 5.8×\times.

We benchmark the time to run the forward and backward pass of H3 with FlashConv against attention. FlashConv maintains nearly linear scaling, even to very long sequence lengths. Fig. 2 shows overall 2-3×\times speedup over FFTConv with cuFFT using our techniques (block FFT, state-passing). Simple kernel fusion (even without block FFT) can yield speedup over cuFFT for short sequences, since memory reads/writes are the bottleneck for short sequences. For long sequences, SSMs using state passing can be dozens of times faster than even the fastest attention implementation.

Conclusion

Our main goal is to understand and narrow the gap between attention and SSMs in language modeling in terms of modeling capabilities and hardware efficiency. Our exploration based on synthetic language tasks motivated us to design the H3 layer, which is surprisingly competitive with attention. Our BlockFFTConv algorithm exploits matrix multiplication units and the dual recurrent–convolution view of SSMs to substantially speed up SSMs, reducing the hardware barrier between attention and SSMs. We are excited about several future directions. Our H3 layer is a simple combination of two SSMs, and more sophisticated designs could be more expressive. Our encouraging results on language models up to 1.3B parameters suggests that scaling SSMs to larger sizes is a promising avenue. Since simply adding two attention layers to H3 models already outperforms both the pure H3 model and Transformers, we are optimistic about combining the complementary strengths of SSMs and attention in the future.

Acknowledgments

We thank Albert Gu for helpful discussion regarding the model architecture, and more importantly for sending us daily hippo videos. We thank Together Computer for providing portions of the compute used to train models in this paper. We gratefully acknowledge the support of NIH under No. U54EB020405 (Mobilize), NSF under Nos. CCF1763315 (Beyond Sparsity), CCF1563078 (Volume to Velocity), and 1937301 (RTML); ARL under No. W911NF-21-2-0251 (Interactive Human-AI Teaming); ONR under No. N000141712266 (Unifying Weak Supervision); ONR N00014-20-1-2480: Understanding and Applying Non-Euclidean Geometry in Machine Learning; N000142012275 (NEPTUNE); NXP, Xilinx, LETI-CEA, Intel, IBM, Microsoft, NEC, Toshiba, TSMC, ARM, Hitachi, BASF, Accenture, Ericsson, Qualcomm, Analog Devices, Google Cloud, Salesforce, Total, the HAI-GCP Cloud Credits for Research program, the Stanford Data Science Initiative (SDSI), Department of Defense (DoD) through the National Defense Science and Engineering Graduate Fellowship (NDSEG) Program, Wu Tsai Neuroscience Stanford Interdisciplinary Graduate Fellowship, and members of the Stanford DAWN project: Facebook, Google, and VMWare. The U.S. Government is authorized to reproduce and distribute reprints for Governmental purposes notwithstanding any copyright notation thereon. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the authors and do not necessarily reflect the views, policies, or endorsements, either expressed or implied, of DARPA, NIH, ONR, or the U.S. Government. Atri Rudra’s research is supported by NSF grant CCF-1763481.

References

Appendix A Related Work

State space models have shown promise in modeling sequential data, including time series data , audio , and visual data . Our model builds off work on simplifying and parameterizing diagonal versions of S4 . Gated state spaces also aim to adapt SSMs to language modeling, but our results suggest that the GSS model does not perform as well as H3 (or even as well as earlier SSMs like S4D). The idea to combine SSMs with attention in hybrid models is not new; Mehta et al. also showed that interleaving attention with their GSS layer can improve performance, which we also validate on our OpenWebText experiments. These positive results suggest that attention and SSMs are complementary, and that hybrid models may be a promising direction for future work.

Large language foundation models have demonstrated the power of scaling attention-based networks to billions of parameters and training them on trillions of tokens . Understanding the mechanistic basis behind these models may yield insights into better design choices for future models. These and similar explorations have informed the design of H3 and our selection of synthetic languages. A number of recent works have also explored how to address the shortcomings of attention by approximating the attention computation . We believe these efforts are complementary to SSMs, and we are excited to see how they can be combined in future work.

Linear attention and classical sequence models like RNNs serve as inspiration for H3. Appendix B draws a direct connection between linear attention and LTI systems. Luo et al. also propose a variant of linear attention that can achieve O(nlogn)O(n\log n) scaling in sequence length. Appendix F evaluates linear attention on language modeling, and finds that it underperforms exact attention, whereas H3 outperforms attention. The multiplicative interactions in H3 are reminiscent of gating mechanisms in LSTMs and GRUs , which suggests that architectural lessons from these sequence models may be useful for adapting SSMs to language modeling. A number of algorithms for scaling attention to longer sequences have also been proposed, such as Transformer-XL , Reformer , Performer , and Perceiver AR . Some of these approaches underperform exact attention on language modeling, and may be slower in wall-clock speed . A thorough comparison of these alternatives to exact attention and how well they scale in model size and amount of training data is fruitful future work.

FFT algorithms are used in a wide variety of applications, including signal processing , control theory , and more. Various algorithms for computing the FFT have existed for decades . We hope our work on appealing to these classic algorithms to accelerate new applications such as learned SSMs will inspire future algorithmic exploration, even if hardware is not designed for them .

Appendix B Linear Attention and Time-Varying Systems

We draw some connections from linear attention to LTI systems and SSMs.

We first present linear attention as a linear time-varying system, and show how converting it to a linear time-invariant system matches H3.

In general, a layer in a sequence model takes in a sequence and outputs a sequence. Many of these take the form of a linear time-varying system (thanks to the Picard-Lindelof theorem, nonlinear systems can be approximated by a series of linear system):

subscript𝐀𝑖subscript𝑥𝑖1subscript𝐁𝑖subscript𝑢𝑖\displaystyle=\mathbf{A}_{i}x_{i-1}+\mathbf{B}_{i}u_{i}, yi\displaystyle y_{i} =Cixi+Diui.\displaystyle=\mathbf{C}_{i}x_{i}+\mathbf{D}_{i}u_{i}. This has the same form as SSMs (Section 2), except that the matrices can depend on the timestep.

Recall the form of linear attention from Section 2. For the purpose of approximation, we ignore the denominator in linear attention Section 2 (i.e., assuming di=1d_{i}=1). We see that SiS_{i} is just a cumulative sum, satisfying the recurrence Si+1=Si+ϕ(Ki+1)Vi+1TS_{i+1}=S_{i}+\phi(K_{i+1})V_{i+1}^{T}. Similarly, OiO_{i} satisfies the recurrence Oi+1=ϕ(Qi+1)TSi+1O_{i+1}=\phi(Q_{i+1})^{T}S_{i+1}. This is a linear time-varying system of the form xi+1=Axi+Bui+1x_{i+1}=\mathbf{A}x_{i}+\mathbf{B}u_{i+1} and yi+1=Ci+1xi+1y_{i+1}=\mathbf{C}_{i+1}x_{i+1} (with A=I\mathbf{A}=I, B=I\mathbf{B}=I, ui=ϕ(Ki)ViTu_{i}=\phi(K_{i})V_{i}^{T}, Ci=ϕ(Qi)TC_{i}=\phi(Q_{i})^{T}). That is, A\mathbf{A} and B\mathbf{B} are constant, but CC is time-variant.

To convert this into a linear time-invariant version, we treat the time-variant Ci\mathbf{C}_{i} as a post-processing step. We instead of a fixed C\mathbf{C} for the LTI. This yields an LTI:

𝑖1\displaystyle x_{i+1} =Axi+Bϕ(Ki)ViT,\displaystyle=\mathbf{A}x_{i}+\mathbf{B}\phi(K_{i})V_{i}^{T}, yi+1\displaystyle y_{i+1} =Cxi,\displaystyle=\mathbf{C}x_{i}, for some matrices A,B,C\mathbf{A},\mathbf{B},\mathbf{C} that are learned. We then apply post-processing by multiply yi+1y_{i+1} with ϕ(Qi)T\phi(Q_{i})^{T}. Replacing ϕ(Ki)\phi(K_{i}) with a shift SSM yields an analogue to H3.

Appendix C Method details

Since we have described the forward pass in Section 3, we describe here the backward pass in details.

We show how to compute the backward pass in a fused kernel.

Let y=fu+Duy=f\ast u+\mathbf{D}u. In our case, we have ff and uu have the same length, so they are symmetric as far as the convolution is concerned.

Suppose we are given dy=lydy=\frac{\partial l}{\partial y} (where ll is some loss function). We wish to compute dudu, dfdf, and dDdD (which are lu\frac{\partial l}{\partial u}, lf\frac{\partial l}{\partial f}, and lD\frac{\partial l}{\partial\mathbf{D}}, respectively).

The most challenging part is computing the gradient through the convolution operator - but we’ll see that we can re-use our FFT infrastructure for it. The rest of the operations are straightforward; we have dD=dyuTd\mathbf{D}=dyu^{T}.

Here, we’ll discuss how to compute dfdf by integrating w.r.t. the convolution operator \ast. As an immediate consequence, we’ll be able to compute dudu as well.

Since ff and uu are the same length LL, fuf\ast u and ufu\ast f have the same result. Thus, we’ll start from ufu\ast f here.

For some notation, let O=ufO=u\ast f. Then, dO=dydO=dy. Recall that O[i]=j=0i1u[ij]f[j]O[i]=\sum_{j=0}^{i-1}u[i-j]f[j].

We’ll start by extending uu and ff with zeros, to give them length 2L2L. Let u=[u,u,,u[L1],0,,0]u^{\prime}=[u,u,\dots,u[L-1],0,\dots,0], and ff^{\prime} extended similarly. Let O=ufO^{\prime}=u^{\prime}\ast f^{\prime}, and O=O[:N]O=O^{\prime}[:N]. Assume that we have all the values of dOdO^{\prime} (we only have them for the first half, but we’ll see that it doesn’t matter in the end).

Let’s construct a Toeplitz matrix HuH_{u^{\prime}} such that uf=Hufu^{\prime}\ast f^{\prime}=H_{u^{\prime}}f^{\prime}:

Since, we have u[i]=f[i]=0u^{\prime}[i]=f^{\prime}[i]=0 for iLi\geq L, we can actually fill in the zeros of the above matrix as well:

Then, we can use the matrix multiplication chain rule to find that:

where we use u[i]u^{\prime}[-i] to mean u[2Li]u^{\prime}[2L-i]. Notice that this matrix has the same format as HuH_{u^{\prime}}! Let u=[u,u,,u[(2N1)]]u_{*}^{\prime}=[u^{\prime},u^{\prime},\dots,u^{\prime}[-(2N-1)]]. Then:

So how do we compute uu_{*}^{\prime} efficiently? Naively, we might incur some nasty memory access issues. But a nice property about the DFT saves us!

Let U[i]U[i] be the ii-th element of the DFT of a signal uu. Note that U[i]U[i] is complex. We have:

where here the * represents the complex conjugate. We can use this property to compute dfdf^{\prime} efficiently:

where FFTFFT^{*} denotes taking the complex conjugate of the FFT, and dydy^{\prime} denotes dydy, padded with zeros.

We can use this same trick to compute dudu, except we need to add in the contribution from the original Du\mathbf{D}u term. We end up with:

C.2 State-Passing Matrices

We show how to derive Mux\mathbf{M}_{ux} for the state update in our state-passing algorithm.

We wish to construct a matrix vMuxRm×NvM_{ux}\in\mathbb{R}^{m\times N^{\prime}} such that Muxu=i=1NAN1Bui\mathbf{M}_{ux}u=\sum_{i=1}^{N^{\prime}}\mathbf{A}^{N^{\prime}-1}\mathbf{B}u_{i}. Note that AiBRd×1\mathbf{A}^{i}\mathbf{B}\in\mathbb{R}^{d\times 1} is a column vector. We can simply stack these column vectors to form a matrix: Mux=[AN1B,AN2B,,B]\mathbf{M}_{ux}=[\mathbf{A}^{N^{\prime}-1}\mathbf{B},\mathbf{A}^{N^{\prime}-2}\mathbf{B},\dots,\mathbf{B}].

Appendix D Proofs

We show parameterizations of H3 and attention that solves the associative recall task. We prove Proposition 1 and Proposition 2.

This section formally describes a parameterization of H3 that solves the associative recall task.

Consider a simple language with 4 keys and 4 values. For concreteness, we will use the keys {k1,k2,k3,k4}=LK\{k_{1},k_{2},k_{3},k_{4}\}=L_{K} and the values {v1,v2,v3,v4}=LV\{v_{1},v_{2},v_{3},v_{4}\}=L_{V}, i.e. our language L=LKLVL=L_{K}\cup L_{V}. Given a sequence of key-value pairs with one key at the end, we want a model to generate the value associated with the key at the end. Assume that the key at the end appeared in the sequence.

More formally, let N+1N+1 be the length of the sequence, NN even. The language Λ\Lambda consists of sequences xLN+1x\in L^{N+1}. Each sequence has an associated mapping fx:LKLVf_{x}:L_{K}\rightarrow L_{V}. For each sequence, the odd indices are randomly sampled from LKL_{K}, for x1,x3,,xN1x_{1},x_{3},\dots,x_{N-1}. The even indices are defined by fxf_{x}: x2i=fx(x2i1)x_{2*i}=f_{x}(x_{2*i-1}), for 1iN/21\leq i\leq N/2. The last item in the sequence xN+1x_{N+1}, is randomly drawn from the keys that have appeared in xx already, i.e. xN+1{x1,x3,,xN1}x_{N+1}\in\cup{\{x_{1},x_{3},\dots,x_{N-1}\}}. The goal of this language modeling task is to produce fx(xN+1)f_{x}(x_{N+1}) at the end of the sequence.

D.1.2 H3 Model to Solve ΛΛ\Lambda

We describe a toy H3 model that can solve Λ\Lambda.

Consider a model consisting of an embedding layer, an H3 model, and an output projection with softmax. Recall that dd is the dimension of the H3 model, mm is the dimension of its hidden states, and HH is the number of heads. Let d=8,m=2,H=4d=8,m=2,H=4. Let the embedding layer map each key kik_{i} to the eie_{i} basis vector, and map each value viv_{i} to the e4+ie_{4+i} basis vector.

Let Bshift\mathbf{B}_{shift} and Cshift\mathbf{C}_{shift} denote the parameters of the shift SSM, and Adiag\mathbf{A}_{diag}, Bdiag\mathbf{B}_{diag}, and Cdiag\mathbf{C}_{diag} denote the parameters of the diagonal SSM (let D\mathbf{D} be zero for both). Let Bshift=Bdiag=Cdiag=e1\mathbf{B}_{shift}=\mathbf{B}_{diag}=\mathbf{C}_{diag}=e_{1}. Let Cshift\mathbf{C}_{shift} = $.Let. Let\mathbf{A}_{diag}beadiagonalmatrixwithbe a diagonal matrix with1$s along its diagonal for each H3.

The action of a diagonal SSM parameterized by Adiag\mathbf{A}_{diag}, Bdiag\mathbf{B}_{diag}, and Cdiag\mathbf{C}_{diag} is to act as a cumulative sum over all its input. The action of shift SSM parameterized by Bshift\mathbf{B}_{shift} and Cshift\mathbf{C}_{shift} is to shift its input by one time step.

Recall that the H3 layer maps its input to QQ, KK, and VV by applying uWQu\mathbf{W}_{Q}, uWKu\mathbf{W}_{K}, and uWVu\mathbf{W}_{V}. Let WQ\mathbf{W}_{Q} and WK\mathbf{W}_{K} be the following:

Recall that QQ and KK are split into HH heads (Q(i),K(i)\mathbf{Q}^{(i)},\mathbf{K}^{(i)} for i{1,2,3,4}i\in\{1,2,3,4\}), each of which is sent to an independent H3.

The action of WQ\mathbf{W}_{Q} and WK\mathbf{W}_{K} are to “assign” each key to a different H3 head, i.e., Qt(i)\mathbf{Q}^{(i)}_{t} is only non-zero when xt=kix_{t}=k_{i}. Similarly, Kt(i)\overline{\mathbf{K}}^{(i)}_{t} is only non-zero when xt1=kix_{t-1}=k_{i} (since Kt=Kt1\overline{\mathbf{K}}_{t}=\mathbf{K}_{t-1} due to the time delay of the shift SSM).

The action of this matrix is to encode the input value (as “binary”), and send it to all H3 heads. E.g., Vt(1)=Vt(2)=Vt(3)=Vt(4)\mathbf{V}_{t}^{(1)}=\mathbf{V}_{t}^{(2)}=\mathbf{V}_{t}^{(3)}=\mathbf{V}_{t}^{(4)} for all ii, and Vt(i)=xt=v1\mathbf{V}_{t}^{(i)}=\Leftrightarrow x_{t}=v_{1}, Vt(i)=xt=v2\mathbf{V}_{t}^{(i)}=\Leftrightarrow x_{t}=v_{2}, etc.

We claim that for xN+1=kix_{N+1}=k_{i}, ON+1(i)\mathbf{O}_{N+1}^{(i)} will be a multiple of the binary encoding of fx(ki)f_{x}(k_{i}), and all the other heads of the output ON+1(j),1j4,ji\mathbf{O}_{N+1}^{(j)},1\leq j\leq 4,j\neq i, will be zero. Let the output projection WO\mathbf{W}_{O} be such that, with a non-linearity afterwards, it inverts the binary encoding to produce the embedding of the desired output fx(ki)f_{x}(k_{i}). We will assume such a projection exists, proof left to the reader.

The model described above solves the associative recall problem for the language Λ\Lambda.

Proof sketch. WLOG, let xN+1=kix_{N+1}=k_{i}. Then Q(i)=\mathbf{Q}^{(i)}=, but Q(j)=\mathbf{Q}^{(j)}= for jij\neq i. Thus, O(j)=\mathbf{O}^{(j)}= for jij\neq i due to the multiplicative interaction.

Since Q(i)=\mathbf{Q}^{(i)}=, O(i)\mathbf{O}^{(i)} is the output of the diag SSMs in the H3 head corresponding to kik_{i} (recall that each head has two independent shift SSMs and two independent diag SSMs). The output of the diag SSMs are the cumulative sum of all the inputs they have seen in the sequence.

For one of the diag SSMs to see a non-zero input, its preceding shift SSM must have a non-zero output. The only times tt this can happen in the sequence are when xt1=kix_{t-1}=k_{i}. But then xt=fx(ki)x_{t}=f_{x}(k_{i}). Thus, the input to the diag SSMs are precisely the binary encoding of fx(ki)f_{x}(k_{i}). Then the output O(i)\mathbf{O}^{(i)} is a multiple of the binary encoding of fx(ki)f_{x}(k_{i}), WO\mathbf{W}_{O} decodes this output to the embedding of fx(ki)f_{x}(k_{i}). ∎

D.2 Attention Expressivity

We provide an informal sketch of a two-layer attention model that can solve the associative recall task, inspired by the construction of . The first layer of the attention model outputs the embedding of the previous token in the sequence, and concatenates it with the current token in the sequence. The second layer compares the current token to the previous token embeddings, and outputs the paired embedding when there is a match—which is exactly the key-value lookup.

In the first layer, let QiQ_{i} be mapped to the positional embedding of token xi1x_{i-1} (e.g., pi1p_{i-1} if pip_{i} denotes the positional embedding of token xix_{i}), and KiK_{i} be mapped to the positional embedding of token xix_{i}.

The attention matrix AA is computed as QKTQK^{T}, with a causal mask (i.e., Ai,j=0A_{i,j}=0 if j>ij>i).

Then, softmax(A)softmax(A) approximates the shift matrix (see Section 3).

Let ViV_{i} be an encoding of token xix_{i}, constrained to the first half of the hidden dimension.

Then, for output O=softmax(QKT)VO=softmax(QK^{T})V, the first half of the vector OiO_{i} is the encoding of token xi1x_{i-1}.

In the second layer, assume that you have a skip connection, that maps the encoding of the input token xix_{i} to the second half of the vector OiO_{i}.

Then, the input to the second layer encodes both xi1x_{i-1} and xix_{i}.

In the second layer, let QiQ_{i} extract the encoding of xix_{i}, and let KiK_{i} extract the encoding of xi1x_{i-1}.

Apply a causal mask on QKTQK^{T}. Then, the value of softmax(QKT)i,jsoftmax(QK^{T})_{i,j} is large if xi=xj1x_{i}=x_{j-1}, and i>j1i>j-1.

Let ViV_{i} extract the encoding of xix_{i}.

Then, output OiO_{i} is the sum of values xjx_{j} such as xj1=xix_{j-1}=x_{i}. But then OiO_{i} is exactly a lookup of the token that came after xix_{i} when it appeared previously in the sequence—which exactly solves associative recall.

We note that the above construction requires the ability for the positional encodings to select the previous token based on the dot product and softmax, and for token comparisons through the dot product and softmax.

D.3 H3 Complexity

We prove Proposition 1, which states that the H3 layer takes O(d2N+dNlogN)O(d^{2}N+dN\log N) time and O(dN)O(dN) space for sequence length NN and hidden dimension dd.

We first analyze the time complexity. Consider the matrix multiplies in H3, where the input uRN×du\in\mathbb{R}^{N\times d} is multiplied by three weight matrices of size d×dd\times d. These take time O(d2N)O(d^{2}N). The output O\mathbf{O} is also multiplied with an output projection weight matrix of size d×dd\times d, also taking time O(d2N)O(d^{2}N). Therefore the matrix multiplies take time O(d2N)O(d^{2}N).

Now consider the two SSMs in H3. The first SSM involves a convolution of KRN×d\mathbf{K}\in\mathbb{R}^{N\times d} (in the NN-dimension) with a kernel of size N×dN\times d. This reduces to an FFT, a pointwise multiply, and an inverse FFT (in the NN-dimension). This takes time O(dNlogN)O(dN\log N). The second SSM involves HH convolutions, inputs of size N×dh×dhN\times d_{h}\times d_{h}, along the NN-dimension. This takes time:

where we use the fact that dh=d/Hd_{h}=d/H and that dh=O(1)d_{h}=O(1). Therefore the two SSMs take total time O(dNlogN)O(dN\log N). As a result, the H3 layer takes time:

superscript𝑑2𝑁𝑑𝑁𝑁O(d^{2}N+dN\log N). Now we analyze the space complexity. The matrix multiplies all take space O(dN)O(dN). The FFTs, pointwise multiplies, and inverse FFTs of the two SSMs takes O(dN)O(dN) space and O(Hdh2N)=O(ddhN)=O(dN)O(Hd_{h}^{2}N)=O(dd_{h}N)=O(dN) space. Therefore the overall space complexity is O(dN)O(dN).

D.4 State Passing Correctness

We prove Proposition 2. We assume that the BlockFFTConv algorithm is correct, i.e., the output y=y=BlockFFTConv(f,u)(f,u) is equal to the output of an SSM with convolution kernel ff and input uu.

C=1C=1. WTS y=[y(1)]y=[y^{(1)}], MxxxN(0)+Muxu(1)=xN\mathbf{M}_{xx}x_{N^{\prime}}^{(0)}+\mathbf{M}_{ux}u^{(1)}=x_{N}.

In this case, note that N=NN=N^{\prime}. Then y(1)=MxyxN(0)+y^{(1)}=\mathbf{M}_{xy}x_{N^{\prime}}^{(0)}+BlockFFTConv(f,u1)=(f,u_{1})=BlockFFTConv(f,u1)(f,u_{1}). But u=u1u=u_{1}, so y=y(1)=[y(1)]y=y^{(1)}=[y^{(1)}].

Additionally, by the recursive definition of a state space,

C>1C>1. Assume that [y(1),,y(C1)]=y[:N(C1)][y^{(1)},\dots,y^{(C-1)}]=y[:N^{\prime}(C-1)], and xN(C1)=x(C1)Nx_{N^{\prime}}^{(C-1)}=x_{(C-1)N^{\prime}}. WTS that y(C)=y[N(C1):NC]y^{(C)}=y[N^{\prime}(C-1):N^{\prime}C], and MxxxN(C1)+Muxu(C)=xN\mathbf{M}_{xx}x_{N^{\prime}}^{(C-1)}+\mathbf{M}_{ux}u^{(C)}=x_{N}. Let tt denote N(C1)N^{\prime}(C-1).

superscript𝐂𝐀𝑖𝑡𝐁subscript𝑥𝑡subscript∗𝑓subscript𝑢𝑡subscript𝑢𝑡1…subscript𝑢𝑡superscript𝑁′1𝑖𝑡𝐃subscript𝑢𝑖\displaystyle=\mathbf{C}\mathbf{A}^{i-t}\mathbf{B}x_{t}+(f\ast[u_{t},u_{t+1},\dots,u_{t+N^{\prime}-1}])_{i-t}+\mathbf{D}u_{i} =CAitBxt+(fu(C))it+Dui\displaystyle=\mathbf{C}\mathbf{A}^{i-t}\mathbf{B}x_{t}+(f\ast u^{(C)})_{i-t}+\mathbf{D}u_{i} =CAitBxt+\textscBlockFFTConv(f,u(C))iN\displaystyle=\mathbf{C}\mathbf{A}^{i-t}\mathbf{B}x_{t}+\textsc{BlockFFTConv}(f,u^{(C)})_{i-N^{\prime}} =(Mxyxt+\textscBlockFFTConv(f,u(C)))iN\displaystyle=(\mathbf{M}_{xy}x_{t}+\textsc{BlockFFTConv}(f,u^{(C)}))_{i-N^{\prime}} =(MxyxN(C1)+\textscBlockFFTConv(f,u(C)))iN\displaystyle=(\mathbf{M}_{xy}x_{N^{\prime}}^{(C-1)}+\textsc{BlockFFTConv}(f,u^{(C)}))_{i-N^{\prime}} =yiN(C).\displaystyle=y^{(C)}_{i-N^{\prime}}. Thus, y(C)=y[N(C1):NC]y^{(C)}=y[N^{\prime}(C-1):N^{\prime}C].

superscript𝐀superscript𝑁′1subscript𝑥𝐶1superscript𝑁′superscriptsubscript𝑖1superscript𝑁′superscript𝐀superscript𝑁′𝑖𝐁subscript𝑢𝑖𝑡\displaystyle=\mathbf{A}^{N^{\prime}-1}x_{(C-1)N^{\prime}}+\sum_{i=1}^{N^{\prime}}\mathbf{A}^{N^{\prime}-i}\mathbf{B}u_{i+t} =AN1xN(C1)+i=1NANiBui(C)\displaystyle=\mathbf{A}^{N^{\prime}-1}x_{N^{\prime}}^{(C-1)}+\sum_{i=1}^{N^{\prime}}\mathbf{A}^{N^{\prime}-i}\mathbf{B}u^{(C)}_{i} =MxxxN(C1)+[AN1B,AN2B,,B]u(C)\displaystyle=\mathbf{M}_{xx}x_{N^{\prime}}^{(C-1)}+[\mathbf{A}^{N^{\prime}-1}\mathbf{B},\mathbf{A}^{N^{\prime}-2}\mathbf{B},\dots,\mathbf{B}]u^{(C)} =MxxxN(C1)+Muxu(C).\displaystyle=\mathbf{M}_{xx}x_{N^{\prime}}^{(C-1)}+\mathbf{M}_{ux}u^{(C)}.

Appendix E Experimental Details

Our synthetic tasks, inspired by , are designed to mimic the in-context learning capability of large language models—the ability to learn from examples in the input sequence, and use information from the input to generate the right answer for the output. For example, the induction head task requires memorizing the token that appears after the special \vdash token in the input sequence, and the associative recall task requires learning the mapping from keys to tokens from the input sequence.

We evaluate synthetics by training two-layer versions of our GPT models, with different modules replacing attention. We train models with inner dimension 32, and MLP dimension 128. For all the synthetics, we use a learning rate of 5e-4 and a weight decay of 0.1. We sample 5000 training examples and 500 test examples from the same distribution, and we train for 200 epochs. Again, we use embedding dropout of 0.1 and residual dropout of 0.0.

E.2 Model Architecture

For our 125M models, we use 12 layers, with hidden dimension 1024, and MLP dimension 4096. For our 355M models, we use 24 layers, with the same hidden dimension and MLP dimension. The 1.3B models have 24 layers, with hidden dimension 2048, and MLP dimension 8192. The 2.7B models have 32 layers, hidden dimension 2560, and MLP dimension 10240. The hybrid models have 12, 16, 16, and 20 heads for the 125M, 355M, 1.3B, and 2.7B models, respectively. The 125M hybrid model has an attention layers at layers 1 and 7, the 355M and 1.3B hybrid models have attention layers at layers 1 and 13, and the 2.7B hybrid model has attention layers at layers 10 and 21. For both our hybrid models and our H3 models, we use SSM state size 64. Our hybrid model uses head dimension 1 for H3, while our pure H3 model uses head dimension 8. We run models with mixed-precision training, with bf16 for the MLP’s and attention. When training language models, we use fp32 for the FFTConv.

E.3 OpenWebText Training

For the 125M models trained on OpenWebText, we follow the training recipe of the Megatron-LM repo.

We use an effective batch size of 512, and use gradient accumulation to fit into available GPU memory. We use the AdamW optimizer, with learning rate 6e-4 for GPT-2 small and 1.5e-4 for GPT-2 medium, and weight decay of 0.1. All models are trained with the same hyperparameters for 100K steps. We run all implementations with mixed-precision training (PyTorch AMP). We train models with sequence length 1024.

We use the Openwebtext dataset, with the GPT-2 BPE tokenizer. We randomly select 0.5% of the dataset as the validation set, with the rest being used as training set. This random selection of validation set is done once, and all models are evaluated on the same validation set.

E.4 The Pile Training

For the 125M and 355M models trained on the Pile, we follow the training recipe of GPT-3. We use batch size 256, with sequence length 2048. We train our models for 800K steps. We use residual dropout 0.0 and embedding dropout 0.1. We use the AdamW optimizer, with learning rate 6e-4 for the 125M model and 3e-4 for the 355M model, and a weight decay of 0.1. We use a cosine schedule with 8000 steps for linear warmup, and decay the learning rate to 10% by 300B tokens, then continue training at 10% learning rate for another 100B tokens. We suspect that there exist better hyperparameters for H3 language models, but we did not have the resources to tune them.

For the 1.3B models, we double the batch size to 512 (with sequence length 2048), again following the training recipe of GPT-3. The number of training steps are halved so that we train on the same number of tokens.

For the Pile dataset, we again use the GPT-2 BPE tokenizer, similar to GPT-3 and GPT-Neo.

E.5 SuperGLUE

We follow the prompts used in the GPT-3 paper . For rank classification on the binary classification tasks, we use yes/no for WSC, WIC, MultiRC, and BoolQ, and we use true/false for RTE. For CB, we use true/false/neither as the three choices. For COPA and ReCoRD, we use the continuations provided by the task.

E.6 Hardware

All models were trained on either a single 16xA100-40GB node or a cluster of 8xA100-80GB nodes.

Appendix F Additional Experiments

We evaluate the accuracy of H3 on LRA. We compare accuracy to S4D , since H3 uses an S4D kernel as a component in its layer. We use the same hyperparameters as S4D, and make the layer bidirectional by making two copies and running them in opposite directions.

Table 9 shows that H3 performs well on the LRA benchmark, even thought it was designed for autoregressive language modeling. H3 outperforms S4D on two of the LRA tasks, and comes within 1 point on the others.

F.2 WikiText103

We train 125M-sized models on WikiText103 and compare their test PPL to transformers, as well as other variants of efficient or long-range attention. We use the same hyperparameters and setup as training on OpenWebText. We also provide results from Transformer-XL and Perceiver AR for context, though the results may not be directly comparable due to differences in model size and tokenizer.

Table 10 shows that the Hybrid H3 model is competitive with Transformers of the same size, as well as larger models such as the 358M Perceiver AR and 285M Transformer-XL. The hybrid H3 model also significantly outperforms transformers with performer, reformer, and linear attention.

We note that the Transformer-XL and Perceiver AR PPl numbers are from the original papers, and may not be directly comparable to our results. In particular, they use a tokenizer with a different vocab size, which means that the PPLs are not directly comparable. In addition, the larger vocab size necessitates a change in the model (adaptive softmax) that may affect performance. The top five numbers in Table 10 are trained with the same setup and are directly comparable to each other.

F.3 PG-19

We evaluate models trained on the PG-19 dataset , a natural language dataset comprised of texts from books. We compare the performance of Hybrid H3 compared against Transformers and linear attention. We use the same setup as evaluating on OpenWebText.

Table 11 shows that Hybrid H3 outperforms transformers and linear attention.

F.4 Length Extrapolation

One property of SSMs is that they can naturally extrapolate to sequence lengths longer than those seen during training. We use the synthetic associative recall task to demonstrate that H3 maintains this capability. To do so, we train a two-layer H3 model on sequences of length 20 drawn from the associative recall synthetic language. Then, we evaluate accuracy of the last token prediction on sequences of length 20 and 40.

Table 12 shows that H3 maintains accuracy on sequences of length 40, which is twice the length of the training sequences.

F.5 Scaling in Number of Tokens

We evaluate how well a Hybrid H3 model scales with the number of tokens seen during training, compared to a Transformer. For these experiments, we train a 125M Hybrid H3 model and a 125M Transformer on the Pile for 5B, 10B, and 15B tokens. We use a learning rate of 6e-4, adjusting the warmup to be 1% of the total training time, and adjusting the decay rate to decay the learning rate to 6e-5 by the end of training.

Table 13 shows the results. Both the Hybrid H3 model and Transformer model improve as the number of training tokens increases.

F.6 H3 Language Model

We report the results of a pure H3 language model on NLP evaluations. We train a 125M model on the Pile for 400B tokens. Tables 14 and 15 show zero-shot and few-shot performance on SuperGLUE, respectively.

F.7 Generation Performance

We report results on SuperGLUE for generation. Instead of taking rank classification, we instead let the model generate a response, and we search for the gold label (i.e., “yes” or “no” for the yes/no questions) in the output. Tables 16 and 17 report the results. The trends for few-shot learning match with the logit results, but the hybrid and H3 models perform very poorly in zero-shot performance on some tasks. In these cases, the models tend to generate long text responses that are not relevant to the answer. The few-shot learning examples help the models generate answers in a parsable format.

F.8 Non-Text Sequence Modeling

We show that H3 outperforms Transformers on two non-text sequence modeling tasks: raw speech classification and seizure classification over raw EEG signals. H3 sets state-of-the-art performance on seizure classification and matches S4 on speech classification—which suggests that H3, or one of its hybrids, may be a strong candidate for a multimodal foundation model. Appendix E gives experimental details, and Appendix F gives an additional experiment on brain fMRI data.

Seizures, which are characterized by uncontrolled brain activity, are some of the most common neurological disorders . Chronic seizures, or epilepsy, cause a range of psychiatric and psycho-social disorders and impact the lives of roughly one percent of the global population . The first step to treating epilepsy is manual analysis of scalp EEG by board-certified neurologists. However, the vast amount of EEG data produced by each patient (which can be up to days of data) makes manual EEG analysis a costly and time-consuming process.

To mitigate the costs associated with EEG monitoring, recent deep learning techniques have began to show promise in flagging abnormal EEG segments for potential seizure events . A challenge with classifying EEG data is the trade-off between increasing input sequence length, where more context has been shown to improve seizure classification performance , with the increased difficulty of training deep learning models on long sequences (e.g., an EEG signal sampled at 200200Hz produces 12,00012{,}000 time steps per minute). As a result, many techniques involve domain-specialized models and pre-processing steps, such as FFT transforms and graphical representations .

We use the largest publicly available EEG corpus, TUSZ v1.5.2 , which includes 5,6125{,}612 EEG signals from 636 patients, with 3,0503{,}050 annotated seizures. Signals are segmented into 60-second clips, and split into train/val/test by patient. The train set contains 39765 clips, the val set contains 4351 clips, and the test set contains 10001 clips.

We evaluate binary seizure classification of 6060-sec EEG clips, sampled at 200200Hz, with 19 electrodes: xR12,000×19x\in R^{12{,}000\times 19} and y{0,1}y\in\{0,1\} on the TUSZ v1.5.2 corpus. Transformers cannot process the long sequence length of EEG signals without running out of GPU memory, whereas H3 can—and sets state-of-the-art performance.

The SC10 speech commands task contains raw audio signals one second in length, sampled at 16kHz. Similarly to EEG signals, Transformers cannot process the long sequence length. Table 19 shows that H3 comes within half a point of S4, the state-of-the-art method.

Functional Magnetic Resonance Imaging (fMRI) data are typically represented in four dimensions, indicating the measured blood-oxygen-level-dependent (BOLD) signal in temporal sequences S={V1,...,Vt}S=\{V_{1},...,V_{t}\} of 3-dimensional volumes VRx×y×zV\in\mathbb{R}^{x\times y\times z}, each indicating the BOLD signal for all spatial locations of the brain (as defined by three spatial dimensions xx, yy, and zz). A key challenge for the analysis of fMRI data is the high dimensionality and low sample size of its datasets, which typically contain many hundred thousand dimensions (i.e., voxels) for each of several hundred volumes VV in each of tens to hundreds of sequences SS. In this setting, where the number of features exceed the number of samples, standard machine learning approaches are prone to overfitting.

In spite of the low sample size of individual datasets, neuroimaging research can be considered as recently entering a big data era because researchers more frequently share their collected datasets publicly . The availability of these data open up the opportunity for pre-training in neuroimaging at scale, as recently demonstrated by , enabling models to utilize the knowledge that can be learned from public functional neuroimaging data for the analysis of individual datasets. Specifically, evaluate the performance of several self-supervised learning frameworks for functional neuroimaging data by first pre-training models on a broad fMRI dataset spanning 11,98011,980 fMRI runs from 1,7261,726 individuals across 3434 datasets and subsequently adapting the pre-trained models to two downstream mental state decoding datasets (namely, the HCP and MDTB datasets). In mental state decoding, predictive models are tasked with identifying (i.e., decoding) some set of mental states (e.g., answering questions about a prose story or math problem) from measured brain activity. The authors find that a GPT-based model, pre-trained in a causal learning framework, performs best in decoding the 2020 (HCP) and 2626 (MDTB) mental states of the two downstream datasets.

To evaluate the performance of H3 on fMRI data, we replicate this analysis, using the up- and downstream fMRI datasets that were published by , treating H3 as a drop-in replacement for the GPT model. To alleviate the high dimensionality challenge of fMRI data, and due to the generally high spatial correlation of brain activity, the original authors have reduced the volumetric time series SS to a set Θθ1,...,θn\Theta\in{\theta_{1},...,\theta_{n}} of n=1,024n=1,024 functionally-independent brain networks θ\theta (as defined by the dictionaries of functional modes (DiFuMo) brain atlas ), each describing the BOLD signal for some subset of voxels vx,y,zVv_{x,y,z}\in V, such that the resulting sequences XRt×nX\in\mathbb{R}^{t\times n} describe the activity pattern of each brain network nn for time points tt.

In line with , we first pre-train models f()f(\cdot) to predict the distribution of brain activity for the next time point jj of an input sequence XX, using a mean absolute error (LrecL_{rec}) training objective given the model’s output X^Rt×n\hat{X}\in\mathbb{R}^{t\times n}: Lrec=1ni=1nXj,iX^j,iL_{rec}=\frac{1}{n}\sum_{i=1}^{n}|X_{j,i}-\hat{X}_{j,i}|; X^t,n=bn+nf(EX)t,ewe,n\hat{X}_{t,n}=b_{n}+\sum_{n}f(E^{X})_{t,e}w_{e,n}; Et,eX=ETR+Epos+be+nXt,nwn,eE^{X}_{t,e}=E^{TR}+E^{pos}+b_{e}+\sum_{n}X_{t,n}w_{n,e}. Here, ETRReE^{TR}\in\mathbb{R}^{e} and EposReE^{pos}\in\mathbb{R}^{e} represent learnable embeddings for each possible time point and position of an input sequence (for details, see ). As the sampling frequency of fMRI is variable, the same position of an input sequence can correspond to different time points. Note that f()f(\cdot) processes the input in a lower-dimensional embedding representation EXRt×eE^{X}\in\mathbb{R}^{t\times e} (with e=768e=768 dimensions).

We evaluate two model architectures for f()f(\cdot), namely, the GPT variant used in , with 44 hidden layers and 1212 attention heads, and a corresponding H3 variant with 44 hidden layers (with H=64H=64 and m=1m=1; see section 3). For both models, the sequence of hidden-states outputs of the last model layer are used to determine X^\hat{X}.

Just as , we randomly divide the upstream data into distinct training and validation datasets by randomly designating 5%5\% of the fMRI runs of each of the 3434 upstream datasets as validation data (at a minimum of 22 runs per dataset) and using the rest of the runs for training. During upstream learning, we then randomly sample sequences of 1010 to 100100 time points from the fMRI runs and train models with the ADAM optimizer (with β1=0.9\beta_{1}=0.9, β2=0.999\beta_{2}=0.999, and ϵ=1e8\epsilon=1e^{-8} ) for 5,0005,000 steps at a mini-batch size of 512512, and a learning rate of 5e45e^{-4}. We apply a linear learning rate decay schedule (with a warm-up phase of 1% of the total number of training steps), gradient norm clipping at 1.01.0, and L2L2-regularisation (weighted by 0.10.1). We also apply dropout at a rate of 0.10.1 for the GPT-based model (based on ) and evaluate three dropout rates for H3: 0.10.1, 0.20.2, and 0.30.3.

We find that the H3 variant trained with 0.20.2 dropout performs on par with the GPT model, in terms of mean absolute error (Fig. 3), and therefore continue all further analyses with this model variant. We also find that both models exhibit almost identify LrecL_{rec} error distributions throughout the brain, with relatively higher errors in the posterior parietal, occipital, and cingulate cortices as well parts of the limbic system (Fig. 4).

To adapt the pre-trained models to mental state decoding, we add a learnable classification embedding EclsRnE^{cls}\in\mathbb{R}^{n} to the end of input sequences XX and forward the model’s prediction f(EX)f(E^{X}) to a decoding head p()p(\cdot), composed of a dense hidden layer with ee model units (one for each embedding dimension, with tanhtanh activation) as well as a softmaxsoftmax output layer with one model unit ii for each considered mental state in the data. Accordingly, we adapt models by optimizing a standard cross entropy loss objective: Lcls=iyilog p(f(EX))iL_{cls}=-\sum_{i}y_{i}\log\ {p(f(E^{X}))_{i}}, where yiy_{i} indicates a binary variable that is 11 if ii is the correct mental state and otherwise.

During downstream adaptation, we begin training with the respective pre-trained model parameters and then allow all parameters to change freely. Similar to , we randomly split each of the two downstream datasets into distinct training and test datasets, each comprising 4040 (HCP) or 1010 (MDTB) distinct individuals. We adapt models for 750750 steps at a mini-batch size of 256256 and a learning rate of 5e55e^{-5} (otherwise using the same learning parameters as for upstream training). Importantly, we repeat each downstream training run 2020 times using different random seeds, leading to different random splits of the data and variability in other non-deterministic training factors (such as random initialization and data shuffling).

As for the upstream data, the H3 and GPT-based models generally perform on par in their mental state decoding performances in the two left-out test datasets (Table 20).