REALM: Retrieval-Augmented Language Model Pre-Training

Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat, Ming-Wei Chang

Introduction

Recent advances in language model pre-training have shown that models such as BERT (Devlin et al., 2018), RoBERTa (Liu et al., 2019) and T5 (Raffel et al., 2019) store a surprising amount of world knowledge, acquired from the massive text corpora they are trained on (Petroni et al., 2019). For example, BERT is able to correctly predict the missing word in the following sentence: “The is the currency of the United Kingdom” (answer: “pound”).

In these language models, the learned world knowledge is stored implicitly in the parameters of the underlying neural network. This makes it difficult to determine what knowledge is stored in the network and where. Furthermore, storage space is limited by the size of the network—to capture more world knowledge, one must train ever-larger networks, which can be prohibitively slow or expensive.

To capture knowledge in a more interpretable and modular way, we propose a novel framework, Retrieval-Augmented Language Model (REALM) pre-training, which augments language model pre-training algorithms with a learned textual knowledge retriever. In contrast to models that store knowledge in their parameters, this approach explicitly exposes the role of world knowledge by asking the model to decide what knowledge to retrieve and use during inference. Before making each prediction, the language model uses the retriever to retrieve documents111We use the term “document” loosely to refer to a passage from the knowledge corpus, not necessarily a whole article. from a large corpus such as Wikipedia, and then attends over those documents to help inform its prediction. Learning this model end-to-end requires backpropagating through a retrieval step that considers an entire corpus of textual knowledge, as shown in Figure 1.

The key intuition of REALM is to train the retriever using a performance-based signal from unsupervised text: a retrieval that improves the language model’s perplexity is helpful and should be rewarded, while an uninformative retrieval should be penalized. For example, in Figure 1, if the model needs to fill the blank in “the at the top of the pyramid”, the retriever should be rewarded for selecting a document containing “The pyramidion on top allows for less material higher up the pyramid”. We achieve this behavior by modeling our retrieve-then-predict approach as a latent variable language model and optimizing the marginal likelihood.

Incorporating a large-scale neural retrieval module during pre-training constitutes a significant computational challenge, since the retriever must consider millions of candidate documents for each pre-training step, and we must backpropagate through its decisions. To address this, we structure the retriever such that the computation performed for each document can be cached and asynchronously updated, and selection of the best documents can be formulated as Maximum Inner Product Search (MIPS).

Numerous prior works have demonstrated the benefit of adding a discrete retrieval step to neural networks (Miller et al., 2016; Chen et al., 2017), but did not apply the framework to language model pre-training and employed non-learned retrievers to handle large-scale document collections. In the language modeling literature, the kk-Nearest Neighbor Language Model (Khandelwal et al., 2019) (kkNN-LM) retrieves similar LM examples to improve memorization. However, kkNN-LM was not fine-tuned for downstream tasks, perhaps because it is unclear how to adapt the retrieval mechanism: a kkNN can only use examples labeled for the target task—during fine-tuning, this precludes LM examples, which contain the desired world knowledge. In contrast, REALM’s retriever is designed to transfer to other tasks, and the retrieval is just text, not a labeled example.

We evaluate our approach by fine-tuning the models pre-trained with REALM on the task of Open-domain Question Answering (Open-QA), one of the most knowledge-intensive tasks in natural language processing. We evaluate on three popular Open-QA benchmarks (NaturalQuestions-Open, WebQuestions, and CuratedTrec) and compare to state-of-the-art Open-QA models, including both extremely large models that store knowledge implicitly (such as T5) as well as previous approaches that also use a knowledge retriever to access external knowledge, but implement retrieval in a more heuristic fashion (Lee et al., 2019; Min et al., 2019a; Asai et al., 2019). REALM achieves new state-of-the-art results on all three benchmarks, significantly outperforming all previous systems by 4-16% absolute accuracy. We also demonstrate qualitative benefits of REALM, including interpretability and modularity.

Background

The goal of language model pre-training is to learn useful representations of language, usually from unlabeled text corpora. The resulting pre-trained model can then be further trained (fine-tuned) for a downstream task of primary interest (in our case, Open-QA), often leading to better generalization than training from scratch (Dai & Le, 2015; Radford et al., 2019).

