An Explanation of In-context Learning as Implicit Bayesian Inference

Sang Michael Xie, Aditi Raghunathan, Percy Liang, Tengyu Ma

Introduction

Large language models (LMs) such as GPT-3 (Brown et al., 2020, Lieber et al., 2021, Wang and Komatsuzaki, 2021, Radford et al., 2019) are pretrained on massive text corpora to predict the next word given previous words. They demonstrate the surprising ability to do in-context learning, where an LM “learns” to do a task simply by conditioning on a prompt containing input-output pairs, achieving SOTA results on LAMBADA (Paperno et al., 2016) and TriviaQA (Joshi et al., 2017) tasks (18% and 3% over previous SOTA (Brown et al., 2020)). For example, consider the task of predicting nationalities from names. A prompt (Figure 1) is constructed by concatenating independent “training” examples (e.g., “Albert Einstein was German”) followed by a “test example” (“Marie Curie was”). Conditioning on this prompt, GPT-3 places the largest probability on the correct output

by inferring the task from examples. Intruigingly, GPT-3 was not explicitly pretrained to learn from examples, and the distribution of prompts (which concatenate independent examples) is quite different from natural language. Our understanding of in-context learning is limited since (i) real pretraining data is messy and (ii) in-context learning has so far required large-scale datasets and models.

In this paper, we introduce a simple pretraining distribution where in-context learning emerges. To generate a document, we first draw a latent concept θ\theta, which parameterizes the transitions of a Hidden Markov Model (HMM) (Baum and Petrie, 1966), then sample a sequence of tokens from the HMM (Figure 9). This latent variable structure is common in topic models such as LDA (Blei et al., 2003, Gruber et al., 2007). During pretraining, the LM must infer the latent concept across multiple sentences to generate coherent continuations. When conditioning on a prompt, in-context learning occurs when the LM also infers a shared prompt concept across examples to make a prediction. We assume the LM fits the pretraining distribution pp exactly with enough data and expressivity, so that the question of in-context learning becomes characterizing the conditional distribution of completions given prompts p(outputprompt)p(\text{output}|\text{prompt}) under the pretraining distribution, where the prompt is generated from a different distribution ppromptp_{\text{prompt}}. This conditional distribution, which is the posterior predictive distribution, marginalizes out the latent concepts:

If p(conceptprompt)p(\text{concept}|\text{prompt}) concentrates on the prompt concept with more examples, then the LM learns via marginalization by “selecting” the prompt concept. Thus, in-context learning can be viewed as the LM implicitly performing Bayesian inference.

The main challenge is that prompts are sampled from a different distribution than the pretraining distribution. The canonical Bayesian asymptotic tool is the Bernstein-von Mises theorem (van der Vaart, 1998, Kleijn and van der Vaart, 2012, Gunst and Shcherbakova, 2008), which asserts (under regularity conditions) that the posterior distribution of a latent variable concentrates on the maximum likelihood estimate. However, Bernstein-von Mises typically assumes observations are independent and/or drawn from the same distribution as the model, both of which are not satisfied. We prove that despite the distribution mismatch, the asymptotic prediction error of in-context learning is optimal when the signal about the latent concept in each prompt example is larger than the error due to the distribution mismatch. Additionally, we prove that the in-context learning error decreases with the length of each example—thus, information in the inputs, not just the input-output mapping, can be useful for in-context learning.

As a companion to this theory, we created the Generative IN-Context learning dataset (GINC), which is a small-scale synthetic dataset for studying in-context learning. We find that both Transformers (Vaswani et al., 2017) and LSTMs (Hochreiter and Schmidhuber, 1997) trained on GINC exhibit in-context learning. We verify intuitions from the theory, showing that the accuracy of in-context learning improves with the number of examples and example length. Ablations of the GINC dataset show that the latent concept structure in the pretraining distribution is crucial to the emergence of in-context learning.

The experiments also bring up open questions which go beyond our theory, which only studies the pretraining distribution. We find that scaling up the number of model parameters steadily improves the in-context accuracy despite achieving the same pretraining loss, showing that larger models may improve in-context learning beyond increasing the capacity for memorizing the training data better. Previously observed in-context learning phenomena such as sensitivity to example ordering (Zhao et al., 2021) and the existence of settings where zero-shot is better than one/few-shot learning (Brown et al., 2020) are also mirrored in GINC.

In-context learning setting

In our framework, a latent concept θ\theta from a family of concepts Θ\Theta defines a distribution over observed tokens oo from a vocabulary O\mathcal{O}. To generate a document, we first sample a concept from a prior p(θ)p(\theta) and then sample the document given the concept. Each pretraining document is a length TT sequence:

We assume p(o1,,oTθ)p(o_{1},\dots,o_{T}|\theta) is defined by a Hidden Markov Model (HMM). The concept θ\theta determines the transition probability matrix of the HMM hidden states h1,,hTh_{1},\dots,h_{T} from a hidden state set H\mathcal{H}.

The prompt distribution ppromptp_{\text{prompt}} generates prompts for in-context learning. The prompt is a concatenation of nn independent training examples and 1 test input xtestx_{\text{test}}, which are all conditioned on a shared prompt concept θ{\theta^{*}}. The goal is to predict the test output ytesty_{\text{test}} by predicting the next token.

A prompt example is composed of an input token sequence xx (e.g., Albert Einstein was) followed by an output token yy (e.g., German). In particular, the ii-th training example OiO_{i} consists of an input xi=Oi[1 ⁣:k1]x_{i}=O_{i}[1\colon k-1] (the first k1k-1 tokens) followed by an output token yi=Oi[k]y_{i}=O_{i}[k] at the end222The example length kk is fixed for simplicity — we leave extending our analysis to variable kk as future work.. The ii-th training example is independently generated as follows:

Generate a start hidden state histarth^{\text{start}}_{i} from a prompt start distribution ppromptp_{\text{prompt}}.

Given histarth^{\text{start}}_{i}, generate the example sequence Oi=[xi,yi]O_{i}=[x_{i},y_{i}] from p(Oihistart,θ)p(O_{i}|h^{\text{start}}_{i},{\theta^{*}}), the pretraining distribution conditioned on a prompt concept θ{\theta^{*}}.

The test input xtest=xn+1x_{\text{test}}=x_{n+1} is sampled similarly. Between each example, there is a special delimiter token odelimo^{\text{delim}}. The prompt consists of a sequence of training examples (SnS_{n}) followed by the test example xtestx_{\text{test}}:

Since transitions between independent examples can be unnatural, the prompts are low probability sequences under the pretraining distribution. We provide a simple illustration using the names to nationalities example. Suppose that wiki bio documents in the pretraining data typically transition between name \rightarrow nationality \rightarrow occupation \rightarrow\dots. In the prompt, the examples transition between name \rightarrow nationality \rightarrow name \rightarrow nationality \rightarrow\dots, which contains low-probability transitions such as “German” \rightarrow “Mahatma Gandhi”. The prompt formatting (e.g., choice of delimiter) can also be a source of mismatch. We aim to show that despite this mismatch, large LMs can infer the prompt concept from examples.

For in-context learning, the output target yy for each example xx is sampled according to pprompt(yx)p_{\text{prompt}}(y|x):

where hteststarth^{\text{start}}_{\text{test}} denotes the hidden state corresponding to the first token of xtestx_{\text{test}}.

We analyze the in-context predictor fn(xtest)=arg maxyp(ySn,xtest)f_{n}(x_{\text{test}})=\operatorname*{arg\,max}_{y}p(y|S_{n},x_{\text{test}}), which outputs the most likely prediction over the pretraining distribution conditioned on the prompt from the prompt distribution333In practice, greedy decoding or nucleus sampling (Holtzman et al., 2020) are used for likely completions.. We study the in-context predictor and its expected 0-1 error with nn examples L0-1(fn)=Extest,ytestpprompt[1[fn(xtest)ytest]]L_{\text{0-1}}(f_{n})=\mathbb{E}_{x_{\text{test}},y_{\text{test}}\sim p_{\text{prompt}}}[\mathbf{1}[f_{n}(x_{\text{test}})\neq y_{\text{test}}]].

1 Assumptions

We detail the assumptions in our framework, including the structure of delimiters and regularity assumptions. We first assume that there exists a subset of delimiter hidden states D\mathcal{D} which generates the special delimiter token odelimo^{\text{delim}} deterministically.

Let the delimiter hidden states D\mathcal{D} be a subset of H\mathcal{H}. For any hdelimDh^{\text{delim}}\in\mathcal{D} and θΘ\theta\in\Theta, p(odelimhdelim,θ)=1p(o^{\text{delim}}|h^{\text{delim}},\theta)=1 and for any hDh\notin\mathcal{D}, p(odelimh,θ)=0p(o^{\text{delim}}|h,\theta)=0.

Thus, observing the delimiter odelimo^{\text{delim}} reveals that the corresponding hidden state is in D\mathcal{D}, but does not reveal which element of D\mathcal{D} it is. The delimiter is usually a token that can appear in a broad range of contexts (e.g., newline). The delimiter ideally does not distract from the examples — for example, an adversarial delimiter could look like part of the input xx. To mitigate these scenarios, we assume that no delimiter (e.g., newline) is significantly more likely under one concept rather than another.

