The Inductive Bias of In-Context Learning: Rethinking Pretraining Example Design
Yoav Levine, Noam Wies, Daniel Jannai, Dan Navon, Yedid Hoshen, Amnon Shashua
Introduction
Beyond excelling in their core task of pure language modeling, modern Neural Language Models (NLMs) show impressive zero- and few-shot abilities in more general Natural Language Understanding (NLU) tasks (Brown et al., 2020). This implies that the training corpus contains the information required for performing such tasks, and moreover it implies that the common pretraining process grants the trained NLM some access to these higher level capabilities. In this paper, we highlight a connection between the quality of the emergent NLU capabilities and a basic component in the NLM training scheme: the process of segmenting the corpus into training examples.
Specifically, NLMs self-train over huge training corpora (typically, billions to trillions of words). A basic, automatic, operation in the training pipeline is to segment these corpora into training examples: contiguous text chunks of sizes processable by the neural architecture (typically, up to thousands of words). We formalize an expressivity bias that this segmentation process introduces, to be referred to as the in-context bias, which directly affects the NLM’s ability to integrate cross-corpus information. We show that the NLM can model much stronger dependencies between sentences that were shown together at least once in-context, i.e., in the same training example, than between sentences that were never shown together in the same input. This inductive bias may be good for language modeling, but it implies that NLU capabilities that involve integrating information from different examples across the corpus (see, e.g., figure 1), are under-favored by design in the current setting. Thus, if one sentence in the corpus can elucidate the meaning of another sentence (e.g., defines a hard concept or provides auxiliary information), our result implies that a model that saw them in different training examples will enjoy this elucidation less than a model that saw them in the same training example.
While standard approximation results examine the expressivity of an architecture over a single input, our theoretical approach pertains to the entire training process, and examines the expressive capacity of the resultant NLM with respect to the training set. Therefore, our approximation result ties an optimization parameter (the learning-rate) to the regular NLM architecture expressivity parameters (depth, width). Intuitively, sentences that were never shown in the same input can only access each other via the weights of the network during training. The mechanism for “storing" information in the network involves a very small learning-rate term ; our analysis formalizes and quantifies an “expressivity toll" that the model pays when making use of such harder-to-access stored information.
We employ the tool of a function’s separation rank with respect to subsets of its variables, which quantifies its ability to model input dependencies between these subsets. The separation rank was employed for analyzing the dependencies modeled by convolutional (Cohen & Shashua, 2017), recurrent (Levine et al., 2018a), and self-attention (Levine et al., 2020) networks with respect to a single input example. In order to analyze an NLM’s ability to model dependencies between different training examples, we refine the usage of this measure in two manners: (1) we introduce the -separation rank, which measures the effective ability of a function to model dependencies in a finite precision setting, and (2) we modify the separation rank such that it can account for the more intricate mechanism of mixing between variables that occurs in the sequential case.
Several recent works intuitively rely on the above formalized in-context expressivity bias in different manners, and significantly improve both task-specific training and pretraining of NLMs. Gao et al. (2020) advance the frontier in -shot learning via finetuning. They show that by concatenating several related training examples per input, instead of using standard fine-tuning practice of one example per input, the -shot performance on sentence similarity tasks is considerably boosted. Another example was pointed out in Humeau et al. (2020); Thakur et al. (2020): when training for sentence similarity tasks, including both sentences in the same input leads to a performance gain of around points relative to separately encoding each sentence. In the challenging setting of open-domain question answering, Izacard & Grave (2020) jointly attend to all documents that may contain the answer, and show large gains relative to prior methods that consider these documents in separate forward passes.
Turning our focus to methods that leverage the in-context bias for improved pretraining, the most straightforward effort is a body of work aimed at reducing the quadratic dependence of the Transformer computation on input sequence length (Tay et al., 2020). While allowing for more text in-context during training, this does not improve the model’s ability to integrate text across different documents in the corpus. The following approaches take a further step and enable direct cross-corpus connections during pretraining. Lewis et al. (2020) attend to related documents when maximizing the likelihood of a target document. The scope of related documents is restricted by meta-data: taken from the same Wikipedia entry as the input, or published on the same date. Guu et al. (2020) expand the scope of the related documents, by training a Knowledge-Retrieval model that has access to the entire Wikipedia corpus. They retrieve several related documents per target document, but condition on each related document independently. Outside of the natural language domain, Rao et al. (2021) train a Transformer based protein-LM that receives multiple related protein sequences in-context. Their protein-LM surpasses previous methods which process one sequence per input by a wide margin, with significant parameter efficiency.
2 Leveraging the in-context bias for NLU oriented training
Though the in-context bias is intuitive, the above subsection surveys recent advances that leverage it in non-trivial manners. Having formalized the theoretical advantage for in-context integration of related text, the roots of the above successes can be unified, and importantly, new methods for tilting the pretraining bias towards NLU tasks are indicated. Following the presentation of our theoretical results in section 2, we detail in section 3 two controlled setting exemplifications of new methods that directly leverage the in-context bias.
Our first experiment augments the Task Adaptive PreTraining (TAPT) setting of Gururangan et al. (2020), in which an NLM that was pretrained on a general corpus continues pretraining (with its original objective) on the training set of an NLU task. We perform TAPT on the SentEval sentence similarity benchmark (Conneau & Kiela, 2018), and during TAPT introduce the following augmentation: along with SentEval sentences, we simultaneously pretrain on related sentences from Wikipedia, the general pretraining corpus. The related sentences are found via k-Nearest Neighbors (kNN) search between the embeddings of SentEval examples and all Wikipedia sentences; we thus dub this approach kNN-TAPT. Importantly, during kNN-TAPT, each input includes a training example from the task, appended in-context by its Wikipedia neighbors. We demonstrate significant gains of the kNN-TAPT over regular TAPT on SentEval sentence similarity tasks. A dedicated ablation study shows the significance of adding the general corpus neighbors in-context, versus in separate training examples, during kNN-TAPT.
Our second experiment introduces a task-independent pretraining phase, dubbed kNN-Pretraining. As in kNN-TAPT, we group together sentences with similar sentence representations in the same training example, but in kNN-Pretraining we use only sentences from the general pretraining corpus. This can be viewed as a sentence-focused variation of the above surveyed pretraining schemes in Lewis et al. (2020) and Guu et al. (2020), who operate on full documents (up to each), and is very similar to RETRO by Borgeaud et al. (2021) (DeepMind), who show the benefits of this approach given much larger resources. Figure 1 shows that after regular pretraining for K steps on Wikipedia, the zero-shot closed book performance of different randomly initialized GPT2-medium models (M parameters) on open domain questions from Wikipedia (Kwiatkowski et al., 2019) is very low (correct on less than questions out of 20,000 in the evaluation set). Adding kNN-Pretraining for K steps raises performance significantly (correct on roughly questions in the evaluation set), reflecting the enhanced ability to integrate knowledge from related sentences, acquired via the in-context bias.
We formally establish the in-context bias: information within pretraining examples is better represented than information integrated across pretraining examples.
We ask and answer a new type of network expressivity question: how expressive is a network with respect to examples seen during its training process?
We demonstrate that in-context bias motivated “pretraining example design" elicits better representations from the same data: kNN-Pretraining improves on several NLU tasks.
Theoretical analysis: The in-context bias of self-attention
In this section, we consider the entire NLM training procedure as a functional that receives an unlabeled training corpus and outputs a trained NLM. Our analysis focuses on the corpus segmentation into training examples as a hyper-parameter of this functional. We reduce the high-level notion of representing “cross-corpus correlations" to a quintessential case study: we quantify the NLM’s ability to model dependencies between two sentences that appear in the same training example (the in-context representation) and in different training examples (the sequential representation).
We believe that the in-context bias can be shown to exist in a broad range of architectures, but we focus on self-attention since almost all modern NLMs are based on the Transformer architecture of Vaswani et al. (2017). Our theoretical framework is based on that of Levine et al. (2020); Wies et al. (2021), who analyze a simplified, theoretically accessible, self-attention network. They study the expressivity of this self-attention architecture with respect to its input, and use a measure of a multivariate function’s ability to correlate two subsets of its variable set, referred to as the separation rank. The analyzed framework captures the connectivity of self-attention but omits its softmax and ReLU non-linearities (see eq. 1 below). We refer the reader to Levine et al. (2020); Wies et al. (2021) for a discussion on the impact of these relaxations. Essentially, they are shown to weaken the overall network power but still allow a meaningful comparison of the self-attention integration abilities. Importantly, both works derive unforeseen theoretical conclusions from analyses of the separation rank measure for this architecture class, and then provide extensive empirical corroboration for their manifestation in common Transformer architectures, reinforcing the relevance of this setting. In the following, we describe in section 2.1 the analyzed in-context and sequential self-attention representations of two sentences. Then, in section 2.2, we present the separation rank, which we use in section 2.3 for quantifying the advantage of in-context representations versus sequential ones.
For simplicity of presentation, we examine two sentences and of equal length : and . The in-context representation simply concatenates both in the input:
For the sequential approach, we consider a setup in which sentence is inserted into the network at training step and sentence is inserted into the network at training step . The output of the network at training step is therefore: , where stand for all the learned weights before training step . Focusing on autoregressive NLMs for simplicity of presentation (the analysis holds for bidirectional NLMs as well), the log-likelihood loss is given by , and the gradient update for any learned weight is: , where is the learning rate. Accordingly, the analyzed sequential representation is the network output after training step :
In practice, two relevant non-neighboring sentences are not necessarily shown in consecutive pretraining steps. In comparison to the realistic scenario of and appearing at any training step, this simplifications tilts the representation in favor of modeling high correlations between and . Thus, by upper bounding the ability to correlate and in the setting of eq. 5 (as we do in section 2.3), we establish an inherent limitation of the network to access information that was stored in its weights via the gradient update mechanism. In the next subsection, we present our approach for measuring a network’s ability to correlate two sentences seen during training, which we will use in order to separate between the in-context and sequential settings.
2 A measure for modeling in-context and sequential dependencies
In this section, we refine the separation rank, used in prior work in order to analyze the dependencies between two sentences appended in-context. In section 2.2.1 we present the separation rank and introduce a finite precision refinement of it, referred to as the effective separation rank, which helps to elucidate the degradation in integration ability caused by the gradient update mechanism. In section 2.2.2 we point at a structural problem in employing the separation rank in the same manner in which it was employed in prior work that analyzed only architecture expressivity, and introduce the the sequential separation rank, meaningful for both the in-context and sequential cases.
The separation rank, introduced in Beylkin & Mohlenkamp (2002) for high-dimensional numerical analysis, was employed for various applications, e.g., chemistry (Harrison et al., 2003), particle engineering (Hackbusch, 2006), and machine learning (Beylkin et al., 2009). More recently, the separation rank has been established as a measure of dependencies modeled by deep convolutional and recurrent networks w.r.t. their inputs (Cohen & Shashua, 2017; Cohen et al., 2017; Levine et al., 2018a), and tied to quantum entanglement measures for proving that these deep learning architectures can model elaborate many-body quantum particle correlations (Levine et al., 2018b; 2019; Sharir et al., 2020). Recently, Levine et al. (2020); Wies et al. (2021) employed this measure for studying the expressivity of a self-attention architecture with respect to its input.
In words, if a function has a high separation rank, but it can be approximated up to error by a function with a low separation rank, then it has a low -separation rank.
Prior works compare two functions by establishing the differences between their separation ranks. In principle, these differences could manifest only in irrelevant magnitudes (if many of the summands in the separation rank definition are negligibly small for the function with the higher separation rank, for example). The effective separation rank is key to our analysis because we rely on the fact that information on past examples is stored in the network weights in a small magnitude (due to a small learning-rate). We show in section 2.3 that much of the integration between text segments from different training examples occurs in very small magnitudes due to high powers of the learning rate, limiting the effective integration, as measured by the -separation rank. Our techniques for bounding the -separation rank are extendable to prior works, and while these did not examine the gradient update mechanism, their results can be reinforced due to the guarantees of this introduced measure.
2.2 The sequential separation rank
Levine et al. (2020), who were the first to apply the separation rank to functions realized by Transformer architectures, studied classical architecture expressivity questions which apply only to the in-context representation. Accordingly, they analyzed only the separation rank of , defined in eq. 1, and the input variables considered for calculating the separation rank were the word embedding vectors. A fundamental difficulty arises when attempting to directly apply this method to the sequential representation: the word embedding vectors are learned parameters of the architecture. In the sequential case, when the second sentence is introduced after the calculation at time-step , the vectors used to describe it, if we were to follow prior practice, would already have depended on .
where denotes element-wise multiplication. We define the sentence association operation over the analyzed representations, denoted with (eqs. 4 or 5), to be the application of the sentence association layer of eq. 8 to all uses of the input embedding layer during the computation of . Meaning, for both mechanisms, that chosen word embeddings are marked with the identity of the sentence that invoked them. Finally, we define the following specialization of the separation rank measure to our setting, referred to as the sequential separation rank of :
Clearly, when the introduced variables are vectors of , the auxiliary layer in eq. 8 is the identity operation and so for both representations. More deeply, our expressivity questions query the ability of the in-context and sequential mechanisms to integrate two sets of variables, and captures the essence of this ability by explicating where each set enters the computation.
In the next subsection, we show that for the in-context case, analyzed in prior work, the introduced measure of the sequential separation rank is asymptotically equal to the previously employed measure of separation w.r.t. a partition of the input word embeddings (Levine et al., 2020). Thus, the properties of the existing framework are unchanged under the new definition. At the same time, for the sequential case brought forth in this paper, the sequential separation rank considers both the effect of on the gradient-updated word embedding and the introduction of into the computation.To see this, note that for the operation in eq. 8 includes both and variables from . In the following section, we make use of both extensions to the separation rank in eqs. 7 and 9 in order to establish the in-context bias.
3 The expressive advantage of in-context learning
We show below that the function computed by a self-attention based NLM when inserting sentences and together in its input (the in-context representation) can model more elaborate dependencies between and than the function attained when showing in the input, modifying the network’s weights according to its loss, and then showing in a subsequent input (the sequential representation). We begin by stating the following corollary, following from theorem 2 in Levine et al. (2020) and proposition 1 in appendix A, which upper bounds the sequential separation rank of the in-context representation:
However, the -separation rank of the sequential representation is upper bounded by a lower term:
(See proof in appendix B). Let be the entry of the analyzed sequential representation defined in eq. 5. Assume that all learned parameters and all gradients are bounded: ,The upper boundedness assumption resembles practices of gradient clipping and weight decay, and the lower boundedness assumption resembles finite precision. , and that . Then, :
Therefore, a gap between upper bounds on the ability to model dependencies between and is indicated. Since the learning rate is a small term, its log is negative and the gap is in favor of the in-context representation. The following theorem guarantees that this gap is meaningful, by showing that the higher upper bound (of the in-context case) is tight in terms of effective rank:
(See proof in appendix C). For as defined in corollary 1, there exists an assignment of the network weights for which the following holds:
Notably, corollary 1 and theorem 2 show that for the in-context case, the sequential separation rank asymptotically equals the regular separation rank, validating the relevance of this measure.
We now provide a high level proof sketch that captures the manner in which the theoretical framework of sections 2.1 and 2.2 is used for establishing the above gap (full proof in the appendix). For the in-context case, notice that each self-attention layer, defined in eq. 1, is a degree polynomial over its inputs, rendering the whole network a degree polynomial. We write this polynomial as a sum over many monomials, and by definition, the separation rank of any monomial composing the polynomial is . Since the separation rank of a sum of functions is upper bounded by the sum of their separation ranks, we upper bound the separation rank by the number of these monomials, yielding eq. 10. The main difference in the sequential representation case is that the variables affect the computation only via the gradient, so their impact is expected to be limited. However, considering that first encounters gradient updated vocabulary matrix entries , it appears that both and variables enter the self-attention stack via its input, similarly to the in-context case. So the integration between and occurs right from the start, and indeed we show that the separation rank of both representations is similar. However, since any function of is accompanied by the learning-rate , the monomials for which there are many variables will be multiplied by high powers of . This causes many monomials to be negligibly small, and accordingly not to contribute to the -separation rank. By combinatorial considerations we show that the number of monomials that are not attenuated by (have sufficiently large magnitude) yields eq. 11. ∎
The above theorems establish that from an expressivity perspective, the small magnitude of commonly employed learning-rates hinders the ability to integrate information across different training examples. Specifically, the established gap implies that the power of the joint representation of two sentences shown in different training examples is upper bounded by that of a network shallower by layers that has seen them in the same context. Common learning-rate values are on the order of , implying a deficit of layers in the sequential case. As shown in in Levine et al. (2020); Tay et al. (2021), in many practical regimes of network size depth is crucial for expressivity, reinforcing the implications of this gap.
The weaker upper bound, of the sequential case, is not guaranteed to be tight. This means that theoretically, the sequential representation may in fact be much weaker than what we have proven, e.g., that showing two sentences in the same context yields a representation that cannot be matched merely by showing them in separate contexts and adding a realistic number of layers. However, Roberts et al. (2020) show evidence supporting our indicated link between architectural parameters and the in-context bias. They show that when performing open domain question answering tasks (their defined “closed book" setting), a large T5 model that sees only the question performs comparably to smaller models that are allowed to attend to the documents that contain the answer. This directly implies a certain strength of the sequential mechanism, namely, that information which was seen during training can be accessed via the weights when the model is realistically stronger, as implied by our bounds. Notably, the large T5 model is 2-4 times the depth of the contrasted smaller models ( versus - layers), suggesting that the upper bound can be tightened to a fraction of , or that factors that are beyond expressivity also contribute to the in-context bias (e.g., optimization, generalization). Investigation of these aspects is left for future work.
kNN based pretraining example design
Our theoretical analysis quantifies the relation between the small magnitude of the learning rate, and the deficiency in the ability to model dependencies between different training examples. Clearly, small learning-rates are critical for optimization purposes, so the formalized phenomenon should not be solved via high learning-rates during training. Instead, our analysis makes it clear that if correlations between specific sentences are important for a given task, appending them in-context yields better representations for the task. Below, we describe two controlled experiments that demonstrate the importance of this indicated “pretraining example design" degree of freedom. In both experiments, correlated sentences are identified via kNN search in their RoBERTa-large sentence representation space (Reimers & Gurevych, 2019), performed using the FAISS library (Johnson et al., 2019).
The Task Adaptive PreTraining (TAPT) method, in which an NLM pretrains on the training set of an NLU task, leads to impressive gains (Gururangan et al., 2020). Notably, TAPT is most effective after the regular pretraining stage on a general corpus. This implies that during TAPT, the model generates improved representations by integrating the task related text with the knowledge stored in its weights from the preceding general pretraining phase. Under this premise, we postulated that performance will improve if we make relevant sentences from the general corpus more available to the model during the TAPT phase. According to the above analysis, a simple and effective way to bias the model towards representing desired correlations between sentences is to append them in context.
We thus propose the kNN-TAPT phase, in which the training examples are composed of task examples, concatenated with their general corpus neighbors in embedding space. We applied kNN-TAPT on the SentEval sentence similarity tasks. Showing similar sentences from Wikipedia is expected to be particularly useful on these tasks, so this is a good experimentation ground to search for effects of the in-context bias. For each SentEval example, we searched over M Wikipedia sentences and appended in-context neighbors that have embeddings with over cosine similarity to the SentEval example embedding, with a special token inserted between different sentences. We continued until finding no more neighbors or reaching a maximum of tokens in the RoBERTa vocabulary (Liu et al., 2019). This search yielded K examples, over which we continued training a pretrained RoBERTa-base model for epochs, using the first epoch for learning-rate warmup and examining peak learning rates of . See appendix D for implementation details.
Table 1, shows zero-shot SentEval sentence similarity scores, attained by using the average word embedding of an inserted sentence as its examined sentence representation (shown by Reimers & Gurevych (2019) to be most meaningful in zero shot). All models were trained according to the above prescription, besides the baseline RoBERTa which was simply evaluated. kNN-TAPT improves over regular TAPT, by over point on average, implying that the Wikipedia neighbors are indeed useful to the TAPT stage. We compared kNN-TAPT variants as an ablations study. Importantly all variants labeled with kNN-TAPT train on the same training data during the TAPT stage – the SentEval sentence similarity tasks training sets and their Wikipedia nearest neighbors, and differ only in the arrangement of the data into training examples. The “neighbors" flag relates a SentEval example to its actual neighbors from the kNN search, while the “random" flag relates it to random Wikipedia sentences from the overall neighbors pool attained in the search. The “in batch" flag implies that related sentences were shown in the same batch, where every training example includes only one sentence from either SentEval or Wikipedia. In contrast, the “in context" flag implies that related sentences were shown in the same training example.
The weakness of “neighbors, in-batch" implies that the a-priori plausible approach of biasing the model to learn from these Wikipedia neighbors via placing them in the same batch is not nearly as effective as the theoretically motivated in-context approach. Leading sentence representations employ in-batch techniques (see for example the contrasive setting of Gao et al. (2021b)), and this signal strongly suggests developing in-context parallels. The fact that the original TAPT scheme outperforms the in-batch approaches implies that including the Wikipedia sentences in separate training examples is harmful. We postulate that this is because training examples that have only Wikipedia sentences actually dilute the original TAPT signal. Indeed, by this view, the reason that “random, in-context" performs comparably to TAPT, is that it does not dilute the original TAPT signal – every training example includes a SentEval example. Overall, the clear advantage of the “neighbors, in-context" kNN-TAPT variant encourages leveraging the in-context bias for TAPT in further tasks.
2 kNN Pretraining
We extended the above to more general kNN-Pretraining, designing pretraining examples with related non-neighboring sentences given only the general pretraining corpus. kNN-Pretraining is also motivated by the kNN-LM results of Khandelwal et al. (2019), who show significant benefits of using nearest neighbors in representation space at inference time. Their results exemplify the potential impact of integrating cross-corpus related examples; our kNN-Pretraining approach provably biases the model to learn these correlations at pretraining time, via the in-context bias.
Specifically, we performed kNN search over Wikipedia sentences for every sentence in Wikipedia, and created each training example similarly to the protocol in the previous subsection. During kNN-Pretraining, half of the batch contained regular pretraining examples and half contained the prepared kNN examples, in order to retain longer ranged LM abilities. To examine the effect of kNN-Pretraining, we pretrained GPT-base and GPT-medium (M and M parameters) architectures from scratch over Wikipedia in the regular pretraining scheme, and switched to kNN-Pretraining at two different points during pretraining (K and K). The training examples were of maximal size , and the batch size was for the GPT-medium models and 256 for the GPT-base models.
In order to directly probe the acquired ability to integrate non-neighboring sentences, we evaluated the resultant models on the very challenging setup of zero-shot closed-book open domain question answering. In this setup, the unidirectional pretrained model decodes an answer conditioned on the given open ended question. We evaluated the models on questions from the Natural Questions (NQ) benchmark (Kwiatkowski et al., 2019), using the same phrasing employed in Brown et al. (2020), and employing the standard “open-domain” version as used e.g. by Lee et al. (2019); Asai et al. (2019); Roberts et al. (2020). NQ is composed of questions that have answers within Wikipedia, our pretraining corpus. kNN pretraining can imtuitively improve in cases where the passage containing the answer has elucidating nearest neighbors from across wikipedia that would help the model to better internalize the answer, such that it is more accessible to the model in zero shot. As figure 1 demonstrates, baseline models, pretrained with the regular scheme, achieve very low F1 scores on this task. In contrast, kNN-Pretraining shows a low-scoring but significant improvement.
To increase the credibility of the signal, we evaluated our models on the first K examples from the NQ training set (we tested zero-shot performance, so the training set was not used earlier). Indeed, the attained F1 scores are low, but they correspond to s of correct answers that the kNN-Pretrained model provide after roughly % of the overall training time, versus much less in the randomly initialized baseline models. Finally, we include in appendix E NQ scores of models of different sizes when starting kNN-Pretraining at different checkpoints, and in appendix F zero-shot scores on several GLUE tasks, which demonstrate clear gains of kNN-Pretraining over the baselines.
Discussion
Modern NLM pretraining schemes have tremendously advanced the natural language landscape, since they allowed powerful models to train on huge amounts of unlabeled text. But NLMs are now challenged with tasks which require deeper and more nuanced understanding of text, and means of improving the basic pretraining process should be considered. For a given architecture, pretraining can be improved by adding more data or finding more sophisticated training objectives to apply over existing data. In this paper we highlight a parallel path for improvement, which employs the same data and objective, but redistributes the available strength of the Transformer architecture such that important connections within the pretraining corpus are learned more effectively. Specifically, we highlight the bias of the trained NLM towards modeling dependencies between chunks of text that appeared within the same training examples. In current pretraining schemes, this means that dependencies between non-neighboring chunks of text are under-favored. If such dependencies matter for the task at hand, we suggest rearranging the data into corresponding training examples.
We formalize the above notion. Our theoretical setup asks expressivity questions that pertain to the training set rather than to a single example. We thus tie the construction of the training example with the available expressivity of the architecture: we prove that the connections that can be modeled between different training examples are bounded by the connections that can learned by a shallower and weaker architecture, if these examples were inserted within the same input.
The advantage in including related text in the input of the NLM is noticed and leveraged in the empirical landscape. With that, it is clear that showing the model related data is meaningful even if it is in different training examples, and many leading methods elect to do just that. Our quantification of this trade-off is intended to aid informed decisions and highlight the expressivity advantage to be gained by smarter training example designs. We follow up on these recommendations and demonstrate the immediately available gains to be achieved by designing training examples that include nearest neighbors in embedding space. This method can be enhanced, and other more explicit biases can be introduced. For example, multiple mentions of the same entity, event, or concept can be concatenated within the same training example.
The gains achieved by using similarity in representation space indicate a path for self-improving representations, left for future work. After a first cycle of kNN-Pretraining, the representation is refined and applying a new kNN search over it can lead to more informative next round of kNN-Pretraining. This way, deeper insight can be elicited from a given pretraining corpus.
Lastly, while this paper focused on leveraging the identified in-context bias for pretraining, it can also be tied to recent successes of in-context inference methods. From the in-context few-shot prompts of Brown et al. (2020), to in context augmentations such as in Gao et al. (2020); Schick & Schütze (2020) and many others, the benefits of biasing the prediction by appending text in-context are now widely established. The tools brought forth here can assist in clarifying the theoretical advantages of such practices. Overall, our work aims to provide timely theoretical interpretations, to help guide the rapid empirical advances of our field.
References
Appendix A Proof of Corollary 1
Let and be two sentences, the Transformer operation of and be a vocabulary embedding matrix. Then:
Clearly, this form of presenting is separable with respect to , and since it has summands, we can conclude that:
Corollary 1 now follows from an upper bound on given in Levine et al. (2020).
Appendix B Upper bound for the ε𝜀\varepsilon-separation rank
For an expression that can be represented as a sum of some terms,
denote by the corresponding sum, but with each term replaced by its absolute value, that is:
and note that by the triangle inequality it holds that:
Let be be the entry of the analyzed sequential representation defined in eq. 5. Assume that all learned parameters and all gradients are bounded: for some , , , , , In addition, assume that there exists for which it holds that on its domain. Then:
We start by finding a representation of as a sum of terms, where each term is separable with respect to . We then turn to finding a subset of these terms, denoted , such that the sum of all terms in is an -approximation of . Lastly, since it follows from the definition of the -separation rank and the construction of that is upper bounded by the cardinality of (which is the number of summands in the approximation), we find an upper bound to , which is therefore an upper bound to as well, which by definition is equal to .
Following Levine et al. (2020); Wies et al. (2021), can be written as:
Separating to vocab-gradient terms and vocab terms:
Compressing summation to count variable powers:
Pushing in summations on , only the parity matters:
where each summand is separable with respect to .
Now that we have a representation of as a sum of -separable terms, we turn to finding a subsum that can approximate up to an -precision.
First, let us define the set of all legal indices in the last sum:
which is the sum of all terms with indices in . Clearly, summing over all possible indices gives us the original expression:
Given , we wish to find a subset of the indices, , such that the sum of all terms whose indices are in is an -approximation of . That is, we are looking for such that for all :
Note that since we assume that , we can get:
and it follows that it is enough for us to show that:
which we make going forward. This will ensure that is an -approximation of .
Now, we assume that , and by Levine et al. (2020), the s and s in eq. 13 are products of up to matrices, so each of their coordinates is bounded in and we assume without loss of generality that (otherwise we could have picked a smaller and a larger ). Then for each the following inequalities hold:
Relaxing the parity constraints inside the brackets and recalling that gives us an upper bound:
which we can further bound using lemmas 3 and 4 in Levine et al. (2020) until we are left with:
Combining the upper and lower bound for , we get that for all :
so in order to show that (14) holds, it suffices to show that:
Now, we can limit ourselves to subsets of the form:
where the inequality is due to lemma 3 in Levine et al. (2020).
And after rearranging we get that for each such that:
In order for to be an -approximation of .
In the last step we have found a condition on subsets of indices, such that summing over any subset who meets this condition will yield an -approximation of . We will now find a specific subset who meets this condition, and use it in order to bound from above.
We will focus our attention on s of the form:
for some which we’ll determine later.
B.1 Lemmas for estimating the number of coefficients
For brevity and clarity we will use expressions of the form , regardless of whether is divisible by or not. For the latter case, this expression should actually be:
expressions of the form should be read as:
and expressions of the form should be read as:
Let be a sequence of non-negative integers such that and is maximal. Assume towards a contradiction there exist such that , then:
in contrary to the maximality of . Therefore, , . ∎
Let be two fixed natural numbers, and denote
Then for all :
where the second equality is due to the fact that:
Induction step: Let such that (17) holds for .
Thus, (17) holds for , and the proof of the induction step is complete. Hence, by induction, (17) is correct for all .
Let be two fixed natural numbers, and . Then the maximum of:
and for all , and .
From lemma 1 we know that a multinumial coefficient reaches its maximum when the sum is evenly distributed between all indices. Since the s and can be chosen independently of each other given and , we may assume without loss of generality that no matter the value of , for all , it holds that and .
is Gaussian-shaped and therefore unimodal and has a unique maximum. for is monotonically decreasing, and therefore their product is also unimodal.
So we get that is monotonically increasing as long as , and the largest integer for which this condition holds is:
Denote by the number of integer lattice points in (the -dimensional zero-centered ball of radius ). Then:
Let be the set of all integer lattice points in . For , define:
Let , so there exists such that , and from the triangle inequality we get:
and therefore . Since was chosen arbitrarily, we get that:
Note that for all and that for such that , is a set of measure zero, hence:
Assume for convenience that , so , and Stirling’s approximation yields:
On the other hand, note that , and therefore:
Let be two fixed natural numbers, a constant sensitivity parameter, and let”
By Stirling’s approximation, it holds that:
and by plugging this approximation to the definition of a multinomial coefficient we get after some rearranging:
Note that using this approximation and some rearrangements, we get that the following are equivalent:
We’ll start by finding a condition that will assure us that (i.e., we will characterize a subset of ).
First, note that by the AM-GM inequality,
Since we’re interested in a subset of , we can show that the last inequality holds when we replace the first term in the left-hand side () with (if the new inequality holds, (19) must hold as well), and we’re left with:
Now, for , let , so:
Observe that for all , it holds that:
so it suffices (again, we’re only interested in a subset of ) to show that:
Let us now turn to finding a condition that will assure us that (i.e., we will characterize a subset of the complement of ). Note that:
so if , we must have that:
and using the same definition of as before, we get:
where the last equality is due to the fact that .
Observing the first order Taylor polynomial of the function:
at with the remainder in the Lagrange form, we get:
Note that is monotonically decreasing with for , and using this fact and the fact that , (22) is lower bounded by:
So we can limit ourselves to looking at the cases where:
Combining the two results together we get:
Let be two fixed natural numbers, and a constant sensitivity parameter. Then the number of multinomial coefficients, , which uphold:
Let be a multinomial coefficient for which it holds that:
so in order to bound the cardinality of (which is the quantity we are interested in), we can find an upper bound on the cardinality of and a lower bound on the cardinality of .
Let and , and denote:
For , denote . So the problem has changed to finding the number of integer -tuples such that and , which is the number of integer lattice points in the zero-centered -dimensional ball of radius that uphold .
and the cardinality of is lower bounded by:
Let be a fixed natural number, , and a constant sensitivity parameter.Then number of integer s such that is upper bounded by:
Denote , so:
Recall that by Stirling’s approximation we know that:
approximates the number we are trying to quantify.
After some rearranging, one can observe that (23) is equivalent to:
and since for all , , the number of s which uphold (24) is upper bounded by the number of integer s for which it holds that:
Recall that for inequalities of the form where , the set of all values of which satisfy this inequality is and the number of integer values of which satisfy this condition is approximately (the interval’s length).
Let be two fixed natural numbers, , and a constant sensitivity parameter. Denote:
and let . If , then the number of which uphold is bounded from above by:
By lemma 3, . By lemma 7, the number of s between such that is upper bounded by , by lemma 6 the number of non-negative integer -tuples such that and is bounded from above by , and the number of non-negative integer -tuples such that and is bounded from above by . In total, without taking into consideration the interactions between the three multiplicands (so our bound is not tight), the number of which uphold is upper bounded by:
and since and (and hence ), this can be further bounded by:
Let be two fixed natural numbers, , and a constant sensitivity parameter. Denote:
and let . If , then the number of which uphold is bounded from below by:
By lemma 6, the number of non-negative integer -tuples such that and is bounded from below by . Since these s are only a subset of the elements of which takes into consideration, this is also a (quite loose) lower bound on the number of which uphold .
and since , it holds that , we get:
Appendix C Lower bounds on the ε𝜀\varepsilon-separation rank
We begin by laying out basic concepts in tensor theory required for the upcoming analysis. The core concept of a tensor may be thought of as a multi-dimensional array. The order of a tensor is defined to be the number of indexing entries in the array, referred to as modes. The dimension of a tensor in a particular mode is defined as the number of values taken by the index in that mode. If is a tensor of order and dimension in each mode , its entries are denoted , where the index in each mode takes values .
We now present the concept of grid tensors, which are a form of function discretization (Hackbusch, 2012). Essentially, the function is evaluated for a set of points on an exponentially large grid in the input space and the outcomes are stored in a tensor. Formally, fixing a set of template vectors , the points on the grid are the set . Given a function , the set of its values on the grid arranged in the form of a tensor are called the grid tensor induced by , denoted .
C.1.2 ε𝜀\varepsilon-rank
We will make use of the concept of -rank Alon et al. (2013) of a matrix defined for any as the minimum rank over matrices that approximate every entry of to within an additive . We will prove lower bounds on the s for which the -rank a matrix remain high by the following lemma:
Finally, we will use the following lemma for lower bounding the amount of small eigenvalues of symmetric matrices:
Since the trace of a matrix equals to both the sum of its eigenvalues and its diagonal entries, we get:
C.1.3 High-Dimensional Spheres
Finally, we will use a well known fact regarding the variation of the sphere volume for different radii (see for example Smith & Vamanamurthy (1989)):
C.2 Proof of the lower bound
In this subsection, we prove theorem 2 of the main text. We will follow the proofs of Levine et al. (2020); Wies et al. (2021), with important adjustments to the -sequential-separation rank definition.
We begin by showing that high -rank Alon et al. (2013) of the grid tensor matricization implies high -sequential-separation rank (see section 2.2 of the main text) of the function. Essentially, we apply claim 1 from Levine et al. (2020) to -approximations obtained from the -separation-rank definition. This relation, which holds for all functions, is formulated below for functions realized by the analyzed Transformer network:
where is the grid tensor of with respect to the above template vectors.
where .
Now we will shows that indeed is able to produce vectors that do not change the analysis in Levine et al. (2020) and the assumptions of corollary 2 holds.
where .
We will ignore ’s element-wise multiplication with vocabulary embedding matrix by choosing (by the terms of corollary 2 it suffices to find any assignment of the learned weights).
For any our templates vectors will be:
We will implement summation of the inputs embedding in the first self-attention layer, we will follow Levine et al. (2020) and set the first layer self-attention key and query weights to:
This assignment implements summation of the inputs embedding in the first self-attention layer since:
where is because are matrices that are zero everywhere except for entry and that all the entries in the vocabulary embedding matrix equals to , and because of linearity. Therefore, for any the output of the first self-attention layer on is:
Finally, we need to show that indeed for any eq 44 give the desired :
The third and forth cases are clear from definition, so it remain to prove the first and second cases. For this we will examine . and therefore:
and therefore:
So it clear that also the first and second cases upholds. ∎
Returning to finding for which , we will use the probabilistic method for proving the existence of such , i.e.we will show that for random the expectation of and therefore in particular there exists such .
We start by bounding the expectation of the squared norm:
Finally, by Jensen inequality we have that:
C.3 Technical lemmas
where the first equality holds because is orthogonal. Therefore, by choosing such that we will get that:
Now we can calculate the last expectation directly:
Now, by lemma 12 and fact 1 we have that:
Finally, by lemma 17 each term in the integral is upper bounded by and thus:
Note that since in the boundaries , it is enough to prove the inequality for critical points.
Therefore, is the only critical point and:
where the last inequality follow from the fact that:
Appendix D Experimental Details
We conducted the network training described in section 3.1 of the main text with AdamW optimizer (with the parameters suggested in the original RoBERTa paper: , , and weight decay of 0.01), with batch sizes of 128 or 256 (depending on model size) and sequences of 256 tokens each. We started with pretrained RoBERTA-base weights from the HuggingFace Transformers repository https://huggingface.co/transformers/, and continued training them on the MLM task with masking probability of 15%, where each masked token had a probability of 80% of being replaced with the special token, 10% of being replaced with a random token and 10% of being kept the same. The data used for this phase of training was created using the four different procedures described in section 3.1. After the training was finished, we evaluated the models’ performance using the SentEval kit.
D.2 kNN-Pretraining
We conducted the network training described in section 3.2 of the main text with AdamW optimizer (with the parameters suggested in the original GPT-2 paper: , , and weight decay of 0.1), with batch size of 512 and sequences of 256 tokens each. We pretrained a HuggingFace Transformers implementation of GPT-2 from scratch on Wikipedia with the standard LM objective, and switched to a mixture of the standard data and our generated kNN data in two different points during training. After the training was finished, we evaluated the models’ performance on the Natural Questions benchmark.
Appendix E kNN-Pretraining at different checkpoints
The following table includes F1 evaluation scores of zero shoe closed book Natural Questions examples for different model sizes at different training checkpoints. Overall, further pretraining seems to improve the effectiveness of kNN-Pretraining.
Appendix F kNN-Pretraining on additional benchmarks
The main text describes experiments on the Natural Questions dataset. We test how kNN-Pretraining affects other NLU tasks, by examining several tasks from the GLUE benchmark (Wang et al., 2018) – Multi-Genre Natural Language Inference (MNLI) (Williams et al., 2017), Recognizing Textual Entailment (RTE) (Dagan et al. (2010) and others), and the The Winograd Schema Challenge (WNLI) (Levesque et al., 2012). As in the case of Natural Questions, we evaluate the zero-shot performance of our models since it is a direct probe to the abilities of the model straight after the process of pretraining. In contrast to Natural Questions, the GLUE tasks we examined are classification tasks and not generation tasks, so assessing zero shot performance on them is not straightforward. We therefore follow the template-based method of Gao et al. (2021a) for converting the tasks’ data into a format processable by unidirectional language models.
Notably, the examined GLUE classification tasks are not easy for the examined unidirectional models in zero shot. Table 3 includes the zero-shot scores of the M parameter model that trained regularly for K steps and then continued training for K steps of kNN-Pretraining, versus the average of 3 baselines that trained regularly for the same number of overall steps (the same models used in figure 1). Similarly to the results on Natural Questions (figure 1), all examined models score only slightly better than random guess on the examined GLUE tasks. However (and again similarly to the case of Natural Questions), we get a clear signal that kNN-Pretraining significantly moves the needle when applied for just of the regular pretraining time. We conjecture that when using stronger models (that train for longer and over more data), the positive effect of kNN-Pretraining will be enhanced, since as the model improves, it can better understand and utilize the various in-context hints that kNN-Pretraining provides.