Text Understanding with the Attention Sum Reader Network

Rudolf Kadlec, Martin Schmid, Ondrej Bajgar, Jan Kleindienst

Introduction

Most of the information humanity has gathered up to this point is stored in the form of plain text. Hence the task of teaching machines how to understand this data is of utmost importance in the field of Artificial Intelligence. One way of testing the level of text understanding is simply to ask the system questions for which the answer can be inferred from the text. A well-known example of a system that could make use of a huge collection of unstructured documents to answer questions is for instance IBM’s Watson system used for the Jeopardy challenge [Ferrucci et al., 2010].

Cloze-style questions [Taylor, 1953], i.e. questions formed by removing a phrase from a sentence, are an appealing form of such questions (for example see Figure 1). While the task is easy to evaluate, one can vary the context, the question sentence or the specific phrase missing in the question to dramatically change the task structure and difficulty.

One way of altering the task difficulty is to vary the word type being replaced, as in [Hill et al., 2015]. The complexity of such variation comes from the fact that the level of context understanding needed in order to correctly predict different types of words varies greatly. While predicting prepositions can easily be done using relatively simple models with very little context knowledge, predicting named entities requires a deeper understanding of the context.

Also, as opposed to selecting a random sentence from a text as in [Hill et al., 2015]), the question can be formed from a specific part of the document, such as a short summary or a list of tags. Since such sentences often paraphrase in a condensed form what was said in the text, they are particularly suitable for testing text comprehension [Hermann et al., 2015].

An important property of cloze-style questions is that a large amount of such questions can be automatically generated from real world documents. This opens the task to data-hungry techniques such as deep learning. This is an advantage compared to smaller machine understanding datasets like MCTest [Richardson et al., 2013] that have only hundreds of training examples and therefore the best performing systems usually rely on hand-crafted features [Sachan et al., 2015, Narasimhan and Barzilay, 2015].

In the first part of this article we introduce the task at hand and the main aspects of the relevant datasets. Then we present our own model to tackle the problem. Subsequently we compare the model to previously proposed architectures and finally describe the experimental results on the performance of our model.

Task and datasets

In this section we introduce the task that we are seeking to solve and relevant large-scale datasets that have recently been introduced for this task.

The task consists of answering a cloze-style question, the answer to which depends on the understanding of a context document provided with the question. The model is also provided with a set of possible answers from which the correct one is to be selected. This can be formalized as follows:

The training data consist of tuples (q,d,a,A)(\mathbf{q},\mathbf{d},a,A), where q\mathbf{q} is a question, d\mathbf{d} is a document that contains the answer to question q\mathbf{q}, AA is a set of possible answers and aAa\in A is the ground truth answer. Both q\mathbf{q} and d\mathbf{d} are sequences of words from vocabulary VV. We also assume that all possible answers are words from the vocabulary, that is AVA\subseteq V, and that the ground truth answer aa appears in the document, that is ada\in\mathbf{d}.

2 Datasets

We will now briefly summarize important features of the datasets.

The first two datasetsThe CNN and Daily Mail datasets are available at https://github.com/deepmind/rc-data [Hermann et al., 2015] were constructed from a large number of news articles from the CNN and Daily Mail websites. The main body of each article forms a context, while the cloze-style question is formed from one of short highlight sentences, appearing at the top of each article page. Specifically, the question is created by replacing a named entity from the summary sentence (e.g. “Producer X will not press charges against Jeremy Clarkson, his lawyer says.”).

Furthermore the named entities in the whole dataset were replaced by anonymous tokens which were further shuffled for each example so that the model cannot build up any world knowledge about the entities and hence has to genuinely rely on the context document to search for an answer to the question.

Qualitative analysis of reasoning patterns needed to answer questions in the CNN dataset together with human performance on this task are provided in [Chen et al., 2016].

2.2 Children’s Book Test

The third datasetThe CBT dataset is available at http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz, the Children’s Book Test (CBT) [Hill et al., 2015], is built from books that are freely available thanks to Project Gutenberghttps://www.gutenberg.org/. Each context document is formed by 2020 consecutive sentences taken from a children’s book story. Due to the lack of summary, the cloze-style question is then constructed from the subsequent (2121st) sentence.

