End-to-End Training of Multi-Document Reader and Retriever for Open-Domain Question Answering

Devendra Singh Sachan, Siva Reddy, William Hamilton, Chris Dyer, Dani Yogatama

Introduction

Open-domain question answering (OpenQA) is a question answering task where the goal is to train a language model to produce an answer for a given question. In contrast to many question answering tasks, an OpenQA model is only provided with the question as its input without accompanying documents that contain the answer. One of the most promising approaches to OpenQA is based on augmenting the language model with an external knowledge source such as Wikipedia (often referred to as the evidence documents). In this approach, the model consists of two core components (Chen et al., 2017): (i) an information retrieval system to identify useful pieces of text from the knowledge source (the retriever); and (ii) a system to produce the answer given the retrieved documents and the question (the reader).

We can view such a model as a latent variable model, where the latent variables represent retrieved documents that are used to produce answers given questions (Lee et al., 2019). End-to-end (joint) training of this model is challenging since we need to learn both to generate an answer given retrieved documents and what to retrieve. Previous work considers two potential solutions (see Table 1 for a high-level summary). First, they adopt a stage-wise training, where the retriever is trained while freezing the reader and vice versa (Karpukhin et al., 2020, Izacard and Grave, 2021b, a). Another alternative is to constraint the reader to condition on each retrieved document individuallyThis makes marginalization over the latent variables easier since we only need to consider one document at a time rather than multiple documents at once. (Guu et al., 2020)—sometimes with extra supervision for the latent variables in the form of the relevant document for a question (Lewis et al., 2020b).

In this paper, we consider a retrieval-augmented question answering model that combines information from multiple documents when generating answers. Expectation-maximization (Dempster et al., 1977) offers a principled template for learning this class of latent variable models. We present Emdr2: End-to-end training of Multi-Document Reader and Retriever (§2). Emdr2 iteratively uses feedback from the model itself as “pseudo labels” of the latent variables for optimizing the retriever and reader parameters. We use two estimates of the latent variables: (i) prior scores for updating the reader parameters and (ii) approximate posterior scores given all observed variables for the retriever parameters.

We evaluate our proposed method by experimenting on three commonly used OpenQA datasets: Natural Questions, TriviaQA, and WebQuestions (§3). Emdr2 achieves new state-of-the-art results for models of comparable size on all datasets, outperforming recent approaches by 2-3 absolute exact match points. We also show that Emdr2 is robust to retriever initialization. It achieves high accuracy with unsupervised initialization, suggesting that supervised training of the retriever may not be an essential component of the training process as suggested in prior work (Karpukhin et al., 2020).

In summary, our contributions are as follows: (i) we present an end-to-end training method (Emdr2) for retrieval-augmented question-answering systems; (ii) we demonstrate that Emdr2 outperforms other existing approaches of comparable size without any kind of supervision on the latent variables; (iii) we provide ablation studies for a better understanding of the contributions of different components of our proposed method; and (iv) we release our code and checkpoints to facilitate future work and for reproducibility.Our code is available at: https://github.com/DevSinghSachan/emdr2

Emdr2 is a framework that can be used to train retrieval-augmented text generation models for any task. We believe that our estimation technique in Emdr2 is also useful for learning similar latent variable models in other domains.

Model

Our proposed model Emdr2 consists of two components: (i) a neural retriever and (ii) a neural reader, which we train jointly in an end-to-end setting. Figure 1 shows an illustration of our model and training procedure. We discuss each component and our training objective in detail below.

Let the collection of evidence documents be denoted by D={d1,,dM}\mathcal{D}=\{\boldsymbol{d}_{1},\ldots,\boldsymbol{d}_{M}\}. Given a question q\boldsymbol{q}, the goal of the retriever module is to select a subset of documents ZD\mathcal{Z}\subset\mathcal{D} to answer the question. We model the retriever as a dual-encoder network (Bromley et al., 1994), where one encoder fqf_{q} encodes the question and another fdf_{d} encodes the evidence document (to a vector). The retrieval score is defined as the dot product between the two resulting vectors:

where Φ=[Φq,Φd]\Phi=[\Phi_{q},\Phi_{d}] denotes the retriever parameters. We select top-KK documents for the question q\boldsymbol{q} from D\mathcal{D} based on the retrieval scores. We denote the set of retrieved documents by Z={z1,,zK}\mathcal{Z}=\{\boldsymbol{z}_{1},\ldots,\boldsymbol{z}_{K}\}.

We use transformer encoders (Vaswani et al., 2017) as our fqf_{q} and fdf_{d}. Our transformer architecture is similar to BERT with 12 layers and 768 hidden size (Devlin et al., 2019). We use the final representation of the first token (i.e., the standard [CLS] token from BERT’s tokenization) as our question (and similarly document) embedding. Initializing fqf_{q} and fdf_{d} with BERT weights has been shown to lead to a poor retrieval accuracy (Lee et al., 2019, Sachan et al., 2021). Therefore, we initialize the retriever with an unsupervised training procedure. We discuss our initialization technique in detail in §3.2.

2 Neural Reader: Fusion-in-Decoder

