Training Language Models with Memory Augmentation
Zexuan Zhong, Tao Lei, Danqi Chen
Introduction
Memory augmentation has become a remarkable approach to enhance language modeling performance without significantly increasing the amount of parameters and computation. By accessing memory units such as a neural cache of recent inputs Merity et al. (2017); Grave et al. (2017b) and an external look-up table Khandelwal et al. (2020), a memory-augmented language model (LM) enjoys increased memorization capacity and sets new state-of-the-art records in various language modeling benchmarks.
A major limitation of existing approaches, however, is that the memory units are either introduced at testing time Grave et al. (2017b, a); Khandelwal et al. (2020) or taken from a separately trained model Yogatama et al. (2021). As a consequence, they are not directly optimized during the training process, resulting in a missed opportunity to achieve even stronger results. In this paper, we pioneer and present a novel yet simple training approach Trime (Training with In-batch Memories)We can also interpret Trime as three types of memories, as we will elaborate in the paper., that is well-suited for memory augmentation in language modeling. Our approach makes two major departures compared to standard language model training:
Inspired by contrastive representation learning, we propose a training objective that directly leverages in-batch examples as accessible memory (Figure 1). Our training objective is closely connected to neural cache models Grave et al. (2017b); Merity et al. (2017) and nearest-neighbor language models Khandelwal et al. (2020), where the next-token probabilities are calculated by comparing encoder outputs against static token embeddings and memory representations. However, previous work only considers incorporating memories at testing time, while we do for both training and testing.
In-batch memory construction
With this training objective in mind, the key challenge is how to construct memories effectively during training while keeping it efficient. We identify three types of memories that can be leveraged at testing time and have been explored in the literature: (a) local memory denotes the words that appear in the recent past and are modeled using attention Vaswani et al. (2017); (b) long-term memoryLong-term memory may have different interpretations in other contexts and we use long-term memory to refer to long-range context in modeling long sequences, following previous work Martins et al. (2022); Wu et al. (2022). denotes long-range context from the same document but cannot be directly accessed due to the limit of input length; (c) external memory is used to store the entire training set or any additional corpus Khandelwal et al. (2020); Borgeaud et al. (2021).
To better leverage these memories at testing time, we devise new data batching strategies to improve the construction of training memories (§4). By packing consecutive segments from the same document in one training batch, our model can access long-term memories beyond the attention context. We pack segments from other documents that have high lexical overlap as a proxy to all external memory units. Importantly, these working memories are generated on the fly during training, allowing us to back-propagate to all memory representations.
We instantiate Trime in three models by considering different sets of training and testing memories (Table 1) and evaluate them on multiple language modeling and machine translation benchmarks. We highlight our results as follows:
We first show that we can simply optimize a language model using our training objective without long-term and external memory. Without any other modifications, we demonstrate that a 247M Transformer-based model can achieve an improved perplexity from 18.70 to 17.76 on WikiText-103 Merity et al. (2017) with negligible overhead. This model can be viewed as a simple replacement for vanilla language models.
By training with consecutive segments in the same batch, our approach is capable of leveraging very long context at testing time—up to 15k-25k tokens on WikiText-103 and Enwik8 Mahoney (2009). Our approach achieves at least competitive performance as previous works Dai et al. (2019); Martins et al. (2022); Ji et al. (2022) that modify the Transformer architecture to incorporate memories from previous segments, yet our solution is conceptually simpler and computationally cheaper.
Finally, we train language models by incorporating all other segments in the same batch as memories. Our model works better with a large datastore at testing time and improves over the kNN-LM model Khandelwal et al. (2020) by reducing the test perplexity from 16.23 to 15.41 on WikiText-103. We also demonstrate significant improvements over the kNN-MT baseline Khandelwal et al. (2021) on an IWSLT’14 De-En machine translation task.
In summary, we propose a simple approach Trime for optimizing language models with memory augmentation and demonstrate consistent and significant gains in multiple experimental settings. Our approach only uses memories at the final prediction step, and hence adds little computational overhead and can be combined with different model architectures such as recurrent networks and other attention variants Lei (2021); Dai et al. (2019); Rae et al. (2020). We hope that our work can encourage the research community to think about better training objectives for language models, given their significant societal impacts Brown et al. (2020); Chowdhery et al. (2022); Zhang et al. (2022).
Preliminaries
2 Memory Augmentation
We consider memory as a set of context-target pairs following Grave et al. (2017b); Khandelwal et al. (2020). These context-target pairs can be aggregated to obtain the next-token probability weighted by the similarity between hidden representations.Other memory-augmented models differ in when the memory was introduced, such as using them in attention, and retrieve texts of different granularity as memory Guu et al. (2020); Borgeaud et al. (2021). We formalize three types of context-target memories as follows:
The local memory is simply the preceding tokens in the same input. Specifically, for , it is defined as:
Grave et al. (2017b) use the local memory at testing time, denoted by the “continuous cache” model. However, it has been argued less effective for Transformer-based models because they can already learn to leverage recent tokens in the self-attention layers Khandelwal et al. (2020). Interestingly, we show that using local memory is still beneficial if we consider it during training.
Long-term memory
Long-term memory denotes long-range context from the same document, but they cannot be directly accessed by attention. For example, if a document contains 10K tokens, only a short segment of text (e.g., 100-3K tokens) can be fed into a Transformer model because the complexity scales quadratically with the input length. Formally, we divide a document into consecutive segments , where a segment contains contexts . The long-term memory for is:
Previous works Dai et al. (2019); Rae et al. (2020); Martins et al. (2022); Ji et al. (2022); Wu et al. (2022); Lei (2021) leverage hidden representations from previous segments with modified Transformer architectures to learn long-range dependency. Our approach does not modify the model architecture and is compatible with these neural architectures. Note that continuous cache can be naturally extended to long-term memory, as we will experiment later. The earlier continuous cache work was applied to LSTMs on long sequences, as LSTMs can linearly scale with long sequences and there is no need to segment documents.
External memory
Finally, external memory assumes a large corpus and the external memory set can be defined as:
can be simply the training corpus, or a domain-specific corpus when the testing domain shifts (§5.3). Note that is usually several orders of magnitude larger than previous two types (e.g., ); accessing all the memories is computationally expensive and requires approximate nearest neighbor search Johnson et al. (2019).
Training with In-batch Memories
In this section, we propose a new training approach Trime for language model training. Compared to standard language model training, our training objective assumes a set of training memories . We differentiate training memories from testing memories, as they are constructed on the fly during training and may deviate from the testing memories used during inference. Importantly, the training memories are constructed from the same training batch, which enables back-propagating the training signal to the current hidden representation as well as all the memory representations. We will discuss how to construct training memories in the next section (§4) and only discuss the training objective in a general form.
Our training objective is illustrated in Figure 1. Given a memory set and a context , Trime defines the next-token probability distribution as:
Here, is the output representation of a Transformer model and is the token embedding. denotes the representations that can be used to compute similarity between and all the contexts in the memory . It is possible to simply take ; however, we find that taking to be the input of the final feed-forward layer in Transformer works better, which is consistent with the observation in Khandelwal et al. (2020). In addition, is a similarity function and we found using the scaled dot-product Vaswani et al. (2017) leads to stable training and better performance in our preliminary experiments.
This training objective can be viewed as a contrastive loss Hadsell et al. (2006): for a context-target pair , the goal is to align the query representation (and ) with the static token representation , and contextualized representations that share the same next token i.e., for . Our objective handles rare words nicely—if does not appear in the training memory, the objective will fall back to aligning with only the word embedding . Similar to the vanilla training loss (Eq. 1), our Trime loss is optimized to minimize the negative log-likelihood of next token and all the parameters and are updated during training.
Our training objective is also inspired by the success of contrastive learning in dense retrieval Karpukhin et al. (2020). As we will show in §6, it can help improve retrieving contexts that share the same next token effectively when the set of testing memories is large. Our objective is also closely connected to the objective used in Grave et al. (2017b); Khandelwal et al. (2020), which linearly interpolates the distribution of standard language modeling, and a distribution defined by cache/external datastore, e.g., . Our work differs from previous works that we use this objective during training (and testing), while they only used it at testing time—the key is how to construct training memories that we will elaborate next.Grave et al. (2017b) described a “global normalization” variant in the paper, which is similar to our objective. However, they only used it at testing time and only considered short-term contexts in calculating the distribution. Other works Merity et al. (2017); See et al. (2017) trained a pointer network with a learned gating component for the interpolation—we attempted training with a similar objective earlier and found it to perform worse than our current objective.
Adaption to Different Memories
We are interested in incorporating the three types of memories defined in §2.2 and their combinations at testing time. The testing objective is basically the same as the training objective (Eq. LABEL:equ:loss) except that we take testing memories as a combination of , and . As can be very large, we approximate it by retrieving the top-K closest terms to . We tune a temperature term to adjust the weight of the memory component (see Appendix A for details).
Notation
Throughout this section, we use to denote segment length, to denote the total number of segments used in the one training batch, and to denote the number of consecutive segments from each document in the batch. Correspondingly, each batch will contain different documents. , and are hyper-parameters that we will choose for training, and will vary as we consider different memories during inference.
A key challenge is that the testing memories can be very large (e.g., and in our experiments) and it is computationally infeasible to keep training memories the same as testing memories. In the following, we will discuss three ways of constructing training memories and data batching, aiming to reduce the discrepancy between training and testing. Along the way, we will also present three major model instantiations: TrimeLM, , (Table 1), which combine the training strategies and different sets of testing memories.
1 Local Memory
only considers all the previous tokens in the same segment. It is straightforward that we can simply use . As shown in Fig. 2(a), we basically do not need to make any modifications compared to standard language model training. All we need is to replace the training objective of Eq. 1 by our objective in Eq. LABEL:equ:loss, by incorporating , in the memory during both training and testing. The computational overhead is also negligible compared to running neural encoders on the segment itself. We denote this model as TrimeLM, which can be viewed as a lightweight replacement for vanilla language models. As we will show in the experiments, simply incorporating local memory provides a notable gain on multiple LM benchmarks, showing the effectiveness of training with memories explicitly.
2 Long-term Memory
In order to enable long-term memory augmentation, we pack multiple consecutive segments from the same document in a training batch (i.e., ). For a context-target pair in the training batch, its accessible memory includes tokens from previous segments as well as the preceding tokens in the same segment. Figure 2(b) illustrates the training batch construction and the training memory for a given token. At testing time, we can use a much longer context: we simply enumerate the number of segments used in and choose the optimum based on the development set.
We denote this model as . It shares a similar motivation with many previous works which aim to leverage memory from previous segments through attention recurrence Dai et al. (2019); Ji et al. (2022), or memory compression Rae et al. (2020); Martins et al. (2022); Wu et al. (2022). However, our solution deviates significantly from previous approaches. First, previous works need to store the hidden representations (of every layer) from previous segments and modify the self-attention layers to incorporate them. Our approach does not modify the architecture and only uses the outputs from the last layer. Additionally, previous works use stale memory representations and do not back-propagate gradients to the representations of previous segments, whereas our batching method enables gradient propagation to the memory and previous segments.We also attempted using segments in previous training batches as stale representations and did not find any improvement in preliminary experiments. As we will show in the experiments, our approach is competitive with previous works while being conceptually simpler and computationally cheaper.
3 External Memory
Finally, we consider external memory . Since contains the context-target pairs in a large corpus such as the entire training set, we need to retrieve top- pairs from measured by through (approximate) similarity search (more details are given in §5.2).
Since the retrieved contexts at testing time are expected to be similar to the query context, we propose a simple heuristic for constructing training memories by packing segments that have large lexical overlap into the same batch using BM25 scores Robertson and Zaragoza (2009). Specifically, we start with a single segment and repeatedly add segments with highest BM25 scores into the same batch (Appendix B). A high BM25 score indicates that two segments have high lexical overlap and can serve as a good proxy to nearest neighbors in the external memory, which improves our model predictions at testing time. contains all tokens from other segments as well as the previous tokens in the same segment (Figure 2(c)). We set during training as many segments from the same document tend to have high lexical overlap and denote this model by .
In practice, when considering tokens from both the current segment and other segments in the batch, we observe that the model tends to leverage local memory more and ignore other segments. To encourage the use of information from other segments, we exclude the local memory from with a probability of during training (we find that works the best, see Appendix H). This significantly improves performance when the model is evaluated with a large set of external memory.
Experiments
We evaluate our approach on two popular language modeling benchmarks: WikiText-103 Merity et al. (2017), Enwik8 Mahoney (2009), and a machine translation benchmark: IWSLT’14 De-En. We also evaluate domain-adaptation performance on the BooksCorpus dataset Zhu et al. (2015).
WikiText-103 is a word-level language modeling dataset consisting of 103M training tokens. We evaluate on two model configurations: one uses a 247M Transformer model and a segment length and another one uses a 150M Transformer model with a segment length .
Enwik8 is a character-level language modeling dataset that contains a total of 100M characters. We use a 12-layer Transformer model with a hidden dimension and segment length .
BooksCorpus is a word-level language modeling dataset. We build our own train/dev/test splits which consist of 100M/250K/250K tokens. On this dataset, we evaluate the models trained on WikiText-103 to study how our approach can adapt to new domain without re-training.
IWSLT’14 De-En is a machine translation task, which consists of 170K translation pairs. We use a Transformer encoder-decoder model. See Appendix C for how we adapt our approach to the machine translation task.
See Appendix C for data statistics and task setups and Appendix D for model configurations.
2 Training and Inference Details
We implement our approach using the Fairseq library Ott et al. (2019). For and , we tune the number of segments used in on the development set during evaluation. Our model requires building a large datastore at testing time and we use the FAISS library Johnson et al. (2019) for approximate nearest neighbor search (details in Appendix D).
We first train our model with the standard LM objective (Eq. 1) for the first 5% updates. Without this warmup stage, we observe the training process to be unstable probably due to a large variance in the estimated distributions. We use different memories when evaluating different instantiations of Trime, as shown in Table 1. We find that when a large set of external memory is considered during inference, the performance can be improved by linearly interpolating the output distribution and a distribution over the memory, similarly to kNN-LM Khandelwal et al. (2020). Thus, we apply an additional linear interpolation to our output probability distribution when considering external memory (see Appendix A for details).
3 Results: Language Modeling
We first compare our TrimeLM model which only uses local memory during training and testing. Table 2 shows that adding a continuous cache during inference can improve the performance of vanilla Transformer from 18.70 to 18.26, and our TrimeLM further improves the perplexity to 17.76. These results suggest that even though the attention mechanism can “see” local context, using local memory during both training and testing can still improve model performance. TrimeLM has no computational overhead compared to vanilla LM (indicated by the “speed” column), making it a simple and better replacement for vanilla language models. Similar trends can be observed in Table 3 and Table 4 (25.87 vs. 25.60 and 1.16 vs. 1.12). The improvement is much smaller though, due to a much smaller segment length . More analysis is given in Appendix G.
We then examine our model which is trained with the data batching method described in §4.2. As shown in Table 3 and Table 4, improves vanilla Transformer models substantially (i.e., on WikiText-103 and on Enwik8) by leveraging long-range contexts at inference time. We find the model achieves its best results when leveraging 15,000 tokens on WikiText-103 and 24,576 tokens on Enwik8, even though the segments used during training are much shorter ( 150 and 512 respectively). We also add continuous cache to the vanilla Transformer model and find it to underperform our model, demonstrating the importance of joint training using our approach.
Compared to previous methods which explicitly leverage hidden representations from previous segments Dai et al. (2019); Rae et al. (2020); Martins et al. (2022); Ji et al. (2022); Lei (2021), our approach achieves better or at least competitive performance. Different from these approaches which need to store all the hidden representations of every layer and modify the model architecture, we only incorporate the outputs from the last layer—requiring less computations and GPU memory. Our approach is orthogonal and can be applied on top of these models. To verify this, we adapt our approach to SRU++ Lei (2021) (see details in Appendix E). As shown in the bottom block of Table 4, gains consistently improvement over vanilla SRU++, outperforming previously reported results given the same model size.
Finally, our model outperforms the kNN-LM model Khandelwal et al. (2020), which uses external memory only at testing time—improving the perplexity from 16.23 to 15.41 on WikiText-103 (Table 2). We also evaluate a model which does not use long-term memory (denoted by w/o ) for a fair comparison with kNN-LM with continuous cache and the difference is very small (15.55 vs 15.41). Our results suggest that by using contrastive loss and BM25 batching (§4.3), the model learns to better retrieve and leverage information from a large external memory.
Domain adaptation
We evaluate the domain-adaptation performance of Trime on BooksCorpus Zhu et al. (2015). We take models that are trained on WikiText-103 and evaluate them on BooksCorpus without any re-training or fine-tuning. As shown in Table 5, a vanilla Transformer model trained on WikiText-103 performs poorly on BooksCorpus. TrimeLM and can significantly improve the performance as they leverage local or long-term memory to adapt to the new domain. By building the external memory using BooksCorpus, both kNN-LM and perform much better on BooksCorpus compared to the vanilla Transformer model. outperforms kNN-LM on domain adaptation. This indicates that although the memory representations are optimized on one domain, our approach does not overfit, and building an external memory using the target domain dataset enables the model to perform well with domain shifts.
4 Results: Machine Translation
To showcase the generality of our training approach Trime to other generation tasks, we evaluate our approach on the IWSLT’14 de-en translation task. Since it is a sentence-level task, we do not use any local or long-term memory (, ), as there are few repetitive tokens. We denote our model as .
As shown in Table 6, our approach improves the vanilla Transformer by BLEU score and outperforms kNN-MT Khandelwal et al. (2021). This demonstrates that our approach is able to improve the performance on other language generation tasks with different memory access.
Analysis
We conduct ablation studies and analysis to further understand individual components of our approach. Due to the limited computation budget, some experiments on WikiText-103 are conducted with a small 7M Transformer model (8 layers, hidden dimension 128) in this section and the trends are generally similar for smaller models (see Appendix D and Appendix F for details).
We first study how different data batching and memory construction strategies affect the performance when different testing memories are used. We compare our three models (TrimeLM, , ) in Table 7. This ablation study clearly shows that packing consecutive segments and segments with high BM25 scores in the same training batch and constructing memories properly can improve the performance when the long-range and external memories are used. This demonstrates the importance of closing the gap between training and inference.
Leveraging long-range contexts
We study if our model is able to handle large long-term memory. As Figure 3 shows, our model is able to effectively handle long-range context (more than 10k tokens), which goes beyond typical attention context. Compared to continuous cache Grave et al. (2017b, a), the improvement of our approach becomes larger when more long-term memory is incorporated. This suggests that our model is able to leverage long-range context much more effectively.
Additional analysis
We conduct more ablation studies and analysis in Appendix G. We summarize them as follows. (1) Our ablation studies show using BM25 batching method and enabling back-propagation to update memory representations are important for our approach (Table 11). (2) TrimeLM is able to leverage local memory effectively to improve performance with different segment lengths (Table 12). (3) outperforms kNN-LM in terms of top-K retrieval accuracy given the external memory set (Table 13). (4) We study the perplexity of tokens in different frequency groups and find that TrimeLM and achieve larger improvements on rare words while improves results across the board (Table 14).
Related Work
We have discussed continuous cache, kNN-LM and models that leverage representations from long-range context in the previous sections. Yogatama et al. (2021) also aim to combine several types of memories by learning an adaptive gating function; however, their external memory uses a pre-trained vanilla language model. Borgeaud et al. (2021) demonstrate a remarkable performance by augmenting LMs with an external datastore of trillion of tokens and their datastore is built based on chunks of text using off-the-shelf BERT embeddings Devlin et al. (2019). Our approach differs from prior works in the following aspects: (1) we update the memory representations through back-propagation from the end loss; (2) our model does not modify the base architecture; (3) we consider different types of memories in a unified framework. GNN-LM Meng et al. (2022) augments LMs with a graph neural network to aggregate information of retrieved items from external memory, which makes an orthogonal contribution to our paper.
Transformers for long inputs
A large body of research has investigated how to scale self-attention mechanism to long contexts, either through sparse attention Liu et al. (2018); Child et al. (2019); Beltagy et al. (2020); Zaheer et al. (2020) or sub-quadratic-time attention Wang et al. (2020); Choromanski et al. (2020); Peng et al. (2021); Katharopoulos et al. (2020). See Tay et al. (2020) for a comprehensive survey of efficient Transformers. Our approach is orthogonal, as we only change the training objective and data batching to enable models to use large contexts during inference.
Memory-augmented models for downstream tasks
Prior works have also improved models for downstream tasks with a retrieval component, such as question answering Kumar et al. (2016); de Masson D’Autume et al. (2019); Karpukhin et al. (2020); Guu et al. (2020); Zemlyanskiy et al. (2021); de Jong et al. (2022); Chen et al. (2022); Izacard and Grave (2021); Singh et al. (2021), dialogue Fan et al. (2021), and other knowledge-intensive NLP tasks Lewis et al. (2020); Petroni et al. (2021). Notably, recent works de Jong et al. (2022); Chen et al. (2022) explore a similar idea for question answering and leverage in-batch memories to train memory representations for entity mentions or QA pairs, which are further incorporated into Transformers at a second stage.
Conclusion
In this work, we propose Trime, a training approach for language modeling. We present three model instantiations TrimeLM, , : Through carefully-designed data batching and memory construction during training, we show that our models can leverage long-range contexts and external memory effectively at testing time. Our approach adds little computational overhead and does not modify model architectures, making it compatible with other neural models and techniques. For future work, we are interested in training Trime with large language models and other text generation tasks.
Limitations
We discuss limitations of our research as follows.
Despite the strong performance achieved by our approach when incorporating a large set of external memory, it results in a reduced inference efficiency at the same time due to the nearest neighbor search. For example, the model is slower when incorporating external memory. This issue can be more crucial when the external memory is even larger. Potential solutions to this issue include (1) constructing the memory using a coarser granularity (e.g., text blocks) Borgeaud et al. (2021); (2) compressing the external memory set and reducing the dimension of memory representations He et al. (2021).
We mainly experiment with Transformer-based models and additionally adapt our approach to SRU++ Lei (2021). We believe our approach is compatible with other architectures or techniques such as Transformer-XL Dai et al. (2019) and Compressive Transformer Rae et al. (2020). We plan to explore them as future work.
We evaluate our approach on machine translation to test the generality of Trime to other generation tasks. However, due to compute limitation, we only evaluate it on a small dataset (i.e., IWSLT’14), which consists of 4M tokens in the external memory. We leave the evaluation on larger machine translation datasets as future work.
Our paper mainly studies language modeling tasks and machine translation tasks. Although we believe our approach is compatible with all language generation tasks, how to adapt Trime to natural language understanding tasks such as text classification still remains an open question.
The biggest model we experimented with consists of 247M parameters due to our compute limit. The state-of-the-art auto-regressive LMs contain hundreds of billions of parameters Brown et al. (2020). We hope to see future efforts in scaling up our approach and evaluating the effectiveness on large LMs.
Ethical Considerations
Our proposed approach leverages external memory to achieve strong results on multiple language modeling benchmarks. In our experiments, we construct the external memory using the corpus on which the model is trained, while it can be constructed using any corpus. In general, we suggest practitioners constructing external memory using a public corpus, as retrieving from the external datastore can cause information leakage from the corpus. We acknowledge this ethical consideration and caution those who apply our approach to privacy-sensitive domains.
Acknowledgments
We thank Jane Pan, Howard Chen, Alexander Wettig, Tianyu Gao, Kaiyu Yang, Mengzhou Xia, Jinhyuk Lee, and the members of Princeton NLP group for helping with proofreading and providing valuable feedback. This research is partially supported by the James Mi *91 Research Innovation Fund for Data Science and a gift from Apple. ZZ is also supported by a JP Morgan PhD fellowship.
References
Appendix A Inference Method
Formally speaking, our testing objective is basically the same as the training objective (Eq. LABEL:equ:loss):
except that we take as a combination of , and . As can be very large, we approximate it by retrieving the top-K closest terms to . Formally, of three instantiations of Trime is constructed as follows,
where returns the top-K closest terms to in the memory set . Additionally, because may be different from the training memories, we tune a temperature term to adjust the weight of the memory component when calibrating the distribution, based on the development set.
We find that when a large set of external memory is considered during inference, the performance can be improved by calibrating a separated distribution over the memory and interpolating the output distribution and the memory distribution, similarly to kNN-LM Khandelwal et al. (2020). We think this is because the distribution of the similarity values has been significantly shifted during inference, while the relative ranking preserves. As a result, having values from two different distributions in one softmax normalization is sub-optimal compared to computing two separated probabilities and interpolating them.
Thus, we apply an additional linear interpolation to our output probability distribution. Specifically, we first use Eq. LABEL:equ:inf_loss to compute the distribution . Then, we compute a probability distribution over the tokens in memory as follow,
We linearly interpolate these two probability distributions with a coefficient and get the final output :
We tune the temperature terms and on the development set.
Appendix B Packing Segments Using BM25 Scores
In §4.3, we construct training memories by packing segments that have large lexical overlap into the same batch using BM25 Robertson and Zaragoza (2009). Algorithm 1 shows the process to pack segments into training batches. We start with a single segment and repeatedly add segments with highest BM25 scores into the same batch.
Appendix C Dataset Statistics and Tasks
We evaluate our approach on three benchmarks: WikiText-103, Enwik8, and IWSLT’14. We also evaluate our approach on BooksCorpus for domain adaptation (Appendix 5.3). Table 8 shows the statistics.
WikiText-103 Merity et al. (2017) is a word-level language modeling dataset consisting of 103M training tokens. Following standard practice, we use adaptive softmax and adaptive token embeddings Baevski and Auli (2019) in our model and report perplexity. In order to better compare with previous work, we evaluate on two model configurations—one uses a 247M Transformer model and a segment length following Baevski and Auli (2019); Khandelwal et al. (2020) and another one uses a 150M Transformer model with segment length following Dai et al. (2019). More details are provided in Appendix D.
Enwik8 Mahoney (2009) is a character-level language modeling dataset that contains a total of 100M characters. Following previous work, we report bit-per-character (bpc) on this dataset. We use a 12-layer Transformer model with a hidden dimension and segment length .
We also evaluate the IWSLT’14 DeEn machine translation task, which consists of 170K translation pairs. Following Khandelwal et al. (2021), we build an external memory by taking all the translation contexts and the corresponding target token on the training set. We use the output representation as and the input representation of last FFN layer as to compute the loss. Similarly, we use BM25 to batch training data – we encourage two target sentences with a high BM25 score to be in the same training batch (see Algorithm 1). We use the default model configuration in the Fairseq library Ott et al. (2019), and sacrebleu Post (2018) to compute BLEU scores Papineni et al. (2002).
We evaluate our approach for domain adaptation on the BooksCorpus dataset Zhu et al. (2015), which is a word-level language modeling dataset. The complete BooksCorpus dataset consists of 0.7B tokens. We build our own train/dev/test splits which consist of 100M/250K/250K tokens respectively. The train set is only used to build external memory. On this dataset, we evaluate the models trained on WikiText-103 to study how our approach can adapt to new domain without re-training or fine-tuning. The model we used on this dataset is the 247M Transformer model with a segment length 3,072.
Appendix D Model Configurations and Hyperparameters
Table 9 shows the model configurations and hyperparameters that we used in our experiments. Following Baevski and Auli (2019), during training, we train the model with fixed-length segments; during evaluation, we evaluate on the tokens at the end of the segment (i.e., an evaluation segment can overlap with others).
When evaluating with large external memory, we always retrieve top- (1,024) context-target pairs for language modeling. For machine translation, we tune following Zheng et al. (2021).
We apply our approach to SRU++ Lei (2021) and we believe our approach is also compatible with other architectures such as Transformer-XL Dai et al. (2019). SRU++ is a language model which combines recurrent units and the attention mechanism. SRU++ use hidden representations from the previous segment at attention layers to incorporate long-range contexts, similarly to Dai et al. (2019).
To apply our approach to SRU++, we follow their data-batching method as it is required due to the recurrence of the model architecture. We construct the training memory using all the contexts in the current segment (i.e., local memory) and all contexts in the previous segment (i.e., long memory). Note that the memory representations from the previous segment will be stale, thus we do not back-propagate to that part. During training, we update the model with 400K steps and a batch size of . For other hyper-parameters and the optimizer, we follow the default ones in their implementation.
During inference, we can use more contexts to construct memory. We train with different segment lengths, i.e., or . For the model trained with , it can leverage a long-term memory of a size 6,144 during inference; for the model trained with , it can leverage a long-term memory of a size 12,228.
Appendix F Performance of the 7M model on WikiText-103
We conduct our ablation studies and analyses in §6 with an 8-layer Transformer model due to the limited computation budget. The model consists of 7M parameters, 8 layers and 4 heads in each layer. The embedding dimension is 128 and the intermediate dimension of FFN is 512. The model takes a segment of 3072 tokens as input. We compare our approach with baselines on this model architecture. As shown in Table 10, our approach improves over the baselines by a large margin. This shows that modeling memory explicitly is essential when the model capacity is limited.
Appendix G Additional Analysis
We study the importance of packing segments with high BM25 scores in the same training batch, as well as the effectiveness of enabling back-propagation to memory representations during training. As shown in Table 11, when we random batch training segments (instead of using BM25 scores), the perplexity increases to 45.71 (). Also, enabling back-propagation to memory is crucial for our approach — the performance is much worse if we disable it.
Effectiveness of using local memory
We study the effectiveness of our model TrimeLM that uses only local memory with different segment lengths . As shown in Table 12, our model significantly outperforms the baselines in all the settings. This suggests that our model can leverage local memory very effectively to improve performance.
Retrieval performance on external memory
When external memory is used in our experiments, we perform nearest-neighbor search over the entire memory set to retrieve the top keys (we use ). Table 13 compares the retrieval accuracy of our approach and kNN-LM Khandelwal et al. (2020) for different . Our approach outperforms kNN-LM in terms of retrieval results; this explains how our final perplexity surpasses kNN-LM when incorporating external memory.
Perplexity breakdown for different frequencies
We aim to understand which type of memories improves perplexity of tokens in different frequency groups. We group tokens into 5 buckets according to their frequency on the development set. Table 14 shows the results for different models. TrimeLM and improve the perplexity of rare words (i.e., frequency 1k) while achieving similar or slightly worse results for frequent words compared to the Transformer baseline. improves perplexity in all the buckets. Interestingly, kNN-LM with continuous cache does not perform significantly better compared to TrimeLM and although these two models do not use external memory. This suggests that jointly training memory representations and the language model particularly help improve the performance of rare words.
Appendix H Tuning p𝑝p for training with external memory
When training the model with local and external memory, to avoid the model to only relies on high-quality local memory, we disable the local memory with a probability of . Here we study how will affect the final performance of our model. The results of using different are shown in Table 15. We find that when , the model performs poorly with external memory as the model learns to only leverage local memory and ignores external memory during training. By increasing , this issue is mitigated. We set in our main experiments.