We focus on the masked language model222Strictly speaking, MLM is not a standard language model, since it does not define a distribution over the entire sequence of tokens. In the paper we sometimes abuse the term “language model” slightly to make the phrase shorter. (MLM) variant of pre-training popularized by BERT (Devlin et al., 2018). In its basic form, an MLM is trained to predict the missing tokens in an input text passage. Given an unlabeled pre-training corpus X\mathcal{X} (e.g., Wikipedia text), a training example (x,y)(x,y) can be generated by randomly masking tokens in a sampled piece of text (e.g., x=x= “The [MASK] is the currency [MASK] the UK”; y=y= (“pound”, “of”)). The model uses its representation of the masked input xx to predict the token that should go in each mask. A good MLM must learn to encode syntactic and semantic information (e.g., to predict “of”) as well as some world knowledge (e.g., to predict “pound”).

Open-domain question answering (Open-QA)

To measure a model’s ability to incorporate world knowledge, we need a downstream task where world knowledge is critical. Perhaps one of the most knowledge-intensive tasks in natural language processing is open-domain question answering (Open-QA): given a question xx such as “What is the currency of the UK?”, a model must output the correct answer string yy, “pound”. The “open” part of Open-QA refers to the fact that the model does not receive a pre-identified document that is known to contain the answer, unlike traditional reading comprehension (RC) tasks such as SQuAD (Rajpurkar et al., 2016, 2018). While RC models comprehend a single document, Open-QA models must retain knowledge from millions of documents, since a question could be about any of them.

We focus on Open-QA systems that utilize a textual knowledge corpus Z\mathcal{Z} as the knowledge source. Many of these systems employ a retrieval-based approach: given a question xx, retrieve potentially relevant documents zz from the corpus Z\mathcal{Z}, and then extract an answer yy from the documents (Brill et al., 2002; Chen et al., 2017; Lee et al., 2019). Our approach, REALM, is inspired by this paradigm and extends it to language model pre-training. Alternatively, some recent work has proposed generation-based systems that apply a sequence-to-sequence model on xx to directly generate yy token-by-token (Lewis et al., 2019; Raffel et al., 2019). We will compare against state-of-the-art systems from both paradigms in our experiments.

Approach

We start by formalizing REALM’s pre-training and fine-tuning tasks as a retrieve-then-predict generative process in Section 3.1. Then in Section 3.2, we describe the model architectures for each component of that process. In Section 3.3, we show how to implement REALM pre-training and fine-tuning by maximizing the likelihood of REALM’s generative process. En route, we address important computational challenges, explain why training works, and also discuss strategies for injecting useful inductive biases. The overall framework is illustrated in Figure 2.

For both pre-training and fine-tuning, REALM takes some input xx and learns a distribution p(yx)p(y\,|\,x) over possible outputs yy. For pre-training, the task is masked language modeling: xx is a sentence from a pre-training corpus X\mathcal{X} with some tokens masked out, and the model must predict the value of those missing tokens, yy. For fine-tuning, the task is Open-QA: xx is a question, and yy is the answer.

REALM decomposes p(yx)p(y\,|\,x) into two steps: retrieve, then predict. Given an input xx, we first retrieve possibly helpful documents zz from a knowledge corpus Z\mathcal{Z}. We model this as a sample from the distribution p(zx)p(z\,|\,x). Then, we condition on both the retrieved zz and the original input xx to generate the output yy—modeled as p(yz,x)p(y\,|\,z,x). To obtain the overall likelihood of generating yy, we treat zz as a latent variable and marginalize over all possible documents zz, yielding

2 Model architecture

We now describe the two key components: the neural knowledge retriever, which models p(zx)p(z\,|\,x), and the knowledge-augmented encoder, which models p(yz,x)p(y\,|\,z,x).

The retriever is defined using a dense inner product model:

