Non-Autoregressive Neural Machine Translation with Enhanced Decoder Input

Junliang Guo, Xu Tan, Di He, Tao Qin, Linli Xu, Tie-Yan Liu

Introduction

The neural network based encoder-decoder framework has achieved very promising performance for machine translation and different network architectures have been proposed, including RNNs (?; ?; ?; ?), CNNs (?), and self-attention based Transformer (?). All those models translate a source sentence in an autoregressive manner, i.e., they generate a target sentence word by word from left to right (?) and the generation of tt-th token yty_{t} depends on previously generated tokens y1:t1y_{1:t-1}:

Since AT models generate target tokens sequentially, the inference speed becomes a bottleneck for real-world translation systems, in which fast response and low latency are expected. To speed up the inference of machine translation, non-autoregressive models (?) have been proposed, which generate all target tokens independently and simultaneously. Instead of using previously generated tokens as in AT models, NAT models take other global signals derived from the source sentence as input. Specifically, Non-AutoRegressive Transformer (NART) (?) takes a copy of source sentence xx as the decoder input, and the copy process is guided by fertilities (?) which represents how many times each source token will be copied; after that all target tokens are simultaneously predicted:

where x^=(x^1,...,x^Ty)\hat{x}=(\hat{x}_{1},...,\hat{x}_{T_{y}}) is the copied source sentence and TyT_{y} is the length of the target sentence yy.

While NAT models significantly reduce the inference latency, they suffer from accuracy degradation compared with their autoregressive counterparts. We notice that the encoder of AT models and that of NAT models are the same; the differences lie in the decoder. In AT models, the generation of the tt-th token yty_{t} is conditioned on previously generated tokens y1:t1y_{1:t-1}, which provides strong target side context information. In contrast, as NART models generate tokens in parallel, there is no such target-side information available. Although the fertilities are learned to cover target-side information in NART (?), such information contained in the copied source tokens x^\hat{x} guided by fertilities is indirect and weak because the copied tokens are still in the domain of source language, while the inputs of the decoder of AT models are target-side tokens y1:t1y_{1:t-1}. Consequently, the decoder of a NAT model has to handle the translation task conditioned on less and weaker information compared with its AT counterpart, thus leading to inferior accuracy. As verified by our study (see Figure 2 and Table 3), NART performs poorly for long sentences, which need stronger target-side conditional information for correct translation than short sentences.

In this paper, we aim to enhance the decoder inputs of NAT models so as to reduce the difficulty of the task that the decoder needs to handle. Our basic idea is to directly feed target-side tokens as the inputs of the decoder. We propose two concrete methods to generate the decoder input y^=(y^1,...,y^Ty)\hat{y}=(\hat{y}_{1},...,\hat{y}_{T_{y}}) which contains coarse target-side information. The first one is based on a phrase table, and explicitly translates source tokens into target-side tokens through such a pre-trained phrase table. The second one linearly maps the embeddings of source tokens into the target-side embedding space and then the mapped embeddings are fed into the decoder. The mapping is learned in an end-to-end manner by minimizing the L2L_{2} distance of the mapped source and target embeddings in the sentence level as well as the adversary loss between the mapped source embeddings and target embeddings in the word level.

With target-side information as inputs, the decoder works as follows:

where y^\hat{y} is the enhanced decoder input provided by our methods. The decoder now can generate all yty_{t}’s in parallel conditioned on the global information y^\hat{y}, which is more close to the target tokens y1:t1y_{1:t-1} as in the AT model. In this way, the difficulty of the task for the decoder is largely reduced.

We conduct experiments on three tasks to verify the proposed method. On WMT14 English-German, WMT16 English-Romanian and IWSLT14 German-English translation tasks, our model outperforms all compared non-autoregressive baseline models. Specifically, we obtain BLEU scores of 24.2824.28 and 34.5134.51 which outperform the non-autoregressive baseline (19.1719.17 and 29.7929.79 reported in ? (?)) on WMT14 En-De and WMT16 En-Ro tasks.

Background

