What Can Transformers Learn In-Context? A Case Study of Simple Function Classes

Shivam Garg, Dimitris Tsipras, Percy Liang, Gregory Valiant

Introduction

Large language models such as GPT-3 [Brown et al., 2020] are able to perform in-context learning: given a prompt containing examples from a task (input-output pairs) and a new query input, the language model can generate the corresponding output. For example, these models are able to produce English translations of French words after being prompted on a few such translations, e.g.:

This capability is quite intriguing as it allows models to adapt to a wide range of downstream tasks on-the-fly—i.e., without the need to perform any parameter updates after the model is trained [Brown et al., 2020, Lieber et al., 2021, Rae et al., 2021, Black et al., 2022]. However, it is unclear to what extent these models have developed the ability to learn new tasks from in-context examples alone as opposed to simply indexing into a vast set of known tasks from the training data (e.g., see Min et al. ). The term “in-context learning” has also been used to refer to a more general notion of learning from a prompt [Olsson et al., 2022]. In this work, we focus on the standard notion which refers to learning a task/function given in-context examples [Brown et al., 2020].

To make progress towards understanding in-context learning, we consider the well-defined problem of learning a function class from in-context examples. That is, we say that a model can in-context learn a function class F\mathcal{F} if, for “most” functions fFf\in\mathcal{F}, the model can approximate f(xquery)f(x_{\text{query}}) for a new query input xqueryx_{\text{query}} by conditioning on a prompt sequence (x1,f(x1),,xk,f(xk),xquery)(x_{1},f(x_{1}),\ldots,x_{k},f(x_{k}),x_{\text{query}}) containing in-context examples and the query input.

Formally, let DXD_{\mathcal{X}} be a distribution over inputs and DFD_{\mathcal{F}} be a distribution over functions in F\mathcal{F}. A prompt PP is a sequence (x1,f(x1),,xk,f(xk),xquery)(x_{1},f(x_{1}),\ldots,x_{k},f(x_{k}),x_{\text{query}}) where inputs (i.e., xix_{i} and xqueryx_{\text{query}}) are drawn i.i.d. from DXD_{\mathcal{X}} and ff is drawn from DFD_{\mathcal{F}}. We say that a model MM can in-context learn the function class F\mathcal{F} up to ϵ\epsilon, with respect to (DF,DX)(D_{\mathcal{F}},D_{\mathcal{X}}), if it can predict f(xquery)f(x_{\text{query}}) with an average error

Within this framework, we can now concretely ask:

Can we train a model to in-context learn a certain function class?

Note that, here, being able to in-context learn a function class is a property of model MM alone, independent of how it was trained. Training such a model can be viewed as an instance of meta-learning [Schmidhuber, 1987, Naik and Mammone, 1992, Thrun and Pratt, 2012], a general paradigm for learning a model or method that can learn from data.

We empirically study this question, focusing on Transformer models [Vaswani et al., 2017, Radford et al., 2018]—the architecture behind recent large language models—trained from scratch to in-context learn a range of simple, well-defined function classes (e.g. linear functions). Specifically, we sample prompts containing in-context examples (input-output pairs) generated using functions in a given class and train models to predict the function value at the corresponding query inputs. (see illustration in Figure 1). Our findings are as follows.

We show empirically that we can train a standard Transformer from scratch to in-context learn the class of linear functions, with respect to the input distribution DXD_{\mathcal{X}} being an isotropic Gaussian in 20 dimensions, and DFD_{\mathcal{F}} being the distribution over linear functions with weight vectors drawn from an isotropic Gaussian (the model was trained on prompts generated from the same distributions DXD_{\mathcal{X}} and DFD_{\mathcal{F}}). Specifically, the trained model achieves error comparable to the optimal least squares estimator, suggesting that it encodes an effective learning algorithm, at least for the distribution used to generate the training prompts.

Generalization to out-of-distribution prompts.

To understand the extent to which the trained model encodes an algorithm that works beyond the training distribution, we consider in-context learning under two types of distribution shifts: (a) a shift between the prompts encountered during training and inference (e.g., training on prompts without any noise in the in-context example outputs but testing with noisy outputs), (b) a shift between the in-context examples and the query input during inference (e.g., in-context examples lie in one orthant and the query input lies in another). We find that the performance of our model is quite robust to such shifts, indicating that it has learned to perform linear regression with some generality.

More complex function classes.

We also consider the function classes of 3-sparse linear functions, two-layer ReLU neural networks with 100 hidden units, and decision trees of depth 4, all with 20 dimensional inputs. We show that we can again train Transformer models that can in-context learn these classes (with respect to isotropic Gaussian inputs and appropriately defined distributions over functions). For sparse linear functions, the trained model is able to exploit sparsity, obtaining performance better than least squares and comparable to Lasso. For neural networks, the corresponding model is able to obtain performance comparable to neural networks of the same architecture trained using gradient descent on in-context examples. Moreover, it is also able to in-context learn linear functions. For decision trees, the trained model can learn unseen trees with as few as 100 in-context examples, whereas greedy learning and tree boosting algorithms are unable to achieve competitive performance (for the distribution of prompts studies here). Note that learning these function classes requires involved algorithms (e.g., gradient descent with the Lasso objective), and our results show that Transformers can encode algorithms with similar performance in a single forward pass.

Role of model capacity and problem dimension.

Finally, we explore how the ability of Transformers to in-context learn linear functions scales with model capacity and problem dimensionality. We find that increasing the capacity of the model improves performance significantly, and also allows the model to in-context learn higher-dimensional functions. Moreover, increasing the capacity often significantly improves performance with distribution shifts, even when the absolute improvement in the standard error is small.

Training models for in-context learning

We now describe a general methodology for training a model that can in-context learn a function class F\mathcal{F} with respect to a distribution DFD_{\mathcal{F}} over functions, and DXD_{\mathcal{X}} over inputs. To do so, we start by constructing random training prompts as follows. For each prompt, we first sample a random function ff from the class according to DFD_{\mathcal{F}}, then create a set of random inputs x1,,xk+1x_{1},\ldots,x_{k+1} drawn independently from DXD_{\mathcal{X}}, and finally evaluate ff on these inputs to produce the prompt P=(x1,f(x1),,xk+1,f(xk+1))P=(x_{1},f(x_{1}),\ldots,x_{k+1},f(x_{k+1})). For example, in the case of linear functions, inputs could be drawn from the isotropic Gaussian distribution N(0,Id)N(0,I_{d}), and a random function chosen by sampling weight vector ww from N(0,Id)N(0,I_{d}) and setting f(x)=wxf(x)=w^{\top}x.

Now, given such prompts, we train a model to predict f(xi)f(x_{i}) for a given xix_{i} based on a set of preceding in-context examples. Concretely, let PiP^{i} denote the prompt prefix containing ii in-context examples (the first ii input-output pairs) and the (i+1)th(i+1)^{\text{th}} input: Pi=(x1,f(x1),x2,f(x2),,xi,f(xi),xi+1)P^{i}=(x_{1},f(x_{1}),x_{2},f(x_{2}),\ldots,x_{i},f(x_{i}),x_{i+1}). Then, we train a model MθM_{\theta} parameterized by θ\theta aiming to minimize the expected loss over all the prompt prefixes:

We use a decoder-only Transformer architecture [Vaswani et al., 2017] from the GPT-2 family [Radford et al., 2019]. Our model consists of 12 layers, 8 attention heads, and a 256-dimensional embedding space (9.5M parameters). This architecture takes as input a sequence of vectors in its embedding space and predicts the next vector in the sequence within the same space (in language modeling, these vectors correspond to input tokens). We apply this architecture to our prompt format of (x1,f(x1),,xk+1,f(xk+1))(x_{1},f(x_{1}),\ldots,x_{k+1},f(x_{k+1})) as follows. We map each prompt output f(xi)f(x_{i}) to the same dimension as prompt inputs xix_{i} by appending zeros, and map the prompt inputs and outputs into the latent embedding space of the Transformer through a (learnable) linear transformation. We then use another (learnable) linear transformation to map the vector predicted by the model to a scalar. Note that the Transformer architecture allows us to compute the prediction (Mθ(Pi)M_{\theta}(P^{i})) for all prompt prefixes in a single forward pass.

Training.

We train the model according to the training objective in (2) using squared error as the loss function. We do so by sampling a batch of random prompts at each training step and then updating the model through a gradient update (we use a batch size of 64 and train for 500k total steps). This training is done from scratch, that is, we do not fine-tune a pre-trained language model, nor do we train on actual text.

Curriculum learning.

Many natural function classes contain functions of varying complexity. We exploit this by training our model using a curriculum [Bengio et al., 2009, Elman, 1993, Sanger, 1994, Wu et al., 2020], where we train on a simpler distribution of functions in the beginning (e.g., linear functions with weight vectors restricted to a low-dimensional subspace) and gradually increase the function complexity. This speeds up training drastically, often allowing us to train models that would be significantly more expensive to train without a curriculum (see Section 6 for details).

In-context learning of linear functions

In the previous section, we describe a general methodology for training Transformer models to in-context learn a class of functions. Here, we focus on a simple function class—namely linear functions—and study how well models trained using our methodology can in-context learn this class.

Baselines.

To contextualize the performance of our trained model, we compare it to other learning algorithms: (a) the least squares estimator, computing the minimum-norm linear fit to the in-context examples (xi, yi)(x_{i},\ y_{i}), (b) nn-Nearest Neighbors, averaging the yiy_{i} values for the nn nearest neighbors of xqueryx_{\text{query}}, (c) averaging the values yixiy_{i}x_{i} to estimate ww and compute the inner product of this estimate with xqueryx_{\text{query}}. Least squares is the optimal estimator for this problem and thus serves as a lower bound to the best error one can achieve. The other two baselines are consistent (but sub-optimal) estimators that are easier to compute and thus provide an estimate of the performance achieved by simple approaches. See Appendix A.3 for more details.

1 Transformers can in-context learn linear functions

We show the in-context learning ability of the resulting model along with the relevant baselines in Figure 2. The trained Transformer is able to in-context learn the class of linear functions with respect to the prompt distribution specified above, performing comparably to the optimal least squares estimator for any number of in-context examples considered. While the simpler baselines achieve non-trivial error, they are far worse, indicating that the trained model encodes a more complex algorithm.

Note that the probability of the model encountering a training prompt similar to the one used for testing is astronomically low—the prompt inputs alone lie in a 800-dimensional space when predicting with 2d2d in-context examples (d=20d=20). Moreover, even considering the possibility that the model encountered a similar weight vector during training cannot explain its performance. That is, the model encounters 32 million random weight vectors during training and even using the best of these vectors would lead to an expected error of around 0.20.2 (computed empirically, see Appendix B.7 for details). However, the model is able to achieve an error of less than 0.0010.001 for a prompt with 2d2d in-context examples. Further, in Section 6, we show that the model is able to obtain a similar error even when trained on prompts generated using only 10,00010,000 distinct weight vectors, in which case the best weight vector seen during training would yield an even worse error of around 0.50.5. Thus, the model cannot be relying on memorization of training prompts or weight vectors, and instead encodes an algorithm capable of in-context learning linear functions that are very different from those seen during training.

2 What functions is the model learning in-context?

Recall that the goal of our model is: given the prompt P=(x1, wx1,,xk, wxk, xquery)P=(x_{1},\ w^{\top}x_{1},\ldots,x_{k},\ w^{\top}x_{k},\ x_{\text{query}}), output wxqueryw^{\top}x_{\text{query}}. Thus, if we fix the prefix given by the kk in-context examples, we can view the output of the model as a function f^w,x1:k(xquery)\hat{f}_{w,x_{1:k}}(x_{\text{query}}), that approximates wxqueryw^{\top}x_{\text{query}}. When k<dk<d (fewer in-context examples than dimensions), the ground truth cannot be recovered perfectly and the ideal model should approximate (projx1:k(w))xquery(\text{proj}_{x_{1:k}}(w))^{\top}x_{\text{query}}, where projx1:k(w)\text{proj}_{x_{1:k}}(w) is the projection of ww onto the subspace spanned by x1,,xkx_{1},\ldots,x_{k}. Here, we will evaluate how accurately the model approximates this.

For a randomly sampled fixed prefix, we visualize f^w,x1:k(xquery)\hat{f}_{w,x_{1:k}}(x_{\text{query}}) as we vary the query input along a random direction xx (Figure 3(a)). That is, we pick a random unit vector xx, and evaluate f^w,x1:k(λx)\hat{f}_{w,x_{1:k}}(\lambda x) as we vary λ\lambda, the distance of the query input from origin. We observe that f^w,x1:d(λx)\hat{f}_{w,x_{1:d}}(\lambda x) and f^w,x1:2d(λx)\hat{f}_{w,x_{1:2d}}(\lambda x) closely match the ground truth and f^w,x1:d/2(λx)\hat{f}_{w,x_{1:d/2}}(\lambda x) matches the projected ground truth, when the distance from origin is not too large compared to the norm of a typical randomly sampled input. In fact, in Appendix B.1, we show that the model is quite robust to scaling the query input: the error doesn’t increase much as we scale up the query input by a factor of up to 2, or scale down by a factor of up to 16, and degrades slowly after that.

Local correctness.

So far, we have seen that the model is able to make predictions close to the ground truth for randomly drawn query inputs and in-context examples. We will now turn our attention to the local change of f^\hat{f} around xqueryx_{\text{query}} by considering the gradient of the function f^w,x1:k(xquery)\hat{f}_{w,x_{1:k}}(x_{\text{query}}) with respect to xqueryx_{\text{query}} (our model is fully differentiable so we can compute the gradient directly). Since f^\hat{f} computed by the model should ideally approximate projx1:k(w)x\text{proj}_{x_{1:k}}(w)^{\top}x, this gradient should lie in the direction of the projected ground truth projx1:k(w)\text{proj}_{x_{1:k}}(w). In Figure 3(b), we show the inner product between the gradient and projx1:k(w)\text{proj}_{x_{1:k}}(w) (both normalized), averaged over 1280 random prompts, and observe that they align almost perfectly. Since projx1:k(w)=w\text{proj}_{x_{1:k}}(w)=w almost surely when kdk\geq d, we observe that the gradient also aligns with ww perfectly in this regime. Thus the model is locally correct with respect to changes in the query input.