where Embedinput\mathtt{Embed_{input}} and Embeddoc\mathtt{Embed_{doc}} are embedding functions that map xx and zz respectively to dd-dimensional vectors. The relevance score f(x,z)f(x,z) between xx and zz is defined as the inner product of the vector embeddings. The retrieval distribution is the softmax over all relevance scores.

We implement the embedding functions using BERT-style Transformers (Devlin et al., 2018). Following standard practices, we join spans of text by applying wordpiece tokenization, separating them with [SEP] tokens, prefixing a [CLS] token, and appending a final [SEP] token.

As in Devlin et al. (2018), we pass this into a Transformer, which produces one vector for each token, including the vector corresponding to [CLS] which is used as a “pooled” representation of the sequence (denoted BERTCLS\mathtt{BERT_{CLS}}). Finally, we perform a linear projection to reduce the dimensionality of the vector, denoted as a projection matrix W\mathbf{W}:

where ztitlez_{\text{title}} is the document’s title and zbodyz_{\text{body}} is its body. We let θ\theta denote all parameters associated with the retriever, which include the Transformer and projection matrices.

Knowledge-Augmented Encoder

Given an input xx and a retrieved document zz, the knowledge-augmented encoder defines p(yz,x)p(y\,|\,z,x). We join xx and zz into a single sequence that we feed into a Transformer (distinct from the one used in the retriever). This allows us to perform rich cross-attention between xx and zz before predicting yy. See Figure 1 for a concrete example.

At this stage, the architectures for pre-training and fine-tuning differ slightly. For the masked language model pre-training task, we must predict the original value of each [MASK] token in xx. To do so, we use the same masked language modeling (MLM) loss as in Devlin et al. (2018):

where BERTMASK(j)\mathtt{BERT}_{\mathtt{MASK}(j)} denotes the Transformer output vector corresponding to the jthj^{th} masked token, JxJ_{x} is the total number of [MASK] tokens in xx, and wjw_{j} is a learned word embedding for token yjy_{j}.

For Open-QA fine-tuning, we wish to produce the answer string yy. Following previous reading comprehension work (Rajpurkar et al., 2016; Seo et al., 2016; Lee et al., 2016; Clark & Gardner, 2017), we will assume that the answer yy can be found as a contiguous sequence of tokens in some document zz. Let S(z,y)S(z,y) be the set of spans matching yy in zz. Then we can define p(yz,x)p(y\,|\,z,x) as:

where BERTSTART(s)\mathtt{BERT_{START(s)}} and BERTEND(s)\mathtt{BERT_{END(s)}} denote the Transformer output vectors corresponding to the start and end tokens of span ss, respectively, while MLP\mathtt{MLP} denotes a feed-forward neural network. We will let ϕ\phi denote all parameters associated with the knowledge-augmented encoder.

3 Training

For both pre-training and fine-tuning, we train by maximizing the log-likelihood logp(yx)\log p(y\,|\,x) of the correct output yy. Since both the knowledge retriever and knowledge-augmented encoder are differentiable neural networks, we can compute the gradient of logp(yx)\log p(y\,|\,x) (defined in Equation 1) with respect to the model parameters θ\theta and ϕ\phi, and optimize using stochastic gradient descent.

The key computational challenge is that the marginal probability p(yx)=zZp(yx,z)p(zx)p(y\,|\,x)=\sum_{z\in\mathcal{Z}}p(y\,|\,x,z)\,p(z\,|\,x) involves a summation over all documents zz in the knowledge corpus Z\mathcal{Z}. We approximate this by instead summing over the top kk documents with highest probability under p(zx)p(z\,|\,x)—this is reasonable if most documents have near zero probability.

Even with this approximation, we still need an efficient way to find the top kk documents. Note that the ordering of documents under p(zx)p(z\,|\,x) is the same as under the relevance score f(x,z)=Embedinput(x)Embeddoc(z)f(x,z)=\mathtt{Embed_{input}}(x)^{\top}\mathtt{Embed_{doc}}(z), which is an inner product. Thus, we can employ Maximum Inner Product Search (MIPS) algorithms to find the approximate top kk documents, using running time and storage space that scale sub-linearly with the number of documents (Ram & Gray, 2012; Shrivastava & Li, 2014; Shen et al., 2015).

