Simplified State Space Layers for Sequence Modeling
Jimmy T. H. Smith, Andrew Warrington, Scott W. Linderman
Introduction
Efficiently modeling long sequences is a challenging problem in machine learning. Information crucial to solving tasks may be encoded jointly between observations that are thousands of timesteps apart. Specialized variants of recurrent neural networks (RNNs) (Arjovsky et al., 2016; Erichson et al., 2021; Rusch & Mishra, 2021; Chang et al., 2019), convolutional neural networks (CNNs) (Bai et al., 2018; Oord et al., 2016; Romero et al., 2022b), and transformers (Vaswani et al., 2017) have been developed to try to address this problem. In particular, many efficient transformer methods have been introduced (Choromanski et al., 2021; Katharopoulos et al., 2020; Kitaev et al., 2020; Beltagy et al., 2020; Gupta & Berant, 2020; Wang et al., 2020) to address the standard transformer’s quadratic complexity in the sequence length. However, these more efficient transformers still perform poorly on very long-range sequence tasks (Tay et al., 2021).
Gu et al. (2021a) presented an alternative approach using structured state space sequence (S4) layers. An S4 layer defines a nonlinear sequence-to-sequence transformation via a bank of many independent single-input, single-output (SISO) linear state space models (SSMs) (Gu et al., 2021b), coupled together with nonlinear mixing layers. Each SSM leverages the HiPPO framework (Gu et al., 2020a) by initializing with specially constructed state matrices. Since the SSMs are linear, each layer can be equivalently implemented as a convolution, which can then be applied efficiently by parallelizing across the sequence length. Multiple S4 layers can be stacked to create a deep sequence model. Such models have achieved significant improvements over previous methods, including on the long range arena (LRA) (Tay et al., 2021) benchmarks specifically designed to stress test long-range sequence models. Extensions have shown good performance on raw audio generation (Goel et al., 2022) and classification of long movie clips (Islam & Bertasius, 2022).
We introduce a new state space layer that builds on the S4 layer, the S5 layer, illustrated in Figure 1. S5 streamlines the S4 layer in two main ways. First, S5 uses one multi-input, multi-output (MIMO) SSM in place of the bank of many independent SISO SSMs in S4. Second, S5 uses an efficient and widely implemented parallel scan. This removes the need for the convolutional and frequency-domain approach used by S4, which requires a non-trivial computation of the convolution kernel. The resulting state space layer has the same computational complexity as S4, but operates purely recurrently and in the time domain.
We then establish a mathematical relationship between S4 and S5. This connection allows us to inherit the HiPPO initialization schemes that are key to the success of S4. Unfortunately, the specific HiPPO matrix that S4 uses for initialization cannot be diagonalized in a numerically stable manner for use in S5. However, in line with recent work on the DSS (Gupta et al., 2022) and S4D (Gu et al., 2022) layers, we found that a diagonal approximation to the HiPPO matrix achieves comparable performance. We extend a result from Gu et al. (2022) to the MIMO setting, which justifies the diagonal approximation for use in S5. We leverage the mathematical relationship between S4 and S5 to inform several other aspects of parameterization and initialization, and we perform thorough ablation studies to explore these design choices.
The final S5 layer has many desirable properties. It is straightforward to implement (see Appendix A),The full S5 implementation is available at: https://github.com/lindermanlab/S5. enjoys linear complexity in the sequence length, and can efficiently handle time-varying SSMs and irregularly sampled observations (which is intractable with the convolution implementation of S4). S5 achieves state-of-the-art performance on a variety of long-range sequence modeling tasks, with an LRA average of , and accuracy on the most difficult Path-X task.
Background
We provide the necessary background in this section prior to introducing the S5 layer in Section 3.
where the discrete-time parameters are each a function, specified by the discretization method, of the continuous-time parameters. See Iserles (2009) for more information on discretization methods.
2 Parallelizing Linear State Space Models With Scans
We use parallel scans to efficiently compute the states of a discretized linear SSM. Given a binary associative operator (i.e. ) and a sequence of elements , the scan operation (sometimes referred to as all-prefix-sum) returns the sequence
3 S4: Structured State Space Sequence Layers
Efficiently applying the S4 layer requires two separate implementations depending on context: a recurrent mode and a convolution mode. For online generation, the SSM is iterated recurrently, much like other RNNs. However, when the entire sequence is available and the observations are evenly spaced, a more efficient convolution mode is used. This takes advantage of the ability to represent the linear recurrence as a one-dimensional convolution between the inputs and a convolution kernel for each of the SSMs. Fast Fourier transforms (FFTs) can then be applied to efficiently parallelize this application. Figure 4(a) in the appendix illustrates the convolution approach of the S4 layer for offline processing. We note that while parallel scans could, in principle, allow a recurrent approach to be used in offline scenarios, applying the parallel scan to all of the -dimensional SSMs would in general be much more expensive than the convolution approach S4 actually uses.
The S5 Layer
In this section we present the S5 layer. We describe its structure, parameterization and computation, particularly focusing on how each of these differ from S4.
2 S5 Parameterization: Diagonalized Dynamics
Prior work showed that the performance of deep state space models are sensitive to the initialization of the state matrix (Gu et al., 2021b; a). We discussed in Section 2.2 that state matrices must be diagonal for efficient application of parallel scans. We also discussed in Section 2.3 that the HiPPO-LegS matrix cannot be diagonalized stably, but that the HiPPO-N matrix can be. In Section 4 we connect the dynamics of S5 to S4 to suggest why initializing with HiPPO-like matrices may also work well in the MIMO setting. We support this empirically, finding that diagonalizing the HiPPO-N matrix leads to good performance, and perform ablations in Appendix E to compare to other initializations. We note that DSS (Gupta et al., 2022) and S4D (Gu et al., 2022) layers also found strong performance in the SISO setting by using a diagonalization of the HiPPO-N matrix.
The complex eigenvalues of a diagonalizable matrix with real entries always occur in conjugate pairs. We enforce this conjugate symmetry by using half the number of eigenvalues and latent states. This ensures real outputs and reduces the runtime and memory usage of the parallel scan by a factor of two. This idea is also discussed in Gu et al. (2022).
3 S5 Computation: Fully Recurrent
Compared to the large effective latent size of the block-diagonal S4 layer, the smaller latent dimension of the S5 layer () allows the use of efficient parallel scans when the entire sequence is available. The S5 layer can therefore be efficiently used as a recurrence in the time domain for both online generation and offline processing. Parallel scans and the continuous-time parameterization also allow for efficient handling of irregularly sampled time series and other time-varying SSMs, by simply supplying a different matrix at each step. We leverage this feature and apply S5 to irregularly sampled data in Section 6.3. In contrast, the convolution of the S4 layer requires a time invariant system and regularly spaced observations.
4 Matching the Computational Efficiency of S4 and S5
A key design desiderata for S5 was matching the computational complexity of S4 for both online generation and offline recurrence. The following proposition guarantees that their complexities are of the same order if S5’s latent size .
Given an S4 layer with input/output features, an S5 layer with input/output features and a latent size has the same order of magnitude complexity as an S4 layer in terms of both runtime and memory usage.
We also support this proposition with empirical comparisons in Appendix C.2.
Relationship Between S4 and S5
We now establish a relationship between the dynamics of S5 and S4. In Section 4.1 we show that, under certain conditions, the outputs of the S5 SSM can be interpreted as a projection of the latent states computed by a particular S4 system. This interpretation motivates using HiPPO initializations for S5, which we discuss in more detail in Section 4.2. In Section 4.3 we discuss how the conditions required to relate the dynamics further motivate initialization and parameterization choices.
We compare the dynamics of S4 and S5 under some simplifying assumptions:
We consider only -dimensional to -dimensional sequence maps.
We assume that the same state matrix is used in S5 as in S4 (also cf. Assumption 2). Note this also specifies the S5 latent size . We also assume the S5 input matrix is the horizontal concatenation of the column input vectors used by S4: .
We will discuss relaxing these assumptions shortly, but under these conditions it is straightforward to derive a relationship between the dynamics of S4 and S5:
2 Diagonalizable Initialization
Ideally, given the interpretation above, we would initialize S5 with the exact HiPPO-LegS matrix. Unfortunately, as discussed in Section 2.3, this matrix is not stably diagonalizable, as is required for the efficient parallel scans used for S5. However, Gupta et al. (2022) and Gu et al. (2022) showed empirically that removing the low rank terms and initializing with the diagonalized HiPPO-N matrix still performed well. Gu et al. (2022) offered a theoretical justification for the use of this normal approximation for single-input systems: in the limit of infinite state dimension, the linear ODE with HiPPO-N state matrix produces the same dynamics as an ODE with the HiPPO-LegS matrix. Using linearity, it is straightforward to extend this result to the multi-input system that S5 uses:
We include a simple proof of this extension in Appendix D.3. This extension motivates the use of HiPPO-N to initialize S5’s MIMO SSM. Note that S4D (the diagonal extension of S4) uses the same HiPPO-N matrix. Thus, when under the assumptions in Proposition 2, an S5 SSM in fact produces outputs that are equivalent to a linear combination of the latent states produced by S4D’s SSMs. Our empirical results in Section 6 suggest that S5 initialized with the HiPPO-N matrix performs just as well as S4 initialized with the HiPPO-LegS matrix.
3 Relaxing the Assumptions
We now revisit the assumptions required for Proposition 2, since they only relate a constrained version of S5 to a constrained version of S4. Regarding Assumption 2, Gu et al. (2021a) report that S4 models with tied state matrices can still perform well, though allowing different state matrices often yields higher performance. Likewise, requiring a single scalar timescale across all of the S4 SSMs, per Assumption 3, is restrictive. S4 typically learns different timescale parameters for each SSM (Gu et al., 2023) to capture different timescales in the data. To relax these assumptions, note that Assumption 4 constrains S5 to have dimension , and is typically much smaller than the dimensionality of the inputs, . Proposition 1 established that S5 can match S4’s complexity with . By allowing for larger latent state sizes, Assumptions 2 and 3 can be relaxed, as discussed in Appendix D.4. We also discuss how this relaxation motivates a block-diagonal initialization with HiPPO-N matrices on the diagonal. Finally, to further relax the tied timescale assumptions, we note that in practice, we find improved performance by learning different timescales (one per state). See Appendix D.5 for further discussion of this empirical finding and the ablations in Appendix E.1.
Related work
S5 is most directly related to S4 and its other extensions, which we have discussed thoroughly. However, there is prior literature that uses similar ideas to those developed here. For example, prior work studied approximating nonlinear RNNs with stacks of linear RNNs connected by nonlinear layers, while also using parallel scans (Martin & Cundy, 2018). Martin & Cundy (2018) showed that several efficient RNNs, such as QRNNs (Bradbury et al., 2017) and SRUs (Lei et al., 2018), fall into a class of linear surrogate RNNs that can leverage parallel scans. Kaul (2020) also used parallel scans for an approach that approximates RNNs with stacks of discrete-time single-input, multi-output (SIMO) SSMs. However, S4 and S5 are the only methods to significantly outperform other comparable state-of-the-art nonlinear RNNs, transformers and convolution approaches. Our ablation study in Appendix E.2 suggests that this performance gain over prior attempts at parallelized linear RNNs is likely due to the continuous-time parameterization and the HiPPO initialization.
Experiments
We now compare empirically the performance of the S5 layer to the S4 layer and other baseline methods. We use the S5 layer as a drop-in replacement for the S4 layer. The architecture consists of a linear input encoder, stacks of S5 layers, and a linear output decoder (Gu et al., 2021a). For all experiments we choose the S5 dimensions to ensure similar computational complexities as S4, following the conditions discussed in Section 3.3, as well as comparable parameter counts. The results we present show that the S5 layer matches the performance and efficiency of the S4 layer. We include in the appendix further ablations, baselines and runtime comparisons.
The long range arena (LRA) benchmark (Tay et al., 2021) is a suite of six sequence modeling tasks, with sequence lengths from 1,024 to over 16,000. The suite was specifically developed to benchmark the performance of architectures on long-range modeling tasks (see Appendix G for more details). Table 1 presents S5’s LRA performance in comparison to other methods. S5 achieves the highest average score among methods that have linear complexity in sequence length (most notably S4, S4D, and the concurrent works: Liquid-S4 (Hasani et al., 2023) and Mega-chunk (Ma et al., 2023)). Most significantly, S5 achieves the highest score among all models (including Mega (Ma et al., 2023)) on the Path-X task, which has by far the longest sequence length of the tasks in the benchmark.
2 Raw Speech Classification
The Speech Commands dataset (Warden, 2018) contains high-fidelity sound recordings of different human readers reciting a word from a vocabulary of 35 words. The task is to classify which word was spoken. We show in Table 2 that S5 outperforms the baselines, outperforms previous S4 methods and performs similarly to the concurrent Liquid-S4 method (Hasani et al., 2023). As S4 and S5 methods are parameterized in continuous-time, these models can be applied to datasets with different sampling rates without the need for re-training, simply by globally re-scaling the timescale parameter by the ratio between the new and old sampling rates. The result of applying the best S5 model trained on 16kHz data, to the speech data sampled (via decimation) at 8kHz, without any additional fine-tuning, is also presented in Table 2. S5 also improves this metric over the baseline methods.
3 Variable Observation Interval
The final application we study here highlights how S5 can naturally handle observations received at irregular intervals. S5 does so by supplying a different value to the discretization at each step. We use the pendulum regression example presented by Becker et al. (2019) and Schirmer et al. (2022), illustrated in Figure 3. The input sequence is a sequence of images, each pixels in size, that has been corrupted with a correlated noise process and sampled at irregular intervals from a continuous trajectory of duration . The targets are the sine and cosine of the angle of the pendulum, which follows a nonlinear dynamical system. The velocity is unobserved. We match the architecture, parameter count and training procedure of Schirmer et al. (2022). Table LABEL:tab:results:pendulum_mse summarizes the results of this experiment. S5 outperforms CRU on the regression task, recovering a lower mean error. Furthermore, S5 is markedly faster than CRU on the same hardware.
4 Pixel-level 1-D Image Classification
Table 10 in Appendix F.4 shows results of S5 on other common benchmarks including sequential MNIST, permuted sequential MNIST and sequential CIFAR (color). We see that S5 broadly matches the performance of S4, and outperforms a range of state-of-the-art RNN-based methods.
Conclusion
We introduce the S5 layer for long-range sequence modeling. The S5 layer modifies the internal structure of the S4 layer, and replaces the frequency-domain approach used by S4 with a purely recurrent, time-domain approach leveraging parallel scans. S5 achieves high performance while retaining the computational efficiency of S4. S5 also provides further opportunities. For instance, unlike the convolutional S4 methods, the parallel scan unlocks the ability to efficiently and easily process time-varying SSMs whose parameters can vary with time. Section 6.3 illustrated an example of this for sequences sampled at variable sampling rates. The concurrently developed method, Liquid-S4 (Hasani et al., 2023), uses an input-dependent bilinear dynamical system and highlights further opportunities for time-varying SSMs. The more general MIMO SSM design will also enable connections to be made with classical probabilistic state space modeling as well as more recent work on parallelizing filtering and smoothing operations (Särkkä & García-Fernández, 2020). More broadly, we hope the simplicity and generality of the S5 layer can expand the use of state space layers in deep sequence modeling and lead to new formulations and extensions.
Acknowledgements and Disclosure of Funding
We thank Albert Gu for his thorough and insightful feedback. We also acknowledge The Annotated S4 Blog (Rush & Karamcheti, 2022) which inspired our JAX implementation. This work was supported by grants from the Simons Collaboration on the Global Brain (SCGB 697092), the NIH BRAIN Initiative (U19NS113201 and R01NS113119), and the Sloan Foundation. Some of the computation for this work was made possible by Stanford Data Science Microsoft Education Azure cloud credits.
References
Appendix for: Simplified State Space Layers for Sequence Modeling
Appendix A: JAX Implementation of S5 Layer.
Appendix C: Computational Efficiency of S5.
Appendix D: Relationship Between S4 and S5.
Appendix H: Background on Parallel Scans for Linear Recurrences.
Appendix A JAX Implementation of S5 Layer
Appendix B S5 Layer Details
Here we provide additional details to supplement the discussion of initialization in Section 3.2. Gu et al. (2023) explains the ability of S4 to capture long-range dependencies when using the HiPPO-LegS matrix via decomposing the input with respect to an infinitely long, exponentially decaying measure. The HiPPO-LegS matrix and corresponding SISO input vector are defined as
B.1.2 Initialization of Input, Output and Feed-through Matrices
B.1.3 Initialization of the Timescales
B.2 Comparison of S4 and S5 Computational Elements
In Figure 4 we illustrate a comparison of the computational details of the S4 and S5 layers for efficient, parallelized offline processing.
Appendix C Computational Efficiency of S5
Given an S4 layer with input/output features, an S5 layer with input/output features and a latent size has the same order of magnitude complexity as an S4 layer in terms of both runtime and memory usage.
We first consider the case where the entire sequence is available and compare the S4 layer’s convolution mode to the S5 layer’s use of a parallel scan. We then consider the online generation case where each method operates recurrently.
Thus, S4 and S5 have the same order computational complexity and memory requirements in both cases. ∎
C.2 Empirical Runtime Comparison
Table 4 provides an empirical evaluation of the runtime performance, in terms of speed and memory, between S4, S4D and S5 across a range of sequence lengths from the LRA tasks. We compared the JAX implementation of S5 to a JAX implementation of S4 and S4D, based on the JAX implementation from Rush & Karamcheti (2022). For a fair comparison, we modified these existing JAX implementations of S4 and S4D to allow them both to enforce conjugate symmetry and use bidirectionality. For each task, models use bidirectionality and conjugate symmetry as reported in Gu et al. (2022). All models, except for the italicised S5 row, use the same input/output features and number of layers as reported in Gu et al. (2022). The S4 and S4D layers also use the same S4 SSM latent size as reported in Gu et al. (2022). All methods used the same batch size and all comparisons were made using a 16GB NVIDIA V100 GPU. Note we observed the JAX S4D implementation to in general be faster than the JAX S4 implementation (possibly due to this specific S4 implementation’s use of the naive Cauchy kernel computation (Gu et al., 2021a)). For this reason, we consider S4D as the baseline.
We consider three configurations of S5 for comparison. The first two configurations, corresponding to lines 3 and 4 for each metric in Table 4, show how the runtime metrics vary as S5’s latent size is adjusted, with all other architecture choices equal to those of S4. In line 3 of each metric, we denote the S5 “Architecture” as “(PH) Matched to Gu et al. (2022)” to indicate that this configuration of S5 sets the latent size equal to the number of input/output features, . This line empirically supports the complexity argument presented in Appendix C.1. In line 4 of each metric, we denote the S5 “Architecture” as “ (PN) Matched to Gu et al. (2022)” to indicate that this configuration of S5 sets the latent size equal to the latent size S4 uses for each of its SISO SSMs. This line also corresponds to the constrained version of S5 that performs similarly to S4/S4D as presented in the ablation study in Table 5. The runtime results of both of these configurations supports the claim in Section 4.3 that the latent size of S5 can be increased while maintaining S4’s computational efficiency.
Finally, we include a third configuration of S5, presented in the fifth line of each metric and italicized. This configuration of S5 uses the best architectural dimensions from Table 11 and was used for the corresponding LRA results in Table 1.
Importantly, the broad takeaway from this empirical study is that the runtime and memory usage of S5 and S4/S4D are broadly similar, as suggested by the complexity analysis in the main text.
Appendix D Relationship Between S4 and S5
We now describe in more detail the connection between the S4 and S5 architectures. This connection allowed us to develop more performant architectures and extend theoretical results from existing work.
We break this analysis down into three parts:
In Section D.2 we prove Proposition 2. We exploit the linearity of the systems to identify that the latent states computed by the S5 SSM are equivalent to a linear combination of latent states computed by the SISO S4 SSMs, and that the outputs of the S5 SSM are a further linear transformation of these states. We then highlight how S4 and S5 effectively define different output matrices in the block-diagonal perspective shown in Figure 2.
In Section D.3 we provide a simple extension of the proof provided by Gu et al. (2022). The original proof shows that in the SISO case, in the limit of large , the dynamics arising from a (non-diagonalizable) HiPPO-LegS matrix, are faithfully approximated by the (diagonalizable) normal component of the HiPPO-LegS matrix. We extend this proof to apply to the MIMO setting. This motivates initialization with the HiPPO-N matrix, which in-turn allows us to use parallel scans efficiently.
In Section D.4 we conclude by showing that, by judicious choice of initialization of the S5 state matrix, S5 can implement multiple independent S4 systems and relax the assumptions made. We also discuss the vector of timescale parameters, which we found to improve performance.
We note that many of these results follow straightforwardly from the linearity of the recurrence.
For these following sections we will use the following assumptions, until otherwise stated:
We consider only -dimensional to -dimensional sequence maps.
We assume that the same state matrix is used in S5 as in S4 (also cf. Assumption 2). Note this also specifies the S5 latent size . We also assume the S5 input matrix is the horizontal concatenation of the column input vectors used by S4, .
D.2 Different Output Projections of Equivalent Dynamics
For an S5 layer, the latent states are expressible as
where we index as and
where this result follows directly from the linearity of (13) and (14). This shows that (under the assumptions outlined above) the states of the MIMO S5 SSM are equivalent to the summation of the states across the different SISO S4 SSMs.
We can then consider the effect of the output matrix . For S5, the output matrix is a single dense matrix
We can substitute the relationship in (15) into (16) to cast the outputs of the MIMO S5 SSM in terms of the state of the SISO S4 SSMs:
Denoting the vertical concatenation of the S4 SSM state vectors , we see that the outputs of the S5 SSM are expressible as:
and hence are equivalent to a linear combination of the states computed by the S4 SSMs. ∎
This shows the outputs of the constrained S5 SSM under consideration (cf. Assumption 4) can be interpreted as a linear combination of the latent states computed by constrained S4 SSMs with the same state matrices and timescale parameters. Note however, it does not show that the outputs of the S5 SSM directly equal the outputs of the effective block-diagonal S4 SSM. Indeed, they are not equal, and we can repeat this analysis for the S4 layer to concretely identify the difference. For comparison we assume that the output vector for each S4 SSM is given as a row in the S5 output matrix, i.e. . We can express the output of each S4 SSM as
By inspecting (19) and (21), we can concretely express the difference in the equivalent output matrix used by both layers
In S4, the effective output matrix consists of independent vectors on the leading diagonal (as is pictured in Figure 2(a)). In contrast, the effective output matrix used by S5 instead ties dense output matrices across the S4 SSMs. As such, S5 can be interpreted as simply defining a different projection of the independent SISO SSMs than is used by S4. Note that both projection matrices have the same number of parameters.
Although the projection is different, the fact that the latent dynamics can still be interpreted as a linear projection of the same underlying S4 latent dynamics suggests that initializing the state dynamics in S5 with the HiPPO-LegS matrix may lead to good performance, similarly to what was observed in S4. We discuss this in the next section. We note that it is not obvious whether tying the dense output matrices is any more or less expressive than S4’s use of a single untied output vector for each SSM, and it is unlikely that one approach is universally better than the other. We also stress that one would never implement S4 using the block diagonal matrix in (28), or, implement S5 using the repeated matrix in (28). These matrices are simply constructs for understanding the equivalence between S4 and S5.
D.3 Diagonalizable Initialization
Proposition 2 suggests that initializing with the HiPPO-LegS matrix may yield good performance in S5, just as it does in S4 (because the constrained version of S5 under consideration is effectively a different linear projection of the same latent dynamics). However, the HiPPO-LegS matrix is not stably diagonalizable. Corollary 1 allows us to initialize MIMO SSMs with the diagonalizable HiPPO-N matrix to approximate the HiPPO-LegS matrix and expect the performance to be comparable.
Theorem 3 in Gu et al. (2022) shows the following relationship for scalar input signals as :
We first recall (15), which shows that the latent states of the MIMO S5 SSM are the summation of the latent states of the SISO S4 SSMs (to which Theorem 3 from Gu et al. (2021a) applies). Although we derived (15) in discrete time, it applies equally in continuous time:
We can therefore define the derivative of the S5 state as:
This equivalence motivates initializing S5 state matrices with the diagonalizable HiPPO-N matrix and suggests that we can expect to see similar performance gains.
D.4 Relaxing the Assumptions
Here we discuss how relaxing the constraint on S5’s latent size from Assumption 4 helps to relax the assumptions on the tied S4 SSM state matrices (Assumption 2) and timescales (Assumption 3) as well as the tied output matrices that result from Proposition 2.
It follows from Proposition 2 that the dynamics of each of these S5 SSM subsystems can be related to the dynamics of a different S4 system from Proposition 2. Each of these S4 systems has its own bank of tied S4 SSMs (cf. Assumptions 2, 3). Importantly, each of the S4 systems can have its own state matrix, timescale parameter and output matrix shared across its S4 SSMs. Thus, the outputs of a dimensional S5 SSM can be equivalent to the linear combination of the latent states of different S4 systems from Proposition 2. This fact motivates the option to initialize a block-diagonal S5 state matrix with several HiPPO-N matrices on the blocks, rather than just initializing with one larger HiPPO-N matrix. In practice we found the block-diagonal initialization to improve performance on many tasks, see Appendix E.
D.5 Timescale Parameterization
Finally, we take a closer look at the parameterization of the timescale parameters . As discussed in Section 4.3, S4 can learn a different timescale parameter for each S4 SSM, potentially allowing it to capture different timescales of the data. Further, the initialization of the timescales can be important (Gu et al., 2023; Gupta et al., 2022), and limiting to sampling a single initial parameter may lead to poor initialization. The discussion in the previous section motivates potentially learning different timescale parameters, one for each of the subsystems. However, in practice, we found better performance when using different timescale parameters, one for each of the states. On the one hand, this can be viewed simply as learning a different scaling for each of the eigenvalues in the diagonalized system (see Eq. (6)). On the other hand, this could be viewed as increasing the number of timescale parameters sampled at initialization, helping to combat the possibility of poor initialization. Of course, the system could learn to use just a single timescale by setting all of the timescales to be the same. See further discussion in the ablation study in Appendix E.
Appendix E Ablations
We perform several ablations to empirically explore different aspects of S5.
The discussion in Section 4 and Appendix D raises several interesting questions: How does S5 perform when the latent size is restricted to be equal to the latent size used by each of S4’s SSMs? How important is the timescale parameterization discussed in Appendix D.5? How important is the block-diagonal initialization? Table 5 displays the results of an ablation study performed on the LRA tasks to get a better sense of this. We consider 3 versions of S5.
Finally, the complexity analysis and runtime comparison in Appendix C.2 suggests the latent size of S5 can be increased while maintaining similar complexity and practical runtimes as S4. We include the unconstrained version of S5 reported in our main results that uses the settings reported in the hyperparameter Table 11. These models were allowed to be parameterized with fewer input/output features (to ensure similar parameter counts to the S4 baselines) and generally used larger latent sizes . Further, we swept over the use of a block-diagonal initialization or not and the number of blocks to use (where indicates no block-diagonal initialization was used). All models benefited from the block-diagonal initialization for the LRA tasks (See Table 11 ).
E.2 Importance of HiPPO-N and continuous-time parameterization
We perform a further ablation study to gain insight into the differences between S5 and prior attempts at parallelized linear RNNs (discussed in Section 5) focusing on what appears to be the distinguishing features: continuous-time parameterizations and HiPPO initializations. We compare different initializations of the state matrix: random Gaussian, random antisymmetric, and HiPPO-N. The antisymmetric initialization is interesting because prior work considered these matrices in RNNs for long-range dependencies (Chang et al., 2019), and because the HiPPO-LegS matrix can be parameterized in a way related to antisymmetric matrices (Gu et al., 2021a). Moreover, to compare to a setup more akin to the previous parallelized linear RNN work, we also consider a direct discrete-time parameterization of S5 that does not perform repeated discretization during training or learn the timescale parameter . We present the results of this ablation study in Table 6 (along with S5). We consider three of the LRA tasks that vary in length and difficulty.
The main takeaway is that the only approach that consistently performs well on all tasks, including the ability to solve Path-X, is the S5 approach that uses the continuous-time parameterization and HiPPO initialization. We also note that we observed the discrete time/HiPPO-N matrix configuration to be difficult to train due to stability issues, typically requiring a much lower learning rate.
E.3 S4D Initialization ablations
Finally, Gu et al. (2022) propose several alternative diagonal matrices to the diagonalized HiPPO-N matrix, including the S4D-Inv and S4D-Lin matrices. They perform an ablation on the LRA tasks where they simply replace the diagonalized HiPPO-N matrix with the S4D-Inv and S4D-Lin matrices while keeping all other factors the same. We include these results in Table 7. In Table 7, we also include results for a similar ablation in S5 by using these matrices to initialize S5 in place of the HiPPO-N matrix while keeping all other factors constant. Both matrices perform well on most tasks with the exception of the S4D-Lin matrix on Path-X. Interestingly, one of these runs reached , however the other runs did not exceed random guessing on this task. Future exploration of these and other matrices are an interesting direction for future work.
Appendix F Supplementary Results
We include further experimental results to supplement the results presented in the main text.
F.2 Extended Speech Results
F.3 Pendulum Extended Results
We also evaluate two ablations: S5-drop uses the same S5 architecture, but drops the dependence on the inter-sample interval, i.e. . We expect this network to perform poorly as it has no knowledge of how long has elapsed between observations. S5-append uses the same S5 architecture, but appends the integration timestep to the thirty-dimensional image encoding, prior to being input into the dense S5 input layer. Hypothetically, we expect this network to perform as well as S5. However, to do so, requires the S5 network to learn to process time, which may be difficult, especially in more complex domains. We include these ablations in the bottom partition of Table LABEL:app:tab:results:pendulum_mse.
Note that the runtimes quoted for the baseline methods (runtimes marked with a *) are as reported by Schirmer et al. (2022). These times are the total time for a training epoch, and hence include any time spent batching data. We re-ran the CRU using the original PyTorch code on the same hardware as we run our JAX S5 experiments on (labelled CRU (our run)). For these experiments we used a single NVIDIA GeForce RTX 2080 Ti. For these runs (CRU (our run), S5, S5-drop and S5-append) we exclude the time spent batching the data to more faithfully compare the runtimes for the models themselves. Also note that our S5 experiments will benefit from JAX compilation, but that this is not sufficient to explain the difference in runtime.
F.4 Pixel-level 1-D Image Classification Results
Table 10 presents results and citations of the pixel-level 1-D image classification.
Appendix G Experiment Configurations
In this section we describe the experimental details. This includes the model architecture, general hyperparameters, specifics for each task, and information about the datasets.
For the experiments, we use the S5 layer as a drop-in replacement for the S4 layer used in the sequence model architecture of (Gu et al., 2021a). On a high level, this architecture consists of a linear encoder (to encode the input at each time step into features), multiple S5 layers, a mean pooling layer, a linear decoder, and a Softmax operation for the classification tasks. The mean pooling layer compresses the output of the last S5 layer, of shape [batch size, sequence length, number of features ()], across the sequence length dimension, so that a single -dimensional encoding is available for softmax classification.
Hyperparameter options such as dropout rate, using either layer normalization or batch normalization, and using either pre-norm or post-norm are applied between the layers. Exceptions to the basic architecture described here are mentioned in the individual experiment sections below.
G.2 Default Hyperparameters
Table 11 presents the main hyperparameters used for each experiment. For all experiments we ensure the number of layers and layer input/output features are less than or equal to the number of layers and layer input/output features reported in Gu et al. (2022) as well as ensuring comparable parameter counts.
In general, the models for most tasks used batch normalization and pre-norm. Exceptions are noted in the individual experiment sections below.
G.2.2 Bidirectionality
We follow Gu et al. (2022) and use bidirectional models for the LRA and speech tasks. Unidirectional (causal) models were used for the pendulum, sequential and permuted MNIST for fair comparison with prior methods that used unidirectional models.
G.3 Task Specific Hyperparameters
Here we specify any task-specific details, hyperparameter or architectural differences from the defaults outlined above.
G.3.2 Text
No exceptions to the defaults for this run.
G.3.3 Retrieval
This document matching task requires a slightly different architecture from the other experiments, as discussed in Tay et al. (2021). We use the same configuration as S4 (Gu et al., 2021a). Each string is passed through the input encoder, S5 layers, and mean pooling layers. Denoting as the output for the first document and as the output for the second document, four features are created and concatenated together (Tay et al., 2021) as
This concatenated feature is then fed to a linear decoder and softmax function as normal.
G.3.4 Image
G.3.5 Pathfinder
No exceptions to the defaults for this run.
G.3.6 Path-X
G.3.7 Speech Commands
G.3.8 Pendulum Regression
We use the same encoder-decoder architecture as Schirmer et al. (2022). The encoder has layers: convolution, ReLU, max pool, convolution, ReLU, max pool, dense, ReLU, dense. The first convolution layer has twelve features, a kernel, and a padding of . The second convolution layer has twelve features, a kernel, a stride of , and a padding of . Both max pools use a window size of and a stride of 2. The dense layer has thirty hidden units. The linear readout layer outputs features. This is chosen to match the encoding size in Schirmer et al. (2022), and is used for all layers. Separate mean and unconstrained variance decoders are used, defined as a one-layer MLP with a hidden size of thirty. An elu+1 activation function is used to constrain the variance to be positive.
Layer normalization and post-norm were used for this task.
For the timings presented in Table LABEL:tab:results:pendulum_mse and LABEL:app:tab:results:pendulum_mse we use a batch size of , instead of the batch size of used during training, to match the batch sizes reported by the baselines.
G.3.9 Sequential MNIST
No exceptions to the defaults for this run.
G.3.10 Permuted Sequential MNIST
G.3.11 Sequential CIFAR
We trained a model with the exact hyperparameter settings as used for the LRA-IMAGE (grayscale sequential CIFAR) task with no further tuning.
G.4 Dataset Details
We provide more context and details for each of the LRA (Tay et al., 2021) and Speech Commands (Warden, 2018) datasets we consider. Note that we follow the same data pre-processing steps as Gu et al. (2021a), which we also include here for completeness.
Text: Based off of the iMDB sentiment dataset presented by Maas et al. (2011). Given a movie review, where characters are encoded as a sequence of integer tokens, classify whether the movie review is positive or negative. Characters are encoded as one-hot vectors, with unique values possible. Sequences are of unequal length, and are padded to a maximum length of . There are two different classes, representing positive and negative sentiment. There are training examples and test examples. No validation set is provided. No normalization is applied.
Retrieval: Based off of the ACL Anthology network corpus presented by Radev et al. (2009). Given two textual citations, where characters are encoded as a sequence of integer tokens, classify whether the two citations are equivalent. The citations must be compressed separately, before being passed into a final classifier layer. This is to evaluate how effectively the network can represent the text. The decoder head then uses the encoded representation to complete the task. Characters are encoded into a one-hot vector with unique values. Two paired sequences may be of unequal length, with a maximum sequence length of . There are two different classes, representing whether the citations are equivalent or not. There are training pairs, validation pairs, and test pairs. No normalization is applied.
Image: Uses the CIFAR-10 dataset presented by Krizhevsky (2009). Given a grayscale CIFAR-10 image as a one-dimensional raster scan, classify the image into one of ten classes. Sequences are of equal length (). There are ten different classes. There are training examples, validation examples, and test examples. RGB pixel values are converted to a grayscale intensities, which are then normalized to have zero mean and unit variance (across the entire dataset).
Pathfinder: Based off of the Pathfinder challenge introduced by Linsley et al. (2018). A grayscale image image shows a start and an end point as a small circle. There are a number of dashed lines on the image. The task is to classify whether there is a dashed line (or path) joining the start and end point. There are two different classes, indicating whether there is a valid path or not. Sequences are all of the same length (). There are training examples, validation examples, and test examples. The data is normalized to be in the range $$.
Path-X: An “extreme” version of the Pathfinder challenge. Instead, the images are pixels, resulting in sequences that are a factor of sixteen times longer. Otherwise identical to the Pathfinder challenge.
Speech Commands: Based on the dataset released by Warden (2018). Readers recite one of 35 words. The task is then to classify which of the 35 words was spoken from a one-dimensional audio recording. There are 35 different classes, each representing one of the words in the vocabulary. Sequences are all of the same length (). There are training examples, validation examples, and test examples. Data is normalized to be zero mean and have a standard deviation of .
Speech Commands 0.5: Temporally sub-sampled version of Speech Commands, where the validation and test datasets only are sub-sampled by a factor of , and are therefore shortened to length . No subsequent padding is applied. The training dataset is not subsampled.
Sequential MNIST: (sMNIST) 10-way digit classification from a grayscale image of a handwritten digit, where the input image is flattened into a -length scalar sequence.
Permuted Sequential MNIST: (psMNIST) 10-way digit classification from a grayscale image of a handwritten digit, where the input image is flattened into a -length scalar sequence. This sequence is then permuted using a fixed order.
Sequential CIFAR: (sCIFAR): 10-way image classification using the CIFAR-10 dataset. Identical to image, except that full colour images are input as a -length input sequence, where each input is an (R,G,B) triple.
Pendulum Regression: Reproduced from Becker et al. (2019) and Schirmer et al. (2022). The input sequence is a grayscale rendering of a pendulum, driven by a random torque process. The images pixels are corrupted by a noise process that is correlated in time. The pendulum is simulated for timesteps, and frames are irregularly sampled without replacement from the simulation. The objective is to estimate the sine and cosine of the angle of the pendulum. A train/validation/test split of is used.
Appendix H Background on Parallel Scans for Linear Recurrences
For the interested reader, this section provides more background on using a parallel scan for a linear recurrence, as well as a simple example to illustrate how it can compute the recurrence in parallel. The parallelization of scan operations has been well studied (Ladner & Fischer, 1980; Lakshmivarahan & Dhall, 1994; Blelloch, 1990), and many standard scientific computing libraries contain efficient implementations. We note the linear recurrence we consider here is a specific instance of the more general setting discussed in Section 1.4 of Blelloch (1990).
Computing a general parallel scan requires defining two objects:
The initial elements the scan will operate on.
A binary associative operator used to combine the elements.
To compute a length linear recurrence, , we will define the initial elements, , such that each element is the tuple
These will be precomputed prior to the scan. Having created the list of elements for the scan to operate on, we define the binary operator for the scan to use on this linear recurrence as
where denotes an input element to the operator that could be the initial elements or some intermediate result, denotes matrix-matrix multiplication, denotes matrix-vector multiplication and denotes elementwise addition. We show that this operator is associative at the end of this section.
We can illustrate how can be used to compute a linear recurrence in parallel with a simple example. Consider the system , and a length sequence of inputs . Assuming , the desired latent states from this recurrence are:
We first note that can be used to compute this recurrence sequentially. We can initialize the scan elements as in (39), and then sequentially scan over these elements to compute the output elements . Defining where is the identity matrix, we have for our example:
Note that the second entry of each of the output tuples, , contains the desired computed above. Computing the scan in this way requires four sequential steps since each depends on .
Now consider how we can use this binary operator to compute the recurrence in parallel. We will label the output elements of the parallel scan as and define . We will first compute the even indexed elements and , and then compute the odd indexed elements and . We start by applying the binary operator to adjacent pairs of our initial elements to compute and the intermediate result , and we then repeat this process to compute by applying to and :
Now we will compute the odd indexed elements and , using the even indexed and , as :
Note that the second entry of each of the output tuples, , corresponds to the desired . Inspecting the required dependencies for each application of , we see that and the intermediate result can be computed in parallel. Once and are computed, , and can all be computed in parallel. We have therefore reduced the number of sequential steps required from four in the sequential scan version to two in the parallel scan version. This reduction in sequential steps becomes important when the sequence length is large since, given sufficient processors, the parallel time scales logarithmically with the sequence length.
Finally, for completeness, we show that the binary operator is associative: