MambaByte: Token-free Selective State Space Model

Junxiong Wang, Tushaar Gangavarapu, Jing Nathan Yan, Alexander M. Rush

Introduction

When defining a language model, a base tokenization is typically used—either words (Bengio et al., 2000), subwords (Schuster and Nakajima, 2012; Sennrich et al., 2015; Wu et al., 2016; Wang et al., 2020), or characters (Gao et al., 2020a). Of these, subword tokenization has been the most popular choice, as it achieves a natural compromise between training efficiency and the ability to handle out-of-vocabulary words. However, several works (e.g., Xue et al. (2022)) have noted issues with subword tokenizers, such as a lack of robustness to typos, spelling and capitalization variations, and morphological changes.

Researchers (Clark et al., 2022; Xue et al., 2022; Yu et al., 2023) have employed an alternative approach of using byte sequences, i.e., an end-to-end mapping from raw data to predictions without any intermediate tokenization. Compared to subword models, byte-level language models can generalize more easily across orthographic and morphological variants. Of course, modeling text as bytes means that the resultant sequences are significantly longer than their subword counterparts. This pushes the efficiency issues upstream into the architecture itself.

Efficiency issues are particularly pronounced for autoregressive Transformers (Vaswani et al., 2017), which dominate language modeling (Brown et al., 2020; Touvron et al., 2023). Due to the quadratic cost of attention, Transformers scale poorly for long (byte) sequences (Brown et al., 2020; Zhang et al., 2022). Researchers have compressed the internal Transformer representation to work with long sequences, for instance, developing length-aware modeling approaches (Dai et al., 2020; Nawrot et al., 2022), where groups of tokens are merged within the intermediate layers. Recently, Yu et al. (2023) proposed the MegaByte Transformer, which uses compression in the form of fixed-size patches of bytes as a subword analog. As a result, MegaByte enables lower computational costs.Although our experiments (see Figure 1) indicate that patching can also lower the model performance compared to the standard Transformer.

In this work, we introduce MambaByte, an efficient and simple byte-level language model. The model is a straightforward adaptation of the recently introduced Mamba architecture (Gu and Dao, 2023), a linear-time approach for sequence modeling. Mamba builds off the approach pioneered by state space models (SSMs) (Gu et al., 2021; Gupta et al., 2022; Gu et al., 2022; Smith et al., 2023) by introducing a selection mechanism that is more effective for discrete data such as text and providing an efficient GPU implementation. Our simple observation is that using Mamba (without modifications) relieves the main computational bottleneck in language modeling, thus allowing for the elimination of patching and effective use of the available compute budget.

Experiments compare MambaByte to Transformers, SSMs, and MegaByte (patching) architectures in a fixed parameter and fixed compute setting on several long-form text datasets. Figure 1 summarizes our main findings. Compared to byte-level Transformers, MambaByte achieves better performance faster and is significantly more compute efficient. We also consider the viability of token-free language models compared to the existing state-of-the-art subword models. In this regard, we find MambaByte to be competitive with various subword baselines despite handling significantly longer sequences. Our results establish MambaByte as a strong alternative to the existing tokenizer-dependent models and advocate its use to facilitate end-to-end learning.

Background: Selective state space sequence models

Mamba embeds this SSM layer into a full neural network language model. Specifically, the model utilizes a stack of gated layers inspired by the previous gated SSM (Mehta et al., 2023). Figure 3 shows the Mamba architecture combining the SSM layer with a gated neural network.

Experimental setup

Our experiments compare MambaByte to other byte-level Transformers and SSMs. All our models employ the same training recipes (see Appendix C for details). We utilize a set of diverse long-form text datasets: PG19 (Rae et al., 2020), Stories (Trinh and Le, 2018), Books (Gao et al., 2020b), ArXiv (Gao et al., 2020b), and Code (Gao et al., 2020b). Dataset sizes and average document lengths are included in Appendix A.

Performance comparison across architectures requires care. To this end, we consider two settings: compute-matched and parameter-matched. This setup is necessary as the default MegaByte Transformer employs a global module that works with 8×8\times-patched representations of the input, thus using 8×8\times fewer feed-forward FLOPs per byte than a raw Transformer, while having significantly more parameters. Table 1 shows the MegaByte and MambaByte model sizes employed in our experiments. The (forward pass) FLOPs computation for various model architectures and the associated hyperparameters employed are detailed in Appendix B.

All MambaByte models were trained using the open-source Mamba code base.https://github.com/state-spaces/mamba. At training, we shuffle the documents and use contiguous sequences of 8,1928,192 bytes (one per document), starting from a random position. We enable mixed precision training using BF1616 for training efficiency at scale. The optimizer, learning rate scheduler, and other training details are specified in Appendix C.