To employ MIPS, we must pre-compute Embeddoc(z)\mathtt{Embed_{doc}}(z) for every zZz\in\mathcal{Z} and construct an efficient search index over these embeddings. However, this data structure will no longer be consistent with p(zx)p(z\,|\,x) if the parameters θ\theta of Embeddoc\mathtt{Embed_{doc}} are later updated. Hence, the search index goes “stale” after every gradient update on θ\theta.

Our solution is to “refresh” the index by asynchronously re-embedding and re-indexing all documents every several hundred training steps. The MIPS index is slightly stale between refreshes, but note that it is only used to select the top kk documents. We recompute p(zx)p(z\,|\,x) and its gradient, using the fresh θ\theta, for these top kk documents after retrieving them. In Section 4.5, we empirically demonstrate that this procedure results in stable optimization, provided that refreshes happen at a sufficiently frequent rate.

We asynchronously refresh the MIPS index by running two jobs in parallel: a primary trainer job, which performs gradient updates on the parameters, and a secondary index builder job, which embeds and indexes the documents. As shown below, the trainer sends the index builder a snapshot of its parameters, θ\theta^{\prime}. The trainer then continues to train while the index builder uses θ\theta^{\prime} to construct a new index in the background. As soon as the index builder is done, it sends the new index back to the trainer, and the process repeats.

While asynchronous refreshes can be used for both pre-training and fine-tuning, in our experiments we only use it for pre-training. For fine-tuning, we just build the MIPS index once (using the pre-trained θ\theta) for simplicity and do not update Embeddoc\mathtt{Embed_{doc}}.333This works because pre-training already yields a good Embeddoc\mathtt{Embed_{doc}} function. However, it is possible that refreshing the index would further improve performance. Note that we still fine-tune Embedinput\mathtt{Embed_{input}}, so the retrieval function is still updated from the query side.

What does the retriever learn?

Since the knowledge retrieval of REALM is latent, it is not obvious how the training objective encourages meaningful retrievals. Here, we show how it rewards retrievals that improve prediction accuracy.

For a given query xx and document zz, recall that f(x,z)f(x,z) is the “relevance score” that the knowledge retriever assigns to document zz. We can see how a single step of gradient descent during REALM pre-training alters this score by analyzing the gradient with respect to the parameters of the knowledge retriever, θ\theta:

For each document zz, the gradient encourages the retriever to change the score f(x,z)f(x,z) by r(z)r(z) — increasing if r(z)r(z) is positive, and decreasing if negative. The multiplier r(z)r(z) is positive if and only if p(yz,x)>p(yx)p(y\,|\,z,x)>p(y\,|\,x). The term p(yz,x)p(y\,|\,z,x) is the probability of predicting the correct output yy when using document zz. The term p(yx)p(y\,|\,x) is the expected value of p(yx,z)p(y\,|\,x,z) when randomly sampling a document from p(zx)p(z\,|\,x). Hence, document zz receives a positive update whenever it performs better than expected.

4 Injecting inductive biases into pre-training

In the process of developing REALM, we discovered several additional strategies that further guide the model towards meaningful retrievals, described below.

During REALM pre-training, we want to focus on examples xx that require world knowledge to predict the masked tokens. As explained in Section 2, some MLM spans only require local context. To focus on problems that require world knowledge, we mask salient spans such as “United Kingdom” or “July 1969”. We use a BERT-based tagger trained on CoNLL-2003 data (Sang & De Meulder, 2003) to identify named entities, and a regular expression to identify dates. We select and mask one of these salient spans within a sentence for the masked language modeling task. We show that this significantly outperforms other masking strategies in Section 4.5.

Null document

Even with salient span masking, not all masked tokens require world knowledge to predict. We model this by adding an empty null document \varnothing to the top kk retrieved documents, allowing appropriate credit to be assigned to a consistent sink when no retrieval is necessary.

Prohibiting trivial retrievals

