The Inductive Bias of In-Context Learning: Rethinking Pretraining Example Design

Yoav Levine, Noam Wies, Daniel Jannai, Dan Navon, Yedid Hoshen, Amnon Shashua

Introduction

Beyond excelling in their core task of pure language modeling, modern Neural Language Models (NLMs) show impressive zero- and few-shot abilities in more general Natural Language Understanding (NLU) tasks (Brown et al., 2020). This implies that the training corpus contains the information required for performing such tasks, and moreover it implies that the common pretraining process grants the trained NLM some access to these higher level capabilities. In this paper, we highlight a connection between the quality of the emergent NLU capabilities and a basic component in the NLM training scheme: the process of segmenting the corpus into training examples.

Specifically, NLMs self-train over huge training corpora (typically, billions to trillions of words). A basic, automatic, operation in the training pipeline is to segment these corpora into training examples: contiguous text chunks of sizes processable by the neural architecture (typically, up to thousands of words). We formalize an expressivity bias that this segmentation process introduces, to be referred to as the in-context bias, which directly affects the NLM’s ability to integrate cross-corpus information. We show that the NLM can model much stronger dependencies between sentences that were shown together at least once in-context, i.e., in the same training example, than between sentences that were never shown together in the same input. This inductive bias may be good for language modeling, but it implies that NLU capabilities that involve integrating information from different examples across the corpus (see, e.g., figure 1), are under-favored by design in the current setting. Thus, if one sentence in the corpus can elucidate the meaning of another sentence (e.g., defines a hard concept or provides auxiliary information), our result implies that a model that saw them in different training examples will enjoy this elucidation less than a model that saw them in the same training example.

While standard approximation results examine the expressivity of an architecture over a single input, our theoretical approach pertains to the entire training process, and examines the expressive capacity of the resultant NLM with respect to the training set. Therefore, our approximation result ties an optimization parameter (the learning-rate) to the regular NLM architecture expressivity parameters (depth, width). Intuitively, sentences that were never shown in the same input can only access each other via the weights of the network during training. The mechanism for “storing" information in the network involves a very small learning-rate term η\eta; our analysis formalizes and quantifies an “expressivity toll" that the model pays when making use of such harder-to-access stored information.

We employ the tool of a function’s separation rank with respect to subsets of its variables, which quantifies its ability to model input dependencies between these subsets. The separation rank was employed for analyzing the dependencies modeled by convolutional (Cohen & Shashua, 2017), recurrent (Levine et al., 2018a), and self-attention (Levine et al., 2020) networks with respect to a single input example. In order to analyze an NLM’s ability to model dependencies between different training examples, we refine the usage of this measure in two manners: (1) we introduce the ε\varepsilon-separation rank, which measures the effective ability of a function to model dependencies in a finite precision setting, and (2) we modify the separation rank such that it can account for the more intricate mechanism of mixing between variables that occurs in the sequential case.

Several recent works intuitively rely on the above formalized in-context expressivity bias in different manners, and significantly improve both task-specific training and pretraining of NLMs. Gao et al. (2020) advance the frontier in kk-shot learning via finetuning. They show that by concatenating several related training examples per input, instead of using standard fine-tuning practice of one example per input, the kk-shot performance on sentence similarity tasks is considerably boosted. Another example was pointed out in Humeau et al. (2020); Thakur et al. (2020): when training for sentence similarity tasks, including both sentences in the same input leads to a performance gain of around 1010 points relative to separately encoding each sentence. In the challenging setting of open-domain question answering, Izacard & Grave (2020) jointly attend to all documents that may contain the answer, and show large gains relative to prior methods that consider these documents in separate forward passes.

Turning our focus to methods that leverage the in-context bias for improved pretraining, the most straightforward effort is a body of work aimed at reducing the quadratic dependence of the Transformer computation on input sequence length (Tay et al., 2020). While allowing for more text in-context during training, this does not improve the model’s ability to integrate text across different documents in the corpus. The following approaches take a further step and enable direct cross-corpus connections during pretraining. Lewis et al. (2020) attend to related documents when maximizing the likelihood of a target document. The scope of related documents is restricted by meta-data: taken from the same Wikipedia entry as the input, or published on the same date. Guu et al. (2020) expand the scope of the related documents, by training a Knowledge-Retrieval model that has access to the entire Wikipedia corpus. They retrieve several related documents per target document, but condition on each related document independently. Outside of the natural language domain, Rao et al. (2021) train a Transformer based protein-LM that receives multiple related protein sequences in-context. Their protein-LM surpasses previous methods which process one sequence per input by a wide margin, with significant parameter efficiency.

2 Leveraging the in-context bias for NLU oriented training

Though the in-context bias is intuitive, the above subsection surveys recent advances that leverage it in non-trivial manners. Having formalized the theoretical advantage for in-context integration of related text, the roots of the above successes can be unified, and importantly, new methods for tilting the pretraining bias towards NLU tasks are indicated. Following the presentation of our theoretical results in section 2, we detail in section 3 two controlled setting exemplifications of new methods that directly leverage the in-context bias.

Our first experiment augments the Task Adaptive PreTraining (TAPT) setting of Gururangan et al. (2020), in which an NLM that was pretrained on a general corpus continues pretraining (with its original objective) on the training set of an NLU task. We perform TAPT on the SentEval sentence similarity benchmark (Conneau & Kiela, 2018), and during TAPT introduce the following augmentation: along with SentEval sentences, we simultaneously pretrain on related sentences from Wikipedia, the general pretraining corpus. The related sentences are found via k-Nearest Neighbors (kNN) search between the embeddings of SentEval examples and all Wikipedia sentences; we thus dub this approach kNN-TAPT. Importantly, during kNN-TAPT, each input includes a training example from the task, appended in-context by its Wikipedia neighbors. We demonstrate significant gains of the kNN-TAPT over regular TAPT on SentEval sentence similarity tasks. A dedicated ablation study shows the significance of adding the general corpus neighbors in-context, versus in separate training examples, during kNN-TAPT.

Our second experiment introduces a task-independent pretraining phase, dubbed kNN-Pretraining. As in kNN-TAPT, we group together sentences with similar sentence representations in the same training example, but in kNN-Pretraining we use only sentences from the general pretraining corpus. This can be viewed as a sentence-focused variation of the above surveyed pretraining schemes in Lewis et al. (2020) and Guu et al. (2020), who operate on full documents (up to 512512 each), and is very similar to RETRO by Borgeaud et al. (2021) (DeepMind), who show the benefits of this approach given much larger resources. Figure 1 shows that after regular pretraining for 200200K steps on Wikipedia, the zero-shot closed book performance of 33 different randomly initialized GPT2-medium models (345345M parameters) on open domain questions from Wikipedia (Kwiatkowski et al., 2019) is very low (correct on less than 5050 questions out of 20,000 in the evaluation set). Adding kNN-Pretraining for 2020K steps raises performance significantly (correct on roughly 250250 questions in the evaluation set), reflecting the enhanced ability to integrate knowledge from related sentences, acquired via the in-context bias.

We formally establish the in-context bias: information within pretraining examples is better represented than information integrated across pretraining examples.

We ask and answer a new type of network expressivity question: how expressive is a network with respect to examples seen during its training process?

We demonstrate that in-context bias motivated “pretraining example design" elicits better representations from the same data: kNN-Pretraining improves on several NLU tasks.

Theoretical analysis: The in-context bias of self-attention

In this section, we consider the entire NLM training procedure as a functional that receives an unlabeled training corpus and outputs a trained NLM. Our analysis focuses on the corpus segmentation into training examples as a hyper-parameter of this functional. We reduce the high-level notion of representing “cross-corpus correlations" to a quintessential case study: we quantify the NLM’s ability to model dependencies between two sentences that appear in the same training example (the in-context representation) and in different training examples (the sequential representation).

We believe that the in-context bias can be shown to exist in a broad range of architectures, but we focus on self-attention since almost all modern NLMs are based on the Transformer architecture of Vaswani et al. (2017). Our theoretical framework is based on that of Levine et al. (2020); Wies et al. (2021), who analyze a simplified, theoretically accessible, self-attention network. They study the expressivity of this self-attention architecture with respect to its input, and use a measure of a multivariate function’s ability to correlate two subsets of its variable set, referred to as the separation rank. The analyzed framework captures the connectivity of self-attention but omits its softmax and ReLU non-linearities (see eq. 1 below). We refer the reader to Levine et al. (2020); Wies et al. (2021) for a discussion on the impact of these relaxations. Essentially, they are shown to weaken the overall network power but still allow a meaningful comparison of the self-attention integration abilities. Importantly, both works derive unforeseen theoretical conclusions from analyses of the separation rank measure for this architecture class, and then provide extensive empirical corroboration for their manifestation in common Transformer architectures, reinforcing the relevance of this setting. In the following, we describe in section 2.1 the analyzed in-context and sequential self-attention representations of two sentences. Then, in section 2.2, we present the separation rank, which we use in section 2.3 for quantifying the advantage of in-context representations versus sequential ones.