Deep neural network with encoder-decoder framework has achieved great success on machine translation, with different choices of architectures such as recurrent neural networks (RNNs) (?; ?), convolutional neural networks (CNNs) (?), as well as self-attention based transformer (?; ?). Early RNNs based models have an inherently sequential architecture that prevents them from being parallelized during training and inference, which is partially solved by CNNs and self-attention based models (?; ?; ?; ?; ?). Since the entire target translation is exposed to the model at training time, each input token of the decoder is the previous ground truth token and the whole training can be parallel given the well-designed CNNs or self-attention models. However, the autoregressive nature still creates a bottleneck at inference stage, since without ground truth, the prediction of each target token has to condition on previously predicted tokens. See Table 1 for a clear comparison between models about whether they are parallelizable.

2 Non-Autoregressive Neural Machine Translation

We generally denote the decoder input as z=(z1,...,zTy)z=(z_{1},...,z_{T_{y}}) to be consistent in the rest of our paper, which represents x^\hat{x} and y^\hat{y} in Equation (2) and (3). The recently proposed non-autoregressive model NART (?) breaks the inference bottleneck by exposing all decoder inputs to the network simultaneously. The generation of zz is guided by the fertility prediction function which represents how many target tokens that each source token can translate to, and then repeatedly copy source tokens w.r.t their corresponding fertilities as the decoder input zz. Given zz, the conditional probability P(yx)P(y|x) is defined as:

where TyT_{y} is the length of target sentence, which equals to the summation of all fertility numbers. θenc\theta_{\textrm{enc}} and θdec\theta_{\textrm{dec}} denote the parameter of the encoder and decoder. The negative log-likelihood loss function for NAT model becomes:

Although non-autoregressive models can achieve 15×15\times speedup compared to autoregressive models, they are also suffering from accuracy degradation. Since the conditional dependencies within the target sentence (yty_{t} depends on y1:t1y_{1:t-1}) are removed from the decoder input, the decoder is unable to leverage the inherent sentence structure for prediction. Hence the decoder has to figure out such target-side information by itself just with the source-side information during training, which is a much more challenging task compared to its autoregressive counterpart. From our study, we find the NART model fails to handle the target sentence generation well. It usually generates repetitive and semantically incoherent sentences with missing words, as shown in Table 3. Therefore, strong conditional signals should be introduced as the decoder input to help the model learn better internal dependencies within a sentence.

Methodology

As discussed in Section 1, to improve the accuracy of NAT models, we need to enhance the inputs of the decoder. We introduce our model, Enhanced Non-Autoregressive Transformer (ENAT), in this section. We design two kinds of enhanced inputs: one is token level enhancement based on phrase-table lookup and the other one is embedding level enhancement based on embedding mapping. The illustration of the phrase-table lookup and embedding mapping can be found in Figure 1.

Previous NAT models take tokens in the source language in as decoder inputs, which make the decoding task difficult. Considering that AT models takes (already generated) target tokens as inputs, a straightforward idea to enhance decoder inputs is to also feed tokens in the target language into the decoder of NAT models. Given a source sentence, a simple method to get target tokens is to translate those source tokens to target tokens using a phrase table, which brings negligible latency in inference.

To implement this idea, we pre-train a phrase table based on the bilingual training corpus utilizing Moses (?), an open-source statistic machine translation (SMT) toolkit. We then greedily segment the source sentence into TpT_{p} phrases and translate the phrases one by one according to the phrase table. The details are as follows. We first calculate the maximum length LL among all the phrases contained in the phrase table. For ii-th source word xix_{i}, we first check whether phrase xi:i+Lx_{i:i+L} has a translation in the phrase table; if not then check xi:i+L1x_{i:i+L-1}, and so on. If there exists a phrase translation for xi:i+Ljx_{i:i+L-j}, then translate it and check the translation started at xi+Lj+1x_{i+L-j+1} following the same strategy. This procedure only brings 0.140.14ms latency per sentence on average over the newstest2014 test set on an Intel Xeon E5-2690 CPU, which is negligible compared with the whole inference latency (e.g., 2525 to 200200+ ms) of the NAT model, as shown in Table 2.

Note that to reduce inference latency, we only search the phrase table to obtain a course phrase-to-phrase translation, without utilizing the full procedure (including language model scoring and tree-based searching). During inference, we generate zz by the phrase table lookup and skip phrases that do not have translations.

2 Embedding Mapping

As the phrase table is pre-trained from SMT systems, it cannot be updated/optimized during NAT model training, and may lead to poor translation quality if the table is not very accurate. Therefore, we propose the embedding mapping approach, which first linearly maps the source token embeddings to target embeddings and feeds them into the decoder as inputs. This linear mapping can be trained end-to-end together with NAT models.

