Repeat After Me: Transformers are Better than State Space Models at Copying
Samy Jelassi, David Brandfonbrener, Sham M. Kakade, Eran Malach
Introduction
Transformers (Vaswani et al., 2017) are the workhorse of modern sequence modeling, achieving remarkable performance on a variety of tasks, but they have unavoidable inefficiencies. Specifically, they require memory111In some naive implementations of transformers, it is common to allocate a matrix to compute the attention. However, memory efficient implementations, such as FlashAttention (Dao et al., 2022), compute the attention with memory. and compute to predict the next token of a sequence of length .
This has spurred a boom in attempts to create architectures that can achieve similar performance as transformers, but with memory to predict each token. This class of models includes state space models like S4 (Gu et al., 2021) or Mamba (Gu & Dao, 2023), as well as traditional RNN models (Hochreiter & Schmidhuber, 1997) and models that can be trained in parallel like linear attention (Katharopoulos et al., 2020; Choromanski et al., 2020) and parallel RNNs (Bradbury et al., 2016; Peng et al., 2023; Sun et al., 2023). In this paper, we will refer to this entire class of models that use a fixed-size memory as “generalized state space models” or GSSMs (see a formal definition in Section 2).
Recent work has demonstrated impressive performance of GSSMs, but it is not yet clear what these models sacrifice for their improved efficiency, if anything. In this paper, we find that one particular capability that is sacrificed is the ability to retrieve and repeat parts of the input context. As a result, transformers are better than GSSMs at a variety of tasks that require accessing arbitrary parts of the context.
To understand this gap in capabilities, we begin by presenting a theoretical analysis of the copying task222Note that we study copying of the input and not copying of training data (McCoy et al., 2023; Carlini et al., 2022). First, we show via construction that a simple transformer model can copy strings of length that is exponential in the number of heads of the transformer. This construction relies on the ability of the transformer to implement a mechanism of “storage” and retrieval of sequences of n tokens (n-grams), where the n-grams are used to track where to copy from. In contrast, we show that, trivially, GSSMs cannot accurately copy strings with more bits than the size of the latent state.
Pythia: 410M 1.4B 2.8B Mamba: 360M 1.4B 2.8B
Our theory studies representation expressivity, but not whether these representations will be learned. Moreover, in practice a large GSSM may have enough capacity to represent the entire input in the latent state, at least in theory. To resolve these concerns, we conduct a variety of synthetic experiments with models of 160M parameters. We find that transformers are both much more efficient at learning to copy (Figure 1(a)) and also generalize better to longer inputs (Figure 1(b)). Additionally, we verify experimentally that the copy “algorithm” learned by transformers indeed relies on n-grams to perform a lookup of where to copy from (Figure 3), similarly to our theoretical construction.
Finally, we present a variety of experiments on pre-trained models to test their ability to remember and access the input context. In particular, we show that Pythia transformers (Biderman et al., 2023) outperform Mamba GSSMs (Gu & Dao, 2023) of similar size at a variety of memory-intensive tasks including copying and retrieving information from the context (Figure 1(c)). This is especially notable since the Mamba models achieve lower perplexity than the Pythia models at language modeling on the Pile (Gao et al., 2020). These experiments illustrate the practical relevance of the memory issues that we raise, and hint at one way that architectual choices can impact the downstream performance of LLMs above and beyond training perplexity.
Theory: Representational Capacity
In this section we use the copy task for a theoretical comparison between state space models and transformers. We prove two main results. First, we construct a small transformer that solves the copy task for sequences lengths that are exponential in the transformer size. Second, we show that any state space model fails to solve the copy task, unless its latent state grows linearly with the sequence length.
Let be a dictionary, which contains “alphabet” tokens. A sequence-to-sequence model is a function , which maps an input sequence of tokens to an output sequence. We think of the input as the “prompt” to the model, and of the output sequence as the generated “answer”.
A sequence-to-token mapping is a function . Any sequence-to-token model naturally defines a sequence-to-sequence model by auto-regressive inference. Namely, for every input sequence we define recursively and let .
A state space is some finite set. We denote by the number of bits required to encode the states of , namely . A generalized state space model (GSSM) is a sequence model defined by an update rule and some output function . Let be some initial state. Given some sequence , the state of the model at iteration is denoted by and the output token is denoted by . The state and output are defined recursively: 1) , 2) , 3) .
It is important to note that for any sequence model, there are two types of memory considerations: 1) input-independent memory (parameters) and 2) input-dependent memory (activations). The GSSM definition constraints the input-dependent memory (activations), which corresponds to , and does not restrict in any way the amount of input-independent memory (parameters) or the run-time of state updates. Since our main goal is to show a lower bound on the state space memory, leaving all other considerations unconstrained only strengthens our results.
Transformers.
Given some input of length and dimension , denoted , an attention head is parameterized by . We denote and denote and . We denote the output of the head at token by , where .
We consider a transformer with attention heads, each one of dimension so that the full dimension of the Transformer is . An embedding is some mapping . An MLP is a function s.t. , for some activation function . Both the embedding and the MLP layer are assumed to be applied on the token level. An attention-block is a set of heads applied in parallel, and a transformer-block is an attention-block followed by an MLP which operates on the concatenated output of the heads. The output of the model is sampled based on the output of the final layer. For simplicity, we study the “sampling” (i.e., predicting the most probable token).
The copy task.
To define the copy task, we add two special tokens to : (1) beginning-of-sequence token, denoted , and (2) copy token, denoted . So now . A length- copy distribution over generates strings of the form: “”, where .
For some sequence-to-sequence model , we denote the error of on a copy distribution by
where denotes the first tokens generated by . That is, we expect the model to output an exact copy of .
2 Transformers can copy inputs of exponential length
In this section we show that transformers can implement the copy operation for input sequences with length exponential in the number of heads. Namely, we construct a transformer with two blocks that gets small error on the copy task.
The key idea in the construction is to first “hash” sequences of tokens (-grams), then at each iteration of the auto-regression attend to the previous occurrence of the most recent -gram, and output the succeeding token. That is, we show that a transformer can implement the copying algorithm illustrated in Figure 3 (and see also Algorithm 1 in the Appendix).
Positional embedding: Hard-ALiBi.
To perform the hashing described in the algorithm, we need to be able to leverage local positional information to define a hash, and also to apply this hash function globally on the entire input. To do this, we use a hard version of ALiBi (Press et al., 2021), which we call Hard-ALiBi. Just as in ALiBi, we add a bias to the -th attention head as follows: . Specifically, we set s.t. for and for . We allow different heads with different choices of and also allow for which corresponds to softmax attention with no positional embedding. This is illustrated in Figure 8(c) (Appendix). While the Hard-ALiBi is introduced for our theoretical construction, we observe it also offers significant benefits empirically, as discussed in Section 3.
Guarantees.
The copy algorithm given in Algorithm 1 (and similarly, our transformer construction) can perfectly copy the input sequence, as long as there are no repeated -gram patterns in the input. Therefore, the error of the algorithm depends on the probability of repeated -grams:
Let be some copy distribution. For some , let be the probability that contains two repeated sequences of tokens. Namely:
Below we state the main theoretical result on copying with transformers, showing that transformers can copy their input, with error bounded by the probability of repeated -grams:
For all , there exists a depth-2 transformer of dimension s.t. for all , and for any copy distribution , .
Intuitively, the probability of repeated -grams decays quickly when increasing the value of . Indeed, we show that for the uniform distribution over sequences, this probability decays exponentially with :
Let be the copy distribution generated by sampling from the uniform distribution over the “alphabet” (non-special) tokens. Then, .
Combining the above results, we get that transformers can copy sequences of tokens drawn from the uniform distribution, using a number of parameters that depends only logarithmically on the input sequence length.
Fix some and some . There exists a depth-2 transformer of dimension s.t. for the uniform copy distribution , .
For simplicity we do not limit the precision of the parameters or activations, but note that our results hold for finite-precision transormers, using bits.
3 State Space Models cannot copy inputs beyond memory size
We saw that transformers are able to copy uniform sequences of tokens, with parameter count logarithmic in the sequence length. We now show that GSSMs cannot copy uniform input sequences, unless the capacity of their state space grows linearly with the size of the sequence length. This is intuitive: to be able to copy the entire input sequence, the model needs to store it in its state space, which requires the memory to grow linearly with the sequence length.
Fix some GSSM over state space . Then, for all , for the uniform copy distribution , the model has error .
Given Theorem 2.7, the following Corollary is immediate:
Fix some . Then, every GSSM with state space s.t. has error for the uniform copy distribution .
As mentioned previously, the input-dependent memory of transformers grows linearly with the sequence length, which is less memory-efficient compared to GSSMs. However, it is interesting to note that from the above result, at least for the copy task, transformers are almost optimal in terms of their input-dependent memory. More specifically, an implication of Theorem 2.3 is that there exists a transformer which can copy inputs of length using input-dependent memory333We use to hide logarithmic factors., and due to Corollary 2.8 this is indeed optimal (up to logarithmic factors).
Learning to Copy
In the previous section, we proved that transformers can represent the copy operation for exponentially long sequences, while GSSMs fail to copy long sequences due to their limited memory. While these results show that in theory, transformers can outperform GSSMs, our theoretical results do not establish that such a gap will be observed in practice for two reasons. First, it is not clear that transformers can indeed learn to copy from examples. Second, GSSMs in practice may use a large latent state memory, so that our bounds only hold for very long sequences of tokens. For example, a latent state of 1000 32-bit floating point numbers has enough bits to store at least 2000 tokens from a 50K token vocabulary. However, even though a GSSM could fit the context into memory, it may not learn to do so.
Our goal in this section is to verify that our theoretical analysis bears out experimentally when training models from scratch on synthetic data, before moving on to study pretrained models in the next section. Specifically, we train transformers and GSSMs (LSTM (Hochreiter & Schmidhuber, 1997) and Mamba (Gu & Dao, 2023)) on variants of the copy task shown in Figure 2.
We now provide a brief overview of our experimental setup. Further details may be found in Appendix A.
In all our experiments, we set the model hyperparameters so that the Mamba and transformers have a similar number of parameters ( million parameters). Since we find that large LSTMs are hard to train (as confirmed in Pascanu et al. (2013)), we use the largest LSTM we managed to train which has million parameters.
Dataset.
During training, we generate in an online manner a batch of 64 examples at each epoch. At test time, we evaluate our models on batches of examples. We report the mean and standard-deviation over these 10 batches. If not specified otherwise, our token space is of size 30 and made of the alphabet letters i.e. where is the beginning of sentence token, the end of sentence token and the separator token. All the strings are sampled uniformly i.e. we first sample the length of the sequence and then independently sample each position of the string from . Finally, we “pack the context” with i.i.d. sequences during training similarly to (Zhou et al., 2023): we fill the context with multiple independent samples of the task.
Positional information.
Positional information also plays an important role in the length generalization capacity of Transformers (Jelassi et al., 2023; Kazemnejad et al., 2023; Shen et al., 2023). Previously popular methods of input-layer positional embeddings (e.g. sinusoidal (Vaswani et al., 2017) or learned (Radford et al., 2019)) have been replaced by relative positional encodings at each attention layer (e.g. RoPE (Su et al., 2023), Alibi (Press et al., 2021), or NoPE (Kazemnejad et al., 2023)). Below, we experiment these positional encodings along with the Hard-Alibi encoding introduced in Section 2.
2 Data efficiency on the copy task
We begin by training our models on the simple task of copying a sequence of input tokens described in Figure 2. The model gets an input of tokens followed by a Separator () token, and needs to output the same sequence again from the beginning. In this section, we focus on in-distribution learning: we train on strings of random length and record the string-level accuracy on evaluation strings sampled from the training distribution.
Results for this experiment are shown in 1(a). Clearly, there is a large gap between the transformers and GSSMs. We observe that the transformers need 100x less samples than the best GSSMs to learn the copy task.
Note that the sharp changes in accuracy displayed in 1(a) are due to the log-scaled x-axis and choice of string-level accuracy as a metric. In 9(a), we report the character-level accuracy, which yields smoother curves demonstrating the learning process of GSSMs. Regarding LSTMs, we find that they do not manage to learn on length-300 strings even at the character level. In 9(b), we show that LSTMs are able to learn to copy on shorter strings and that string length is the bottleneck.
3 Length generalization on the copy task
The prior experiment demonstrates superior efficiency of learning in-distribution. Now, we test the ability of the learned functions to generalize out-of-distribution. Specifically, we consider generalization from short sequences to longer sequences. Testing this sort of generalization can help us to better understand which function the model has learned, i.e. whether the model has truly learned the “correct” copy operation or whether it just learned to copy sequences of the particular size it was trained on.
Here, we train all models on sequences of tokens, and test them on sequences of up to tokens, reporting string-level accuracy. As seen in 1(b), all models are able to (eventually) solve the task in-distribution on lengths of , but transformer-based models display much better generalization to longer inputs compared to GSSMs. Namely, we observe that the performance of the GSSMs (LSTM and MAMBA) drops to zero almost immediately when increasing the input length, while the performance of transformers decays much more gradually with length.
When looking at the relative performance of different transformer models in 1(b), it becomes clear that the positional encoding is important to length generalization. Specifically, the ALiBi and NoPE transformers dramatically outperform the RoPE model on longer inputs. This is likely because the sinusoidal embeddings of RoPE create a more dramatic change than the decay of ALiBi or NoPE when we go to longer inputs.
Improved generalization with Hard-ALiBi.
To test our understanding of how transformers learn to copy, we now consider swapping in the Hard-ALiBi positional encoding that we used in our theoretical construction of hash-based copying (introduces in Subsection 2.2 and illustrated in Figure 8 in the Appendix). 1(b) shows that a transformer trained with Hard-ALiBi embedding on sequences of length achieves almost perfect length generalization up to sequences of length 1000. Note that this is well beyond the context length ever encountered in training.
4 Transformers learn to use n-gram hashing
Next, we attempt to determine whether the transformer trained on the copy task indeed applies the mechanism of storage and retrieval of n-grams. To do this, we evaluate the performance of a transformer with Hard-ALiBi positional encoding trained on the copy task when tested on a distribution of examples that intentionally contains duplicate n-grams. That is, we draw uniform sequences of tokens, and then randomly replace some n-gram with another n-gram that already appears in the sequence, such that each example always contains two copies of the same n-gram (typically followed by a different token). We use the Hard-Alibi model here since it performs the best for the copy task as showed in 1(a). Figure 4 shows the performance of the transformer for different choices of . We observe that the transformer maintains roughly the same accuracy for , but that its accuracy starts dropping when the inputs contains duplicate sequences of 5 or more tokens. This suggests that the transformer relies on something like 5-gram retrieval to do the copy task.
5 GSSMs cannot arbitrarily retrieve from context
Transformer: NoPE Alibi HAlibi GSSM: LSTM Mamba
Transformer: NoPE Alibi HAlibi GSSM: LSTM Mamba
We now introduce another task to probe the mechanisms that the models use to copy from the context: the n-gram lookup task. In this task the model needs to use a given n-gram as a key to look up the k-token key that follows the query. We consider two variants of the task: suffix keys and prefix keys. In both variants, we assess length generalization to understand the function that the models have learned.
First, we consider the suffix key version of n-gram lookup. In this task, the model is given a sequence of input tokens, a separator, and then an n-gram from the input sequence. The model then needs to output a sequence of tokens following the chosen n-gram (see Figure 5 for an illustration). This task is closely related to induction heads (Olsson et al., 2022). This task requires the model to be able to “store” the entire context in order to effectively find the correct key to access it’s query. We train all models on sequences of at most 30 tokens and show results in Figure 5. Transformers perform well on this task, with a relatively small drop in performance when increasing the sequence length up to 100. This suggests that transformers can learn to perform n-gram storage and retrieval. GSSMs, however, perform poorly beyond their training distribution. Intuitively, this task still requires the models to store the entire input sequence, something that GSSMs struggle to do.
Next, we try the prefix key version of n-gram lookup. Here we provide the n-gram key at the beginning and then the full input sequence (illustrated in Figure 6). In this version of the task the model does not need to store the entire input since it can look for the key on the fly as the sequence is processed. This is good for the GSSMs, since they can write the key into the state and then ignore inputs that do not match. Indeed, GSSMs achieve perfect length-generalization on this variant. Interestingly, the GSSMs even outperform the NoPE and ALiBi transformers (although not the Hard-Alibi model). We hypothesize that this may be an issue where these positional embeddings make it more difficult to effectively perform the hashing lookup over a long distance in relative positions. Taken together, these results illustrate how GSSMs seem to be memory limited, but can be effective when the tasks only require a summary of the inputs rather than storing the entire context.
Pre-trained Models
In this section, we compare the performance of pre-trained transformers and pre-trained GSSMs on memory-intensive tasks such as copying long strings, retrieval and few-shot question answering. We show that transformers outperform GSSMs of similar scale on such memory-intensive tasks, even when the GSSM has lower perplexity as a language model. These results confirm that the limitation of GSSMs raised in previous sections apply to large scale models trained on real pretraining data.
In the experiments below, we compare Pythia transformer models (Biderman et al., 2023) of sizes ranging from 410M to 2.8B against Mamba models (Gu & Dao, 2023) of similar sizes. All these models have been pre-trained on the Pile (Gao et al., 2020) and use the same tokenizer. The Mamba models generally have slightly lower perplexity on the training set for a given size. The main difference between the Pythia and the Mamba models is their architectural design.
We compare these models by measuring their performance while varying the input instance length and consider two types of tasks: copy-based and information retrieval tasks. The copy-based tasks consist of presenting a random text to the model and asking it to copy the text. In the information retrieval tasks, we provide a text to the model and ask it a related question. These retrieval tasks can be seen as ”selective copy”, since the model needs to copy a small chunk of the input text in order to respond to the question. To measure performance, we use the string-level accuracy in all the experiments except in 7(c) where we consider question answering and thus report the F1 score. We evaluate the models over 10 batches of size 64 for all the tasks except for question answering where we evaluate over 50 questions because the number of questions with a given context length is limited. Further details are in Appendix A.
Pythia: 410M 1.4B 2.8B Mamba: 360M 1.4B 2.8B
Pythia: 410M 1.4B 2.8B Mamba: 360M 1.4B 2.8B
2 Copying the input context
We first observe that pre-trained transformers outperform pre-trained GSSMs at copying long natural language strings. In 7(a), we randomly sample strings from the C4 dataset (Raffel et al., 2020) with varying number of tokens. Our prompt consists of two copies of the sampled string plus the first word of the string and we expect the model to complete the third copy. Even the smallest transformer model dramatically outperforms the largest GSSM. This happens even though the large GSSMs have enough bits in the state variable to potentially store the context. This confirms the idea that this is an architectual bias of transformers that makes it easier for them to copy from the context.
Unlike strings of tokens sampled uniformly at random, natural text can often be compressed, possibly allowing language models to copy longer strings even with limited memory. To test whether this matters, in 7(b) we conduct the same experiment as above but randomly shuffle the order of the words in the strings. We find that when we shuffle the words, both GSSMs and transformers perform worse on the task, but the effect is more stark for GSSMs. Even the largest GSSM now gets zero accuracy on strings of length 300. This suggests that when the input is more difficult to compress, the GSSM suffers due to its fixed size state.
3 Retrieval from the input context
While copying provides a clear task to separate the model classes, it is not a particularly realistic task. That said, it presents an extreme case of a type of behavior that is highly relevant for many tasks of interest. In particular, many tasks require retrieving specific information from the context that is relevant to the desired output. This subsection presents examples of how our results transfer to more practical tasks.
We first consider a “phone-book” experiment where we provide a synthetic phone-book to the model and ask it to return the phone number when given a name. We generate the phone-book by randomly sampling names and their associated phone number. One line of this phone-book looks like “John Powell: 609-323-7777”. Our prompt to the model consists of the phone-book, two few-shot examples and a question asking for the phone number of a randomly sampled name from the phone-book. 1(c) reports the accuracy obtained by the pretrained transformers and GSSMs while varying the size of the phone-book We observe that even the smallest transformer (410M parameters) outperforms the largest GSSMs (2.8B parameters) when the phone-book size is long enough (). This shows that in retrieval tasks which require access to the whole context, GSSMs struggle to store the relevant information in their fixed-size state.
Question-Answering.
In this experiment, we compare the 2.8B parameter Mamba and transformer models444In our experiments, smaller models were unable to achieve reasonable and consistent performance on this dataset., on the SQuAD question-answering dataset (Rajpurkar et al., 2018). This dataset provides text paragraphs together with a few questions regarding the text. We probe the models to answer the question by providing a single demonstration of a question/answer pair (corresponding to the same text) before giving the target question. We bin the paragraphs according to their lengths, and report the F1 score as a function of the paragraph length for both models in 7(c). We observe that while for short paragraphs, both the Pythia transformer and Mamba achieve comparable performance, the performance of Mamba degrades more quickly with the paragraph length, while the transformer-based model maintains a similar accuracy even for longer texts. This result shows that the fixed-memory of GSSMs also limits their performance on standard natural tasks.
Related Work
There exists a broad body of prior work on the representational capacity of GSSMs like RNNs (Merrill, 2019; Merrill et al., 2020) as well as transformers (Weiss et al., 2021; Merrill et al., 2022; Wei et al., 2022; Sanford et al., 2023; Edelman et al., 2022). Previous works that study transformers do so through comparison to other complexity classes, such as threshold circuits (Merrill et al., 2022), RASP language (Weiss et al., 2021) or first-order logic (Chiang et al., 2023) (see Strobl et al. (2023) for a thorough review). These works do not provide insights into how transformers implement algorithms for solving specific problems. In contrast, our theoretical result constructs a transformer for the copy task, which illustrates the mechanism and provides tight bounds on the model size. Together with the result showing that GSSMs cannot copy long sequences, our theory characterizes the power of different sequence models on the copy task. Other theoretical separation results between transformers and RNNs (Sanford et al., 2023; Merrill, 2019) use more complex tasks of less practical relevance.
Other papers have previously demonstrated the capacity of transformers to leverage the entire input context for tasks like retrieval, question answering, and in-context learning (Devlin et al., 2018; Raffel et al., 2020; Petroni et al., 2020; Brown et al., 2020; Liu et al., 2023b). Another line of work has studied the “induction head” mechanism in transformers that performs a retrieval operation much like the one we observe for copying (Olsson et al., 2022). But, to our knowledge, there is not a comparison in related work between transformers and GSSMs of similar quality on these tasks.
Several of our experiments study length generalization as a way to assess whether the model found the “right way” to solve the task. Prior work on length generalization in transformers has focused on the data distribution (Anil et al., 2022), positional embeddings (Kazemnejad et al., 2023), and arithmetic tasks (Delétang et al., 2022; Ruoss et al., 2023; Jelassi et al., 2023; Zhou et al., 2023). We extend many of these ideas to the copying task.
Finally, we note that while we focus on tasks where transformers outperform GSSMs, there are also tasks where GSSMs outperform transformers. For example, Liu et al. (2023a) shows that transformers fail to generalize out of distribution for “flip-flop language modeling”, while LSTMs do so easily. These tasks require tracking a small state variable over time. Another benefit of GSSMs is the ability to input long contexts like DNA sequences that may be impractical for transformers (Nguyen et al., 2023).
Discussion
We have demonstrated through theory and experiments that transformers are better than GSSMs at copying from their input context. However, we emphasize that state space models have many advantages over transformers. The memory and computational complexity of GSSMs does not increase with the input length, which is ideal for training and inference on long inputs. Additionally, state space models such as RNNs are better at tracking state variables across long sequences (Liu et al., 2023a), which may be useful for generating long consistent text. Importantly, language processing in the human brain appears to be much more similar to how state space models process language (Tikochinski et al., 2024). We therefore believe that future work should focus on building hybrid architectures that endow state space models with an attention-like mechanism, allowing them to retrieve relevant pieces of text from their input. Indeed, humans have an incredibly limited capacity for memorizing sequences (Miller, 1956), but can translate entire novels if we allow them to look back at the text (Shelton, 1612).
Acknowledgements
We thank Boaz Barak for helpful discussions. Kempner Institute computing resources enabled this work. Samy Jelassi acknowledges funding supported by the Center of Mathematical Sciences and Applications. This work has been made possible in part by a gift from the Chan Zuckerberg Initiative Foundation to establish the Kempner Institute for the Study of Natural and Artificial Intelligence. Sham Kakade acknowledges funding from the Office of Naval Research under award N00014-22-1-2377.
References
Appendix A Experimental setup
In this section, we provide additional details about our experimental setup. We first give a description of the positional encodings used in our transformers experiments (Subsection A.1) and then give details about the training and evaluation procedures (Subsection A.2).
We consider multiple positional encoding schemes in our experiments in Section 3:
the NoPE scheme (Kazemnejad et al., 2023) where no positional information is added to any of the attention scores (8(a)). This architecture choice helps to get better length generalization in multiple tasks including the copy task.
the ALiBi scheme (Press et al., 2021) which biases the attention scores with a penalty that is proportional to their distance (8(b)). is a head-specific slope fixed before training.
the Hard-ALiBi scheme introduced in Section 2 which has masked attention heads where we explicitly force the model to attend to their directly previous tokens and heads set to be NoPE attention heads. In 8(c), we display the case where we have masked heads: in the first head, the tokens just attend to themselves; in the second head, the tokens attend to themselves and to previous ones; in the third head, the tokens attend to themselves, the previous ones and the second preceding tokens. The remaining heads are set to NoPE.
A.2 Pretraining and evaluation details
We implement all of our training in Pytorch (Paszke et al., 2019). We use the HuggingFace library (Wolf et al., 2019) and the Mamba GitHub repository (Gu & Dao, 2023).
Architectures.
In our experiments in Section 3, the backbone of our transformers is the GPT-NeoX architecture. We set the number of layers to 12, the hidden size to 1024 and the number of heads . We consider the different positional encodings that are described in Subsection A.1. For Alibi, we set the head-specific scalar as in the original paper i.e. for For the Hard-Alibi model, we sweep over the number of masked heads and found that the best model corresponds to . Regarding the Mamba models, we set the number of layers to 24 and the hidden size 1024. We also sweep over the state space dimension and found the best model is . This choice of hyperparameters ensures that both transformers and Mamba models have a comparable number of parameters. Lastly, our LSTM is made of 4 layers and width 1024.
Training hyperparameters.
In Section 3, at each epoch, we sample online a batch size of size 64. We fill the context with examples so we choose a context length ( for all the experiments except 1(a) where we set ) and pack as many examples as possible to fit this context. So in our case, one sample contains many instances. We run the experiments for 15 epochs for both transformers and Mamba while for LSTMs we need 300 epochs. All methods are trained with the AdamW optimizer (Loshchilov & Hutter, 2017) with learning rate 5e-5, a linear rate decay schedule, 300 steps of warmup and default weight decay of 1e-1. Finally, to train all the models, we use the next-token prediction loss but we apply a mask on the input instance so that we only penalize the model whenever it makes a mistake on the labels (and not on the inputs and labels jointly).
Compute resources.
Pretraining was all done on an internal cluster using RTX8000 GPUs. We estimate that the final training run needed to produce the results in the paper took approximately 600 GPU hours.
Evaluation algorithm.
We evaluate the models over 10 batches of size 64 for all the tasks except for the question answering one where we evaluate over 50 questions because the number of questions with a given context length is limited.
Decoding algorithm.
At inference, all our models use greedy decoding for generation and we set the temperature to 0.
Appendix B Additional Experiments
In Subsection B.1, we focus on the in-distribution learning of the copy task and show that the number of samples needed by GSSMs is much higher than the one for transformers. In Subsection B.2, we study the performance of pre-trained models on the copy task in the case where the strings are sampled uniformly. This experiment shows that when the text to copy is totally random, the gap between pre-trained transformers and GSSMs is even larger.
Transformer: RoPE NoPE Alibi HAlibi GSSM: LSTM Mamba
In this section, we provide additional plots to complement the data efficiency experiment from 1(a). We want to highlight the following points:
in 1(a), we see a sharp transition for the Mamba learning curve. However, 9(a) shows that the learning process is more smooth at the character level. Besides, LSTMs are not able to learn the copy on length-300 strings even at the character level.
We consider the experiment of learning to copy much shorter strings namely strings with length . 9(b) shows that the gap in terms of training examples between transformers and Mamba is much smaller i.e. Mamba only needs 10x more data. Besides, we see that the LSTM is able to learn the copy task but it needs 100x more data than transformers.
B.2 Pre-trained models on the uniform copy task
In this section, we provide an additional experiment that shows the superiority of pre-trained Pythia over pre-trained Mamba models in the copy task.
Pythia: 410M 1.4B 2.8B Mamba: 360M 1.4B 2.8B
We consider the same setup as in Section 3: we sample uniform strings of alphabet characters with a fixed length and ask the model to copy it by using the same prompt format as the one described in Subsection 4.2.
This setting is a more extreme version of 7(b) since the strings are more random: in 7(b), the order of the nouns were random but the nouns were English nouns while in 7(b), the strings are totally random. In Figure 10, we see a clear separation between the transformers and Mamba models with the smallest Pythia outperforming the largest Mamba. However, compared to 7(b), the Pythia performance is much higher since the 1.4B model able to get almost 100% accuracy.
Appendix C Proofs - Upper Bound
This section gives a detailed proof of Theorem 2.3 and Lemma 2.4.
We begin by introducing some technical lemmas that we use in the proof of Theorem 2.3.
𝑛2…𝐿i=n+2,\dots,L do end for for do if then else Let s.t. , and set end if end for Output: sequence Lemma C.1. Let . Then, can be computed using a hard-ALiBi attention head.
Let (zero matrix) and let (indentity matrix). We choose s.t.
Assume that . Then, there exists an embedding s.t.
For every it holds that and .
For it holds that .
For every , , and for every , .
Denote , and observe that we can encode all “non-special” tokens as vectors in , and denote this encoding by . Now, define:
Let be some vector such that, for some constants , there exists s.t. and for all we have . Denote . Then and for all .
subscript𝑧𝑖𝐾1𝑏𝑎𝐾𝑏𝑎1𝐾𝑏𝑎\displaystyle\exp(a)=\exp(z_{i})\leq\sum_{j=1}^{K}\exp(z_{j})\leq\exp(z_{i})+(K-1)\exp(b)\leq\exp(a)+K\exp(b)=\exp(a)(1+K\exp(b-a)) Observe the following:
1𝐾𝑏𝑎11𝐾𝑏𝑎\displaystyle s_{i}=\frac{\exp(z_{i})}{\sum_{j=1}^{K}\exp(z_{j})}\geq\frac{\exp(a)}{\exp(a)(1+K\exp(b-a))}=\frac{1}{1+K\exp(b-a)} Finally, for every :
C.2 Proof of Theorem 2.3
We begin by constructing the first block of the transformer, which computes the “lookup-table” for the copy algorithm. This lookup-table consists of pairs of (key,values) for each position , where the key encodes the -gram preceding the -th token, and the value is the -th token. Namely, if the sequence is , then and . Additionally, the transformer block also computes a query, which is just the “current” -gram, i.e. . The copy algorithm matches the current with previous -s, retrieving the matching .
The following theorem shows that by using a combination of hard-ALiBi attention heads (with different choice of for each head), together with an MLP layer, can compute the correct for each position. We use a slightly modified to handle cases where the (or, is one of the first tokens after the token).
Let be the one-hot embedding. Then, there exists a hard-ALiBi transformer block with 3 outputs, denoted , which correspond to 3 blocks of the output dimension, s.t. , and satisfying, for all sampled from a length- copy distribution,
𝑡1𝑑1𝑡𝑑𝑖Ψsubscript𝑥1…Ψsubscript𝑥𝑖Ψsubscript𝑥𝑖𝑡T^{\mathrm{key}}_{(t-1)d+1:td,i}(\Psi(x_{1}),\dots,\Psi(x_{i}))=\Psi(x_{i-t}) and if
𝑡1𝑑1𝑡𝑑𝑖Ψsubscript𝑥1…Ψsubscript𝑥𝑖0T^{\mathrm{key}}_{(t-1)d+1:td,i}(\Psi(x_{1}),\dots,\Psi(x_{i}))=0 • Additionally, for , for all
𝑛𝑑𝑡𝑖Ψsubscript𝑥1…Ψsubscript𝑥𝑖1𝑖𝑡1T^{\mathrm{key}}_{nd+t,i}(\Psi(x_{1}),\dots,\Psi(x_{i}))=\bm{1}\{i=t+1\} 3. Query output:
𝑡1𝑑1𝑡𝑑𝑖Ψsubscript𝑥1…Ψsubscript𝑥𝑖Ψsubscript𝑥𝑖𝑡1T^{\mathrm{query}}_{(t-1)d+1:td,i}(\Psi(x_{1}),\dots,\Psi(x_{i}))=\Psi(x_{i-t+1}) and if
𝑡1𝑑1𝑡𝑑𝑖Ψsubscript𝑥1…Ψsubscript𝑥𝑖0T^{\mathrm{query}}_{(t-1)d+1:td,i}(\Psi(x_{1}),\dots,\Psi(x_{i}))=0 • Additionally, for , for all
𝑛𝑑𝑡𝑖Ψsubscript𝑥1…Ψsubscript𝑥𝑖⋅𝑛1𝑖𝐿𝑡T^{\mathrm{key}}_{nd+t,i}(\Psi(x_{1}),\dots,\Psi(x_{i}))=n\cdot\bm{1}\{i=L+t\} Proof. We prove the following:
For the value output, we simply take as defined in Lemma C.1.
𝑡1subscriptℎ𝑡1subscript𝒙1…subscript𝒙𝑖⋅𝑡subscriptℎ𝑡subscript𝒙1…subscript𝒙𝑖g_{t}({\bm{x}}_{1},\dots,{\bm{x}}_{i})=(t+1)\cdot h_{t+1}({\bm{x}}_{1},\dots,{\bm{x}}_{i})-t\cdot h_{t}({\bm{x}}_{1},\dots,{\bm{x}}_{i}) where we define . Observe that if then:
𝑡11𝑡1superscriptsubscript𝑗𝑖𝑡𝑖subscript𝒙𝑗⋅𝑡1𝑡superscriptsubscript𝑗𝑖𝑡1𝑖subscript𝒙𝑗subscript𝒙𝑖𝑡g_{t}({\bm{x}}_{1},\dots,{\bm{x}}_{i})=(t+1)\cdot\frac{1}{t+1}\sum_{j=i-t}^{i}{\bm{x}}_{j}-t\cdot\frac{1}{t}\sum_{j=i-t+1}^{i}{\bm{x}}_{j}={\bm{x}}_{i-t} and if then:
Claim:
Proof: Fix some . Observe that for all , .
If , we have and so where we use the properties of and the fact that . Therefore, .
where we use the fact that and therefore .
Claim:
Proof: Denote and . Observe:
If , then and and therefore .
If then and and so .
If then and and so .
If then and therefore .
Finally, we can take .
For all , define .
Claim:
Proof: Denote . Observe:
If then and therefore .
If and then and therefore .
If then since we get and therefore .
Therefore, we can take .
Now, we prove Theorem 2.3 by showing that using a single attention head with no positional embedding on top of the construction in Lemma C.4 realizes the copy algorithm. Since the first block computes the correct choice of , by correctly scaling of the attention matrix we verify that the output of the second layer at position corresponds to for s.t. .
Let be the outputs of the Transformer block guaranteed by Lemma C.4. Observe that, for some temprature , the following function can be computed by a softmax-attention layer on-top of this block:
where e.g. denotes .
For now, assume that all the -grams in are unique, and that the length of the input satisfies for .
Claim: Fix some , denote . Then, and for all .
Proof: We separate to the following cases:
𝑖𝑛1\displaystyle=\bm{1}\{j>n\}\cdot[\Psi(x_{j-1}),\dots,\Psi(x_{j-n})]^{\top}[\Psi(x_{i}),\dots,\Psi(x_{i-n+1})] Now, if then and since we get
𝑖𝑡1𝑛T_{j}^{\mathrm{key}}\cdot T_{i}^{\mathrm{query}}=\sum_{t=1}^{n}\left\lVert\Psi(x_{i-t+1})\right\rVert=n If , since there are no repeated -grams, there is at least some s.t. and by the choice of the embedding . In this case, we get .
𝑖𝐿1T_{j}^{\mathrm{key}}\cdot T_{i}^{\mathrm{query}}=ne_{j-1}\cdot e_{i-L}=n\cdot\bm{1}\{j=i-L+1\} which satisfies the required.
𝑖𝑡1T_{j}^{\mathrm{key}}\cdot T_{i}^{\mathrm{query}}=\sum_{t=1}^{n}\Psi(x_{j-t})\Psi(x_{i-t+1}) and as before, since there are no repeated -grams, we get
Claim: Fix some and some , denote . If , then and for all .
Proof: Using the previous claim, togehter with Lemma C.3, we get that:
Claim: Fix some and some . Then, for , it holds that:
𝑖𝐿1italic-ϵ\left\lVert H(\Psi(x_{1}),\dots,\Psi(x_{i}))-\Psi(x_{i-L+1})\right\rVert\leq\epsilon Proof: Let as defined in the previous claim. Then:
𝑖𝐿1\displaystyle\left\lVert H(\Psi(x_{1}),\dots,\Psi(x_{i}))-\Psi(x_{i-L+1})\right\rVert Now, denote by the output map given by (which can be computed by an over a linear function).
Claim: If , then for all we have .
Proof: Denote . First, using the previous claim, we observe that
𝑖𝐿1\displaystyle{\bm{y}}_{i}\cdot\Psi(x_{i-L+1}) Next, observe that for all we have
⋅subscript𝒚𝑖Ψsubscript𝑥𝑖𝐿1Ψsubscript𝑥𝑗⋅Ψsubscript𝑥𝑗Ψsubscript𝑥𝑖𝐿1\displaystyle=({\bm{y}}_{i}-\Psi(x_{i-L+1}))\cdot\Psi(x_{j})+\Psi(x_{j})\cdot\Psi(x_{i-L+1}) From the above claim, the Transformer construction outputs the correct token at each step of the auto-regressive generation. ∎
C.3 Proof of Lemma 2.4
Fix some . Let and . We first bound the probability of drawing some s.t. . Note that there are choices for . We count the number of choices for s.t. . Notice that in this case, is determined by , therefore there are possible choices. We conclude that
Appendix D Proofs - Lower Bound
In this section, we prove Theorem 2.7. We begin by showing that, for every input, the output of the model in each iteration is a deterministic function of the state of the model after observing the input:
Let be some fixed-state sequence-to-sequence model. Then, there exists map s.t. for all
Let be the outputs of . We need to show that there exist functions s.t. . We give the following recurive definition:
, .
, .
Denote We prove by induction that and also that .
.
Given the previous Lemma, we bound the error of the model by comparing the number of possible states to the number of possible inputs.
From Lemma D.1, there exists some function s.t. . For each , we denote by the sequence . Now, observe the following:
𝑛21~𝒙1𝐺subscript𝑆superscript𝑛′2~𝒙𝒙\displaystyle=\frac{1}{D^{n}}\sum_{s\in{\mathcal{S}}}\sum_{{\bm{x}}\in S_{n+2}^{-1}(\tilde{{\bm{x}}})}\bm{1}\{G\circ S_{n^{\prime}+2}(\tilde{{\bm{x}}})={\bm{x}}\} ∎