The reader takes as input a question q\boldsymbol{q} and a set of retrieved documents (to be read) Z\mathcal{Z} to generate an answer. Our reader is based on the Fusion-in-Decoder (FiD; Izacard and Grave, 2021b) model, which is built on top of T5 (Raffel et al., 2020). T5 is a pretrained sequence-to-sequence transformer that consists of an encoder geg_{e} and a decoder gdg_{d}.

In FiD, each retrieved document zk\boldsymbol{z}_{k} is first appended with its title (tzk\boldsymbol{t}_{\boldsymbol{z}_{k}}) and the question:

where [CLS] is used to indicate the start of a document and [SEP] is used as a separator for the different parts of the document as well as the final token.

Each xk\boldsymbol{x}_{k} is then independently given as an input to the T5 encoder geg_{e}. The output representations corresponding to all of the retrieved documents are concatenated as:

where NN is the number of tokens in each xk\boldsymbol{x}_{k}We truncate and pad as necessary such that every xk\boldsymbol{x}_{k} has the same length NN. See §3.2 for details. and HH is the hidden size of the T5 encoder geg_{e}. In this work, we use the T5-base configuration with N=512N=512 and H=768H=768.

XZ\mathbf{X}_{\mathcal{Z}} is then given as an input to the T5 decoder gdg_{d}. When generating an answer token, the decoder attends to both previously generated tokens (i.e., causal attention) as well as the tokens encoded in XZ\mathbf{X}_{\mathcal{Z}} (i.e., cross attention). Since XZ\mathbf{X}_{\mathcal{Z}} contains information from multiple documents, the decoder has the ability to aggregate useful signals contained in multiple documents and jointly reason over them. We define the probability of the answer as:

where Θ\Theta denotes the reader parameters (i.e., T5 encoder and decoder) and TT is the number of answer tokens. We keep generating answer tokens until the decoder outputs a special EOS token or a pre-specified maximum answer length is reached.

3 End-to-End Training of Reader and Retriever

In contrast to previous work on generative question answering, we train both the reader and the retriever jointly in an end-to-end differentiable fashion.

Denote our latent variable which represents a set of retrieved documents by ZZ and let Z\mathcal{Z} be a possible value of ZZ. The marginal likelihood of an answer (marginalizing over all the possible values of ZZ) is: p(aq;Θ,Φ)=Z=Zp(aq,Z;Θ)p(Zq;Φ)p(\boldsymbol{a}\mid\boldsymbol{q};\Theta,\Phi)=\sum_{Z=\mathcal{Z}}p(\boldsymbol{a}\mid\boldsymbol{q},\mathcal{Z};\Theta)p(\mathcal{Z}\mid\boldsymbol{q};\Phi). The goal of our training procedure is to find Φ\Phi and Θ\Theta that would maximize the above objective. Exactly optimizing Eq. 3 is intractable as it is combinatorial in nature. Contrast our objective with REALM (Guu et al., 2020), where the reader only conditions on one retrieved document zk\boldsymbol{z}_{k} when generating an answer. In this case, the latent variable represents a document assignment instead of a set of retrieved documents. For one particular value Z\mathcal{Z}, the log-likelihood is simpler to compute: logp(aq,Z;Θ)p(Zq;Φ)=logp(aq,Z;Θ)+logp(Zq;Φ)\log p(\boldsymbol{a}\mid\boldsymbol{q},\mathcal{Z};\Theta)p(\mathcal{Z}\mid\boldsymbol{q};\Phi)=\log p(\boldsymbol{a}\mid\boldsymbol{q},\mathcal{Z};\Theta)+\log p(\mathcal{Z}\mid\boldsymbol{q};\Phi).

Expectation-maximization (EM) algorithm (Dempster et al., 1977) offers a solution to learning this latent variable model. In classical EM, we iteratively compute the posterior of ZZ given all observed variables and use it to update Θ\Theta and Φ\Phi.

We propose using two estimates of ZZZreader\mathcal{Z}_{\text{reader}} and Zretriever\mathcal{Z}_{\text{retriever}}—for updating the two components of the model (reader parameters Θ\Theta and retriever parameters Φ\Phi):

In the first term, we set the value of the latent variable Z=ZreaderZ=\mathcal{Z}_{\text{reader}} based on the prior scores. In the second term, we seek to maximize an approximate posterior of Z=ZretrieverZ=\mathcal{Z}_{\text{retriever}}. We discuss them in more detail below.

For updating Θ\Theta (the first term of Eq. 3), we use the top-KK documents with the highest individual scores (as computed by Eq. 1 based on the current value of Φ\Phi) to construct Zreader\mathcal{Z}_{\text{reader}}. This is equivalent to relying on the prior p(Zq;Φ)p(Z\mid\boldsymbol{q};\Phi) to estimate Zreader\mathcal{Z}_{\text{reader}} (without using information from the answer a\boldsymbol{a}). We choose to use the prior to train reader parameters since the prior scores are also used at evaluation time to obtain the top-KK documents. As a result, there is no mismatch between training and test computations when computing p(aq,Z;Θ)p(\boldsymbol{a}\mid\boldsymbol{q},\mathcal{Z};\Theta) (i.e., Z\mathcal{Z} that is used at test time is obtained in exactly the same way as Zreader=Ztop-K\mathcal{Z}_{\text{reader}}=\mathcal{Z}_{\text{top-}K}).

Retriever parameters ΦΦ\Phi.

