Efficient Nearest Neighbor Language Models

Junxian He, Graham Neubig, Taylor Berg-Kirkpatrick

Introduction

Language models (LMs) are one of the most fundamental technologies in NLP, with applications spanning text generation (Bahdanau et al., 2015; Rush et al., 2015), representation learning (Peters et al., 2018; Devlin et al., 2019; Yang et al., 2019), and few-shot learning (Radford et al., 2019; Brown et al., 2020).

Modern neural language models (NLMs) based on recurrent (Mikolov et al., 2010; Sundermeyer et al., 2012) or self-attentional (Vaswani et al., 2017; Al-Rfou et al., 2019) neural networks are mostly parametric, where the predictions are solely dependent on the model parameters given the input data.

In contrast, recent non-parametric LMs (Guu et al., 2018; Khandelwal et al., 2020; He et al., 2020) model text distributions by referencing both the parameters of the underlying model and examples from an external datastore. Non-parametric LMs are appealing since they allow for effective language modeling – particularly for rarer patterns – through explicit memorization via a datastore, which mitigates the burden on model parameters to learn to encode all information from a large dataset. One effective and representative example is the kk-nearest neighbors LM (kkNN-LM, Khandelwal et al. (2020)). The kkNN-LM computes the probability of the next token by interpolating a parametric LM with a distribution calculated from the kk nearest context-token pairs in the datastore, as demonstrated in Figure 2. This model is particularly notable for its large improvements in performance – it outperforms the previous best parametric LMs by a large margin in standard language modeling benchmarks, in domain adaptation settings, and on other conditional generation tasks such as machine translation (Khandelwal et al., 2021).

However, one downside to the kkNN-LM is that the datastore stores high-dimensional dense vectors for each token in the training data; this can easily scale to hundreds of millions or even billions of records. As a result, the extra retrieval step from such datastores greatly decreases model efficiency at test time. For example, a 100M-entry datastore can lead to an over 10x slow-down compared to parametric models (§3.3) as shown in Figure 1. This issue poses a serious hurdle for the practical deployment of non-parametric LMs, despite their effectiveness.

In this paper, we attempt to address this issue of test-time inefficiency and make non-parametric LMs more applicable in real-world settings. We take kkNN-LM as an example, first analyzing the evaluation overhead, and raise three questions that we aim to answer in this paper: (1) Do we really need to perform retrieval on the prediction of every single token? (2) Can we identify and prune redundant records from the datastore? (3) Is it possible to further compress the datastore by reducing the vector dimensionality without losing performance? We propose and explore potential solutions for each question to aid efficiency. Specifically, we (1) show that a lightweight network can be learned to automatically prune unnecessary retrieval operations (adaptive retrieval, §4.1), (2) explore several different methods for datastore pruning based on clustering, importance-guided filtering, or greedy merging (§4.2), and (3) empirically demonstrate that simple dimension reduction techniques are able to improve both the performance and speed (§4.3). Figure 1 illustrate the overall performance of these methods. Our experiments on the WikiText-103 language modeling benchmark (Merity et al., 2017) and a training-free domain-adaptation setting demonstrate speed improvements of up to 6x with comparable perplexity to the kkNN-LM. On a higher level, we expect the empirical results and analysis in the paper to help researchers better understand the speed-performance tradeoff in non-parametric NLMs, and provide a springboard for future research on more efficient non-parametric LMs.

k𝑘k-Nearest Neighbors Language Model

In this section, we overview kkNN-LM (Khandelwal et al., 2020) and its implementation details.

where λ\lambda is the interpolation hyperparameter. Note that the application of kkNN-LM requires no additional training; the parameters of the NLM remain as-is, and Eq. 1 is only applied at test time. The workflow of kkNN-LM is shown in Figure 2

Datastore.

The datastore in kkNN-LM stores context vectors from the pretrained NLM as keys, and their corresponding next tokens as values. Formally, let ff be the key function that maps context sequence cc to a fixed-size vector, then the datastore (K,V)({\mathcal{K}},{\mathcal{V}}) contains all the key-value pairs constructed from the entire training examples D{\mathcal{D}}:

The size of such a datastore is almost equal to the number of training tokens because the context ctc_{t} is (nearly) unique due to the large context window size in modern recurrent (Sundermeyer et al., 2012) or self-attentional (Vaswani et al., 2017) NLMs. This suggests that the datastore can easily scale to hundreds of millions or even billions of records. Also, each f(ct)f(c_{t}) is a high-dimensional dense vector, which makes the datastore difficult to fit in memory. For example, a datastore from a 100M-token training dataset, using 1024-dimension context vectors at 16-bit precision, could require 200GB of memory. Note that the dataset to contruct the datastore may not necessarily be the training data that trains the parametric NLM in Eq. 1 – a separate dataset may be used for the datastore construction which would lead to potential applications such as training-free domain adaptation or a gradient-free way to utilize extra training data (Khandelwal et al., 2020).

At inference time, the kkNN-LM (1) computes the context vector f(c)f(c) from the current sequence using the pretrained NLM, (2) uses f(c)f(c) as the query to retrieve kk nearest neighbors N={(qi,vi)i=1,,k}{\mathcal{N}}=\{(q_{i},v_{i})|i=1,\cdots,k\} from the datastore, and (3) aggregates the retrieved tokens to form the distribution pkNN(wc)p_{\text{kNN}}(w|c) to be used in Eq. 1 as:

d(,)d(\cdot,\cdot) is a distance function between the two vectors, and L2L^{2} was shown to be more effective than other alternatives (Khandelwal et al., 2020). Intuitively, kkNN-LM finds context sequences in the datastore that are similar to the test context, and then utilizes the next tokens observed after these contexts to help prediction. Such a mechanism allows language modeling through explicit memorization from the datastore, and may be particularly helpful for patterns rarely seen by the pretrained NLM (Khandelwal et al., 2020, 2021).

Sources of Inference Overhead.

The extra inference overhead stems from the kkNN search process in pkNN(wtct)p_{\text{kNN}}(w_{t}|c_{t}) computation. We denote the inference time per token as t=tNLM+tkNNt=t_{\text{NLM}}+t_{k\text{NN}}. While tNLMt_{\text{NLM}} remains constant with different datasets, tkNNt_{k\text{NN}} unfortunately grows as the datastore scales.

In practice, the kkNN search process is often performed only approximately (ANN, Gionis et al. (1999); Muja and Lowe (2009)) to reduce computational cost. Khandelwal et al. (2020) implemented ANN search in kkNN-LMhttps://github.com/urvashik/knnlm. using FAISS (Johnson et al., 2019), which combines inverted file systems (Sivic and Zisserman, 2003) and product vector quantization (Jegou et al., 2010). This type of index reduces memory usage by only storing quantized vectors and accelerates kkNN search by pre-clustering the datastore vectors; interested readers can refer to (Jegou et al., 2010) for more details. For the purpose of this paper we study kkNN-LM using this indexing method as a black box, aiming to improve efficiency in an index-agnostic way. At the same time, we note that building fast and accurate indexing methods remains an active area of research (André et al., 2019; Guo et al., 2020), and selection or improvement of the index itself (possibly in concert with the methods proposed in this paper) is an interesting avenue for future work.

Distance Recomputation.

The distances to the nearest neighbors are required to compute pkNN(wtct)p_{\text{kNN}}(w_{t}|c_{t}) as shown in Eq. 3. However, as described above, kkNN-LM’s nearest neighbor search process performs search over quantized vectors, and as a result it can only return approximate distances. While it is possible to compute the accurate distances by reading the full-precision vectors from the datastore after retrieval, this presents challenges as well: (1) storing the entire datastore in memory is not scalable for large datastores, (2) reading the vectors from a large datastore on disk on-the-fly is too slow to be practical (< 1 token per second).Disk random I/O is another aspect that may be improved by further engineering effort, which is also interesting future work. Therefore, in this paper we use the approximate distances directly to compute pkNNp_{\text{kNN}}. This comes at the cost of a minor performance loss, as we will show in §3.3. Similar approximations were adopted to apply kkNN-LM to machine translation tasks (Khandelwal et al., 2021).

The Efficiency of k𝑘kNN-LM

