Diagonal State Spaces are as Effective as Structured State Spaces

Ankit Gupta, Albert Gu, Jonathan Berant

Introduction

The Transformer architecture [VSP+17] has been successful across many areas of machine learning. Transformers pre-trained on large amounts of unlabelled text via a denoising objective have become the standard in natural language processing, exhibiting impressive amounts of linguistic and world knowledge [BMR+20, BMH+21]. This recipe has also led to remarkable developments in the areas of vision [RPG+21, RKH+21] and speech [DXX18, BZMA20].

The contextualizing component of the Transformer is the multi-head attention layer which, for inputs of length LL, has an expensive Ω(L2)\Omega(L^{2}) complexity. This becomes prohibitive on tasks where the model is required to capture long-range interactions of various parts of a long input. To alleviate this issue, several Transformer variants have been proposed with reduced compute and memory requirements [QML+20, KKL20, RSVG21, BPC20, GB20, KVPF20, WLK+20, CLD+21, ZS21, GDG+21] (cf. [TDBM20] for a survey). Despite this effort, all these models have reported inadequate performance on benchmarks created to formally evaluate and quantify a model’s ability to perform long-range reasoning (such as Long Range Arena [TDA+21] and SCROLLS [SSI+22]).

In a recent breakthrough result, Gu et al. [GGR22] proposed S4, a sequence-to-sequence model that uses linear state spaces for contextualization instead of attention. It has shown remarkable performance on tasks requiring long-range reasoning in domains such as text, images and audio. For instance, on Long Range Arena it advances the state-of-the-art by 1919 accuracy points over the best performing Transformer variant. Its remarkable abilities are not limited to text and images and carry over to tasks such as time-series forecasting, speech recognition and audio generation [GGDR22].

Despite S4’s achievements, its design is complex and is centered around the HiPPO theory, which is a mathematical framework for long-range modeling [VKE19, GDE+20, GJG+21]. [GGR22] showed that state space models with various alternative initializations perform poorly in comparison to initializing the state-space parameters with a particular HiPPO matrix. In order to leverage this matrix, they parameterize the learned state spaces using a Diagonal Plus Low Rank (DLPR) structure and, as a result, need to employ several reduction steps and linear algebraic techniques to be able to compute the state space output efficiently, making S4 difficult to understand, implement and analyze.

In this work, we show that it is possible to match S4’s performance while using a much simpler, fully-diagonal parameterization of state spaces. While we confirm that random diagonal state spaces are less effective, we observe that there do in fact exist effective diagonal state matrices: simply removing the low-rank component of the DPLR HiPPO matrix still preserves its performance. Leveraging this idea, our proposed Diagonal State Space (DSS) model enforces state matrices to be diagonal, making it significantly simpler to formulate, implement and analyze, while being provably as expressive as general state spaces. In contrast to S4, DSS does not assume any specialized background beyond basic linear algebra and can be implemented in just a few lines of code. Our implementation fits in a single page and is provided in §A.5 (Figure 6).

We evaluate the performance of DSS on Long Range Arena (LRA) which is a suite of sequence-level classification tasks with diverse input lengths (1K1K-16K16K) requiring similarity, structural, and visual-spatial reasoning over a wide range of modalities such as text, natural/synthetic images, and mathematical expressions. Despite its simplicity, DSS delivers an average accuracy of 81.8881.88 across the 66 tasks of LRA, comparable to the state-of-the-art performance of S4 (80.2180.21). In addition, DSS maintains a comfortable 2020 point lead over the best Transformer variant (81.8881.88 vs 61.4161.41).

In the audio domain, we evaluate the performance of DSS on raw speech classification. On the Speech Commands dataset [War18], which consists of raw audio samples of length 16K16K, we again found the performance of DSS to be comparable to that of S4 (98.298.2 vs 98.198.1).

To summarize, our results demonstrate that DSS is a simple and effective method for modeling long-range interactions in modalities such as text, images and audio. We believe that the effectiveness, efficiency and transparency of DSS can significantly contribute to the adoption of state space models over their attention-based peers. Our code is available at https://github.com/ag1988/dss.

Background

We start by reviewing the basics of time-invariant linear state spaces.

Discretization

Assuming x1=0x_{-1}=0 for simplicity, this recurrence can be explicitly unrolled as