Extrapolating beyond the training distribution

In the previous section, we demonstrated that we can train a model to in-context learn linear functions with respect to the distribution of prompts encountered during training. That is, we evaluate the in-context learning ability of the model with respect to distributions DXD_{\mathcal{X}} and DFD_{\mathcal{F}} that were also used to train the model.

Here, we evaluate the in-context learning performance of our model on prompt distributions different from the one used for training. Our overarching goal here is to better understand the learning algorithm encoded by our model by analysing how it responds to different prompt distributions.

Formally, we will refer to the distribution of functions used during training as DFtrainD_{\mathcal{F}}^{\text{train}} and the corresponding distribution of prompt inputs as DXtrainD_{\mathcal{X}}^{\text{train}}. Then, during inference, functions are sampled from a (potentially different) distribution DFtestD_{\mathcal{F}}^{\text{test}}, while prompt inputs from a distribution DXtestD_{\mathcal{X}}^{\text{test}}. Moreover, deviating again from our analysis so far, we also consider a separate distribution DquerytestD_{\text{query}}^{\text{test}}, from which the query input is sampled, potentially dependent on the rest of the in-context inputs x1,,xkx_{1},\ldots,x_{k} (which are still sampled from DXtestD_{\mathcal{X}}^{\text{test}}).

Within this framework, we consider the same model as last section, and evaluate its performance on prompts that deviate from those encountered during training, either by

sampling prompt inputs or functions from a different distribution, that is DX/FtrainDX/FtestD_{\mathcal{X/F}}^{\text{train}}\neq D_{\mathcal{X/F}}^{\text{test}} or

introducing a mismatch between in-context examples and the query input, that is DquerytestDXtestD_{\text{query}}^{\text{test}}\neq D_{\mathcal{X}}^{\text{test}}.

We describe each such prompt structure below and present a subset of the results in Figure 4 (see Appendix B.2 for additional details and full results). Overall, the model performs reasonably accurate in-context learning with respect to these prompt distributions, indicating that it has indeed learnt to perform linear regression to some generality.

Recall that we generate a training prompt P=(x1,wTx1,,xk,wTxk,xquery)P=(x_{1},w^{T}x_{1},\ldots,x_{k},w^{T}x_{k},x_{\text{query}}) by drawing the prompt inputs (xix_{i} and xqueryx_{\text{query}}), and the weight vector (ww) i.i.d. from N(0,Id)N(0,I_{d}), with d=20d=20. For all the settings below, except prompt scaling, we normalize the inputs so that their expected squared norm is equal to that of inputs encountered during training.

We sample prompt inputs from N(0,Σ)N(0,\Sigma) where Σ\Sigma is a skewed covariance matrix with eigenbasis chosen uniformly at random and ithi^{\text{th}} eigenvalue proportional to \nicefrac1i2\nicefrac{{1}}{{i^{2}}}. The model matches the performance of least squares until k=10k=10, mimicking the sharp drop in the error in this regime, but its error plateaus afterwards (see Figure 4(a)). Thus, it is not perfectly robust to this distribution mismatch but still does relatively well, achieving less than half the error of the nearest neighbor baseline in most cases.

Low-dimensional subspace.

We sample prompt inputs from a random 1010 dimensional subspace. In this case, the model achieves low error after 10 in-context examples, closely matching the behavior of the optimal least squares estimator (the model achieves an error of 0.0360.036, 0.00140.0014, and 0.000570.00057 at 10, 20, and 40 in-context examples respectively)—see Appendix Figure 8(b). Crucially, unlike the training prompts, when kk is between 1010 and 2020, the prompt inputs are linearly dependent, and a model achieving low error in this regime indicates that it encodes a valid orthogonalization procedure for these inputs.

Noisy linear regression.

Prompt scale.

We consider the setting where the prompt scale between training and inference is different. We either scale the prompt inputs or the weight vectors, by a factor {\nicefrac13,\nicefrac12,2,3}\{\nicefrac{{1}}{{3}},\nicefrac{{1}}{{2}},2,3\}. The model is relatively robust when scaling the weight vector, but not as robust when scaling the prompt inputs, especially for the more extreme scales \nicefrac13\nicefrac{{1}}{{3}} and 33. Specifically, for 4040 in-context examples, the model achieves errors 0.0012,0.0008,0.0016,0.02780.0012,0.0008,0.0016,0.0278 when scaling the weights, and errors 0.30,0.013,0.043,0.580.30,0.013,0.043,0.58 while scaling the inputs, by factors \nicefrac13,\nicefrac12,2\nicefrac{{1}}{{3}},\nicefrac{{1}}{{2}},2 and 33 respectively (Appendix Figure 9). For context, recall that with 40 in-context examples, the least squares estimator achieves an error of whereas the model achieves an error of 0.00060.0006 at the original scale.

Different orthants for in-context and query inputs.

We fix the sign of each coordinate to be positive or negative for all in-context inputs xix_{i} (at random). As a result, all in-context inputs lie in the same orthant, while the query input lies in another orthant with high probability. The model is not affected by the mismatch between in-context and query inputs and closely match the performance of least squares. In this case, the model achieves errors 0.0620.062 and 0.0040.004 for 2020 and 4040 in-context examples respectively (see Figure 4(c)), whereas recall that it achieves errors 0.020.02 and 0.00060.0006 on standard prompts. This indicates that the model is not relying on some variant of nearest neighbor search as in that case, its error would have been significantly larger (see the 3-nearest neighbor baseline).

Query input orthogonal to in-context inputs.

We sample the query input from the subspace orthogonal to the subspace spanned by in-context example inputs. Here, there is no information relevant to the query input in the in-context examples and thus the model would ideally predict something close to 0 to minimize the error. Indeed, the model outputs such a prediction, achieving an error close to 11 (Appendix Figure 8(d)).

Query input matches an in-context example.

We choose the query input to match one of the in-context examples inputs chosen uniformly at random. In this case, the model achieves errors 0.001,0.001,0.00050.001,0.001,0.0005 for 1010, 2020, 4040 examples respectively thus making close to the correct prediction, without being affected by the additional in-context examples present (Appendix Figure 8(e)).

More complex function classes

We now turn our attention to in-context learning for more complex function classes, namely sparse linear functions, decision trees, and two-layer ReLU neural networks. Here, we are back in the setting where the distribution of prompts during inference is same as that during training (except the setting of neural networks where we evaluate on linear functions as well). The overall methodology remains the same: we sample random functions from these families and train a Transformer from scratch to approximate these functions given in-context examples. (See Appendix A.3 for more details and baselines.)

Decision trees.

