Adversarial Training for Large Neural Language Models
Xiaodong Liu, Hao Cheng, Pengcheng He, Weizhu Chen, Yu Wang, Hoifung Poon, Jianfeng Gao
Introduction
Generalization and robustness are two fundamental considerations in assessing machine learning methods. Ideally, a learned model should perform well on unseen test examples and withstand adversarial attacks. In natural language processing (NLP), pre-training neural language models on unlabeled text has proven very effective to improve generalization performance for a variety of downstream tasks, as exemplified by BERT Devlin et al. (2018) and other transformer-based models Liu et al. (2019c); Radford et al. (2018); Clark et al. (2020); Dong et al. (2019); Bao et al. (2020). However, these models may still suffer catastrophic failures in adversarial scenarios Nie et al. (2019); Hsieh et al. (2019). For example, Jin et al. (2019) show that classification accuracy on a Yelp dataset drops from 95.6% on standard test to 6.8% on robust test for a BERT model.
Adversarial training Madry et al. (2017); Goodfellow et al. (2014) has been well studied in computer vision, but past work shows that it often hurts generalization Raghunathan et al. (2019); Min et al. (2020). In NLP, there is growing interest in adversarial training, but existing work typically focuses on assessing the impact on generalization Zhu et al. (2019); Jiang et al. (2019); Cheng et al. (2019); Wang et al. (2019). Moreover, adversarial training is generally limited to task-specific fine-tuningA notable exception is Wang et al. (2019), but it only applied adversarial training to generative language modeling.. See Minaee et al. (2020a) for a recent survey.
In this paper, we present the first comprehensive study on adversarial pre-training, and show that it can improve both generalization and robustness for a wide range of NLP tasks. We propose a unifying algorithm ALUM (Adversarial training for large neural LangUage Models), which augments the standard training objective with an additional term that maximizes the adversarial loss via applying perturbation in the embedding space. ALUM is generally applicable to pre-training and fine-tuning, on top of any Transformer-based language models.
We conduct a comprehensive evaluation on various NLP tasks across multiple benchmark datasets, including GLUE, SQuAD v1.1/v2.0, SNLI, SciTail for assessing model generalization, and ANLI, HELLSWAG, SWAG, Adversarial SQuAD for assessing model robustness. Experimental results show that by conducting adversarial pre-training, ALUM attains significant improvements, often outperforming previous state of the art by a large margin. This is true even for the extremely well-trained RoBERTa model, where continual pre-training without adversarial training fails to attain any gain.
Remarkably, in addition to improving generalization, we find that adversarial pre-training also substantially improves robustness, as exemplified by the resulting large gains in adversarial datasets such as ANLI, Adversarial-SQuAD, HELLASWAG, which significantly reduces the gap between standard errors and robust errors for popular models like BERT and RoBERTa. This suggests that adversarial training on unlabeled data can provide a promising direction to reconcile the apparent conflict between generalization and robustness as observed in prior work Raghunathan et al. (2019); Min et al. (2020). We also show that adversarial pre-training can be combined with adversarial fine-tuning, resulting in extra gains.
Our contributions are summarized as follows:
We propose ALUM, a general algorithm to incorporate adversarial training for pre-training and fine-tuning large neural language models.
We conduct a comprehensive evaluation on a wide range of NLP tasks and assess the impact of adversarial training in pre-training from scratch, continual pre-training, task-specific fine-tuning, and their combinations.
We obtain significant improvements over prior state of the art, including extremely well-trained models such as RoBERTa, in both generalization and robustness.
To facilitate research, we will release our code and pre-trained models.
Preliminary
In this section, we give a quick overview of language model pre-training, using BERT Devlin et al. (2018) as a running example for transformer-based neural language models.
We assume that the input consists of text spans (typically sentences) separated by a special token . To address the problem of out-of-vocabulary words, tokens are divided into subword units, using Byte-Pair Encoding (BPE) Sennrich et al. (2015) or its variants Kudo and Richardson (2018), which generates a fixed-size subword vocabulary to compactly represent words in training text corpora.
2 Model Architecture
Following recent pre-training methods Devlin et al. (2018); Liu et al. (2019c), we use transformer-based models Vaswani et al. (2017) to leverage a multi-head attention mechanism, which have demonstrated superiority in parallel computation and modeling long-range dependencies, compared to recurrent neural networks such as LSTM Hochreiter and Schmidhuber (1997). The input is first passed to a lexical encoder, which combines a token embedding, a (token) position embedding and a segment embedding (i.e., which text span the token belongs to) by element-wise summation. The embedding layer is then passed to multiple layers of transformer modules to generate the contextual representation Vaswani et al. (2017).
3 Self Supervision
A key innovation in BERT Devlin et al. (2018) is the use of Masked Language Model (MLM) for self-supervised pre-training. Instead of predicting the next token based on the preceding tokens, as in traditional generative language models, MLM randomly replaces a subset of tokens by a special token (e.g., ), and asks the model to predict them. Essentially, it is a cloze task Taylor (1953), where the training objective is the cross-entropy loss between the original tokens and the predicted ones. In BERT and RoBERTa, 15% of the input tokens are chosen, among which a random 80% are replaced by , 10% are left unchanged and 10% are randomly replaced by a token from the vocabulary. In our experiments, instead of using a fixed masked rate of 15%, we gradually increase it from 5% to 25% with 5% increment for every 20% of training epochs, as we find this makes pre-training more stable.
Additionally, BERT also uses Next Sentence Prediction (NSP), which is a binary classification task that for a given sentence pair determines whether one sentence follows the other in the original text. There have debates on how much NSP helps Liu et al. (2019c). But we include it in our experiments for a fair comparison with BERT.
ALUM (Adversarial training for large neural LangUage Models)
In this section, we first present a unifying view of standard training objectives and prior approaches to adversarial training. We then present ALUM, which is a general adversarial training algorithm applicable to pre-training and fine-tuning, on top of any transformer-based neural language models.
Both pre-training and fine-tuning can be viewed as minimizing the standard error on training data, with the training objectives derived from self-supervision (MLM and NSP in pre-training) or direct supervision (labeled examples in task-specific fine-tuning), respectively.
Specifically, the training algorithms seek to learn a function , parametrized by . In MLM, is the vocabulary, and tries to predict the masked token . In fine-tuning, is the task-specific label set, and is the classifier. Given a training dataset of input-output pairs and the loss function (e.g., cross entropy), is trained to minimize the empirical risk:
2 Adversarial Training
Pre-training a large neural language model such as BERT has proven effective to improve generalization performance in task-specific fine-tuning Devlin et al. (2018). However, such models can still suffer catastrophic loss in adversarial scenarios Nie et al. (2019); Hsieh et al. (2019); Madry et al. (2017); Jin et al. (2019), with attacks as simple as replacing a few words in input sentences while preserving the semantics.
To improve model robustness and withstand adversarial attacks, adversarial training has been proposed and studied extensively, predominantly in computer vision literature Goodfellow et al. (2014); Madry et al. (2017). The key idea is to modify the training objective by applying small perturbation to input images that maximize the adversarial loss:
where the inner maximization can be solved by running a number of projected gradient descent steps Madry et al. (2017).
While adversarial training has been successful in mitigating adversarial attacks, past work often encounters an apparent conflict between generalization and robustness Raghunathan et al. (2019, 2020); Min et al. (2020), as adversarial training could hurt generalization performance.
3 The ALUM Algorithm
In NLP, applying adversarial training is not straightforward, since the input are discrete elements (token or subword sequences), but there have been some recent successes Zhu et al. (2019); Jiang et al. (2019); Cheng et al. (2019); Wang et al. (2019); Minaee et al. (2020b). However, aside from Wang et al. (2019), there has not been any prior work on adversarial pre-training, and Wang et al. (2019) only applied adversarial training to generative language modeling using LSTM.
ALUM is applicable to both pre-training and fine-tuning. It builds on several key ideas that have proven useful in prior work. First, instead of applying perturbation to the input text directly, one would perturb the embedding space. Namely, is the sub-word embedding in Jiang et al. (2019); Zhu et al. (2019).
Second, instead of adopting the adversarial training objective of Eq. 2, as in Zhu et al. (2019) and most other approaches, we follow Jiang et al. (2019) to regularize the standard objective using virtual adversarial training Miyato et al. (2018):
Effectively, the adversarial term favors label smoothness in the embedding neighborhood, and is a hyperparameter that controls the trade-off between standard errors and robust errors.
We found that virtual adversarial training is superior to conventional adversarial training, especially when labels might be noisy. E.g., BERT pre-training uses the masked words as self-supervised labels, but in many cases, they could be replaced by other words to form completely legitimate text. Empirically, we verified that this is indeed the case, as pre-training benefits from larger . We set for pre-training, and for fine-tuning in all our experiments.
Compared to standard training, adversarial training is rather expensive due to the inner maximization. Zhu et al. (2019) adopted the free adversarial training idea in Shafahi et al. (2019) for acceleration, by reusing the backward pass for gradient computation to carry out the inner ascent step and outer descent step simultaneously. Inspired by ERNIE Sun et al. (2019) and other continual pre-training approaches, we instead adopt a curriculum learning approach: first train the model using the standard objective (1); and then continue the training with virtual adversarial training (3).
Jiang et al. (2019) also incorporated a momentum term using the Bregman proximate point method, which can be quite costly in training time. We found that our curriculum learning approach largely rendered this unnecessary and simplified our algorithm without using this term.
Algorithm 1 shows the details of ALUM. Line 4-6 run projected gradient steps to find the perturbation that maximizes the adversarial loss (violation of local smoothness). Note that a larger leads to better approximation Madry et al. (2017); Qin et al. (2019), but it is more expensive. To attain a good trade-off between speed and performance, we set in all our experiments.
4 Generalization vs. Robustness
Empirically, we found that by applying adversarial pre-training using ALUM, we were able to improve both generalization and robustness for a wide range of NLP tasks, as seen in Section 4. This is very interesting as prior work often finds that adversarial training hurts generalization, even with theoretical justification Raghunathan et al. (2019, 2020); Min et al. (2020).
We hypothesize that adversarial pre-training might be the key for reconciling this apparent incongruence, as prior work on the conflict between generalization and robustness generally focuses on the supervised learning setting. Interestingly, some nascent results in reconciling the two also leverage unlabeled data, such as self-training Raghunathan et al. (2020). Additionally, we hypothesize that by perturbing the embedding space rather than the input space, adversarial training in NLP might inadvertently bias toward on-manifold perturbation than regular perturbation, which helps generalization Stutz et al. (2019). We leave the theoretical analysis of all these connections to future work.
Experiments
In this section, we present a comprehensive study of adversarial training on large neural language models. We show that ALUM substantially improves both generalization and robustness in a wide range of NLP tasks, for both the standard BERT model and the extremely well-trained RoBERTa model. We also show that ALUM can be applied to adversarial pre-training and fine-tuning alike and attain further gain by combining the two.
For BERT pre-training, we use Wikipedia (English Wikipedia dumphttps://dumps.wikimedia.org/enwiki/; 13GB). For continual pre-training of RoBERTa, we use Wikipedia (13GB), OPENWEBTEXT (public Reddit content Gokaslan and Cohen ; 38GB), STORIES (a subset of CommonCrawl Trinh and Le (2018); 31GB).
NLP application benchmarks:
To assess the impact of adversarial training on generalization, we use standard benchmarks such as GLUE Wang et al. (2018) and SQuAD (v1.1 and v2.0) Rajpurkar et al. (2016, 2018), as well as three named entity recognition (NER) tasks in the biomedical domain. To evaluate the impact of adversarial training on robustness, we use ANLI Nie et al. (2019), Adversarial SQuAD Jia and Liang (2017), and HELLASWAG Hampel (1974). To assess the combination of adversarial pre-training and fine-tuning, we follow Jiang et al. (2019) and use MNLI Williams et al. (2018) (from GLUE), ANLI, SWAG Zellers et al. (2018), SNLI Bowman et al. (2015), SciTail Khot et al. (2018). These benchmarks cover a wide range of NLP tasks such as named entity recognition, textual entailment, and machine reading comprehension, spanning classification, ranking, and regression. For details, see Appendix A.
2 Implementation Details
We perform three types of adversarial training in our experiments: pre-training from scratch, continual pre-training on a well-trained model, and task-specific fine-tuning.
We pre-train BERT models from scratch using WikipediaBookCorpus is no longer publicly available.. The training code is based on Megatron, implemented in PyTorch Shoeybi et al. (2019)https://github.com/NVIDIA/Megatron-LM. We use ADAM Kingma and Ba (2014) for the optimizer with a standard learning rate schedule that increases linearly from zero to the peak rate of in first one percent of steps, and then decays linearly to zero in the remaining 99% of steps. Following Devlin et al. (2018), training is done for one million steps with batch size of 256. We set the perturbation size , the step size , and the variance for initializing perturbation . We set for heightened regularization in virtual adversarial training, and set for training efficiency (i.e., one projected gradient step). The training takes 10 days on one DGX-2 machine with 16 V100 GPUs.
For continual pre-training of RoBERTa Liu et al. (2019c), we use RoBERTa’s default training parameters, except a smaller learning rate (), and run for 100K training steps with a batch size of 256 on the union of Wikipedia, OPENWEBTEXT, and STORIES (total size 82GB). The code is based on FairSeqhttps://github.com/pytorch/fairseq. The training takes 7 days on two DGX-2 machines.
For fine-tuning with or without adversarial training, we use the MT-DNN open-sourced toolkit Liu et al. (2020, 2015)https://github.com/namisan/mt-dnn. We follow Jiang et al. (2019) for head-to-head comparison, which uses ADAM Kingma and Ba (2014) and RADAM Liu et al. (2019a) as our optimizers, with peak learning rates of , and batch sizes of 16, 32 or 64, depending on the tasks. The dropout rate is set to for all the task-specific layers, except for MNLI and for CoLA. To avoid gradient exploding, the gradient is clipped to keep the norm within . All the texts are tokenized using WordPiece and chopped to spans up to tokens. We conduct fine-tuning for up to 10 epochs and pick the best model using the dev set.
3 Improving Generalization
In this subsection, we study the impact of adversarial pre-training on generalization, by comparing the performance of pre-trained models in various downstream applications. First, we study the scenario of pre-training from scratch, by comparing three BERT models:
BERTBASE is the standard BERT base model trained using the same setting as Devlin et al. (2018) (i.e., 1M steps with a batch size of 256).
BERT+BASE is similar to BERTBASE, except that it is trained with 1.6M steps, which takes roughly the same amount of time as that of adversarial pre-training (see ALUM below).
ALUM is a BERT model trained using the same setting as BERTBASE, except that ALUM is used in the last 500K steps. Each adversarial training step takes approximately 1.5 times longer than a step in standard trainingWith K=1 in Algorithm 1, ALUM requires two more forward passes and one more backward pass compared to standard training..
Table 1 compares these pre-trained models on three standard benchmarks (SQuAD v1.1 Rajpurkar et al. (2016) and v2.0 Rajpurkar et al. (2018), and MNLI from GLUE Wang et al. (2018)), using the same standard fine-tuning setting (without adversarial training). The standard BERT models trained using only the Wikipedia data attain similar results as in Devlin et al. (2018), thus provide a good baseline for comparison. ALUM consistently outperforms the standard BERT models across all the datasets, even adjusting for the slightly longer trainng time. E.g., on SQuAD v1.1, ALUM gains 2.3% points in F1 over BERTBASE and 1.2% points over BERT+BASE. Figure 1 shows ALUM at work on the development set of MNLI. Once adversarial training is applied in the middle (after first 500K steps), ALUM starts outperforming BERT and the gap is widening.
We also assess the impact of adversarial pre-training in the biomedical domain, which is substantially different from the Wikipedia corpus used in pre-training. Table 2 shows the results on standard biomedical name entity recognition (NER) datasets: BC2GM Smith et al. (2008), NCBI Dogan et al. (2014), JNLPBA Collier and Kim (2004). Interestingly, ALUM still outperforms the standard BERT model on all three tasks, even though the application domain is substantially different from the pre-training one.
Next, we assess the impact of adversarial training in the continual pre-training setting. We use our pre-training dataset (Wikipedia, OPENWEBTEXT, STORIES; 82GB)This is a subset of the data (160GB) used in RoBERTa pre-training., and run 100K steps in all our continual pre-training experiments. We choose the RoBERTa models as the baseline, which use the same neural model as BERT, but were pre-trained on an order of magnitude more text (160GB vs 13GB). They are the state-of-the-art pre-trained language models, outperforming the standard BERT models in many NLP tasks.
RoBERTa models are extremely well-trained. Standard continual pre-training fails to attain any gains in downstream applications such as MNLI Williams et al. (2018) and SST Socher et al. (2013) from GLUE Wang et al. (2018), as shown in Table 3. On the other hand, ALUM is able to attain further gain from continual pre-training of RoBERTa, as shown in Table 4. E.g., ALUM outperforms RoBERTaBASE by +0.5%, and ALUM outperforms RoBERTa by +0.7% on the MNLI development set. This is rather remarkable, as by contrast standard continual pre-training is unable to attain any gain.
4 Improving Robustness
In this subsection, we assess the impact of adversarial pre-training on the model’s robustness against adversarial attacks, using three standard adversarial NLP benchmarks: ANLI Nie et al. (2019), HELLASWAG Zellers et al. (2019) and adversarial SQuAD Jia and Liang (2017). On ANLI, we follow the experimental setting of Nie et al. (2019) to enable a head-to-head comparison, which combines four datasets (ANLI, MNLI, SNLI and FEVER Thorne et al. (2018)) for fine-tuning.
Adversarial pre-training substantially improves model robustness, as shown in Table 5 and Table 6. In all three adversarial datasets, ALUM consistently outperformed the standard pre-training counterparts, for BERT and RoBERTa alike. For example, on ANLI, ALUM gains 7.3% points in test accuracy over RoBERTa, outperforms XLNet Yang et al. (2019) by 5.0% points, creating a new state-of-the-art result. The gains on Adversarial SQuAD and HELLASWAG are equally significant. For example, for Adversarial SQuAD, ALUM outperforms BERTBASE by +6.4% F1 in the AddSent setting and +5.0% F1 in the AddOneSent setting. Against RoBERTa, ALUM gains +3.4% F1 in AddSent and +2.1% F1 in AddOneSent.
5 Combining Adversarial Pre-Training and Fine-tuning
Adversarial training has been shown to be effective in task-specific fine-tuning Jiang et al. (2019); Zhu et al. (2019). In this subsection, we explore combining adversarial pre-training with adversarial fine-tuning. Specifically, we use RoBERTa as the base model, and compare it with ALUM, which uses adversarial continual pre-training but standard fine-tuning, and ALUMRoBERTA-LARGE-SMART, which uses adversarial training in both continual pre-training and fine-tuning. Figure 2 shows the results on the development sets of MNLI and ANLI, two representative GLUE tasks. Combining adversarial pre-training and fine-tuning attains the best results, and substantially outperforms RoBERTa. E.g., on ANLI, ALUMRoBERTa-SMART outperforms ALUM by +1.1% points in accuracy, and outperforms RoBERTa by +5.1% points. On SNLI, SciTail, SWAG, and HELLASWAG, we observe similar gains by combining adversarial pre-training and fine-tuning, attaining new state-of-the-art results on these tasks. See table 7 and 8.
Conclusion
We propose ALUM, a general adversarial training algorithm, and present the first comprehensive study of adversarial training in large neural language models. We show that adversarial pre-training can significantly improves both generalization and robustness, which provides a promising direction for reconciling their conflicts as observed in prior work. ALUM substantially improved accuracy for BERT and RoBERTa in a wide range of NLP tasks, and can be combined with adversarial fine-tuning for further gain.
Future directions include: further study on the role of adversarial pre-training in improving generalization and robustness; speed up adversarial training; apply ALUM to other domains.
Acknowledgments
We thank Haoming Jiang, Tuo Zhao, Zhe Gan, Keivn Duh, Yangfeng Ji, Greg Yang, Pengchuan Zhang, Lei Zhang, Furu Wei, Li Dong, Masayuki Asahara, and Lis Pereira for valuable discussions and comments, Microsoft Research Technology Engineering team for setting up GPU machines.
References
Appendix A NLP Application Benchmarks
GLUE. The General Language Understanding Evaluation (GLUE) benchmark is a collection of nine natural language understanding (NLU) tasks. As shown in Table 9, it includes question answering Rajpurkar et al. (2016), linguistic acceptability Warstadt et al. (2018), sentiment analysis Socher et al. (2013), text similarity Cer et al. (2017), paraphrase detection Dolan and Brockett (2005), and natural language inference (NLI) Dagan et al. (2006); Bar-Haim et al. (2006); Giampiccolo et al. (2007); Bentivogli et al. (2009); Levesque et al. (2012); Williams et al. (2018). The diversity of the tasks makes GLUE very suitable for evaluating the generalization and robustness of NLU models.
SNLI. The Stanford Natural Language Inference (SNLI) dataset contains 570k human annotated sentence pairs, in which the premises are drawn from the captions of the Flickr30 corpus and hypotheses are manually annotated Bowman et al. (2015). This is the most widely used entailment dataset for NLI.
SciTail. This is a textual entailment dataset derived from a science question answering (SciQ) dataset Khot et al. (2018). The task involves assessing whether a given premise entails a given hypothesis. In contrast to other entailment datasets mentioned previously, the hypotheses in SciTail are created from science questions while the corresponding answer candidates and premises come from relevant web sentences retrieved from a large corpus. As a result, these sentences are linguistically challenging and the lexical similarity of premise and hypothesis is often high, thus making SciTail particularly difficult.
ANLI. The Adversarial Natural Language Inference (ANLI, Nie et al. (2019)) is a new large-scale NLI benchmark dataset, collected via an iterative, adversarial human-and-model-in-the-loop procedure. Specifically, the instances are chosen to be difficult for the state-of-the-art models such as BERT and RoBERTa.
SWAG. It is a large-scale adversarial dataset for the task of grounded commonsense inference, which unifies natural language inference and physically grounded reasoning Zellers et al. (2018). SWAG consists of 113k multiple choice questions about grounded situations.
HELLASWAG. It is similar to SWAG but more challenging Zellers et al. (2019). For each query in HELLASWAG, it also has 4 choices and the goal is to find the best choice among them.
SQuAD v1.1/v2.0. Stanford Question Answering Dataset (SQuAD) v1.1 and v2.0 Rajpurkar et al. (2016, 2018) are popular machine reading comprehension benchmarks. Their passages come from approximately 500 Wikipedia articles and the questions and answers are obtained by crowdsourcing. The SQuAD v2.0 dataset includes unanswerable questions about the same paragraphs.
BC2GM. The Gene Mention Task at the Biocreative II workshop Smith et al. (2008) provides an annotated dataset for gene name entity recognition.
NCBI. The NCBI disease corpus Dogan et al. (2014) contains annotations of disease mentions from a collection of PubMed abstracts.
JNLPBA. JNLBA is a biomedical entity recognition shared task Collier and Kim (2004). It is one of the largest datasets covering a large fraction of major taxonomies in molecular biology.