Learned Token Pruning for Transformers

Sehoon Kim, Sheng Shen, David Thorsley, Amir Gholami, Woosuk Kwon, Joseph Hassoun, Kurt Keutzer

Introduction

Transformer-based deep neural network architectures (Vaswani et al., 2017), such as BERT (Devlin et al., 2018) and RoBERTa (Liu et al., 2019), achieve state-of-the-art results in Natural Language Processing (NLP) tasks such as sentence classification and question answering. However, efficiently deploying these models is increasingly challenging due to their large size, the need for real-time inference, and the limited energy, compute, and memory resources available. The heart of a transformer layer is the multi-head self-attention mechanism, where each token in the input sequence attends to every other token to compute a new representation of the sequence. Because all tokens attend to each others, the computation complexity is quadratic with respect to the input sequence length; thus the ability to apply transformer models to long input sequences becomes limited.

Pruning is a popular technique to reduce the size of neural networks and the amount of computation required. Unstructured pruning allows arbitrary patterns of sparsification for parameters and feature maps and can, in theory, produce significant computational savings while preserving accuracy. However, commodity DNN accelerators cannot efficiently exploit unstructured sparsity patterns. Thus, structured pruning methods are typically preferred in practice due to their relative ease of deployment to hardware.

Multi-head self-attention provides several possibilities for structured pruning; for example, head pruning (Michel et al., 2019; Voita et al., 2019) decreases the size of the model by removing unneeded heads in each transformer layer. Another orthogonal approach that we consider in this paper is token pruning, which reduces computation by progressively removing unimportant tokens in the sequence during inference. For NLP tasks such as sentence classification, token pruning is an attractive approach to consider as it exploits the intuitive observation that not all tokens (i.e., words) in an input sentence are necessarily required to make a successful inference.

There are two main classes of token pruning methods. In the first class, methods like PoWER-BERT (Goyal et al., 2020) and Length-Adaptive Transformer (LAT) (Kim and Cho, 2020) search for a single token pruning configuration (i.e., sequence length for each layer) for an entire dataset. In other words, they prune all input sequences to the same length. However, input sequence lengths can vary greatly within tasks and between training and validation sets as in Figure 1, and thus applying a single pruning configuration to all input sequences can potentially under-prune shorter sequences or over-prune longer sequences.

In the other class, the token pruning method adjusts the configuration based on the input sequence. SpAtten (Wang et al., 2020b) uses a pruning configuration proportional to input sentence length; however, it does not adjust the proportion of pruned tokens based on the content of the input sequence. The recently published TR-BERT (Ye et al., 2021) uses reinforcement learning (RL) to find a policy network that dynamically reduces the number of tokens based on the length and content of the input sequence; however, it requires additional costly training for convergence of the RL-based method. Additionally, all of these prior methods rely in part on selecting the kk most significant tokens during inference or training. This selection can be computationally expensive without the development of specialized hardware, such as the top-kk engine introduced in SpAtten (Wang et al., 2020b).

To this end, we propose a learned threshold-based token pruning method which adapts to the length and content of individual examples and avoids the use of top-kk operations. In particular, our contributions are as follows:

We propose Learned Token Pruning (LTP), a threshold-based token pruning method, which only needs a simple threshold operation to detect unimportant tokens. In addition, LTP fully automates the search for optimal pruning configurations by introducing a differentiable soft binarized mask that allows training the correct thresholds for different layers and tasks. (Section 3.3)

We apply LTP to RoBERTa and evaluate its performance on GLUE and SQuAD tasks. We show LTP achieves up to 2.10×\times FLOPs reduction with less than 1% accuracy degradation, which results in up to 1.93×\times and 1.97×\times throughput improvement on NVIDIA V100 GPU and Intel Haswell CPU, respectively, as compared to the unpruned FP16 baseline. We also show that LTP outperforms SpAtten and LAT in most cases, achieving additional FLOPs reduction for the same drop in accuracy. (Section 4.2 and 4.5)

We show that LTP is highly robust against sentence length variations. LTP exhibits consistently better accuracy over different sentence length distributions, achieving up to 16.4% accuracy gap from LAT. (Section 4.3)

Related Work