Next, we consider the class of depth 4 decision trees with 20 dimensional inputs. A function ff in this class is represented by a full binary tree (with 16 leaf nodes) where each non-leaf node is associated with a coordinate, and each leaf node is associated with a target value. To evaluate ff on an input xx, we traverse the tree starting from the root node, and go to the right child if the coordinate associated with the current node is positive and go to the left child otherwise (that is, the threshold at each node is ). f(x)f(x) is given by the value associated with the leaf node reached at the end. To sample a random prompt P=(x1,f(x1),,xk,f(xk),xquery)P=(x_{1},f(x_{1}),\ldots,x_{k},f(x_{k}),x_{\text{query}}), we draw prompt inputs xix_{i}s and xqueryx_{\text{query}} from N(0,Id)N(0,I_{d}), and ff corresponds to a tree where the coordinates associated with the non-leaf nodes are drawn uniformly at random from {1,2,,d}\{1,2,\ldots,d\} and the values associated with the leaf nodes are drawn from N(0,1)N(0,1). In Figure 5(b), we show that Transformers can be trained to in-context learn this class, with performance much better than greedy tree learning and boosting (via XGBoost [Chen and Guestrin, 2016]). With k=100k=100 in-context examples, the Transformer achieves an error of 0.120.12 whereas greedy learning achieves an error of 0.800.80 and XGBoost achieves an error of 0.620.62.

Since the decision trees in our function class predict solely based on the sign of each coordinate of xix_{i}, we also consider a baseline where we provide the greedy learning and XGBoost algorithms with the signs of each xix_{i} instead. This significantly improves their performance—at 100 in-context examples, greedy achieves an error of 0.50 and XGBoost an error if 0.31—but they still perform much worse than the trained Transformer.

Note that, in general, we do not have a good understanding of the space of efficient algorithms for learning decision trees, and the conditions under which known heuristics work [Blanc et al., 2021, Brutzkus et al., 2020]. At the same time, we found that Transformers can be trained to directly discover such an algorithm for the prompt distribution we considered. This suggests an intriguing possibility where we might be able to reverse engineer the algorithm encoded by a Transformer to obtain new sample efficient algorithms for existing learning problems.

Two-layer ReLU neural networks.

Moreover, the model trained to in-context learn two-layer neural networks is also able to in-context learn linear functions (for which it is not explicitly trained), albeit with a rate slower than least squares, but comparable to a neural network trained on in-context examples generated using a linear function (Figure 5(d)). For k=20k=20, 5050, and 100100 in-context examples respectively, the Transformer achieves error 0.340.34, 0.050.05, and 0.010.01, and the two-layer network achieves error 0.370.37, 0.040.04, and 0.0030.003 (the least squares estimator achieves error for k20k\geq 20).

Investigating what matters for in-context learning

We now return to the setting of training models to in-context learn linear functions and explore different factors that lead to successful in-context learning.

In Section 3 and 4, we saw that Transformer models can be trained to in-context learn 20-dimensional linear functions accurately and relatively robustly. To explore the interplay between problem dimensionality and capacity, we also consider models with fewer parameters (see Appendix A.1) and train each architecture on {10, 30, 40, 50}-dimensional problems. In Figure 6, we plot the model error with 2d2d in-context examples as we vary the problem dimension dd and the model capacity. In the standard setting, i.e., when the training and inference time prompt distributions are the same, we observe that the error decreases as we increase the capacity or reduce the problem dimensionality (see Figure 6(a)). Thus, model capacity helps perform accurate in-context learning. For out-of-distribution prompts, we observe that the settings where the input covariance is skewed or where in-context example inputs and query inputs lie in different orthants are particularly challenging, especially for higher dimensional problems. However, the error decreases considerably (in most cases) as we increase the model capacity, even when absolute decrease in the standard error is small (see Figure 6(b) and 6(c)). See Appendix B.3 for additional plots.

Curriculum.

We train our models using curriculum learning. That is, we initially draw the prompt inputs from a fixed 55 dimensional subspace (by setting some of the coordinates to ) with prompt length 1111 (number of input-output pairs), and increase the subspace dimension by 11 and prompt length by 22 every 2,0002,000 training steps, until the subspace dimension reaches the ambient dimension dd and prompt length reaches 2d+12d+1 (see Appendix A.2 for details). This process can also be viewed as gradually increasing the complexity of the function class. This speeds up training drastically, especially for higher dimensional problems: for dimension 50, the loss barely decreases through the 500k training steps without curriculum but reaches close to the optimum with curriculum. For the 20 dimensional problem where we were able to train the model without curriculum within the training (step count) budget, we did not observe any qualitative difference in accuracy or robustness compared to the model trained with curriculum. We include plots comparing the speed and accuracy of training with and without curriculum in Appendix B.5.

Notably, when training Transformers without curriculum, there is an initial—relatively long—period in training where the loss does not decrease, followed by a period of sharp decrease. The length of this period varies with training randomness and seems to increase on average with problem dimension. Understanding the model just before and after this transition moment is a promising future direction, which can give insights into the emergence of in-context learning. Interestingly, Olsson et al. observe a similar jump in the in-context learning ability of a language model which they attribute to the formation of “induction heads”.

Number of distinct prompts or functions seen during training.

To estimate the amount of training data required for in-context learning, we perform two ablation studies. In the first study, we limit the number of distinct prompts seen during training. That is, we create a set of npn_{p} randomly generated prompts (as described in Section 2), and sample prompts from this set during training (here, we train without curriculum, as it would introduce additional prompts during the warmup phase). In the second study, we only limit the number of distinct functions used for training. That is we create a set of nwn_{w} randomly chosen vectors (corresponding to nwn_{w} linear functions) and sample weight vectors uniformly from that set to generate the training prompts (the inputs are still sampled from N(0,Id)N(0,I_{d}) for each training prompt). We find that the amount of training data required is relatively small: non-trivial in-context learning is possible with np=100kn_{p}=100\text{k} or nw=1kn_{w}=1\text{k}, and the error drops close to that of the unrestricted model (discussed in Section 3) with np=1Mn_{p}=1\text{M} or nw=10kn_{w}=10\text{k} (details in Appendix B.6). For context, in Section 3, the model is trained on fresh prompts each step, thus encountering 32M distinct functions and prompts (500k training steps with 64 prompts/batch).

Related work

Since Brown et al. demonstrated the in-context learning ability of GPT-3, there has been a significant interest in improving and understanding this capability [Liu et al., 2021, Min et al., 2021a, Zhao et al., 2021, Lu et al., 2021b, Rubin et al., 2021, Min et al., 2021b, Chen et al., 2021, Mishra et al., 2021, Lampinen et al., 2022]. The works most relevant to ours are as follows. Xie et al. propose a Bayesian inference framework explaining how in-context learning works despite formatting differences between training and inference distributions. Razeghi et al. show that in-context learning for numerical reasoning tasks is better for instances whose terms are more prevalent in training data. Min et al. [2021a] demonstrate tasks where in-context learning works even when the prompt outputs are chosen randomly, questioning to what extent these models are truly learning new tasks on-the-fly, while Rong gives examples of novel tasks on which these models demonstrate on-the-fly learning ability. Chan et al. demonstrate that distributional properties such as long-tailedness are crucial for in-context learning on an image-based few-shot dataset. Olsson et al. and Elhage et al. consider a different framing of in-context learning, referring to any model behavior that utilizes information in a prompt to make predictions that improve with prompt size. They hypothesize the existence of special circuits inside Transformer models responsible for in-context learning, that can complete prompts by copying previous similar patterns in the prompt sequence. Pesut and Dinh et al. [2022, Table 16] consider in-context learning for small tabular datasets and learning problems in one and two dimensions, and show that GPT-3 can obtain non-trivial accuracy. Our work contributes to and complements this line of work, by posing in-context learning as a well-defined problem of learning function classes at inference time, and empirically investigating training models that in-context learn simple function classes.