For simplicity of presentation, we examine two sentences S1S_{1} and S2S_{2} of equal length N{N}: S1={w1j}j=1NS_{1}=\{w_{1}^{j}\}_{j=1}^{N} and S2={w2j}j=1NS_{2}=\{w_{2}^{j}\}_{j=1}^{N}. The in-context representation simply concatenates both in the input:

For the sequential approach, we consider a setup in which sentence S1S_{1} is inserted into the network at training step tt and sentence S2S_{2} is inserted into the network at training step t+1t+1. The output of the network at training step tt is therefore: yWt,MtVi,L,dx(S1){\mathbf{y}}^{i,L,d_{x}}_{\mathcal{W}_{t},M^{\textrm{V}}_{t}}\left(S_{1}\right), where Wt,MtV\mathcal{W}_{t},M^{\textrm{V}}_{t} stand for all the learned weights before training step tt. Focusing on autoregressive NLMs for simplicity of presentation (the analysis holds for bidirectional NLMs as well), the log-likelihood loss is given by L(S1)=j=1Nlog[(softmax{(MtV)yWt,MtVj,L,dx(S1)})w1j+1]\mathcal{L}\left(S_{1}\right)=-\sum_{j=1}^{N}\log\left[\left(softmax\left\{\left(M_{t}^{\textrm{V}}\right)^{\top}{\mathbf{y}}^{j,L,d_{x}}_{\mathcal{W}_{t},M^{\textrm{V}}_{t}}\left(S_{1}\right)\right\}\right)_{w^{j+1}_{1}}\right], and the gradient update for any learned weight θ{Wt,MtV}\theta\in\{\mathcal{W}_{t},M^{\textrm{V}}_{t}\} is: θt+1(S1;η)=θtη\nicefracL(S1)θt\theta_{t+1}\left(S_{1};\eta\right)=\theta_{t}-\eta\cdot\nicefrac{{\partial\mathcal{L}\left(S_{1}\right)}}{{\partial\theta_{t}}}, where η\eta is the learning rate. Accordingly, the analyzed sequential representation is the network output after training step t+1t+1:

In practice, two relevant non-neighboring sentences are not necessarily shown in consecutive pretraining steps. In comparison to the realistic scenario of S1S_{1} and S2S_{2} appearing at any training step, this simplifications tilts the representation in favor of modeling high correlations between S1S_{1} and S2S_{2}. Thus, by upper bounding the ability to correlate S1S_{1} and S2S_{2} in the setting of eq. 5 (as we do in section 2.3), we establish an inherent limitation of the network to access information that was stored in its weights via the gradient update mechanism. In the next subsection, we present our approach for measuring a network’s ability to correlate two sentences seen during training, which we will use in order to separate between the in-context and sequential settings.

2 A measure for modeling in-context and sequential dependencies

In this section, we refine the separation rank, used in prior work in order to analyze the dependencies between two sentences appended in-context. In section 2.2.1 we present the separation rank and introduce a finite precision refinement of it, referred to as the effective separation rank, which helps to elucidate the degradation in integration ability caused by the gradient update mechanism. In section 2.2.2 we point at a structural problem in employing the separation rank in the same manner in which it was employed in prior work that analyzed only architecture expressivity, and introduce the the sequential separation rank, meaningful for both the in-context and sequential cases.

The separation rank, introduced in Beylkin & Mohlenkamp (2002) for high-dimensional numerical analysis, was employed for various applications, e.g., chemistry (Harrison et al., 2003), particle engineering (Hackbusch, 2006), and machine learning (Beylkin et al., 2009). More recently, the separation rank has been established as a measure of dependencies modeled by deep convolutional and recurrent networks w.r.t. their inputs (Cohen & Shashua, 2017; Cohen et al., 2017; Levine et al., 2018a), and tied to quantum entanglement measures for proving that these deep learning architectures can model elaborate many-body quantum particle correlations (Levine et al., 2018b; 2019; Sharir et al., 2020). Recently, Levine et al. (2020); Wies et al. (2021) employed this measure for studying the expressivity of a self-attention architecture with respect to its input.

In words, if a function has a high separation rank, but it can be approximated up to error ε\varepsilon by a function with a low separation rank, then it has a low ε\varepsilon-separation rank.

Prior works compare two functions by establishing the differences between their separation ranks. In principle, these differences could manifest only in irrelevant magnitudes (if many of the summands in the separation rank definition are negligibly small for the function with the higher separation rank, for example). The effective separation rank is key to our analysis because we rely on the fact that information on past examples is stored in the network weights in a small magnitude (due to a small learning-rate). We show in section 2.3 that much of the integration between text segments from different training examples occurs in very small magnitudes due to high powers of the learning rate, limiting the effective integration, as measured by the ε\varepsilon-separation rank. Our techniques for bounding the ε\varepsilon-separation rank are extendable to prior works, and while these did not examine the gradient update mechanism, their results can be reinforced due to the guarantees of this introduced measure.

2.2 The sequential separation rank

Levine et al. (2020), who were the first to apply the separation rank to functions realized by Transformer architectures, studied classical architecture expressivity questions which apply only to the in-context representation. Accordingly, they analyzed only the separation rank of gi,L,dx{\mathbf{g}}^{i,L,d_{x}}, defined in eq. 1, and the input variables considered for calculating the separation rank were the word embedding vectors. A fundamental difficulty arises when attempting to directly apply this method to the sequential representation: the word embedding vectors are learned parameters of the architecture. In the sequential case, when the second sentence S2S_{2} is introduced after the calculation at time-step tt, the vectors used to describe it, if we were to follow prior practice, would already have depended on S1S_{1}.

where \odot denotes element-wise multiplication. We define the sentence association operation over the analyzed representations, denoted Zy(a,b)\mathcal{Z}_{y}\left({\mathbf{a}},{\mathbf{b}}\right) with y{yin-contexti,L,dx(S1,S2),ysequentiali,L,dx,η(S1,S2)}y\in\{{\mathbf{y}}^{i,L,d_{x}}_{\textrm{in-context}}\left(S_{1},S_{2}\right),{\mathbf{y}}^{i,L,d_{x},\eta}_{\textrm{sequential}}\left(S_{1},S_{2}\right)\} (eqs. 4 or 5), to be the application of the sentence association layer of eq. 8 to all uses of the input embedding layer during the computation of yy. Meaning, for both mechanisms, that chosen word embeddings are marked with the identity of the sentence that invoked them. Finally, we define the following specialization of the separation rank measure to our setting, referred to as the sequential separation rank of y{yin-contexti,L,dx(S1,S2),ysequentiali,L,dx,η(S1,S2)}y\in\{{\mathbf{y}}^{i,L,d_{x}}_{\textrm{in-context}}\left(S_{1},S_{2}\right),{\mathbf{y}}^{i,L,d_{x},\eta}_{\textrm{sequential}}\left(S_{1},S_{2}\right)\}:

Clearly, when the introduced variables are vectors of 1\mathbf{1}, the auxiliary layer in eq. 8 is the identity operation and so Zy(1,1)=y\mathcal{Z}_{y}(\mathbf{1},\mathbf{1})=y for both representations. More deeply, our expressivity questions query the ability of the in-context and sequential mechanisms to integrate two sets of variables, and Zy\mathcal{Z}_{y} captures the essence of this ability by explicating where each set enters the computation.

In the next subsection, we show that for the in-context case, analyzed in prior work, the introduced measure of the sequential separation rank is asymptotically equal to the previously employed measure of separation w.r.t. a partition of the input word embeddings (Levine et al., 2020). Thus, the properties of the existing framework are unchanged under the new definition. At the same time, for the sequential case brought forth in this paper, the sequential separation rank considers both the effect of S1S_{1} on the gradient-updated word embedding and the introduction of S2S_{2} into the computation.To see this, note that for s=2s=2 the operation in eq. 8 includes both Mt+1V(a)M^{\textrm{V}}_{t+1}\left({\mathbf{a}}\right) and variables from b{\mathbf{b}}. In the following section, we make use of both extensions to the separation rank in eqs. 7 and 9 in order to establish the in-context bias.

3 The expressive advantage of in-context learning

We show below that the function computed by a self-attention based NLM when inserting sentences S1S_{1} and S2S_{2} together in its input (the in-context representation) can model more elaborate dependencies between S1S_{1} and S2S_{2} than the function attained when showing S1S_{1} in the input, modifying the network’s weights according to its loss, and then showing S2S_{2} in a subsequent input (the sequential representation). We begin by stating the following corollary, following from theorem 2 in Levine et al. (2020) and proposition 1 in appendix A, which upper bounds the sequential separation rank of the in-context representation:

However, the ε\varepsilon-separation rank of the sequential representation is upper bounded by a lower term:

(See proof in appendix B). Let ysequential(p,i),L,dx,ηy^{(p,i),L,d_{x},\eta}_{\text{\emph{sequential}}} be the p[dx]p\in[d_{x}] entry of the analyzed sequential representation defined in eq. 5. Assume that all learned parameters and all gradients are bounded: θ{W,MV}:0<Λminθ,\nicefracL(S1)θΛmax\forall\theta\in\{\mathcal{W},M^{\textrm{V}}\}:0<\Lambda_{\min}\leq\left\lvert\theta\right\rvert,\left\lvert\nicefrac{{\partial\mathcal{L}\left(S_{1}\right)}}{{\partial\theta}}\right\rvert\leq\Lambda_{\max},The upper boundedness assumption resembles practices of gradient clipping and weight decay, and the lower boundedness assumption resembles finite precision. N<dxN<d_{x}, and that L>log3dxL>\log_{3}d_{x}. Then, ε>0\forall\varepsilon>0:

Therefore, a gap between upper bounds on the ability to model dependencies between S1S_{1} and S2S_{2} is indicated. Since the learning rate η\eta is a small term, its log is negative and the gap is in favor of the in-context representation. The following theorem guarantees that this gap is meaningful, by showing that the higher upper bound (of the in-context case) is tight in terms of effective rank:

(See proof in appendix C). For yin-context(p,i),L,dxy^{(p,i),L,d_{x}}_{\text{\emph{in-context}}} as defined in corollary 1, there exists an assignment of the network weights for which the following holds:

Notably, corollary 1 and theorem 2 show that for the in-context case, the sequential separation rank asymptotically equals the regular separation rank, validating the relevance of this measure.

We now provide a high level proof sketch that captures the manner in which the theoretical framework of sections 2.1 and 2.2 is used for establishing the above gap (full proof in the appendix). For the in-context case, notice that each self-attention layer, defined in eq. 1, is a degree 33 polynomial over its 2Ndx2N\cdot d_{x} inputs, rendering the whole network a degree 3L3^{L} polynomial. We write this polynomial as a sum over many monomials, and by definition, the separation rank of any monomial composing the polynomial is 11. Since the separation rank of a sum of functions is upper bounded by the sum of their separation ranks, we upper bound the separation rank by the number of these monomials, yielding eq. 10. The main difference in the sequential representation case is that the S1S_{1} variables affect the computation only via the gradient, so their impact is expected to be limited. However, considering that S2S_{2} first encounters gradient updated vocabulary matrix entries (mV)t+1=(mV)tη\nicefracL(S1)(mV)t\left(m^{\textrm{V}}\right)_{t+1}=\left(m^{\textrm{V}}\right)_{t}-\eta\nicefrac{{\partial\mathcal{L}\left(S_{1}\right)}}{{\partial\left(m^{\textrm{V}}\right)_{t}}}, it appears that both S1S_{1} and S2S_{2} variables enter the self-attention stack via its input, similarly to the in-context case. So the integration between S1S_{1} and S2S_{2} occurs right from the start, and indeed we show that the separation rank of both representations is similar. However, since any function of S1S_{1} is accompanied by the learning-rate η\eta, the monomials for which there are many S1S_{1} variables will be multiplied by high powers of η\eta. This causes many monomials to be negligibly small, and accordingly not to contribute to the ε\varepsilon-separation rank. By combinatorial considerations we show that the number of monomials that are not attenuated by η\eta (have sufficiently large magnitude) yields eq. 11. ∎

The above theorems establish that from an expressivity perspective, the small magnitude of commonly employed learning-rates hinders the ability to integrate information across different training examples. Specifically, the established gap implies that the power of the joint representation of two sentences shown in different training examples is upper bounded by that of a network shallower by 0.5log3(η1)0.5\log_{3}(\eta^{-1}) layers that has seen them in the same context. Common learning-rate values are on the order of η[106,104]\eta\in[10^{-6},10^{-4}], implying a deficit of 6\sim 6 layers in the sequential case. As shown in in Levine et al. (2020); Tay et al. (2021), in many practical regimes of network size depth is crucial for expressivity, reinforcing the implications of this gap.

