Improving language models by retrieving from trillions of tokens
Sebastian Borgeaud, Arthur Mensch, Jordan Hoffmann, Trevor Cai, Eliza Rutherford, Katie Millican, George van den Driessche, Jean-Baptiste Lespiau, Bogdan Damoc, Aidan Clark, Diego de Las Casas, Aurelia Guy, Jacob Menick, Roman Ring, Tom Hennigan, Saffron Huang, Loren Maggiore, Chris Jones, Albin Cassirer, Andy Brock, Michela Paganini, Geoffrey Irving, Oriol Vinyals, Simon Osindero, Karen Simonyan, Jack W. Rae, Erich Elsen, Laurent Sifre
Introduction
Language modelling (LM) is an unsupervised task that consists of modelling the probability of text, usually by factorising it into conditional next-token predictions . Neural networks have proven to be powerful language models, first in the form of recurrent architectures (Mikolov et al., 2010; Graves, 2013; Jozefowicz et al., 2016) and more recently in the form of Transformers (Vaswani et al., 2017), that use attention to contextualise the past. Large performance improvements have come from increasing the amount of data, training compute, or model parameters. Transformers have been scaled from million parameter models in seminal work to over hundred billion parameters (Radford et al., 2019; Brown et al., 2020) in the last two years which has led to models that do very well on a wide array of tasks in a zero or few-shot formulation. Increasing model size predictably improves performance on a wide range of downstream tasks (Kaplan et al., 2020). The benefits of increasing the number of parameters come from two factors: additional computations at training and inference time, and increased memorization of the training data.
In this work, we endeavor to decouple these, by exploring efficient means of augmenting language models with a massive-scale memory without significantly increasing computations. Specifically, we suggest retrieval from a large text database as a complementary path to scaling language models. Instead of increasing the size of the model and training on more data, we equip models with the ability to directly access a large database to perform predictions—a semi-parametric approach. At a high level, our Retrieval Transformer (Retro) model splits the input sequence into chunks and retrieves text similar to the previous chunk to improve the predictions in the current chunk. Existing retrieval for language modelling work only considers small transformers ( millions parameters) and databases of limited size (up to billions of tokens) (Khandelwal et al., 2020; Yogatama et al., 2021; Guu et al., 2020; Lewis et al., 2020). To our knowledge, our work is the first to show the benefits of scaling the retrieval database to trillions of tokens for large parametric language models. Our main contributions are the following.
We introduce Retro, a retrieval-enhanced autoregressive language model (§2.2). We use a chunked cross-attention module to incorporate the retrieved text (§2.4), with time complexity linear in the amount of retrieved data. We show that retrieving based on a pre-trained frozen Bert model (§2.3) works at scale, removing the need for training and updating a retriever network.
We show that our method scales well with model size and database size (Fig. 1): Retro provides a constant gain for models ranging from 150M to 7B parameters, and Retro can be improved at evaluation time by increasing the database size and the number of retrieved neighbours. Our largest model obtains state-of-the-art results on a range of downstream evaluation datasets including Wikitext103 (Merity et al., 2017) and the Pile (Gao et al., 2020) (§4). We show that Retro can be fine-tuned to achieve competitive performance on downstream tasks such as question answering (§4.3).
We propose an evaluation aware of proximity of test documents with the training set (§2.6), addressing the problem of test set leakage (Lee et al., 2021). This is relevant for all language models, and especially for retrieval-enhanced models since they have direct access to the training dataset during evaluation. Using this methodology, we show that the performance of Retro comes from both explicit neighbour copying and general knowledge extraction (§4.4).
Method
We design our retrieval-enhanced architecture to be capable of retrieving from a database with trillions of tokens. For this purpose, we retrieve at the level of contiguous token chunks instead of individual tokens which reduces storage and computation requirements by a large linear factor. Our method first constructs a key-value database, where values store raw chunks of text tokens and keys are frozen Bert embedddings (Devlin et al., 2019). We use a frozen model to avoid having to periodically re-compute embeddings over the entire database during training. Each training sequence is then split into chunks, which are augmented with their -nearest neighbour retrieved from the database. An encoder-decoder architecture integrates retrieval chunks into the model’s predictions. We summarize the Retro architecture in Fig. 2, and detail it in this section. We end the section by introducing a new methodology to evaluate language models when an evaluation set is partially present in the training set.
We use a multi-lingual version of MassiveText (Rae et al., 2021) for both training and retrieval data. The dataset consists of text documents from multiple sources and multiple languages totalling over 5 trillion tokens (detailed in Table 1). Sequences are sampled from subsets of the training data, with sampling weights given in the right-most column of Table 1. We tokenize the dataset using SentencePiece (Kudo and Richardson, 2018) with a vocabulary of 128,000 tokens. During training (unless otherwise specified), we retrieve from 600B tokens from the training data. The training retrieval database is made of the same subsets as the training data, in proportion that matches the training sampling frequencies. During evaluation the retrieval database consists in the full union of these datasets, with the exception of books for which we use a sub-sample of 4%. The evaluation retrieval database thus contains 1.75T tokens. To limit test set leakage, we compute the -gram Jaccard similarity between train and test documents using the MinHash scheme and remove all training documents with high similarity (0.8 or higher) to a validation or test set document. Additionally, we remove all validation and test articles from Wikitext103 (Merity et al., 2017) from our Wikipedia training data.
2 Retrieval-enhanced autoregressive token models
3 Nearest neighbour retrieval
Our database consists of a key-value memory. Each value consists of two contiguous chunks of tokens which we denote where is the neighbour chunk which is used to compute the key, and is its continuation in the original document. The corresponding key is the Bert embedding of , averaged over time, that we denote . For each chunk , we retrieve its approximate -nearest neighbours from our key-value database using the distance on BERT embeddings . The model receives the corresponding values . Both neighbour chunks and their continuations provide meaningful improvements, as illustrated in our ablation study (Appendix D). We use a length for both and , thus has a shape of with . To avoid retrieving the chunk in the retrieval set , which would break causality during training, we filter out neighbours originating from the same document as the training sequence .
For a database of elements, we can query the approximate nearest neighbours in time. We use the SCaNN library (Guo et al., 2020) to achieve this. This means that we can query our trillion token database in whilst evaluating or sampling from the model; this expense is amortized over a chunk length. Performing retrieval on-the-fly is too slow to keep up with the training calculations—we leverage the frozen aspect of the embedding operator Bert to precompute all approximate nearest neighbours and save the results as part of the data. In Fig. 9 in the Appendix, we show results where we only retrieve neighbours within Wikipedia. We find that neighbours tend to come from 2-3 links away from a given article whereas random articles are more than 5 links apart.
4 Retro model architecture
Since Ffw, Attn and Cca are all autoregressive operators whose output at position only depends on , any succession of Retro and lm layers, followed by a token classification head defines an autoregressive log-likelihood (1). An overview of the model architecture is given in Algorithm 1 and in Fig. 2. We next describe the retrieval encoder and the chunked cross-attention layer in more detail, and explain how to sample from Retro.
where the softmax is performed on the second dimension and all products are matrix products. We use multi-head cross-attention, and add positional encodings to the softmax(see §B.1.2).
The first tokens cannot attend to any neighbour of a previous chunk; at these positions, we define Cca as the identity, setting for all tokens . Finally, the last token attends to the last retrieval set and we set (not shown in Fig. 2). 1 contains a simplified implementation of Cca. Note that chunked cross-attention is autoregressive: the output of Cca at position depends on the sequence from tokens from to that is input to Cca.
With Retro models, even though each Cca cross-attention attends only to the neighbours of the preceding chunk , the dependencies over previous neighbours are propagated via the self-attention operations. The activations of the th token in the th chunk therefore potentially depend upon the set of all previous neighbours , without incurring the quadratic cost of cross attending to that set.
When sampling, at the end of a chunk , we use SCaNN to retrieve neighbours , based on the embedding . The encoded neighbours are then used to condition the generation of the next chunk , which we do incrementally: overall the cost of sampling is thus quadratic in the size of the sampled sequence, as when sampling from regular Transformers; the added cost of retrieval is linear in the number of chunks , and is negligible compared to the token sampling cost in practice.
5 Baseline Transformer architecture
We use a transformer (Vaswani et al., 2017) similar to the one described in (Radford et al., 2019), with some minimal changes: we replace LayerNorm with RMSNorm (Zhang and Sennrich, 2019) and use relative position encodings (Dai et al., 2019). As baselines, we train retrieval-free transformers with 132M, 368M, 1.3B and 7.0B parameters (embedding matrices are excluded from parameter counts). The hyperparameters we used are detailed in Table 2. All retrieval models use the same size encoder for the retrieval data, with and 2 layers, which roughly adds parameters. The encoder uses relative positional encodings. The retrieval models contain one Retro-block every 3 blocks, starting from layer 6. For our smallest model, Cca is applied in layers 6, 9 and 12 of the main pathway and also once for query conditioning in the encoder, which adds an additional parameters. The relative number of extra parameters reduces as we increase the baseline model size. All models are implemented using JAX (Bradbury et al., 2018) and Haiku (Hennigan et al., 2020).
6 Quantifying dataset leakage exploitation
Retro models may arguably benefit more easily from evaluation dataset leakage, i.e. the fact that we evaluate on data that were also present in the training set. To better understand how retrieval improves language modelling performance, we therefore quantify evaluation likelihood as a function of the overlap between the evaluation and training datasets.
Related Work
We first review existing work on using retrieval for language modelling, and compare Retro to these works (see Table 3). As we train Retro models on a large dataset containing a substantial section of the internet, our work raises potential privacy, safety, and fairness issues that we then review.
Brants et al. (2007) show that scaling the training data to trillions of tokens improves the machine translation performance of -gram models. More recently, GPT-2 (Radford et al., 2019), GPT-3 (Brown et al., 2020), and Jurassic-1 (Lieber et al., 2021) show that scaling up language models leads to massive improvements on many downstream tasks. At the same time, Carlini et al. (2021) demonstrate that large-scale language models can perfectly memorise parts of their training data, suggesting that enhancing models with retrieval may lead to further improvements. However, significant leakage between train and test datasets (Lee et al., 2021; Lewis et al., 2021) makes comparing and evaluating large models trained on large datasets difficult, especially once retrieval capabilities over the training dataset are added.
Historically, information retrieval for text relies on inverted index matching such as TF-IDF and BM25 (Robertson and Zaragoza, 2009). Foundational work use latent topic modelling approaches like LDA (Blei et al., 2003) to identify relevant neighbours (Wei and Croft, 2006). Work in machine translation such as Zhang et al. (2018) and Gu et al. (2018) retrieve translation pairs based on edit distance between source sentences and guide the translation output using the closest retrieved target sentences. The retrieval database may also be structured — for example, Ahn et al. (2016) use a symbolic knowledge graph to improve an RNN language model.
With the success of deep learning, retrieving systems have partly switched to dense learned representations based on a neural network’s activations. Continuous cache (Grave et al., 2017) adds probability mass to tokens for which previous activations resemble the current activation vector, extending the model’s context to the local history. (Khandelwal et al., 2020) applies this idea to transformers and extends the retrieval database to English Wikipedia, resulting in substantial improvements on Wikitext103 evaluation. Continuous cache and do not modify the underlying neural-network models, but interpolate at inference between the language model’s output and distributions computed from retrieved tokens. These methods can therefore be plugged into any model without additional training, although this limits the model’s ability to reason about the retrieved text. Spalm (Yogatama et al., 2021) addresses this limitation by adding an extra gating network to post-process the retrieved data; yet most of the network is unaffected by the retrieval during inference.
The retrieval representations may be trained directly instead of relying on a pre-trained model—retriever systems have been developed for this purpose, primarily on open-domain question answering. For example, Dpr (Karpukhin et al., 2020) trains two Bert models (for queries and keys respectively) using a contrastive loss to align the representations of a question and of its answers. Lee et al. (2019) use an inverse cloze task to find semantic representations of passages for retrieval. These works differs from continuous cache and in that they embeds passages (or chunks) of text together, as opposed to each token individually. The retriever network is trained in isolation of the downstream task that uses the retrieval data. This potential issue is specifically addressed by Realm (Guu et al., 2020), which trains the retrieval system end-to-end to maximize the final training cross-entropy. This comes with the extra complexity of searching the database during training and periodically updating the embedding table, severely limiting the scale at which it can operate. RAG (Lewis et al., 2020) and FiD (Izacard and Grave, 2021) build upon Dpr to set the state of the art on question answering benchmarks by training encoder-decoder transformer models. More recently, (Sachan et al., 2021) extends FiD by using an expectation-maximization algorithm to train the retriever end-to-end and achieves state of the art results compared to similarly sized models.
In the open-domain dialogue setting, BlenderBot 2.0 (Komeili et al., 2021) learns to issue textual internet queries, outperforming dense retrieval methods when evaluated on a task measuring how close model responses are to those of humans. This involves collecting a dataset of human dialogues with associated search queries, which limits the scalability of this approach. Hashemi et al. (2020) introduce the Guided Transformer, a modified Transformer similar to Retro, for document retrieval and clarifying question selection. Although effective on question answering and other tasks with strong conditioning, none of these methods are designed to model arbitrary text sequences, in contrast with Retro.
Retro shares components with and Dpr in that it uses frozen retrieval representations. Retro models longer sequences than QA examples; this requires to reason at a sub-sequence level, and to retrieve different documents for the different chunks of a sequence. Similar to FiD, Retro processes the retrieved neighbours separately in the encoder, and assemble them in the chunked cross-attention. This differs from e.g. Realm, that prepends retrieved documents to the prompt. Using chunks allows for repeated retrieval whilst generating a sequence as opposed to retrieving only once based on the prompt alone. Furthermore, retrieval is done during the whole pre-training process in Retro, and is not simply plugged-in to solve a certain downstream task. Finally, previous methods based on dense query vectors use small models and retrieval datasets with less than 3B tokens (English Wikipedia). Table 3 summarizes the difference of Retro with existing approaches.
2 Privacy, safety and fairness
Bender et al. (2021); Weidinger et al. (2021) highlight several dangers of large language models. Those stem from their ability to memorise training data, their high training cost, the static nature of their training data (Lazaridou et al., 2021), their tendency of amplifying inherent biases in the training data, and their ability to generate toxic language (Gehman et al., 2020). In this section we inspect these dangers, focusing on how retrieval augmented language models may exacerbate or mitigate them.
Large language models can perfectly memorise parts of their training data (Carlini et al., 2021). When coupled with large training datasets gathered from the web or other sources, this has clear privacy and safety implications. Retrieval models such as Retro that have access to the entire training dataset during inference exacerbate these privacy issues by being able to directly copy training data. However, retrieval systems offer a path towards mitigating these concerns via obliteration of the retrievable data at inference time. In addition, differential privacy training (Abadi et al., 2016) of retrieval models could guarantee that no private information is stored in the model weights, while individualisation on private data could be made by updating the retrieval database at inference time.
Due to their high training cost, re-training large language model regularly to incorporate new data, languages, and norms is prohibitively expensive. To keep retrieval models up-to-date, it may be sufficient to update the retrieval database, which is orders of magnitude cheaper than re-training a model from scratch. In addition to the benefits of updating models in terms of fairness and bias, simply training large language models has a significant energy cost (Strubell et al., 2019; Schwartz et al., 2020). Retrieval mechanisms offer a path to reducing the compute requirements needed to train and update language models that reach a certain performance.
Large language models are prone to generating toxic outputs, as shown in Gehman et al. (2020). Bender et al. (2021); Jo and Gebru (2020) advocate for the importance of better training data curation and documentation. Additionally, if portions of the training data are found to be eliciting biased or toxic outputs after training, retrieval allows for some correction, as the offending retrieval data can be retroactively filtered. However, it is also the case that without careful analysis and intervention, retrieval models may exacerbate biases that are present in the training data. Retrieval models can also add a further source of bias through the selection mechanism for retrieval documents. Further work in this area is required to better understand how retrieval affects the bias and toxicity of the model outputs.
Finally, samples from large models are difficult to interpret, making mitigating these issues all the more challenging (Belinkov et al., 2020; Jain and Wallace, 2019). Retrieval provides more insights in to the outputs of a model, as one can directly visualise or modify the neighbours that are being used. The examples in Table 6, 7, 20 and 21 illustrate how retrieval makes language models more factual and interpretable by providing more transparent outputs.
Results
We first report results on language modelling benchmarks. Second, we show how to Retrofit pre-trained Transformer language models into retrieval models with few additional FLOPs. Next, we report Retro results on question answering. Finally, we report evaluation metrics with leakage filtering, to better understand the source of the gains with retrieval.
We evaluate our models on C4 (Raffel et al., 2020), Wikitext103 (Merity et al., 2017), Curation Corpus (Curation, 2020), Lambada (Paperno et al., 2016) and the Pile (Gao et al., 2020). We also evaluate on a set of manually selected Wikipedia articles that were added or heavily edited in September 2021, months after our pre-training and retrieval dataset was collected (details are given in §A.2). We construct the dataset with articles from the “future” and manually remove new articles that strongly overlap documents in our training data. This guarantees that the evaluation documents are not leaked in our training data.
For C4, Wikitext103, the Pile, and our Wikipedia dataset we evaluate the language modelling performance on entire documents and measure the bits-per-byte (bpb). We favour bits-per-byte over loss as it is tokenizer agnostic. We evaluate with a sequence length of 2048 tokens but use a stride of 1024 within documents to mitigate boundary effects. On Curation Corpus we concatenate the article, the “TL;DR:” string, and the summary, but only evaluate the bpb on the summary. For Lambada we evaluate the accuracy on the last word, using greedy generation.
In Fig. 1(left) and Fig. 3 we show the language modelling performance as we scale models from 150 million to 7 billion (non-embedding) parameters. We see that on all datasets, Retro outperforms the baseline at all model sizes. Furthermore, we observe that improvements do not diminish as we scale the models. The performance is dataset dependent, with the largest gains on Wikitext103 and C4. Wikipedia articles and other web pages are similar to Wikitext103 documents, even if not exact copies (§4.4), we thus obtain dramatic improvements on Wikitext103 as our retrieval model is able to directly exploit these overlaps. The smallest gains are for Curation Corpus, where Retro only slightly outperforms the baseline. This is expected as Curation Corpus summaries are designed to only contain information from the source article and are not included in our retrieval database. On our “future” Wikipedia September 2021 dataset, we also observe consistent gains for all model sizes.
Fig. 1 (middle) shows how scaling the retrieval database at evaluation improves the language modelling performance. We observe dramatic gains as the retrieval data is increased from Wikipedia (4 billion tokens) to all of Massive text (1.7T tokens). Fig. 1(right) shows how performance scales as we increase the number of retrieved chunks. Despite being only trained with 2 neighbours, we see consistent improvements for all models when the number of neighbours is increased from 1 to 10. Furthermore, we observe that larger models are able to better utilise more neighbours: the 172M model improves with up to 10 neighbours, whereas the 7B model improves with up to 40 neighbours.
We evaluate our 7B models on the Pile test setsDue to legal and ethical concerns relating to their use, we exclude the Enron Emails and the Youtube Subtitles datasets. and compare against the 178B parameter Jurrasic-1 (Lieber et al., 2021) model and the 280B parameter Gopher (Rae et al., 2021) model. We do not compare against GPT-3 as it is outperformed by Jurassic-1 and Gopher on almost all subsets. Fig. 4 shows the relative improvements in bits-per-byte over our 7B transformer baseline for our 7.5B Retro model, Jurassic-1 and Gopher. Jurassic-1 outperforms the baseline on all datasets except for books, likely due to the inclusion of books in our training data. Gopher and Retro outperform the baseline on all test sets. Overall, Retro 7.5B outperforms Jurassic-1 and Gopher on a majority of the test sets. On the dm_mathematics and ubuntu_irc subsets, our Retro model does not outperform our 7B baseline and underperforms Jurassic-1. We hypothesise that the retrieved neighbours on these datasets are not helpful, due to a combination of what is in our retrieval dataset and the efficacy of the nearest-neighbour search.
To validate our approach in a controlled setting, we compare our method with (Khandelwal et al., 2020) on the Wikitext103 dataset in Table 4. We train a baseline transformer on the training set of Wikitext103. This transformer has 24 layers, 1024 hidden units, 16 heads and a key size of 64, as in Baevski and Auli (2019). Our baseline does not have adaptive input, and our tokenizer has an open vocabulary, unlike Baevski and Auli (2019), which makes our baseline perplexities a bit higher. The full experiment details and hyperparameters are given in §C.2 and Table 11.
We re-implement with our tokenizer and baseline transformer to produce embeddings of size 1024 for every token in Wikitext103. has probabilities with . We tune and on the validation set (Fig. 7) and report performance for these hyperparameters on both the validation and test set.
We fine-tune our baseline transformer into a Retro model (Fig. 7), using the Wikitext103 training data and retrieving from Wikipedia with 2 neighbours. We only train the new weights, as explained in §4.2, and share the embedding weights between the encoder and the main pathway. This is necessary for Wikitext103 which is quite small, as training Retro from scratch in this setting leads to over-fitting.
We evaluate the fine-tuned Retro model with different retrieval sets. We use 10 neighbours at evaluation for both Retro and . When retrieving from Wikipedia, we obtain results comparable to our implementation. Furthermore, scaling the retrieval database to MassiveText yields dramatic improvements, though this is partly due to leakage (see §4.4). For reproducibility, we also include results when retrieving from C4, which are close to previous state-of-the-art and comparable to using 10 % of MassiveText.
It is worth noting that requires 1024 floats for every token in the retrieval dataset, totalling 15 terabytes (Tb) for the 4 billion tokens in Wikipedia. and other token-level retrieval approaches therefore don’t scale to retrieval databases with trillions of tokens such as MassiveText. In comparison, Retro only requires 215Gb to index our Wikipedia dataset, and 93Tb for MassiveText. Inspecting the number of retrieval database entries in Table 4 makes it clear why retrieving at the chunk level is necessary when scaling to datasets with trillions of tokens.
2 Retro-fitting baseline models
We extend baseline models into Retro models by freezing the pre-trained weights and training only chunked cross-attention and neighbour encoder parameters (less than 10% of weights for the 7B model) in Fig. 5. This offers an efficient alternative path to enhance transformers with retrieval, requiring only 6 million sequences (3% of the pre-training sequences that we used). Additionally, by only training the new weights we ensure that when evaluated without retrieval, the original model performance is exactly maintained. Retrofitting models quickly surpasses the performance of baseline models and even achieves performance close to that of Retro models trained from scratch. The experiment hyperparameters are given in §C.3.
3 Question answering
We fine-tune our retrieval models on the Natural Questions (Kwiatkowski et al., 2019) dataset to demonstrate that our retrieval pathway can be used to inject information from arbitrary data sources. We use the versionhttps://github.com/facebookresearch/FiD provided by Izacard and Grave (2021) which is augmented with the retrieved passages from Dpr (Karpukhin et al., 2020). We fine-tune all the weights of our 7.5B pre-trained Retro model for 25,000 steps using the top 20 retrieved passages. We format the data as “question: {question} \n answer: {answer}” and left pad the data such that “answer:” coincides with the end of the first chunk of 64 tokens and thus aligns with the first retrieving chunk. The model has access to the question via the previous tokens in the sequence as well as the top 20 DPR Wikipedia passages and their titles via the chunked cross-attention mechanism. The exact match scores are shown in Table 5 and the full fine-tuning details are given in §C.4. Our method is competitive with previous approaches such as Realm, RAG and Dpr, but underperforms the more recent FiD. In contrast with this work, we find that increasing the number of neighbours past 20 does not improve Retro performance on this task. We hypothesise that the encoder-decoder structure of T5—the base model in FiD— and the T5 pre-training objective leads to a model that relies more on the encoder output than Retro, which is important in the QA setting. To compete with T5-finetuned models, future work should consider ways of forcing Retro to rely further on the retrieval encoder output when producing tokens.
4 Relating retrieval performance to dataset leakage.
We report the filtered eval losses as detailed in §2.6 on C4, Curation Corpus and Wikitext103 in Fig. 6. On C4 and Wikitext103, for which there is leakage into the training set, the slope is negative for both baseline models and Retro models. Retro models exploit leakage more strongly than baseline models, as indicated by the more negative slope. This is due to its explicit ability to copy-paste existing training chunks to predict leaked evaluation chunks (see a qualitative example of this model behavior on a Wikitext103 article in Table 19). On Curation Corpus, retrieval provides a constant offset, which is expected as there is by design no leakage between Curation Corpus and the training dataset.
On the other hand, Retro outperforms baseline models at all leakage levels, down to . At this level, the loss is computed on chunks with less than contiguous tokens shared with the closest matching chunk in the training dataset—this is a reasonable level of overlap at which we consider that there is no local leakage. Retrieval thus improves predictions on both chunks that are syntactically similar to chunks in the training set, and on chunks that are syntactically different from all training chunks. This points toward a non trivial Retro capacity of generalizing based on both model parameters and retrieval database. Similar results are found on the Pile dataset (see Fig. 12, §F.3).
5 Using Retro for sampling
We show examples of samples obtained using the 7.5B Retro model in Table 6, Table 7 and Appendix E. For each chunk (the first one being the prompt), we juxtapose sampled chunks with retrieved neighbours . To give an indication of local overlap, we colour each sampled token in chunk based on the length of the longest common prefix (LCP) found in the retrieved chunks . Similarly, we colour the retrieved chunks based on the LCP in the sampled chunk. For the sample in Table 6, for which we chose the prompt, we observe that the retrieved chunks influence the sample as there are overlaps between the sampled tokens and neighbour tokens. Overall, retrieval reduces hallucinations (in line with the findings of Shuster et al. (2021)) and makes the model more knowledgeable, when comparing with samples produced with retrieval disabled. In the sample in Table 7, the model recognises that the prompt is the beginning of the first scene of Hamlet and leverages retrieval data to continue it with only a few mistakes. We provide further examples in Appendix E, including examples from the evaluation sets, as well as the detailed procedure used for colouring the tables.
Conclusion
We present Retrieval-Enhanced Transformers (Retro), a method for modelling arbitrary text sequences whilst retrieving from databases with trillions of tokens—scaling the data available to models by an order of magnitude compared to what is typically consumed during training. Retro models gains do not diminish for models with up to at least 7B parameters, and correspond to non-retrieval models with 10 more parameters on certain datasets. On Wikitext103 and the Pile, Retro outperforms previous models trained on large scale datasets. We also show that Retro is competitive on retrieval-intensive downstream tasks such as question answering.
Retro models are flexible and can be used without retrieval at evaluation and still achieve comparable performance to baseline models. Conversely, baseline models can be rapidly fine-tuned into Retro models to obtain nearly the same performance as if trained from scratch. Careful analysis shows that only a modest fraction of the gains obtained by Retro are due to test set leakage. In general, we caution for such leakage in large-scale language datasets and suggest further work in better understanding the role of test set leakage in the performance of large-scale language models.
Overall, our work demonstrates at an unprecedented scale that semi-parametric approaches can provide an orthogonal, more efficient approach than raw parameter scaling as we seek to build more powerful language models.
Acknowledgements
We would like to thank Nikolai Grigorev, Marc’aurelio Ranzato, Cyprien de Masson d’Autume, Po-Sen Huang, Johannes Welbl, Lisa Anne Hendricks, Ethan Perez, Jeff Stanway, Eric Noland, Gregory Wayne, John Jumper, Julian Schrittwieser, Lorrayne Bennett, Devang Agrawal, Dani Yogatama, Susannah Young, Nando de Freitas, Demis Hassabis, and Koray Kavukcuoglu for their help, advice and reviews. Additionally, we would like to thank Zonglin Li, David Simcha, and the ScaNN developers for their help.
References
Appendix A Datasets
We provide a full description of MassiveText and of our extract of recent Wikipedia articles.
The full break down of MassiveText by source and languages is given in Table 8. For a full description and analysis of MassiveText, see Rae et al. (2021).
A.2 Wikipedia September 2021
We create an evaluation dataset consisting of 23 Wikipedia articles that were added or heavily edited in September 2021, after we collected our training dataset. In addition, we filter out articles that rely too heavily on templated content, using the method detailed in §2.6 to identify articles with chunks that have a high overlap with their neighbours. Fig. 10 show that little overlap remains between our test dataset and the retrieved neighbours from the training dataset. The full list of included articles is given in Table 9.
We first parse articles using mwparserfromhellhttps://github.com/earwig/mwparserfromhell. We then remove sections with the following titles: “references”, “external links”, “sources”, “further reading”, “see also”, “citations”, and “note”. In the remaining sections, we remove Wikilinks and remove the following templates: “reflist”, “notelist”, “notelist-ua”, “notelist-lr”, “notelist-ur”, and “notelist-lg”. We also exclude objects with the “ref” or “table” tag and clean the remaining text with the strip_code function. Finally, we concatenate the title and all the sections and use \n\n to delimitate them.
Appendix B Details on the retrieval architecture
We give details on the Retro architecture, and on the fine-tuning procedure we use for Retrofitting existing language models.
B.1.2 Relative positional encoding in the chunked cross-attention layer
The Ca operator uses relative positional logits, that are computed from a specific relative distance separating data tokens from retrieval tokens. Indeed, we expect any retrieval neighbour and the chunk to be relatively well aligned, and assume that they start at the same position. Therefore, when computing , we set the distance between the data token of chunk and the retrieval token of to be
When computing the encoder cross-attentions , we set the distance between the retrieval token and the data token to be
Positional logits are obtained as a linear transform of a cosine vector computed from , and are added to content logits, as in a regular self-attention block.
B.1.3 Chunked cross-attention implementation
Our implementation of the Cca operator, shown in 1, is based on a vectorized application of a cross-attention layer. For simplicity, we omit the multi-head attention logic and use the simplest Q,K,V attention. We omit relative positional logits computation, described above.
B.1.4 Optional sharing of embedding matrices
We use disjoint embeddings for the encoder and decoder by default, which allows us to use a different dimensionality for the encoder (typically kept at and for the decoder (that we scale up to ). It is possible to share the embeddings, with little difference in training, as we show in the ablation section.
B.2 Baseline to Retro model fine-tuning
As shown in Fig. 5, we found that we were able to take a pre-trained baseline transformer and add Retro through fine-tuning. In all cases, we froze all weights from pre-training and freshly initialised the retrieval encoder and cross-attention weights. In all cases, the cross-attention is added every third layer starting at layer six. The learning rate for the three smaller models was set to and half that for the larger model. We experimented with allowing the entire model to resume training during fine-tuning but consistently found that the best approach was to freeze the pre-trained model. This kept the retrieval-off performance frozen whereas when all weights were tuned the retrieval off performance would degrade.
Appendix C Training details and hyperparameters
We provide the hyperparameters used in the various experiments of §4.
In Table 10, we show the hyperparameters of the different models we train. In all cases, we train for 419,430,400,000 training tokens. The three smaller models are trained with a batch size of 256 and the largest model is trained with a batch size of 1024. The minimum learning rate is set to 0.1 times the maximum learning rate, which is shown in Table 10. The learning rate is decayed using a cosine cycle length that matches the total number of training tokens. All models are trained using AdamW (Loshchilov and Hutter, 2019) with a weight decay parameter of 0.1. The learning rate linearly increases from to the maximum learning rate over the first 750 steps of training.
All models use ZeRO to shard the optimiser state (Rajbhandari et al., 2020). Additional infrastructure details can be found in Rae et al. (2021).
C.2 Wikitext103 comparison
We provide more details on our Wikitext103 results presented in §4.1 and Table 4. We train a baseline transformer on the Wikitext103 training set with the hyperparameters presented in Table 11. The learning rate ramps linearly from to in the first 4,000 steps, then decays to at 100,000 steps using a cosine schedule. The baseline checkpoint at step 35,000 has the lowest perplexity on Wikitext103 valid, of , for overlapping proportion of 75% (sliding window evaluation that only uses probabilities for tokens that have at least 75% of the sequence length of context, when available). We use this checkpoint for all our baseline and numbers reported in Table 4, except that Table 4 reports for an overlapping proportion of 87.5 %, which slightly lowers the perplexity of our baseline to 21.53 on Wikitext103 valid.
We also use the 35,000 step baseline checkpoint as initialization for a Retrofit, which otherwise uses the same optimiser and schedule hyperparameters but only trains the new retrieval weights, as explained in §4.2. Our best Retrofit checkpoint has a Wikitext103 valid perplexity , when retrieving from Wikipedia. We use this Retro checkpoint in Table 4 for all other retrieval sets. The evaluation curves for our baseline and Retrofit is shown if Fig. 7 (left). In this particular case, because Wikitext103 is quite small, training a Retro model from scratch led to weaker results than the baseline, at least when retrieving from Wikipedia, as we couldn’t find an effective way to mitigate the increased over-fitting due to the additional weights of Retro.
We also re-implement using the same tokenizer and dataset that we use for our baseline and Retrofitting experiments. has probabilities with . To tune and , we begin with , which corresponds to the inverse of the standard deviation of the norm of the embeddings that we use as keys and queries for . We find the best . We then find the best for that value of . Fig. 7 center and right respectively show the perplexity of as a function of and .
C.3 Retrofitting baseline models experiments
In Table 12, we give the hyperparameters used for Retrofitting the models on Massive Text.
C.4 Question answering experiments
We fine-tune our 7.5B Retro model for 25,000 steps, using a batch size of 128, a learning rate cosine scheduled from to , with a linear ramp of 750 steps. We use dropout in the decoder only, as it performs better than using dropout in both the encoder and the decoder. Each neighbour is formatted as title: {title}, source: {source}. We use the top 20 neighbours from Dpr when training and evaluating.
Appendix D Model ablations
We validate important design choices by evaluating what happens when we do not include them. We use the 247M parameter model for all experiments and we train on a compressed 157 billion token schedule for all ablation experiments. We describe results relative to the default settings presented in the main text and recalled here. We report C4 evaluation loss at the end of the training process, and also compares how the evaluation loss decrease versus the training time, measured relatively to the baseline training time. Results are reported in Fig. 8 and Table 13.
Using relative encodings in cross-attention, as described in §B.1.2, provides a pure improvement both in the number of steps to reach a given performance and computational efficiency.
Conditioning the encoder on the previous chunk’s intermediate embeddings, as described in §B.1.1, provides a pure improvement both in term of number of steps and computational efficiency.
Sharing embeddings across the encoder and the decoder does not affect performance. This motivates us using separate embeddings, as it allows to have a narrower encoder than decoder as we scale up the decoder size.
Retro models are trained by attending, for a given chunk, to both the neighbours of the preceding chunk and their continuation in time. We measure how training and evaluating Retro models on neighbours only and their continuation only affects performance. Overall, attending to neighbours only provides of the performance improvement due to retrieval in Retro, while attending the future of the neighbours gives of the performance. Attending to both neighbours and their continuation is the most efficient choice both in term of final performance and training efficiency.
All models in the text use a relatively small Retro encoder. We experimented with a deeper encoder. We found that this resulted in a tiny decrease in loss– 0.15% at the cost of a larger training time (). Overall, using a shallow encoder is the best choice in term of training efficiency.
We measure the effect of training on a single retrieved neighbour, as well as training on 4 neighbours (Retro uses 2 neighbours in training). Training on a single neighbour results in a large decrease in performance, while training on 4 neighbours does not give substantial performance improvement at the end of training, but induces a large computational overhead. Overall, we find that using 2 neighbours is the best choice in term of training efficiency. Furthermore, evaluation can be done with additional neighbours.
We measure how the frequency of cross-attention in the decoder affects performance. Overall, attending only once at the top or the bottom layer is a bad choice, while attending once on a mid-depth layer is relatively sound. We choose to have cross-attention every 3 layer as this provides a good trade-off between performance and run-time.
Appendix E Qualitative experiments
We illustrate the usage of Retro models by looking at the perplexity of evaluation samples and by producing samples autoregressively.
To build an intuition of what kind of information is leveraged by Retro models, we suggest to have a closer look at a few evaluation documents and the corresponding retrieved data in Tables 16, 17, 18 and 19. In these tables, the 4 rows corresponds to the first 4 chunks of the documents. The left-most column shows the chunk from the document being evaluated, where each token is coloured by the negative cross entropy loss difference , a positive value, coloured in yellow, indicates that Retro performs better when it has access to neighbours data. The second columns also shows the evaluated chunk but where each token is coloured by the length of the longest common prefix (LCP) with the preceding neighbours, i.e. the largest integer such that the prefix also appears in . Conversely, columns three and four show the first two neighbours and their continuation, respectively and coloured by LCP with subsequent chunk . LCP colouring helps to visually identify where the evaluated document overlaps the retrieved data. Note that the first chunk, , in the second column is not coloured as it does not have any preceding neighbours to compute LCP with. Similarly, we do not show the neighbours of the fourth chunk, as these are not used to condition any of the first four chunks.
Our qualitative analysis exhibits two major behaviors.
Firstly, we observe that sometimes, specific facts in can be extracted from the preceding neighbours and that this can correspond to significant reduction in loss from the Retro model for the corresponding tokens. Some examples of such behavior include the journal name Publishers Weekly in Table 16, the football team name Tyrone in Table 17 or the event dates 25 August to 6 September 2020 in Table 18. In these three examples, the evaluated data consists of recent Wikipedia articles written in September 2021, after we built our retrieval dataset (see section §A.2). Yet, relevant information to predict this new data was available in the pre-existing retrieval data and the Retro model seems to be able to correctly leverage it.
On the other hand, we also observe that some of the evaluation data can partially leak in our training and retrieval data, despite the use of deduplication. Retro can dramatically exploit such leakage. Table 19 illustrates this behavior, where the chunks and largely overlaps and respectively, up to small formatting differences, which leads to much lower Retro loss for all the corresponding tokens. Fig. 6 shows that it is possible to quantify how much of the Retro loss reduction is due to each of these two behaviors, by filtering out evaluation chunks that overlaps with the retrieval set.
E.2 Inspecting samples
We can follow the same procedure as above on samples generated using Retro models, in order to better understand where retrieval data had an influence on sampling. We show examples of samples obtained using the 7.5B Retro model in Table 6, 7, 20 and 21.
E.3 Neighbour quantification
To quantify a notion of distance between the source document and the retrieved chunks, we can ask the distance between source articles when retrieving only from Wikipedia. Consonni et al. (2019) provides a Wikipedia link dataset which, for each article, contains a list of neighbouring articles. Using this, we construct a directed graph and compute the distance from one page to another. In Fig. 9 we compute the link-distance between training sequences and the retrieved neighbours. We find that retrieved documents tend to be from articles that are quite close to the article containing the target. Furthermore, we find that on average the distance increases with rank, suggesting that our neighbours are both useful and that the order is reasonable. This provides confidence for our larger-scale experiments where document distance is less well defined.
Appendix F Complementary quantitative results
We report tables corresponding to quantitative figures of the main text, as well as further filtered language model results on the Pile.
We report the performance of Retro and baseline models, measured in bits-per-bytes on evaluation set, in Table 14.
F.2 The Pile
In Fig. 4, we compare Retro against Jurassic-1 (Lieber et al., 2021). The full bits-per-bytes results are reported in Table 15.
F.3 Filtered results
We evaluate leakage between the evaluation sets and the training set by measuring the proportion of evaluation chunks with a certain overlap . We show histograms in Fig. 10. We can see that has some slight overlaps between train and evaluation. Similarly, chunks of Wikitext103 appear in the training set despite having removed the actual Wikitext103 evaluation documents from the training set. On the other hand, our Wikipedia September 21 dataset shows almost no leakage (data being original documents that did not exist at training data creation), and neither does Curation Corpus.
We report chunk overlap distribution and filtered performance curves on the Pile in Fig. 12 and Fig. 11, respectively. The qualitative interpretation of the filtered curves is the same: Retro models exploit leakage more, but the performance improvement they provide remains significant even on original chunks that haven’t been observed in the training set.