One can also see how the task complexity varies with the type of the omitted word (named entity, common noun, verb, preposition). [Hill et al., 2015] have shown that while standard LSTM language models have human level performance on predicting verbs and prepositions, they lack behind on named entities and common nouns. In this article we therefore focus only on predicting the first two word types.

Basic statistics about the CNN, Daily Mail and CBT datasets are summarized in Table 1.

Our Model — Attention Sum Reader

Our model called the Attention Sum Reader (AS Reader)Our implementation of the AS Reader is available at https://github.com/rkadlec/asreader is tailor-made to leverage the fact that the answer is a word from the context document. This is a double-edged sword. While it achieves state-of-the-art results on all of the mentioned datasets (where this assumption holds true), it cannot produce an answer which is not contained in the document. Intuitively, our model is structured as follows:

We compute a vector embedding of the query.

We compute a vector embedding of each individual word in the context of the whole document (contextual embedding).

Using a dot product between the question embedding and the contextual embedding of each occurrence of a candidate answer in the document, we select the most likely answer.

Our model uses one word embedding function and two encoder functions. The word embedding function ee translates words into vector representations. The first encoder function is a document encoder ff that encodes every word from the document d\mathbf{d} in the context of the whole document. We call this the contextual embedding. For convenience we will denote the contextual embedding of the ii-th word in d\mathbf{d} as fi(d)f_{i}(\mathbf{d}). The second encoder gg is used to translate the query q\mathbf{q} into a fixed length representation of the same dimensionality as each fi(d)f_{i}(\mathbf{d}). Both encoders use word embeddings computed by ee as their input. Then we compute a weight for every word in the document as the dot product of its contextual embedding and the query embedding. This weight might be viewed as an attention over the document d\mathbf{d}.

To form a proper probability distribution over the words in the document, we normalize the weights using the softmax function. This way we model probability sis_{i} that the answer to query q\mathbf{q} appears at position ii in the document d\mathbf{d}. In a functional form this is:

Finally we compute the probability that word ww is a correct answer as:

where I(w,d)I(w,\mathbf{d}) is a set of positions where ww appears in the document d\mathbf{d}. We call this mechanism pointer sum attention since we use attention as a pointer over discrete tokens in the context document and then we directly sum the word’s attention across all the occurrences. This differs from the usual use of attention in sequence-to-sequence models [Bahdanau et al., 2015] where attention is used to blend representations of words into a new embedding vector. Our use of attention was inspired by Pointer Networks (Ptr-Nets) [Vinyals et al., 2015].

A high level structure of our model is shown in Figure 2.

2 Model instance details

In our model the document encoder ff is implemented as a bidirectional Gated Recurrent Unit (GRU) network [Cho et al., 2014, Chung et al., 2014] whose hidden states form the contextual word embeddings, that is fi(d)=fi(d)fi(d)f_{i}(\mathbf{d})=\overrightarrow{f_{i}}(\mathbf{d})\,\,||\,\,\overleftarrow{f_{i}}(\mathbf{d}), where || denotes vector concatenation and fi\overrightarrow{f_{i}} and fi\overleftarrow{f_{i}} denote forward and backward contextual embeddings from the respective recurrent networks. The query encoder gg is implemented by another bidirectional GRU network. This time the last hidden state of the forward network is concatenated with the last hidden state of the backward network to form the query embedding, that is g(q)=gq(q)g1(q)g(\mathbf{q})=\overrightarrow{g_{|\mathbf{q}|}}(\mathbf{q})\,\,||\,\,\overleftarrow{g_{1}}(\mathbf{q}). The word embedding function ee is implemented in a usual way as a look-up table V\mathbf{V}. V\mathbf{V} is a matrix whose rows can be indexed by words from the vocabulary, that is e(w)=Vw,wVe(w)=V_{w},w\in V. Therefore, each row of V\mathbf{V} contains embedding of one word from the vocabulary. During training we jointly optimize parameters of ff, gg and ee.

Related Work