For updating Φ\Phi (the second term of Eq. 3), we propose to use the posterior estimate. In other words, we use additional information from a\boldsymbol{a} when evaluating ZretrieverZ_{\text{retriever}} to train Φ\Phi. Using the posterior allows our retriever to learn from richer training signals as opposed to relying only on the prior.

We need to be able to compute p(Zretrieverq,a;Θ,Φ)p(\mathcal{Z}_{\text{retriever}}\mid\boldsymbol{q},\boldsymbol{a};\Theta,\Phi) to maximize the retriever parameters. However, computing this quantity is difficult since it is a probability of a set.This is true whether we choose to use the posterior probability or the prior probability. Consider a set of KK documents (e.g., Ztop-K\mathcal{Z}_{\text{top-}K}), where zk\boldsymbol{z}_{k} denotes a document in the set. We approximate the maximization of the probability of the set by assuming that its probability is maximized if the sum of the probability of each document in the set is maximized.The intuition is that each element of the set contributes independently, which greatly simplifies the computation to find the maximum of the set. With this approximation, we arrive at a simpler quantity: k=1Kp(zkq,a;Θ,Φ)\sum_{k=1}^{K}p(\boldsymbol{z}_{k}\mid\boldsymbol{q},\boldsymbol{a};\Theta,\Phi). Note that using Bayes rule, we can rewrite:We choose not to normalize with p(aq;Θ,Φ)p(\boldsymbol{a}\mid\boldsymbol{q};\Theta,\Phi) since computing this quantity would require summing over all evidence documents MM. While this makes the resulting objective that we optimize not correspond to a proper probability distribution anymore, we observe that our training method still behaves well in practice.

The reader now only conditions on one document when computing the probability of an answer p(aq,zk;Θ)p(\boldsymbol{a}\mid\boldsymbol{q},\boldsymbol{z}_{k};\Theta). This simpler reader uses the same parameters as the more sophisticated one Θ\Theta, but it only uses one document zk\boldsymbol{z}_{k} instead of a set of documents.

To compute Eq. 4, we first obtain KK documents with the highest scores as computed by Eq. 1 based on the current value of Φ\Phi. We compute the probability of document zkZtop-K\boldsymbol{z}_{k}\in\mathcal{Z}_{\text{top-}K} as:

where τ\tau is a temperature hyperparameter and the approximation assumes that documents beyond the top-KK contributes very small scores so we do not need to sum over all evidence documents MM in the denominator (which is in the order of tens of millions in our experiments). We then compute p(aq,zk;Θ)p(\boldsymbol{a}\mid\boldsymbol{q},\boldsymbol{z}_{k};\Theta) similarly to Eq. 2.

Overall training objective of Emdr2.

Combining the above derivations, our end-to-end training objective that we seek to maximize for a particular example becomes:

Given a training example, we update Θ\Theta and Φ\Phi by taking gradients of Eq. 6 with respect to Θ\Theta and Φ\Phi in an end-to-end fashion. Intuitively, we train the reader to generate the correct answer given KK highest scoring documents Ztop-K\mathcal{Z}_{\text{top-}K}. For the retriever, we train it to select KK documents which collectively has a high score of generating an answer (since the sum over KK is inside the log in the second term) while taking into account feedback from the reader. Algorithm 1 summarizes our training algorithm.

Experiments

We experiment with three commonly used open-domain question answering datasets:

Natural Questions (NQ; Kwiatkowski et al., 2019). NQ contains questions asked by users of the Google search engine. Similar to Lee et al. (2019), we use the short answer subset.

TriviaQA (Joshi et al., 2017). TriviaQA is a collection of trivia question-answer pairs that were collected from multiple sources on the web.

WebQuestions (WebQ; Berant et al., 2013). WebQ questions were collected using Google Suggest API and the answers were annotated using Mechanical Turk. We use the version from Chen et al. (2017) where Freebase IDs in the answers are replaced by entity names.

We use the preprocessed English Wikipedia dump from December 2018 released by Karpukhin et al. (2020) as our evidence documents. Each Wikipedia article is split into non-overlapping 100 words long segments. Each segment corresponds to a document in our case. There are a total of 21,015,324 documents in total.

We provide descriptive statistics and other preprocessing details in Appendix A.

2 Implementation Details

We run all of our experiments on a machine with 96 CPUs, 1.3TB physical memory, and 16 A100 GPUs. We use PyTorch (Paszke et al., 2019) to implement our proposed model and relevant baselines.

Model configurations.

For both the retriever and reader, we use the base configuration that consists of 12 layers, 768 dimensional hidden size, and 12 attention heads. In all experiments, we retrieve 50 documents, unless stated otherwise. We only use the base configuration in our experiments due to GPU memory constraints. However, we believe that our results would generalize to larger configurations as well.

Retrieval.

To support fast retrieval, we pre-compute evidence document embeddings and store them in a distributed fashion over all the GPUs. We refer to these document embeddings as the document index. For each question, we retrieve documents in an online (on-the-fly) manner by performing exact maximum inner product search (MIPS), implemented using asynchronous distributed matrix multiplication over the document index. These documents are converted to subwords using BERT’s tokenization and are given as input to the T5 reader. If a tokenized document is shorter than 512 tokens, it is padded using the tokens from the neighboring documents until the maximum token limit is reached. Such padding additionally helps to provide an extended context for answer generation.