In this section, we first introduce the datasets and setup that we will use throughout the paper, and then compare the inference speed of kkNN-LM to parametric NLMs.

We study kkNN-LM in two different settings: (1) the standard setting where the datastore is constructed from the same data used to train the NLM, and (2) a domain adaptation setting where the datastore is based on the training data in the test domain, in which case the NLM never sees the examples included in the datastore. The following two datasets are used for the two settings respectively:

is a standard language modeling benchmark from Wikipedia that has 250K word-level vocabulary. It consists of 103M training tokens, and thus leads to a datastore that has 103M records and takes 200G space. Following (Khandelwal et al., 2020), we use the transformer-based (Vaswani et al., 2017) language model checkpoint released by (Baevski and Auli, 2019) as the underlying pretrained NLM, which is trained on the WikiText-103 training split.

Law-MT

is an English-German machine translation dataset in the law domain originally released by (Koehn and Knowles, 2017) and resplit by (Aharoni and Goldberg, 2020). We only use the English text for language modeling. The training set consists of 19M tokens which we use to build the datastore that occupies 55G space. To inspect the domain-adaptation performance, our pretrained NLM is a 12-layer transformer model trained on WMT News Crawlhttp://data.statmt.org/news-crawl/ released by (Ng et al., 2019).

2 Setup

Throughout the rest of the paper, we adopt the same hyperparameters and index as (Khandelwal et al., 2020) for kkNN-LM.We directly base our experiments on the original kkNN-LM implementation. Specifically, the number of nearest neighbors is set to 1024 during evaluation.The perplexity continues improving as kk grows as shown in (Khandelwal et al., 2020) and confirmed by us. Yet kk does not have an effect on the evaluation speed in the range from our observation. Our pretrained NLMs are the state-of-the-art decoder-only transformers as mentioned above, and the key function f(c)f(c) to obtain context vectors is the input to the final layer’s feedforward network. The context vectors are 1024-dimensional and 1536-dimensional for WikiText-103 and Law-MT respectively. Given a dataset, we tune the interpolation weight λ\lambda on validation set in terms of the vanilla kkNN-LM performance, and fix it unless otherwise specified. Complete details on the setup can be found in Appendix A.

Evaluation efficiency is benchmarked on 32 CPU cores (1.5 GHz AMD EPYC 7282) and 1 NVIDIA RTX 3090 GPU which represents a normalized environment – the index searching uses all the CPU cores while neural network computation is based on the GPU. Running retrieval on 32 CPU cores is also used by the FAISS repohttps://github.com/facebookresearch/faiss/wiki/Indexing-1G-vectors as a standard setting to benchmark large-scale retrieval.

3 Baseline Speed

We measure the perplexity (ppl) and speed of evaluation in term of tested tokens per second, and Table 1 reports the results on the test set of the two datasets. We also include “kkNN-LM (exact)” for reference, which represents the kkNN-LM variant that re-computes accurate distances as explained in §2. While very effective with 2 ppl points gains on WikiText-103 and over 90 points gains on Law-MT in a domain-adaptation setting, kkNN-LM is 10x – 30x slower to evaluate on these datasets because of the extra retrieval step. When exact distances are computed by reading vectors from the disk on-the-fly, kkNN-LM (exact) takes over 1 second to evaluate a single token.

The Remedies

In this section we propose and explore several different methods that may potentially improve the efficiency of kkNN-LM along three axes: (1) adaptive retrieval, (2) datastore pruning, and (3) dimension reduction. We analyze the performance of each method on WikiText-103, trying to conclude the best practices that we will evaluate in §5.

Just as humans refer to books only when they are uncertain in an open-book quiz, the parametric NLMs may not always need help from the external datastore. To inspect this hypothesis, we compare pkNN(wc)p_{\text{kNN}}(w|c) and pNLM(wc)p_{\text{NLM}}(w|c) for every token in the WikiText-103 validate set. Interestingly, pkNN(wc)pNLM(wc)p_{\text{kNN}}(w|c)\geq p_{\text{NLM}}(w|c) only 39%39\% of the time – the likelihood of 61%61\% of the tokens becomes worse after interpolation despite the overall improvement. This indicates that if we were able to identify these locations perfectly, 61%61\% of the retrieval operations could be removed completely and we would achieve even better perplexity. Inspired by this observation, we aim to automatically identify and prune unnecessary retrieval operations to speed up inference.