where the last equality follows by substituting the values of A\overline{A}, B\overline{B}, C\overline{C} from Equation 2. Hence,

Computing y𝑦y from u𝑢u and K¯¯𝐾\overline{K} is easy.

yky_{k} is the coefficient of zkz^{k} in the polynomial K(z)u(z)\overline{K}(z)\cdot u(z), i.e. all yky_{k}’s can be computed simultaneously by multiplying two degree L1L-1 polynomials. It is well-known that this can be done in O(Llog(L))O(L\log(L)) time via Fast Fourier Transform (FFT) [CLRS09]. We denote this fast computation of Equation 5 via the discrete convolution as

Method

Having stated the necessary background, we now turn to the main contribution of our work.

The proof of Proposition 1 is elementary and is provided in §A.1. In the above equations, the last equality follows by using Equation 4 to explicitly compute the expression for the kernel of the corresponding diagonal state space. Hence, for any given state space with a well-behaved state matrix there exists a diagonal state space computing the same kernel.It is possible for norms of the parameters of the resulting diagonal state spaces to be much larger than that of the original state space. For example, this occurs for the HiPPO matrix [GGR22]. More importantly, the expressions for the kernels of the said diagonal state spaces no longer involve matrix powers but only a structured matrix-vector product.

DSSexp provides a remarkably simple computation of state space kernels but restricts the space of the learned Λ\Lambda (the real part must be negative). It is not clear if such a restriction could be detrimental for some tasks, and we now present an alternate method that provides the simplicity of Proposition 1(a) while being provably as expressive as general state spaces.

DSSsoftmax

2 Diagonal State Space (DSS) Layer

We are now ready to describe the DSS layer. We retain the skeletal structure of S4 and simply replace the parameterization and computation of the SSM kernel by one of the methods described in §3.1.

The DSS layer can be implemented in just a few lines of code and our PyTorch implementation of DSSsoftmax layer is provided in §A.5 (Figure 6). The implementation of DSSexp layer is even simpler and is omitted.

For batch size BB, sequence length LL and hidden size HH, the DSS layer requires O(NHL)O(NHL) time and space to compute the kernels, O(BHLlog(L))O(BHL\log(L)) time for the discrete convolution and O(BH2L)O(BH^{2}L) time for the output projection. For small batch size BB, the time taken to compute the kernels becomes important whereas for large batches more compute is spent on the convolution and the linear projection. The kernel part of DSS layer has 2N+H+2HN2N+H+2HN real-valued parameters.

3 Initialization of DSS layer

The performance of state spaces models is known to be highly sensitive to initialization [GGR22]. In line with the past work, we found that carefully initializing the parameters of the DSS layer is crucial to obtain state-of-the-art performance (§4).

Henceforth, we would refer to the above initialization of Λ\Lambda as Skew-Hippo initialization. In all our experiments, we used the above initialization with N=64N=64. The initial learning rate of all DSS parameters was 10310^{-3} and weight decay was not applied to them. Exceptions to these settings are noted in §A.3.

4 States of DSS via the Recurrent View

As stated in Proposition 1(a), DSSexp computes KΔ,L(Λ, (1)1iN, w~)\overline{K}_{\Delta,L}(\Lambda,\ (1)_{1\leq i\leq N},\ \widetilde{w}). For this state space and sample time Δ\Delta, we use Equation 2 to obtain its discretization

DSSsoftmax

As stated in Proposition 1(b), DSSsoftmax computes KΔ,L(Λ, ((eLλiΔ1)1)1iN, w)\overline{K}_{\Delta,L}(\Lambda,\ ((e^{L\lambda_{i}\Delta}-1)^{-1})_{1\leq i\leq N},\ w). For this state space and sample time Δ\Delta, we obtain the discretization

For the ii’th coordinate we can independently compute

Let us drop the coordinate index ii for clarity to obtain

Experiments

We evaluate the performance of DSS on sequence-level classification tasks over text, images, audio. Overall, we find its performance is comparable to S4.

LRA [TDA+21] is a standard benchmark for assessing the ability of models to process long sequences. LRA contains 66 tasks with diverse input lengths 1K1K-16K16K, encompassing modalities such as text and images. Several Transformer variants have been benchmarked on LRA but all have underperformed due to factors such as their high compute-memory requirements, implicit locality bias and inability to capture long-range dependencies.