If the pre-training corpus X\mathcal{X} and the knowledge corpus Z\mathcal{Z} are the same, there exists a trivial retrieval candidate zz that is too informative: if the masked sentence xx comes from document zz, the knowledge augmented encoder can trivially predict yy by looking at the unmasked version of xx in zz. This results in a large positive gradient for p(zx)p(z\,|\,x). If this occurs too often, the knowledge retriever ends up learning to look for exact string matches between xx and zz, which does not capture other forms of relevance. For this reason, we exclude this trivial candidate during pre-training.

Initialization

At the beginning of training, if the retriever does not have good embeddings for Embedinput(x)\mathtt{Embed_{input}}(x) and Embeddoc(z)\mathtt{Embed_{doc}}(z), the retrieved documents zz will likely be unrelated to xx. This causes the knowledge augmented encoder to learn to ignore the retrieved documents. Once this occurs, the knowledge retriever does not receive a meaningful gradient and cannot improve, creating a vicious cycle. To avoid this cold-start problem, we warm-start Embedinput\mathtt{Embed_{input}} and Embeddoc\mathtt{Embed_{doc}} using a simple training objective known as the Inverse Cloze Task (ICT) where, given a sentence, the model is trained to retrieve the document where that sentence came from. We defer to Lee et al. (2019) for details. For the knowledge-augmented encoder, we warm-start it with BERT pre-training—specifically, the uncased BERT-base model (12 layers, 768 hidden units, 12 attention heads).

Experiments

We now evaluate our approach on the Open-QA task. In this section, we describe in detail the benchmarks used and the different approaches to which we compare empirically.

A number of benchmarks have been proposed for Open-QA. In this work, we focus on datasets where the question writers did not already know the answer. This yields questions that reflect more realistic information-seeking needs, and also avoids artifacts that can arise if the question is formulated with a particular answer in mind. A deeper justification is given in Lee et al. (2019). In all cases, the predicted answer is evaluated via exact match with any reference answer, following previous Open-QA work (Chen et al., 2017).

The NaturalQuestions dataset (Kwiatkowski et al., 2019) consists of naturally occurring Google queries and their answers. Each answer also comes with an “answer type”: following Lee et al. (2019), we only keep questions that are categorized as “short answer type” with at most five tokens. The dataset also provides a suggested Wikipedia document to retrieve; like all models we compare against, we do not provide this to our model.

WebQuestions

The WebQuestions dataset (Berant et al., 2013) was collected from the Google Suggest API, using one seed question and expanding the set to related questions. We follow the setting defined by Chen et al. (2017).

CuratedTrec

The CuratedTrec dataset is a collection of question-answer pairs drawn from real user queries issued on sites such as MSNSearch and AskJeeves. To account for multiple correct answers or different spelling variations, the answers in this dataset are defined as regular expressions that match all correct answers. It is unclear how to train generation-based models with this type of supervision, so we do not evaluate them on this dataset.

2 Approaches compared

Most existing Open-QA systems answer the input question by first retrieving potentially relevant documents from a knowledge corpus, and then using a reading comprehension system to extract an answer from the documents. In this paradigm, the knowledge is stored explicitly in the corpus. We wish to compare different methods for implementing retrieval.

Many approaches use non-learned heuristic retrieval such as sparse bag-of-words matching (Robertson et al., 2009) or entity linking on the question to select a small set of relevant documents (e.g., 20). These documents are typically then re-ranked using a learned model, but coverage may be limited by the initial heuristic retrieval step. Approaches such as DrQA (Chen et al., 2017), HardEM (Min et al., 2019a), GraphRetriever (Min et al., 2019b), and PathRetriever (Asai et al., 2019) in Table 1 are in this category.

Some recent approaches have proposed to implement learnable retrieval using a MIPS index. ORQA (Lee et al., 2019) formulates Open-QA using a similar latent variable model as REALM, and also trains by maximizing the marginal likelihood. However, REALM adds a novel language model pre-training step, and backpropagates into the MIPS index, rather than using a fixed index. In Table 1, we directly compare the two. It is also important to note that the retrievers for both REALM pretraining and ORQA are initialized using the Inverse Cloze Task, described in Section 3.4.

Generation-based Open-QA