We propose to train a light neural network, the retrieval adaptor, to identify when we should remove the retrieval operation. Specifically, given the context cc as the input, the retrieval adaptor may be trained with either (1) a classification objective to predict whether pkNN(wc)pNLM(wc)p_{\text{kNN}}(w|c)\geq p_{\text{NLM}}(w|c), or (2) a likelihood maximization objective to predict the interpolation weight λ(c)\lambda(c) and maximize the overall likelihood of kkNN-LM as in Eq. 1. In our preliminary results the classification method performs only on par with a random removal baseline, partially due to the discretized noisy supervision. Therefore, we directly maximize the kkNN-LM log likelihood by modeling λ\lambda as a function of the context:

where only θ\theta – the parameters of the retrieval adaptor – are updated. The second term is an L1L^{1} regularizer that encourages learning sparse weights for pkNNp_{k\text{NN}}, which we find helpful to prune unnecessary retrievals. At inference time, we prune a given fraction of retrievals with the smallest kkNN weight λ(c)\lambda(c) by resetting λ(c)\lambda(c) to zero. The hyperparameters of the retrieval adaptor network including the regularizer coefficient, aa, are tuned on the validation set in terms of perplexity at 50% retrieval pruning. Learning the interpolation weights to prune is related to (Johansen and Socher, 2017) where they learn to skip text for classification tasks. Optimizing the interpolation weights in kkNN-LM has also been applied at training time to train the NLM jointly (Yogatama et al., 2021).

Architecture and Input Features:

The retrieval adaptor is a light MLP network with linear transformation followed by ReLU activation at each layer. The output layer maps the hidden representation to a 2-dimensional vector followed by a LogSoftmax layer to yield log(λ)\log(\lambda) and log(1λ)\log(1-\lambda) respectively. Complete details on the retrieval adaptor can be found in Appendix A.2. We concatenate several neural and count-based features as input to the retrieval adaptor as shown in Table 2. For the scalar features (basically all the features excluding f(c)f(c)) , we found it helpful to map them to a vector with a small network before concatenation. We note that all the features are trivial to obtain at test time – the neural features are from intermediate computation of pNLM(wc)p_{\text{NLM}}(w|c) and count-based features are looked-up values. Ablation analysis on these features can be found in Appendix B.

Training:

During training, only the retrieval adaptor is updated while the pretrained NLM is fixed. Note that it is inappropriate to train the retrieval adaptor on the training dataset, which would lead to biased solutions since pNLMp_{\text{NLM}} may have already overfit on the training data and the datastore includes the training example itself. To generalize to the test data, we hold out 10% of the validation data for validation and use the remaining 90% to train the retrieval adaptor. The retrieval adaptor is light and converges quickly; it took several minutes to train it on WikiText-103 with a single GPU.

Results:

Figure 3 shows the perplexity and evaluation speed of adaptive retrieval on the test set of WikiText-103, varying the percent of removed retrieval operations. The different threshold values of λ\lambda used to cut off retrieval is selected based on the synthetic validation set mentioned above. We also add a random retrieval baseline which uniformly selects a certain fraction of retrieval operations to discard. We observe that adaptive retrieval (AR) exhibits a much flatter increase of perplexity than the random baseline when the number of removed retrievals grows. Notably, AR is able to achieve comparable perplexity to the original kkNN-LM model (16.67 vs. 16.65) while being nearly 2x faster (530 vs. 277 tokens/s) through removing 50% of the operations. AR’s gain comes from both the smart pruning mask and optimized λ\lambda. We perform an ablation study on this in Appendix B.

2 Datastore Pruning

The information present in a large training dataset is often redundant, which suggests that a datastore constructed from training tokens may be pruned with no or only minor performance cost. To validate this hypothesis, we propose several different methods to prune the number of entries and reduce the datastore size:

As a simple baseline, a certain fraction of the datastore entries are randomly selected. Random pruning has been shown to work well with a billion-scale datastore in machine translation tasks (Khandelwal et al., 2021).

k𝑘k-Means Pruning:

Clustering is a common technique to prune redundant vectors by only keeping the centroids of the clusters. Yet in our task specifically, we note that a general clustering on the context vectors is not directly applicable since the vectors in the same cluster may still correspond to various target tokens, as language use in context is not deterministic. Therefore, we propose to perform target-aware kk-means clustering – for a word wiw_{i} in the vocabulary, we perform a separate kk-means clustering for all the context vectors that have wiw_{i} as the target token, then we only keep centroids of each cluster as well as saving the cluster size ss. The (centroid vector, cluster size, target token) triples form a new compressed datastore. Since we approximate multiple vectors in the same cluster with the centroid and only save the centroid vector once in the new datastore, the computation of the kkNN distribution pkNNp_{k\text{NN}} needs to be rectified as:

the cluster size sis_{i} acts like weights for each datastore entry. Eq. 5 recovers Eq. 1 when every cluster is of size 1.In addition, the centroid formulation is roughly equivalent to saving vectors within the same cluster as the centroids multiple times without pruning in the original formulation. In practice, we perform 5000 separate kk-means clustering passes only for the most frequent 5000 words due to high computational cost, which accounts for 84% of all the training tokens. For other vectors we treat each of them as a separate clusters with size 1. The number of clusters in kk-means are set to 1/20 of the number of vectors to be clustered, which produces a 5x smaller datastore overall. We did not intensively tune the kk-means hyperparameters due to the computational burden. We note that the clustering here is different from the pre-clustering in the ANN index with inverted file systems mentioned in §2– the index’s pre-clustering does not actually reduce size and is just for lookup.

Greedy Merging:

Generally we aim to merge records that share the same target token while being close to each other in vector space. Token-aware clustering is an attempt to achieve this goal, but forcing all points to participate in clustering – and the resulting large clusters – causes some points within the same cluster to be distant in some clusters with high variance. Thus approximating all the vectors with the cluster centroids may lead to large errors. To address this issue, we propose a simple approach, greedy merging (GM), which inspects every record in the datastore and greedily merges their nearest neighbors if a merging condition is satisfied. The detailed algorithm is shown in Algorithm 1. Intuitively, GM is density-based to group points with nearest neighbors, but the merging operation only happens locally between a point and its nearest neighbors – it never propagates to merge the nearest neighbors of nearest neighbors unlike typical density-based clustering methods (Ester et al., 1996) which may amplify errors. Similar to kk-means pruning, we also compute the weights sis_{i} of each entry in the compressed datastore to correct pkNNp_{k\text{NN}} computation using Eq. 5. Without a global clustering mechanism, this approach ensures that the merging vectors are close enough by inspecting only a small number of nearest neighbors. In the following analysis we vary the number of nearest neighbors KK within range to achieve different compression rates.

Rank-based Pruning:

It is well known that embedding spaces contain “hubs” which are nearest neighbors of many other embeddings (Tomasev et al., 2013), and other points that are not nearest neighbors of any other points. We hypothesize that these entries which are rarely nearest neighbors may be removed without significant impact on the performance. To verify this assumption, we iterate every (ci,wi)(c_{i},w_{i}) pair in the training data as queries to search their kk nearest neighbors from the datastore (kk is set to a large number as 1024 here). In this process we compute an “importance score” for every entry in the datastore as g=i1/rankig=\sum_{i}1/\text{rank}_{i}, where ranki\text{rank}_{i} is the rank of this entry among the nearest neighbors of the query f(ci)f(c_{i}). rank=+=+\infty if it is not in the retrieval results. Intutively, the “importance score” up-weights the datastore records that appear more often with lower ranks in the retrieval results. Then we sort all the datastore records in terms of gg and remove the ones with small scores, varying the compression rate. This method shares spirit with the technique in (Min et al., 2020) which filters out the articles that are never retrieved in memory-constrained open-domain question answering tasks.

Results:

Figure 4 demonstrates the perplexity v.s. speed results on Wikitext-103 validation set of different datastore pruning methods described above. Only one solution point is reported for kk-means since we do not vary the hyperparameters of kk-means for different compression rate, given that its computational cost is much higher than other methods. Using 20% of the original datstore, kk-means even underperforms the vanilla NLM baseline, suggesting that the cluster centroids approximation may lead to large distance errors which reduce the accuracy of the kkNN distribution. Surprisingly, the simple random pruning method outperforms more complicated ones such as kk-means and rank-based pruning. The best approach is greedy merging, which demonstrates a relatively flat curve compared with others.

3 Dimension Reduction

The context vectors f(c)f(c) from large NLMs are often high-dimensional. For example, the pretrained NLMs that we use produce vectors of 1024 and 1536 dimensions in WikiText-103 and Law-MT respectively, which incurs significant datastore space and distance computation cost. To mitigate this issue, we empirically explore the effect of dimension reduction in kkNN-LM . Specifically, we use principal component analysis (PCA), an efficient and scalable dimension reduction algorithm, to reduce the dimensions and generate a new compressed datastore. We vary the new PCA dimensions as the hyperparameter and report the results.

As shown in Figure 5, the evaluation becomes faster as expected with smaller dimensions, yet a too aggressive compression (dimension < 256) incurs large perplexity cost and even loses advantages over NLM when the dimension is smaller than 128. However, at 256 and 512 dimensions PCA is able to achieve comparable or even better performance than the original 1024-dim vectors, while attaining 3x-4x speed-up.The tool we use for PCA, the FAISS PCA implementation, applies random rotation to the PCA output vectors by default to re-balance variances of components of a vector (Gong et al., 2012), which may provide additional benefits over vanilla PCA on product vector quantization inside the index.

Putting it All Together

Based on the analysis results in §4, in this section we combine best practices in adaptive retrieval, datastore pruning, and dimension reduction to assess the performance. We select the retrieval pruning rate rr, datastore pruning rate nn, and the reduced dimensions dd on the validation set,Adaptive retrieval uses part of the validation data to training the retrieval adaptor network, thus we select rr separately on its own held-out validation and then combine it to others. so that they achieve the largest speed-up at the cost of <=0.1<=0.1 perplexity compared to vanilla kkNN-LM . We report the results on the test set.

Table 3 shows the results on the test set of WikiText-103 and Law-MT, where we assess the combination of all three different strategies. Separate performance for each strategy is also included for reference points. On WikiText-103, adaptive retrieval is able to remove 50% of the retrieval and achieve nearly 2x speed-up, greedy merging prunes 40% of the datastore at the cost of 0.2 perplexity points. The dimension reduction method PCA leads to a minor improvement of perplexity over kkNN-LM while being 3.6x faster. Combination of all the three techniques yields comparable perplexity to vanilla kkNN-LM (16.67 v.s. 16.65) and a 6.6x speed-up (1835 v.s. 277).

Different from WikiText-103 where the datastore is contructed from the data that trains the pretrained NLM, in the Law-MT domain adaptation setting the datastore represents the domain-specific knowledge that the pretrained NLM never sees during training and thus is critical to produce good perplexity. This may be inferred from by the large ppl gains that the datastore offers (94 points). From another perspective though, the big improvement from the datastore retrieval leads to difficulties removing retrieval operations adaptivelyThis can be reflected from the oracle comparison: pkNN(wc)pNLM(wc)p_{\text{kNN}}(w|c)\geq p_{\text{NLM}}(w|c) 76% of the time compared to 39% in WikiText-103. – our learned retrieval adaptor is able to remove only 10% of the retrieval operations costing 0.1 ppl points. Greedy merging is able to prune 40% of the datastore losing 0.7 ppl points. We suspect that the Law-MT datastore is more vulnerable to pruning than the WikiText-103 one because of its smaller size (19M v.s. 103M) and corresponding lack of redundancy. Interestingly, the PCA dimension reduction yields 1 point ppl gain over the vanilla kkNN-LM while achieving 3.3x speed-up, consistent with WikiText-103. This implies that a PCA transformation may be able to produce a new vector space that is more appropriate for defining pkNNp_{k\text{NN}} with L2L^{2} distances, we leave the underlying reasons for future work to discuss. Finally, a combination of the three allows kkNN-LM to be evaluated 5.4x faster and even obtain superior perplexity.

