Learning Kernel-Smoothed Machine Translation with Retrieved Examples
Qingnan Jiang, Mingxuan Wang, Jun Cao, Shanbo Cheng, Shujian Huang, Lei Li
Introduction
Over the past years, end-to-end Neural Machine Translation (NMT) models have achieved great success Bahdanau et al. (2015); Wu et al. (2016); Vaswani et al. (2017). How to effectively update a deployed NMT model and adapt to emerging cases? For example, after a generic NMT model trained on WMT data, a customer wants to use service to translate financial documents. The costomer may have a handful of translation pairs for the finance domain, but do not have the capacity to perform a full retraining.
Non-parametric adaptation methods enable incorporating individual examples on-the-fly, by retrieving similar source-target pairs from an external database to guide the translation process Bapna and Firat (2019); Gu et al. (2018); Zhang et al. (2018); Cao and Xiong (2018). The external database can be easily updated online. Most of these methods rely on effective sentence-level retrieval. Different from sentence retrieval, -nearest-neighbour machine translation introduces token level retrieval to improve translation Khandelwal et al. (2021). It shows promising results for online domain adaptation.
There are still limitations for existing non-parametric methods for online adaptation. First, since it is not easy for sentence-level retrieval to find examples that are similar enough to the test example, this low overlap between test examples and retrieved examples brings noise to translation Bapna and Firat (2019). Second, completely non-parametric methods are prone to overfit the retrieved examples. For example, although NN-MT improves domain-specific translation, it overfits severely and can not generalize to the general domain, as is shown in Figure 1. An ideal online adaptation method should introduce less noise to the translation process and generalize to the changeful test examples with the incrementally changing database.
In this paper, we propose to learn Kernel-Smoothed Translation with Example Retrieval (KSTER), to effectively learn and adapt neural machine translation online. KSTER retains the online adaptation advantage of non-parametric methods and avoids the drawback of easy overfitting. More specifically, KSTER improves the generalization ability of non-parametric methods in three aspects. First, we introduce a learnable kernel to dynamically measure the relevance of the retrieved examples based on the current context. Then, the exampled-based distribution is combined with the model-based distribution computed by NMT with adaptive mixing weight for next word prediction. Further, to make the learning of KSTER stable, we introduce a retrieval dropout strategy. The intuition is that similar examples can constantly be retrieved during training, but not the same situation during inference. We therefore drop the most similar examples during training to reduce this discrepancy.
With above improvements, KSTER shows the following advantages:
Extensive experiments show that, KSTER outperforms NN-MT, a strong competitor, in specific domains for 1.1 to 1.5 BLEU scores while keeping the performance in general domain.
KSTER outperforms NN-MT for 1.8 BLEU scores on average in unseen domains. Therefore, there is no strong restriction of the input domain, which makes KSTER much more practical for industry applications.
Related Work
This work is mostly related to two research areas in machine translation (MT), i.e., domain adaptation for machine translation and online adaptation of MT models by non-parametric retrieval.
Domain adaptation for MT aims to adapt general domain MT models for domain-specific language translation Chu and Wang (2018). The most popular method for this task is finetuning general domain MT models on in-domain data. However, finetuning suffers from the notorious catastrophic forgetting problem McCloskey and Cohen (1989); Santoro et al. (2016). There are also some sparse domain adaptation methods that only update part of the MT parameters Bapna et al. (2019); Wuebker et al. (2018); Liang et al. (2020); Guo et al. (2021); Lin et al. (2021); Zhu et al. (2021).
In real-world translation applications, the domain labels of test examples are often not available. This dilemma inspires a closely related research area — multi-domain machine translation Pham et al. (2021); Farajian et al. (2017); Liang et al. (2020), where one model translates sentences from all domains.
Online Adaptation of MT by Non-parametric Retrieval
Non-parametric methods enable online adaptation of deployed NMT models by updating the database from which similar examples are retrieved.
Traditional non-parametric methods search sentence-level examples to guide the translation process Cao and Xiong (2018); Gu et al. (2018); Zhang et al. (2018). Recently, n-gram level Bapna and Firat (2019) and token level Khandelwal et al. (2021) retrieval are introduced and shows strong empirical results. Generally, similar examples are retrieved based on fuzzy matching Bulte and Tezcan (2019); Xu et al. (2020), embedding similarity, or a mixture of the two approaches Bapna and Firat (2019). There are also works that utilize off-the-shelf search engine to retrieve similar translation examples He et al. (2021); Xia et al. (2019).
Methodology
In this section, we first formulate the kernel-smoothed machine translation (KSTER), which smooths neural machine translation (NMT) output with retrieved token level examples. Then we introduce the modeling and training of the learnable kernel and adaptive mixing weight. The overview of KSTER is shown in Figure 2.
The state-of-the-art NMT models are based on the encoder-decoder architecture. The encoder encodes the source text into a sequence of hidden states. The decoder takes the representations of the source text as input and generates target text auto-regressively. In each decoding step, the decoder predicts the probability distribution of next tokens conditioned on the source text and previously generated target tokens . The NMT model parameters are denoted as .
Kernel Smoothing with Retrieved Examples
The model-based distribution is then smoothed by an example-based distribution . It is computed using kernel density estimation (KDE) on retrieved examples.
We build a database from which similar examples are retrieved. The database consists of all token level examples from the training set in the form of key-value pairs. In each key-value pair , the key is the intermediate representation of a certain layer in the NMT decoder. The value is the corresponding ground truth target token . The key can be seen as a vector representation of the context of value . We obtain the key-value pairs from by running force-decoding with a trained NMT model.
In each decoding step, we compute the query and retrieve similar examples based on the distance For two -dimension vectors and , we compute the distance between and as . query and keys. Each retrieved example forms a triple , where is the key; is the corresponding value token and is the distance between query and key . The example-based distribution is then computed with these retrieved examples using the following equation.
where is the kernel function and is the parameter of the kernel.
Finally, the NMT output is smoothed by combing the model-based distribution and the example-based distribution with a mixing weight .
2 Learnable Kernel Function
Although all kernel functions can be used in KDE, we choose two specific kernels in this study — Gaussian kernel and Laplacian kernel, since they are easy to parameterize.
The only parameter in Gaussian kernel is the bandwidth , it controls the smoothness of the example-based distribution, as is shown in Figure 3.
In a learnable Gaussian kernel, the bandwidth is not a fixed hyper-parameter. Instead, it is estimated in each decoding step by a learned affine network with exponential activation.
is average-pooled keys and are trainable parameters.
The bandwidth of learnable Laplacian kernel is modeled in the same way as the bandwidth of learnable Gaussian kernel.
3 Adaptive Mixing of Base Prediction and Retrieved Examples
To mix the model-based distribution and example-based distribution adaptively, we model the mixing weight with a learnable neural network.
The mixing weight is computed by a multi-layer perceptron with query and weighted sum of keys as inputs, where are trainable parameters.
In this way, NN-MT Khandelwal et al. (2021) could be seen as a specific case of KSTER, with fixed Gaussian kernel and mixing weight.
4 Training
We optimize the KSTER model by minimizing the cross entropy loss between the mixed distribution and ground truth target tokens:
where is the length of a target sentence .
We keep the NMT model parameters fixed. Only parameters of learnable kernel and mixing weight are trained.
5 Retrieval Dropout
Since the database is built from the training data and KSTER is trained on the training data, similar examples can constantly be retrieved from the database during training. However, in test time, there may be no example in the database that is similar to the query. This discrepancy between training and inference may lead to overfitting.
We propose a simple training strategy called retrieval dropout to alleviate this problem. During training, we search top similar examples instead of top examples. Then we drop the most similar example and use the remaining examples for training. Retrieval dropout is not used in test time.
Experiments
We evaluate the proposed methods on two machine translation tasks: domain adaptation for machine translation (DAMT) and multi-domain machine translation (MDMT). In DAMT, in-domain translation model is built for each specific domain, since the domain labels of examples are available in test time. In MDMT, the domain labels of examples are not available in test time, so examples from all domains are translated with one model, which is a more practical setting.
We conduct experiments in EN-DE translation and DE-EN translation. We use WMT14 EN-DE dataset Bojar et al. (2014) as general domain training data, which consists of 4.5M sentence pairs. newstest2013 and newstest2014 are used as the general domain development set and test set, respectively. 5 domain-specific datasets proposed by Koehn and Knowles (2017) and re-splited by Aharoni and Goldberg (2020)https://github.com/roeeaharoni/unsupervised-domain-clusters are used to evaluate the domain-specific translation performance. The detailed statistics of the 5 datasets are shown in Table 1.
Implementation Details
We use joint Byte Pair Encoding (BPE) Sennrich et al. (2016) with 30k merge operations for subword segmentation. The resulted vocabulary is shared between source and target languages. We employ Transformer Base Vaswani et al. (2017) as the base model. Following Khandelwal et al. (2021), the normalized inputs of feed forward network in the last Transformer decoder block are used as keys to build the databases and queries for retrieval. The translation performance is evaluated with detokenized BLEU scores Papineni et al. (2002), computed by SacreBLEU Post (2018) https://github.com/mjpost/sacrebleu.
We build a FAISS Johnson et al. (2017) index for nearest neighbour search. We employ inverted file and product quantization to accelerate retrieval in large scale databases. The keys of examples are stored in the fp16 format to reduce the memory demand. We set to keep a balance between translation quality and inference speed.
We train the base model for 200k steps. The best 5 checkpoints are averaged to obtain the final model. We train KSTER for 30k steps. For the training procedures of all models, each batch contains 32,768 tokens approximately. The models are optimized by Adam optimizer Kingma and Ba (2015) with learning rates set to 0.0002.
KSTER introduced 526k trainable parameters, which is 0.85% of the base model. We implement all the models based on JoeyNMT Kreutzer et al. (2019) https://github.com/joeynmt/joeynmt.
2 Domain Adaptation for Machine Translation
We build individual database for each specific domain with in-domain training data in DAMT. The sizes of databases are shown in Table 2.
We compare the proposed method with the following baselines.
Base base model trained on general-domain data.
Finetuning base model trained on general domain dataset and then finetuned with in-domain data for each specific domain individually.
NN-MT NN-MT with in-domain database individually, where the hyper-parameters are tuned on development set of each domain.
The KSTER model is trained for each specific domain individually for fair comparison .
Main results
The DAMT experiment results are shown in Table 3. For domain-specific translation, KSTER outperforms NN-MT for 1.2 and 1.4 BLEU scores on average in EN-DE and DE-EN translation respectively. Finetuning achieves best domain-specific performance on average. However, the performance of finetuned models on general domain drops significantly due to the catastrophic forgetting problem. The even worse general domain performance of NN-MT indicates that it overfits to the retrieved examples severely. KSTER performs far better than finetuning and NN-MT in general domain, which shows strong generalization ability. We notice that KSTER with Laplacian kernel performs slightly better than Gaussian kernel, since KSTER with Gaussian kernel tends to ignore the long-tailed retrieved examples.
Robustness test
The performance of MT model with non-parametric retrieval is influenced by the size and quality of database. Khandelwal et al. (2021) have studied how translation performance of NN-MT changes with the size of database. In this work, we study the performance change of NN-MT and KSTER with low-quality database. Specifically, we test the robustness of these models in DAMT when the database is noisy.
We add token-level noise to the English sentences in parallel training data by EDA Wei and Zou (2019) We do not experiment with adding noise to the German side, since German WordNet is not available for us, which is necessary for synonym replacement. For each word in a sentence, it is modified with a probability of . The candidate modifications contain synonym replacement, random insertion, random swap and random deletion with equal probability. Then we use the noisy training data to construct the noisy database.
We study the effects of source side noise and target side noise on translation performance. The experiment results are presented in Table 4. Target side noise has more negative effect to translation performance than source side noise. The BLEU scores of KSTER drop less apparently in all settings, which indicates that the proposed method is more robust with low-quality database.
3 Multi-Domain Machine Translation
In MDMT, since there is no domain label available in test time, examples from all domains are translated with one model. We build a mixed database with training data of general domain and 5 specific domains, which is used in all MDMT experiments. The mixed database for EN-DE translation and DE-EN translation contains 172M and 167M key-value pairs respectively.
We compare the proposed method with the following baselines.
Base base model trained on general domain dataset.
Joint-training base model trained on the mixture of general domain dataset and 5 specific domain datasets.
NN-MT NN-MT with mixed database. The hyper-parameters are selected that achieve highest averaged development set BLEU scores over general domain and 5 specific domains.
We sample 500k training examples from general domain training set, which are then mixed with all 5 specific domain training examples for KSTER training.
Main results
The experiment results of MDMT are shown in Table 5. For general domain sentence translation, KSTER outperforms NN-MT for 3 and 6 BLEU scores in EN-DE and DE-EN direction respectively. For domain-specific translation, KSTER outperforms NN-MT for 1.5 and 1.1 BLEU scores in EN-DE and DE-EN direction. Besides, KSTER also outperforms joint-training in both general domain performance and averaged domain-specific performance significantly. The proposed method achieves advantages over joint-training in both online adaptation and translation performance.
General-specific domain performance trade-off
We plot the general domain performance and averaged domain-specific performance of NN-MT with different hyper-parameter selections in Figure 4. We find that NN-MT performs better in domain-specific translation when the system prediction relies more on the searched examples (low bandwidth and higher mixing weight). In contrast, better general domain translation performance is achieved when the system prediction relies more on NMT prediction (high bandwidth and low mixing weight).
There is a trade-off between general and specific domain performance in NN-MT. Applying the identical kernel and mixing weight for all test examples can not achieve best performance in general domain and specific domains simultaneously.
KSTER with Gaussian kernel, which is a generalization of NN-MT, achieves better performance in both general domain and domain-specific translation since it applies adaptive kernel and mixing weight for different test examples. Distributions in Figure 5 indicates that KSTER learns different kernels and different weights for different examples.
Generalization ability over unseen domains
We test the generalization ability of baselines and KSTER with Laplacian kernel in unseen domains, which is important in real-world MDMT applications. We take Bible and QED from OPUS Tiedemann (2012)https://opus.nlpl.eu/ as unseen domains and randomly sample 2k examples from each domain for test. We directly use the MDMT models to translate sentences from unseen domains. The results of EN-DE translation are presented in Table 6. KSTER outperforms all baselines, which shows strong generalization ability.
4 Inference Speed
A common concern about non-parametric methods in MT is that searching similar examples may slow the inference speed. We test the inference speed KSTER in MDMT in EN-DE translation, which is the setting with the largest database. The averaged inference time in general domain and 5 specific domain test sets of NN-MT is 1.15 times of the base model. The averaged inference time of KSTER is 1.19 times of the base model, which is only slightly slower than the baselines.
Analysis
In this section, we first conduct ablation studies to verify the effectiveness of each part of the proposed method. Then we conduct detailed analysis on how kernel-smoothing with retrieved examples helps translation.
In KSTER, both the kernel and mixing weight are learnable. We study the effect of keeping only one of the two parts learnable in MDMT.
We take KSTER with Gaussian kernel for analysis. The ablation experiment results are presented in Table 7.
Both learnable kernel and learnable mixing weight bring improvement in both general domain and domain-specific translation. Keeping the two parts learnable simultaneously brings additional improvement. Overall, learnable mixing weight is more important than learnable kernel function.
KSTER outperforms k𝑘kNN-MT with all k𝑘k selections
We conduct ablation study over different selections in both DAMT and MDMT settings in EN-DE translation. We experiment with four selections — {4, 8, 16, 32}, and plot the results in Figure 6. In DAMT, KSTER achieves best performance with . In MDMT, the performance of our method increases with . With all selections, KSTER outperforms NN-MT consistently.
Retrieval dropout improves generalization
We study the effect of retrieval dropout in MDMT and select the KSTER with Laplacian kernel for analysis.
We plot the general domain and averaged domain-specific translation performance of models trained with or without retrieval dropout in Figure 7. Without retrieval dropout, the performance of both general domain and domain-specific translation drops dramatically. The discrepancy between training and inference leads to severe overfitting. This problem is alleviated by the proposed retrieval dropout, which shows that this training strategy improves the generalization ability of KSTER.
2 Fine-grained Effects of Kernel-smoothing with Retrieved Examples on Translation
For better understanding the effects of kernel-smoothing with retrieved examples on translation, we study the following two research questions.
RQ1. Which types of word kernel-smoothing influences most?
RQ2. Does kernel-smoothing help word sense disambiguation?
To study the first research question, we categorize the predicted words with their Part-of-Speech tags (POS tags). In each decoding step, if the predicted word obtains the highest probability of example-based distribution but it does not obtain the highest probability of model-based distribution, it is recognized as a prediction determined by kernel-smoothing with retrieved examples.
We compute the ratio of predictions determined by kernel-smoothing across different POS tags. We conduct this analysis on DAMT task in EN-DE direction and select Medical and Subtitles domains as representatives.
We report the results in Figure 8. Medical and Subtitles represent two opposite cases where non-parametric retrieval contributes more in the former and contributes less in the latter. We find that across the 2 different domains, kernel-smoothing contributes most to the predictions of verbs, adverbs and nouns, which are morphologically complex word types. Retrieving words in similar context may helps selecting the correct form of morphologically complex words.
Kernel-smoothing helps word sense disambiguation
In kernel-smoothing, we search examples with similar keys — contextualized hidden states. We hypothesize that the retrieved examples contains useful context information which helps word sense disambiguation. We test this hypothesis with contrastive translation pairs.
A contrastive translation pair contains a source, a reference and one or more contrastive translations. Contrastive translations are constructed by replacing a word in reference with a word which is another translation of an ambiguous word in the source. NMT systems are used to score reference and contrastive translations. If an NMT system assign higher score to reference than all contrastive translations in an example, the NMT system is recognized as making correct prediction on this example.
We use ContraWSD Gonzales et al. (2017) https://github.com/ZurichNLP/ContraWSD as the test suite, which contains 7,359 contrastive translation pairs for DE-EN translation. We encode the source sentences from ContraWSD and training data of 5 specific domains by averaged BERT embeddings Devlin et al. (2018). Then we whiten the sentence embeddings with BERT-whitening proposed by Huang et al. (2021); Li et al. (2020). For each domain, we select 300 examples from ContraWSD that most similar to the in-domain data based on the cosine similarity of sentence embeddings.
We evaluate the translation performance and word sense disambiguation ability of base model and KSTER for MDMT on selected examples for each domain. The results are shown in Figure 9. Experimental results show that KSTER consistently outperforms base model in both translation performance and word sense disambiguation accuracy, which indicates that kernel-smoothing helps word sense disambiguation in machine translation.
Conclusion
In this work, we propose kernel-smoothed machine translation with retrieved examples. It improves the generalization ability over existing non-parametric methods, while keeps the advantage of online adaptation.
Acknowledgements
We would like to thank the anonymous reviewers for their insightful comments. This work is supported by National Science Foundation of China (No. U1836221, 61772261), National Key R&D Program of China (No. 2019QY1806).