Multiple different approaches have been proposed to improve speed and diminish memory footprint of transformers. These can be broadly categorized as follows: (i) efficient architecture design (Lan et al., 2019; Child et al., 2019; Kitaev et al., 2020; Wang et al., 2020a; Iandola et al., 2020; Vyas et al., 2020; Tay et al., 2020; Katharopoulos et al., 2020; Zaheer et al., 2020; Roy et al., 2021); (ii) knowledge distillation (Sun et al., 2020; Jiao et al., 2019; Tang et al., 2019; Sanh et al., 2019; Sun et al., 2019); (iii) quantization (Bhandare et al., 2019; Zafrir et al., 2019; Shen et al., 2020; Fan et al., 2020; Zadeh et al., 2020; Zhang et al., 2020; Bai et al., 2020; Kim et al., 2021); and (iv) pruning. Here, we focus only on pruning and briefly discuss the related work.

2. Transformer Pruning

Pruning methods can be categorized into unstructured pruning and structured pruning. For unstructured pruning, the lottery-ticket hypothesis (Frankle and Carbin, 2018) has been explored for transformers in (Prasanna et al., 2020; Chen et al., 2020). Recently, (Zhao et al., 2020) leverages pruning as an effective way to fine-tune transformers on downstream tasks. (Sanh et al., 2020) proposes movement pruning, which achieves significant performance improvements versus prior magnitude-based methods by considering the weights modification during fine-tuning. However, it is often quite difficult to efficiently deploy unstructured sparsity on commodity neural accelerators for meaningful speedup.

For this reason, a number of structured pruning methods have been introduced to remove structured sets of parameters. (Michel et al., 2019; Voita et al., 2019) drop attention heads in multi-head attention layers, and (Sajjad et al., 2020; Fan et al., 2019) prunes entire transformer layers. (Wang et al., 2019) structurally prunes weight matrices via low-rank factorization, and (Khetan and Karnin, 2020; Lin et al., 2020) attempt to jointly prune attention heads and filters of weight matrices. (Liu et al., 2021; Hou et al., 2020) dynamically determines structured pruning ratios during inference. Recent block pruning schemes chunk weight matrices into multiple blocks and prune them based on group Lasso optimization (Li et al., 2020), adaptive regularization (Yao et al., 2021), and movement pruning (Lagunas et al., 2021). All of these methods correspond to weight pruning, where model parameters (i.e., weights) are pruned.

Recently, there has been work on pruning input sentences to transformers, rather than model parameters. This is referred to as token pruning, where less important tokens are progressively removed during inference. PoWER-BERT (Goyal et al., 2020), one of the earliest works, proposes to directly learn token pruning configurations. LAT (Kim and Cho, 2020) extends this idea by introducing LengthDrop, a procedure in which a model is trained with different token pruning configurations, followed by an evolutionary search. This method has an advantage that the former training procedure need not be repeated for different pruning ratios of the same model. While these methods have shown a large computation reduction on various NLP downstream tasks, they fix a single token pruning configuration for the entire dataset. That is, they prune all input sequences to the same length. However, as shown in Figure 1, input sequence lengths vary greatly within a task. As a consequence, fixing a single pruning configuration can under-prune shorter sequences so as to retain sufficient tokens for processing longer sequences or, conversely, over-prune longer sequences to remove sufficient tokens to efficiently process shorter sequences. More importantly, a single pruning configuration lacks robustness against input sequence length variations, where input sentences at inference time are longer than those in the training dataset (Press et al., 2021).

In contrast, SpAtten (Wang et al., 2020b) circumvents this issue by assigning a pruning configuration proportional to the input sequence length. While this allows pruning more tokens from longer sequences and fewer tokens from shorter ones, it is not adaptive to individual input sequences as it assigns the same configuration to all sequences with the same length regardless of their contents. In addition, the pruning configurations are determined heuristically and thus can result in a suboptimal solution. Recently, TR-BERT (Ye et al., 2021) proposes to learn a RL policy network to apply adaptive pruning configurations for each input sequence. However, as noted by the authors, the problem has a large search spaces which can be hard for RL to solve. This issue is mitigated by heuristics involving imitation learning and sampling of action sequences, which significantly increases the cost of training. Importantly, all of the aforementioned token pruning methods depend partially or entirely on top-kk operation for selecting the kk most important tokens during inference or training. This operation can be costly without specialized hardware support, as discussed in (Wang et al., 2020b). LTP, on the other hand, is based on a fully learnable threshold-based pruning strategy. Therefore, it is (i) adaptive to both input length and content, (ii) robust to sentence length variations, (iii) computationally efficient, and (iv) easy to deploy.

Methodology