Transformers.

There is a long line of work investigating the capabilities [Vaswani et al., 2017, Dehghani et al., 2018, Yun et al., 2019, Pérez et al., 2019, Yao et al., 2021, Bhattamishra et al., 2020b, Zhang et al., 2022], limitations [Hahn, 2020, Bhattamishra et al., 2020a], applications [Lu et al., 2021a, Dosovitskiy et al., 2020, Parmar et al., 2018], and internal workings [Elhage et al., 2021, Snell et al., 2021, Weiss et al., 2021, Edelman et al., 2022, Olsson et al., 2022] of Transformer models. Most similar to our work, Müller et al. and Nguyen and Grover demonstrate the ability of Transformer models to solve prediction tasks using the input context, albeit in different settings. Müller et al. introduce a “Prior-data fitted transformer network” that is trained to approximate Bayesian inference with priors such as Gaussian processes and Bayesian neural networks, and use it to perform downstream tasks such as tabular dataset classification and few-shot image classification. Nguyen and Grover introduce Transformer neural processes, building on prior work on neural processes [Garnelo et al., 2018b, a, Kim et al., 2019], and show that they achieve state-of-the art performance on tasks such as image completion and contextual multi-armed bandits. Our work complements these works, focusing on understanding the in-context learning ability of Transformers for various simple function classes and the extent to which this ability extrapolates beyond the training distribution.

Meta learning.

Training a model to perform in-context learning can be viewed as an instance of the more general learning-to-learn or meta-learning paradigm [Schmidhuber, 1987, Naik and Mammone, 1992, Thrun and Pratt, 2012]. Typical approaches from this extensive line of work (see [Hospedales et al., 2020] for a survey) include: training a meta-learner on how to update the parameters of a downstream learner [Bengio et al., 1995, Li and Malik, 2016], learning parameter initializations from which one can quickly train for many downstream tasks [Finn et al., 2017, Ravi and Larochelle, 2017], learning latent embeddings that allow for effective similarity search [Snell et al., 2017]. Most relevant to our setting are approaches that directly take as input examples from a downstream task and a query input and produce the corresponding output [Hochreiter et al., 2001, Mishra et al., 2018, Santoro et al., 2016, Garnelo et al., 2018b, a, Kirsch and Schmidhuber, 2021]. Our work contributes to this line of work, by investigating the learning-to-learn abilities of Transformer models in a well-defined setting.

Data-driven algorithm design.

Another line of work aims to discover algorithms that perform well on a distribution of inputs [Horvitz et al., 2001, Xu et al., 2008, Vinyals et al., 2015, Bello et al., 2016, Khalil et al., 2017, Selsam et al., 2018, Schwarzschild et al., 2021] (as opposed to algorithms with guarantees on their worst-case performance). See Balcan for a survey on advancements on the theoretical foundations of such algorithms. Our work can be viewed as part of this line of work, as we train Transformer models to discover algorithms for different learning problems.

Discussion

In this work, we formalize and study the question: can we train models that learn different classes of functions in-context? We show that Transformer models trained from scratch can in-context learn the class of linear functions, with performance comparable to the optimal least squares estimator, even under distribution shifts. Moreover, we show that in-context learning is also possible for sparse linear functions, decision trees, and two-layer neural networks; learning problems which are solved in practice with involved iterative algorithms such as gradient descent.

At the same time, understanding the implications of our results for language models requires further investigation. A pertinent question regarding the in-context learning capabilities of language models is how they leverage in-context examples [Min et al., 2022]. Our results demonstrate that Transformers can encode complex learning algorithms that utilize in-context examples in a far-from-trivial manner. In fact, this is the case for standard Transformer architectures trained with standard optimization procedures. The extent to which such non-trivial in-context learning behavior exists in large language models is still open, but we believe that our work takes a step towards formalizing and investigating this question.

Our work lays the groundwork for several future directions.

We empirically show that model capacity helps in performing in-context learning accurately and robustly. This raises the question: How does the in-context learning loss (1) depend on the complexity of the function class F\mathcal{F}, the capacity of model MM, and the number of prompts used to train MM. Even the right notion of complexity of F\mathcal{F} is unclear and may depend on the model family. Understanding this question for models explicitly trained to perform in-context learning may suggest an upper bound for the in-context learning performance of models such as GPT-3 that have not been explicitly trained for this purpose.

Curriculum learning.

Within our framework, there is natural notion of curriculum learning where during training, we gradually increase the complexity of the function class learned in-context. This leads to drastic speed-ups in training. What is the reason behind such a speedup? Are similar speedups also possible for training large language models? Understanding these questions can have implications for training of models on large real-world datasets, potentially reducing the time and energy used for training.

Inductive bias of model families.

Our framework presents an opportunity to understand and compare the inductive biases of different model families (e.g., Transformers vs. LSTMs) in a well-defined setting. For instance, a concrete question is: Are there function classes that are easier to in-context learn using Transformers but harder for LSTMs and vice-versa?

Understanding the learning algorithms encoded in Transformers.

The models we train are able to perform in-context learning, and are thus themselves encoding learning algorithms. A worthwhile research direction would be to investigate the internal workings of these models and better understand the exact learning algorithms that they encode. Moreover, for settings such as decision trees, we do not have a good understanding of what the optimal learning algorithms are. Nevertheless, in Section 5 we found that Transformers are able to discover sample efficient algorithms when being trained to perform in-context learning. This suggests an intriguing possibility where we might be able to reverse engineer the Transformer to obtain better learning algorithms for such problems.

Acknowledgements

We thank Niladri Chatterji, Micah Goldblum, Rohith Kuditipudi, Shibani Santurkar, Carmen Strassle, Mirac Sugzun, and Li-Yang Tan for helpful conversations, and anonymous reviewers for helpful comments.

SG was funded by a Stanford Interdisciplinary Graduate Fellowship. DT was funded by Open Philanthropy, and partially supported by NSF Award CCF-1813049. GV was supported by NSF Awards CCF-1704417, CCF-1813049, Frontier Award 1804222 and DOE award DE-SC0019205. We performed our experiments on the Stanford NLP cluster.

References

Appendix A Experimental setup

Here, we provide additional details on our experimental setup.

We use architectures from the GPT-2 family [Radford et al., 2018] as implemented by HuggingFace [Wolf et al., 2020] https://huggingface.co/docs/transformers/model_doc/gpt2 . Specifically, we consider the following set of configurations.

We use the Standard model for the bulk of our experiments and only consider the smaller models for the capacity explorations in Section 6 and Appendix B.3. Since we train on each input once (we sample new inputs at each training step), overfitting to the training data is not an issue. Therefore, we set the Dropout probability to 0.

