Bridging the Gap between Training and Inference for Neural Machine Translation

Wen Zhang, Yang Feng, Fandong Meng, Di You, Qun Liu

Introduction

Neural Machine Translation has shown promising results and drawn more attention recently. Most NMT models fit in the encoder-decoder framework, including the RNN-based Sutskever et al. (2014); Bahdanau et al. (2015); Meng and Zhang (2019), the CNN-based Gehring et al. (2017) and the attention-based Vaswani et al. (2017) models, which predict the next word conditioned on the previous context words, deriving a language model over target words. The scenario is at training time the ground truth words are used as context while at inference the entire sequence is generated by the resulting model on its own and hence the previous words generated by the model are fed as context. As a result, the predicted words at training and inference are drawn from different distributions, namely, from the data distribution as opposed to the model distribution. This discrepancy, called exposure bias Ranzato et al. (2015), leads to a gap between training and inference. As the target sequence grows, the errors accumulate among the sequence and the model has to predict under the condition it has never met at training time.

Intuitively, to address this problem, the model should be trained to predict under the same condition it will face at inference. Inspired by Data As Demonstrator (DaD) Venkatraman et al. (2015), feeding as context both ground truth words and the predicted words during training can be a solution. NMT models usually optimize the cross-entropy loss which requires a strict pairwise matching at the word level between the predicted sequence and the ground truth sequence. Once the model generates a word deviating from the ground truth sequence, the cross-entropy loss will correct the error immediately and draw the remaining generation back to the ground truth sequence. However, this causes a new problem. A sentence usually has multiple reasonable translations and it cannot be said that the model makes a mistake even if it generates a word different from the ground truth word. For example,

once the model generates “abide” as the third target word, the cross-entropy loss would force the model to generate “with” as the fourth word (as cand1) so as to produce larger sentence-level likelihood and be in line with the reference, although “by” is the right choice. Then, “with” will be fed as context to generate “the rule”, as a result, the model is taught to generate “abide with the rule” which actually is wrong. The translation cand1 can be treated as overcorrection phenomenon. Another potential error is that even the model predicts the right word “by” following “abide”, when generating subsequent translation, it may produce “the law” improperly by feeding “by” (as cand2). Assume the references and the training criterion let the model memorize the pattern of the phrase “the rule” always following the word “with”, to help the model recover from the two kinds of errors and create the correct translation like cand3, we should feed “with” as context rather than “by” even when the previous predicted phrase is “abide by”. We refer to this solution as Overcorrection Recovery (OR).

In this paper, we present a method to bridge the gap between training and inference and improve the overcorrection recovery capability of NMT. Our method first selects oracle words from its predicted words and then samples as context from the oracle words and ground truth words. Meanwhile, the oracle words are selected not only with a word-by-word greedy search but also with a sentence-level evaluation, e.g. BLEU, which allows greater flexibility under the pairwise matching restriction of cross-entropy. At the beginning of training, the model selects as context ground truth words at a greater probability. As the model converges gradually, oracle words are chosen as context more often. In this way, the training process changes from a fully guided scheme towards a less guided scheme. Under this mechanism, the model has the chance to learn to handle the mistakes made at inference and also has the ability to recover from overcorrection over alternative translations. We verify our approach on both the RNNsearch model and the stronger Transformer model. The results show that our approach can significantly improve the performance on both models.

RNN-based NMT Model

Encoder. A bidirectional Gated Recurrent Unit (GRU) Cho et al. (2014) is used to acquire two sequences of hidden states, the annotation of xix_{i} is hi=[hi;hi]h_{i}=[{\overrightarrow{h}_{i}};{\overleftarrow{h}_{i}}]. Note that exie_{x_{i}} is employed to represent the embedding vector of the word xix_{i}.

Attention. The attention is designed to extract source information (called source context vector). At the jj-th step, the relevance between the target word yjy_{j}^{*} and the ii-th source word is evaluated and normalized over the source sequence

The source context vector is the weighted sum of all source annotations and can be calculated by

Decoder. The decoder employs a variant of GRU to unroll the target information. At the jj-th step, the target hidden state sjs_{j} is given by

The probability distribution PjP_{j} over all the words in the target vocabulary is produced conditioned on the embedding of the previous ground truth word, the source context vector and the hidden state