Table 1 compares DSS against S4, the Transformer variants reported in [TDA+21], as well as follow-up work. State space models (S4, DSS) shown in Table 1 are left-to-right unidirectional whereas other models could be bidirectional.

Despite its simplicity, DSS delivers state-of-the-art performance on LRA. Its performance is comparable to that of S4, with a modest improvement in test accuracy averaged across the 6 tasks (81.8881.88 vs 80.2180.21).The large gap between S4 and DSS on Text is due to the use of a larger learning rate for Δlog\Delta_{\log} in DSS. For our S4 runs, we decided to use the official hyperparameters as provided by [GGR22]. In addition, DSS maintains a comfortable 2020 point lead over the best performing Transformer variant (81.8881.88 vs 61.4161.41).

Raw Speech Classification

Audio is typically digitized using a high sampling rate resulting in very long sequences. This provides an interesting domain for investigating the abilities of long-range models. We evaluate the performance of DSS on the Speech Commands (SC) dataset [War18], consisting of raw audio samples of length 1600016000, modeled as a 1010-way classification task. As shown in Table 2, the performance of all DSS variants is comparable to that of S4 (98.298.2 vs 98.198.1).

In all experiments and ablations, S4 and DSS use identical model hyperparameters such as hidden size, number of layers, etc. Our experimental setup was built on top of the training framework provided by the S4 authors and for our S4 runs we followed their official instructions.https://github.com/HazyResearch/state-spaces Details about model initialization, and hyperparameters are provided in §A.3.

1 Analyzing the Performance of DSS

While the experimental results presented above are encouraging, and clearly demonstrate the effectiveness of DSS at modeling long-range dependencies, it is not clear what exactly are the main factors contributing to its performance. To investigate this further, we performed an ablation analysis aimed at answering the following questions:

How significant is the Skew-Hippo initialization (§3.3) to the model performance? Would initializing Λ\Lambda randomly work just as well?

Is the main source of superior performance of state space models (S4, DSS), compared to previous models, their ability to model long-range dependencies? Would restricting DSS to only model local interactions hurt its performance on the above tasks?

Truncated Kernels

To answer the second question, instead of constructing a kernel of length equal to the length LL of the input, we restricted the length of the kernel constructed in DSSsoftmax (Algorithm 1) to 128128, significantly shorter than the length of the input. To understand the implication of this restriction recall Equation 5 which states that yk = j=0kKjukjy_{k}\ =\ \sum_{j=0}^{k}\overline{K}_{j}\cdot u_{k-j}.

For a given context size c=128c=128, restricting Kc=0\overline{K}_{\geq c}=0 would imply

and hence the output yky_{k} at position kk would only depend on uk,,ukc+1u_{k},\ldots,u_{k-c+1}. This would restrict each DSSsoftmax layer to only model local interactions and the model would require several layers to have a broader receptive field and capture long-range interactions.

As shown in Table 3, randomly initializing the Λ\Lambda parameters of DSS leads to significant performance degradation on the majority of tasks, with the model failing to perform on Path-X. This is inline with the findings of [GGR22] who also reported the initialization of S4 to be critical to its performance. Interestingly, despite this performance reduction, DSS manages to outperform all non state-space-based models on every task.

Truncating the length of the kernel also leads to a significant reduction in performance across most tasks (Table 3), suggesting that the superior performance of state-space models on these tasks can indeed be partly attributed to their ability to capture long-range dependencies. Moreover, on some tasks such as ListOps and Image, using a truncated kernel still manages to outperform all Transformer variants, which is surprising as Transformer layers are known to be effective at capturing interactions at such short ranges.

2 Analysis of Learned DSS Parameters

To further explore the inner workings of DSS, we visually inspected the trained parameters and kernels of DSSsoftmax .

The kernels of all layers of the trained DSSsoftmax are shown in Figure 2 and reveal a stark contrast between the tasks. On the tasks Image and SC, for almost all kernels, the absolute values of the first 128128 positions are significantly higher than the later positions indicating that these kernels are mostly local. On the other hand for Path-X, for a significant proportion of kernels the opposite is true, indicating that these kernels are modeling long-range interactions.

We note that the plot for ListOps reveals an outlier with a value of 2222 which after exponentiation in Algorithm 1 would result in an extreme large Δ\Delta. This can potentially lead to training instabilities and we plan to address this issue in future work.

