BertGCN: Transductive Text Classification by Combining GCN and BERT
Yuxiao Lin, Yuxian Meng, Xiaofei Sun, Qinghong Han, Kun Kuang, Jiwei Li, Fei Wu
Introduction
Text classification is a core task in natural language processing (NLP) and has been used in many real-world applications such as spam detection (Wang, 2010) and opinion mining (Bakshi et al., 2016). Transductive learning (Vapnik, 1998) is a particular method for text classification which makes use of both labeled and unlabeled examples in the training process. Graph neural networks (GNNs) serve as an effective approach for transductive learning (Yao et al., 2019; Liu et al., 2020). In these works, a graph is constructed to model the relationship between documents. Nodes in the graph represent text units such as words and documents, while edges are constructed based on the semantic similarity between nodes. GNNs are then applied to the graph to perform node classification. The merits of GNNs and transductive learning are as follows: (1) the decision for an instance (both training and test) does not depend merely on itself, but also its neighbors. This makes the model more immune to data outliers; (2) at the training time, since the model propagates influence from supervised labels across both training and test instances through graph edges, unlabeled data also contributes to the process of representation learning, and consequently higher performances.
Large-scale pretraining has recently demonstrated their effectiveness on a variety of NLP tasks (Devlin et al., 2018; Liu et al., 2019). Trained on large-scale unlabeled corpora in an unsupervised manner, large-scale pretrained models are able to learn implicit but rich text semantics in language at scale. Intuitively, large-scale pretrained models have potentials to benefit transductive learning. However, existing models for transductive text classification (Yao et al., 2019; Liu et al., 2020) did not take large-scale pretraining into consideration, and its effectiveness still remains unclear.
In this work, we propose BertGCN, a model that combines the advantages of both large-scale pretraining and transductive learning for text classification. BertGCN constructs a heterogeneous graph for the corpus with node being word or document, and node embeddings are initialized with pretrained BERT representations, and uses graph convolutional networks (GCN) for classification. By jointly training the BERT and GCN modules, the proposed model is able to leverage the advantages of both worlds: large-scale pretraining which takes the advantage of the massive amount of raw data and transductive learning which jointly learns representations for both training data and unlabeled test data by propagating label influence through graph edges. The proposed BertGCN model successfully combines the powers of large-scale pretraining and graph networks, and achieves new state-of-the-art performances on a wide range of text classification datasets.
Related Work
Graph neural networks (GNNs) are connectionist models that capture dependencies and relations between graph nodes via message passing through edges that connect nodes (Scarselli et al., 2008; Hamilton et al., 2017; Xu et al., 2018). GNNs are practically categorized into (Wu et al., 2020): graph convolutional networks (Kipf and Welling, 2016a; Wu et al., 2019), graph attention networks (Veličković et al., 2017; Zhang et al., 2018a), graph auto-encoder (Cao et al., 2016; Kipf and Welling, 2016b), graph generative networks (De Cao and Kipf, 2018; Li et al., 2018b) and graph spatial-temporal networks (Li et al., 2017; Yu et al., 2017). GNNs serve as powerful tools to utilize the relationship between different objects, and have been applied to various domains such as traffic prediction (Yu et al., 2018; Zhang et al., 2018a) and recommendation (Zhang et al., 2020; Monti et al., 2017). In the context of NLP, GNNs have achieved remarkable successes across a wide range of end tasks such as relation extraction (Zhang et al., 2018b), semantic role labeling (Marcheggiani and Titov, 2017), data-to-text generation (Marcheggiani and Perez-Beltrachini, 2018), machine translation (Bastings et al., 2017) and question answering (Song et al., 2018; De Cao et al., 2018).
The prevalence of neural networks has motivated a diverse array of works on developing neural models for text classification. Different neural model architectures (Kim, 2014; Zhou et al., 2015; Radford et al., 2018; Chai et al., 2020) have demonstrated their effectiveness against traditional statistical feature based methods (Wallach, 2006). Other works leverage label embeddings and jointly train them along with input texts (Wang et al., 2018; Pappas and Henderson, 2019). More recently, the success achieved by large-scale pretraining models has spurred great interests in adapting the large-scale pretraining framework (Devlin et al., 2018) into text classification (Reimers and Gurevych, 2019), leading to remarkable progressive on few-shot (Mukherjee and Awadallah, 2020) and zero-shot (Ye et al., 2020) learning.
Our work is inspired by the work of using graph neural networks for text classification (Yao et al., 2019; Huang et al., 2019; Zhang and Zhang, 2020). But different from these works, we focus on combining large-scale pretrained models and GNNs, and show that GNNs can significantly benefit from large-scale pretraining. Existing works that combine BERT and GNNs uses graph to model relationships between tokens within a single document sample (Lu et al., 2020; He et al., 2020b), which fall into the category of inductive learning. Different from these works, we use graph to model relationships between different samples from the whole corpus to utilize the similarity between labeled and unlabeled documents, and uses GNNs to learn their relationships.
Method
In the proposed BertGCN model, we initialize representations for document nodes in a text graph using a BERT-style model (e.g., BERT, RoBERTa). These representations are used as inputs to GCN. Document representations will then be iteratively updated based on the graph structures using GCN, the outputs of which are treated as final representations for document nodes, and are sent to the softmax classifier for predictions. In this way, we are able to leverage the complementary strengths of pretrained models and graph models.
Specifically, we construct a heterogeneous graph containing both word nodes and document nodes following TextGCN (Yao et al., 2019). We define word-document edges and word-word edges based on the term frequency-inverse document frequency (TF-IDF) and positive point-wise mutual information (PPMI), respectively. The weight of an edge between two nodes and is defined as:
We feed into a GCN model (Kipf and Welling, 2016a) which iteratively propagates messages across training and test examples. Specifically, the output feature matrix of the -th GCN layer is computed as
where represents the GCN model. We use the cross entropy loss over labeled document nodes to jointly optimize parameters for BERT and GCN.
2 Interpolating BERT and GCN Predictions
Practically, we find that optimizing BertGCN with a auxiliary classifier that directly operates on BERT embeddings leads to faster convergence and better performances. Specifically, we construct an auxiliary classifier by directly feeding document embeddings (denoted by ) to a dense layer with softmax activation:
The final training objective is the linear interpolation of the prediction from BertGCN and the prediction from BERT, which is given by:
where controls the tradeoff between the two objectives. means we use the full BertGCN model, and means we only use the BERT module. When , we are able to balance the predictions from both models, and the BertGCN model can be better optimized.
The explanation for better performances achieved by the interpolation is as follows: The directly operates on the input of GCN, making sure that inputs to GCN are regulated and optimized towards the objective. This helps the multi-layer GCN model to overcome intrinsic drawbacks such as gradient vanishing or over-smoothing (Li et al., 2018a), and thus leads to better performances.
3 Optimization using Memory Bank
The original GCN model uses the full-batch gradient descent method for training, which is intractable for the proposed BertGCN model, since the full-batch method can not be applied to BERT due to the memory limitation. Inspired by techniques in contrastive learning which decouples the dictionary size from the mini-batch size (Wu et al., 2018; He et al., 2020a), we introduce a memory bank that stores all document embeddings to decouple the training batch size from the total number of nodes in the graph.
Specifically, during training, we maintain a memory bank that tracks input features for all document nodes. At the beginning of each epoch, we first compute all document embeddings using the current BERT module and store them in . During each iteration, we sample a mini batch from both labeled and unlabeled document nodes with the index set , where is the mini-batch size. We then compute their document embeddings also using the current BERT module and update the corresponding memories in .Note that the BERT module used to compute is the one finished training in the last iteration, which is different from the the BERT module used to compute the initial . Next, we use the updated as input to derive the GCN output and compute the loss for the current mini batch. For back-propagation, is considered as constant except the records in .
With the memory bank, we are able to efficiently train the BertGCN model including the BERT module. However, during training, the embeddings in the memory bank are computed using the BERT module at different steps in an epoch and are thus inconsistent. To overcome this issue, we set a small learning rate for the BERT module to improve consistency of the stored embeddings. With low learning rate the training takes more time. In order to speed up training, we fine-tune a BERT model on the target dataset before training begins, and use it to initialize the BERT parameters in BertGCN.
Experiments
We run experiments on five widely-used text classification benchmarks: 20 Newsgroups (20NG)http://qwone.com/˜jason/20Newsgroups/, R8 and R52https://www.cs.umb.edu/~smimarog/textmining/datasets/, Ohsumedhttp://disi.unitn.it/moschitti/corpora.htm and Movie Review (MR)http://www.cs.cornell.edu/people/pabo/movie-review-data/.
We compare BertGCN to current state-of-the-art pretrained and GCN models: TextGCN (Yao et al., 2019), SGC (Wu et al., 2019), BERT (Devlin et al., 2018) and RoBERTa (Liu et al., 2019). Details for datasets and baseline are left in the supplementary material.
We follow protocols in TextGCN to preprocess data. For BERT and RoBERTa, we use the output feature of the [CLS] token as the document embedding, followed by a feedforward layer to derive the final prediction. We use BERT and a two-layer GCN to implement BertGCN. We initialize the learning rate to 1e-3 for the GCN module and 1e-5 for the fine-tuned BERT module. We also implement our model with RoBERTa and GAT (Veličković et al., 2017). GAT variants are trained over the same graph as GCN variants, but learn edge weights through attention mechanism instead of using pre-defined weight matrix.
2 Main Results
Table 1 presents the test accuracy of each model. We can see that BertGCN and RoBERTaGCN perform the best across all datasets. Only using BERT and RoBERTa generally performs better than GCN variants except 20NG, which is due to the great merits brought by large-scale pretraining. Compared with BERT and RoBERTa, the performance boost from BertGCN and RoBERTaGCN is significant on the 20NG and Ohsumed datasets. This is because the average length in 20NG and Ohsumed is much longer than that in other datasets: the graph is constructed using word-document statistics, which means that long texts may produce more document connections transited via an intermediate word node, and this potentially benefits message passing through the graph, leading to better performances when combined with GCN. This may also explain why GCN models perform better than BERT models on 20NG. For datasets with shorter documents such as R52 and MR, the power of graph structure is limited, and thus the performance boost is smaller relative to 20NG. BertGAT and RoBERTaGAT can also benefit from the graph structure, but their performance are not as good as GCN variants due to the lack of edge weight information.
3 The Effect of λ𝜆\lambda
controls the trade-off between training BertGCN and BERT. The optimal value of can be different for different tasks. Fig.9 shows the accuracy of RoBERTaGCN with different . On 20NG, the accuracy is consistently higher with larger value. This can be explained by the high performance of graph-based methods on 20NG. The model reaches its best when , performing slightly better than only using the GCN prediction ().
4 The Effect of Strategies in Joint Training
To overcome inconsistency of embeddings in the memory bank, we set a smaller learning rate for the BERT module and use a finetuned BERT model for initialization. We evaluate the effect of the two strategies. Table 2 shows the results of RoBERTaGCN on 20NG with and without these strategies. With the same learning rate for RoBERTa and GCN, the model cannot be trained due to inconsistency in the memory bank, regardless of whether the fine-tuned RoBERTa is used. Models can be successfully trained when we set a smaller learning rate for the RoBERTa module, and additional using finetuned RoBERTa leads to the best performance.
Conclusion and Future Work
In this work, we propose BertGCN, which takes the best advantages from both large-scale pretraining models and transductive learning for text classification. We efficiently train BertGCN by using a memory bank that stores all document embeddings and updates part of them with respect to the sampled mini batch. The framework of BertGCN can be built on top of any document encoder and any graph model. Experiments demonstrate the power of the proposed BertGCN model. However, in this work, we only use document statistics to build the graph, which might be sub-optimal compared to models that are able to automatically construct edges between nodes. We leave this in future work.
Acknowledgement
This work is supported by National Key R&D Program of China (2020AAA0105200) and Beijing Academy of Artificial Intelligence (BAAI).
References
Appendix A Dataset Details
The 20NG datasethttp://qwone.com/˜jason/20Newsgroups/ contains 18,846 newsgroups posts from 20 different topics. We use the bydate version which splits the dataset to 11,314 train samples and 7,532 test samples based on the posting date.
R8 and R52https://www.cs.umb.edu/~smimarog/textmining/datasets/ are two subsets of the Reuters dataset with respectively 8 and 52 categories. R8 has 5,485 training and 2,189 test documents. R52 has 6,532 training and 2,568 test documents.
The OHSUMED test collectionhttp://disi.unitn.it/moschitti/corpora.htm is a set of references from MEDLINE, the online medical information database. Following previous works, we use 7,400 documents belonging to one of the 23 disease categories to form a classification dataset, with 3,357 documents for training and 4,043 for test.
MR (Pang and Lee, 2005)http://www.cs.cornell.edu/people/pabo/movie-review-data/ is a movie review dataset for binary sentiment classification. The corpus has 10,662 reviews. We use the train/test split in Tang et al. (2015)
Appendix B Baselines
TextGCN (Yao et al., 2019): TextGCN is a model that operates graph convolution over a word-document heterogeneous graph. Node features are initialized using an identity matrix.
SGC (Wu et al., 2019): Simple Graph Convolution is a variant of GCN that reduces the complexity of GCN by removing non-linearities and collapsing weight matrices between consecutive layers.
BERT (Devlin et al., 2018): BERT is a large-scale pretrained NLP model.
RoBERTa (Liu et al., 2019): a robustly optimized BERT model that improves upon BERT with different pretraining methods.