For any delimiter state hdelimDh^{\text{delim}}\in\mathcal{D} and any hidden state hHh\in\mathcal{H}, the probability of transitioning to a delimiter hidden state under θ\theta is upper bounded p(hdelimh,θ)<c2p(h^{\text{delim}}|h,\theta)<c_{2} for any θΘ{θ}\theta\in\Theta\setminus\{{\theta^{*}}\}, and is lower bounded p(hdelimh,θ)>c1>0p(h^{\text{delim}}|h,{\theta^{*}})>c_{1}>0 for θ{\theta^{*}}. Additionally, the start hidden state distribution for delimiter hidden states is bounded as p(hdelimθ)[c3,c4]p(h^{\text{delim}}|\theta)\in[c_{3},c_{4}].

The choice of prompt start distribution can be a source of distribution shift which is separate from the distribution shift from concatenating independent examples. We make an assumption that limits how much distribution shift is introduced by the prompt start distribution.

We assume that the prompt start distribution ppromptp_{\text{prompt}} is close in TV distance to all hidden transition distributions (under θ{\theta^{*}}) starting from a delimiter hidden state: maxhdelimDTV(pprompt(h)p(hhdelim,θ))<Δ/4\max_{h^{\text{delim}}\in\mathcal{D}}TV(p_{\text{prompt}}(h)\|p(h|h^{\text{delim}},{\theta^{*}}))<\Delta/4. Here, Δ=pprompt(ymaxxtest)maxyymaxpprompt(yxtest)\Delta=p_{\text{prompt}}(y_{\text{max}}|x_{\text{test}})-\max_{y\neq y_{\text{max}}}p_{\text{prompt}}(y|x_{\text{test}}) is the margin between the most likely label ymax=arg maxypprompt(yxtest)y_{\text{max}}=\operatorname*{arg\,max}_{y}p_{\text{prompt}}(y|x_{\text{test}}) and the second most likely label.

Note that even when the maximum TV distance is 0, there is still distribution shift from concatenating independent examples.

We also assume the prompt concept θ{\theta^{*}} is in the family Θ\Theta, which is a broad set of concepts.

The prompt concept θ{\theta^{*}} is in Θ\Theta.

Even though the pretraining distribution is broad, the prompt is still low probability under the pretraining distribution since it concatenates independent examples.

Finally, if the prompt has zero probability under the prompt concept θ{\theta^{*}}, then Bayesian inference will not be able to infer the prompt concept as in Section 3.1. The following are regularity assumptions which mainly ensure that the prompt is not zero probability under θ{\theta^{*}}.

The pretraining distribution pp satisfies: 1) Lower bound on transition probability for the prompt concept θ{\theta^{*}}: for any pair of hidden states h,hHh,h^{\prime}\in\mathcal{H}, p(hh,θ)>c5>0p(h|h^{\prime},{\theta^{*}})>c_{5}>0. 2) Start hidden state is lower bounded: for any hHh\in\mathcal{H}, p(hθ)c8>0p(h|{\theta^{*}})\geq c_{8}>0. 3) All tokens can be emitted: for every symbol oo, there is some hidden state hHh\in\mathcal{H} such that p(oh,θ)>c6>0p(o|h,{\theta^{*}})>c_{6}>0, 4) The prior p(θ)p(\theta) has support over the entire concept family Θ\Theta and is bounded above everywhere.

Theoretical analysis

We prove that in the limit of infinite examples, the error of the in-context predictor is optimal if a distinguishability condition holds — the prompt concept θ{\theta^{*}} is distinct enough from the other concepts in Θ\Theta (e.g., when Θ\Theta is a discrete set). When distinguishability does not hold (e.g, Θ\Theta is continuous-valued), we show that the expected error still decreases with the length of each example, showing that information in both the inputs and the input-output mapping contribute to in-context learning.

Our goal is to show that arg maxyp(ySn,xtest)arg maxypprompt(yxtest)\operatorname*{arg\,max}_{y}p(y|S_{n},x_{\text{test}})\rightarrow\operatorname*{arg\,max}_{y}p_{\text{prompt}}(y|x_{\text{test}}) as the number of examples nn grows. In the following, assume that the prompt has non-zero probability under the pretraining distribution pp given θ{\theta^{*}}, meaning that p(Sn,xtestθ)>0p(S_{n},x_{\text{test}}|{\theta^{*}})>0. We expand p(ySn,xtest)p(y|S_{n},x_{\text{test}}) to analyze its limit:

where rn(θ)=1nlogp(Sn,xtestθ)p(Sn,xtestθ)r_{n}(\theta)=\frac{1}{n}\log\frac{p(S_{n},x_{\text{test}}|\theta)}{p(S_{n},x_{\text{test}}|{\theta^{*}})}. In Theorem 1, we prove that under a distinguishability condition, exp(nrn(θ))0\exp(n\cdot r_{n}(\theta))\rightarrow 0 for all concepts θ\theta except the prompt concept θ{\theta^{*}}, where exp(nrn(θ))=1\exp(n\cdot r_{n}({\theta^{*}}))=1. The only nonzero term in the integral is when θ=θ\theta={\theta^{*}}, and thus the prompt concept is “selected” as a consequence of Bayesian inference444We can exchange limits and integrals since the probabilities are bounded (dominated convergence).. Lemma 1 shows that the argmax after restricting to θ{\theta^{*}} is the same as the most likely label under pprompt(yxtest)p_{\text{prompt}}(y|x_{\text{test}}) (using Assumption 3). Putting these together with Equation 6, the in-context predictor infers the prompt concept θ{\theta^{*}}:

Thus, the in-context predictor is optimal as the number of in-context examples increases.

2 Heuristic derivation

Recall from Section 3.1 that if exp(nrn(θ))0\exp(n\cdot r_{n}(\theta))\rightarrow 0 for all θθ\theta\neq{\theta^{*}}, then Bayesian inference “selects” the prompt concept through marginalization. To do this, we focus on showing that rn(θ)r_{n}(\theta), the average log-likelihood ratio between θ\theta and θ{\theta^{*}}, converges to a negative constant, and thus nrnnr_{n} goes to -\infty.

The main technical challenge is to handle the sequence-of-examples structure of the prompt, which makes all the examples dependent with respect to the pretraining distribution. Our approach uses properties of delimiter tokens to approximately factorize the examples, with constant error per example. We let Oiex=[oi1delim,Oi]O^{\text{ex}}_{i}=[o^{\text{delim}}_{i-1},O_{i}] be the ii-th input-output pair and the previous delimiter together for i>1i>1 and define O1ex=O1O^{\text{ex}}_{1}=O_{1}. Expanding the likelihood term inside rn(θ)r_{n}(\theta), our goal is to show

To show this, we expand p(Snθ)p(S_{n}|\theta) with the chain rule, and with Assumption 5 (to bound p(xtestSn,θ)p(x_{\text{test}}|S_{n},\theta) by O(1)O(1)) it can be shown that

We then marginalize p(OiexO1:i1ex,θ)p(O^{\text{ex}}_{i}|O^{\text{ex}}_{1:i-1},\theta) over the hidden state hi1delimh^{\text{delim}}_{i-1} corresponding to the delimiter in Oiex=[oi1delim,Oi]O^{\text{ex}}_{i}=[o^{\text{delim}}_{i-1},O_{i}]:

While summing over H\mathcal{H} above would be a trivial equality, we can replace H\mathcal{H} with the set of delimiter hidden states D\mathcal{D} since p(hO1:i1ex,θ)=0p(h|O^{\text{ex}}_{1:i-1},\theta)=0 for non-delimiter hidden states hDh\notin\mathcal{D} (Assumption 1). We used in the first equality that O1:i1exhi1delimOiexO^{\text{ex}}_{1:i-1}\rightarrow h^{\text{delim}}_{i-1}\rightarrow O^{\text{ex}}_{i} forms a Markov chain and p(oi1delimhi1delim)=1p(o^{\text{delim}}_{i-1}|h^{\text{delim}}_{i-1})=1 (Assumption 1) to change OiexO^{\text{ex}}_{i} to OiO_{i}. Finally, we can show using properties of delimiter hidden states (Assumption 2) that p(hi1delimO1:i1ex,θ)=O(1)p(h^{\text{delim}}_{i-1}|O^{\text{ex}}_{1:i-1},\theta)=O(1) and hi1delimDp(Oihi1delim,θ)O(1)p(Oiθ)\sum_{h^{\text{delim}}_{i-1}\in\mathcal{D}}p(O_{i}|h^{\text{delim}}_{i-1},\theta)\approx O(1)p(O_{i}|\theta) in the second step. Therefore, we can upper bound rn(θ)r_{n}(\theta) as