Press et al. (2021) proposed using a sliding window to trade off speed for performance during inference. Following this, we employ a sliding window (with a stride of Lctx/2L_{\text{ctx}}/2 for a byte sequence of length LctxL_{\text{ctx}}) when comparing with the state-of-the-art subword models in Table 3.

Results

Table 2 shows the bits per byte (BPB\operatorname{BPB}) across each dataset. For this experiment, the MegaByte-758758M+262262M and MambaByte models use the same number of FLOPs per byte (see Table 1). We observe MambaByte to outperform MegaByte consistently across all datasets. Furthermore, we note that we could not train MambaByte for the full 8080B bytes due to monetary constraints, but MambaByte outperforms MegaByte with 0.63×0.63\times less compute and training data. Additionally, MambaByte-353353M also outperforms byte-level Transformer and PerceiverAR.

How is MambaByte performing better than a much larger model in so few training steps? Figure 1 further explores this relationship by looking at models with the same number of parameters. The graphs indicate that for MegaByte models of the same parameter size, models with less input patching perform better, but when compute-normalized, they perform similarly. In fact, a full-length Transformer, while slow in an absolute sense, also performs similarly to MegaByte when compute-normalized. In contrast, switching to the Mamba architecture significantly improves both the compute usage and the model performance.

Following these findings, Table 3 compares a larger version of these models on the PG19 dataset. For this experiment, we compare MambaByte-972972M with MegaByte-1.31.3B+350350M and other byte-level models, as well as several state-of-the-art subword models. (The conversion from BPB\operatorname{BPB} to perplexity (PPL\operatorname{PPL}) is detailed in Appendix E). We find that MambaByte-972972M, even just trained for 150150B bytes, outperforms all the byte-level models and achieves competitive performance with subword models.

Autoregressive inference in Transformer models requires caching the entire context, which can significantly affect the generation speed. MambaByte does not suffer from this bottleneck as it maintains a single hidden state per layer that evolves with time, enabling constant time per generation step. Table 4 compares the text generation speeds of MambaByte-972972M and MambaByte-1.61.6B with MegaByte-1.31.3B+350350M on an A100 80GB PCIe GPU. While MegaByte significantly reduces the generation cost through patching, we observe MambaByte to be 2.6×2.6\times faster in a parameter-matched setting due to its use of recurrent generation. Appendix F includes more information about the generation process.

Conclusion

We introduce MambaByte, a token-free SSM for modeling long byte-sequences. MambaByte outperforms other byte-level models over several datasets and shows competitive results with subword Transformers, thus serving as a promising tokenization alternative. SSMs also enable significantly fast text generation due to their recurrent nature, making byte models practical. Our findings establish the possibility of token-free language modeling in future large models.

References

Appendix A Dataset specifics

We benchmark our results on various long-form text datasets. The PG19 dataset [Rae et al., 2020] is an extensive collection of full-length English books (written before 19191919) from the Project Gutenberg online library. The PG19 dataset is ideal to test for long-distance context modeling [Gao et al., 2020b]. The Stories dataset [Trinh and Le, 2018] is a subset of the CommonCrawl data used for commonsense reasoning and language modeling. The Books dataset [Gao et al., 2020b] is another collection of English books. The ArXiv dataset [Gao et al., 2020b] comprises technical publications in LaTeX from the arXiv online archive. Finally, the Code dataset [Gao et al., 2020b] is a large dataset of publicly available open-source code (under Apache, MIT, or BSD licenses). Dataset statistics are tabulated in Table 5.

For the PG19 dataset, we employ the train, validation, and test data splits as indicated by Rae et al. . For Stories, Books, ArXiv, and Code datasets, we randomly sample 4040M consecutive bytes for testing and the rest to train MambaByte.

Appendix B Compute-constrained modeling