Several recent deep neural network architectures [Hermann et al., 2015, Hill et al., 2015, Chen et al., 2016, Kobayashi et al., 2016] were applied to the task of text comprehension. The last two architectures were developed independently at the same time as our work. All of these architectures use an attention mechanism that allows them to highlight places in the document that might be relevant to answering the question. We will now briefly describe these architectures and compare them to our approach.

Attentive and Impatient Readers were proposed in [Hermann et al., 2015]. The simpler Attentive Reader is very similar to our architecture. It also uses bidirectional document and query encoders to compute an attention in a similar way we do. The more complex Impatient Reader computes attention over the document after reading every word of the query. However, empirical evaluation has shown that both models perform almost identically on the CNN and Daily Mail datasets.

The key difference between the Attentive Reader and our model is that the Attentive Reader uses attention to compute a fixed length representation rr of the document d\mathbf{d} that is equal to a weighted sum of contextual embeddings of words in d\mathbf{d}, that is r=isifi(d)r=\sum_{i}s_{i}f_{i}(\mathbf{d}). A joint query and document embedding mm is then a non-linear function of rr and the query embedding g(q)g(\mathbf{q}). This joint embedding mm is in the end compared against all candidate answers aAa^{\prime}\in A using the dot product e(a)me(a^{\prime})\cdot m, in the end the scores are normalized by softmax. That is: P(aq,d)exp(e(a)m)P(a^{\prime}|\mathbf{q},\mathbf{d})\propto\exp\left(e(a^{\prime})\cdot m\right).

In contrast to the Attentive Reader, we select the answer from the context directly using the computed attention rather than using such attention for a weighted sum of the individual representations (see Eq. 2). The motivation for such simplification is the following.

Consider a context “A UFO was observed above our city in January and again in March.” and question “An observer has spotted a UFO in ___ .”

Since both January and March are equally good candidates, the attention mechanism might put the same attention on both these candidates in the context. The blending mechanism described above would compute a vector between the representations of these two words and propose the closest word as the answer - this may well happen to be February (it is indeed the case for Word2Vec trained on Google News). By contrast, our model would correctly propose January or March.

2 Chen et al. 2016

A model presented in [Chen et al., 2016] is inspired by the Attentive Reader. One difference is that the attention weights are computed with a bilinear term instead of simple dot-product, that is siexp(fi(d)Wg(q))s_{i}\propto\exp\left(f_{i}(\mathbf{d})^{\intercal}\,W\,g(\mathbf{q})\right). The document embedding rr is computed using a weighted sum as in the Attentive Reader, r=isifi(d)r=\sum_{i}s_{i}f_{i}(\mathbf{d}). In the end P(aq,d)exp(e(a)r)P(a^{\prime}|\mathbf{q},\mathbf{d})\propto\exp\left(e^{\prime}(a^{\prime})\cdot r\right), where ee^{\prime} is a new embedding function.

Even though it is a simplification of the Attentive Reader this model performs significantly better than the original.

3 Memory Networks

MemNNs [Weston et al., 2014] were applied to the task of text comprehension in [Hill et al., 2015].

The best performing memory networks model setup - window memory - uses windows of fixed length (8) centered around the candidate words as memory cells. Due to this limited context window, the model is unable to capture dependencies out of scope of this window. Furthermore, the representation within such window is computed simply as the sum of embeddings of words in that window. By contrast, in our model the representation of each individual word is computed using a recurrent network, which not only allows it to capture context from the entire document but also the embedding computation is much more flexible than a simple sum.

To improve on the initial accuracy, a heuristic approach called self supervision is used in [Hill et al., 2015] to help the network to select the right supporting “memories” using an attention mechanism showing similarities to the ours. Plain MemNNs without this heuristic are not competitive on these machine reading tasks. Our model does not need any similar heuristics.

4 Dynamic Entity Representation

The Dynamic Entity Representation model [Kobayashi et al., 2016] has a complex architecture also based on the weighted attention mechanism and max-pooling over contextual embeddings of vectors for each named entity.

5 Pointer Networks