𝑂𝑛superscriptsubscript𝑖1𝑛𝑝conditionalsubscript𝑂𝑖𝜃𝑝conditionalsubscript𝑂𝑖superscript𝜃→𝑂1subscript𝔼similar-to𝑂subscript𝑝promptdelimited-[]𝑝conditional𝑂𝜃𝑝conditional𝑂superscript𝜃\displaystyle\leq\frac{1}{n}\left(O(n)+\sum_{i=1}^{n}\log\frac{p(O_{i}|\theta)}{p(O_{i}|{\theta^{*}})}\right)\rightarrow O(1)+\mathbb{E}_{O\sim p_{\text{prompt}}}\left[\log\frac{p(O|\theta)}{p(O|{\theta^{*}})}\right]. (11) The expectation term can be written as the difference of two KL divergences, KL(pprompt(O)p(Oθ))KL(pprompt(O)p(Oθ))KL(p_{\text{prompt}}(O)\|p(O|{\theta^{*}}))-KL(p_{\text{prompt}}(O)\|p(O|\theta)). We bound the first KL term by a constant using Assumption 5 — intuitively for one example, ppromptp_{\text{prompt}} and p(θ)p(\cdot|{\theta^{*}}) are close. We break the second term into a sum of negative KL divergences over kk tokens. There are O(k)O(k) KL terms and only O(1)O(1) other error terms, which come from the distribution mismatch between the prompt and pretraining distributions. If the KL terms are larger than the error terms, then rn(θ)r_{n}(\theta) has a negative limit. If this holds for all θθ\theta\neq{\theta^{*}}, then we have exp(nrn(θ))0\exp(n\cdot r_{n}(\theta))\rightarrow 0 for all θθ\theta\neq{\theta^{*}}, enabling in-context learning.

3 Formal results

We define a distinguishability condition which formalizes when in-context learning occurs. Letting pθj(o)p(O[j]=oO[1:j1],θ)p^{j}_{\theta}(o)\coloneqq p(O[j]=o|O[1:j-1],\theta) be the output distribution of the jj-th token given the previous tokens and ppromptj(o)pprompt(O[j]=oO[1:j1])p_{\text{prompt}}^{j}(o)\coloneqq p_{\text{prompt}}(O[j]=o|O[1:j-1]) be the analogous distribution under the prompt distribution, the distinguishability condition depends on the KL divergence between ppromptjp_{\text{prompt}}^{j} (which represents θ{\theta^{*}}) and pθjp^{j}_{\theta} as well as error terms ϵstartθ\epsilon^{\theta}_{\text{start}} and ϵdelimθ\epsilon^{\theta}_{\text{delim}} coming from the distribution mismatch between the prompt and pretraining distributions at the start and delimiter token for each example:

2subscript𝑐2subscript𝑐1subscript𝑐4subscript𝑐3\displaystyle\epsilon^{\theta}_{\text{delim}}\coloneqq 2(\log(c_{2})-\log(c_{1}))+\log(c_{4})-\log(c_{3}),      ϵstartθlog(1/c8).\displaystyle~{}~{}~{}~{}~{}\epsilon^{\theta}_{\text{start}}\coloneqq\log(1/c_{8}). (13) Condition 1 (Distinguishability). We define θ{\theta^{*}} to be distinguishable if for all θΘ,θθ\theta\in\Theta,\theta\neq{\theta^{*}},

subscriptsuperscriptitalic-ϵ𝜃startsubscriptsuperscriptitalic-ϵ𝜃delim\displaystyle\sum_{j=1}^{k}KL_{j}({\theta^{*}}\|\theta)>\epsilon^{\theta}_{\text{start}}+\epsilon^{\theta}_{\text{delim}}. (14) When the signal from KL divergence (LHS) is larger than the error terms, Equation 14 is satisfied (Figure 2). For larger example lengths kk, the LHS increases, improving distinguishability. Intuitively, larger example lengths increase the proportion of the prompt sampled from the pretraining distribution by providing more evidence for Bayesian inference. Under Condition 1, the in-context predictor asymptotically achieves the optimal expected error.

Assume the assumptions in Section 2.1 hold. If Condition 1 holds, then as nn\rightarrow\infty the prediction according to the pretraining distribution is

Thus, the in-context predictor fnf_{n} achieves the optimal 0-1 risk: limnL0-1(fn)=inff L0-1(f).\lim_{n\rightarrow\infty}L_{\text{0-1}}(f_{n})=\inf_{f}~{}L_{\text{0-1}}(f).

3.2 Non-distinguishable case

The distinguishability condition (Condition 1) fails when there is some θθ\theta\neq{\theta^{*}} for which the KL divergence between θ\theta and θ{\theta^{*}} is less than the error terms. However, this also means that the output distributions of θ\theta and θ{\theta^{*}} are close in KL. We leverage this to prove that the expected 0-1 error decreases with the example length kk under two different settings where distinguishability does not hold.

Our first result relies on a continuity assumption between the concept parameter and its corresponding output distribution. Our assumption is based on prior works (Kleijn and van der Vaart, 2012), where the KL divergence is assumed to have a 2nd-order Taylor expansion.

Let the set of θ\theta which does not satisfy Equation 14 in Condition 1 to be B\mathcal{B}. Assume that KL divergences have a 2nd-order Taylor expansion around θ{\theta^{*}}:

12superscript𝜃superscript𝜃topsubscript𝐼𝑗superscript𝜃𝜃superscript𝜃𝑂superscriptnorm𝜃superscript𝜃3\displaystyle\forall j>1,~{}~{}KL_{j}({\theta^{*}}\|\theta)=\frac{1}{2}(\theta-{\theta^{*}})^{\top}I_{j,{\theta^{*}}}(\theta-{\theta^{*}})+O(\|\theta-{\theta^{*}}\|^{3}) (16) where Ij,θI_{j,{\theta^{*}}} is the Fisher information matrix of the jj-th token distribution with respect to θ{\theta^{*}}. Let γθ=maxjλmax(Ij,θ)minjλmin(Ij,θ)\gamma_{{\theta^{*}}}=\frac{\max_{j}\lambda_{\text{max}}(I_{j,{\theta^{*}}})}{\min{j}\lambda_{\text{min}}(I_{j,{\theta^{*}}})} where λmax,λmin\lambda_{\text{max}},\lambda_{\text{min}} return the largest and smallest eigenvalues. Then for k2k\geq 2 and as nn\rightarrow\infty, the 0-1 risk of the in-context learning predictor fnf_{n} is bounded as

subscriptinfimum𝑓subscript𝐿0-1𝑓superscript𝑔1𝑂subscript𝛾superscript𝜃subscriptsupremum𝜃ℬsubscriptsuperscriptitalic-ϵ𝜃startsubscriptsuperscriptitalic-ϵ𝜃delim𝑘1\displaystyle\lim_{n\rightarrow\infty}L_{\text{0-1}}(f_{n})\leq\inf_{f}L_{\text{0-1}}(f)+g^{-1}\left(O\left(\frac{\gamma_{{\theta^{*}}}\sup_{\theta\in\mathcal{B}}(\epsilon^{\theta}_{\text{start}}+\epsilon^{\theta}_{\text{delim}})}{k-1}\right)\right) (17) where g(δ)=12((1δ)log(1δ)+(1+δ)log(1+δ))g(\delta)=\frac{1}{2}((1-\delta)\log(1-\delta)+(1+\delta)\log(1+\delta)) is a calibration function (Steinwart, 2007, Ávila Pires and Szepesvári, 2016) for the multiclass logistic loss for δ[0,1)\delta\in[0,1), assuming that the minimizers of the 0-1 risk and multiclass logistic risk are the same.

Since the inverse calibration function g1g^{-1} is roughly linear in ϵ\epsilon for ϵ0.7\epsilon\leq 0.7, the excess risk roughly decreases as O(1/k)O(1/k). When the “worst-case condition number” γθ\gamma_{{\theta^{*}}} of the Fisher information matrices is smaller (well-conditioned), the error decreases. Intuitively, this means that there is no direction to vary θ{\theta^{*}} in which the output distribution will sharply change. As a consequence, the concepts θ\theta that are not distinguishable from the prompt concept θ{\theta^{*}} parameterize distributions that produce similar outputs to the prompt concept and thus achieve a small error.

In the setting where the length of xtestx_{\text{test}} is random (uniformly from 2 to kk), we can give a similar error guarantee without continuity.

Let the set of θ\theta which does not satisfy Equation 14 in Condition 1 to be B\mathcal{B}. Let the length of the test example xtestx_{\text{test}} be uniformly distributed between 2 and kk, for k2k\geq 2. Then for k2k\geq 2 and as nn\rightarrow\infty, the 0-1 risk of the in-context learning predictor fnf_{n} is bounded as

subscriptinfimum𝑓subscript𝐿0-1𝑓superscript𝑔1𝑂subscriptsupremum𝜃ℬsubscriptsuperscriptitalic-ϵ𝜃startsubscriptsuperscriptitalic-ϵ𝜃delim𝑘1\displaystyle\lim_{n\rightarrow\infty}L_{\text{0-1}}(f_{n})\leq\inf_{f}~{}L_{\text{0-1}}(f)+g^{-1}\left(O\left(\frac{\sup_{\theta\in\mathcal{B}}(\epsilon^{\theta}_{\text{start}}+\epsilon^{\theta}_{\text{delim}})}{k-1}\right)\right), (18) assuming that the minimizers of the 0-1 risk and multiclass logistic risk are the same.