Out of the box, these models take as input a sequence of vectors in embedding space and output a sequence of vectors in the same space. However, the tasks we study are functions from a lower dimensional vector space (e.g., 10-50 dimensions) to a scalar value. Thus, in order to use a prompt such as x1,f(x1),x2,f(x2)x_{1},f(x_{1}),x_{2},f(x_{2})\ldots, we need to map xix_{i}s and f(xi)f(x_{i})s to vectors in embedding space. We do so by first turning the scalars f(xi)f(x_{i}) into vectors of the same dimension as xix_{i} by appending 0s and then applying a learnable linear transformation to map all these vectors into the embedding space. Finally, we map the model output into a scalar value through a dot product with a learnable vector.

We treat the prediction of the model at the position corresponding to xix_{i} (that is absolute position 2i12i-1) as the prediction of f(xi)f(x_{i}). Due to the structure of these models, this prediction only depends on (xj,f(xj))(x_{j},f(x_{j})) for j<ij<i and xix_{i}. We ignore the model predictions at positions corresponding to f(xi)f(x_{i}).

A.2 Training

Each training prompt is produced by sampling a random function ff from the function class we are training on, then sampling inputs xix_{i} from the isotropic Gaussian distribution N(0,Id)N(0,I_{d}) and constructing a prompt as (x1,f(x1),,xk,f(xk))(x_{1},f(x_{1}),\ldots,x_{k},f(x_{k})). Given a prompt, we obtain model predictions y^i\hat{y}_{i} (meant to approximate f(xi)f(x_{i})) for each input, and compute the loss

At each training step, we average the loss over a batch of randomly generated prompts (with different functions and prompt inputs), and perform an update step. We use the Adam optimizer [Kingma and Ba, 2014], and train for 500,000 total steps with a batch size of 64. We use a learning rate of 10410^{-4} for all function classes and models.

To accelerate training, we start by training on prompt inputs xix_{i} lying in a smaller dimensional subspace, and with fewer inputs per prompt, and gradually increase the subspace dimension and number of prompt inputs. Specifically, we zero out all but the first dcurd_{\text{cur}} coordinates of xix_{i}, sample prompts of size kcurk_{\text{cur}} and leave the rest of the training process the same. We use the same schedule for all training runs for the function classes of linear functions and sparse linear functions, starting with dcur=5, kcur=11d_{\text{cur}}=5,~{}k_{\text{cur}}=11, and increasing dcurd_{\text{cur}} and kcurk_{\text{cur}} by 11 and 22 respectively, every 2000 steps, until dcur=dd_{\text{cur}}=d, kcur=2d+1k_{\text{cur}}=2d+1. We use a slightly different schedule for 2 layer neural networks and decision trees as we want prompts with more inputs for these function classes. For these classes, we start with dcur=5, kcur=26d_{\text{cur}}=5,~{}k_{\text{cur}}=26, and increase dcurd_{\text{cur}} and kcurk_{\text{cur}} by 11 and 55 respectively, every 2000 steps, until dcur=dd_{\text{cur}}=d, kcur=5d+1k_{\text{cur}}=5d+1.