An emerging alternative approach to Open-QA is to model it as a sequence prediction task: simply encode the question, and then decode the answer token-by-token based on the encoding. While it was initially unclear how large amounts of knowledge could be injected into the model, GPT-2 (Radford et al., 2019) hinted at the possibility of directly generating answers without using any given context via sequence-to-sequence. However, their performance was not competitive possibly due to the lack of fine-tuning. Orthogonally, T5 (Raffel et al., 2019) showed that directly generating answers without explicit extraction from the given context is viable approach, but they only experimented on the reading comprehension task, where a context document is provided.

For the most competitive and comparable generation-based baseline, we compare to concurrent work which fine-tunes T5 for Open-QA (Roberts et al., 2020).444We initially conducted our own T5 experiments using the code from https://tinyurl.com/t5-openqa-colab (Raffel et al., 2019). We now report results from the concurrent work of Roberts et al. (2020), which has an improved fine-tuning procedure. We compare against the Base, Large, and even larger 11-billion parameter model to measure the effect of model size.

3 Implementation Details

We reuse all hyperparameters from Lee et al. (2019), to enable direct comparison. Our knowledge corpus is derived from the December 20, 2018 snapshot of English Wikipedia. Documents are greedily split into chunks of up to 288 BERT wordpieces, resulting in just over 13 million retrieval candidates. During fine-tuning inference, we consider the top-5 candidates, and the entire model can be run on a single machine with a 12GB GPU.

Pre-training

We pre-train for 200k steps on 64 Google Cloud TPUs, with a batch size of 512 and a learning rate of 3e-5, using BERT’s default optimizer. The document embedding step for the MIPS index is parallelized over 16 TPUs. For each example, we retrieve and marginalize over 8 candidate documents, including the null document \varnothing.

We experiment with two choices of the pre-training corpus X\mathcal{X}: (1) Wikipedia, which is identical to the knowledge corpus Z\mathcal{Z}, and (2) CC-News, our reproduction of the corpus of English news proposed by Liu et al. (2019).

4 Main results

Table 1 shows the accuracy of different approaches on the three Open-QA datasets. REALM outperform all previous approaches by a significant margin. Table 1 also shows the number of parameters for each model.

As reported in the concurrent work of Roberts et al. (2020), the generative Open-QA systems based on T5 are surprisingly powerful, with the largest T5-11B model outperforming the previous best Open-QA system. Increasing the size of T5 yields consistent improvement, but comes at significant computational cost (from Base to 11B, the model is 50 times larger, and gains roughly 5 points in accuracy). In contrast, REALM outperforms the largest T5-11B model while being 30 times smaller. It is also important to note that T5 accesses additional reading comprehension data from SQuAD during its pre-training (100,000+ examples). Access to such data could also benefit REALM, but was not used in our experiments.

Among all systems, the most direct comparison with REALM is ORQA (Lee et al., 2019), where the fine-tuning setup, hyperparameters and training data are identical. The improvement of REALM over ORQA is purely due to better pre-training methods. The results also indicate that our method of pre-training can be applied both on (1) the single-corpus setting (X\mathcal{X} = Wikipedia, Z\mathcal{Z} = Wikipedia), or (2) the separate-corpus setting (X\mathcal{X} = CC-News, Z\mathcal{Z} = Wikipedia).

Compared to other retrieval-based systems (Asai et al., 2019; Min et al., 2019a, b) which often retrieve from 20 to 80 documents, our system gets the overall best performance while only retrieving 5 documents.

5 Analysis

In Table 2 we present results for NaturalQuestions-Open after ablating critical components of REALM. In addition to the end-to-end results, we also report how often the gold answer appears in the top-5 retrievals before applying any fine-tuning. The latter metric more significantly isolates the contribution of improving the retriever during pre-training.

We first aim to determine whether REALM pre-training improves the retriever or the encoder, or both. To do so, we can reset the parameters of either the retriever or the encoder to their baseline state before REALM pre-training, and feed that into fine-tuning. Resetting both the retriever and encoder reduces the system to our main baseline, ORQA. We find that both the encoder and retriever benefit from REALM training separately, but the best result requires both components acting in unison.