Discussion

In a long line of work, several variants of the Transformer have been proposed to address its quadratic complexity (cf. [GB21] and references therein). Recently, Gu et al. [GGR22] introduced S4, a new type of model that leverages linear state spaces for contextualization instead of attention. Our work is inspired from S4 but uses a diagonal parameterization of state spaces. As a result, our method is significantly simpler compared to S4 and we do not require (1) Padé approximations to A=eAΔ\overline{A}=e^{A\Delta} (Euler, Bilinear, etc), (2) Woodbury Identity reductions to compute matrix inverse, and (3) fourier analysis for computing the SSM kernel efficiently via truncated generating functions.

Limitations and future work

Acknowledgments and Disclosure of Funding

We thank Ramon Fernandez Astudillo for carefully reviewing the preliminary draft and suggesting several helpful edits. We thank Omer Levy, Achille Fokoue and Luis Lastras for their support. Our experiments were conducted on IBM’s Cognitive Computing Cluster, with additional resources from Tel Aviv University. This research was supported by (1) IBM AI Residency program and (2) Defense Advanced Research Projects Agency (DARPA) through Cooperative Agreement D20AC00004 awarded by the U.S. Department of the Interior (DOI), Interior Business Center.

References

Appendix A Supplemental Material

We restate Proposition 1 for convenience.

where the last equality follows from (zL1)=(z1)(z0++zL1)(z^{L}-1)=(z-1)(z^{0}+\ldots+z^{L-1}) and using zL1z^{L}\neq 1.

Similarly, for the state space (Λ, (1)1iN, w~))(\Lambda,\ (1)_{1\leq i\leq N},\ \widetilde{w})) and sample time Δ\Delta its kernel K~\widetilde{K} can be obtained from Equation 4 as

It is easy to verify that Equation 10 can be expressed as a vector-matrix product

Similarly, for the state space (Λ, ((eLλiΔ1)1)1iN, w))(\Lambda,\ ((e^{L\lambda_{i}\Delta}-1)^{-1})_{1\leq i\leq N},\ w)) and sample time Δ\Delta its kernel K^\widehat{K} can be obtained from Equation 4 as

which is also the expression for KkK_{k} (Equation 7). ∎

Case 1 (p=0,n=1p=0,n=1): In this case we have e=exp(c)e=\exp(c). For the map

where last equality follows from ωL=1\omega^{L}=1.

Case 2 (p=1,n=0p=1,n=0): In this case we have e=exp(c)e=\exp(-c). For the map

where last equality follows from ωL=1\omega^{L}=1.

Finally, the second equality of the main Claim follows from 1e=npe+pne1-e=n-pe+p-ne. ∎

A.3 Experimental Setup

We now describe the training details for DSS and S4 on LRA and Speech Commands (§4).

Sequence Classification Head: Both LRA and Speech Commands are sequence classification tasks. The final layer of the DSS stack outputs a sequence which is aggregated into a single vector via mean pooling along the length dimension. Exceptions to this were Text and Pathfinder tasks where the rightmost token was used as the aggregate.

For all datasets, we used AdamW optimizer with a constant learning rate schedule with decay on validation plateau. However, for the DSS parameters (§3.2) initial learning rate was 10310^{-3} and weight decay was not used, with a few exceptions noted below.

We used hyperparameters such as model sizes, number of update steps, etc as recommended by the S4 authors on their official repository and are listed in Table 4. We made the following exceptions for DSS trainings:

ListOps: learning rate of Δlog\Delta_{\log} was 0.020.02 instead of 10310^{-3}.

Text: learning rate of Δlog\Delta_{\log} 0.020.02 instead of 10310^{-3}.

Image: we used seed and trained for 200200 epochs instead of 100100.

Path-X: we used batch size 1616 and trained for 3535 epochs. Δlog\Delta_{\log} was initialized as ere^{r} where rU(log(.0001),log(.01))r\sim\mathcal{U}(\log(.0001),\log(.01)) and its learning rate was 10410^{-4}. This was beneficial in early convergence of the model.

For our experiments, the test accuracy that we report in §4 was measured at the checkpoint with the highest validation accuracy.

All our experiments were conducted on a single A100 GPU (40GiB).

A.4 Learned Parameters of DSSsoftmax

A.5 Implementation of DSSsoftmax