Initialization and training details.

We initialize the parameters of the model with unsupervised pre-training before performing supervised training using the question-answer training examples. Unsupervised pre-training is essential as it helps to warm-start the retriever so that it outputs relevant documents for a given question.

We first pre-train the retriever parameters with unsupervised Inverse Cloze Task training (Lee et al., 2019) for 100,000 steps. We then extract sentences containing named entities from the evidence documents. Next, we replace 15% of the named entity tokens with masked tokens, which are often referred to as masked salient spans (MSS; Guu et al., 2020). The masked sentence can be considered as the question and its salient spans (i.e, named entities) can be considered as the answer to train the model with Eq. 6. We train the model on these question-answer (masked sentence-named entities) pairs for 82,000 steps with a batch size of 64 using Adam (Kingma and Ba, 2015). We refer to this initialization method as unsupervised pre-training with masked salient spans. We provide further description in Appendix C.

After MSS training, we finetune the model on the dataset-specific question-answer training examples with Emdr2. We perform training for 10 epochs on NQ and TriviaQA with a batch size of 64, and for 20 epochs on WebQ with a batch size of 16. During training, we save a checkpoint every 500 steps and select the best checkpoint based on its performance on the development set.

During end-to-end training, since the parameters of the document encoder (fdf_{d}) are also updated at every step, the pre-computed document embeddings become stale as training progresses. We use the most recent document encoder checkpoint to compute fresh document embeddings asynchronously with which the document index is updated after every 500 training steps to prevent staleness.

Inference.

We use greedy decoding for answer generation at inference time.

3 Baselines

We compare our model to other approaches for OpenQA that can be categorized under the following two classes:

Closed-book QA models. Large-scale language models capture a lot of world knowledge in their parameters derived from the corpus they have been trained on (Petroni et al., 2019). We compare with the work of Roberts et al. (2020) who show that larger T5 models—when finetuned with question-answer pairs—can perform remarkably well. We also compare with the few-shot results of GPT-3 (Brown et al., 2020).We note that GPT-3 is not trained on the full training examples that we use, so the results are not directly comparable.

Open-book QA models. Similar to this work, these models consist of retriever and reader components and adopt the retrieve then predict approach for answering questions given a collection of evidence documents. These models mainly differ in how the retriever is initialized (ORQA; Lee et al., 2019, DPR; Karpukhin et al., 2020), whether the reader processes a single document (ORQA, DPR, RAG; Lewis et al., 2020b) or multiple documents (FiD; Izacard and Grave, 2021b), or whether the reader and retriever are trained jointly or in a multistage process (REALM; Guu et al., 2020, FiD-KD; Izacard and Grave, 2021a).

4 Results

We follow standard conventions and report exact match (EM) scores using the reference answers included in each dataset. Table 2 shows our main results. We divide the table into three main sections: closed-book QA models, open-book QA models, and our implementation. The first two sections contain results from other papers, which we include for comparisons. The last section includes results from our proposed model, as well as our reimplementation of relevant baselines to control for our experimental setup.

Our reimplementation of T5-base provides strong baselines when the number of retrieved documents is set to 0 (no retrieval) and 1. From Table 2, we see that the setting of top-11 vastly improves performance over the setting with no retrieved documents, signifying the importance of retrieval for OpenQA tasks. When further increasing the top-kk documents to 50, the performance of the FiD models substantially improves over the top-11 retrieval, verifying the observation from (Izacard and Grave, 2021b) about the importance of modeling the retrieved documents as a set.

Comparing Emdr2 with our reimplementation of FiD illustrates the benefit of our end-to-end training approach. The underlying model is similar in both cases, but the training method is different. FiD adopts a two-stage approach to first train the retriever and then the reader. We have three variants of FiD: (i) the reader and retriever are initialized with MSS training, (ii) the retriever is initialized with DPR training, which is the setting used in the original paper (Izacard and Grave, 2021b), and (iii) the retriever is initialized with MSS + DPR training from (Sachan et al., 2021), as it further improves DPR recall. Emdr2 outperforms all the variants by large margins on all the datasets.

The current best approach for training multi-document reader and retriever is FiD-KD (Izacard and Grave, 2021a). FiD-KD is a complex training procedure that requires multiple training stages and performs knowledge distillation with inter-attention scores. We take the results from the original paper when comparing our model with FiD-KD. Emdr2 outperforms the reported numbers of FiD-KD by more than 2.5 points on NQ and TriviaQA to obtain new state-of-the-art results on these benchmarks.

In addition to better performance, Emdr2 also has three other advantages compared to FiD-KD: (i) Emdr2 is more efficient since it only uses 50 evidence documents, whereas FiD-KD leverages 100 documents; (ii) FiD-KD is based on a distillation approach which requires multiple cycles of retriever and reader training, while Emdr2 only requires one cycle of end-to-end training; and (iii) FiD-KD relies on supervised initialization of the retriever to achieve its best performance. Emdr2 is more robust to the retriever initialization, as demonstrated by state-of-the-art results even with unsupervised initialization of the retriever.