Overall, with curriculum, a training prompt (x1,f(x1),,xkcur,f(xkcur)(x_{1},f(x_{1}),\ldots,x_{k_{\text{cur}}},f(x_{k_{\text{cur}}}) is generated by sampling a random function ff from the function class, drawing inputs xix_{i} by sampling i.i.d. from N(0,Id)N(0,I_{d}) and zeroing out all but the first dcurd_{\text{cur}} coordinates. Given model predictions yi^\hat{y_{i}}, the loss is given by

Sampling random functions.

For the class of linear functions, we sample random function f(x)=wxf(x)=w^{\top}x by drawing wN(0,Id)w\sim N(0,I_{d}). For our main setting (Section 3 and 4), we set d=20d=20.

For the class of two-layer neural networks, we sample f(x)=i=1rαiσ(wix)f(x)=\sum_{i=1}^{r}\alpha_{i}\sigma(w_{i}^{\top}x), where αi\alpha_{i}s and wiw_{i}s are drawn i.i.d. from N(0,2/r)N(0,2/r) and N(0,Id)N(0,I_{d}) respectively. We set d=20d=20 and r=100r=100.

For the class of kk-sparse linear functions, we sample f(x)=wxf(x)=w^{\top}x by drawing wN(0,Id)w\sim N(0,I_{d}) and zeroing out all but kk coordinates of ww chosen uniformly at random from the first dcurd_{\text{cur}} coordinates (as defined in the curriculum learning description above). We set d=20d=20 and k=3k=3.

For the class of decision trees, the random function ff is represented by a decision tree of depth 44 (with 16 leaf nodes), with 20 dimensional inputs. Each non-leaf node of the tree is associated with a coordinate selected uniformly at random from {1,2,,d}\{1,2,\ldots,d\}, and each leaf node is associated with a value drawn randomly from N(0,1)N(0,1). To evaluate ff on an input xx, we traverse the tree starting from the root node, and go to the right child if the coordinate associated with the current node is positive and go to the left child otherwise. f(x)f(x) is given by the value associated with the leaf node reached at the end.

Computational resources.

We train using a single NVIDIA GeForce RTX 3090 GPU and most training runs take 5-20 hours depending on model size and context length. For instance, for the class of linear functions, training the standard model takes 17 hours for d=50d=50, 7 hours for d=20d=20 and 5.5 hours for d=10d=10. For decision trees, training the standard model takes 17 hours. The time it takes for decision trees and 5050 dimensional linear functions is higher due to larger context lengths (we train for dd dimensional linear functions with 2d+12d+1 input-output pairs per prompt).

A.3 Baselines

Minimum norm least squares is the optimal estimator for the linear regression problem. Given a prompt P=(x1,y1,,xk,yk,xquery)P=(x_{1},y_{1},\ldots,x_{k},y_{k},x_{\text{query}}), let XX be a k×dk\times d matrix with row ii given by xix_{i}, and let yy be a kk dimensional vector with the ithi^{\text{th}} entry yiy_{i}. Set w^T=X+y\hat{w}^{T}=X^{+}y, where X+X^{+} denotes the Moore-Penrose pseudoinverse of XX. The estimator predicts M(P)=w^TxqueryM(P)=\hat{w}^{T}x_{\text{query}}.

Averaging estimator.

This corresponds to M(P)=w^TxqueryM(P)=\hat{w}^{T}x_{\text{query}} where w^=1ki=1kxiyi\hat{w}=\frac{1}{k}\sum_{i=1}^{k}x_{i}y_{i}. This estimator is consistent (yet sub-optimal) when xix_{i}s are drawn from N(0,Id)N(0,I_{d}). Unlike least squares, this estimator does not involve an inverse computation, and might be easier for a model to encode.

Nearest neighbors.

This corresponds to setting M(P)=1niSyiM(P)=\frac{1}{n}\sum_{i\in S}y_{i}. Here, SS is the set of indices of the nn nearest neighbors of xqueryx_{\text{query}} among x1x_{1} to xkx_{k}. For k<nk<n, we average over all the yiy_{i}s from 11 to kk, and for k=0k=0, we set M(P)=0M(P)=0. We consider the nearest neighbors baselines as it might be easier for a Transformer model to encode using self-attention compared to least squares.

Lasso.

We try different values of α{1,101,102,103,104}\alpha\in\{1,10^{-1},10^{-2},10^{-3},10^{-4}\}, and report the best solution (achieving the smallest error with 1010 in-context examples) corresponding to α=102\alpha=10^{-2}. To solve the optimization problem, we use the Lasso implementation from Scikit-learn [Pedregosa et al., 2011] https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html.

Greedy Tree Learning.

We use this baseline for the class of decision trees. This corresponds to greedily learning a decision tree using the in-context examples, and using it to classify the query input. To construct the tree, at each node (starting from a root node), we choose a coordinate for partitioning the examples into two sets, so as to minimize the variance of yiy_{i}s in each set, averaged across the two sets. The value associated with a leaf node is the average yiy_{i} value of the examples belonging to it. We use Scikit-learn’s decision tree regressor [Pedregosa et al., 2011] https://scikit-learn.org/stable/modules/tree.html#regression implementation for this, with all the arguments set to their default value except the max_depth argument which is set to 2. We considered values {1,2,3,4,5,6,unbounded}\{1,2,3,4,5,6,\text{unbounded}\} for the maximum depth and chose the value that performs best at 100 in-context examples which was 2 (which differs from the decision trees sampled from the function class which have depth 4). We also considered a baseline where we learn this tree using only the signs of each xix_{i} coordinate—after all, the decision tree we are trying to learn depends only on the signs of xix_{i}. In this case, we found the optimal depth to be 4.

Tree boosting.

For the class of decision trees, we also consider a tree boosting baseline that corresponds to learning an ensemble of decision trees (see Friedman for a description of the general framework). Specifically, we use the XGBoost library [Chen and Guestrin, 2016] https://github.com/dmlc/xgboost, an implementation commonly used for a wide range of real-world machine learning tasks.

We performed a hyperpameter search by considering {1, 2, 5, 10, 50, 100, 200, 400} estimators in the ensemble (equivalent to number of boosting rounds), a learning rate of {0.001, 0.01, 0.1, 0.3, 0.6, 1, 3}, and a maximum depth of {1, 2, 3, 4, 6, 10, 16}. In general, we found the performance of the learning algorithm to be quite robust. We chose the hyperparameters obtaining the best performance with 100 training examples, corresponding to 50 estimators, a maximum depth of 4, and a learning rate f 0.1. We found these hyperparameters to also be optimal when learning based on the signs of each xix_{i}.

Learning neural networks with gradient descent.

We use this baseline for the class of two-layer neural networks (Section 5). This corresponds to training a two-layer neural network on the in-context examples, and outputting its prediction on the query point. That is, M(P)=f^(xquery)M(P)=\hat{f}(x_{\text{query}}), where

Here, σ()\sigma(\cdot) is the ReLU activation. We find parameters α^i,w^i\hat{\alpha}_{i},\hat{w}_{i} by minimizing the squared error of the prediction for the in-context examples

using the Adam optimizer. We use a batch size of 10 (we use full batch when the number of in-context examples is less than 10) with 5000 optimization steps, and set r=100r=100. We use a learning rate of 51035\cdot 10^{-3} in the case when the data is generated using a neural network, and a learning rate of 51025\cdot 10^{-2} when the data is generated using a linear function. We consider the setting with 100100 in-context examples, and do a hyperparameter grid search over learning rate {5104,5103,5102,5101,5}\in\{5\cdot 10^{-4},5\cdot 10^{-3},5\cdot 10^{-2},5\cdot 10^{-1},5\}, r{100,400}r\in\{100,400\}, batch size {10,100}\in\{10,100\}, optimization algorithm {adam,sgd}\in\{\text{adam},\text{sgd}\}. All the hyperparameter settings in this grid led to a similar or worse performance compared to the hyperparameter setting we choose.

Appendix B Additional experimental results

In Figure 7, we show that the trained model is quite robust to scaling the query input (while keeping the in-context examples fixed): the error does not increase much as we scale up the query input by a factor of up to 2, or scale down by a factor of up to 16, and degrades slowly after that.

B.2 Out-of-distribution prompts

Here, we describe the structure of our out-of-distribution prompts (cf. Section 4), and show the corresponding plots (Figure 8). To avoid conflating factors, we normalize the prompt inputs such that their expected norm is equal to the expected norm of inputs during training and investigate the role of scaling these inputs separately. We summarize how these prompts deviate from those seen during training in the table below.

(Figure 8(a)) We sample inputs from N(0,Σ)N(0,\Sigma) where Σ\Sigma is a skewed covariance matrix with eigenbasis chosen uniformly at random and ithi^{\text{th}} eigenvalue proportional to \nicefrac1i2\nicefrac{{1}}{{i^{2}}}.

Low-dimensional subspace.

(Figure 8(b)) We sample prompt inputs from a random d/2d/2 dimensional subspace. That is, we pick a random d/2d/2 dimensional subspace, and draw the prompt inputs from an isotropic Gaussian distribution restricted to this subspace. As a result, it is possible to achieve zero error after d/2d/2 in-context examples.

Prompt scale.

(Figure 9) We consider the setting where the prompt scale between training and inference is different. We either scale the prompt inputs or the weight vectors, by a factor {1/3,1/2,2,3}\{1/3,1/2,2,3\}.

Noisy linear regression.

(Figure 8(c)) We add noise to each prompt output, that is, the ithi^{\text{th}} output is equal to wTxi+ϵiw^{T}x_{i}+\epsilon_{i} where ϵiN(0,d/20)\epsilon_{i}\sim N(0,d/20).

Different orthants for in-context and query inputs.

(Figure 8(f)) We fix the sign of each coordinate to be positive or negative for all in-context inputs xix_{i} (at random), and draw xqueryx_{\text{query}} (as before) i.i.d. from N(0,Id)N(0,I_{d}). As a result, all in-context inputs lie in the same orthant, while the query input lies in another orthant with high probability.

Query input orthogonal to in-context inputs.

(Figure 8(d)) We choose the query input randomly in the space orthogonal to the space spanned by in-context example inputs. That is, we draw the query input from an isotropic Gaussian distribution restricted to the subspace orthogonal to the space spanned by the in-context examples. Thus, the optimal normalized error is 1 for any number of in-context examples (there can be at most d1d-1 in-context examples for an orthogonal query to exist).

Query input matches an in-context example.

(Figure 8(e)) We set the query input equal to one of the in-context examples chosen uniformly at random. Thus it’s possible to achieve zero error since the in-context examples include the correct prediction for the query input already.

B.3 Effect of problem dimension and model capacity

We plot the model error for additional out-of-distribution prompts in Figure 10 for 2d2d in-context examples (with the exception of orthogonal queries where we use d1d-1 in-context examples).

Similar to the settings in Section 6 (skewed covariance and different orthants), accuracy improves with capacity in most cases. One exception is scaling xx (Figure 10(e)), in which case we do not see any clear trend. In the case of noisy output (Figure 10(b)), the accuracy almost saturates at 1.2M parameters, close to the error of the least squares estimator. In the case of orthogonal query input (Figure 10(c)), the model achieves the optimum error of 11 even with the tiny model with 0.2M parameters.

B.4 Training variance

In Figure 12, we show the variance in error across training runs for the standard Transformer model (9.5M parameters). We plot the squared error for 3 models (with different random seeds) each for d{10,20,30,40,50}d\in\{10,20,30,40,50\}, trained to in-context learn linear functions. The error is quite concentrated in the standard setting as well as for most out-of-distribution prompts. In the different-orthants and skewed-covariance settings, we observe a high variance for higher dimensional problems (d30)(d\geq 30). However, in Section 6, we saw that the error in these settings usually decreases as we increase the model size. In the setting where we scale xx, there is high variance even when d=10d=10.

B.5 Curriculum

In Figure 13, we show the training loss of the Transformer model trained to in-context learn linear functions, with and without a curriculum. Specifically, given a random training prompt sequence (x1,f(x1)(x_{1},f(x_{1}), x2,f(x2)x_{2},f(x_{2}), \ldots, xkcur,f(xkcur))x_{k_{\text{cur}}},f(x_{k_{\text{cur}}})), let yi^\hat{y_{i}} be the model’s prediction for the ithi^{\text{th}} input (meant to approximate f(xi)f(x_{i})). For each such prompt, we consider the loss given by the normalized squared error averaged over all prompt prefixes

At each training step, we plot the loss averaged over a batch of 6464 random prompts. For training with curriculum, kcurk_{\text{cur}} is gradually increased to 2d+12d+1 as described in Section A.2. For training without curriculum kcur=2d+1k_{\text{cur}}=2d+1 at all times.

Note that the loss often increases in the beginning as we train the model with curriculum. This is due to a sharp increase in the loss at steps where we increase the effective dimensionality (dcurd_{\text{cur}}) of prompt inputs (xix_{i}). There are two reasons for this increase: (i) variance of the target output (f(xi)=wxif(x_{i})=w^{\top}x_{i}) increases, so even the optimum loss is larger, (ii) the model performance is worse for the prompt inputs with increased effective dimension. After each such step where we increment dcurd_{\text{cur}}, the loss starts to decrease again until the next increment. The overall trend in the loss looks upward when the sharp increase dominates the decrease that follows. Some observations worth highlighting are as follows.

For functions in 20 or more dimensions, curriculum allows us to train a low-error model often 4 times faster. Moreover, training without curriculum does not always succeed within our training budget (500k steps), e.g., for one run with d=30d=30 and all runs with d=50d=50, the loss does not decrease at all without curriculum.

Initial lull without curriculum.

For training without curriculum, we observe that the loss does not decrease for relatively a long period in the beginning, and starts to decrease sharply thereafter. There is a large variance in the length of this period for any fixed dimension, and the average length seems to increase with dimension. This period is almost non-existent for smaller dimensions (e.g., see the plot for d=10d=10), and therefore we do not observe such a period while training with curriculum where we start training with inputs lying in a 5 dimensional subspace.

Curriculum does not affect final performance significantly.

For our core setting (d=20d=20), where we are able to train the model to low error even without curriculum, we do not observe any qualitative differences in the error in most cases (both with and without distribution shifts). One exception is the case with skewed covariance, where the model trained without curriculum seems to do slightly better. We plot the error curves for the standard, different orthants and skewed covariance cases in Figure 14.

B.6 Effect of number of distinct prompts/functions seen during training

Here, we investigate the effect of amount of training data required for in-context learning linear functions.

First, we consider the effect of number of distinct prompts encountered during training. For this, we create a set SpS_{p} of npn_{p} randomly generated prompts, where each prompt in SpS_{p} is generated by sampling a weight vector and prompt inputs from N(0,Id)N(0,I_{d}). We generate random prompts during training by sampling uniformly from this set. As before, we train the model for 500k500k steps with a batch size of 6464. We observe (see Figure 15) that a model trained with np=100kn_{p}=100k is able to achieve non-trivial error and a model trained with np=1Mn_{p}=1M achieves error close to that of the unrestricted model (trained with 32M32M distinct prompts). Recall that with curriculum learning, we zero out some of the coordinates of prompt inputs in the beginning of training, which will increase the total number of prompts the model sees during training. Therefore we do not use curriculum learning for this study to avoid inflating the number of distinct prompts seen during training.

Second, we consider the effect of number of distinct weight vectors (equivalently, distinct functions) encountered during training. For this, we create a set SwS_{w} of nwn_{w} weight vectors where each weight ww is drawn i.i.d. from N(0,Id)N(0,I_{d}). To generate a training prompt, (x1,wx1,,xk,wxk)(x_{1},w^{\top}x_{1},\ldots,x_{k},w^{\top}x_{k}), we draw prompt inputs (xix_{i}s) i.i.d. from N(0,Id)N(0,I_{d}) as in the unrestricted setting, and sample ww uniformly at random from SwS_{w}. Thus while we sample from a finite set of weight vectors, we sample fresh inputs at each step. As before, we train the model for 500k500k steps with a batch size of 6464. Here, we observe (see Figure 15) that the model trained with as few as 10k10k distinct weight vectors achieves error close to the unrestricted model (trained with 32M32M distinct functions). We use curriculum learning for this study as in our standard setting. Recall that with curriculum learning, we only zero out some coordinates of prompt inputs in the beginning, so this does not change the number of distinct weight vectors seen by the model during training.

B.7 Can memorization explain model performance?

In Section 3.1, we discussed that memorization of prompts seen during training cannot explain model performance. This is because the probability of the model encountering a training prompt similar to the one used for testing is astronomically low—the prompt inputs alone lie in a 800-dimensional space when predicting with 2d2d in-context examples (d=20d=20).

Moreover, even considering the possibility that the model encountered a similar weight vector during training cannot explain its performance. Let SwS_{w} be the set of weight vectors used to generate training prompts. At inference time, given a prompt with in-context examples generated using a weight vector ww_{\star}, suppose the model is somehow able to find the best weight vector w^\hat{w} in SwS_{w} minimizing the normalized squared error on query inputs:

Taking expectation over the weight vector ww_{\star}, we get the expected normalized squared error of the model (with respect to randomly drawn in-context examples and query inputs):

To empirically estimate this quantity, we sample nwn_{w} weight vectors from N(0,Id)N(0,I_{d}) (with d=20d=20) that form the set SwS_{w}, and 500500 weight vectors from N(0,Id)N(0,I_{d}) to estimate the outer expectation. We do this 20 times, freshly sampling the 500500 weight vectors and the vectors comprising SwS_{w} each time, and compute the mean of the 20 estimates obtained. When nw=32Mn_{w}=32M (number of weight vectors encountered in our standard training setup), we get a mean of 0.216 (standard deviation 0.004). However, our model is able to achieve an expected error of less than 0.0010.001 for prompts with 2d2d in-context examples. Similarly, when nw=10,000n_{w}=10,000, we get a mean of 0.505 (standard deviation 0.006), while a model trained on prompts generated using 10,00010,000 distinct weight vectors is able to achieve a much smaller error (see Figure 15).

Thus we can conclude that the model cannot be relying on memorization of the training prompts or weight vectors, and is encoding a more sophisticated algorithm capable of in-context learning linear functions that are very different from those seen during training.