As noted earlier, we evaluate and benchmark MambaByte in a compute-controlled setting. To this end, we estimate the FLOPs per byte incurred by various byte-level model architectures. We parameterize the architectures using hyperparameters nn (ng/nl)(n_{g}/n_{l}) number of (global//local) layers, dimension dd (dg/dl)(d_{g}/d_{l}) of the (global//local) residual stream, expansion factor ee of linear layers, patch size pp in MegaByte, state dimension nstaten_{\text{state}} in SSMs, 1D convolution kernel size kk, and low-rank projection dimension rr in Mamba. We also include LctxL_{\text{ctx}} bytes in the input context. Detailed component-wise compute counts for the forward pass are included in Table 6.

For the medium-scale language modeling experiments (Table 1, §5\S 5 of Yu et al. ), Yu et al. employ the MegaByte-758758M+262262M model, with a context length of 8,1928,192 and patch size of 88, trained for 8080B bytes. As shown in Figure 5, MambaByte-353353M (n=\text{53}, d=\text{1,024}, e=\text{2}) and MegaByte-758758M+262262M use the same total compute in FLOPs; hence, we employ the MambaByte-353353M to benchmark against MegaByte-758758M+262262M in Table 2 of §\refsec:results\S\ref{sec:results}.

Appendix C Training recipes

All the models in this study were trained using an AdamW optimizer with β=(0.9,0.95)\beta=(0.9,0.95). We used a linear learning rate warm-up (for the first 500500 steps) followed by cosine annealing. Keeping consistent with MegaByte training [Yu et al., 2023], we used a batch size of 4848 across all our experiments. Additionally, we do not use dropout with any of our models.

For the experiments in Figure 1, we conducted a hyperparameter search using peak learning rates of 0.00020.0002, 0.00060.0006, and 0.00080.0008 and clipped the gradient norm to 1.01.0 for all the models. The best-observed performance curve for each model is reported in Figure 1. Furthermore, we use an improved Transformer recipe that uses RMSNorm instead of LayerNorm, rotary positional encodings [Su et al., 2021], and linear terms without bias (same as [Yu et al., 2023]).

In our medium-scale experiments shown in Table 2, we set the peak learning rate to 0.00040.0004 and clipped the gradient norm to 0.10.1. We trained the MambaByte-353353M for a total of 8080K steps, equivalent to 80,000×48×8,1923080,000\times 48\times 8,192\approx 30B bytes.

In the large-scale experiment on PG19, we use a similar setting to that in the medium-scale experiments: the peak learning rate is set to 0.00040.0004, and the gradient norm is clipped to 0.10.1. The MambaByte-972972M is trained for 380380K steps, equivalent to 380,000×48×8,192150380,000\times 48\times 8,192\approx 150B bytes.

Appendix D Discretization and selection

Discretization has deep connections to continuous-time systems, which allows for desirable properties such as model normalization [Orvieto et al., 2023, Gu et al., 2023] and resolution invariance [Nguyen et al., 2022]. In this section, we show how zero-order hold discretization of a selective SSM can be viewed as a generalization of the gating mechanism in recurrent networks.

Selection mechanics and gating in recurrent networks.

It is interesting to note from (4) that limΔh[k]=x[k]\lim_{\Delta\to\infty}h[k]=x[k] and limΔ0h[k]=h[k1]\lim_{\Delta\to 0}h[k]=h[k-1]: a large Δ\Delta (Δ\Delta\to\infty) denotes the evolution of the system to focus only on the current input and forgetting the state. In contrast, a small Δ\Delta (Δ0\Delta\to 0) represents a transient input being ignored.

Appendix E Evaluation metrics

Subword-based language models [Vaswani et al., 2017, Hawthorne et al., 2022, Hutchins et al., 2022] report their performance in word-level PPL\operatorname{PPL}, while byte-level language models [Xue et al., 2022, Yu et al., 2023] report theirs in BPB\operatorname{BPB}. To facilitate meaningful comparisons, we report performance in BPB\operatorname{BPB} when benchmarking against byte-level models and PPL\operatorname{PPL} when comparing to token-level models. In this section, we detail the conversion between word-level PPL\operatorname{PPL} and BPB\operatorname{BPB}.

Irrespective of the underlying segmentation, the amount of information I(D)I(D) in a given dataset DD is constant. Simply put,

where LTL_{T} and LBL_{B} are the length of the dataset in tokens and bytes, respectively. From (5), we observe:

For the PG19 dataset, we train MambaByte-972972M to minimize BPB over the training data and report word-level PPL on the test data. Split-wise values of LB/LTL_{B}/L_{T} for the PG19 dataset are tabulated in Table 8.

Appendix F PG19 generation samples

This section includes a few sample generations from the MambaByte-972972M trained on the PG19 dataset. We use Nucleus sampling with p=0.98p=0.98 [Holtzman et al., 2020] and generate continuations for a total of 8,1928,192 bytes (including the given context prefix). Furthermore, we chose the same test set prefixes used in Appendix F of Rae et al. . We observe that the model is able to continue the dialogue in the style of the prefix and effectively recall the character names over hundreds of bytes.

The Diary of Samuel Pepys

The Patrol of the Sun Dance Trail by Ralph Connor