For the WebQ dataset, the training set size is much smaller compared to the other datasets (Table 5). Previous approaches such as RAG rely on supervised transfer (i.e., they finetune a model pre-trained on NQ) to obtain good results. In contrast, Emdr2 improves over the results from this RAG model by 3.5 points without the supervised transfer step. This result demonstrates the applicability of our approach to the low-resource setting where we only have a limited number of training examples.

We also perform qualitative analysis of the model outputs, which is included in Appendix E.

5 Ablations

We investigate the performance of Emdr2 and FiD as we vary the number of retrieved documents KK in Figure 2. We observe that when the number of retrieved documents is increased, both Emdr2 and FiD improve in performance. When KK is small, the gap between Emdr2 and FiD is larger. This indicates the efficacy of Emdr2 in a more constrained setting where we can only retrieve a small number of documents (e.g., due to memory limitations).

Retriever initialization. We explore the effect of different parameter initialization strategies when training with Emdr2: (i) unsupervised MSS pre-training, (ii) supervised retriever training (DPR), and (iii) MSS pre-training followed by supervised retriever training (MSS + DPR; Sachan et al. (2021)). Table 3 shows our results. We can see that on NQ, MSS pre-training being unsupervised leads to a lower initial retriever recall than DPR. After Emdr2 training, the recall improves by 20% (highlighted in yellow cells). Training with DPR initialization leads to the same final recall as obtained by MSS pre-training, suggesting that DPR initialization of the retriever may not be an essential component to obtain good performance in OpenQA tasks. Similar trends are also observed on TriviaQA and WebQ. Similarly, MSS + DPR initialization has a better initial recall but leads to a marginal or no improvements in answer extraction performance over MSS pre-training. Finally, we also observe that MSS pre-training also provides an improvement of 2 points in answer extraction on WebQ when compared to the T5 reader (shown in orange cells), highlighting its importance in the low-resource OpenQA tasks.

6 Alternative End-to-End Training Objectives

We compare Emdr2 objective (Eq. 6) to two alternative formulations for end-to-end training.

In the first alternative formulation, when training the retriever parameters Φ\Phi, we simply factorize p(Zq;Φ)=k=1Kp(zkq;Φ)p(\mathcal{Z}\mid\boldsymbol{q};\Phi)=\prod_{k=1}^{K}p(\boldsymbol{z}_{k}\mid\boldsymbol{q};\Phi) to arrive at the following objective:

The second term in this objective is maximised by a uniform retrieval, in other words, by removing any discrimination between documents in the retriever. We include it to show the impact of an adversarial objective.

Intuitively, we try to match the probability of retrieving a document zk\boldsymbol{z}_{k} with the “contribution” of that document to the generated answer a\boldsymbol{a}, regardless of whether the retriever is relatively more or less likely to retrieve the document a priori.

Table 4 shows our results on the development set of NQ. We observe that training with the adversarial Lalt-1\mathcal{L}_{\text{alt-1}} objective diverges, leading to poor performance, as expected. This shows that harming the retriever during training can significantly harm performance of the QA system. In contrast, although it disregards the estimated prior, the Lalt-2\mathcal{L}_{\text{alt-2}} objective still improves over the FiD baseline for NQ and TriviaQA. However, it still lags behind Emdr2. On WebQ, the Lalt-2\mathcal{L}_{\text{alt-2}} objective diverges and leads to a poor performance. We leave further analysis on the convergence of Lalt-2\mathcal{L}_{\text{alt-2}} objective as a part of future work.

Related Work

Our work is based on end-to-end training of neural readers and retrievers, which we discuss in §1, §2, and §3. Here we instead focus on discussing previous work related to standalone neural retrievers, neural readers, and their application in other natural language processing tasks.

Neural retrievers. There are two broad classes of neural retrievers based on the number of embeddings computed for a document: dual encoders (Yih et al., 2011, Lee et al., 2019) and multivector encoders (Khattab and Zaharia, 2020, Luan et al., 2021). Dual encoders store one embedding for each evidence document. Multivector encoders require multiple embeddings, which can be computationally expensive for large-scale retrieval. Due to the large size of the evidence document collection in OpenQA, our work uses the more efficient dual-encoder. Sachan et al. (2021) show that the performance of supervised dual encoders in OpenQA can be improved when pre-training with the Inverse Cloze Task for the high-resource setting or masked salient spans for the low-resource setting.

Neural readers. Neural readers output an answer given retrieved documents as its input. There are also two broad classes of neural readers: extractive and generative. Extractive readers (Clark and Gardner, 2018, de Masson d’Autume et al., 2019, Wang et al., 2019, Guu et al., 2020, Karpukhin et al., 2020) extract a span from a retrieved document to produce an answer. Generative readers (Izacard and Grave, 2021b), on the other hand, generates an answer conditioned on the retrieved documents.

Other application areas. In addition to question answering, retrieval-augmented methods have been successfully applied to other natural language processing tasks. In left-to-right language modeling, retrieving similar words from an external memory has been shown to improve perplexity (Khandelwal et al., 2020, Yogatama et al., 2021). In machine translation, retrieving domain-specific target language tokens has improved performance in domain adaptation (Khandelwal et al., 2021). Finally, in dialog modeling, retrieving knowledge-informed text has helped improve factual correctness in the generated conversations (Fan et al., 2021).

We provide a detailed comparison of Emdr2 with some of the previous work in Appendix C and D.