Since we already have the sentence-level alignment from the training set, we can minimize the L2L_{2} distance between the mapped source embeddings and the ground truth target embeddings in the sentence level:

where e(x)=1Txi=1Txe(xi)e(x)=\frac{1}{T_{x}}\sum_{i=1}^{T_{x}}e(x_{i}) is the embedding of source sentence xx which is simply calculated by the average of embeddings of all source tokens. e(y)e(y) is the embedding of target sentence yy which is defined in the same way.

As the regularization in Equation (7) just ensures the coarse alignment between the sentence embeddings which is simply the summation of each word embeddings, it misses the fine-grained token-level alignment. Therefore, we propose the word-level adversary learning, considering we do not have the supervision signal of word-level mapping. Specifically, we use Generative Adversarial Network (GAN) (?) to regularize the the projection matrix WW, which is widely used in NLP tasks such as unsupervised word translation (?) and text generation (?). The discriminator fDf_{D} takes an embedding as input and outputs a confidence score between and 11 to differentiate the embeddings mapped from source tokens, i.e., EzE_{z}, and the ground truth embedding of the target tokens, i.e., EyE_{y}, during training. The linear mapping function fGf_{G} acts as the generator whose goal is to make fGf_{G} able to provide plausible EzE_{z} that is indistinguishable to EyE_{y} in the embedding space, to fool the discriminator. We implement the discriminator by a two-layers multi-layer perceptron (MLP). Although other architectures such as CNNs can also be chosen, we find that the simple MLP has achieved fairly good performance.

Formally, given the linear mapping function fG(;W)f_{G}(\cdot;W), i.e., the generator, and the discriminator fD(;θD)f_{D}(\cdot;\theta_{D}), the adversarial training objective LadvL_{\textrm{adv}} can be written as:

where VwordV_{\textrm{word}} is the word-level value function which encourages every word in zz and yy to be distinguishable:

where e(xj)e(x_{j}) and e(yi)e(y_{i}) indicates the embedding of jj-th source and ii-th target token respectively. In conclusion, for each training pair (x,y)(x,y), along with the original negative log-likelihood loss Lneg(x,y)L_{\textrm{neg}}(x,y) defined in Equation (5), the total loss function of our model is:

where Θ=(θenc,θdec,W)\Theta=(\theta_{\textrm{enc}},\theta_{\textrm{dec}},W) and θD\theta_{D} consist of all parameters that need to be learned, while μ\mu and λ\lambda are hyper-parameters that control the weight of different losses.

3 Discussion

The approach of phrase-table lookup is simple and efficient. It achieves considerable performance in experiments by providing direct token-level enhancements, when the phrase table is good enough. However, when training data is messy and noisy, the generated phrase table might be of low quality and consequently hurts NAT model training. We observe that the phrase table trained by Moses can obtain fairly good performance on small and clean datasets such as IWSLT14 but very poor on big and noisy datasets such as WMT14. See Section 5.3 for more details. In contrast, the approach of embedding mapping learns to adjust the mapping function together with the training of NAT models, resulting in more stable results.

As for the two components proposed in embedding mapping, the sentence-level alignment LalignL_{\textrm{align}} leverages bilingual supervisions which can well guide the learning of the mapping function, but lacks the fine-grained word-level mapping signal; word-level adversary loss LadvL_{\textrm{adv}} can provide complimentary information to LalignL_{\textrm{align}}. Our ablation study in Section 5.3 (see Table 5) verify the benefit of combining the two loss functions.

Experimental Setup

We evaluate our model on three widely used public machine translation datasets: IWSLT14 De-Enhttps://wit3.fbk.eu/, WMT14 En-Dehttps://www.statmt.org/wmt14/translation-task and WMT16 En-Rohttps://www.statmt.org/wmt16/translation-task, which has 153153K/4.54.5M/2.92.9M bilingual sentence pairs in corresponding training sets. For WMT14 tasks, newstest2013 and newstest2014 are used as the validation and test set respectively. For the WMT16 En-Ro task, newsdev2016 is the validation set and newstest2016 is used as the test set. For IWSLT14 De-En, we use 7K data split from the training set as the validation set and use the concatenation of dev2010, tst2010, tst2011 and tst2012 as the test set, which is widely used in prior works (?; ?). All the data are tokenized and segmented into subword tokens using byte-pair encoding (BPE) (?) , and we share the source and target vocabulary and embeddings in each language pair. The phrase table is extracted from each training set by Moses (?), and we follow the default hyper-parameters in the toolkit.