Approach

The main framework (as shown in Figure 1) of our method is to feed as context either the ground truth words or the previous predicted words, i.e. oracle words, with a certain probability. This potentially can reduce the gap between training and inference by training the model to handle the situation which will appear during test time. We will introduce two methods to select the oracle words. One method is to select the oracle words at the word level with a greedy search algorithm, and another is to select a oracle sequence at the sentence-level optimum. The sentence-level oracle provides an option of nn-gram matching with the ground truth sequence and hence inherently has the ability of recovering from overcorrection for the alternative context. To predict the jj-th target word yjy_{j}, the following steps are involved in our approach:

Use the sampled word as yj1y_{j-1} and replace the yj1y_{j-1}^{*} in Equation (6) and (7) with yj1y_{j-1}, then perform the following prediction of the attention-based NMT.

Note that the Gumbel noise is just used to select the oracle and it does not affect the loss function for training.

Sentence-Level Oracle

But a problem comes with sentence-level oracle. As the model samples from ground truth word and the sentence-level oracle word at each step, the two sequences should have the same number of words. However we can not assure this with the naive beam search decoding algorithm. Based on the above problem, we introduce force decoding to make sure the two sequences have the same length.

2 Sampling with Decay

Borrowing ideas from but being different from Bengio et al. (2015) which used a schedule to decrease pp as a function of the index of mini-batch, we define pp with a decay function dependent on the index of training epochs ee (starting from )

where μ\mu is a hyper-parameter. The function is strictly monotone decreasing. As the training proceeds, the probability pp of feeding ground truth words decreases gradually.

3 Training

After selecting yj1y_{j-1} by using the above method, we can get the word distribution of yjy_{j} according to Equation (6), (7), (8) and (9). We do not add the Gumbel noise to the distribution when calculating loss for training. The objective is to maximize the probability of the ground truth sequence based on maximum likelihood estimation (MLE). Thus following loss function is minimized:

Related Work

Some other researchers have noticed the problem of exposure bias in NMT and tried to solve it. Venkatraman et al. (2015) proposed Data As Demonstrator (DAD) which initialized the training examples as the paired two adjacent ground truth words and at each step added the predicted word paired with the next ground truth word as a new training example. Bengio et al. (2015) further developed the method by sampling as context from the previous ground truth word and the previous predicted word with a changing probability, not treating them equally in the whole training process. This is similar to our method, but they do not include the sentence-level oracle to relieve the overcorrection problem and neither the noise perturbations on the predicted distribution.

Another direction of attempts is the sentence-level training with the thinking that the sentence-level metric, e.g., BLEU, brings a certain degree of flexibility for generation and hence is more robust to mitigate the exposure bias problem. To avoid the problem of exposure bias, Ranzato et al. (2015) presented a novel algorithm Mixed Incremental Cross-Entropy Reinforce (MIXER) for sequence-level training, which directly optimized the sentence-level BLEU used at inference. Shen et al. (2016) introduced the Minimum Risk Training (MRT) into the end-to-end NMT model, which optimized model parameters by minimizing directly the expected loss with respect to arbitrary evaluation metrics, e.g., sentence-level BLEU. Shao et al. (2018) proposed to eliminate the exposure bias through a probabilistic n-gram matching objective, which trains NMT NMT under the greedy decoding strategy.

Experiments

We carry out experiments on the NIST Chinese\rightarrowEnglish (Zh\rightarrowEn) and the WMT’14 English\rightarrowGerman (En\rightarrowDe) translation tasks.