Discussion

We presented Emdr2, an end-to-end training method for retrieval-augmented question answering systems. We showed how to arrive at our training objective using the expectation-maximization algorithm. We demonstrated that Emdr2 achieves state-of-the-art performance on three benchmark OpenQA datasets.

Technical limitations.

Emdr2 shares a few limitations with other retrieval-augmented question answering models. In particular, as evidence documents are stored in an uncompressed format, maintaining them and searching for relevant documents can be expensive (both in terms of compute and memory consumption). In our experiments, we only focused on open-domain question answering. It would be interesting to see how Emdr2 performs for other text generation models as well. We also note that training is relatively resource-heavy (requiring 16 GPUs), potentially having environmental concerns.

Potential negative societal impacts.

While Emdr2 has the potential to improve language models in the low-resource setting (as demonstrated by our results on WebQ in §3.4), it could exhibit typical biases that are associated with large language models. For example, our model does not have an explicit mechanism to generate answers that are calibrated for fairness across all spectra. As a retrieval-augmented method, it also could be more prone to generating fake answers if an attacker manages to have access and modify information in the collection of evidence documents.

Acknowledgements

The authors would like to thank the DeepMind Language team, Mila’s students, and anonymous reviewers for providing us valuable feedback and useful suggestions about this work that helped us improve the paper.

Funding Statement

DSS was supported by the Canada CIFAR AI Chair held by Prof. William Hamilton.

References

Checklist

Do the main claims made in the abstract and introduction accurately reflect the paper’s contributions and scope? [Yes] Please see the model (§2) and result (§3) sections that solidify the claims made in the abstract and introduction sections.

Did you describe the limitations of your work? [Yes] Please see limitations in §5.

Did you discuss any potential negative societal impacts of your work? [Yes] Please see negative societal impact in §5.

Have you read the ethics review guidelines and ensured that your paper conforms to them? [Yes]

If you are including theoretical results…

Did you state the full set of assumptions of all theoretical results? [N/A]

Did you include complete proofs of all theoretical results? [N/A]

Did you include the code, data, and instructions needed to reproduce the main experimental results (either in the supplemental material or as a URL)? [Yes] We include the code, data, and instructions in the supplemental material and §3.2.

Did you specify all the training details (e.g., data splits, hyperparameters, how they were chosen)? [Yes] We specify these details in the appendix included in the supplementary material.

Did you report error bars (e.g., with respect to the random seed after running experiments multiple times)? [No] Our experiments are compute expensive and it is not feasible to perform multiple runs of the same experiment with different seeds. All our training runs use the same seed value (1234). As an alternative to running multiple seeds, we perform a number of ablation studies (§3.5).

Did you include the total amount of compute and the type of resources used (e.g., type of GPUs, internal cluster, or cloud provider)? [Yes] Please see §3.2 under hardware and library.

If you are using existing assets (e.g., code, data, models) or curating/releasing new assets…

If your work uses existing assets, did you cite the creators? [Yes] Please see §3.1 for the details.

Did you mention the license of the assets? [Yes] Our work is based on open-source data and framework. When applicable, we describe the license information in the appendix.

Did you include any new assets either in the supplemental material or as a URL? [Yes] We include our code in the supplementary material.

Did you discuss whether and how consent was obtained from people whose data you’re using/curating? [N/A]

Did you discuss whether the data you are using/curating contains personally identifiable information or offensive content? [N/A]

If you used crowdsourcing or conducted research with human subjects…

Did you include the full text of instructions given to participants and screenshots, if applicable? [N/A]

Did you describe any potential participant risks, with links to Institutional Review Board (IRB) approvals, if applicable? [N/A]

Did you include the estimated hourly wage paid to participants and the total amount spent on participant compensation? [N/A]

Appendix A Dataset Details

For validation, we randomly select approximately 10% examples from the training set. For all the datasets, we use the dataset splits from (Lee et al., 2019). We provide the size of the training, development, and test sets in Table 5.

Pre-processing.

For TriviaQA experiments, following (Izacard and Grave, 2021a), we select human-annotated answers for training the QA model. We also filter out those questions whose answer length is more than 55 words. Overall, this filters out 2,362 examples from the training set.

Dataset license and URLs.

All the datasets are open-source and widely used by the community. Below, we provide the URLs of the actual dataset source and their preprocessed version which is used in this work.

NQ: dataset: https://ai.google.com/research/NaturalQuestions/download, license: https://github.com/google-research-datasets/natural-questions/blob/master/LICENSE

TriviaQA: dataset: http://nlp.cs.washington.edu/triviaqa/, license: https://github.com/mandarjoshi90/triviaqa/blob/master/LICENSE

WebQ: dataset: https://github.com/google-research/language/tree/master/language/orqa#getting-the-data, license: https://nlp.stanford.edu/software/sempre/

Preprocessed version: We make use of NQ, TriviaQA, and evidence datasets as open-sourced by Karpukhin et al. (2020) here: https://github.com/facebookresearch/DPR/blob/master/data/download_data.py.

Appendix B Additional Training Details

In addition to the details provided in §3.2, here, we provide further training details for reproducibility.