2 Model Configurations

We also use multi-head self attention and encoder-to-decoder attention, as well as feed forward networks for decoder, as used in Transformer (?). Considering the enhanced decoder input is of the same word order of the source sentence, we add the multi-head positional attention to rearrange the local word orders within a sentence, as used in NART (?). Therefore, the three kinds of attentions along with residual connections (?) and layer normalization (?) constitute our model.

To enable a fair comparison, we use same network architectures as in NART (?). Specifically, for WMT14 and WMT16 datasets, we use the default hyper-parameters of the base model described in ? (?), whose encoder and decoder both have 66 layers and the size of hidden state and embeddings are set to 512512, and the number of heads is set to 88. As IWSLT14 is a smaller dataset, we choose to a smaller architecture as well, which consists of a 55-layer encoder and a 55-layer decoder. The size of hidden state and embeddings are set to 256256, and the number of heads is set to 44.

3 Training and Inference

We follow the optimizer settings in ? (?). Models on WMT/IWSLT tasks are trained on 88/11 NVIDIA M40 GPUs respectively. We set μ=0.1\mu=0.1 and λ=1.0\lambda=1.0 in Equation (10) for all tasks to ensure LnegL_{\textrm{neg}}, LalignL_{\textrm{align}} and LadvL_{\textrm{adv}} are in the same scale. We implement our model on Tensorflow (?). We provide detailed description of the knowledge distillation and the inference stage below.

Sequence-Level Knowledge Distillation During training, we apply the same knowledge distillation method used in (?; ?; ?). We first train an autoregressive teacher model which has the same architecture as the non-autoregressive student model, and collect the translations of each source sentence in the training set by beam search, which are then used as the ground truth for training the student. By doing so, we provide less noisy and more deterministic training data which make the NAT model easy to learn (?; ?; ?). Specifically, we pre-train the state-of-the-art Transformer (?) architecture as the autoregressive teacher model, and the beam size while decoding is set to 44.

As for the efficiency, the decoder input zz is obtained through table-lookup or the multiplication between dense matrices, which brings negligible additional latency. The teacher model rescoring procedure introduced above is fully parallelizable as it is identical to the teacher forcing training process in autoregressive models, and thus will not increase the latency much. We analyze the inference latency per sentence and demonstrate the efficiency of our model in experiment.

Results

We compare our model with non-autoregressive baselines including NART (?), a semi-non-autoregressive model Latent Transformer (LT) (?) which incorporates an autoregressive module into NART, as well as Iterative Refinement NAT (IR-NAT) (?) which trains extra decoders to iteratively refine the translation output, and we list the “Adaptive” results reported in their paper. We also compare with strong autoregressive baselines that based on LSTM (?; ?) and self-attention (?). We also list the translation quality purely by lookup from the phrase table, denoted as Phrase-Table Lookup, which serves as the decoder input in the hard model. For inference latency, the average per-sentence decoding latency on WMT14 En-De task over the newstest2014 test set is also reported, which is conducted on a single NVIDIA P100 GPU to keep consistent with NART (?). Results are shown in Table 2.

Among different datasets, our model achieves state-of-the-art performance all non-autoregressive baselines. Specifically, our model outperforms NART with rescoring 1010 candidates from 4.264.26 to 5.625.62 BLEU score on different tasks. Comparing to autoregressive models, our model is only 1.11.1 BLEU score behind its Transformer teacher at En-Ro tasks, and we also outperforms the state-of-the-art LSTM-based baseline (?) on IWSLT14 De-En task. The promising results demonstrate that the proposed method can make the decoder easy to learn by providing a strong input close to target tokens and result in a better model. For inference latency, NART needs to first predict the fertilities of source sentence before the translation process, which is slower than the phrase-table lookup procedure and matrix multiplication in our method. Moreover, our method outperforms NART with rescoring 100100 candidates on all tasks, but with nearly 55 times faster, which also demonstrate the advantages of the enhanced decoder input.