In training the NMT model, we limit the source and target vocabulary to the most frequent 3030K words for both sides in the Zh\rightarrowEn translation task, covering approximately 97.797.7% and 99.399.3% words of two corpus respectively. For the En\rightarrowDe translation task, sentences are encoded using byte-pair encoding (BPE) Sennrich et al. (2016) with 37k37k merging operations for both source and target languages, which have vocabularies of 3941839418 and 4027440274 tokens respectively. We limit the length of sentences in the training datasets to 5050 words for Zh\rightarrowEn and 128128 subwords for En\rightarrowDe. For RNNSearch model, the dimension of word embedding and hidden layer is 512512, and the beam size in testing is 1010. All parameters are initialized by the uniform distribution over [0.1,0.1]\left[-0.1,0.1\right]. The mini-batch stochastic gradient descent (SGD) algorithm is employed to train the model parameters with batch size setting to 8080. Moreover, the learning rate is adjusted by adadelta optimizer Zeiler (2012) with ρ\rho=0.950.95 and ϵ\epsilon=1e-61e\textnormal{-}6. Dropout is applied on the output layer with dropout rate being 0.50.5. For Transformer model, we train base model with default settings (fairseqhttps://github.com/pytorch/fairseq).

2 Systems

Our implementation of an improved model as described in Section 2, where the decoder employs two GRUs and an attention. Specifically, Equation 6 is substituted with:

Our implementation of the scheduled sampling (SS) method Bengio et al. (2015) on the basis of the RNNsearch. The decay scheme is the same as Equation 15 in our approach.

Our implementation of the mixed incremental cross-entropy reinforce Ranzato et al. (2015), where the sentence-level metric is BLEU and the average reward is acquired according to its offline method with a 11-layer linear regressor.

Based on the RNNsearch, we introduced the word-level oracles, sentence-level oracles and the Gumbel noises to enhance the overcorrection recovery capacity. For the sentence-level oracle selection, we set the beam size to be 33, set τ\tau=0.50.5 in Equation (11) and μ\mu=1212 for the decay function in Equation (15). OR-NMT is the abbreviation of NMT with Overcorrection Recovery.

3 Results on Zh→→\rightarrowEn Translation

We verify our method on two baseline models with the NIST Zh\rightarrowEn datasets in this section.

As shown in Table 1, Tu et al. (2016) propose to model coverage in RNN-based NMT to improve the adequacy of translations. Shen et al. (2016) propose minimum risk training (MRT) for NMT to directly optimize model parameters with respect to BLEU scores. Zhang et al. (2017) model distortion to enhance the attention model. Compared with them, our baseline system RNNsearch 1) outperforms previous shallow RNN-based NMT system equipped with the coverage model Tu et al. (2016); and 2) achieves competitive performance with the MRT Shen et al. (2016) and the Distortion Zhang et al. (2017) on the same datasets. We hope that the strong shallow baseline system used in this work makes the evaluation convincing.

We also compare with the other two related methods that aim at solving the exposure bias problem, including the scheduled sampling Bengio et al. (2015) (SS-NMT) and the sentence-level training Ranzato et al. (2015) (MIXER). From Table 1, we can see that both SS-NMT and MIXER can achieve improvements by taking measures to mitigate the exposure bias. While our approach OR-NMT can outperform the baseline system RNNsearch and the competitive comparison systems by directly incorporate the sentence-level oracle and noise perturbations for relieving the overcorrection problem. Particularly, our OR-NMT significantly outperforms the RNNsearch by +2.362.36 BLEU points averagely on four test datasets. Comparing with the two related models, our approach further gives a significant improvements on most test sets and achieves improvement by about +1.21.2 BLEU points on average.

Results on the Transformer

The methods we propose can also be adapted to the stronger Transformer model. The evaluated results are listed in Table 1. Our word-level method can improve the base model by +0.540.54 BLEU points on average, and the sentence-level method can further bring in +1.01.0 BLEU points improvement.

4 Factor Analysis

We propose several strategies to improve the performance of approach on relieving the overcorrection problem, including utilizing the word-level oracle, the sentence-level oracle, and incorporating the Gumbel noise for oracle selection. To investigate the influence of these factors, we conduct the experiments and list the results in Table 2.

When only employing the word-level oracle, the translation performance was improved by +1.211.21 BLEU points, this indicates that feeding predicted words as context can mitigate exposure bias. When employing the sentence-level oracle, we can further achieve +0.620.62 BLEU points improvement. It shows that the sentence-level oracle performs better than the word-level oracle in terms of BLEU. We conjecture that the superiority may come from a greater flexibility for word generation which can mitigate the problem of overcorrection. By incorporating the Gumbel noise during the generation of the word-level and sentence-level oracle words, the BLEU score are further improved by 0.560.56 and 0.530.53 respectively. This indicates Gumbel noise can help the selection of each oracle word, which is consistent with our claim that Gumbel-Max provides a efficient and robust way to sample from a categorical distribution.