Instead of measuring only the error at the kk-th token, we average the prediction error on the 2nd to kk-th tokens. However, we leave bridging the mismatch between training examples, which are consistently length kk, and test examples, which have random length, to future work.

Simulations

We generate the GINC dataset and show that Transformers (Vaswani et al., 2017) and LSTMs (Hochreiter and Schmidhuber, 1997) trained on GINC exhibit in-context learning. In the theory, we assumed that the pretrained LM fits the pretraining distribution exactly. Here, we pretrain LMs to approximate the pretraining distribution, showing that the in-context learning properties of the pretraining distribution transfer to the LM.

We construct the GINC dataset according to our theory (see Appendix F.1). For pretraining, we define a uniform mixture of HMMs over a family Θ\Theta of 5 concepts to generate 1000 pretraining documents with \sim10 million tokens total. For prompting, we generate prompts with 0 to 64 training examples and example lengths k{3,5,8,10}k\in\{3,5,8,10\} (2500 prompts for each setting). The target token ytesty_{\text{test}} is taken to be the most likely output arg maxypprompt(yxtest)\operatorname*{arg\,max}_{y}p_{\text{prompt}}(y|x_{\text{test}}) instead of sampling so that the intrinsic error is 0.

We train GPT-2-based Transformers (Radford et al., 2019) and LSTMs on three versions of the GINC dataset with vocabulary sizes 50, 100, and 150, then evaluate the in-context accuracy (see Appendix F.2, F.3). We average all results over 5 pretraining runs. Figure 3 shows that for both Transformer and LSTMs, in-context accuracy improves as the number of prompt examples nn and the example length kk increase, verifying our theory.

We ablate the role of the mixture-of-concepts structure in GINC. In Figure 4 (left), we pretrain a 4 layer Transformer on data with only one concept (removing the prior) from Θ\Theta, resulting in flat in-context learning curves. Figure 4 (middle) shows that pretraining on random pretraining data, which contains all possible token transitions, in-context learning also fails. Therefore, the mixture-of-concepts structure is important and simply seeing diverse token transitions does not enable in-context learning.

Full generative control of GINC allows for experimentation with latent variables in the pretraining distribution. For example, in large-scale datasets, it is difficult to test whether a concept or task is in the pretraining data. We test this in GINC by testing the in-context accuracy of a 4 layer Transformer on prompts generated from 5 random concepts that are not in the pretraining family of concepts. Figure 4 (right) shows that in-context learning also fails for these novel concepts.

Figure 6 shows that increasing the size of the Transformer (4, 12, 16 layers) steadily increases the in-context accuracy, corroborating the results of Brown et al. (2020). Table 6 shows that even though larger Transformers may have the same pretraining loss (e.g., 12 and 16 layer Transformers both get 1.33 validation loss for vocab size 50), the in-context accuracy still improves (81% to 85% from 12 to 16 layers), suggesting that larger models can improve in-context learning beyond improving pretraining perplexity. This may be related to phenomena from overparameterization and overtraining (Zhang et al., 2017, Power et al., 2021). Finally, the model architecture also plays a role — LSTMs consistently outperform Transformers on GINC despite having fewer parameters, perhaps due to the similarity between HMMs and LSTMs. We leave analysis of the effect of model scaling and model architecture as open questions.

In Figure 7 (left), we test the sensitivity of in-context accuracy on GINC to the ordering of the prompt examples, following Zhao et al. (2021). For this experiment, we consider prompts generated from a single concept and prompt start distribution. We sample 10 different sets (leading to 10 training set IDs) of 4 examples and generate all 24 possible permutations for each example set. We consider the in-context accuracy of the 4 layer Transformer trained on GINC with vocabulary size 50. Similarly to the behavior of GPT-3 (Zhao et al., 2021), there is a significant variation (10–40% difference) between permutations of the same set of examples.

In some settings in GINC, we find that zero-shot performance can be better than few-shot performance. This mirrors GPT-3 on some datasets (e.g., LAMBADA, HellaSwag, PhysicalQA, RACE-m, CoQA/SAT analogies for smaller models (Brown et al., 2020)). This occurs especially when the transition probabilities in GINC are lower entropy (controlled via a temperature parameter). For this experiment, we consider GINC with transition matrix temperature parameter 0.01 (instead of 0.1), 12 concepts, and vocabulary size 100. Figure 7 (right) shows that here, few-shot accuracy is initially worse than zero-shot accuracy, but can recover with more examples. We hypothesize that the distracting prompt structure initially decreases the accuracy in this setting.

Discussion and related work

The canonical Bernstein-von Mises theorem (van der Vaart, 1998) does not apply for in-context learning since the prompt examples are not independent under the pretraining distribution. Gunst and Shcherbakova (2008) show a Bernstein-von Mises-type result for observations from an HMM, but do not handle observations from a different distribution. Future directions include more precise asymptotic results about the posterior distribution and results under misspecification/extrapolation (Kleijn and van der Vaart, 2012). A possible avenue for extrapolation to some types of unseen concepts is to factorize the latent concept into semantics and syntax. While the pretraining data may contain only some semantics-syntax pairs, the language model could generalize to unseen pairs if it learns generalizable syntactical operations such as copying or reordering.

Topic models such as LDA (Blei et al., 2003) also have document-level latent variables, but learning is typically relies on algorithms such as EM (Dempster et al., 1977), variational inference (Jordan et al., 1999), or MCMC (Metropolis et al., 1953, Hastings, 1970). We focus on learning as a natural result of Bayesian inference without an explicit inference algorithm. Wei et al. (2021a) also use an HMM model in their pretraining analysis. However, they analyze how pre-trained representations learned with masked LMs (Devlin et al., 2019, Liu et al., 2019, Lewis et al., 2020, Clark et al., 2020) can improve optimization-based downstream learning (Li and Liang, 2021, Lester et al., 2021) rather than in-context learning.

Prior works support our theoretical intuitions that reducing the prompt distribution mismatch would improve in-context learning. Finetuning LMs on text with a prompting format improves its zero-shot performance (Wei et al., 2021b, Sanh et al., 2021) and optimizing prompt templates improves few-shot finetuning (Jiang et al., 2020, Schick and Schütze, 2021, Shin et al., 2020, Gao et al., 2021). Zhao et al. (2021), Holtzman et al. (2021) improve in-context accuracy via calibration or renormalization, a form of adaptation to the prompt distribution.

Meta-learning methods can also train a sequence model to learn from examples (Ravi and Larochelle, 2017). However, meta-learning models are trained to learn, while in-context learning emerges from LM pretraining.

We can study in-context learning, a large scale phenomenon, at a small scale in GINC because the complexity of the pretraining distribution (HMM hidden state size, number of latent concepts) is small, such that the data and models are relatively larger. Since GINC is synthetic, we can also control the latent data properties (e.g., unseen concepts) to make predictions about large LMs while working at a small scale.

Conclusion

We cast in-context learning as implicit Bayesian inference, where the pretrained LM implicitly infers a concept when making a prediction. We show that in-context learning occurs when the pre-training distribution is a mixture of HMMs. Our work provides a first step towards understanding in-context learning, which we hope will provide insight for improving pretraining and prompting.

Acknowledgements

We thank Tianyi Zhang, Frieda Rong, Lisa Li, Colin Wei, Shibani Santurkar, Tri Dao, Ananya Kumar, and Shivam Garg for helpful discussions and feedback. SMX is supported by an NDSEG Fellowship. The work is partially supported by an Open Philanthropy Project Award, SDSI, and SAIL at Stanford University. TM acknowledges support of Google Faculty Award, NSF IIS 2045685, the Sloan Fellowship, and JD.com. Toyota Research Institute provided funds to support this work.

References

Appendix A Framework details

For in-context learning, we sample a prompt from a new distribution ppromptp_{\text{prompt}}, which consists of nn independent training examples and 1 test example. We first sample nn hidden segments HH of length kk by sampling the first element hstart=Hh^{\text{start}}=H from a prompt start distribution ppromptp_{\text{prompt}}. Then, we sample the rest of the segment Hseg=H[2:k]H^{\text{seg}}=H[2:k] from the hidden transition distribution of the pretraining distribution pp corresponding to a particular concept θ{\theta^{*}}:

To end each example (except the test example), we sample nn delimiters hdelimDh^{\text{delim}}\in\mathcal{D} from ppromptdelimp_{\text{prompt}}^{\text{delim}}:

Conditioned on hidden variables HiH_{i} and hidelimh^{\text{delim}}_{i}, we sample the observed tokens Oi=[oi,1,,oi,k]O_{i}=[o_{i,1},\dots,o_{i,k}] and oidelimo^{\text{delim}}_{i} respectively from the pre-training distribution:

The “input” for each example is xi=Oi[1:k1]x_{i}=O_{i}[1:k-1] and the “output” is yi=Oi[k]y_{i}=O_{i}[k]. Taking SS to be the sequence of training examples (without the test example), the resulting prompt sequence is

