Semi-Supervised QA with Generative Domain-Adaptive Nets
Zhilin Yang, Junjie Hu, Ruslan Salakhutdinov, William W. Cohen
Introduction
Recently, various neural network models were proposed and successfully applied to the tasks of questions answering (QA) and/or reading comprehension Xiong et al. (2016); Dhingra et al. (2016); Yang et al. (2017). While achieving state-of-the-art performance, these models rely on a large amount of labeled data. However, it is extremely difficult to collect large-scale question answering datasets. Historically, many of the question answering datasets have only thousands of question answering pairs, such as WebQuestions Berant et al. (2013), MCTest Richardson et al. (2013), WikiQA Yang et al. (2015), and TREC-QA Voorhees and Tice (2000). Although larger question answering datasets with hundreds of thousands of question-answer pairs have been collected, including SQuAD Rajpurkar et al. (2016), MSMARCO Nguyen et al. (2016), and NewsQA Trischler et al. (2016a), the data collection process is expensive and time-consuming in practice. This hinders real-world applications for domain-specific question answering.
Compared to obtaining labeled question answer pairs, it is trivial to obtain unlabeled text data. In this work, we study the following problem of semi-supervised question answering: is it possible to leverage unlabeled text to boost the performance of question answering models, especially when only a small amount of labeled data is available? The problem is challenging because conventional manifold-based semi-supervised learning algorithms Zhu and Ghahramani (2002); Yang et al. (2016a) cannot be straightforwardly applied. Moreover, since the main foci of most question answering tasks are extraction rather than generation, it is also not sensible to use unlabeled text to improve language modeling as in machine translation Gulcehre et al. (2015).
To better leverage the unlabeled text, we propose a novel neural framework called Generative Domain-Adaptive Nets (GDANs). The starting point of our framework is to use linguistic tags to extract possible answer chunks in the unlabeled text, and then train a generative model to generate questions given the answer chunks and their contexts. The model-generated question-answer pairs and the human-generated question-answer pairs can then be combined to train a question answering model, referred to as a discriminative model in the following text. However, there is discrepancy between the model-generated data distribution and the human-generated data distribution, which leads to suboptimal discriminative models. To address this issue, we further propose two domain adaptation techniques that treat the model-generated data distribution as a different domain. First, we use an additional domain tag to indicate whether a question-answer pair is model-generated or human-generated. We condition the discriminative model on the domain tags so that the discriminative model can learn to factor out domain-specific and domain-invariant representations. Second, we employ a reinforcement learning algorithm to fine-tune the generative model to minimize the loss of the discriminative model in an adversarial way.
In addition, we present a simple and effective baseline method for semi-supervised question answering. Although the baseline method performs worse than our GDAN approach, it is extremely easy to implement and can still lead to substantial improvement when only limited labeled data is available.
We experiment on the SQuAD dataset Rajpurkar et al. (2016) with various labeling rates and various amounts of unlabeled data. Experimental results show that our GDAN framework consistently improves over both the supervised learning setting and the baseline methods, including adversarial domain adaptation Ganin and Lempitsky (2014) and dual learning Xia et al. (2016). More specifically, the GDAN model improves the F1 score by 9.87 points in F1 over the supervised learning setting when 8K labeled question-answer pairs are used.
Our contribution is four-fold. First, different from most of the previous neural network studies on question answering, we study a critical but challenging problem, semi-supervised question answering. Second, we propose the Generative Domain-Adaptive Nets that employ domain adaptation techniques on generative models with reinforcement learning algorithms. Third, we introduce a simple and effective baseline method. Fourth, we empirically show that our framework leads to substantial improvements.
Semi-Supervised Question Answering
Let us first introduce the problem of semi-supervised question answering.
In addition to the labeled dataset , in the semi-supervised setting, we are also given a set of unlabeled data, denoted as , where is the number of unlabeled instances. Note that it is usually trivial to have access to an almost infinite number of paragraphs from sources such as Wikipedia articles and other web pages. And since the answer is always a consecutive chunk in , we argue that it is also sensible to extract possible answer chunks from the unlabeled text using linguistic tags. We will discuss the technical details of answer chunk extraction in Section 4.1, and in the formulation of our framework, we assume that the answer chunks are available.
We now present a simple baseline for semi-supervised question answering. Given a paragraph and the answer , we extract from the paragraph and treat it as the question. Here is the window size and is set at 5 in our experiments so that the lengths of the questions are similar to human-generated questions. The context-based question-answer pairs on are combined with human-generated pairs on for training the discriminative model. Intuitively, this method extracts the contexts around the answer chunks to serve as hints for the question answering model. Surprisingly, this simple baseline method leads to substantial improvements when labeled data is limited.
Generative Domain-Adaptive Nets
Though the simple method described in Section 2.1 can lead to substantial improvement, we aim to design a learning-based model to move even further. In this section, we will describe the model architecture and the training algorithms for the GDANs. We will use a notation in the context of question answering following Section 2, but one should be able to extend the notion of GDANs to other applications as well.
The GDAN framework consists of two models, a discriminative model and a generative model. We will first discuss the two models in detail in the context of question answering, and then present an algorithm based on reinforcement learning to combine the two models.
The GA model consists of layers with being a hyper-parameter. Let be the intermediate paragraph representation at layer , and be the question representation. The paragraph representation is a matrix, and the question representation is a matrix, where is the dimensionality of the representations. Given the paragraph , we apply a bidirectional Gated Recurrent Unit (GRU) network Chung et al. (2014) on top of the embeddings of the sequence , and obtain the initial paragraph representation . Given the question , we also apply another bidirectional GRU to obtain the question representation .
The question and paragraph representations are combined with the gated-attention (GA) mechanism Dhingra et al. (2016). More specifically, for each paragraph token , we compute
where is the -th row of and is the -th row of .
Since the answer is a sequence of consecutive word tokens in the paragraph , we apply two softmax layers on top of to predict the start and end indices of , following Yang et al. Yang et al. (2017).
We will train our discriminative model on both model-generated question-answer pairs and human-generated pairs. However, even a well-trained generative model will produce questions somewhat different from human-generated ones. Learning from both human-generated data and model-generated data can thus lead to a biased model. To alleviate this issue, we propose to view the model-generated data distribution and the human-generated data distribution as two different data domains and explicitly incorporate domain adaptation into the discriminative model.
More specifically, we use a domain tag as an additional input to the discriminative model. We use the tag “d_true” to represent the domain of human-generated data (i.e., the true data), and “d_gen” for the domain of model-generated data. Following a practice in domain adaptation Johnson et al. (2016); Chu et al. (2017), we append the domain tag to the end of both the questions and the paragraphs. By introducing the domain tags, we expect the discriminative model to factor out domain-specific and domain-invariant representations. At test time, the tag “d_true” is appended.
2 Generative Model
The generative model consists of an encoder and a decoder. An encoder is a GRU that encodes the input paragraph into a sequence of hidden states . We inject the answer information by appending an additional zero/one feature to the word embeddings of the paragraph tokens; i.e., if a word token appears in the answer, the feature is set at one, otherwise zero.
The decoder is another GRU with an attention mechanism over the encoder hidden states . At each time step, the generation probabilities over all word types are defined with a copy mechanism:
where is the probability of generating the token from the vocabulary, while is the probability of copying a token from the paragraph. The probability is computed based on the current hidden state :
where denotes the logistic function and is a vector of model parameters. The generation probabilities are defined as a softmax function over the word types in the vocabulary, and the copying probabilities are defined as a softmax function over the word types in the paragraph. Both and are defined as a function of the current hidden state and the attention results Gu et al. (2016).
3 Training Algorithm
We first define the objective function of the GDANs, and then present an algorithm to optimize the given objective function. Similar to the Generative Adversarial Nets (GANs) Goodfellow et al. (2014) and adversarial domain adaptation Ganin and Lempitsky (2014), the discriminative model and the generative model have different objectives in our framework. However, rather than formulating the objective as an adversarial game between the two models Goodfellow et al. (2014); Ganin and Lempitsky (2014), in our framework, the discriminative model relies on the data generated by the generative model, while the generative model aims to match the model-generated data distribution with the human-generated data distribution using the signals from the discriminative model.
meaning that the domain tag, “tag”, is appended to the dataset . We use to denote the number of the instances in the dataset . The objective function is averaged over all instances such that we can balance labeled and unlabeled data.
Let denote the dataset obtained by generating questions on the unlabeled dataset with the generative model . The objective of the discriminative model is then to maximize for both labeled and unlabeled data under the domain adaptation notions, i.e., .
Now we discuss the objective of the generative model. Similar to the dual learning Xia et al. (2016) framework, one can define an auto-encoder objective. In this case, the generative model aims to generate questions that can be reconstructed by the discriminative model, i.e., maximizing . However, this objective function can lead to degenerate solutions because the questions can be thought of as an overcomplete representation of the answers Vincent et al. (2010). For example, given and , the generative model might learn to generate trivial questions such as copying the answers, which does not contributed to learning a better .
Instead, we leverage the discriminative model to better match the model-generated data distribution with the human-generated data distribution. We propose to define an adversarial training objective . We append the tag “d_true” instead of “d_gen” for the model-generated data to “fool” the discriminative model. Intuitively, the goal of G is to generate ”useful” questions where the usefulness is measured by the probability that the generated questions can be answered correctly by .
The overall objective function now can be written as
With the above objective function in mind, we present a training algorithm in Algorithm 1 to train a GDAN. We first pretrain the generative model on the labeled data with maximum likelihood estimation (MLE):
We then alternatively update and based on their objectives. To update , we sample one batch from the labeled data and one batch from the unlabeled data , and combine the two batches to perform a gradient update step. Since the output of is discrete and non-differentiable, we use the Reinforce algorithm Williams (1992) to update . The action space is all possible questions with length (possibly with padding) and the reward is the objective function . Let be the parameters of . The gradient can be written as
MLE vs RL. The generator has two training phases–MLE training and RL training, which are different in that: 1) RL training does not require labels, so can explore a broader data domain of using unlabeled data, while MLE training requires labels; 2) MLE maximizes , while RL maximizes . Since is the sum of and (plus a constant), maximizing does not require modeling that is irrelevant to QA, which makes optimization easier. Moreover, maximizing is consistent with the goal of QA.
Experiments
As discussed in Section 2, our model assumes that answers are available for unlabeled data. In this section, we introduce how we use linguistic tags and rules to extract answer chunks from unlabeled text.
To extract answers from massive unlabelled Wikipedia articles, we first sample 205,511 Wikipedia articles that are not used in the training, development and test sets in the SQuAD dataset. We extract the paragraphs from each article, and limit the length of each paragraph at the word level to be less than 850. In total, we obtain 950,612 paragraphs from unlabelled articles.
Answers in the SQuAD dataset can be categorized into ten types, i.e., “Date”, “Other Numeric”, “Person”, “Location”, “Other Entity”, “Common Noun Phrase”, “Adjective Phrase”, “Verb Phrase”, “Clause” and “Other” Rajpurkar et al. (2016). For each paragraph from the unlabeled articles, we utilize Stanford Part-Of-Speech (POS) tagger Toutanova et al. (2003) to label each word with the corresponding POS tag, and implement a simple constituency parser to extract the noun phrase, verb phrase, adjective and clause based on a small set of constituency grammars. Next, we use Stanford Named Entity Recognizer (NER) Finkel et al. (2005) to assign each word with one of the seven labels, i.e., “Date”, “Money”, “Percent”, “location”, “Organization” and “Time”. We then categorize a span of consecutive words with the same NER tags of either “Money” or “Percent” as the answer of the type “Other Numeric”. Similarly, we categorize a span of consecutive words with the same NER tags of “Organization” as the answer of the type “Other Entity”. Finally, we subsample five answers from all the extracted answers for each paragraph according to the percentage of answer types in the SQuAD dataset. We obtain 4,753,060 answers in total, which is about 50 times larger than the number of answers in the SQuAD dataset.
2 Settings and Comparison Methods
The original SQuAD dataset consists of 87,636 training instances and 10,600 development instances. Since the test set is not published, we split 10% of the training set as the test set, and the remaining 90% serves as the actual training set. Instances are split based on articles; i.e., paragraphs in one article always appear in only one set. We tune the hyper-parameters and perform early stopping on the development set using the F1 scores, and the performance is evaluated on the test set using both F1 scores and exact matching (EM) scores Rajpurkar et al. (2016).
We compare the following methods. SL is the supervised learning setting where we train the model solely on the labeled data . Context is the simple context-based method described in Section 2.1. Context + domain is the “Context” method with domain tags as described in Section 3.1.1. Gen is to train a generative model and use the generated questions as additional training data. Gen + GAN refers to the domain adaptation method using GANs Ganin and Lempitsky (2014); in contrast to the original work, the generative model is updated using Reinforce. Gen + dual refers to the dual learning method Xia et al. (2016). Gen + domain is “Gen” with domain tags, while the generative model is trained with MLE and fixed. Gen + domain + adv is the approach we propose (Cf. Figure 1 and Algorithm 1), with “adv” meaning adversarial training based on Reinforce. We use our own implementation of “Gen + GAN” and “Gen + dual”, since the GAN model Ganin and Lempitsky (2014) does not handle discrete features and the dual learning model Xia et al. (2016) cannot be directly applied to question answering. When implementing these two baselines, we adopt the learning schedule introduced by Ganin and Lempitsky Ganin and Lempitsky (2014), i.e., gradually increasing the weights of the gradients for the generative model .
3 Results and Analysis
We study the performance of different models with varying labeling rates and unlabeled dataset sizes. Labeling rates are the percentage of training instances that are used to train . The results are reported in Table 2. Though the unlabeled dataset we collect consists of around 5 million instances, we also sample a subset of around 50,000 instances to evaluate the effects of the size of unlabeled data. The highest labeling rate in Table 2 is because 10% of the training instances are used for testing. Since we do early stopping on the development set using the F1 scores, we also report the development F1. We report two metrics, the F1 scores and the exact matching (EM) scores Rajpurkar et al. (2016), on the test set. All metrics are computed using the official evaluation scripts.
SL v.s. SSL. We observe that semi-supervised learning leads to consistent improvements over supervised learning in all cases. Such improvements are substantial when labeled data is limited. For example, the GDANs improve over supervised learning by 9.87 points in F1 and 7.26 points in EM when the labeling rate is . With our semi-supervised learning approach, we can use only training instances to obtain even better performance than a supervised learning approach with training instances, saving more than half of the labeling costs.
Comparison with Baselines. By comparing “Gen + domain + adv” with “Gen + GAN” and “Gen + Dual”, it is clear that the GDANs perform substantially better than GANs and dual learning. With labeling rate , GDANs outperform dual learning and GANs by 2.47 and 4.29 points respectively in terms of F1.
Ablation Study. We also perform an ablation study by examining the effects of “domain” and “adv” when added to “gen”. It can be seen that both the domain tags and the adversarial training contribute to the performance of the GDANs when the labeling rate is equal to or less than . With labeling rate , adding domain tags still leads to better performance but adversarial training does not seem to improve the performance by much.
Unlabeled Data Size. Moreover, we observe that the performance can be further improved when a larger unlabeled dataset is used, though the gain is relatively less significant compared to changing the model architectures. For example, increasing the unlabeled dataset size from 50K to 5M, the performance of GDANs increases by 0.38 points in F1 and 0.52 points in EM.
Context-Based Method. Surprisingly, the simple context-based method, though performing worse than GDANs, still leads to substantial gains; e.g., 7.00 points in F1 with labeling rate . Adding domain tags can improve the performance of the context-based method as well.
MLE vs RL. We plot the loss curve of for both the MLE-trained generator (“Gen + domain”) and the RL-trained generator (“Gen + domain + adv”) in Figure 2. We observe that the training loss for D on RL-generated questions is lower than MLE-generated questions, which confirms that RL training maximizes .
Samples of Generated Questions. We present some questions generated by our model in Table 1. The generated questions are post-processed by removing repeated subs-sequences. Compared to MLE-generated questions, RL-generated questions are more informative (Cf., P1, P2, and P4), and contain less “UNK” (unknown) tokens (Cf., P1). Moreover, both semantically and syntactically, RL-generated questions are more accurate (Cf., P3 and P5).
Related Work
Semi-Supervised Learning. Semi-supervised learning has been extensively studied in literature Zhu (2005). A batch of novel models have been recently proposed for semi-supervised learning based on representation learning techniques, such as generative models Kingma et al. (2014), ladder networks Rasmus et al. (2015) and graph embeddings Yang et al. (2016a). However, most of the semi-supervised learning methods are based on combinations of the supervised loss and an unsupervised loss . In the context of reading comprehension, directly modeling the likelihood of a paragraph would not possibly improve the supervised task of question answering. Moreover, traditional graph-based semi-supervised learning Zhu and Ghahramani (2002) cannot be easily extended to modeling the unlabeled answer chunks.
Domain Adaptation. Domain adaptation has been successfully applied to various tasks, such as classification Ganin and Lempitsky (2014) and machine translation Johnson et al. (2016); Chu et al. (2017). Several techniques on domain adaptation Glorot et al. (2011) focus on learning distribution invariant features by sharing the intermediate representations for downstream tasks. Another line of research on domain adaptation attempt to match the distance between different domain distributions in a low dimensional space Long et al. (2015); Baktashmotlagh et al. (2013). There are also methods seeking a domain transition from the source domain to the target domain Gong et al. (2012); Gopalan et al. (2011); Pan et al. (2011). Our work gets inspiration from a practice in Johnson et al. Johnson et al. (2016) and Chu et al. Chu et al. (2017) based on appending domain tags. However, our method is different from the above methods in that we apply domain adaptation techniques to the outputs of a generative model rather than a natural data domain.
Question Answering. Various neural models based on attention mechanisms Wang and Jiang (2016); Seo et al. (2016); Xiong et al. (2016); Wang et al. (2016); Dhingra et al. (2016); Kadlec et al. (2016); Trischler et al. (2016b); Sordoni et al. (2016); Cui et al. (2016); Chen et al. (2016) have been proposed to tackle the tasks of question answering and reading comprehension. However, the performance of these neural models largely relies on a large amount of labeled data available for training.
Learning with Multiple Models. GANs Goodfellow et al. (2014) formulated a adversarial game between a discriminative model and a generative model for generating realistic images. Ganin and Lempitsky Ganin and Lempitsky (2014) employed a similar idea to use two models for domain adaptation. Review networks Yang et al. (2016b) employ a discriminative model as a regularizer for training a generative model. In the context of machine translation, given a language pair, various recent work studied jointly training models to learn the mappings in both directions Tu et al. (2016); Xia et al. (2016).
Conclusions
We study a critical and challenging problem, semi-supervised question answering. We propose a novel neural framework called Generative Domain-Adaptive Nets, which incorporate domain adaptation techniques in combination with generative models for semi-supervised learning. Empirically, we show that our approach leads to substantial improvements over supervised learning models and outperforms several strong baselines including GANs and dual learning. In the future, we plan to apply our approach to more question answering datasets in different domains. It will also be intriguing to generalize GDANs to other applications.
Acknowledgements. This work was funded by the Office of Naval Research grants N000141512791 and N000141310721 and NVIDIA.