where Eq. 3 is the residual connection and the follow up LayerNorm (LN). The output of the MHA is then fed into the FFN block which applies two feed-forward layers to this input:

where W1,W2,b1{\mathbf{W}}_{1},{\mathbf{W}}_{2},b_{1} and b2b_{2} are the FFN parameters, and σ\sigma is the activation function (typically GELU for BERT).

2. Threshold Token Pruning

Let us denote the attention probability of head hh between token xi and xj as A(h,l){\mathbf{A}}^{(h,l)}:

The cost of computational complexity for computing the attention matrix is O(d2n+n2d)\mathcal{O}(d^{2}n+n^{2}d), which quadratically scales with sequence length. As such, the attention operation becomes a bottleneck when applied to long sequences. To address this, we apply token pruning which removes unimportant tokens as the input passes through the transformer layers to reduce the sequence length nn for later blocks. This is schematically shown in Figure 2 (Left).

For token pruning, we must define a metric to determine unimportant tokens. Following (Goyal et al., 2020; Wang et al., 2020b; Kim and Cho, 2020), we define the importance score of token xi in layer ll as:

Intuitively, the attention probability A(h,l)(xi,xj){\mathbf{A}}^{(h,l)}(\text{x}_{i},\text{x}_{j}) is interpreted as the normalized amount that all the other tokens xj attend to token xi. Token xi is thus considered important if it receives more attention from all tokens across all heads, which directly leads us to equation 7. The procedure for computing importance scores from attention probabilities is illustrated in Figure 2 (Right).

Note that this operation only requires a simple comparison operator without any expensive top-kk calculation. Once a token is pruned, it is excluded from calculations in all succeeding layers, thereby gradually reducing computation complexity towards the output layers.

3. Learnable Threshold for Token Pruning

A key concern with the method above is how to determine the threshold values for each layer. Not only do threshold values change for different layers, they also vary between different tasks. We address this by making the thresholds (i.e., θ\theta in Eq. 8) learnable. However, there are several challenges to consider. First, due to the binary nature of MM there is no gradient flow for pruned tokens. Second, the MM operator is non-differentiable which prevents gradient flow into the thresholds. To address these challenges, we use a soft pruning scheme that simulates the original hard pruning while still propagating gradients to the thresholds as shown in Figure 3.

Soft Pruning Scheme. In the soft pruning scheme, we replace the non-differentiable mask M(l)M^{(l)} with a differentiable soft mask using the sigmoid operation σ\sigma:

Here, L\mathcal{L} is the original loss function (e.g., cross-entropy loss), and λ\lambda is the regularization parameter. Larger values of λ\lambda result in higher pruning ratios. This regularization operator induces an additional gradient to the threshold:

Experiments

We implemented LTP on RoBERTabase{}_{\text{base}} (Liu et al., 2019) using HuggingFace’s repohttps://github.com/huggingface/transformers/ and tested on (English) GLUE tasks (Wang et al., 2018) and SQuAD 2.0 (Rajpurkar et al., 2018). For GLUE tasks (Wang et al., 2018), we use 6 tasks for evaluation including sentence similarity (QQP (Iyer et al., 2017), MRPC (Dolan and Brockett, 2005), STS-B (Cer et al., 2017)), sentiment classification (SST-2 (Socher et al., 2013)), textual entailment (RTE (Dagan et al., 2005)) and natural language inference (MNLI (Williams et al., 2017), QNLI (Rajpurkar et al., 2016)). For evaluating the results, we measure classification accuracy and F1 score for MRPC and QQP, Pearson Correlation and Spearman Correlation for STS-B, and classification accuracy for the remaining tasks on validation sets. For the tasks with multiple metrics (i.e., MRPC, QQP, STS-B), we report their average. For SQuAD 2.0 (Rajpurkar et al., 2018), which is a question and answering task, we measure F1 score for evaluating the results.

As mentioned in Section 3.3, the training procedure of LTP consists of two stages: soft pruning that trains both the model parameters and thresholds on downstream tasks, followed by hard pruning that fine-tunes the model parameters with fixed thresholds. We also compare LTP with the current state-of-the-art token pruning methods of SpAtten (Wang et al., 2020b) and LAT (Kim and Cho, 2020) following the implementation details in their papers. See A.1 for the details of the training process. We use PyTorch 1.8 throughout all experiments. For CPU inference speed experiments, we use an Intel Haswell CPU with 3.75GB memory of Google Cloud Platform. For GPU inference speed experiments, we use an AWS p3.2xlarge instance that has a NVIDIA V100 GPU with CUDA 11.1.

An important issue in previous work (Goyal et al., 2020; Kim and Cho, 2020) is that all input sequences for a specific task are padded to the nearest power of 2 from the 99th percentile of the sequence lengths, and then the pruned performance is compared with the padded baseline. This results in exaggerated performance gain over the baseline. For instance, in (Goyal et al., 2020), inputs from the SST-2 dataset are padded to 64, while its average sentence length is 26 (cf. Figure 1). With this approach, one can achieve roughly 2.5×2.5\times speedup by just removing padding. As such, we avoid any extra padding of input sequences and all speedups and throughputs we report are compared with the unpadded baselines.

2. Performance Evaluation

Table 1 lists the accuracy and GFLOPs for LTP. We select a model for each downstream task that achieves the smallest GFLOPs while constraining the accuracy degradation from the baseline (RoBERTabase{}_{\text{base}}) to be at most 1%. Using our method, sequence lengths in each layer can vary across different input sentences. Therefore, we report the averaged GFLOPs of processing all input sentences in the development set. As shown in the table, our method achieves speedup of 1.96×\times on average and up to 2.10×\times within 1% accuracy degradation.

Figure 4 plots the accuracy of LTP (blue lines) as well as the prior pruning methods (red lines for SpAtten and orange lines for LAT) with different FLOPs on GLUE tasks. LTP consistently outperforms SpAtten for all tasks with up to ~2% higher accuracy under the same amount of FLOPs. Compared with LAT, LTP outperforms for all tasks except for QQP with up to ~2.5% higher accuracy for the same target FLOPs. For QQP alone, LTP achieves at most ~0.2% lower accuracy than LTP.

An important observation is that for SST-2 and STS-B where LTP (ours) outperforms LAT with large margins, the sequence length varies greatly from the training dataset to the evaluation dataset as can be seen from the large KL-divergence in Table 2 and Figure 1 (b, c). On the other hand, for QQP, the only dataset that LAT slightly outperforms LTP (ours), the sequence length distribution of the training dataset is almost identical to that of the evaluation dataset as can be seen from the small KL-divergence in Table 2 and Figure 2 (a). This observation supports our claim in Section 1 and 2: LTP is robust to sequence length variations as it does not fix the pruning configuration unlike other methods using a single pruning configuration regardless of the input sequence length. This is important in practice because the sequence lengths during inference do not always follow the sequence length distribution of the training dataset as in SST-2 and STS-B. We make a further discussion in Section 4.3.

For SQuAD 2.0, we have similar results to GLUE. As can be seen in Table 1 and Figure 5 (Left), we obtain nearly identical F1 score to baseline at 0.58 relative FLOPs, and 1.89×\times speedup with less than 1% drop of F1 score. The SQuAD 2.0 dataset is divided into two subsets: the subset of examples where the answer to the question is included in the context text, and the subset that has no answer. In Figure 5 (Right), we further plot the results on each subset of the dataset (black and red for the one with and without answers, respectively). We see that the F1 score decreases for the subset with answers and increases for the subset without answers as we decrease the relative FLOPs. This is to be expected as the question answering head will predict no answer if the start and end points of the answer within the context cannot be determined due to high token pruning ratios. Thus, a careful setting of λ\lambda in Eq. 12 is necessary to balance the accuracy between the two subsets.

At last, we also highlight that LTP has an additional gain over the prior top-kk based approaches by avoiding computationally inefficient top-kk operations as further discussed in Section A.2.

3. Robustness to Sequence Length Variation

In Section 4.2, we claim that LTP is more robust against sequence length variations from training time to evaluation time. Here, we make a more systematic analysis on this. Ideally, performance should be independent of sequence length. To quantitatively test the robustness of pruning methods against sequence length variations, we train LTP and LAT on QNLI and QQP, but only using the training examples whose sequence lengths are below the median length of the evaluation dataset. We then evaluate the resulting models using the evaluation examples with sequence lengths (i) below the median (~Q2), (ii) between the median and the third quantile (Q2~Q3), and (iii) above the third quantile (Q3~) of the evaluation dataset. To make a fair comparison, we choose models from LTP and LAT that require similar FLOPs on ~Q2.