where xtest=xn+1=On+1[1:k1]x_{\text{test}}=x_{n+1}=O_{n+1}[1:k-1] is sampled via the same process but with k1k-1 elements.

Appendix B Propositions for Theorem 1

The following propositions, which lower bound the probability of a delimiter token and probability of an example under θ{\theta^{*}}, are direct corollaries of the assumptions.

For all ii, we have p(hidelimO1,o1delim,,Oi,θ)>c1p(h^{\text{delim}}_{i}|O_{1},o^{\text{delim}}_{1},\dots,O_{i},{\theta^{*}})>c_{1} and p(hidelimO1,o1delim,,Oi,θ)<c2p(h^{\text{delim}}_{i}|O_{1},o^{\text{delim}}_{1},\dots,O_{i},\theta)<c_{2}.

The probability of an example is lower bounded for θ{\theta^{*}}: there is some c7>0c_{7}>0 such that p(Oihistart,hj,l,θ)>c7p(O_{i}|h^{\text{start}}_{i},h_{j,l},{\theta^{*}})>c_{7} for all ii and future hidden states hj,lh_{j,l}, for any ll and j>ij>i.

which lower bounds the terms in the numerator by c5c_{5} (marginalizing over previous hidden states), and upper bounding the denominator by 1. Setting c7=(c6)kc52c_{7}=(c_{6})^{k}c_{5}^{2} finishes the proof. ∎

Appendix C Convergence of the in-context predictor

Under Assumption 3, we show that the in-context predictor fn(xtest)=arg maxyp(ySn,xtest)f_{n}(x_{\text{test}})=\operatorname*{arg\,max}_{y}p(y|S_{n},x_{\text{test}}) converges when abstracting away the Bayesian inference component (the selection of θ{\theta^{*}} from Θ\Theta) of the in-context predictor. We will complete the argument for the convergence of the in-context predictor in the proof of Theorem 1.

Suppose the prompt SnS_{n} and the test input xtestx_{\text{test}} are given. Under Assumption 3, we show that the argmax of the averaged predictive distribution conditioned on θ{\theta^{*}} and a prompt SnS_{n} is the same as the argmax of the prompt predictive distribution:

which is proportional to a constant in xtestx_{\text{test}}.

On the other hand, analyzing one term inside the LHS of the lemma statement, we have

which is proportional to a constant in xtestx_{\text{test}} and SnS_{n}. The quantities differ in the last term, which we expand below and put in matrix form. Let TRH×DT\in\mathbb{R}^{|\mathcal{H}|\times|\mathcal{D}|} be the matrix that represents the transition probabilities starting from a delimiter state: p(hteststarthdelim)p(h^{\text{start}}_{\text{test}}|h^{\text{delim}}) for hteststartHh^{\text{start}}_{\text{test}}\in\mathcal{H} and hdelimDh^{\text{delim}}\in\mathcal{D}. As a result,

where hndelimh^{\text{delim}}_{n} is the delimiter hidden state before hteststarth^{\text{start}}_{\text{test}}.

Let WRY×HW\in\mathbb{R}^{|\mathcal{Y}|\times|\mathcal{H}|} be the matrix that represents the probabilities p(yxtest,hteststart,θ)p(xtesthteststart,θ)p(y|x_{\text{test}},h^{\text{start}}_{\text{test}},{\theta^{*}})p(x_{\text{test}}|h^{\text{start}}_{\text{test}},{\theta^{*}}) for all the possible yYy\in\mathcal{Y} and hteststartHh^{\text{start}}_{\text{test}}\in\mathcal{H}. Overall, we can write

where uRHu\in\mathbb{R}^{|\mathcal{H}|} is the vector of probabilities that corresponds to the prompt start distribution ppromptp_{\text{prompt}}.

Bounding the difference between the two predictive distributions,

Using Assumption 3, we can further bound this by Δ/2\Delta/2:

Since the probability of any output does not change by more than Δ/2\Delta/2 and the margin between the most likely label and the second most likely label is Δ\Delta, the argmax’s are the same, showing the result. ∎

Appendix D Proof of Theorem 1

We analyze the most likely prediction over the pretraining distribution conditioned on the prompt arg maxyp(ySn,xtest)\operatorname*{arg\,max}_{y}p(y|S_{n},x_{\text{test}}).

we will show that under distinguishability for all θθ\theta\neq{\theta^{*}}, rn(θ)r_{n}(\theta) converges to a negative constant such that

for θθ\theta\neq{\theta^{*}}, whereas this ratio is always 1 for θ=θ\theta={\theta^{*}}. This will then “select” the desired prompt concept through marginalization.

Supposing that Equation 53 holds, we show that the theorem statement holds. Let

and let ϵ<(Δ/2Δ)p(θ)\epsilon<(\Delta/2-\Delta^{\prime})p({\theta^{*}}). Then for nn large enough (due to Equation 53),

subscriptsubscriptsuperscriptℎstarttestℋ𝑝conditional𝑦subscript𝑥testsubscriptsuperscriptℎstarttestsuperscript𝜃𝑝conditionalsubscriptsuperscriptℎstarttestsubscript𝑆𝑛subscript𝑥testsuperscript𝜃𝑝superscript𝜃subscript𝜃superscript𝜃subscriptitalic-ϵ𝜃𝑦𝑝𝜃differential-d𝜃\displaystyle=\sum_{h^{\text{start}}_{\text{test}}\in\mathcal{H}}p(y|x_{\text{test}},h^{\text{start}}_{\text{test}},{\theta^{*}})p(h^{\text{start}}_{\text{test}}|S_{n},x_{\text{test}},{\theta^{*}})p({\theta^{*}})+\int_{\theta\neq{\theta^{*}}}\epsilon_{\theta}(y)p(\theta)d\theta (56) hteststartHp(yxtest,hteststart,θ)p(hteststartSn,xtest,θ)+1p(θ)θθϵθ(y)p(θ)dθ\displaystyle\propto\sum_{h^{\text{start}}_{\text{test}}\in\mathcal{H}}p(y|x_{\text{test}},h^{\text{start}}_{\text{test}},{\theta^{*}})p(h^{\text{start}}_{\text{test}}|S_{n},x_{\text{test}},{\theta^{*}})+\frac{1}{p({\theta^{*}})}\int_{\theta\neq{\theta^{*}}}\epsilon_{\theta}(y)p(\theta)d\theta (57) where ϵθ(y)ϵ/2\epsilon_{\theta}(y)\leq\epsilon/2 for all yYy\in\mathcal{Y}.

By Lemma 1, the argmax of the first term of Equation 57 is the same as arg maxypprompt(yxtest)\operatorname*{arg\,max}_{y}p_{\text{prompt}}(y|x_{\text{test}}), where the margin between the most likely label and the second most likely is at least Δ/2Δ\Delta/2-\Delta^{\prime}. Since

for all yYy\in\mathcal{Y}, the argmax of Equation 57 is also the same as arg maxpprompt(yxtest)\operatorname*{arg\,max}p_{\text{prompt}}(y|x_{\text{test}}).

Now it remains to show that rn(θ)r_{n}(\theta) converges to a negative constant for θθ\theta\neq{\theta^{*}}. Let Oiex=[oi1delim,Oi]O^{\text{ex}}_{i}=[o^{\text{delim}}_{i-1},O_{i}] be the ii-th observation segment and the previous delimiter together for i>1i>1 and define O1ex=O1O^{\text{ex}}_{1}=O_{1}. Expanding the numerator of the ratio in rn(θ)r_{n}(\theta), we have

Note that in the last line, the inner sum is over the set of delimiter states D\mathcal{D} by using the assumption that observing a delimiter odelimo^{\text{delim}} implies that the corresponding hidden state hdelimh^{\text{delim}} must be in D\mathcal{D}. We also see that hndelimp(hndelimO1:nex,θ)=1\sum_{h^{\text{delim}}_{n}}p(h^{\text{delim}}_{n}|O^{\text{ex}}_{1:n},\theta)=1.

We restrict our attention to θ\theta where p(Sn,xtestθ)>0p(S_{n},x_{\text{test}}|\theta)>0, since otherwise θ\theta does not affect the prediction. Expanding rn(θ)r_{n}(\theta), we have the following upper bound:

subscriptsubscriptsuperscriptℎstarttest𝑝conditionalsubscript𝑥testsubscriptsuperscriptℎstarttest𝜃𝑝conditionalsubscriptsuperscriptℎstarttestsubscript𝑆𝑛𝜃subscriptsubscriptsuperscriptℎstarttest𝑝conditionalsubscript𝑥testsubscriptsuperscriptℎstarttestsuperscript𝜃𝑝conditionalsubscriptsuperscriptℎstarttestsubscript𝑆𝑛superscript𝜃superscriptsubscript𝑖1𝑛subscriptsubscriptsuperscriptℎdelim𝑖1𝒟𝑝conditionalsubscript𝑂𝑖subscriptsuperscriptℎdelim𝑖1𝜃𝑝conditionalsubscriptsuperscriptℎdelim𝑖1subscriptsuperscript𝑂ex:1𝑖1𝜃subscriptsubscriptsuperscriptℎdelim𝑖1𝒟𝑝conditionalsubscript𝑂𝑖subscriptsuperscriptℎdelim𝑖1superscript𝜃𝑝conditionalsubscriptsuperscriptℎdelim𝑖1subscriptsuperscript𝑂ex:1𝑖1superscript𝜃\displaystyle=\frac{1}{n}\bigg{(}\log\frac{\sum_{h^{\text{start}}_{\text{test}}}p(x_{\text{test}}|h^{\text{start}}_{\text{test}},\theta)p(h^{\text{start}}_{\text{test}}\mid S_{n},\theta)}{\sum_{h^{\text{start}}_{\text{test}}}p(x_{\text{test}}|h^{\text{start}}_{\text{test}},{\theta^{*}})p(h^{\text{start}}_{\text{test}}\mid S_{n},{\theta^{*}})}+\sum_{i=1}^{n}\log\frac{\sum_{h^{\text{delim}}_{i-1}\in\mathcal{D}}p(O_{i}|h^{\text{delim}}_{i-1},\theta)p(h^{\text{delim}}_{i-1}|O^{\text{ex}}_{1:i-1},\theta)}{\sum_{h^{\text{delim}}_{i-1}\in\mathcal{D}}p(O_{i}|h^{\text{delim}}_{i-1},{\theta^{*}})p(h^{\text{delim}}_{i-1}|O^{\text{ex}}_{1:i-1},{\theta^{*}})}\bigg{)} (67) \displaystyle\leq\frac{1}{n}\bigg{(}\log\frac{\sum_{h^{\text{start}}_{\text{test}}}1\cdot p(h^{\text{start}}_{\text{test}}\mid S_{n},\theta)}{\sum_{h^{\text{start}}_{\text{test}}}c_{7}\cdot p(h^{\text{start}}_{\text{test}}\mid S_{n},{\theta^{*}})}+n(\log(c_{2})-\log(c_{1}))+\sum_{i=1}^{n}\log\frac{\sum_{h^{\text{delim}}_{i-1}\in\mathcal{D}}p(O_{i}|h^{\text{delim}}_{i-1},\theta)}{\sum_{h^{\text{delim}}_{i-1}\in\mathcal{D}}p(O_{i}|h^{\text{delim}}_{i-1},{\theta^{*}})}\bigg{)} (68) \displaystyle=\frac{1}{n}\bigg{(}-\log(c_{7})+n(\log(c_{2})-\log(c_{1}))+\sum_{i=1}^{n}\log\frac{\sum_{h^{\text{delim}}_{i-1}\in\mathcal{D}}p(O_{i}|h^{\text{delim}}_{i-1},\theta)}{\sum_{h^{\text{delim}}_{i-1}\in\mathcal{D}}p(O_{i}|h^{\text{delim}}_{i-1},{\theta^{*}})}\bigg{)} (69) In the above steps, we used both Propositions 1 and 2 in the terms involving c2,c1c_{2},c_{1} (bounding the probability of hdelimh^{\text{delim}} hidden states) and c7c_{7} (bounding the probability of xtestx_{\text{test}}). Note that in the second line, the sum can must be over the set of delimiter states D\mathcal{D} by using the assumption that observing a delimiter odelimo^{\text{delim}} implies that the corresponding hidden state hdelimh^{\text{delim}} must be in D\mathcal{D}.

Focusing on the numerator of the ratio term and summing over the start hidden state for the ii-th example,

where the last step applies Bayes’ rule. We can lower and upper bound the following quantity for any θ\theta using Assumption 2:

subscript𝑐72𝑛subscript𝑐2subscript𝑐1𝑛subscript𝑐4subscript𝑐3superscriptsubscript𝑖1𝑛subscriptsubscriptsuperscriptℎstart𝑖𝑝conditionalsubscript𝑂𝑖subscriptsuperscriptℎstart𝑖𝜃𝑝conditionalsubscriptsuperscriptℎstart𝑖𝜃subscriptsubscriptsuperscriptℎstart𝑖𝑝conditionalsubscript𝑂𝑖subscriptsuperscriptℎstart𝑖𝜃𝑝conditionalsubscriptsuperscriptℎstart𝑖superscript𝜃\displaystyle\leq\frac{1}{n}\bigg{(}-\log(c_{7})+2n(\log(c_{2})-\log(c_{1}))+n(\log(c_{4})-\log(c_{3}))+\sum_{i=1}^{n}\log\frac{\sum_{h^{\text{start}}_{i}}p(O_{i}|h^{\text{start}}_{i},\theta)p(h^{\text{start}}_{i}|\theta)}{\sum_{h^{\text{start}}_{i}}p(O_{i}|h^{\text{start}}_{i},\theta)p(h^{\text{start}}_{i}|{\theta^{*}})}\bigg{)} (77) \displaystyle=\frac{1}{n}\bigg{(}-\log(c_{7})+2n(\log(c_{2})-\log(c_{1}))+n(\log(c_{4})-\log(c_{3}))+\sum_{i=1}^{n}\log\frac{p(O_{i}|\theta)}{p(O_{i}|{\theta^{*}})}\bigg{)} (78) nEOpprompt[logp(Oθ)p(Oθ)]+ϵdelimθ\displaystyle\rightarrow_{n\rightarrow\infty}\mathbb{E}_{O\sim p_{\text{prompt}}}\left[\log\frac{p(O|\theta)}{p(O|{\theta^{*}})}\right]+\epsilon^{\theta}_{\text{delim}} (79) where we set

2subscript𝑐2subscript𝑐1subscript𝑐4subscript𝑐3\displaystyle\epsilon^{\theta}_{\text{delim}}=2(\log(c_{2})-\log(c_{1}))+\log(c_{4})-\log(c_{3}). (80) Next, we convert the expectation in the bound into a KL divergence. We have

𝑝conditional𝑂𝜃subscript𝑝prompt𝑂subscript𝑝prompt𝑂𝑝conditional𝑂superscript𝜃\displaystyle=\mathbb{E}_{O\sim p_{\text{prompt}}}\left[\log\frac{p(O|\theta)}{p_{\text{prompt}}(O)}+\log\frac{p_{\text{prompt}}(O)}{p(O|{\theta^{*}})}\right] (81) =KL(ppromptp(θ))KL(ppromptp(θ)).\displaystyle=KL(p_{\text{prompt}}\|p(\cdot|{\theta^{*}}))-KL(p_{\text{prompt}}\|p(\cdot|\theta)). (82) We will upper bound the first KL term:

Expanding the numerator and denominator of the ratio inside, we have

which differ in only the hidden start distribution. Using Assumption 5, we have that p(hθ)c8p(h|{\theta^{*}})\geq c_{8} for any hHh\in\mathcal{H}, which implies that

Finally, this implies that the KL term is bounded as

This term is non-negative since c81c_{8}\leq 1.

Aiming to decompose the second KL term into a sum over the kk tokens, we write pθj(o)=p(O[j]=oO[1:j1],θ)p^{j}_{\theta}(o)=p(O[j]=o|O[1:j-1],\theta) and ppromptj(o)=pprompt(O[j]=oO[1:j1])p_{\text{prompt}}^{j}(o)=p_{\text{prompt}}(O[j]=o|O[1:j-1]). We have

superscriptsubscript𝑗1𝑘subscript𝔼𝑂delimited-[]:1𝑗1similar-toabsentsubscript𝑝promptdelimited-[]𝐾𝐿conditionalsuperscriptsubscript𝑝prompt𝑗subscriptsuperscript𝑝𝑗𝜃subscriptsuperscriptitalic-ϵ𝜃startsubscriptsuperscriptitalic-ϵ𝜃delim\displaystyle<-\sum_{j=1}^{k}\mathbb{E}_{O[1:j-1]\sim p_{\text{prompt}}}[KL(p_{\text{prompt}}^{j}\|p^{j}_{\theta})]+\epsilon^{\theta}_{\text{start}}+\epsilon^{\theta}_{\text{delim}} (93) The second term (set ϵstartθ=log(1c8)\epsilon^{\theta}_{\text{start}}=\log(\frac{1}{c_{8}})) is an error term that depends on how different the starting prompt distribution ppromptp_{\text{prompt}} (which is part of ppromptp_{\text{prompt}}) is to the pretraining distribution. The third term is an error term that comes from the delimiter transitions. The bound is negative when the sum of KL terms is larger in magnitude than the error terms. Note that as kk becomes larger, the number of observations of θ{\theta^{*}} “overpowers” the distracting transitions in the prompt distribution. This condition is equivalent to the disinguishability condition (Condition 1).

By assumption, for θθ\theta\neq{\theta^{*}} the Condition 1 holds, and thus

since rn(θ)r_{n}(\theta) has a negative, constant limit. Note that exp(nrn(θ))=1\exp(n\cdot r_{n}({\theta^{*}}))=1 for θ{\theta^{*}}.

Appendix E Non-distinguishable case