5 About Convergence

In this section, we analyze the influence of different factors for the convergence. Figure 4 gives the training loss curves of the RNNsearch, word-level oracle (WO) without noise and sentence-level oracle (SO) with noise. In training, BLEU score on the validation set is used to select the best model, a detailed comparison among the BLEU score curves under different factors is shown in Figure 5. RNNsearch converges fast and achieves the best result at the 77-th epoch, while the training loss continues to decline after the 77-th epoch until the end. Thus, the training of RNNsearch may encounter the overfitting problem.

Figure 4 and 5 also reveal that, integrating the oracle sampling and the Gumbel noise leads to a little slower convergence and the training loss does not keep decreasing after the best results appear on the validation set. This is consistent with our intuition that oracle sampling and noises can avoid overfitting despite needs a longer time to converge.

Figure 6 shows the BLEU scores curves on the MT03 test set under different factorsNote that the “SO” model without noise is trained based on the pre-trained RNNsearch model (as shown by the red dashed lines in Figure 5 and 6).. When sampling oracles with noise (τ\tau=0.50.5) on the sentence level, we obtain the best model. Without noise, our system converges to a lower BLEU score. This can be understood easily that using its own results repeatedly during training without any regularization will lead to overfitting and quick convergence. In this sense, our method benefits from the sentence-level sampling and Gumbel noise.

6 About Length

Figure 7 shows the BLEU scores of generated translations on the MT03 test set with respect to the lengths of the source sentences. In particular, we split the translations for the MT03 test set into different bins according to the length of source sentences, then test the BLEU scores for translations in each bin separately with the results reported in Figure 7. Our approach can achieve big improvements over the baseline system in all bins, especially in the bins (1010,2020], (4040,5050] and (7070,8080] of the super-long sentences. The cross-entropy loss requires that the predicted sequence is exactly the same as the ground truth sequence which is more difficult to achieve for long sentences, while our sentence-level oracle can help recover from this kind of overcorrection.

7 Effect on Exposure Bias

To validate whether the improvements is mainly obtained by addressing the exposure bias problem, we randomly select 11K sentence pairs from the Zh\rightarrowEn training data, and use the pre-trained RNNSearch model and proposed model to decode the source sentences. The BLEU score of RNNSearch model was 24.8724.87, while our model produced +2.182.18 points. We then count the ground truth words whose probabilities in the predicted distributions produced by our model are greater than those produced by the baseline model, and mark the number as N\mathcal{N}. There are totally 28,26628,266 gold words in the references, and N\mathcal{N}=18,39118,391. The proportion is 18,391/28,26618,391/28,266=65.06%65.06\%, which could verify the improvements are mainly obtained by addressing the exposure bias problem.

8 Results on En→→\rightarrowDe Translation

We also evaluate our approach on the WMT’14 benchmarks on the En\rightarrowDe translation task. From the results listed in Table 3, we conclude that the proposed method significantly outperforms the competitive baseline model as well as related approaches. Similar with results on the Zh\rightarrowEn task, both scheduled sampling and MIXER could improve the two baseline systems. Our method improves the RNNSearch and Transformer baseline models by +1.591.59 and +1.311.31 BLEU points respectively. These results demonstrate that our model works well across different language pairs.

Conclusion

The end-to-end NMT model generates a translation word by word with the ground truth words as context at training time as opposed to the previous words generated by the model as context at inference. To mitigate the discrepancy between training and inference, when predicting one word, we feed as context either the ground truth word or the previous predicted word with a sampling scheme. The predicted words, referred to as oracle words, can be generated with the word-level or sentence-level optimization. Compared to word-level oracle, sentence-level oracle can further equip the model with the ability of overcorrection recovery. To make the model fully exposed to the circumstance at reference, we sample the context word with decay from the ground truth words. We verified the effectiveness of our method with two strong baseline models and related works on the real translation tasks, achieved significant improvement on all the datasets. We also conclude that the sentence-level oracle show superiority over the word-level oracle.

Acknowledgments

We thank the three anonymous reviewers for their valuable suggestions. This work was supported by National Natural Science Foundation of China (NO. 61662077, NO. 61876174) and National Key R&D Program of China (NO. YS2017YFGH001428).

References