Our model architecture was inspired by Ptr-Nets [Vinyals et al., 2015] in using an attention mechanism to select the answer in the context rather than to blend words from the context into an answer representation. While a Ptr-Net consists of an encoder as well as a decoder, which uses the attention to select the output at each step, our model outputs the answer in a single step. Furthermore, the pointer networks assume that no input in the sequence appears more than once, which is not the case in our settings.

6 Summary

Our model combines the best features of the architectures mentioned above. We use recurrent networks to “read” the document and the query as done in [Hermann et al., 2015, Chen et al., 2016, Kobayashi et al., 2016] and we use attention in a way similar to Ptr-Nets. We also use summation of attention weights in a way similar to MemNNs [Hill et al., 2015].

From a high level perspective we simplify all the discussed text comprehension models by removing all transformations past the attention step. Instead we use the attention directly to compute the answer probability.

Evaluation

In this section we evaluate our model on the CNN, Daily Mail and CBT datasets. We show that despite the model’s simplicity its ensembles achieve state-of-the-art performance on each of these datasets.

To train the model we used stochastic gradient descent with the ADAM update rule [Kingma and Ba, 2015] and learning rate of 0.0010.001 or 0.00050.0005. During training we minimized the following negative log-likelihood with respect to θ\theta:

where aa is the correct answer for query q\mathbf{q} and document d\mathbf{d}, and θ\theta represents parameters of the encoder functions ff and gg and of the word embedding function ee. The optimized probability distribution P(aq,d)P(a|\mathbf{q},\mathbf{d}) is defined in Eq. 2.

The initial weights in the word embedding matrix were drawn randomly uniformly from the interval [0.1,0.1][-0.1,0.1]. Weights in the GRU networks were initialized by random orthogonal matrices [Saxe et al., 2014] and biases were initialized to zero. We also used a gradient clipping [Pascanu et al., 2012] threshold of 10 and batches of size 32.

During training we randomly shuffled all examples in each epoch. To speedup training, we always pre-fetched 1010 batches worth of examples and sorted them according to document length. Hence each batch contained documents of roughly the same length.

For each batch of the CNN and Daily Mail datasets we randomly reshuffled the assignment of named entities to the corresponding word embedding vectors to match the procedure proposed in [Hermann et al., 2015]. This guaranteed that word embeddings of named entities were used only as semantically meaningless labels not encoding any intrinsic features of the represented entities. This forced the model to truly deduce the answer from the single context document associated with the question. We also do not use pre-trained word embeddings to make our training procedure comparable to [Hermann et al., 2015].

We did not perform any text pre-processing since the original datasets were already tokenized.

We do not use any regularization since in our experience it leads to longer training times of single models, however, performance of a model ensemble is usually the same. This way we can train the whole ensemble faster when using multiple GPUs for parallel training.

For Additional details about the training procedure see Appendix A.

2 Evaluation Method

We evaluated the proposed model both as a single model and using ensemble averaging. Although the model computes attention for every word in the document we restrict the model to select an answer from a list of candidate answers associated with each question-document pair.

For single models we are reporting results for the best model as well as the average of accuracies for the best 20% of models with best performance on validation data since single models display considerable variation of results due to random weight initializationThe standard deviation for models with the same hyperparameters was between 0.6-2.5% in absolute test accuracy. even for identical hyperparameter values. Single model performance may consequently prove difficult to reproduce.

What concerns ensembles, we used simple averaging of the answer probabilities predicted by ensemble members. For ensembling we used 14, 16, 84 and 53 models for CNN, Daily Mail and CBT CN and NE respectively. The ensemble models were chosen either as the top 70% of all trained models, we call this avg ensemble. Alternatively we use the following algorithm: We started with the best performing model according to validation performance. Then in each step we tried adding the best performing model that had not been previously tried. We kept it in the ensemble if it did improve its validation performance and discarded it otherwise. This way we gradually tried each model once. We call the resulting model a greedy ensemble.

3 Results

Performance of our models on the CNN and Daily Mail datasets is summarized in Table 2, Table 3 shows results on the CBT dataset. The tables also list performance of other published models that were evaluated on these datasets. Ensembles of our models set new state-of-the-art results on all evaluated datasets.

Table 4 then measures accuracy as the proportion of test cases where the ground truth was among the top kk answers proposed by the greedy ensemble model for k=1,2,5k=1,2,5.