When Condition 1 is unsatisfied, Equation 14), gives an upper bound on the sum of KL divergences for the next token distributions given different-length histories. In contrast, the in-context task only measures the accuracy of the last (kk-th) token. The main challenge is to relate the different-length histories to each other to give a more precise bound for the error on the in-context task (last token).

Before addressing this challenge, we give the following lemma, which leverages the result of Ávila Pires and Szepesvári (2016), Steinwart (2007) to relate a bound on the KL divergence to 0-1 loss.

Let the set of θ\theta which does not satisfy Condition 1 to be B\mathcal{B}. Assume that KL(pprompt(ytestxtest)p(ytestxtest,θ)KL(p_{\text{prompt}}(y_{\text{test}}|x_{\text{test}})\|p(y_{\text{test}}|x_{\text{test}},\theta) is bounded above for all θ\theta and that θ{\theta^{*}} minimizes the multiclass logistic risk LCE(θ)=Extestpprompt[pprompt(ytestxtest)logp(ytestxtest,θ)]L_{\text{CE}}(\theta)=-\mathbb{E}_{x_{\text{test}}\sim p_{\text{prompt}}}[p_{\text{prompt}}(y_{\text{test}}|x_{\text{test}})\log p(y_{\text{test}}|x_{\text{test}},\theta)]. If

subscriptinfimum𝑓subscript𝐿0-1𝑓superscript𝑔1subscriptsupremum𝜃ℬsubscriptitalic-ϵ𝜃\displaystyle\lim_{n\rightarrow\infty}L_{\text{0-1}}(f_{n})\leq\inf_{f}L_{\text{0-1}}(f)+g^{-1}\left(\sup_{\theta\in\mathcal{B}}\epsilon_{\theta}\right) (96) where

1𝛿1𝛿1𝛿1𝛿\displaystyle=\frac{1}{2}((1-\delta)\log(1-\delta)+(1+\delta)\log(1+\delta)) (97) is a calibration function for the multiclass logistic loss for δ\delta\in.

First, we note that we can study the 0-1 risk of the limiting predictor:

where in the last step we use that since the output space of fnf_{n} is discrete and the probabilities that the in-context predictor takes an argmax over converges, then for NN large enough, fN(xtest)=limnfn(xtest)f_{N}(x_{\text{test}})=\lim_{n\rightarrow\infty}f_{n}(x_{\text{test}}).

Note that for every input xtestx_{\text{test}}, the limiting in-context learning predictor outputs the argmax of a predictive distribution which can be a mixture of predictive distributions over B\mathcal{B}:

for some distribution qq over B\mathcal{B}. The KL divergence between this mixture and the prompt concept is bounded by the KL divergence of any one θB\theta\in\mathcal{B}, due to the convexity of KL:

where we can exchange the order of expectations since the KL is bounded (dominated convergence).

From the KL bound KL(pprompt(ytestxtest)p(ytestxtest,θ)KL(p_{\text{prompt}}(y_{\text{test}}|x_{\text{test}})\|p(y_{\text{test}}|x_{\text{test}},\theta), we thus have

where LCE(θ)=Extestpprompt[pprompt(ytestxtest)logp(ytestxtest,θ)]L_{\text{CE}}(\theta)=-\mathbb{E}_{x_{\text{test}}\sim p_{\text{prompt}}}[p_{\text{prompt}}(y_{\text{test}}|x_{\text{test}})\log p(y_{\text{test}}|x_{\text{test}},\theta)] is the multiclass logistic risk, and LCE(θ)L_{\text{CE}}({\theta^{*}}) is the optimal risk over θΘ\theta\in\Theta by assumption. Applying Theorem 2.2 and 5.11 of Ávila Pires and Szepesvári (2016), gg is a calibration function for the multiclass logistic loss, and allows us to convert the surrogate risk bound to a bound on the 0-1 loss, giving the result. Note that we have zero approximation error here, since θΘ{\theta^{*}}\in\Theta. ∎

Note that g1g^{-1} is roughly linear in ϵ\epsilon for ϵ\epsilon smaller than 0.7, where the bound is non-vacuous.

By the continuity assumption, we have for any θ\theta in B\mathcal{B} that

12superscriptsubscript𝑗2𝑘superscript𝜃superscript𝜃topsubscript𝐼𝑗superscript𝜃𝜃superscript𝜃𝑘1𝑂superscriptnorm𝜃superscript𝜃3\displaystyle\geq\frac{1}{2}\sum_{j=2}^{k}(\theta-{\theta^{*}})^{\top}I_{j,{\theta^{*}}}(\theta-{\theta^{*}})+(k-1)O(\|\theta-{\theta^{*}}\|^{3}) (107) 12(k1)λmin(Ij,θ)θθ2\displaystyle\geq\frac{1}{2}(k-1)\lambda_{\text{min}}(I_{j,{\theta^{*}}})\|\theta-{\theta^{*}}\|^{2} (108)     θθ2\displaystyle\implies\|\theta-{\theta^{*}}\|^{2} ϵstartθ+ϵdelimθ12(k1)(minj λmin(Ij,θ)).\displaystyle\leq\frac{\epsilon^{\theta}_{\text{start}}+\epsilon^{\theta}_{\text{delim}}}{\frac{1}{2}(k-1)(\min_{j}~{}\lambda_{\text{min}}(I_{j,{\theta^{*}}}))}. (109) We use this to bound the last KL term by plugging it in below:

12superscript𝜃superscript𝜃topsubscript𝐼𝑘superscript𝜃𝜃superscript𝜃𝑂superscriptnorm𝜃superscript𝜃3\displaystyle=\frac{1}{2}(\theta-{\theta^{*}})^{\top}I_{k,{\theta^{*}}}(\theta-{\theta^{*}})+O(\|\theta-{\theta^{*}}\|^{3}) (110) 12(maxj λmax(Ij,θ))θθ2+O(θθ2)\displaystyle\leq\frac{1}{2}(\max_{j}~{}\lambda_{\text{max}}(I_{j,{\theta^{*}}}))\|\theta-{\theta^{*}}\|^{2}+O(\|\theta-{\theta^{*}}\|^{2}) (111) (ϵstartθ+ϵdelimθ)(maxj λmax(Ij,θ)+O(1))(k1)minj λmin(Ij,θ).\displaystyle\leq\frac{(\epsilon^{\theta}_{\text{start}}+\epsilon^{\theta}_{\text{delim}})(\max_{j}~{}\lambda_{\text{max}}(I_{j,{\theta^{*}}})+O(1))}{(k-1)\min_{j}~{}\lambda_{\text{min}}(I_{j,{\theta^{*}}})}. (112) Rearranging and noting that KLk(θθ)=Extestpprompt[KL(pprompt(ytestxtest)p(ytestxtest,θ))]KL_{k}({\theta^{*}}\|\theta)=\mathbb{E}_{x_{\text{test}}\sim p_{\text{prompt}}}[KL(p_{\text{prompt}}(y_{\text{test}}|x_{\text{test}})\|p(y_{\text{test}}|x_{\text{test}},\theta))], we have

Plugging into Lemma 2 gives the result. ∎

E.2 Proof of Theorem 3

Note that Condition 1 ensures that the sum of KL divergences between positions within a kk-length input is bounded. This means that we have a bound over not only the last-position KL divergence, but also for all the intermediate tokens. Intuitively, the random length test example allows the in-context predictor to “take credit” for fitting the intermediate tokens. The proof is immediate given the KL bound and Lemma 2, given that the length of xtestx_{\text{test}} is uniformly random between 2 to kk.

Let the set of θ\theta that does not satisfy Condition 1 to be B\mathcal{B}. We have for any θ\theta in B\mathcal{B} that

subscriptsuperscriptitalic-ϵ𝜃startsubscriptsuperscriptitalic-ϵ𝜃delim𝑘1\displaystyle\leq\frac{\sup_{\theta}(\epsilon^{\theta}_{\text{start}}+\epsilon^{\theta}_{\text{delim}})}{k-1} (116) by Theorem 1 and Condition 1. Plugging this into Lemma 2 gives the result. ∎

Appendix F Experimental details

We consider a pretraining distribution from a mixture of HMMs with an interpretable hidden state structure and emission distribution. The HMM hidden state ht=[st,vt]h_{t}=[s_{t},v_{t}] at time tt is composed of an entity vt{1,,V}v_{t}\in\{1,\dots,|\mathcal{V}|\} (e.g., Einstein) and a property st{1,,S}s_{t}\in\{1,\dots,|\mathcal{S}|\} (e.g., nationality, first name, last name, other grammatical tokens). We model the entities and properties as independent Markov chains (i.e., a factorial HMM (Ghahramani and Jordan, 1997)), while the emissions depend on both. In pretraining documents, we expect that the entities (e.g., Einstein) change slowly over time while and the properties of the entity (e.g., their nationality) change quickly with some pattern to generate natural sentences. We implement this by ensuring that the probability of transitioning to the same entity index in the next step is at least 0.9. The emission distribution depends on a memory matrix MM with V|\mathcal{V}| rows and S|\mathcal{S}| columns (Figure 9). At step tt, we use the entity vtv_{t} and property sts_{t} to index into the memory matrix. In particular, the observed tokens are deterministic with p(otht)=1p(o_{t}|h_{t})=1 if ot=M[vt,st]o_{t}=M[v_{t},s_{t}]. This construction satisfies the structure on delimiter states (Assumption 1). We ensure that all the transitions have nonzero probability and use a uniform prior over concepts, satisfying Assumptions 2 and 5.

The concept parameter is the property transition matrix, while the entity transition matrix is fixed for all concepts. The prompt start distribution and the concept together determine the in-context task. We define a uniform mixture of HMMs over a family Θ\Theta of 5 concepts to generate 1000 documents with \sim10 million tokens total.

The GINC dataset is generated from a mixture of HMMs. These HMMs output tokens from a vocabulary of size in {50,100,150}\{50,100,150\}. The vocabulary contains a special delimiter token (backslash – see Figure 8, designated to be index 1. The vocabulary is generated as combinations of letters starting from a to z, then aa to az, and so on. All sequences are tokenized by splitting on whitespaces.

The shared memory matrix has 10 entities and 10 properties, totaling 100 entries (corresponding to 100 hidden states). The first column of the memory matrix is fixed to be the delimiter token, while each remaining entry of the shared memory matrix is populated with a token sampled uniformly from the vocabulary.

We generate 5 property transition matrices, one for each component of the HMM mixture. We generate each transition matrix via a convex combination of 100 random permutation matrices. The weights of the convex combination are randomly generated as

where uR100u\in\mathbb{R}^{100} has uniform random entries in $andandt$ is a temperature parameter, set to 0.1.

The entity transition matrix is shared between all the HMMs that consistute the mixture. The entity transition matrix is generated in the same way as the property transition matrices, except with one additional step. Letting TT be a transition matrix sampled in the same way as a property transition matrix,

In pretraining documents, we expect that the entities (e.g., Einstein) change slowly over time while and the properties of the entity (e.g., their occupation) change quickly with some pattern to generate natural sentences. We implement this by ensuring that the probability of transitioning to the same entity index in the next step is at least 0.9. The final entity transition matrix is then 0.1T+0.9I0.1T+0.9I where II is the identity matrix. Although we add the diagonal component for added realism, we also consider not adding this component. Figure 10 shows in-context learning curves for a small (4 layer) Transformer trained on data that does not add the diagonal component (we check this for vocabulary sizes 50, 100, and 150). In-context learning still works in this case, although not as well for the 50 vocab size case.

The starting distribution for the hidden states in all HMMs in the mixture are close to uniform. We generate the start distribution as softmax((u0.5)/t)\text{softmax}((u-0.5)/t) for random vector uu with entries uniformly from $andtemperatureand temperaturet=10$. In the pretraining documents, we only sample from the start distribution in the beginning of the document.

We generate prompts with 0 to 64 training examples and example lengths k{3,5,8,10}k\in\{3,5,8,10\} (2500 prompts for each setting). The target token ytesty_{\text{test}} is taken to be the most likely output arg maxypprompt(yxtest)\operatorname*{arg\,max}_{y}p_{\text{prompt}}(y|x_{\text{test}}) instead of sampling so that the intrinsic error is 0.

To generate the prompts, we first sample a concept θ\theta uniformly at random from Θ\Theta (well-specification, Assumption 4), then use it to generate all the prompt examples. The prompt start distribution is chosen to be uniform over entities but with a fixed starting property that is chosen randomly for each prompt, for consistency in the task. This may not satisfy Assumption 3, but we found this to still work empirically and is simpler. Given the starting property, we sample kk tokens from the HMM defined by the concept θ\theta. Finally, we append the delimiter token for the example. We repeat this process for each example in the prompt, concatenating all examples. The label is generated as

under the prompt concept θ{\theta^{*}}. This differs from the theory, which samples ytesty_{\text{test}} instead of taking it to be the most likely token. However, there can be a large amount of intrinsic error that sampling introduces. We define the label this way in the simulations to remove the intrinsic error from sampling.

In the example in Figure 8 (right), the starting property is fixed to be 5 (for example). The first token (l) is generated by sampling a random entity index (3), and indexing into the memory matrix returns l. Running the hidden state chain of the HMM forward gives the next pair of property and entity. Since the entity Markov chain changes slowly, the entity is still 3 in the next step – however, the property has changed to 4, and indexing into the memory matrix outputs the next token (aw). Following this same process to generate the third token (the output for the first example), we finish generating one example. To end the example, we append a delimiter (backslash). We repeat this example generation process for all the examples, except for the test example at the end, where we do not generate the last token. We condition the HMM on the generated prompt to compute the posterior distribution over the next token pprompt(yxtest)p_{\text{prompt}}(y|x_{\text{test}}). We take the argmax of this distribution to be the ground truth label.

The dataset contains 1000 training documents and 100 validation documents, where training documents have 10240 tokens and validation documents have 1024 tokens. Each document is generated by first selecting one of the HMMs from the mixture uniformly at random, then generating 10240 tokens from the HMM.

We also generate 2500 in-context prompts for each (example length,number of examples) pair, for example lengths k=k= and number of examples n=n=. Each prompt is generated using a random HMM in the mixture.

F.2 Transformer details

Our Transformer models are based on the GPT-2 architectures with 4, 12, and 16 layers respectively, with 12 attention heads, 768 dimensional embeddings, residual/embedding/attention dropout set to 0.1, and a context window of 1024. Other than the number of layers, the other parameters are the default settings from the HuggingFace library (Wolf et al., 2019). We train for 5 epochs using the AdamW optimizer (Loshchilov and Hutter, 2019, Kingma and Ba, 2015) with a batch size of 8 and a linear learning rate schedule (with 1000 step warmup) up to a learning rate of 8e-4 for the 4 layer and 12 layer model, while for the 16 layer model we start with a constant learning rate of 8e-4 and reduce by a factor of 0.25 whenever the best validation loss does not improve. We tried both learning rate strategies for all models and take the most consistent. We tuned these models so that the training loss curves between seeds have smaller variability between the runs in terms of the curve shape and when the loss decreases – we found that this is an important indication of stable results. The models took 50 minutes, 2 hours, 3 hours to train respectively. The hardware was mainly Titan Xp GPUs, trained and evaluated using 16-bit precision. All the results are reported with 5 pretraining runs (5 different seeds).

F.3 LSTM details

We train an LSTM language model with embedding size 768, hidden layer size 768, and 6 layers. We use dropout 0.2 and weight decay 1e-5. The optimizer is AdamW starting with a learning rate of 1e-3, then reducing by a factor of 0.25 whenever the best validation loss does not go down. We train for a total of 10 epochs, with gradient clipping at norm 1.0. We use a batch size of 8 and backpropagate through time for 1024 steps (each pretraining data segment is also 1024 tokens). Each model takes roughly 2 hours to train on Titan Xp GPUs.

F.4 Varying the vocabulary size

To do well on the in-context learning task, the model must both infer the prompt concept and the last HMM hidden state. In general, increasing the number of observable symbols makes the in-context task easier by making the inference of the HMM hidden state easier. With more symbols, each hidden state is more likely to output a different symbol, making the inference problem easier. This improvement comes despite the number of output classes in the problem (same as the vocabulary size) increasing. Figures 11, 12, 13, 14 show in-context learning curves for vocabulary sizes 50, 100, and 150, keeping other hyperparmeters of the dataset the same.

F.5 Experiment on GPT-3

We conduct an additional experiment which shows that longer examples improve in-context learning in GPT-3 on the LAMBADA (Paperno et al., 2016) completion task.

In this experiment, we define a short version of the LAMBADA test dataset (LAMBADA test-short) which contains only test examples with up to 200–300 characters in length. We also define two “training” datasets from which to sample examples for the in-context prompts from. The short training dataset (LAMBADA train-short) contains examples from the training set that are 200–300 characters in length, which matches the distribution of test-short. The long training dataset (LAMBADA train-long) contains training examples that are 500–600 characters long. We cut the number of examples in the larger of the two training datasets so that the two training datasets are equally sized (47 examples). For each test example, we sample 5 random training examples (5-shot learning).

We also consider equalizing the total length of the prompts in two ways. First, we consider duplicating the 5 short examples (if the examples are , duplicating refers to ). This allows for equalizing the total length without increasing the number of examples. As a skyline comparison, we also consider sampling 10 independent short examples, which contains more input-output pairs for the task.

Table 1 shows that when evaluating only on LAMBADA test-short, 5-shot in-context learning using LAMBADA train-long improves the test accuracy by almost 1% compared to LAMBADA train-short, despite the long/short distribution mismatch between train and test. This supports intuitions from our theory.

In comparison, simply increasing the total prompt length by duplicating the short examples does not improve the accuracy. Intuitively, the longer examples have additional information that is not directly related to mapping between the input and output, but can be leveraged to improve in-context learning by helping the model infer the latent concept. Using 5 long examples (as opposed to 5 short examples) closes about 56% of the gap between using 5 short examples and 10 independent short examples despite not adding additional examples or task-related information.