Masking scheme

We compare our salient span masking scheme (Section 3.4) with (1) random token masking introduced in BERT (Devlin et al., 2018) and (2) random span masking proposed by SpanBERT (Joshi et al., 2019). While such salient span masking has not been shown to be impactful in previous work with standard BERT training (Joshi et al., 2019), it is crucial for REALM. Intuitively, the latent variable learning relies heavily on the utility of retrieval and is therefore more sensitive to a consistent learning signal.

MIPS index refresh rate

During pre-training, we run a parallel process to re-embed corpus documents and rebuild the MIPS index. This results in one index refresh per approximately 500 training steps. To demonstrate the importance of frequent index refreshes, we compare against using a slower refresh rate. The results in Table 2 suggests that a stale index can hurt model training, and further reducing this staleness could offer better optimization.

Examples of retrieved documents

Table 3 shows an example of the REALM masked language model prediction. In this example, “Fermat” is the correct word, and REALM (row (c)) gives the word a much high probability compared to the BERT model (row (a)). Since REALM manages to retrieve some documents with a related fact (row (b)), the marginalized probability of the correct answer dramatically increases. This shows that REALM is able to retrieve document to fill in the masked word even though it is trained with unsupervised text only.

Discussion and Related Work

We previously discussed related methods for Open-QA. Here we present several alternate ways of viewing REALM that connect it to a broader set of ideas beyond Open-QA:

Language representation models have been incorporating contexts of increasingly large scope when making predictions. Examples of this progression include models that condition on surrounding words (Mikolov et al., 2013a, b), sentences (Kiros et al., 2015; Peters et al., 2018), and paragraphs (Radford et al., 2018; Devlin et al., 2018). We can view REALM as a generalization of the above work to the next level of scope: the entire text corpus.

Retrieve-and-edit with learned retrieval

In order to better explain the variance in the input text and enable controllable generation, Guu et al. (2018) proposed a language model with the retrieve-and-edit framework (Hashimoto et al., 2018) that conditions on text with high lexical overlap. REALM has a similar approach, except that the model learns for itself which texts are most useful for reducing perplexity. By jointly learning the retriever, REALM has the capacity to depend on information beyond lexical overlap.

Scalable grounded neural memory

The document index can be viewed as a memory where the keys are the document embeddings. From this view, our work share motivations with works such as product key memory (Lample et al., 2019), which enables sub-linear memory access in a memory network (Weston et al., 2014; Graves et al., 2014; Sukhbaatar et al., 2015), allowing these scalable memory layers to be integrated into large language models. One main difference is that our memories are grounded—each memory is associated with a document rather than unnamed value vectors. This level of interpretability is crucial for applications like Open-QA, where users would require provenance for a predicted answer to be trustworthy.

Unsupervised Corpus Alignment

In sequence-to-sequence models with attention (Bahdanau et al., 2014), text is generated with latent selection of relevant tokens. This results in a set of model-centric unsupervised alignments between target and source tokens. Analogously, REALM also generates text with latent selection of relevant documents. A by-product of our method is that we offer a set of model-centric unsupervised alignments between text in the pre-training corpus X\mathcal{X} and knowledge corpus Z\mathcal{Z}.

Future Work

The work presented here is the minimal instantiation of a family of REALM-like approaches where a representation is pre-trained to perform reasoning over a large corpus of knowledge on-the-fly during inference. We are particularly optimistic about generalizations of this work to (1) structured knowledge, which would result in a generalization of Peters et al. (2019) where we would also learn the decision of which entities are informative, (2) the multi-lingual setting, e.g., retrieving knowledge in a high-resource language to better represent text in a low-resource language, and (3) the multi-modal setting, e.g., retrieving images or videos that can provide knowledge rarely observed in text.

References

Appendix A Derivation of the gradient with respect to the knowledge retriever

We compute the gradient of the REALM pre-training objective (a log-likelihood) with respect to the parameters of the knowledge retriever, θ\theta:

where the last line follows from applying conditional Bayes’ rule. We can then expand logp(zx)\nabla\log p\left(z\,|\,x\right) as:

Plugging this back into the first set of equations yields: logp(yx)\displaystyle\nabla\log p\left(y\,|\,x\right) =zp(zy,x)[f(x,z)zp(zx)f(x,z)]\displaystyle=\sum_{z}p\left(z\,|\,y,x\right)\left[\nabla f(x,z)-\sum_{z^{\prime}}p\left(z^{\prime}\,|\,x\right)\nabla f(x,z^{\prime})\right] =zp(zy,x)f(x,z)zp(zx)f(x,z)\displaystyle=\sum_{z}p\left(z\,|\,y,x\right)\nabla f(x,z)-\sum_{z^{\prime}}p\left(z^{\prime}\,|\,x\right)\nabla f(x,z^{\prime}) =z[p(zy,x)p(zx)]f(x,z)\displaystyle=\sum_{z}\left[p\left(z\,|\,y,x\right)-p\left(z\,|\,x\right)\right]\nabla f(x,z) =z[p(yz,x)p(zx)p(yx)p(zx)]f(x,z)\displaystyle=\sum_{z}\left[\frac{p\left(y\,|\,z,x\right)p\left(z\,|\,x\right)}{p\left(y\,|\,x\right)}-p\left(z\,|\,x\right)\right]\nabla f(x,z) =z[p(yz,x)p(yx)1]p(zx)f(x,z).\displaystyle=\sum_{z}\left[\frac{p\left(y\,|\,z,x\right)}{p\left(y\,|\,x\right)}-1\right]p\left(z\,|\,x\right)\nabla f(x,z).

In the second line, we used the fact that the overall expression is an expectation with respect to p(zy,x)p\left(z\,|\,y,x\right), and the terms which depend on zz^{\prime} but not zz can be moved out of that expectation.

Appendix B Connection between REALM and supervised learning

From the equations in Appendix A, we saw that

Suppose that there exists one document zz^{*} which causes the model to achieve perfect prediction accuracy (i.e., p(yz,x)=1p\left(y\,|\,z^{*},x\right)=1), while all other documents zz^{\prime} result in zero accuracy (i.e., p(yz,x)=0p\left(y\,|\,z^{\prime},x\right)=0). Under this setting, p(zy,x)=1p\left(z^{*}\,|\,y,x\right)=1 (provided that p(zx)p\left(z^{*}\,|\,x\right) is non-zero), which causes the gradient to become

From this, we see that gradient descent on the REALM objective is equivalent to gradient descent on logp(zx)\log p\left(z^{*}\,|\,x\right). This is none other than the typical maximum likelihood training objective used in supervised learning, where zz^{*} is the “gold” document.

Appendix C Adapting to new knowledge

An explicit retrieval system allows us to adapt to new world knowledge simply by modifying the corpus documents. To demonstrate this ability, we replace the knowledge corpus with a more recent version of Wikipedia corpus after pre-training is done. When the input query is about a fact where the two corpora disagree, REALM can change the prediction to reflect the updated information, as exemplified in Table 4. However, even with an explicit retrieval mechanism, the knowledge-augmented encoder will still end up remembering some world knowledge, making the prediction of some input sentences not updated with the new corpus. (For instance, the model predicts “Thatcher” for “ is the prime minister of United Kingdom.” on both corpora, perhaps due to the frequent mention of her name in Wikipedia articles.)

Appendix D Retrieval Utility

The null document \varnothing described in Section 3.4 provides a way to measure the importance of a retrieved document zz: we define the retrieval utility (RU) of zz for the masked input xx as the difference between the log-likelihood of the knowledge-augmented encoder when conditioning on zz versus on \varnothing:

A negative RU shows that zz is less useful for predicting yy than the null document. This could mean that zz is irrelevant to xx, but could also mean that the masked tokens in xx do not require world knowledge to predict, or that the world knowledge is sufficiently commonplace it has been baked into the model’s parameters. In practice, we find that RU increases steadily over the course of pre-training, and is more predictive of good performance on the downstream task of Open-QA than even the overall log-likelihood. An example of how RU behaves over time and across different settings is in Figure 4.