We derive the implementations of BERT (Devlin et al., 2019) and ICT (Lee et al., 2019) from the open-source Megatron-LM toolkit.https://github.com/NVIDIA/Megatron-LM For ICT, the dual-encoder retriever is initialized with BERT weights and then we train the model according to Lee et al. (2019). For training, we use Wikipedia paragraphs where we truncate the maximum length of a paragraph to 256 tokens. We list the settings and hyperparameters used for training BERT and ICT in Table 6.

T5.

We derive the implementation of T5 (Raffel et al., 2020) language model from the open-source Megatron-LM toolkit (Shoeybi et al., 2019). We list the hyperparameters used for training T5 in Table 6. For consistency, we train T5 for the same number of steps and batch size as was done in the original paper. Additionally, we use BERT lowercase tokenization for both T5 and BERT.

Unsupervised pre-training with masked salient spans (MSS).

For MSS training, we initialize the retriever of our model from the ICT weights and the reader from the T5 weights. We make use of the Stanza toolkit (Qi et al., 2020) to segment evidence documents into sentences. We then extract named entities from these sentences using the NER model trained on the OntoNotes-5.0 dataset as provided by Stanza. These names entities are replaced by mask tokens. As the masked tokens correspond to special named entities, they are referred to as salient spans. The masked sentence is considered as the question to retrieve evidence documents and the reader is trained to generate the named entities corresponding to the masked salient spans with the help of retrieved documents. During retrieval, we ignore the evidence document from which the masked sentence was derived. We list the hyperparameters of MSS training in Table 6.

Supervised training using the question-answer pairs.

We provide the training details in §3.2. We list the hyperparameters in Table 7. Apart from the number of epochs and batch size in WebQ, we use the same hyperparameters for all the experiments. For the temperature parameter (τ\tau) in Eq. 5, we follow Sachan et al. (2021) and set it as the square root of the hidden size.

Training Time.

We run all of our experiments on a machine with 96 CPUs, 1.3TB physical memory, and 16 A100 GPUs. We use PyTorch (Paszke et al., 2019) to implement our proposed model. With this hardware setup, our experiments on NQ and TriviaQA took approximately 25 hours to complete, while experiments on WebQ took roughly 8 hours to complete. Before supervised training, we also perform a one-time unsupervised MSS pre-training for 82,000 steps that took roughly 1 week.

Appendix C Unsupervised Pre-training and Comparisons with REALM

We make use of a couple of training techniques introduced in the REALM paper (Guu et al., 2020): masked salient spans (MSS) pre-training and asynchronous evidence embedding update. There are similarities and differences in the way in which we apply these ideas to Emdr2 training.

Both ICT and MSS are unsupervised techniques used to bootstrap the retriever so that it has a good initial recall.

We first initialize the retriever with ICT pre-training. For ICT, similar to REALM, we follow the settings in the ORQA paper (Lee et al., 2019). We observe our Recall@5 to be much higher than that reported in the REALM paper (see Table 8). We believe that our choice of 768 dimensional embedding of each evidence document leads to better results when compared to the 128 dimensional embedding used in REALM.

We further pre-train with MSS once the retriever weights are initialized with ICT. We use a batch size of 64 and train for 82K steps using the Emdr2 objective. In comparison, REALM uses a batch size of 512 and trains the model for 200K steps. Even with a much smaller batch size and training steps, Emdr2 achieves similar Recall@5 after MSS training (Table 8). We hypothesize that with a large batch size and longer training, Emdr2 would be able to further improve its recall. Another implementation detail is that Emdr2 does not require the additional null document which was used in REALM.

For low-resource datasets such as WebQ, MSS pre-training also improves the performance of the FiD reader. As Table 3 illustrates, on WebQ, MSS pre-trained reader obtains a gain of more than 1 EM point over the T5 reader (shaded in orange color).

C.2 Asynchronous Evidence Embedding Updates

The asynchronous evidence embedding updates are performed after every 500 steps of training and is similar to REALM with a couple of differences. In our work, asynchronous embedding updates is done both during MSS pre-training and supervised training, while in REALM it is performed only during MSS pre-training. The second difference, although a minor one, we needed to compute the embeddings of 21M evidence documents while REALM had to do this for 13M documents. We do this by having two process groups during training, one group trains the model on 8 GPUs while the other group performs evidence embedding computation on 8 GPUs in an asynchronous manner.

C.3 Pre-computed Evidence Embeddings Storage for Retrieval

In Table 9, we provide some comparisons between REALM and Emdr2 to showcase that the retrieval task is more challenging in our setting. Firstly, the size of evidence in REALM is 13M because each Wikipedia article is split into 288 wordpieces while the size of evidence in Emdr2 is 21M as each Wikipedia article is split into 100 linguistic words. Second, the embedding dimension of each evidence document in REALM is 128 while the embedding dimension of each evidence document in Emdr2 is 768. Due to these factors, the memory required by REALM to store evidence embeddings (in FP16) is approximately 3 GB, while the memory required by Emdr2 to store evidence embeddings (in FP16) is 30 GB. As the GPU RAM is constrained by its capacity (40 GB maximum in A100 GPUs), it was not possible to store the entire 30 GB embeddings in each GPU. Therefore, for online retrieval, we store the evidence embeddings in a distributed fashion over 16 GPUs and perform distributed asynchronous MIPS for fast retrieval.

Appendix D Comparison with Previous Work