CNN and Daily Mail. The CNN dataset is the most widely used dataset for evaluation of text comprehension systems published so far. Performance of our single model is a little bit worse than performance of simultaneously published models [Chen et al., 2016, Kobayashi et al., 2016]. Compared to our work these models were trained with Dropout regularization [Srivastava et al., 2014] which might improve single model performance. However, ensemble of our models outperforms these models even though they use pre-trained word embeddings.

On the CNN dataset our single model with best validation accuracy achieves a test accuracy of 69.5%. The average performance of the top 20% models according to validation accuracy is 69.9% which is even 0.5% better than the single best-validation model. This shows that there were many models that performed better on test set than the best-validation model. Fusing multiple models then gives a significant further increase in accuracy on both CNN and Daily Mail datasets..

CBT. In named entity prediction our best single model with accuracy of 68.6% performs 2% absolute better than the MemNN with self supervision, the averaging ensemble performs 4% absolute better than the best previous result. In common noun prediction our single models is 0.4% absolute better than MemNN however the ensemble improves the performance to 69% which is 6% absolute better than MemNN.

Analysis

To further analyze the properties of our model, we examined the dependence of accuracy on the length of the context document (Figure 5), the number of candidate answers (Figure 6) and the frequency of the correct answer in the context (Figure 7).

On the CNN and Daily Mail datasets, the accuracy decreases with increasing document length (Figure 5(a)). We hypothesize this may be due to multiple factors. Firstly long documents may make the task more complex. Secondly such cases are quite rare in the training data (Figure 5(b)) which motivates the model to specialize on shorter contexts. Finally the context length is correlated with the number of named entities, i.e. the number of possible answers which is itself negatively correlated with accuracy (see Figure 6).

On the CBT dataset this negative trend seems to disappear (Fig. 5(c)). This supports the later two explanations since the distribution of document lengths is somewhat more uniform (Figure 5(d)) and the number of candidate answers is constant (1010) for all examples in this dataset.

The effect of increasing number of candidate answers on the model’s accuracy can be seen in Figure 6(a). We can clearly see that as the number of candidate answers increases, the accuracy drops. On the other hand, the amount of examples with large number of candidate answers is quite small (Figure 6(b)).

Finally, since the summation of attention in our model inherently favours frequently occurring tokens, we also visualize how the accuracy depends on the frequency of the correct answer in the document. Figure 7(a) shows that the accuracy significantly drops as the correct answer gets less and less frequent in the document compared to other candidate answers. On the other hand, the correct answer is likely to occur frequently (Fig. 7(a)).

Conclusion

In this article we presented a new neural network architecture for natural language text comprehension. While our model is simpler than previously published models, it gives a new state-of-the-art accuracy on all evaluated datasets.

An analysis by [Chen et al., 2016] suggests that on CNN and Daily Mail datasets a significant proportion of questions is ambiguous or too difficult to answer even for humans (partly due to entity anonymization) so the ensemble of our models may be very near to the maximal accuracy achievable on these datasets.

Acknowledgments

We would like to thank Tim Klinger for providing us with masked softmax code that we used in our implementation.

References

Appendix A Training Details

During training we evaluated the model performance after each epoch and stopped the training when the error on the validation set started increasing.

The models usually converged after two epochs of training. Time needed to complete a single epoch of training on each dataset on an Nvidia K40 GPU is shown in Table 5.

The hyperparameters, namely the recurrent hidden layer dimension and the source embedding dimension, were chosen by grid search. We started with a range of 128 to 384 for both parameters and subsequently kept increasing the upper bound by 128 until we started observing a consistent decrease in validation accuracy. The region of the parameter space that we explored together with the parameters of the model with best validation accuracy are summarized in Table 6.

Our model was implemented using Theano [Bastien et al., 2012] and Blocks [van Merrienboer et al., 2015].

Appendix B Dependence of accuracy on the frequency of the correct answer

In Section 6 we analysed how the test accuracy depends on how frequent the correct answer is compared to other answer candidates for the news datasets. The plots for the Children’s Book Test looks very similar, however we are adding it here for completeness.