Translation Quality w.r.t Different Lengths We compare the translation quality between AT (?), NART (?) and our method with regard to different sentence lengths. We conduct the analysis on WMT14 En-De test set and divide the sentence pairs into different length buckets according to the length of reference sentence. The results are shown in Figure 2. It can be seen that as sentence length increases, the accuracy of NART model drops quickly and the gap between AT and NART model also enlarges. Our method achieves more improvements over the longer sentence, which demonstrates that NART perform worse on long sentence, due to the weak decoder input, while our enhanced decoder input provides strong conditional information for the decoder, resulting more accuracy improvements on these sentences.

2 Case Study

We conduct several case studies on IWSLT14 De-En task to intuitively demonstrate the superiority of our model, listed in Table 3.

As we claimed in Section 1, the NART model tends to repetitively translate same words or phrases and sometimes misses meaningful words, as well as performs poorly while translating long sentences. In the first case, NART fails to translate a long sentence due to the weak signal provided by the decoder input, while both of our models successfully translate the last half sentence thanks to the strong information carried in our decoder input. As for the second case, NART translates “to you” twice, and misses “all of”, which therefore result in a wrong translation, while our model achieves better translation results again.

3 Method Analysis

Phrase-Table Lookup v.s. Embedding Mapping We have proposed two different approaches to provide decoder input with enhanced quality, and we make a comparison between the two approaches in this subsection.

According to Table 2, the phrase-table lookup achieves better BLEU scores in IWSLT14 De-En and WMT14 De-En task, and the embedding mapping performs better on the other two tasks. We find the performance of the first approach is related to the quality of phrase table, which can be judged by the BLEU score of the Phrase-to-Phrase translation. As IWSLT14 De-En is a cleaner and smaller dataset, the pre-trained phrase table tends to have good quality (with BLEU score 15.69 as shown in Table 2), therefore it is able to provide an accurate enough signal to the decoder. Although WMT14 En-De and WMT16 En-Ro dataset are much larger, the phrase tables are of low quality (with BLEU score 6.03 in WMT14 En-De and 9.16 in WMT16 En-Ro), which may provides noise signals such as missing too much tokens and misguide the learning procedure. Therefore, our embedding mapping outperforms the phrase-table lookup by providing implicit guidance and allow the model adjust the decoder input in a way of end-to-end learning.

Varying the Quality of Decoder Input We study how the quality of decoder input influence the performance of the NAT model. We mainly analyze in the phrase-table lookup approach as it is easy to change the quality of decoder input with word-table. After obtained the phrase table by Moses from the training data, we further extract the word table from the phrase table following the word alignments. Then we can utilize word-table lookup by the extracted word table as the decoder input zz, which provides relatively weaker signals compared with the phrase-table lookup. We measure the BLEU score directly between the phrase/word-table lookup and the reference, as well as between the NAT model outputs and the reference in WMT14 En-De test set, listed in Table 4. The quality of the word-table lookup is relatively poor compared with the phrase-table lookup. Under this circumstance, the signal provided by the decoder input will be weaker, and thus influence the accuracy of NAT model.

Ablation Study on Embedding Mapping We conduct an ablation study in this subsection to study the different components in the embedding mapping approach, i.e., the sentence-level alignment and word-level adversary learning. Results are shown in Table 5. Sentence-level alignment LalignL_{\textrm{align}} slightly outperforms the word-level adversary learning LadvL_{\textrm{adv}}. However, adding LadvL_{\textrm{adv}} to LalignL_{\textrm{align}} improves the BLEU score to 24.1324.13, which illustrates that the complimentary information provided by two loss functions is indispensable.

Conclusion

We targeted at improving accuracy of non-autoregressive translation models and proposed two methods to enhance the decoder inputs of NAT models: one based on a phrase table and the other one based on word embeddings. Our methods outperform the baseline on all tasks by BLEU scores ranging from 3.473.47 to 5.025.02.

In the future, we will extend this study from several aspects. First, we will test our methods on more language pairs and larger scale datasets. Second, we will explore better methods to utilize the phrase table. For example, we may sample multiple candidate target tokens (instead of using the one with largest probability in this work) for each source token and feed all the candidates into the decoder. Third, it is interesting to investigate better methods (beyond the phrase table and word embedding based methods in this work) to enhance the decoder inputs and further improve translation accuracy for NAT models.

Acknowledgements

This research was supported by the National Natural Science Foundation of China (No. 61673364, No. 91746301) and the Fundamental Research Funds for the Central Universities (WK2150110008).

References