The results are listed in Table 3. LTP consistently achieves better accuracy and FLOPs over different sequence lengths, even with those that are significantly longer than the training sequences. On the contrary, LAT shows significant accuracy degradation as longer sequences are over-pruned, which can be seen from the significant FLOPs reduction. In particular, LTP outperforms LAT by up to 16.44% and 9.20% on QNLI and QQP for the Q3~ evaluation dataset.

4. Ablation Studies

Instead of learning thresholds, we can set them manually. Because manually searching over the exponential search space is intractable, we add a constraint to the search space by assigning linearly rising threshold values for each layer, similar to how SpAtten (Wang et al., 2020b) assigns the token retain ratios: given the threshold of the final layer θ(L)\theta^{(L)}, the threshold for layer ll is set as θ(L)l/L\theta^{(L)}l/{L}. We plot the accuracy and FLOPs of the manual threshold approach in Figure 4 as black lines. While this approach exhibits decent results on all downstream tasks, the learned thresholds consistently outperform the manual thresholds under the same FLOPs. This provides empirical evidence for the effectiveness of our threshold learning method.

5. Direct Throughput Measurement on Hardware

We directly measure throughputs on real hardware by deploying LTP on a NVIDIA V100 GPU and a Intel Haswell CPU. For inference, we completely remove the pruned tokens and rearrange the retained tokens into a blank-free sequence to have a latency gain. One consequence of adaptive pruning, however, is that each sequence will end up with a different pruning pattern and sequence length. As such, a naive hardware implementation of batched inference may require padding all the sequences in a batch to ensure that they all have the same length (i.e., the maximum sequence length in the batch), which results in a significant portion of computation being wasted to process padding tokens. To avoid this, we use NVIDIA’s Faster Transformerhttps://github.com/NVIDIA/FasterTransformer for GPU implementation that requires large batch sizes. This framework dynamically removes and inserts padding tokens during inference so that most of the transformer operations effectively skip processing padding tokens. This enables fast inference even with irregular pruning lengths of individual sequences. For the CPU implementation, we find naive batching (i.e., padding sequences to the maximum sentence length) enough for good throughput.

The measured throughput results are shown in Figure 6 for different batch sizes. For all experiments, relative throughput is evaluated 3 times on the randomly shuffled datasets. LTP achieves up to \sim1.9×\times and \sim2.0×\times thoughput improvement for QNLI and QQP on both CPU and GPU, as compared to the baseline. This is similar to the theoretical speedup inferred from the FLOPs reduction reported in Table 1. Importantly, the speedup of LTP increases with larger batch sizes on both CPU and GPU, proving effectiveness of LTP on batched cases.

6. LTP with Quantization and Knowledge Distillation

Here, we show that our token-level pruning method is compatible with other compression methods. In particular, we perform compression experiments by combining LTP with quantization and knowledge distillation (Hinton et al., 2015) together. For quantization, we use the static uniform symmetric integer quantization method (Gholami et al., 2021), which is easy to deploy in commodity hardware with minimal run-time overhead. All the model parameters are quantized to 8-bit integers, except for those of the embedding layer whose bit-width does not affect the inference speed. Afterwards, we apply knowledge distillation that helps recover accuracy for high compression ratios. We set the baseline RoBERTabase{}_{\text{base}} model as the teacher and the quantized LTP model as the student. We then distill knowledge from the teacher model into the student model through a knowledge distillation loss that matches the output logits of the classification layer and the output representations of the embedding layer in the teacher model to the counterparts in the student model. The training objective is a convex combination of the original loss and the knowledge distillation loss. As shown in Figure 7, we achieve up to 10×\times reduction in bit operations (BOPs) with less than 2%2\% accuracy degradation as compared to FP16 RoBERTabase{}_{\text{base}} by combining quantization and knowledge distillation. The results empirically show the effectiveness of LTP with other compression methods.

Conclusions

In this work, we present Learned Token Pruning (LTP), a fully automated token pruning framework for transformers. LTP only requires comparison of token importance scores with threshold values to determine unimportant tokens, thus adding minimal complexity over the original transformer inference. Importantly, the threshold values are learned for each layer during training through a differentiable soft binarized mask that enables backpropagation of gradients to the threshold values. Compared to the state-of-the-art token pruning methods, LTP outperforms by up to ~2.5% accuracy with the same amount of FLOPs. Extensive experiments on GLUE and SQuAD show the effectiveness of LTP, as it achieves up to 2.10×\times FLOPs reduction over the baseline model within only 1% of accuracy degradation. Our preliminary (and not highly optimized) implementation shows up to 1.9×\times and 2.0×\times throughput improvement on an Intel Haswell CPU and a NVIDIA V100 GPU. Furthermore, LTP exhibits significantly better robustness and consistency over different input sequence lengths.