Here we provide a discussion of how Emdr2 is different from some of the previous work.

There are some similarities between Emdr2 and Lalt-2\mathcal{L}_{\text{alt-2}} to Hard EM (Min et al., 2019) and Reinforced Reader-Ranker (R3\text{R}^{3}; Wang et al. (2018)), at the conceptual level even though they are not equivalent. Training with REINFORCE involves sampling from a policy network (i.e., the retriever in our case). We take a deterministic approach and take the top-K documents in both Emdr2 and Lalt-2\mathcal{L}_{\text{alt-2}}. Compared to Hard EM, Lalt-2\mathcal{L}_{\text{alt-2}} directly minimizes the KL divergence of the probability of a retrieved document with the probability of an answer given that document.

At the implementation level, there are many other differences between Lalt-2\mathcal{L}_{\text{alt-2}} (and Emdr2) with models in (Min et al., 2019) and (Wang et al., 2018). First, we would like to note that both these methods use TF-IDF and BM25 as their retrieval approach which are not trainable. In contrast, our work uses a dense retriever which is trained in an end-to-end manner. We list other differences in more detail below.

Min et al. (2019) propose a hard EM approach to train an extractive reader model for QA tasks. The context document is assumed to contain multiple mentions of the correct answer. They propose an objective to train the reader. Specifically, during the training step, the model is trained using maximum marginal likelihood for the first τ\tau steps and subsequently with their proposed logmax objective. In their open-domain QA experiments on TriviaQA and NQ, the retriever part is based on TF-IDF and BM25 and is non-trainable. Overall, their model is applicable to extractive readers without retriever training. In comparison, in Emdr2, we train both the reader and retriever. As such, the hard EM approach is not directly applicable to our case.

This paper involves three pipelined components: retriever, ranker, and reader. The retriever is BM25 based and is non-trainable. They jointly train the ranker and the reader. The ranker takes 100 documents from the retriever and selects one document to give as input to the reader (contrast this with our work that selects a set of documents). As this selection operation is non-differentiable, their model leverages policy gradient to train the ranker. They also propose a custom reward function based on the overlap of text between the extracted answer and the correct answer. The reader takes a single document as input. In contrast, our approach does not involve a ranker component, both the FiD reader and retriever are trainable, and our proposed objective function Emdr2 is end-to-end differentiable.

D.2 Comparison with Individual Top-K and Joint Top-K Models

Individual Top-K is another approach for end-to-end training but the difference is that it applies a single-document reader while Emdr2 consists of a multi-document reader. Similar to previous methods like REALM and RAG, Individual Top-K objective function is also defined over multiple retrieved documents but is better optimized than them. As the performance of Emdr2 is much better than Individual Top-K, Emdr2 is a better modeling approach.

Comparison with Joint Top-K (Sachan et al., 2021).

While both Emdr2 and Joint Top-K are end-to-end training approaches for open-domain QA based on the FiD model, they are different in many ways. (i) Different Objective Functions: These approaches optimize different training objectives. To achieve retriever training, Joint Top-K adds the retrieval probability score of the top-K documents to the unnormalized inter-attention scores. In this way, the reader pays more importance to those top-K documents with a higher retriever score. There is no explicit feedback from the reader to the retriever. In contrast, the second term in the training objective of Emdr2 explicitly encourages the retriever to improve its predictions based on the agreement with the reader’s answer-generation likelihood of a particular top-K document. (ii) Task Performance: Emdr2 objective leads to a much improved end-to-end training algorithm. This is reflected by the performance gains over the FiD baseline. On NQ and TriviaQA, while Emdr2 leads to 4.3 and 6.4 EM points improvements respectively, Joint Top-K obtains a much lower gain of 1 point improvement on NQ and no improvements on TriviaQA. This demonstrates that EMDR2 training leads to substantially better retrieval, that in turn leads to higher gains in answer generation. These results also illustrate that Emdr2 is a much better end-to-end or joint training algorithm than Joint Top-K for the multi-document reader retriever approaches.

Appendix E Qualitative Analysis

In Table 10, we present some representative examples of the retriever output with both MSS pre-training and when the MSS pre-trained model is finetuned on NQ. We observe that after MSS pre-training, the top-11 outputs are related to the question but are not relevant enough to answer them. However, when the MSS pre-trained model is finetuned on NQ with Emdr2, the retrieval accuracy improves with the top-11 documents being much more relevant to answer the question. The retriever’s confidence score of the top-11 document also improves.

We analyze the reader’s training loss when the retriever is either initialized with unsupervised MSS training or with first MSS pre-training followed by supervised DPR training (MSS + DPR). As indicated in Table 3, MSS pre-training being unsupervised has a lower accuracy while MSS + DPR retriever has a higher accuracy. However, as is also evident from the plots in Figure 4, retriever initialization has a marginal effect on the answer generation performance. We see that for NQ, for the first 1200 steps, the higher accuracy MSS + DPR retriever leads to a smaller training loss compared with the MSS retriever, after which the difference between the two training losses diminishes as the end-to-end training improves the accuracy of the MSS retriever. Similar trends are also observed for TriviaQA and WebQ but to a lesser extent.

Visualizing reader and retriever losses.

In Figure 3, we show the trajectories of the reader and retriever training losses when the model is initialized with MSS pre-training.