The weaker upper bound, of the sequential case, is not guaranteed to be tight. This means that theoretically, the sequential representation may in fact be much weaker than what we have proven, e.g., that showing two sentences in the same context yields a representation that cannot be matched merely by showing them in separate contexts and adding a realistic number of layers. However, Roberts et al. (2020) show evidence supporting our indicated link between architectural parameters and the in-context bias. They show that when performing open domain question answering tasks (their defined “closed book" setting), a large T5 model that sees only the question performs comparably to smaller models that are allowed to attend to the documents that contain the answer. This directly implies a certain strength of the sequential mechanism, namely, that information which was seen during training can be accessed via the weights when the model is realistically stronger, as implied by our bounds. Notably, the large T5 model is 2-4 times the depth of the contrasted smaller models (4848 versus 1212-2424 layers), suggesting that the upper bound can be tightened to a fraction of LL, or that factors that are beyond expressivity also contribute to the in-context bias (e.g., optimization, generalization). Investigation of these aspects is left for future work.

kNN based pretraining example design

Our theoretical analysis quantifies the relation between the small magnitude of the learning rate, and the deficiency in the ability to model dependencies between different training examples. Clearly, small learning-rates are critical for optimization purposes, so the formalized phenomenon should not be solved via high learning-rates during training. Instead, our analysis makes it clear that if correlations between specific sentences are important for a given task, appending them in-context yields better representations for the task. Below, we describe two controlled experiments that demonstrate the importance of this indicated “pretraining example design" degree of freedom. In both experiments, correlated sentences are identified via kNN search in their RoBERTa-large sentence representation space (Reimers & Gurevych, 2019), performed using the FAISS library (Johnson et al., 2019).

The Task Adaptive PreTraining (TAPT) method, in which an NLM pretrains on the training set of an NLU task, leads to impressive gains (Gururangan et al., 2020). Notably, TAPT is most effective after the regular pretraining stage on a general corpus. This implies that during TAPT, the model generates improved representations by integrating the task related text with the knowledge stored in its weights from the preceding general pretraining phase. Under this premise, we postulated that performance will improve if we make relevant sentences from the general corpus more available to the model during the TAPT phase. According to the above analysis, a simple and effective way to bias the model towards representing desired correlations between sentences is to append them in context.

We thus propose the kNN-TAPT phase, in which the training examples are composed of task examples, concatenated with their general corpus neighbors in embedding space. We applied kNN-TAPT on the SentEval sentence similarity tasks. Showing similar sentences from Wikipedia is expected to be particularly useful on these tasks, so this is a good experimentation ground to search for effects of the in-context bias. For each SentEval example, we searched over 100100M Wikipedia sentences and appended in-context neighbors that have embeddings with over 0.80.8 cosine similarity to the SentEval example embedding, with a special token inserted between different sentences. We continued until finding no more neighbors or reaching a maximum of 256256 tokens in the RoBERTa vocabulary (Liu et al., 2019). This search yielded 170170K examples, over which we continued training a pretrained RoBERTa-base model for 55 epochs, using the first epoch for learning-rate warmup and examining peak learning rates of {1,3,5,7}105\{1,3,5,7\}\cdot 10^{-5}. See appendix D for implementation details.

Table 1, shows zero-shot SentEval sentence similarity scores, attained by using the average word embedding of an inserted sentence as its examined sentence representation (shown by Reimers & Gurevych (2019) to be most meaningful in zero shot). All models were trained according to the above prescription, besides the baseline RoBERTa which was simply evaluated. kNN-TAPT improves over regular TAPT, by over 11 point on average, implying that the Wikipedia neighbors are indeed useful to the TAPT stage. We compared 44 kNN-TAPT variants as an ablations study. Importantly all variants labeled with kNN-TAPT train on the same training data during the TAPT stage – the SentEval sentence similarity tasks training sets and their Wikipedia nearest neighbors, and differ only in the arrangement of the data into training examples. The “neighbors" flag relates a SentEval example to its actual neighbors from the kNN search, while the “random" flag relates it to random Wikipedia sentences from the overall neighbors pool attained in the search. The “in batch" flag implies that related sentences were shown in the same batch, where every training example includes only one sentence from either SentEval or Wikipedia. In contrast, the “in context" flag implies that related sentences were shown in the same training example.

The weakness of “neighbors, in-batch" implies that the a-priori plausible approach of biasing the model to learn from these Wikipedia neighbors via placing them in the same batch is not nearly as effective as the theoretically motivated in-context approach. Leading sentence representations employ in-batch techniques (see for example the contrasive setting of Gao et al. (2021b)), and this signal strongly suggests developing in-context parallels. The fact that the original TAPT scheme outperforms the in-batch approaches implies that including the Wikipedia sentences in separate training examples is harmful. We postulate that this is because training examples that have only Wikipedia sentences actually dilute the original TAPT signal. Indeed, by this view, the reason that “random, in-context" performs comparably to TAPT, is that it does not dilute the original TAPT signal – every training example includes a SentEval example. Overall, the clear advantage of the “neighbors, in-context" kNN-TAPT variant encourages leveraging the in-context bias for TAPT in further tasks.

2 kNN Pretraining

We extended the above to more general kNN-Pretraining, designing pretraining examples with related non-neighboring sentences given only the general pretraining corpus. kNN-Pretraining is also motivated by the kNN-LM results of Khandelwal et al. (2019), who show significant benefits of using nearest neighbors in representation space at inference time. Their results exemplify the potential impact of integrating cross-corpus related examples; our kNN-Pretraining approach provably biases the model to learn these correlations at pretraining time, via the in-context bias.

Specifically, we performed kNN search over Wikipedia sentences for every sentence in Wikipedia, and created each training example similarly to the protocol in the previous subsection. During kNN-Pretraining, half of the batch contained regular pretraining examples and half contained the prepared kNN examples, in order to retain longer ranged LM abilities. To examine the effect of kNN-Pretraining, we pretrained GPT-base and GPT-medium (110110M and 345345M parameters) architectures from scratch over Wikipedia in the regular pretraining scheme, and switched to kNN-Pretraining at two different points during pretraining (200200K and 400400K). The training examples were of maximal size 256256, and the batch size was 128128 for the GPT-medium models and 256 for the GPT-base models.

In order to directly probe the acquired ability to integrate non-neighboring sentences, we evaluated the resultant models on the very challenging setup of zero-shot closed-book open domain question answering. In this setup, the unidirectional pretrained model decodes an answer conditioned on the given open ended question. We evaluated the models on questions from the Natural Questions (NQ) benchmark (Kwiatkowski et al., 2019), using the same phrasing employed in Brown et al. (2020), and employing the standard “open-domain” version as used e.g. by Lee et al. (2019); Asai et al. (2019); Roberts et al. (2020). NQ is composed of questions that have answers within Wikipedia, our pretraining corpus. kNN pretraining can imtuitively improve in cases where the passage containing the answer has elucidating nearest neighbors from across wikipedia that would help the model to better internalize the answer, such that it is more accessible to the model in zero shot. As figure 1 demonstrates, 33 baseline models, pretrained with the regular scheme, achieve very low F1<103<10^{-3} scores on this task. In contrast, kNN-Pretraining shows a low-scoring but significant improvement.

To increase the credibility of the signal, we evaluated our models on the first 2020K examples from the NQ training set (we tested zero-shot performance, so the training set was not used earlier). Indeed, the attained F1 scores are low, but they correspond to 100100s of correct answers that the kNN-Pretrained model provide after roughly 1010% of the overall training time, versus much less in the 33 randomly initialized baseline models. Finally, we include in appendix E NQ scores of models of different sizes when starting kNN-Pretraining at different checkpoints, and in appendix F zero-shot scores on several GLUE tasks, which demonstrate clear gains of kNN-Pretraining over the baselines.

Discussion

Modern NLM pretraining schemes have tremendously advanced the natural language landscape, since they allowed powerful models to train on huge amounts of unlabeled text. But NLMs are now challenged with tasks which require deeper and more nuanced understanding of text, and means of improving the basic pretraining process should be considered. For a given architecture, pretraining can be improved by adding more data or finding more sophisticated training objectives to apply over existing data. In this paper we highlight a parallel path for improvement, which employs the same data and objective, but redistributes the available strength of the Transformer architecture such that important connections within the pretraining corpus are learned more effectively. Specifically, we highlight the bias of the trained NLM towards modeling dependencies between chunks of text that appeared within the same training examples. In current pretraining schemes, this means that dependencies between non-neighboring chunks of text are under-favored. If such dependencies matter for the task at hand, we suggest rearranging the data into corresponding training examples.

We formalize the above notion. Our theoretical setup asks expressivity questions that pertain to the training set rather than to a single example. We thus tie the construction of the training example with the available expressivity of the architecture: we prove that the connections that can be modeled between different training examples are bounded by the connections that can learned by a shallower and weaker architecture, if these examples were inserted within the same input.

The advantage in including related text in the input of the NLM is noticed and leveraged in the empirical landscape. With that, it is clear that showing the model related data is meaningful even if it is in different training examples, and many leading methods elect to do just that. Our quantification of this trade-off is intended to aid informed decisions and highlight the expressivity advantage to be gained by smarter training example designs. We follow up on these recommendations and demonstrate the immediately available gains to be achieved by designing training examples that include nearest neighbors in embedding space. This method can be enhanced, and other more explicit biases can be introduced. For example, multiple mentions of the same entity, event, or concept can be concatenated within the same training example.

The gains achieved by using similarity in representation space indicate a path for self-improving representations, left for future work. After a first cycle of kNN-Pretraining, the representation is refined and applying a new kNN search over it can lead to more informative next round of kNN-Pretraining. This way, deeper insight can be elicited from a given pretraining corpus.

Lastly, while this paper focused on leveraging the identified in-context bias for pretraining, it can also be tied to recent successes of in-context inference methods. From the in-context few-shot prompts of Brown et al. (2020), to in context augmentations such as in Gao et al. (2020); Schick & Schütze (2020) and many others, the benefits of biasing the prediction by appending text in-context are now widely established. The tools brought forth here can assist in clarifying the theoretical advantages of such practices. Overall, our work aims to provide timely theoretical interpretations, to help guide the rapid empirical advances of our field.

References

Appendix A Proof of Corollary 1

Let S1={w1j}j=1NS_{1}=\left\{w_{1}^{j}\right\}_{j=1}^{N} and S2={w2j}j=1NS_{2}=\left\{w_{2}^{j}\right\}_{j=1}^{N} be two sentences, gWi,L,dx\mathbf{g}{}_{\mathcal{W}}^{i,L,d_{x}} the Transformer operation of yin-contexti,L,dx(S1,S2)\mathbf{y}_{\textrm{in-context}}^{i,L,d_{x}}\left(S_{1},S_{2}\right) and MVM^{\textrm{V}} be a vocabulary embedding matrix. Then:

Clearly, this form of presenting Zyin-contexti,L,dx(S1,S2)(a,b)\mathcal{Z}_{\mathbf{y}_{\textrm{in-context}}^{i,L,d_{x}}\left(S_{1},S_{2}\right)}\left(\mathbf{a},\mathbf{b}\right) is separable with respect to (a,b)\left(\mathbf{a},\mathbf{b}\right), and since it has RR summands, we can conclude that:

Corollary 1 now follows from an upper bound on sep([N],[2N]\[N])(gWi,L,dx)\text{sep}_{\left(\left[N\right],\left[2N\right]\backslash\left[N\right]\right)}\left(\mathbf{g}_{\mathcal{W}}^{i,L,d_{x}}\right) given in Levine et al. (2020).

Appendix B Upper bound for the ε𝜀\varepsilon-separation rank

For an expression that can be represented as a sum of some terms,

denote by f+f^{+} the corresponding sum, but with each term replaced by its absolute value, that is:

and note that by the triangle inequality it holds that:

Let ysequential(p,i),L,dx,ηy_{\text{\emph{sequential}}}^{(p,i),L,d_{x},\eta} be be the p[dx]p\in[d_{x}] entry of the analyzed sequential representation defined in eq. 5. Assume that all learned parameters and all gradients are bounded: θ{W,MV}:Λminθ,\nicefracL(S1)θΛmax\forall\theta\in\{\mathcal{W},M^{\textrm{V}}\}:\Lambda_{\min}\leq\,\left|\theta\right|,\,\left|\nicefrac{{\partial\mathcal{L}\left(S_{1}\right)}}{{\partial\theta}}\right|\leq\Lambda_{\max}\, for some 0<ΛminΛmax0<\Lambda_{\min}\leq\Lambda_{\max}, N<dxN<d_{x}, η(0,1]\eta\in\left(0,1\right], 2(1+η)dxη<3L\frac{2\left(1+\eta\right)d_{x}}{\eta}<3^{L}, 2(1+η)dx2<3L2\left(1+\eta\right)d_{x}^{2}<3^{L}, In addition, assume that there exists M0M\geq 0 for which it holds that Zysequentialp,i,H,L,dx,η(S1,S2)+<M\mathcal{Z}_{y_{\text{sequential}}^{p,i,H,L,d_{x},\eta}\left(S_{1},S_{2}\right)}^{+}<M on its domain. Then:

We start by finding a representation of ZΘt+1(a,S1;η)(S2)\mathcal{Z}_{\Theta_{t+1}\left(\mathbf{a},S_{1};\eta\right)}^{\left(S_{2}\right)} as a sum of terms, where each term is separable with respect to (a,b)\left(\mathbf{a},\mathbf{b}\right). We then turn to finding a subset of these terms, denoted GG, such that the sum of all terms in GG is an ε\varepsilon-approximation of ZΘt+1(a,S1;η)(S2)\mathcal{Z}_{\Theta_{t+1}\left(\mathbf{a},S_{1};\eta\right)}^{\left(S_{2}\right)}. Lastly, since it follows from the definition of the ε\varepsilon-separation rank and the construction of GG that ε-sep(a,b)(Zysequentialp,i,H,L,dx,η(S1,S2))\varepsilon\text{-sep}_{\left(\mathbf{a},\mathbf{b}\right)}\left(\mathcal{Z}_{y_{\text{sequential}}^{p,i,H,L,d_{x},\eta}\left(S_{1},S_{2}\right)}\right) is upper bounded by the cardinality of GG (which is the number of summands in the approximation), we find an upper bound to G\left|G\right|, which is therefore an upper bound to ε-sep(a,b)(Zysequentialp,i,H,L,dx,η(S1,S2))\varepsilon\text{-sep}_{\left(\mathbf{a},\mathbf{b}\right)}\left(\mathcal{Z}_{y_{\text{sequential}}^{p,i,H,L,d_{x},\eta}\left(S_{1},S_{2}\right)}\right) as well, which by definition is equal to ε-seq-sep(ysequentialp,i,H,L,dx,η)\varepsilon\text{-seq-sep}\left(y_{\text{sequential}}^{p,i,H,L,d_{x},\eta}\right).

Following Levine et al. (2020); Wies et al. (2021), ZΘt+1(a,S1;η)(S2)(b)\mathcal{Z}_{\Theta_{t+1}\left(\mathbf{a},S_{1};\eta\right)}^{\left(S_{2}\right)}\left(\mathbf{b}\right) can be written as:

Separating to vocab-gradient terms and vocab terms:

Compressing summation to count variable powers:

Pushing in summations on NN, only the parity matters:

where each summand is separable with respect to (a,b)\left(\mathbf{a},\mathbf{b}\right).

Now that we have a representation of Zysequentialp,i,H,L,dx,η(S1,S2)\mathcal{Z}_{y_{\text{sequential}}^{p,i,H,L,d_{x},\eta}\left(S_{1},S_{2}\right)} as a sum of (a,b)\left(\mathbf{a},\mathbf{b}\right)-separable terms, we turn to finding a subsum that can approximate Zysequentialp,i,H,L,dx,η(S1,S2)\mathcal{Z}_{y_{\text{sequential}}^{p,i,H,L,d_{x},\eta}\left(S_{1},S_{2}\right)} up to an ε\varepsilon-precision.

First, let us define the set of all legal indices in the last sum:

which is the sum of all terms with indices in GG. Clearly, summing over all possible indices gives us the original expression:

Given ε>0\varepsilon>0, we wish to find a subset of the indices, GDG\subseteq D, such that the sum of all terms whose indices are in GG is an ε\varepsilon-approximation of Zysequentialp,i,H,L,dx,η(S1,S2)(a,b)\mathcal{Z}_{y_{\text{sequential}}^{p,i,H,L,d_{x},\eta}\left(S_{1},S_{2}\right)}\left(\mathbf{a},\mathbf{b}\right). That is, we are looking for GDG\subseteq D such that for all a,b\mathbf{a},\mathbf{b}:

Note that since we assume that ZD+(a,b)=Zysequentialp,i,H,L,dx,η(S1,S2)+<M\mathcal{Z}_{D}^{+}\left(\mathbf{a},\mathbf{b}\right)=\mathcal{Z}_{y_{\text{sequential}}^{p,i,H,L,d_{x},\eta}\left(S_{1},S_{2}\right)}^{+}<M, we can get:

and it follows that it is enough for us to show that:

which we make going forward. This will ensure that ZG\mathcal{Z}_{G} is an ε\varepsilon-approximation of Zysequentialp,i,H,L,dx,η(S1,S2)(a,b)\mathcal{Z}_{y_{\text{sequential}}^{p,i,H,L,d_{x},\eta}\left(S_{1},S_{2}\right)}\left(\mathbf{a},\mathbf{b}\right).

Now, we assume that θΘ:θ,L(S1)θ[Λmin,Λmax]\forall\theta\in\Theta:\theta,-\frac{\partial\mathcal{L}\left(S_{1}\right)}{\partial\theta}\in\left[\Lambda_{\min},\Lambda_{\max}\right], and by Levine et al. (2020), the PPs and QQs in eq. 13 are products of up to LL matrices, so each of their coordinates is bounded in [ΛminL,ΛmaxL]\left[\Lambda_{\min}^{L},\Lambda_{\max}^{L}\right] and we assume without loss of generality that Λmin1Λmax\Lambda_{\min}\leq 1\leq\Lambda_{\max} (otherwise we could have picked a smaller Λmin\Lambda_{\min} and a larger Λmax\Lambda_{\max}). Then for each (NA,p,n,e)D\left(N_{\text{A}},\bm{p},\bm{n},\bm{e}\right)\in D the following inequalities hold:

Relaxing the parity constraints inside the brackets and recalling that Λmax1\Lambda_{\max}\geq 1 gives us an upper bound:

which we can further bound using lemmas 3 and 4 in Levine et al. (2020) until we are left with:

Combining the upper and lower bound for λNA,p,nϕA,p,eΨB,p,n,e\lambda_{N_{\text{A}},\bm{p},\bm{n}}\cdot\phi_{A,\bm{p},\bm{e}}\cdot\Psi_{B,\bm{p},\bm{n},\bm{e}}, we get that for all GDG\subseteq D:

so in order to show that (14) holds, it suffices to show that:

Now, we can limit ourselves to subsets GDG\subseteq D of the form:

where the inequality is due to lemma 3 in Levine et al. (2020).

And after rearranging we get that for each T0T\geq 0 such that:

In order for ZG(T)(a,b)\mathcal{Z}_{G\left(T\right)}\left(\mathbf{a},\mathbf{b}\right) to be an ε\varepsilon-approximation of Zysequentialp,i,H,L,dx,η(S1,S2)(a,b)\mathcal{Z}_{y_{\text{sequential}}^{p,i,H,L,d_{x},\eta}\left(S_{1},S_{2}\right)}\left(\mathbf{a},\mathbf{b}\right).

In the last step we have found a condition on subsets of indices, such that summing over any subset who meets this condition will yield an ε\varepsilon-approximation of Zysequentialp,i,H,L,dx,η(S1,S2)\mathcal{Z}_{y_{\text{sequential}}^{p,i,H,L,d_{x},\eta}\left(S_{1},S_{2}\right)}. We will now find a specific subset who meets this condition, and use it in order to bound ε-sep(a,b)(Zysequentialp,i,H,L,dx,η(S1,S2))\varepsilon\text{-sep}_{\left(\mathbf{a},\mathbf{b}\right)}\left(\mathcal{Z}_{y_{\text{sequential}}^{p,i,H,L,d_{x},\eta}\left(S_{1},S_{2}\right)}\right) from above.

We will focus our attention on TTs of the form:

for some s(0,e1.5]s\in\left(0,e^{-1.5}\right] which we’ll determine later.

B.1 Lemmas for estimating the number of coefficients

For brevity and clarity we will use expressions of the form (KKM,,KM)\binom{K}{\frac{K}{M},\ldots,\frac{K}{M}}, regardless of whether KK is divisible by MM or not. For the latter case, this expression should actually be:

expressions of the form (KKM,,KM+1)\binom{K}{\frac{K}{M},\ldots,\frac{K}{M}+1} should be read as:

and expressions of the form (KKM,,KM1)\binom{K}{\frac{K}{M},\ldots,\frac{K}{M}-1} should be read as:

Let a1,,aMa_{1},\ldots,a_{M} be a sequence of non-negative integers such that a1++aM=Ka_{1}+\ldots+a_{M}=K and (Ka1,,aM)\binom{K}{a_{1},\ldots,a_{M}} is maximal. Assume towards a contradiction there exist j1,j2[M]j_{1},j_{2}\in\left[M\right] such that aj1aj2>1    aj2+1aj1<1a_{j_{1}}-a_{j_{2}}>1\iff\frac{a_{j_{2}}+1}{a_{j_{1}}}<1, then:

in contrary to the maximality of (Ka1,,aM)\binom{K}{a_{1},\ldots,a_{M}}. Therefore, j1,j2[M]\forall j_{1},j_{2}\in\left[M\right], aj1aj21\left|a_{j_{1}}-a_{j_{2}}\right|\leq 1. ∎

Let K,MK,M be two fixed natural numbers, η(0,1]\eta\in\left(0,1\right] and denote

Then for all n[K]{0}n\in\left[K\right]\cup\left\{0\right\}:

where the second equality is due to the fact that:

Induction step: Let n0n\geq 0 such that (17) holds for nn.

Thus, (17) holds for n+1n+1, and the proof of the induction step is complete. Hence, by induction, (17) is correct for all n[K]{0}n\in\left[K\right]\cup\left\{0\right\}.

Let K,MK,M be two fixed natural numbers, and η(0,1]\eta\in\left(0,1\right]. Then the maximum of:

and for all i[M]i\in\left[M\right], ainM1\left|a_{i}-\frac{n}{M}\right|\leq 1 and biKnM1\left|b_{i}-\frac{K-n}{M}\right|\leq 1.

From lemma 1 we know that a multinumial coefficient reaches its maximum when the sum is evenly distributed between all indices. Since the aia_{i}s and bib_{i} can be chosen independently of each other given KK and nn, we may assume without loss of generality that no matter the value of nn, for all i[M]i\in\left[M\right], it holds that ainM1\left|a_{i}-\frac{n}{M}\right|\leq 1 and biKnM1\left|b_{i}-\frac{K-n}{M}\right|\leq 1.

is Gaussian-shaped and therefore unimodal and has a unique maximum. ηx\eta^{x} for η(0,1]\eta\in\left(0,1\right] is monotonically decreasing, and therefore their product is also unimodal.

So we get that S(n)S\left(n\right) is monotonically increasing as long as n1MMηKM1+η\left\lfloor\frac{n-1}{M}\right\rfloor M\leq\frac{\eta K-M}{1+\eta}, and the largest integer for which this condition holds is:

Denote by N(BRd)\mathcal{N}\left(\mathcal{B}_{R}^{d}\right) the number of integer lattice points in BRd\mathcal{B}_{R}^{d} (the dd-dimensional zero-centered ball of radius RR). Then:

Let I(BRd)\mathcal{I}\left(\mathcal{B}_{R}^{d}\right) be the set of all integer lattice points in BRd\mathcal{B}_{R}^{d}. For xI(BRd)\bm{x}\in\mathcal{I}\left(\mathcal{B}_{R}^{d}\right), define:

Let yxI(BRd)Cx\bm{y}\in\bigcup_{\bm{x}\in\mathcal{I}\left(\mathcal{B}_{R}^{d}\right)}C_{\bm{x}}, so there exists xI(BRd)\bm{x}\in\mathcal{I}\left(\mathcal{B}_{R}^{d}\right) such that yCx\bm{y}\in C_{\bm{x}}, and from the triangle inequality we get:

and therefore yBR+d2d\bm{y}\in\mathcal{B}_{R+\frac{\sqrt{d}}{2}}^{d}. Since y\bm{y} was chosen arbitrarily, we get that:

Note that Vol(Cx)=1Vol\left(C_{\bm{x}}\right)=1 for all xI(BRd)\bm{x}\in\mathcal{I}\left(\mathcal{B}_{R}^{d}\right) and that for x,xI(BRd)\bm{x},\bm{x}^{\prime}\in\mathcal{I}\left(\mathcal{B}_{R}^{d}\right) such that xx\bm{x}\neq\bm{x}^{\prime}, CxCxC_{\bm{x}}\cap C_{\bm{x}^{\prime}} is a set of measure zero, hence:

Assume for convenience that dmod20d\mod 2\equiv 0, so Γ(d2+1)=(d2)!\Gamma\left(\frac{d}{2}+1\right)=\left(\frac{d}{2}\right)!, and Stirling’s approximation yields:

On the other hand, note that BRd2dxI(BRd)Cx\mathcal{B}_{R-\frac{\sqrt{d}}{2}}^{d}\subseteq\bigcup_{\bm{x}\in\mathcal{I}\left(\mathcal{B}_{R}^{d}\right)}C_{\bm{x}}, and therefore:

Let K,MK,M be two fixed natural numbers, s(0,1)s\in\left(0,1\right) a constant sensitivity parameter, and let”

By Stirling’s approximation, it holds that:

and by plugging this approximation to the definition of a multinomial coefficient we get after some rearranging:

Note that using this approximation and some rearrangements, we get that the following are equivalent:

We’ll start by finding a condition that will assure us that (a1,,aM)TK,M\left(a_{1},\ldots,a_{M}\right)\in T_{K,M} (i.e., we will characterize a subset of TK,MT_{K,M}).

First, note that by the AM-GM inequality,

Since we’re interested in a subset of TK,MT_{K,M}, we can show that the last inequality holds when we replace the first term in the left-hand side (i=1Mai+16\prod_{i=1}^{M}\sqrt{a_{i}+\frac{1}{6}}) with (KM+16)M2\left(\frac{K}{M}+\frac{1}{6}\right)^{\frac{M}{2}} (if the new inequality holds, (19) must hold as well), and we’re left with:

Now, for i[M]i\in\left[M\right], let hiMKai1h_{i}\coloneqq\frac{M}{K}a_{i}-1, so:

Observe that for all x1x\geq-1, it holds that:

so it suffices (again, we’re only interested in a subset of TK,MT_{K,M}) to show that:

Let us now turn to finding a condition that will assure us that (a1,,aM)TK,M\left(a_{1},\ldots,a_{M}\right)\notin T_{K,M} (i.e., we will characterize a subset of the complement of TK,MT_{K,M}). Note that:

so if (a1,,aM)TK,M\left(a_{1},\ldots,a_{M}\right)\notin T_{K,M}, we must have that:

and using the same definition of hih_{i} as before, we get:

where the last equality is due to the fact that i=1Mhi=0\sum_{i=1}^{M}h_{i}=0.

Observing the first order Taylor polynomial of the function:

at x=0x=0 with the remainder in the Lagrange form, we get:

Note that f(ξi)f^{\prime\prime}\left(\xi_{i}\right) is monotonically decreasing with ξi\xi_{i} for ξi(1,M1]\xi_{i}\in\left(-1,M-1\right], and using this fact and the fact that K1K\geq 1, (22) is lower bounded by:

So we can limit ourselves to looking at the cases where:

Combining the two results together we get:

Let K,MK,M be two fixed natural numbers, and s(0,1)s\in\left(0,1\right) a constant sensitivity parameter. Then the number of multinomial coefficients, (Ka1,,aM)\binom{K}{a_{1},\ldots,a_{M}}, which uphold:

Let (Ka1,,aM)\binom{K}{a_{1},\ldots,a_{M}} be a multinomial coefficient for which it holds that:

so in order to bound the cardinality of TK,MT_{K,M} (which is the quantity we are interested in), we can find an upper bound on the cardinality of TK,MUT_{K,M}^{U} and a lower bound on the cardinality of TK,MLT_{K,M}^{L}.

Let B{L,U}B\in\left\{L,U\right\} and aTK,MB\bm{a}\in T_{K,M}^{B}, and denote:

For i[M]i\in\left[M\right], denote xiaiKMx_{i}\coloneqq a_{i}-\frac{K}{M}. So the problem has changed to finding the number of integer MM-tuples x1,,xMx_{1},\ldots,x_{M} such that i=1Mxi=0\sum_{i=1}^{M}x_{i}=0 and xR(B)\left\|\bm{x}\right\|\leq R\left(B\right), which is the number of integer lattice points x\bm{x} in the zero-centered MM-dimensional ball of radius R(B)R\left(B\right) that uphold i=1Mxi=0\sum_{i=1}^{M}x_{i}=0.

and the cardinality of TK,MLT_{K,M}^{L} is lower bounded by:

Let KK be a fixed natural number, η(0,1]\eta\in\left(0,1\right], and s(0,1)s\in\left(0,1\right) a constant sensitivity parameter.Then number of integer nns such that (Kn)ηns(KηK1+η)ηηK1+η\binom{K}{n}\eta^{n}\geq s\cdot\binom{K}{\frac{\eta K}{1+\eta}}\eta^{\frac{\eta K}{1+\eta}} is upper bounded by:

Denote nηK1+η+xn\coloneqq\frac{\eta K}{1+\eta}+x, so:

Recall that by Stirling’s approximation we know that:

approximates the number we are trying to quantify.

After some rearranging, one can observe that (23) is equivalent to:

and since for all t>1t>-1, t1+tln(1+t)\frac{t}{1+t}\leq\ln\left(1+t\right), the number of xxs which uphold (24) is upper bounded by the number of integer xxs for which it holds that:

Recall that for inequalities of the form ax2+bx+c0ax^{2}+bx+c\leq 0 where a>0a>0, the set of all values of xx which satisfy this inequality is (bb24ac2a,b+b24ac2a)\left(\frac{-b-\sqrt{b^{2}-4ac}}{2a},\frac{-b+\sqrt{b^{2}-4ac}}{2a}\right) and the number of integer values of xx which satisfy this condition is approximately b24aca\frac{\sqrt{b^{2}-4ac}}{a} (the interval’s length).

Let K,MK,M be two fixed natural numbers, η(0,1]\eta\in\left(0,1\right], and s(0,e1.5]s\in\left(0,e^{-1.5}\right] a constant sensitivity parameter. Denote:

and let xargmaxxDK,MF(x)\mathbf{x^{\star}}\coloneqq\underset{\mathbf{x}\in D_{K,M}}{\arg\max}F\left(\mathbf{x}\right). If ηK1+η(M1)\frac{\eta K}{1+\eta}\geq\left(M-1\right), then the number of xDK,M\mathbf{x}\in D_{K,M} which uphold F(x)sF(x)F\left(\mathbf{x}\right)\geq s\cdot F\left(\mathbf{x^{\star}}\right) is bounded from above by:

By lemma 3, x(ηKM(1+η),ηKM(1+η),,ηKM(1+η),KM(1+η),,KM(1+η))\mathbf{x^{\star}}\simeq\left(\frac{\eta K}{M\left(1+\eta\right)},\frac{\eta K}{M\left(1+\eta\right)},\ldots,\frac{\eta K}{M\left(1+\eta\right)},\frac{K}{M\left(1+\eta\right)},\ldots,\frac{K}{M\left(1+\eta\right)}\right). By lemma 7, the number of nns between 0nK0\leq n\leq K such that (Kn)ηns(KηK1+η)ηηK1+η\binom{K}{n}\eta^{n}\geq s\cdot\binom{K}{\frac{\eta K}{1+\eta}}\eta^{\frac{\eta K}{1+\eta}} is upper bounded by K(2ln(s1)1)2(1+η)24η2(1+η)(ln(s1)1)\frac{K\sqrt{\left(2\ln\left(s^{-1}\right)-1\right)^{2}\left(1+\eta\right)^{2}-4\eta}}{2\left(1+\eta\right)\left(\ln\left(s^{-1}\right)-1\right)}, by lemma 6 the number of non-negative integer MM-tuples a1,,aMa_{1},\ldots,a_{M} such that a1++aM=ηK1+ηa_{1}+\ldots+a_{M}=\frac{\eta K}{1+\eta} and (ηK1+ηa1,,aM)s(ηK1+ηηKM(1+η),,ηKM(1+η))\binom{\frac{\eta K}{1+\eta}}{a_{1},\ldots,a_{M}}\geq s\cdot\binom{\frac{\eta K}{1+\eta}}{\frac{\eta K}{M\left(1+\eta\right)},\ldots,\frac{\eta K}{M\left(1+\eta\right)}} is bounded from above by (πe2)M12(M1)π(4ηKln(s1)(M1)(1+η)+1)M1\frac{\left(\frac{\pi e}{2}\right)^{\frac{M-1}{2}}}{\left(M-1\right)\sqrt{\pi}}\left(4\sqrt{\frac{\eta K\ln\left(s^{-1}\right)}{\left(M-1\right)\left(1+\eta\right)}}+1\right)^{M-1}, and the number of non-negative integer MM-tuples b1,,bMb_{1},\ldots,b_{M} such that b1++bM=K1+ηb_{1}+\ldots+b_{M}=\frac{K}{1+\eta} and (K1+ηb1,,bM)s(K1+ηKM(1+η),,KM(1+η))\binom{\frac{K}{1+\eta}}{b_{1},\ldots,b_{M}}\geq s\cdot\binom{\frac{K}{1+\eta}}{\frac{K}{M\left(1+\eta\right)},\ldots,\frac{K}{M\left(1+\eta\right)}} is bounded from above by (πe2)M12(M1)π(4Kln(s1)(M1)(1+η)+1)M1\frac{\left(\frac{\pi e}{2}\right)^{\frac{M-1}{2}}}{\left(M-1\right)\sqrt{\pi}}\left(4\sqrt{\frac{K\ln\left(s^{-1}\right)}{\left(M-1\right)\left(1+\eta\right)}}+1\right)^{M-1}. In total, without taking into consideration the interactions between the three multiplicands (so our bound is not tight), the number of xDK,M\mathbf{x}\in D_{K,M} which uphold F(x)sF(x)F\left(\mathbf{x}\right)\geq s\cdot F\left(\mathbf{x^{\star}}\right) is upper bounded by:

and since ηK(M1)(1+η)1\frac{\eta K}{\left(M-1\right)\left(1+\eta\right)}\geq 1 and se1.5s\leq e^{-1.5} (and hence 2ln(s1)12(ln(s1)1)2\frac{2\ln\left(s^{-1}\right)-1}{2\left(\ln\left(s^{-1}\right)-1\right)}\leq 2), this can be further bounded by:

Let K,MK,M be two fixed natural numbers, η(0,1]\eta\in\left(0,1\right], and s(0,e1.5]s\in\left(0,e^{-1.5}\right] a constant sensitivity parameter. Denote:

and let xargmaxxDK,MF(x)\mathbf{x^{\star}}\coloneqq\underset{\mathbf{x}\in D_{K,M}}{\arg\max}F\left(\mathbf{x}\right). If K1+ηM2\frac{K}{1+\eta}\geq M^{2}, then the number of xDK,M\mathbf{x}\in D_{K,M} which uphold F(x)sF(x)F\left(\mathbf{x}\right)\geq s\cdot F\left(\mathbf{x^{\star}}\right) is bounded from below by:

By lemma 6, the number of non-negative integer MM-tuples b1,,bMb_{1},\ldots,b_{M} such that b1++bM=K1+ηb_{1}+\ldots+b_{M}=\frac{K}{1+\eta} and (K1+ηb1,,bM)s(K1+ηKM(1+η),,KM(1+η))\binom{\frac{K}{1+\eta}}{b_{1},\ldots,b_{M}}\geq s\cdot\binom{\frac{K}{1+\eta}}{\frac{K}{M\left(1+\eta\right)},\ldots,\frac{K}{M\left(1+\eta\right)}} is bounded from below by (πe2)M12Mπ(2K1+ηln(s1)M1)M1\frac{\left(\frac{\pi e}{2}\right)^{\frac{M-1}{2}}}{M\sqrt{\pi}}\left(2\frac{\sqrt{\frac{K}{1+\eta}\ln\left(s^{-1}\right)}}{M}-1\right)^{M-1}. Since these bbs are only a subset of the elements of DK,MD_{K,M} which FF takes into consideration, this is also a (quite loose) lower bound on the number of xDK,M\mathbf{x}\in D_{K,M} which uphold F(x)sF(x)F\left(\mathbf{x}\right)\geq s\cdot F\left(\mathbf{x^{\star}}\right).

and since se1.5s\leq e^{-1.5}, it holds that ln(s1)>1\ln\left(s^{-1}\right)>1, we get:

Appendix C Lower bounds on the ε𝜀\varepsilon-separation rank

We begin by laying out basic concepts in tensor theory required for the upcoming analysis. The core concept of a tensor may be thought of as a multi-dimensional array. The order of a tensor is defined to be the number of indexing entries in the array, referred to as modes. The dimension of a tensor in a particular mode is defined as the number of values taken by the index in that mode. If A{\mathcal{A}} is a tensor of order NN and dimension MiM_{i} in each mode i[N]i\in[N], its entries are denoted Ad1...dN{\mathcal{A}}_{d_{1}...d_{N}}, where the index in each mode takes values di[Mi]d_{i}\in[M_{i}].

We now present the concept of grid tensors, which are a form of function discretization (Hackbusch, 2012). Essentially, the function is evaluated for a set of points on an exponentially large grid in the input space and the outcomes are stored in a tensor. Formally, fixing a set of template vectors x(1),,x(Z)[V]{\mathbf{x}}^{(1)},\ldots,{\mathbf{x}}^{(Z)}\in\left[V\right], the points on the grid are the set {(x(d1),,x(dN))}d1,,dN=1Z\{({\mathbf{x}}^{(d_{1})},\ldots,{\mathbf{x}}^{(d_{N})})\}_{d_{1},\ldots,d_{N}=1}^{Z}. Given a function y(x1,,xN)y({\mathbf{x}}^{1},\ldots,{\mathbf{x}}^{N}), the set of its values on the grid arranged in the form of a tensor are called the grid tensor induced by yy, denoted A(y)d1,,dNy(x1=x(d1),,xN=x(dN)){\mathcal{A}}(y)_{d_{1},\ldots,d_{N}}\equiv y({\mathbf{x}}^{1}={\mathbf{x}}^{(d_{1})},\ldots,{\mathbf{x}}^{N}={\mathbf{x}}^{(d_{N})}).

C.1.2 ε𝜀\varepsilon-rank

We will make use of the concept of ε\varepsilon-rank Alon et al. (2013) of a matrix AA defined for any ε>0\varepsilon>0 as the minimum rank over matrices that approximate every entry of AA to within an additive ε\varepsilon. We will prove lower bounds on the ε\varepsilons for which the ε\varepsilon-rank a matrix remain high by the following lemma:

Finally, we will use the following lemma for lower bounding the amount of small eigenvalues of symmetric matrices:

Since the trace of a matrix equals to both the sum of its eigenvalues and its diagonal entries, we get:

C.1.3 High-Dimensional Spheres

Finally, we will use a well known fact regarding the variation of the sphere volume for different radii (see for example Smith & Vamanamurthy (1989)):

C.2 Proof of the lower bound

In this subsection, we prove theorem 2 of the main text. We will follow the proofs of Levine et al. (2020); Wies et al. (2021), with important adjustments to the ε\varepsilon-sequential-separation rank definition.

We begin by showing that high ε\varepsilon-rank Alon et al. (2013) of the grid tensor matricization implies high ε\varepsilon-sequential-separation rank (see section 2.2 of the main text) of the function. Essentially, we apply claim 1 from Levine et al. (2020) to ε\varepsilon-approximations obtained from the ε\varepsilon-separation-rank definition. This relation, which holds for all functions, is formulated below for functions realized by the analyzed Transformer network:

where A(Zyin-context(p,i),L,dx){\mathcal{A}}(\mathcal{Z}_{y^{(p,i),L,d_{x}}_{\text{\emph{in-context}}}}) is the grid tensor of Zyin-context(p,i),L,dx\mathcal{Z}_{y^{(p,i),L,d_{x}}_{\text{\emph{in-context}}}} with respect to the above template vectors.

where ϕ(j)\nicefracj1da(da1)+(j1modda)+1\phi(j)\equiv\left\lfloor\nicefrac{{j-1}}{{d_{a}}}\right\rfloor\cdot(d_{a}-1)+(j-1\bmod d_{a})+1.

Now we will shows that indeed Zyin-context(p,i),L,dx\mathcal{Z}_{y^{(p,i),L,d_{x}}_{\text{\emph{in-context}}}} is able to produce vectors that do not change the analysis in Levine et al. (2020) and the assumptions of corollary 2 holds.

where ϕ(j)\nicefracj1da(da1)+(j1modda)+1\phi(j)\equiv\left\lfloor\nicefrac{{j-1}}{{d_{a}}}\right\rfloor\cdot(d_{a}-1)+(j-1\bmod d_{a})+1.

We will ignore Zyin-context(p,i),L,dx\mathcal{Z}_{y^{(p,i),L,d_{x}}_{\text{\emph{in-context}}}}’s element-wise multiplication with vocabulary embedding matrix by choosing i,jMi,jV=1\forall i,j\,M^{\textrm{V}}_{i,j}=1 (by the terms of corollary 2 it suffices to find any assignment of the learned weights).

For any i[2((d3L2))+1]i\in\left[2\left(\kern-3.00003pt\left(\genfrac{}{}{0.0pt}{}{d}{3^{L-2}}\right)\kern-3.00003pt\right)+1\right] our templates vectors will be:

We will implement summation of the inputs embedding in the first self-attention layer, we will follow Levine et al. (2020) and set the first layer self-attention key and query weights to:

This assignment implements summation of the inputs embedding in the first self-attention layer since:

where (1)(1) is because WQ,1,h=WK,1,hW^{Q,1,h}=W^{K,1,h} are matrices that are zero everywhere except for entry (1,da)(1,d_{a}) and that all the entries in the vocabulary embedding matrix equals to 11, and (2)(2) because of linearity. Therefore, for any j1,j2[((d3L2))]j_{1},j_{2}\in\left[\left(\kern-3.00003pt\left(\genfrac{}{}{0.0pt}{}{d}{3^{L-2}}\right)\kern-3.00003pt\right)\right] the output of the first self-attention layer on j1,j2j_{1},j_{2} is:

Finally, we need to show that indeed for any j1,j2[((d3L2))]j_{1},j_{2}\in\left[\left(\kern-3.00003pt\left(\genfrac{}{}{0.0pt}{}{d}{3^{L-2}}\right)\kern-3.00003pt\right)\right] eq 44 give the desired u{\mathbf{u}}:

The third and forth cases are clear from xs{\mathbf{x}}^{\prime}s definition, so it remain to prove the first and second cases. For this we will examine d1,d2d_{1},d_{2}. d1=j1((d3L2))d_{1}=j_{1}\leq\left(\kern-3.00003pt\left(\genfrac{}{}{0.0pt}{}{d}{3^{L-2}}\right)\kern-3.00003pt\right) and therefore:

d2=j2+((d3L2))[1+((d3L2)),2((d3L2))]d_{2}=j_{2}+\left(\kern-3.00003pt\left(\genfrac{}{}{0.0pt}{}{d}{3^{L-2}}\right)\kern-3.00003pt\right)\in\left[1+\left(\kern-3.00003pt\left(\genfrac{}{}{0.0pt}{}{d}{3^{L-2}}\right)\kern-3.00003pt\right),2\left(\kern-3.00003pt\left(\genfrac{}{}{0.0pt}{}{d}{3^{L-2}}\right)\kern-3.00003pt\right)\right] and therefore:

So it clear that also the first and second cases upholds. ∎

Returning to finding BB for which MF((d+1))n34\left\|M\right\|_{F}\leq\left(\sqrt{\left(d+1\right)}\right)n^{\frac{3}{4}}, we will use the probabilistic method for proving the existence of such BB, i.e.we will show that for random BB the expectation of MF((d+1))n34\left\|M\right\|_{F}\leq\left(\sqrt{\left(d+1\right)}\right)n^{\frac{3}{4}} and therefore in particular there exists such BB.

We start by bounding the expectation of the squared norm:

Finally, by Jensen inequality we have that:

C.3 Technical lemmas

where the first equality holds because RR is orthogonal. Therefore, by choosing RR such that Rv=e1Rv=e_{1} we will get that:

Now we can calculate the last expectation directly:

Now, by lemma 12 and fact 1 we have that:

Finally, by lemma 17 each term in the integral is upper bounded by ((dλ))12\left(\kern-3.00003pt\left(\genfrac{}{}{0.0pt}{}{d}{\lambda}\right)\kern-3.00003pt\right)^{-\frac{1}{2}} and thus:

Note that since x2λ(1x2)d2=0x^{2\lambda}\left(1-x^{2}\right)^{\frac{d}{2}}=0 in the boundaries (x{0,1})\left(x\in\left\{0,1\right\}\right), it is enough to prove the inequality for critical points.

Therefore, x2=2λ2λ+dx^{2}=\frac{2\lambda}{2\lambda+d} is the only critical point and:

where the last inequality follow from the fact that:

Appendix D Experimental Details

We conducted the network training described in section 3.1 of the main text with AdamW optimizer (with the parameters suggested in the original RoBERTa paper: β1=0.9\beta_{1}=0.9, β2=0.98\beta_{2}=0.98, ε=106\varepsilon=10^{-6} and weight decay of 0.01), with batch sizes of 128 or 256 (depending on model size) and sequences of 256 tokens each. We started with pretrained RoBERTA-base weights from the HuggingFace Transformers repository https://huggingface.co/transformers/, and continued training them on the MLM task with masking probability of 15%, where each masked token had a probability of 80% of being replaced with the special [MASK]\mathtt{\left[MASK\right]} token, 10% of being replaced with a random token and 10% of being kept the same. The data used for this phase of training was created using the four different procedures described in section 3.1. After the training was finished, we evaluated the models’ performance using the SentEval kit.

D.2 kNN-Pretraining

We conducted the network training described in section 3.2 of the main text with AdamW optimizer (with the parameters suggested in the original GPT-2 paper: β1=0.9\beta_{1}=0.9, β2=0.95\beta_{2}=0.95, ε=108\varepsilon=10^{-8} and weight decay of 0.1), with batch size of 512 and sequences of 256 tokens each. We pretrained a HuggingFace Transformers implementation of GPT-2 from scratch on Wikipedia with the standard LM objective, and switched to a mixture of the standard data and our generated kNN data in two different points during training. After the training was finished, we evaluated the models’ performance on the Natural Questions benchmark.

Appendix E kNN-Pretraining at different checkpoints

The following table includes F1 evaluation scores of zero shoe closed book Natural Questions examples for different model sizes at different training checkpoints. Overall, further pretraining seems to improve the effectiveness of kNN-Pretraining.

Appendix F kNN-Pretraining on additional benchmarks

The main text describes experiments on the Natural Questions dataset. We test how kNN-Pretraining affects other NLU tasks, by examining several tasks from the GLUE benchmark (Wang et al., 2018) – Multi-Genre Natural Language Inference (MNLI) (Williams et al., 2017), Recognizing Textual Entailment (RTE) (Dagan et al. (2010) and others), and the The Winograd Schema Challenge (WNLI) (Levesque et al., 2012). As in the case of Natural Questions, we evaluate the zero-shot performance of our models since it is a direct probe to the abilities of the model straight after the process of pretraining. In contrast to Natural Questions, the GLUE tasks we examined are classification tasks and not generation tasks, so assessing zero shot performance on them is not straightforward. We therefore follow the template-based method of Gao et al. (2021a) for converting the tasks’ data into a format processable by unidirectional language models.

Notably, the examined GLUE classification tasks are not easy for the examined unidirectional models in zero shot. Table 3 includes the zero-shot scores of the 345345M parameter model that trained regularly for 200200K steps and then continued training for 2020K steps of kNN-Pretraining, versus the average of 3 baselines that trained regularly for the same number of overall steps (the same models used in figure 1). Similarly to the results on Natural Questions (figure 1), all examined models score only slightly better than random guess on the examined GLUE tasks. However (and again similarly to the case of Natural Questions), we get a clear signal that kNN-Pretraining significantly moves the needle when applied for just 10%10\% of the regular pretraining time. We conjecture that when using stronger models (that train for longer and over more data), the positive effect of kNN-Pretraining will be enhanced, since as the model improves, it can better understand and utilize the various in-context hints that kNN-Pretraining provides.