References

Appendix A Appendix

The training procedure of LTP consists of two separate stages: soft pruning followed by hard pruning. For soft pruning, we train both the model parameters and the thresholds on downstream tasks for 1 to 10 epochs, depending on the dataset size. We find it effective to initialize the thresholds with linearly rising values as described in 4.4 with a fixed threshold of the final layer. We search the optimal temperature TT in a search space of {1, 2, 5, 10, 20}e-4 and vary λ\lambda from 0.001 to 0.4 to control the number of tokens to be pruned (and thus the FLOPs) for all experiments. We then fix the thresholds and perform an additional training with the hard pruning to fine-tune the model parameters only. More detailed hyperparameter settings are listed in Table A.1 for GLUE and SQuAD 2.0.

SpAtten is trained based on the implementation details in the paper: the first three layers retain all tokens and the remaining layers are assigned with linearly decaying token retain ratio until it reaches the final token retain ratio at the last layer. We vary the final token retain ratio from 1.0 to -1.0 (prune all tokens for non-positive retain ratios) to control the FLOPs of SpAtten. For both LTP and SpAtten, we use learning rate of {0.5, 1, 2}e-5, except for the soft pruning stage of LTP where we use 2e-5. We follow the optimizer setting in RoBERTa (Liu et al., 2019) and use batch size of 64 for all experiments.

LAT is trained using the same hyperparameter and optimizer setting in the paper except for the length drop probabilities: for more extensive search on more aggressive pruning configurations, we used 0.25, 0.3, 0.35, and 0.4 for the length drop probability instead of 0.2 in the original setting.

A.2. Computation Efficiency Comparison

Here we compare the efficiency of top-kk versus threshold operation. To do this, we use a batch size of 32 and average the latency over 1000 independent runs. For each sequence length, we test over five different token retain ratios from 10% to 50% (e.g., 10% token retain ratio is the case where we select top-kk 10% of tokens from the input sequence).

With the above setting, we directly measure the latency of these two operations on an Intel Haswell CPU, and report the results in Figure A.1. For top-kk operation, there is a noticeable increase in latency when token retain ratios and sequence lengths become larger whereas this is not an issue for our threshold pruning method as it only requires a comparison operation. More importantly, top-kk operation incurs a huge latency overhead that is up to 7.4×\times and 33.4×\times slower than threshold operation for sequence length of 128 and 1024, respectively. The inefficiency of top-kk is also further confirmed by (Wang et al., 2020b), where they report only 1.1×\times speedup for GPT-2 without the top-kk hardware engine that they developed.

A.3. Discussion

Figure A.2 shows how the pruned sequence length decreases for input sequences of varying lengths. For LAT, the token pruning configuration is fixed for all sequences in the dataset. In LTP, token pruning can be more or less aggressive depending on the sequence content and the number of important tokens in the sequence. On average, LTP calculates 25.86% fewer tokens per layer than LAT for MNLI-m and 12.08% fewer tokens for SST-2. For both LTP and LAT, the model has been trained to produce a 1% drop in accuracy compared to baseline.

A.3.2. Unbiased Token Pruning for Various Sequence Length

Figure A.3 shows the distributions of initial sequence lengths for sequences that are correctly classified and for sequences that are not. We see that for multiple tasks, there is no significant correlation between the length of the sequence and the accuracy of the pruned models. Importantly, this suggests that our method is not biased towards being more accurate on longer or shorter sequences.

A.4. Comparison with TR-BERT on GLUE

Unlike LAT and SpAtten, TR-BERT (Ye et al., 2021) does not report results on the GLUE benchmark tasks described in the paper. We attempted to run TR-BERT on the GLUE tasks using the TR-BERT repohttps://github.com/thunlp/TR-BERT, but were unable to get the algorithm to converge to a high accuracy, despite varying the learning rate between 1e-6 and 1e-3 and the value of α\alpha, the parameter that defines the length penalty, over the search space of {0.01,0.05,0.1,0.5,1,2,5}\{0.01,0.05,0.1,0.5,1,2,5\}. We also varied the number of training epochs based on the number of examples in each task’s training set. The authors of TR-BERT note the convergence difficulties of RL learning while describing the algorithm in their paper.