Implications and Future Work

In this paper, we explore several different ways to improve efficiencies of the kk-nearest neighbors language model, achieving up to 6x speed-up while attaining comparable performance. As for future work, it is interesting to explore features from the datastore side to better know when to retrieve, and the gap between retrieval-based NLMs and parametric NLMs may be further reduced by combining more optimized indexing methods and the approaches in this paper.

Acknowledgements

We thank the anonymous reviewers for their comments, Emma Strubell, André Martins, Pedro Martins, and Uri Alon for helpful advice and discussions, and Wanzhen He for help with figure plotting. This material is based upon work supported by the National Science Foundation under Grant 1815287.

References

Appendix A Experimental Setup Details

The interpolation hyperparameter λ\lambda is tuned in the range [0,1, 0.9] with interval 0.05 on the validation split of each dataset separately. As a result, λ=0.25\lambda=0.25 in WikiText-103 and λ=0.9\lambda=0.9 in Law-MT.

A.2 Adaptive Retreival

We use the same adaptive retrieval configuration hyperparameters for different datasets, which are validated on the WikiText-103 dev set: the retrieval adaptor is a MLP network with 4 hidden layers, 1 input layer and 1 output layer. Each layer is a linear transformation followed by the ReLU non-linear activation, and a dropout layer with 0.2 dropout probability, except for the output layer where the hidden unites are transformed to 2 dimensions followed by a log softmax to produce logλ\log\lambda and log(1λ)\log(1-\lambda). The number of hidden units in each layer is 128. Before passing the input features to MLP, we transform each of the scalar features (all the features except for f(c)f(c)) into an mm-dim vector, where m=dim(f(c))/nm=\text{dim}(f(c))/n and n is the number of scalar feature types. This is to balance the context vector feature and other features. The scalar-feature transformation is performed with an one-layer Linear(in, out)-ReLU-Linear(out, out) network. We also tried using LSTM (Hochreiter and Schmidhuber, 1997) network to capture the temporal relations yet found it leads to very unstable training and fails to converge, though we note that MLP is faster at test time and the f(c)f(c) feature already captures the temporal correlations between tokens. The coefficient of the L1L^{1} regularizer aa is tuned on WikiText-103 validation set among {0.01,0.05,0.1,0.2,0.5,1}\{0.01,0.05,0.1,0.2,0.5,1\} and fixed as 0.05 for both WikiText-103 and Law-MT. The model is trained using the Adam optimizer (Kingma and Ba, 2015) with learning rate 0.0005. The checkpoint with the best validation perplexity at 50% pruning is saved.

Appendix B Ablation Analysis

We analyze the effect of different input features to the retrieval adaptor by removing a subset of features. We report the perplexities at 50% retrieval pruning, because using different features only has a marginal effect on the evaluation speed. Results on the WikiText-103 test set are shown in Table 4. All features together produce the best results, while the perplexity is relatively robust to removal of a single feature. In our experiments (§4 and §5) we drop off the logfreq\log\text{freq} feature and use the others to save memory while achieving comparable perplexities to using all features.

Effect of learnable interpolation weights:

In the adaptive retrieval analysis (§4.1), we observed gains of a learned retrieval adaptor over a random baseline at different fractions of retrieval prunning. However, the advantages may come from two sources: (1) the automatically identified prunning masks agains the random masks, and (2) the learned interpolation weights on the remaining retrievals against the constant weights that random baseline uses. To separate the two effects, we perform an ablation study to analyze the results of (1) random mask, constant weight (“Random” in §4.1), (2) random mask, learned weight – the weights are from the trained retrieval adaptor, (3) learned mask, constant weight, and (4) learned mask, learned weight (“Adaptor” in §4.1). The results are shown in Figure 6, “learned mask, learned weight” performs the best. While minor gains are from the automatically learned weights (“Random mask, learned weights”), most of the superiority can be attained with the smart pruning strategy even with constant weights (“Learned mask, constant weights”).