MambaByte: Token-free Selective State Space Model
Junxiong Wang, Tushaar Gangavarapu, Jing Nathan Yan, Alexander M. Rush
Introduction
When defining a language model, a base tokenization is typically used—either words (Bengio et al., 2000), subwords (Schuster and Nakajima, 2012; Sennrich et al., 2015; Wu et al., 2016; Wang et al., 2020), or characters (Gao et al., 2020a). Of these, subword tokenization has been the most popular choice, as it achieves a natural compromise between training efficiency and the ability to handle out-of-vocabulary words. However, several works (e.g., Xue et al. (2022)) have noted issues with subword tokenizers, such as a lack of robustness to typos, spelling and capitalization variations, and morphological changes.
Researchers (Clark et al., 2022; Xue et al., 2022; Yu et al., 2023) have employed an alternative approach of using byte sequences, i.e., an end-to-end mapping from raw data to predictions without any intermediate tokenization. Compared to subword models, byte-level language models can generalize more easily across orthographic and morphological variants. Of course, modeling text as bytes means that the resultant sequences are significantly longer than their subword counterparts. This pushes the efficiency issues upstream into the architecture itself.
Efficiency issues are particularly pronounced for autoregressive Transformers (Vaswani et al., 2017), which dominate language modeling (Brown et al., 2020; Touvron et al., 2023). Due to the quadratic cost of attention, Transformers scale poorly for long (byte) sequences (Brown et al., 2020; Zhang et al., 2022). Researchers have compressed the internal Transformer representation to work with long sequences, for instance, developing length-aware modeling approaches (Dai et al., 2020; Nawrot et al., 2022), where groups of tokens are merged within the intermediate layers. Recently, Yu et al. (2023) proposed the MegaByte Transformer, which uses compression in the form of fixed-size patches of bytes as a subword analog. As a result, MegaByte enables lower computational costs.Although our experiments (see Figure 1) indicate that patching can also lower the model performance compared to the standard Transformer.
In this work, we introduce MambaByte, an efficient and simple byte-level language model. The model is a straightforward adaptation of the recently introduced Mamba architecture (Gu and Dao, 2023), a linear-time approach for sequence modeling. Mamba builds off the approach pioneered by state space models (SSMs) (Gu et al., 2021; Gupta et al., 2022; Gu et al., 2022; Smith et al., 2023) by introducing a selection mechanism that is more effective for discrete data such as text and providing an efficient GPU implementation. Our simple observation is that using Mamba (without modifications) relieves the main computational bottleneck in language modeling, thus allowing for the elimination of patching and effective use of the available compute budget.
Experiments compare MambaByte to Transformers, SSMs, and MegaByte (patching) architectures in a fixed parameter and fixed compute setting on several long-form text datasets. Figure 1 summarizes our main findings. Compared to byte-level Transformers, MambaByte achieves better performance faster and is significantly more compute efficient. We also consider the viability of token-free language models compared to the existing state-of-the-art subword models. In this regard, we find MambaByte to be competitive with various subword baselines despite handling significantly longer sequences. Our results establish MambaByte as a strong alternative to the existing tokenizer-dependent models and advocate its use to facilitate end-to-end learning.
Background: Selective state space sequence models
Mamba embeds this SSM layer into a full neural network language model. Specifically, the model utilizes a stack of gated layers inspired by the previous gated SSM (Mehta et al., 2023). Figure 3 shows the Mamba architecture combining the SSM layer with a gated neural network.
Experimental setup
Our experiments compare MambaByte to other byte-level Transformers and SSMs. All our models employ the same training recipes (see Appendix C for details). We utilize a set of diverse long-form text datasets: PG19 (Rae et al., 2020), Stories (Trinh and Le, 2018), Books (Gao et al., 2020b), ArXiv (Gao et al., 2020b), and Code (Gao et al., 2020b). Dataset sizes and average document lengths are included in Appendix A.
Performance comparison across architectures requires care. To this end, we consider two settings: compute-matched and parameter-matched. This setup is necessary as the default MegaByte Transformer employs a global module that works with -patched representations of the input, thus using fewer feed-forward FLOPs per byte than a raw Transformer, while having significantly more parameters. Table 1 shows the MegaByte and MambaByte model sizes employed in our experiments. The (forward pass) FLOPs computation for various model architectures and the associated hyperparameters employed are detailed in Appendix B.
All MambaByte models were trained using the open-source Mamba code base.https://github.com/state-spaces/mamba. At training, we shuffle the documents and use contiguous sequences of bytes (one per document), starting from a random position. We enable mixed precision training using BF for training efficiency at scale. The optimizer, learning rate scheduler, and other training details are specified in Appendix C.
Press et al. (2021) proposed using a sliding window to trade off speed for performance during inference. Following this, we employ a sliding window (with a stride of for a byte sequence of length ) when comparing with the state-of-the-art subword models in Table 3.
Results
Table 2 shows the bits per byte () across each dataset. For this experiment, the MegaByte-M+M and MambaByte models use the same number of FLOPs per byte (see Table 1). We observe MambaByte to outperform MegaByte consistently across all datasets. Furthermore, we note that we could not train MambaByte for the full B bytes due to monetary constraints, but MambaByte outperforms MegaByte with less compute and training data. Additionally, MambaByte-M also outperforms byte-level Transformer and PerceiverAR.
How is MambaByte performing better than a much larger model in so few training steps? Figure 1 further explores this relationship by looking at models with the same number of parameters. The graphs indicate that for MegaByte models of the same parameter size, models with less input patching perform better, but when compute-normalized, they perform similarly. In fact, a full-length Transformer, while slow in an absolute sense, also performs similarly to MegaByte when compute-normalized. In contrast, switching to the Mamba architecture significantly improves both the compute usage and the model performance.
Following these findings, Table 3 compares a larger version of these models on the PG19 dataset. For this experiment, we compare MambaByte-M with MegaByte-B+M and other byte-level models, as well as several state-of-the-art subword models. (The conversion from to perplexity () is detailed in Appendix E). We find that MambaByte-M, even just trained for B bytes, outperforms all the byte-level models and achieves competitive performance with subword models.
Autoregressive inference in Transformer models requires caching the entire context, which can significantly affect the generation speed. MambaByte does not suffer from this bottleneck as it maintains a single hidden state per layer that evolves with time, enabling constant time per generation step. Table 4 compares the text generation speeds of MambaByte-M and MambaByte-B with MegaByte-B+M on an A100 80GB PCIe GPU. While MegaByte significantly reduces the generation cost through patching, we observe MambaByte to be faster in a parameter-matched setting due to its use of recurrent generation. Appendix F includes more information about the generation process.
Conclusion
We introduce MambaByte, a token-free SSM for modeling long byte-sequences. MambaByte outperforms other byte-level models over several datasets and shows competitive results with subword Transformers, thus serving as a promising tokenization alternative. SSMs also enable significantly fast text generation due to their recurrent nature, making byte models practical. Our findings establish the possibility of token-free language modeling in future large models.
References
Appendix A Dataset specifics
We benchmark our results on various long-form text datasets. The PG19 dataset [Rae et al., 2020] is an extensive collection of full-length English books (written before ) from the Project Gutenberg online library. The PG19 dataset is ideal to test for long-distance context modeling [Gao et al., 2020b]. The Stories dataset [Trinh and Le, 2018] is a subset of the CommonCrawl data used for commonsense reasoning and language modeling. The Books dataset [Gao et al., 2020b] is another collection of English books. The ArXiv dataset [Gao et al., 2020b] comprises technical publications in LaTeX from the arXiv online archive. Finally, the Code dataset [Gao et al., 2020b] is a large dataset of publicly available open-source code (under Apache, MIT, or BSD licenses). Dataset statistics are tabulated in Table 5.
For the PG19 dataset, we employ the train, validation, and test data splits as indicated by Rae et al. . For Stories, Books, ArXiv, and Code datasets, we randomly sample M consecutive bytes for testing and the rest to train MambaByte.
Appendix B Compute-constrained modeling
As noted earlier, we evaluate and benchmark MambaByte in a compute-controlled setting. To this end, we estimate the FLOPs per byte incurred by various byte-level model architectures. We parameterize the architectures using hyperparameters number of (globallocal) layers, dimension of the (globallocal) residual stream, expansion factor of linear layers, patch size in MegaByte, state dimension in SSMs, 1D convolution kernel size , and low-rank projection dimension in Mamba. We also include bytes in the input context. Detailed component-wise compute counts for the forward pass are included in Table 6.
For the medium-scale language modeling experiments (Table 1, of Yu et al. ), Yu et al. employ the MegaByte-M+M model, with a context length of and patch size of , trained for B bytes. As shown in Figure 5, MambaByte-M (n=\text{53}, d=\text{1,024}, e=\text{2}) and MegaByte-M+M use the same total compute in FLOPs; hence, we employ the MambaByte-M to benchmark against MegaByte-M+M in Table 2 of .
Appendix C Training recipes
All the models in this study were trained using an AdamW optimizer with . We used a linear learning rate warm-up (for the first steps) followed by cosine annealing. Keeping consistent with MegaByte training [Yu et al., 2023], we used a batch size of across all our experiments. Additionally, we do not use dropout with any of our models.
For the experiments in Figure 1, we conducted a hyperparameter search using peak learning rates of , , and and clipped the gradient norm to for all the models. The best-observed performance curve for each model is reported in Figure 1. Furthermore, we use an improved Transformer recipe that uses RMSNorm instead of LayerNorm, rotary positional encodings [Su et al., 2021], and linear terms without bias (same as [Yu et al., 2023]).
In our medium-scale experiments shown in Table 2, we set the peak learning rate to and clipped the gradient norm to . We trained the MambaByte-M for a total of K steps, equivalent to B bytes.
In the large-scale experiment on PG19, we use a similar setting to that in the medium-scale experiments: the peak learning rate is set to , and the gradient norm is clipped to . The MambaByte-M is trained for K steps, equivalent to B bytes.
Appendix D Discretization and selection
Discretization has deep connections to continuous-time systems, which allows for desirable properties such as model normalization [Orvieto et al., 2023, Gu et al., 2023] and resolution invariance [Nguyen et al., 2022]. In this section, we show how zero-order hold discretization of a selective SSM can be viewed as a generalization of the gating mechanism in recurrent networks.
Selection mechanics and gating in recurrent networks.
It is interesting to note from (4) that and : a large () denotes the evolution of the system to focus only on the current input and forgetting the state. In contrast, a small () represents a transient input being ignored.
Appendix E Evaluation metrics
Subword-based language models [Vaswani et al., 2017, Hawthorne et al., 2022, Hutchins et al., 2022] report their performance in word-level , while byte-level language models [Xue et al., 2022, Yu et al., 2023] report theirs in . To facilitate meaningful comparisons, we report performance in when benchmarking against byte-level models and when comparing to token-level models. In this section, we detail the conversion between word-level and .
Irrespective of the underlying segmentation, the amount of information in a given dataset is constant. Simply put,
where and are the length of the dataset in tokens and bytes, respectively. From (5), we observe:
For the PG19 dataset, we train MambaByte-M to minimize BPB over the training data and report word-level PPL on the test data. Split-wise values of for the PG19 dataset are tabulated in Table 8.
Appendix F PG19 generation samples
This section includes a few sample generations from the MambaByte-M trained on the PG19 dataset. We use Nucleus sampling with [Holtzman et al., 2020] and generate continuations for a total of bytes (including the given context prefix). Furthermore, we chose the same test set prefixes used in Appendix F of Rae et al. . We observe that the model is able to continue the dialogue in the style of the prefix and effectively recall the character names over hundreds of bytes.