What learning algorithm is in-context learning? Investigations with linear models

Ekin Akyürek, Dale Schuurmans, Jacob Andreas, Tengyu Ma, Denny Zhou

Introduction

One of the most surprising behaviors observed in large neural sequence models is in-context learning (ICL; Brown et al., 2020). When trained appropriately, models can map from sequences of (x,f(x))(x,f(x)) pairs to accurate predictions f(x)f(x^{\prime}) on novel inputs xx^{\prime}. This behavior occurs both in models trained on collections of few-shot learning problems (Chen et al., 2022; Min et al., 2022) and surprisingly in large language models trained on open-domain text (Brown et al., 2020; Zhang et al., 2022; Chowdhery et al., 2022). ICL requires a model to implicitly construct a map from in-context examples to a predictor without any updates to the model’s parameters themselves. How can a neural network with fixed parameters to learn a new function from a new dataset on the fly?

This paper investigates the hypothesis that some instances of ICL can be understood as implicit implementation of known learning algorithms: in-context learners encode an implicit, context-dependent model in their hidden activations, and train this model on in-context examples in the course of computing these internal activations. As in recent investigations of empirical properties of ICL (Garg et al., 2022; Xie et al., 2022), we study the behavior of transformer-based predictors (Vaswani et al., 2017) on a restricted class of learning problems, here linear regression. Unlike in past work, our goal is not to understand what functions ICL can learn, but how it learns these functions: the specific inductive biases and algorithmic properties of transformer-based ICL.

In Section 3, we investigate theoretically what learning algorithms transformer decoders can implement. We prove by construction that they require only a modest number of layers and hidden units to train linear models: for dd-dimensional regression problems, with O(d)\mathcal{O}(d) hidden size and constant depth, a transformer can implement a single step of gradient descent; and with O(d2)\mathcal{O}(d^{2}) hidden size and constant depth, a transformer can update a ridge regression solution to include a single new observation. Intuitively, nn steps of these algorithms can be implemented with nn times more layers.

In Section 4, we investigate empirical properties of trained in-context learners. We begin by constructing linear regression problems in which learner behavior is under-determined by training data (so different valid learning rules will give different predictions on held-out data). We show that model predictions are closely matched by existing predictors (including those studied in Section 3), and that they transition between different predictors as model depth and training set noise vary, behaving like Bayesian predictors at large hidden sizes and depths. Finally, in Section 5, we present preliminary experiments showing how model predictions are computed algorithmically. We show that important intermediate quantities computed by learning algorithms for linear models, including parameter vectors and moment matrices, can be decoded from in-context learners’ hidden activations.

A complete characterization of which learning algorithms are (or could be) implemented by deep networks has the potential to improve both our theoretical understanding of their capabilities and limitations, and our empirical understanding of how best to train them. This paper offers first steps toward such a characterization: some in-context learning appears to involve familiar algorithms, discovered and implemented by transformers from sequence modeling tasks alone.

Preliminaries

Training a machine learning model involves many decisions, including the choice of model architecture, loss function and learning rule. Since the earliest days of the field, research has sought to understand whether these modeling decisions can be automated using the tools of machine learning itself. Such “meta-learning” approaches typically treat learning as a bi-level optimization problem (Schmidhuber et al., 1996; Andrychowicz et al., 2016; Finn et al., 2017): they define “inner” and “outer” models and learning procedures, then train an outer model to set parameters for an inner procedure (e.g. initializer or step size) to maximize inner model performance across tasks.

Recently, a more flexible family of approaches has gained popularity. In in-context learning (ICL), meta-learning is reduced to ordinary supervised learning: a large sequence model (typically implemented as a transformer network) is trained to map from sequences [x1,f(x1),x2,f(x2),...,xn][x_{1},f(x_{1}),x_{2},f(x_{2}),...,x_{n}] to predictions f(xn)f(x_{n}) (Brown et al., 2020; Olsson et al., 2022; Laskin et al., 2022; Kirsch & Schmidhuber, 2021). ICL does not specify an explicit inner learning procedure; instead, this procedure exists only implicitly through the parameters of the sequence model. ICL has shown impressive results on synthetic tasks and naturalistic language and vision problems (Garg et al., 2022; Min et al., 2022; Zhou et al., 2022; Hollmann et al., 2022).

Past work has characterized what kinds of functions ICL can learn (Garg et al., 2022; Laskin et al., 2022; Müller et al., 2021) and the distributional properties of pretraining that can elicit in-context learning (Xie et al., 2021; Chan et al., 2022). But how ICL learns these functions has remained unclear. What learning algorithms (if any) are implementable by deep network models? Which algorithms are actually discovered in the course of training? This paper takes first steps toward answering these questions, focusing on a widely used model architecture (the transformer) and an extremely well-understood class of learning problems (linear regression).

Transformers (Vaswani et al., 2017) are neural network models that map a sequence of input vectors x=[x1,,xn]{\bm{x}}=[x_{1},\ldots,x_{n}] to a sequence of output vectors y=[y1,,yn]{\bm{y}}=[y_{1},\ldots,y_{n}]. Each layer in a transformer maps a matrix H(l)H^{(l)} (interpreted as a sequence of vectors) to a sequence H(l+1)H^{(l+1)}. To do so, a transformer layer processes each column hi(l){\bm{h}}_{i}^{(l)} of H(l)H^{(l)} in parallel. Here, we are interested in autoregressive (or “decoder-only”) transformer models in which each layer first computes a self-attention:

then applies a feed-forward transformation:

Here σ\sigma denotes a nonlinearity, e.g. a Gaussian error linear unit (GeLU; Hendrycks & Gimpel, 2016):

and λ\lambda denotes layer normalization (Ba et al., 2016):

where the expectation and variance are computed across the entries of x{\bm{x}}. To map from x{\bm{x}} to y{\bm{y}}, a transformer applies a sequence of such layers, each with its own parameters. We use θ\theta to denote a model’s full set of parameters (the complete collection of WW matrices across layers). The three main factors governing the computational capacity of a transformer are its depth (the number of layers), its hidden size (the dimension of the vectors h{\bm{h}}), and the number of heads (denoted mm above).

2 Training for In-Context Learning

We study transformer models directly trained on an ICL objective. (Some past work has found that ICL also “emerges” in models trained on general text datasets; Brown et al., 2020.) To train a transformer TT with parameters θ\theta to perform ICL, we first define a class of functions F\mathcal{F}, a distribution p(f)p(f) supported on F\mathcal{F}, a distribution p(x)p(x) over the domain of functions in F\mathcal{F}, and a loss function L\mathcal{L}. We then choose θ\theta to optimize the auto-regressive objective, where the resulting TθT_{\theta} is an in-context learner:

3 Linear Regression

What learning algorithms can a transformer implement?

For a transformer-based model to solve Eq. 9 by implementing an explicit learning algorithm, that learning algorithm must be implementable via Eq. 1 and Eq. 4 with some fixed choice of transformer parameters θ\theta. In this section, we prove constructively that such parameterizations exist, giving concrete implementations of two standard learning algorithms. These proofs yield upper bounds on how many layers and hidden units suffice to implement (though not necessarily learn) each algorithm. Proofs are given in Appendices A, LABEL: and B.

mov(H;s,t,i,j,i,j)\textbf{{mov}}(H;s,t,i,j,i^{\prime},j^{\prime}): selects the entries of the sths^{\textrm{th}} column of HH between rows ii and jj, and copies them into the ttht^{\textrm{th}} column (tst\geq s) of HH between rows ii^{\prime} and jj^{\prime}, yielding the matrix:

mul(H;a,b,c,(i,j),(i,j),(i,j))\textbf{{mul}}(H;a,b,c,(i,j),(i^{\prime},j^{\prime}),(i^{\prime\prime},j^{\prime\prime})): in each column h{\bm{h}} of HH, interprets the entries between ii and jj as an a×ba\times b matrix A1A_{1}, and the entries between ii^{\prime} and jj^{\prime} as a b×cb\times c matrix A2A_{2}, multiplies these matrices together, and stores the result between rows ii^{\prime\prime} and jj^{\prime\prime}, yielding a matrix in which each column has the form [h:i1,A1A2,hj:][{\bm{h}}_{:i^{\prime\prime}-1},A_{1}A_{2},{\bm{h}}_{j^{\prime\prime}:}]^{\top}.

div(H;(i,j),i,(i,j))\textbf{{div}}(H;(i,j),i^{\prime},(i^{\prime\prime},j^{\prime\prime})): in each column h{\bm{h}} of HH, divides the entries between ii and jj by the absolute value of the entry at ii^{\prime}, and stores the result between rows ii^{\prime\prime} and jj^{\prime\prime}, yielding a matrix in which every column has the form [h:i1,hi:j/hi,hj:][{\bm{h}}_{:i^{\prime\prime}-1},{\bm{h}}_{i:j}/|{\bm{h}}_{i^{\prime}}|,{\bm{h}}_{j^{\prime\prime}:}]^{\top}.

aff(H;(i,j),(i,j),(i,j),W1,W2,b)\textbf{{aff}}(H;(i,j),(i^{\prime},j^{\prime}),(i^{\prime\prime},j^{\prime\prime}),W_{1},W_{2},b): in each column h{\bm{h}} of HH, applies an affine transformation to the entries between ii and jj and ii^{\prime} and jj^{\prime}, then stores the result between rows ii^{\prime\prime} and jj^{\prime\prime}, yielding a matrix in which every column has the form [h:i1,W1hi:j+W2hi:j+b,hj:][{\bm{h}}_{:i^{\prime\prime}-1},W_{1}{\bm{h}}_{i:j}+W_{2}{\bm{h}}_{i^{\prime}:j^{\prime}}+b,{\bm{h}}_{j^{\prime\prime}:}]^{\top}.

Each of mov, mul, div and aff can be implemented by a single transformer decoder layer: in Eq. 1 and Eq. 4, there exist matrices WQW^{Q}, WKW^{K}, WVW^{V}, WFW^{F}, W1W_{1} and W2W_{2} such that, given a matrix HH as input, the layer’s output has the form of the corresponding function output above. We omit the trivial size preconditions, e.g. mul: (ij=ab,ij=bc,ij=cd)\textbf{{mul}: }(i-j=a*b,i^{\prime}-j^{\prime}=b*c,i^{\prime\prime}-j^{\prime\prime}=c*d).

With these operations, we can implement building blocks of two important learning algorithms.

2 Gradient descent

Rather than directly solving linear regression problems by evaluating Eq. 10, a standard approach to learning exploits a generic loss minimization framework, and optimizes the ridge-regression objective in Eq. 9 via gradient descent on parameters w{\bm{w}}. This involves repeatedly computing updates:

for different examples (xi,yi)({\bm{x}}_{i},y_{i}), and finally predicting wxn{\bm{w}}^{\prime\top}{\bm{x}}_{n} on a new input xnx_{n}. A step of this gradient descent procedure can be implemented by a transformer:

A transformer can compute Eq. 11 (i.e. the prediction resulting from single step of gradient descent on an in-context example) with constant number of layers and O(d)O(d) hidden space, where dd is the problem dimension of the input xx. Specifically, there exist transformer parameters θ\theta such that, given an input matrix of the form:

the transformer’s output matrix H(L)H^{(L)} contains an entry equal to wxn{\bm{w}}^{\prime\top}{\bm{x}}_{n} (Eq. 11) at the column index where xnx_{n} is input.

3 Closed-form regression

Another way to solve the linear regression problem is to directly compute the closed-form solution Eq. 10. This is somewhat challenging computationally, as it requires inverting the regularized covariance matrix XX+λIX^{\top}X+\lambda I. However, one can exploit the Sherman–Morrison formula (Sherman & Morrison, 1950) to reduce the inverse to a sequence of rank-one updates performed example-by-example. For any invertible square AA,

Because the covariance matrix XXX^{\top}X in Eq. 10 can be expressed as a sum of rank-one terms each involving a single training example xi{\bm{x}}_{i}, this can be used to construct an iterative algorithm for computing the closed-form ridge-regression solution.

A transformer can predict according to a single Sherman–Morrison update:

with constant layers and O(d2)\mathcal{O}(d^{2}) hidden space. More precisely, there exists a set of transformer parameters θ\theta such that, given an input matrix of the form in Eq. 12, the transformer’s output matrix H(L)H^{(L)} contains an entry equal to wxn{\bm{w}}^{\prime\top}x_{n} (Eq. 14) at the column index where xnx_{n} is input.

There are various existing universality results for transformers (Yun et al., 2020; Wei et al., 2021), and for neural networks more generally (Hornik et al., 1989). These generally require very high precision, very deep models, or the use of an external “tape”, none of which appear to be important for in-context learning in the real world. Results in this section establish sharper upper bounds on the necessary capacity required to implement learning algorithms specifically, bringing theory closer to the range where it can explain existing empirical findings. Different theoretical constructions, in the context of meta-learning, have been shown for linear self-attention models (Schlag et al., 2021), or for other neural architectures such as recurrent neural networks (Kirsch & Schmidhuber, 2021). We emphasize that Theorem 1 and Theorem 2 each show the implementation of a single step of an iterative algorithm; these results can be straightforwardly generalized to the multi-step case by “stacking” groups of transformer layers. As described next, it is these iterative algorithms that capture the behavior of real learners.

What computation does an in-context learner perform?

The previous section showed that the building blocks for two specific procedures—gradient descent on the least-squares objective and closed-form computation of its minimizer—are implementable by transformer networks. These constructions show that, in principle, fixed transformer parameterizations are expressive enough to simulate these learning algorithms. When trained on real datasets, however, in-context learners might implement other learning algorithms. In this section, we investigate the empirical properties of trained in-context learners in terms of their behavior. In the framework of Marr’s (2010) “levels of analysis”, we aim to explain ICL at the computational level by identifying the kind of algorithms to regression problems that transformer-based ICL implements.

Determining which learning algorithms best characterize ICL predictions requires first quantifying the degree to which two predictors agree. We use two metrics to do so:

Given any learning algorithm A\mathcal{A} that maps from a set of input–output pairs D=[x1,y1,,xn,yn]D=[{\bm{x}}_{1},y_{1},\ldots,{\bm{x}}_{n},y_{n}] to a predictor f(x)=A(D)(x)f({\bm{x}})=\mathcal{A}(D)({\bm{x}}), we define the squared prediction difference (SPD):

where DD is sampled as in Eq. 8. SPD measures agreement at the output level, regardless of the algorithm used to compute this output.

When ground-truth predictors all belong to a known, parametric function class (as with the linear functions here), we may also investigate the extent to which different learners agree on the parameters themselves. Given an algorithm A\mathcal{A}, we sample a context dataset DD as above, and an additional collection of unlabeled test inputs DX={xi}D_{\mathcal{X}^{\prime}}=\{{\bm{x}}_{i}^{\prime}\}. We then compute A\mathcal{A}’s prediction on each xix^{\prime}_{i}, yielding a predictor-specific dataset D_{\mathcal{A}}=\{({\bm{x}}_{i}^{\prime},\hat{y}_{i})\}=\big{\{}\big{(}{\bm{x}}_{i},\mathcal{A}(D)({\bm{x}}_{i}^{\prime})\big{)}\big{\}} encapsulating the function learned by A\mathcal{A}. Next we compute the implied parameters:

We can then quantify agreement between two predictors A1\mathcal{A}_{1} and A2\mathcal{A}_{2} by computing the distance between their implied weights in expectation over datasets:

When the predictors are not linear, ILWD measures the difference between the closest linear predictors (in the sense of Eq. 16) to each algorithm. For algorithms that have linear hypothesis space (e.g. Ridge regression), we will use the actual value of w^A\hat{{\bm{w}}}_{\mathcal{A}} instead of the estimated value.

2 Experimental Setup

We train a Transformer decoder autoregresively on the objective in Eq. 8. For all experiments, we perform a hyperparameter search over depth L{1,2,4,8,12,16}L\in\{1,2,4,8,12,16\}, hidden size W{16,32,64,256,512,1024}W\in\{16,32,64,256,512,1024\} and heads M{1,2,4,8}M\in\{1,2,4,8\}. Other hyper-parameters are noted in Appendix D. For our main experiments, we found that L=16,H=512,M=4L=16,H=512,M=4 minimized loss on a validation set. We follow the training guidelines in Garg et al. (2022), and trained models for 500,000500,000 iterations, with each in-context dataset consisting of 40 (x,y)({\bm{x}},y) pairs. For the main experiments we generate data according to p(w)=N(0,I)p({\bm{w}})=\mathcal{N}(0,I) and p(x)=N(0,I)p({\bm{x}})=\mathcal{N}(0,I).

3 Results

We begin by comparing a (L=16,H=512,M=4)(L=16,H=512,M=4) transformer against a variety of reference predictors:

kk-nearest neighbors: In the uniform variant, models predict y^i=13jyj\hat{y}_{i}=\frac{1}{3}\sum_{j}y_{j}, where jj is the top-3 closest data point to xix_{i} where j<ij<i. In the weighted variant, a weighted average y^i13jxixj2yj\hat{y}_{i}\propto\frac{1}{3}\sum_{j}|x_{i}-x_{j}|^{-2}y_{j} is calculated, normalized by the total weights of the yjy_{j}s.

One-pass stochastic gradient descent: y^i=wixi\hat{y}_{i}={\bm{w}}_{i}^{\top}x_{i} where wi{\bm{w}}_{i} is obtained by stochastic gradient descent on the previous examples with batch-size equals to 1: wi ⁣= ⁣wi12α(xi1wi1xi1xi1yi1+λwi1){\bm{w}}_{i}\!=\!{\bm{w}}_{i-1}-2\alpha(x_{i-1}^{\top}{\bm{w}}^{\top}_{i-1}x_{i-1}-x_{i-1}^{\top}y_{i-1}+\lambda w_{i-1}).

One-step batch gradient descent: y^i=wixi\hat{y}_{i}={\bm{w}}_{i}^{\top}x_{i} where wi{\bm{w}}_{i} is obtained by one of step gradient descent on the batch of previous examples: wi ⁣= ⁣w02α(XwXXY+λw0){\bm{w}}_{i}\!=\!{\bm{w}}_{0}-2\alpha(X^{\top}{\bm{w}}^{\top}X-X^{\top}Y+\lambda w_{0}).

Ridge regression: We compute y^i=wxi\hat{y}_{i}={\bm{w}}^{\prime\top}x_{i} where w=(XX+λI)1XY{\bm{w}}^{\prime\top}=(X^{\top}X+\lambda I)^{-1}X^{\top}Y. We denote the case of λ=0\lambda=0 as OLS.

The agreement between the transformer-based ICL and these predictors is shown in Fig. 1(a). As can be seen, there are clear differences in fit to predictors: for almost any number of examples, normalized SPD and ILWD are small between the transformer and OLS predictor (with squared error less than 0.01), while other predictors (especially nearest neighbors) agree considerably less well.

When the number of examples is less than the input dimension d=8d=8, the linear regression problem is under-determined, in the sense that multiple linear models can exactly fit the in-context training dataset. In these cases, OLS regression selects the minimum-norm weight vector, and (as shown in Fig. 1(a)), the in-context learner’s predictions are reliably consistent with this minimum-norm predictor. Why, when presented with an ambiguous dataset, should ICL behave like this particular predictor? One possibility is that, because the weights used to generate the training data are sampled from a Gaussian centered at zero, ICL learns to output the minimum Bayes risk solution when predicting under uncertainty (see Müller et al. (2021)). Building on these initial findings, our next set of experiments investigates whether ICL is behaviorally equivalent to Bayesian inference more generally.

To more closely examine the behavior of ICL algorithms under uncertainty, we add noise to the training data: now we present the in-context dataset as a sequence: [x1,f(x1)+ϵ1,,xn,f(xn)+ϵn][{\bm{x}}_{1},f({\bm{x}}_{1})+\epsilon_{1},\ldots,{\bm{x}}_{n},f({\bm{x}}_{n})+\epsilon_{n}] where each ϵiN(0,σ2)\epsilon_{i}\sim\mathcal{N}(0,\sigma^{2}). Recall that ground-truth weight vectors are themselves sampled from a Gaussian distribution; together, this choice of prior and noise mean that the learner cannot be certain about the target function with any number of examples.

Standard Bayesian statistics gives that the optimal predictor for minimizing the loss in Eq. 8 is:

Note that this predictor has the same form as the ridge predictor from Section 2.3, with the regularization parameter set to σ2τ2\frac{\sigma^{2}}{\tau^{2}}. In the presence of noisy labels, does ICL match this Bayesian predictor? We explore this by varying both the dataset noise σ2\sigma^{2} and the prior variance τ2\tau^{2} (sampling wN(0,τ2)){\bm{w}}\sim\mathcal{N}(0,\tau^{2})). For these experiments, the SPD values between the in-context learner and various regularized linear models is shown in Fig. 2. As predicted, as variance increases, the value of the ridge parameter that best explains ICL behavior also increases. For all values of σ2,τ2\sigma^{2},\tau^{2}, the ridge parameter that gives the best fit to the transformer behavior is also the one that minimizes Bayes risk. These experiments clarify the finding above, showing that ICL in this setting behaviorally matches minimum-Bayes-risk predictor. We also note that when the noise level σ0+\sigma\rightarrow 0^{+}, the Bayes predictor converges to the ordinary least square predictor. Therefore, the results on noiseless datasets studied in the beginning paragraph of this subsection can be viewed as corroborating the finding here in the setting with σ0+\sigma\rightarrow 0^{+}.

The two experiments above evaluated extremely high-capacity models in which (given findings in Section 3) computational constraints are not likely to play a role in the choice of algorithm implemented by ICL. But what about smaller models—does the size of an in-context learner play a role in determining the learning algorithm it implements? To answer this question, we run two final behavioral experiments: one in which we vary the hidden size (while optimizing the depth and number of heads as in Section 4.2), then vary the depth of the transformer (while optimizing the hidden size and number of heads). These experiments are conducted without dataset noise.

Results are shown in Fig. 3(a). When we vary the depth, learners occupy three distinct regimes: very shallow models (1L) are best approximated by a single step of gradient descent (though not well-approximated in an absolute sense). Slightly deeper models (2L-4L) are best approximated by ridge regression, while the deepest (+8L) models match OLS as observed in Fig. 3(a). Similar phase shift occurs when we vary hidden size in a 16D problem. Interestingly, we can read hidden size requirements to be close to ridge-regression-like solutions as H16H\geq 16 and H32H\geq 32 for 8D and 16D problems respectively, suggesting that ICL discovers more efficient ways to use available hidden state than our theoretical constructions requiring O(d2)\mathcal{O}(d^{2}). Together, these results show that ICL does not necessarily involve minimum-risk prediction. However, even in models too computationally constrained to perform Bayesian inference, alternative interpretable computations can emerge.

Does ICL encode meaningful intermediate quantities?

Section 4 showed that transformers are a good fit to standard learning algorithms (including those constructed in Section 3) at the computational level. But these experiments leave open the question of how these computations are implemented at the algorithmic level. How do transformers arrive at the solutions in Section 4, and what quantities do they compute along the way? Research on extracting precise algorithmic descriptions of learned models is still in its infancy (Cammarata et al., 2020; Mu & Andreas, 2020). However, we can gain insight into ICL by inspecting learners’ intermediate states: asking what information is encoded in these states, and where.

To do so, we identify two intermediate quantities that we expect to be computed by gradient descent and ridge-regression variants: the moment vector XYX^{\top}Y and the (min-norm) least-square estimated weight vector wOLS{\bm{w}}_{\textrm{OLS}}, each calculated after feeding nn exemplars. We take a trained in-context learner, freeze its weights, then train an auxiliary probing model (Alain & Bengio, 2016) to attempt to recover the target quantities from the learner’s hidden representations. Specifically, the probe model takes hidden states at a layer H(l)H^{(l)} as input, then outputs the prediction for target variable. We define a probe with position-attention that computes (Appendix E):

We train this probe to minimize the squared error between the predictions and targets v{\bm{v}}: L(v,v^)=vv^2\mathcal{L}({\bm{v}},\hat{{\bm{v}}})=|{\bm{v}}-\hat{{\bm{v}}}|^{2}. The probe performs two functions simultaneously: its prediction error on held-out representations determines the extent to which the target quantity is encoded, while its attention mask, α\bm{\alpha} identifies the location in which the target quantity is encoded. For the FF term, we can insert the function approximator of our choosing; by changing this term we can determine the manner in which the target quantity is encoded—e.g. if FF is a linear model and the probe achieves low error, then we may infer that the target is encoded linearly.

For each target, we train a separate probe for the value of the target on each prefix of the dataset: i.e. one probe to decode the value of w{\bm{w}} computed from a single training example, a second probe to decode the value for two examples, etc. Results are shown in Fig. 4. For both targets, a 2-layer MLP probe outperforms a linear probe, meaning that these targets are encoded nonlinearly (unlike the constructions in Section 3). However, probing also reveals similarities. Both targets are decoded accurately deep in the network (but inaccurately in the input layer, indicating that probe success is non-trivial.) Probes attend to the correct timestamps when decoding them. As in both constructions, XYX^{\top}Y appears to be computed first, becoming predictable by the probe relatively early in the computation (layer 7); while w{\bm{w}} becomes predictable later (around layer 12). For comparison, we additionally report results on a control task in which the transformer predicts yys generated with a fixed weight vector w=1w=\mathbf{1} (so no ICL is required). Probes applied to these models perform significantly worse at recovering moment matrices (see Appendix E for details).

Conclusion

We have presented a set of experiments characterizing the computations underlying in-context learning of linear functions in transformer sequence models. We showed that these models are capable in theory of implementing multiple linear regression algorithms, that they empirically implement this range of algorithms (transitioning between algorithms depending on model capacity and dataset noise), and finally that they can be probed for intermediate quantities computed by these algorithms.

While our experiments have focused on the linear case, they can be extended to many learning problems over richer function classes—e.g. to a network whose initial layers perform a non-linear feature computation. Even more generally, the experimental methodology here could be applied to larger-scale examples of ICL, especially language models, to determine whether their behaviors are also described by interpretable learning algorithms. While much work remains to be done, our results offer initial evidence that the apparently mysterious phenomenon of in-context learning can be understood with the standard ML toolkit, and that the solutions to learning problems discovered by machine learning researchers may be discovered by gradient descent as well.

Acknowledgements

We thank Evan Hernandez, Andrew Drozdov, Ed Chi for their feedback on the early drafts of this paper. At MIT, Ekin Akyürek is supported by an MIT-Amazon ScienceHub fellowship and by the MIT-IBM Watson AI Lab.

References

Appendix A Theorem 1

The operations for 1-step SGD with single exemplar can be expressed as following chain (please see proofs for the Transformer implementation of these operations (Lemma 1) in Appendix C):

mov(;1,0,(1,1+d),(1,1+d))\texttt{mov}(;1,0,(1,1+d),(1,1+d)) (move x{\bm{x}})

aff(;(1,1+d),(),(1+d,2+d),W1=w)\texttt{aff}(;(1,1+d),(),(1+d,2+d),W_{1}={\bm{w}}) (wx{\bm{w}}^{\top}{\bm{x}})

aff(;(1+d,2+d),(0,1),(2+d,3+d),W1=I,W2=I)\texttt{aff}(;(1+d,2+d),(0,1),(2+d,3+d),W_{1}=I,W_{2}=-I) (wxy{\bm{w}}^{\top}{\bm{x}}-y)

mul(;d,1,1,(1,1+d),(2+d,3+d),(3+d,3+2d))\texttt{mul}(;d,1,1,(1,1+d),(2+d,3+d),(3+d,3+2d)) (x(wxy){\bm{x}}({\bm{w}}^{\top}{\bm{x}}-y))

aff(;(),(),(3+2d,3+3d),b=w,)\texttt{aff}(;(),(),(3+2d,3+3d),b={\bm{w}},) (write w{\bm{w}})

aff(;(3+d,3+2d),(3+2d,3+3d),(3+3d,3+4d),W1=I,W2=λ)\texttt{aff}(;(3+d,3+2d),(3+2d,3+3d),(3+3d,3+4d),W_{1}=I,W_{2}=-\lambda) (x(wxy)λw{\bm{x}}({\bm{w}}^{\top}{\bm{x}}-y)-\lambda{\bm{w}})

aff(;(3+2d,3+3d),(3+3d,3+4d),(3+2d,3+3d),W1=I,W2=2α,)\texttt{aff}(;(3+2d,3+3d),(3+3d,3+4d),(3+2d,3+3d),W_{1}=I,W_{2}=-2\alpha,) (w{\bm{w}}^{\prime})

mov(;2,1,(3+2d,3+3d),(3+2d,3+3d))\texttt{mov}(;2,1,(3+2d,3+3d),(3+2d,3+3d)) (move w{\bm{w}}^{\prime})

mul(;1,d,1,(3+2d,3+3d),(1,1+d),(3+3d,4+3d))\texttt{mul}(;1,d,1,(3+2d,3+3d),(1,1+d),(3+3d,4+3d)) (wx2{{\bm{w}}^{\prime}}^{\top}x_{2})

We can verify the chain of operator step-by-step. In each step, we show only the non-zero rows.

mov(;1,0,(1,1+d),(1,1+d))\texttt{mov}(;1,0,(1,1+d),(1,1+d)) (move x{\bm{x}})

[0y10x10x2][0y10x1x1x2]\begin{bmatrix}0&y_{1}&0\\ x_{1}&0&x_{2}\\ \end{bmatrix}\mapsto\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ \end{bmatrix}

aff(;(1,1+d),(),(1+d,2+d),W1=w)\texttt{aff}(;(1,1+d),(),(1+d,2+d),W_{1}={\bm{w}}) (wx{\bm{w}}^{\top}{\bm{x}})

[0y10x1x1x2][0y10x1x1x2wx1wx1wx2]\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ \end{bmatrix}\mapsto\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}&w^{\top}x_{2}\\ \end{bmatrix}

aff(;(1+d,2+d),(0,1),(2+d,3+d),W1=I,W2=I)\texttt{aff}(;(1+d,2+d),(0,1),(2+d,3+d),W_{1}=I,W_{2}=-I) (wxy{\bm{w}}^{\top}{\bm{x}}-y)

[0y10x1x1x2wx1wx1wx2][0y10x1x1x2wx1wx1wx2wx1wx1y1wx2]\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}&w^{\top}x_{2}\\ \end{bmatrix}\mapsto\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}&w^{\top}x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}-y_{1}&w^{\top}x_{2}\\ \end{bmatrix}

mul(;d,1,1,(1,1+d),(2+d,3+d),(3+d,3+2d))\texttt{mul}(;d,1,1,(1,1+d),(2+d,3+d),(3+d,3+2d)) (x(wxy){\bm{x}}({\bm{w}}^{\top}{\bm{x}}-y))

[0y10x1x1x2wx1wx1wx2wx1wx1y1wx2][0y10x1x1x2wx1wx1wx2wx1wx1ywx2x1wx1x1(wx1y)x2wx1]\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}&w^{\top}x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}-y_{1}&w^{\top}x_{2}\\ \end{bmatrix}\mapsto\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}&w^{\top}x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}-y&w^{\top}x_{2}\\ x_{1}w^{\top}x_{1}&x_{1}(w^{\top}x_{1}-y)&x_{2}w^{\top}x_{1}\\ \end{bmatrix}

aff(;(),(),(3+2d,3+3d),b=w,)\texttt{aff}(;(),(),(3+2d,3+3d),b={\bm{w}},) (write w{\bm{w}})

[0y10x1x1x2wx1wx1wx2wx1wx1ywx2x1wx1x1(wx1y)x2wx1][0y10x1x1x2wx1wx1wx2wx1wx1ywx2x1wx1x1(wx1y)x2wx1www]\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}&w^{\top}x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}-y&w^{\top}x_{2}\\ x_{1}w^{\top}x_{1}&x_{1}(w^{\top}x_{1}-y)&x_{2}w^{\top}x_{1}\\ \end{bmatrix}\mapsto\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}&w^{\top}x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}-y&w^{\top}x_{2}\\ x_{1}w^{\top}x_{1}&x_{1}(w^{\top}x_{1}-y)&x_{2}w^{\top}x_{1}\\ w&w&w\\ \end{bmatrix}

aff(;(3+d,3+2d),(3+2d,3+3d),(3+3d,3+4d),W1=I,W2=2λ)\texttt{aff}(;(3+d,3+2d),(3+2d,3+3d),(3+3d,3+4d),W_{1}=I,W_{2}=-2\lambda) (x(wxy)2λw{\bm{x}}({\bm{w}}^{\top}{\bm{x}}-y)-2\lambda{\bm{w}})

[0y10x1x1x2wx1wx1wx2wx1wx1ywx2x1wx1x1(wx1y)x2wx1www][0y10x1x1x2wx1wx1wx2wx1wx1ywx2x1wx1x1(wx1y)x2wx1wwwx1wx1λwx1(wx1y)λwx2wx1λw]\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}&w^{\top}x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}-y&w^{\top}x_{2}\\ x_{1}w^{\top}x_{1}&x_{1}(w^{\top}x_{1}-y)&x_{2}w^{\top}x_{1}\\ w&w&w\\ \end{bmatrix}\mapsto\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}&w^{\top}x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}-y&w^{\top}x_{2}\\ x_{1}w^{\top}x_{1}&x_{1}(w^{\top}x_{1}-y)&x_{2}w^{\top}x_{1}\\ w&w&w\\ x_{1}w^{\top}x_{1}-\lambda w&x_{1}(w^{\top}x_{1}-y)-\lambda w&x_{2}w^{\top}x_{1}-\lambda w\\ \end{bmatrix}

aff(;(3+2d,3+3d),(3+3d,3+4d),(3+2d,3+3d),W1=I,W2=2α,)\texttt{aff}(;(3+2d,3+3d),(3+3d,3+4d),(3+2d,3+3d),W_{1}=I,W_{2}=-2\alpha,) (w{\bm{w}}^{\prime})

[0y10x1x1x2wx1wx1wx2wx1wx1ywx2x1wx1x1(wx1y)x2wx1wwwx1wx1λwx1(wx1y)λwx2wx1λw][0y10x1x1x2wx1wx1wx2wx1wx1ywx2x1wx1x1(wx1y)x2wx1w2α(x1wx1λw)ww2α(x2wx1λw)x1wx1λwx1(wx1y)λwx2wx1λw]\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}&w^{\top}x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}-y&w^{\top}x_{2}\\ x_{1}w^{\top}x_{1}&x_{1}(w^{\top}x_{1}-y)&x_{2}w^{\top}x_{1}\\ w&w&w\\ x_{1}w^{\top}x_{1}-\lambda w&x_{1}(w^{\top}x_{1}-y)-\lambda w&x_{2}w^{\top}x_{1}-\lambda w\\ \end{bmatrix}\mapsto\\ \begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}&w^{\top}x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}-y&w^{\top}x_{2}\\ x_{1}w^{\top}x_{1}&x_{1}(w^{\top}x_{1}-y)&x_{2}w^{\top}x_{1}\\ w-2\alpha(x_{1}w^{\top}x_{1}-\lambda w)&w^{\prime}&w-2\alpha(x_{2}w^{\top}x_{1}-\lambda w)\\ x_{1}w^{\top}x_{1}-\lambda w&x_{1}(w^{\top}x_{1}-y)-\lambda w&x_{2}w^{\top}x_{1}-\lambda w\\ \end{bmatrix}

mov(;2,1,(3+2d,3+3d),(3+2d,3+3d))\texttt{mov}(;2,1,(3+2d,3+3d),(3+2d,3+3d)) (move w{\bm{w}}^{\prime})

[0y10x1x1x2wx1wx1wx2wx1wx1ywx2x1wx1x1(wx1y)x2wx1wwww2α(x1wx1λw)ww2α(x2wx1λw)x1wx1λwx1(wx1y)λwx2wx1λw][0y10x1x1x2wx1wx1wx2wx1wx1ywx2x1wx1x1(wx1y)x2wx1w2α(x1wx1λw)ww2α(x2wx1λw)x1wx1λwx1(wx1y)λwx2wx1λww2α(x1wx1λw)ww]\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}&w^{\top}x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}-y&w^{\top}x_{2}\\ x_{1}w^{\top}x_{1}&x_{1}(w^{\top}x_{1}-y)&x_{2}w^{\top}x_{1}\\ w&w&w\\ w-2\alpha(x_{1}w^{\top}x_{1}-\lambda w)&w^{\prime}&w-2\alpha(x_{2}w^{\top}x_{1}-\lambda w)\\ x_{1}w^{\top}x_{1}-\lambda w&x_{1}(w^{\top}x_{1}-y)-\lambda w&x_{2}w^{\top}x_{1}-\lambda w\\ \end{bmatrix}\mapsto\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}&w^{\top}x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}-y&w^{\top}x_{2}\\ x_{1}w^{\top}x_{1}&x_{1}(w^{\top}x_{1}-y)&x_{2}w^{\top}x_{1}\\ w-2\alpha(x_{1}w^{\top}x_{1}-\lambda w)&w^{\prime}&w-2\alpha(x_{2}w^{\top}x_{1}-\lambda w)\\ x_{1}w^{\top}x_{1}-\lambda w&x_{1}(w^{\top}x_{1}-y)-\lambda w&x_{2}w^{\top}x_{1}-\lambda w\\ w-2\alpha(x_{1}w^{\top}x_{1}-\lambda w)&w^{\prime}&w^{\prime}\\ \end{bmatrix}

mul(;1,d,1,(3+2d,3+3d),(1,1+d),(3+3d,4+3d))\texttt{mul}(;1,d,1,(3+2d,3+3d),(1,1+d),(3+3d,4+3d)) (wx2{{\bm{w}}^{\prime}}^{\top}x_{2})

[0y10x1x1x2wx1wx1wx2wx1wx1ywx2x1wx1x1(wx1y)x2wx1w2α(x1wx1λw)ww2α(x2wx1λw)x1wx1λwx1(wx1y)λwx2wx1λww2α(x1wx1λw)ww][0y10x1x1x2wx1wx1wx2wx1wx1ywx2x1wx1x1(wx1y)x2wx1w2α(x1wx1λw)ww2α(x2wx1λw)x1wx1λwx1(wx1y)λwx2wx1λww2α(x1wx1λw)ww(w2α(x1wx1λw))x1wx1wx2]\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}&w^{\top}x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}-y&w^{\top}x_{2}\\ x_{1}w^{\top}x_{1}&x_{1}(w^{\top}x_{1}-y)&x_{2}w^{\top}x_{1}\\ w-2\alpha(x_{1}w^{\top}x_{1}-\lambda w)&w^{\prime}&w-2\alpha(x_{2}w^{\top}x_{1}-\lambda w)\\ x_{1}w^{\top}x_{1}-\lambda w&x_{1}(w^{\top}x_{1}-y)-\lambda w&x_{2}w^{\top}x_{1}-\lambda w\\ w-2\alpha(x_{1}w^{\top}x_{1}-\lambda w)&w^{\prime}&w^{\prime}\\ \end{bmatrix}\mapsto\begin{bmatrix}0&y_{1}&0\\ x_{1}&x_{1}&x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}&w^{\top}x_{2}\\ w^{\top}x_{1}&w^{\top}x_{1}-y&w^{\top}x_{2}\\ x_{1}w^{\top}x_{1}&x_{1}(w^{\top}x_{1}-y)&x_{2}w^{\top}x_{1}\\ w-2\alpha(x_{1}w^{\top}x_{1}-\lambda w)&w^{\prime}&w-2\alpha(x_{2}w^{\top}x_{1}-\lambda w)\\ x_{1}w^{\top}x_{1}-\lambda w&x_{1}(w^{\top}x_{1}-y)-\lambda w&x_{2}w^{\top}x_{1}-\lambda w\\ w-2\alpha(x_{1}w^{\top}x_{1}-\lambda w)&w^{\prime}&w^{\prime}\\ (w-2\alpha(x_{1}w^{\top}x_{1}-\lambda w))^{\top}x1&{w^{\prime}}^{\top}x_{1}&\mathbf{{w^{\prime}}^{\top}x_{2}}\end{bmatrix}

We obtain the updated prediction in the last hidden unit of the third time-step. ∎

Since w{\bm{w}}^{\prime} is written in the hidden states, we may repeat this iteration to obtain y^3=wx3\hat{y}_{3}={\bm{w}}{{}^{\prime\prime}}^{\top}{\bm{x}}_{3} where w{\bm{w}}^{\prime\prime} is the one step update w2α(x2wx2y2x2+λw{\bm{w}}^{\prime}-2\alpha({\bm{x}}_{2}{\bm{w}}^{\prime\top}{\bm{x}}_{2}-{\bm{y}}_{2}{\bm{x}}_{2}+\lambda{\bm{w}}, requiring a total of O(n)\mathcal{O}(n) layers for a single pass through the dataset where nn is the number of examplers.

As an empirical demonstration of this procedure, the accompanying code release contains a reference implementation of SGD defined in terms of the base primitive provided in an anymous links https://icl1.s3.us-east-2.amazonaws.com/theory/{primitives,sgd,ridge}.py (to preserve the anonymity we did not provide the library dependencies). This implementation predicts y^n=wnxn\hat{y}_{n}={\bm{w}}_{n}^{\top}{\bm{x}}_{n}, where wn{\bm{w}}_{n} is the weight vector resulting from n1n-1 consecutive SGD updates on previous examples. It can be verified there that the procedure requires O(n+d)\mathcal{O}(n+d) hidden space. Note that, it is not O(nd)\mathcal{O}(nd) because we can reuse spaces for the next iteration for the intermediate variables, an example of this performed in (w{\bm{w}}^{\prime}) step above highlighted with blue color.

Appendix B Theorem 2

We provide a similar construction to Theorem 1 (please see proofs for the Transformer implementation of these operations in Appendix C, specifically for div see Section C.6)

mov(;1,0,(1,1+d),(1,1+d))\texttt{mov}(;1,0,(1,1+d),(1,1+d)) (move x1{\bm{x}}_{1})

mul(;d,1,1,(1,1+d),(0,1),(1+d,1+2d))\texttt{mul}(;d,1,1,(1,1+d),(0,1),(1+d,1+2d)) (x1y{\bm{x}}_{1}y)

aff(;(),(),(1+2d,1+2d+d2),b=Iλ)\texttt{aff}(;(),(),(1+2d,1+2d+d^{2}),b=\frac{I}{\lambda}) (A01=IλA_{0}^{-1}=\frac{I}{\lambda})

mul(;d,d,1,(1+2d,1+2d+d2),(1,1+d),(1+2d+d2,1+3d+d2))\texttt{mul}(;d,d,1,(1+2d,1+2d+d^{2}),(1,1+d),(1+2d+d^{2},1+3d+d^{2})) (A01u=Iλx1A_{0}^{-1}{\bm{u}}=\frac{I}{\lambda}{\bm{x}}_{1})

mul(;1,d,d,(1,1+d),(1+2d,1+2d+d2),(1+3d+d2,1+4d+d2))\texttt{mul}(;1,d,d,(1,1+d),(1+2d,1+2d+d^{2}),(1+3d+d^{2},1+4d+d^{2})) (vA01=x1Iλ{\bm{v}}A_{0}^{-1}={\bm{x}}_{1}^{\top}\frac{I}{\lambda})

mul(;d,1,d,(1+2d+d2,1+3d+d2),(1+3d+d2,1+4d+d2),(1+4d+d2,1+4d+2d2))\texttt{mul}(;d,1,d,(1+2d+d^{2},1+3d+d^{2}),(1+3d+d^{2},1+4d+d^{2}),(1+4d+d^{2},1+4d+2d^{2})) (A01uvA01=Iλx1x1IλA_{0}^{-1}{\bm{u}}{\bm{v}}A_{0}^{-1}=\frac{I}{\lambda}{\bm{x}}_{1}{\bm{x}}_{1}^{\top}\frac{I}{\lambda})

mul(;1,d,1,(1+3d+d2,1+4d+d2),(1,1+d),(1+4d+2d2,2+4d+2d2))\texttt{mul}(;1,d,1,(1+3d+d^{2},1+4d+d^{2}),(1,1+d),(1+4d+2d^{2},2+4d+2d^{2})) (vA01u=x1Iλx1{\bm{v}}^{\top}A_{0}^{-1}{\bm{u}}={\bm{x}}_{1}^{\top}\frac{I}{\lambda}{\bm{x}}_{1})

aff(;(1+4d+2d2,2+4d+2d2),(),(1+4d+2d2,2+4d+2d2),W1=1,b=1,)\texttt{aff}(;(1+4d+2d^{2},2+4d+2d^{2}),(),(1+4d+2d^{2},2+4d+2d^{2}),W_{1}=1,b=1,) (1+vA01u=1+x1Iλx11+{\bm{v}}^{\top}A_{0}^{-1}{\bm{u}}=1+{\bm{x}}_{1}^{\top}\frac{I}{\lambda}{\bm{x}}_{1})

div(;(1+4d+d2,1+4d+2d2),1+4d+2d2,(2+4d+2d2,2+4d+3d2))\texttt{div}(;(1+4d+d^{2},1+4d+2d^{2}),1+4d+2d^{2},(2+4d+2d^{2},2+4d+3d^{2})) (right term)

aff(;(1+2d,1+2d+d2),(2+4d+2d2,2+4d+3d2),(1+2d,1+2d+d2),W1=I,W2=I)\texttt{aff}(;(1+2d,1+2d+d^{2}),(2+4d+2d^{2},2+4d+3d^{2}),(1+2d,1+2d+d^{2}),W_{1}=I,W_{2}=-I) (A11A_{1}^{-1})

mul(;d,d,1,(1+2d,1+2d+d2),(1,1+d),(2+4d+3d2,2+5d+3d2))\texttt{mul}(;d,d,1,(1+2d,1+2d+d^{2}),(1,1+d),(2+4d+3d^{2},2+5d+3d^{2})) (A11x1A_{1}^{-1}{\bm{x}}_{1})

mul(;d,1,1,(2+4d+3d2,2+5d+3d2),(0,1),(2+4d+3d2,2+5d+3d2))\texttt{mul}(;d,1,1,(2+4d+3d^{2},2+5d+3d^{2}),(0,1),(2+4d+3d^{2},2+5d+3d^{2})) (A11x1y1A_{1}^{-1}{\bm{x}}_{1}y_{1})

mov(;2,1,(2+4d+3d2,2+5d+3d2),(2+4d+3d2,2+5d+3d2))\texttt{mov}(;2,1,(2+4d+3d^{2},2+5d+3d^{2}),(2+4d+3d^{2},2+5d+3d^{2})) (move w{\bm{w}}^{\prime})

mul(;d,1,1(2+4d+3d2,2+5d+3d2),(1,1+d),(2+5d+3d2,3+5d+3d2))\texttt{mul}(;d,1,1(2+4d+3d^{2},2+5d+3d^{2}),(1,1+d),(2+5d+3d^{2},3+5d+3d^{2})) (wx2{{\bm{w}}^{\prime}}^{\top}x_{2})

Note that, in contrast to Appendix A, we need O(d2)\mathcal{O}(d^{2}) space to implement matrix multiplications. Therefore over-all required hidden size is O(d2)\mathcal{O}(d^{2})

As Theorem 1, generalizing it to multiple iterations will at least require O(n)\mathcal{O}(n) layers, as we repeat the process for the next examplar.

Appendix C Lemma 1

All of the operators mentioned in this lemma share a common computational structure, and can in fact be implemented as special cases of a “base primitive” we call RAW (for Read-Arithmetic-Write). This operator may also be useful for future work aimed at implementing other algorithms.

The structure of our proof of Lemma 1 is as follows:

Definition of dot, aff, mov in terms of RAW.

Implementation of RAW in terms of transformer parameters.

Brief discussion of how to parallelize RAW, making it possible to implement mul.

Seperate proof for div by utilizing layer norm.

At a high level, all of the primitives in Lemma 1 involve a similar sequence of operations:

dot and aff read from two subsets of indices in the current hidden state ht{\bm{h}}_{t}For notational convenience, we will use h{\bm{h}} to refer to sequence of hidden states (instead of HH in Eq. 1.), ht{\bm{h}}_{t^{\prime}} will be the hidden state at time step tt^{\prime}, while mov reads from a previous hidden state ht{\bm{h}}_{t^{\prime}}. This selection is straightforwardly implemented using the attention component of a transformer layer.

We may notate this reading operation as follows:

Here r\bf r denotes a list of indice to read from, and KK denotes a map from current timesteps to target timesteps. For convenience, we use Numpy-like notation to denote indexing into a vector with another vector:

x[.]{\bm{x}}[.] is Python index notation where the resulting vector, x=x[r]{\bm{x}}^{\prime}={\bm{x}}[\texttt{r}]:

The first step of our proof below shows that the attention output a(l){\bm{a}}^{(l)} can compute the expression above.

This step takes different forms for aff and mul (mov ignores values at the current timestep altogether).

The second step of the proof below computes these operations inside the MLP component of the transformer layer.

Once the underlying element-wise operation calculated, the operator needs to write these values to the some indices in current hidden state, defined by a list of indices w\bf w. Writing might be preceded by a reduction state (e.g. for computing dot products), which can be expressed generically as a linear operator WoW_{o}. The final form of the computation is thus:

Here, \leftarrow means that the other indices iwi\notin w are copied from hl1h^{l-1}.

C.2 RAW Operator Definition

We additionally require that jK(i)    j<ij\in K(i)\implies j<i (since self-attention is causal.)

(For simplicity, we did not include a possible bias term in linear projections WoW_{o}, WaW_{a}, WW, we can always assume the accompanying bias parameters b0,ba,b{\bm{b}}_{0},{\bm{b}}_{a},{\bm{b}} when needed)

C.3 Reducing Lemma 1 operators to RAW operator

Given this operator, we can define each primitive in Lemma 1 using a single RAW\operatorname{RAW} operator, except the mul and div. Instead of the matrix multiplication operator mul, we will first show the dot product dot (a special case of mul), then later in the proof, we will argue that we can parallelize these dot products in Section C.5 to obtain mul. We will show how to implement div separately in Section C.6.

We can define mov, aff operator, and the dot product case of mul in Lemma 1 by using a single RAW operator

Follows immediately by substituting parameters into Eq. 34. ∎

C.4 Implementing RAW

A single transformer layer can implement the RAW operator: there exist settings of transformer parameters such that, given an arbitrary hidden matrix h{\bm{h}} as input, the transformer computes h{\bm{h}}^{\prime} (Eq. 34) as output.

Our proof proceeds in stages. We begin by providing specifying initial embedding and positional embedding layers, constructing inputs to the main transformer layer with necessary positional information and scratch space. Next, we prove three useful procedures for bypassing (or exploiting) non-linearities in the feed-forward component of the transformer. Finally, we provide values for remaining parameters, showing that we can implement the Elementwise and Reduction steps described above.

Rather than inserting the input matrix h{\bm{h}} directly into the transformer layer, we assume (as is standard) the existence of a linear embedding layer. We can set this layer to pad the input, providing extra scratch space that will be used by later steps of our implementation.

We define the embedding matrix WeW_{e} as:

Implementing RAW ultimately requires controlling which position attends to which position in each layer. For example, we may wish to have layers in which each position attends to the previous position only, or in which even positions attends to other even positions. We can utilize position embeddings, pi{\bm{p}}_{i}, to control attention weights. In a standard transformer, the position embedding matrix is a constant matrix that is added to the inputs of the transformer after embedding layer (before the first layer), so the actual input to to the transformer is:

We will use these position embeddings to encode the timestep map K. To do this, we will use 2p2p units per layer (pp will be defined momentarily). pp units will be used to encode attention keys ki\mathbf{k}_{i}, and the other pp will be used to encode queries qi\mathbf{q}_{i}.

We define the position embedding matrix as follows:

With KK encoded in positional embeddings, the transformer matrices WQW_{Q} and WKW_{K} are easy to define: they just need to retrieve the corresponding embedding values:

The constructions used in this paper rely on two specific timestep maps KK, each of which can be implemented compactly in terms of k{\bm{k}} and q{\bm{q}}:

where NN is a sufficiently large number. In this case, the output of the attention mechanism will be:

For simpler patterns, such as attention to a specific token:

from which it can be verified (using the same procedure as in Case 1) that the desired attention pattern is produced.

We can also cause K(i)K(i) to attend to an empty set by assuming the softmax\operatorname{softmax} has extra (“imaginary”) timestep obtained by prepending a 0 to attention vector pot-hoc (Chen et al., 2021).

Cumulatively, the parameter matrices defined in this subsection implement the Read with Attention component of the RAW operator.

C.4.2 Handling & Utilizing Nonlinearities

The mul operator requires elementwise multiplication of quantities stored in hidden states. While transformers are often thought of as only straightforwardly implementing affine transformations on hidden vectors, their nonlinearities in fact allow elementwise multiplication to a high degree of approximation. We begin by observing the following property of the GeLU\operatorname{GeLU} activation function in the MLP layers of the Transformer network:

The GeLU nonlinearity can be used to perform multiplication: specifically,

A standard implementation of the GeLU nonlinearity is defined as follows:

For small xx and yy, the third-order term vanishes. By scaling inputs down by a constant before the GeLU layer, and scaling them up afterwards, models may use the GeLU operator to perform elementwise multiplication.∎

We can generalize this proof to other smooth functions as we discussed further in [TODO REF]. Previous work also shows, in practice, Transformers with ReLU activation utilize non-linearities to get the multiplication in other settings.

When implementing the aff operator, we have the opposite problem: we would like the output of addition to be transmitted without nonlinearities to the output of the transformer layer. Fortunately, for large inputs, the GeLU nonlinearity is very close to linear; to bypass it it suffices to add to inputs a large NN:

The GeLU nonlinearity can be bypassed: specifically,

For all verions of the RAW operator, it is additionally necessary to bypass the LayerNorm operation. The following formula will be helpful for this:

Let NN be a large number and λ\lambda the LayerNorm function. Then the following approximation holds:

By adding a large number NN to two padding locations and sum the part of the hidden state that we are interested to pass through LayerNorm, we make xx to the output of LayerNorm pass through. This addition can be done in the transformer’s feed-forward computation (with parameter WFW^{F}) prior to layer norm. This multiplication of 2LN\sqrt{\frac{2}{L}}N can be done in first layer of MLP back, then linear layer can output/use x{\bm{x}}. For convenience, we will henceforth omit the LayerNorm operation when it is not needed.

We may make each of these operations as precise as desired (or allowed by system precision). With them defined, we are ready to specify the final components of the RAW operator.

C.4.3 Parameterizing RAW

We want to show a layer of Transformer defined in above, hence parameterized by θ={Wf,W1,W2,(WQ,WK,Wv)m}\theta=\{W_{\textrm{f}},W_{1},W_{2},(W^{Q},W^{K},W^{v})_{m}\}, can well-approximate the RAW operator defined in Eq. 25. We will provide step by step constructions and define the parameters in θ\theta. Begin by recalling the transformer layer definition:

We will only use m=2m=2 attention heads for this construction. We show in Eq. 40 that we can control attentions to uniformly attend with a pattern by setting key and query matrices. Assume that the first head parameters W1Q,W1KW_{1}^{Q},W_{1}^{K} have been set in the described way to obtain the pattern function K\mathbf{K}.

Now we will set remaining attention parameters W1V,W2Q,W2K,W2VW_{1}^{V},W_{2}^{Q},W_{2}^{K},W_{2}^{V} and show hat we can make the ai+hi(l){\bm{a}}_{i}+{\bm{h}}^{(l)}_{i} term in Eq. 4 to contain the corresponding term in Eq. 25, in some unused indices t such that:

Then the term on the RAW operator can be obtained by the first head’s output. In order to achieve that, we will set WaW_{a} as a part of actual attention value network such that W1VW_{1}^{V} is sparse matrix 0 everywhere expect:

Now our first heads stores the right term in Eq. 61 in the indicies tt. However, when we add the residual term hi(l){\bm{h}}_{i}^{(l)}, this will change. To remove the residual term, we will use another head to output hi(l){\bm{h}}_{i}^{(l)}, by setting W2Q,W2KW_{2}^{Q},W_{2}^{K} such that K(i)=iK(i)=i, and W2VW_{2}^{V} (similar to Eq. 42):

We already defined (WQ,WK,WV)1,2 and Wf(W^{Q},W^{K},W^{V})_{1,2}\text{ and }W^{\textrm{f}} and obtained the first term in the Eq. 25 in (ai+hi(l))tt({\bm{a}}_{i}+{\bm{h}}^{(l)}_{i})_{t^{\prime}\in\texttt{t}}.

Now we want to calculate the term inside the parenthesis Eq. 25. We will calculate it through the MLP layer and store in mi{\bm{m}}_{i} and substract the first term. Let’s denote the input to the MLP as xi=(ai+hi(l)){\bm{x}}_{i}=({\bm{a}}_{i}+{\bm{h}}^{(l)}_{i}), the output of the first layer ui{\bm{u}}_{i}, the output of the non-linearity as ai{\bm{a}}_{i}, and the final output as mi{\bm{m}}_{i}. The entries of mi{\bm{m}}_{i} will be:

We will define the MLP layer to operate the attention term calculated above with a part of the current hidden state by defining W1W_{1} and W2W_{2}. Let’s assume we bypass the LayerNorm\operatorname{LayerNorm} by using Lemma 6.

Let’s show this seperately for ++ and \odot operators.

If the operator, \mathbin{\mathchoice{\ooalign{\displaystyle\vbox{\hbox{\scalebox{0.77778}{\displaystyle\bigcirc}}}\cr\cr\displaystyle\star\cr}}{\ooalign{\textstyle\vbox{\hbox{\scalebox{0.77778}{\textstyle\bigcirc}}}\cr\cr\textstyle\star\cr}}{\ooalign{\scriptstyle\vbox{\hbox{\scalebox{0.77778}{\scriptstyle\bigcirc}}}\cr\cr\scriptstyle\star\cr}}{\ooalign{\scriptscriptstyle\vbox{\hbox{\scalebox{0.77778}{\scriptscriptstyle\bigcirc}}}\cr\cr\scriptscriptstyle\star\cr}}}=+, first layer of the MLP will calculate the second term in Eq. 25 and overwrite the space where the attention output term Eq. 61 is written, and add a large positive bias term NN to by pass GeLU\operatorname{GeLU} as explained in Lemma 4. We will use an available space t^\hat{\texttt{t}} in the xix_{i} same size as t.

This can be done by setting W1W_{1} (weight term of the first layer of the MLP) to zero except the below indices:

Note the second term is added to make unused indices twt^{\texttt{t}\cup\texttt{w}\cup\hat{\texttt{t}}} become zero after the gelu\operatorname{gelu}, which outputs zero for large negative values. Since we added a large positive term, we make sure gelu\operatorname{gelu} behaved like a linear layer. Thus we have,

Therefore, mi[w]=Woxi[t]+W0Whil[s]xi[w]{\bm{m}}_{i}[w]=W_{o}{\bm{x}}_{i}[\texttt{t}]+W_{0}W{\bm{h}}_{i}^{l}[\texttt{s}]-{\bm{x}}_{i}[\texttt{w}] equals to what we promised in Eq. 76 for ++ case. If we sum this with the residual xi{\bm{x}}_{i} term back Eq. 61, so the output of this layer can be written as:

If the operator, \mathbin{\mathchoice{\ooalign{\displaystyle\vbox{\hbox{\scalebox{0.77778}{\displaystyle\bigcirc}}}\cr\cr\displaystyle\star\cr}}{\ooalign{\textstyle\vbox{\hbox{\scalebox{0.77778}{\textstyle\bigcirc}}}\cr\cr\textstyle\star\cr}}{\ooalign{\scriptstyle\vbox{\hbox{\scalebox{0.77778}{\scriptstyle\bigcirc}}}\cr\cr\scriptstyle\star\cr}}{\ooalign{\scriptscriptstyle\vbox{\hbox{\scalebox{0.77778}{\scriptscriptstyle\bigcirc}}}\cr\cr\scriptscriptstyle\star\cr}}}=\odot, we need to use three extra hidden units the same size as t|\texttt{t}|, let’s name the extra indices as ta\texttt{t}_{a}, tb\texttt{t}_{b}, tc\texttt{t}_{c}, and output ww space. The (ui)({\bm{u}}_{i}) will get below entries to be able to use [], where NN is a large number:

All of this operations are linear, can be done W1W_{1} zero except the below entries:

The resulting v{\bm{v}} with the approximations become:

Now, we can use the GeLU trick in Lemma 4, by setting W2W_{2}

With this, mi[w]=Woxi[t]W0Whil1[s]xi[w]{\bm{m}}_{i}[w]=W_{o}{\bm{x}}_{i}[t]*W_{0}W{\bm{h}}_{i}^{l-1}[s]-{\bm{x}}_{i}[w], and

We have used 4t4|t| space for internal computation of this operation, and finally used w|\texttt{w}| space to write the final result. We show RAW operator is implementable by setting the parameters of a Transformer.

C.5 Parallelizing the RAW operator

With the conditions that KK is constant, the operators are independent (i.e (risiwi)wji=(r_{i}\cup s_{i}\cup w_{i})\cap w_{j\neq i}=\emptyset), and there is k(4tk+wk)\sum_{k}(4|\texttt{t}_{k}|+|\texttt{w}_{k}|) available space in the hidden state, then a Transformer layer can apply kk such RAW\operatorname{RAW} operation in parallel by setting different regions of W1,W2,Wf and (WV)kW_{1},W_{2},W_{f}\text{ and }(W^{V})_{k} matrices.

From the construction above, it is straightforward to modify the definition of the RAW operator to perform kk operations as all the indices of matrices that we use in Section C.4.3 do not overlap with the given conditions in the lemma. ∎

This makes it possible to construct a Transformer layer not only to implement vector-vector dot products, but general matrix-matrix products, as required by mul. With this, we show that we can implement mul by using single layer of a Transformer.

C.6 LayerNorm for Division

Let say we have the input [c,y,0][c,{\bm{y}},\mathbf{0}]^{\top} calculated before the attention output in Eq. 61, and we want to divide y{\bm{y}} to cc. This trick is very similar to the on in Lemma 6. We can use the following formula:

using LayerNorm for division. Let N,MN,M to be large numbers, λ\lambda LayerNorm function, the following approximation holds:

To get the input to the format used in this Lemma, we can easily use WfW_{f} to convert the head outputs. Then, after the layer norm, we can use W1W_{1} to pull the yc\frac{{\bm{y}}}{c} back and write it to the attention output. By this way, we can approximate scalar division in one layer.

By Lemmas 2, LABEL:, 3, LABEL:, 3, LABEL:, 7, LABEL: and 8; we constructed the operators in Lemma 1 using single layer of a Transformer, thus proved Lemma 1 ∎

Appendix D Details of Transformer Arhitecture and Training

We perform these experiments using the Jax framework on P100 GPUs. The major hyperparameters used in these experiments are presented in Table 1. The code repository used for reproducing these experiments will be open sourced at the time of publication. Most of the hyperparameters adapted from previous work Garg et al. (2022) to be compatible, and we adapted the Transformer architecture details. We use Adam optimizer with cosine learning rate scheduler with warmup where number of warmup steps set to be 1/5 of total iterations. We use larned absolute position embeddings.

In the phase shift plots in Fig. 3(a), we keep the value in the x-axis constant and used the best setting over the parameters: {number of layers, hidden size, number of heads and learning rate}.

Appendix E Details of Probe

We will use the terms probe model and task model to distinguish probe from ICL. Our probe is defined as:

In Fig. 4, dashed lines show probing results with a task model trained on a control task, in which ww is always the all-ones 1\mathbf{1}. This problem structurally resembles our main experiment setup, but does not require in-context learning. During probing, we feed this model data generated by ww sampled form normal distribution as in the original task model. We observe that the control probe has a significantly higher error rate, showing that the probing accuracy obtained with actual task model is non-trivial. We present detailed error values of the control probe in Fig. 5.

Appendix F Linearity of ICL

In Fig. 1(b), we compare implicit linear weight of the ICL against the linear algorithms using ILWD measure. Note that this measure do not assume predictors to be linear: when the predictors are not linear, ILWD measures the difference between closest linear predictors (in Eq. 16 sense) to each algorithm.

To gain more insight to ICL’s algorithm, we can measure how linear ICL in different regimes of the linear problem (underdetermined, determined) by using R2R^{2} (coefficient of determination) measure. So, instead of asking what’s the best linear fit in Eq. 16, we can ask how good is the linear fit, which is the R2R^{2} of the estimator. Interestingly, even though our model matches min-norm least square solution in both metrics in LABEL:fig:fit, we show that ICL is becoming gradually linear in the under-determined regime Fig. 6. This is an important result, enables us to say the in-context learner’s hypothesis class is not purely linear.

Appendix G Multiplicative interactions with Other Non-linearities

We can show that for a real-valued and smooth non-linearity f(x)f(x), we can apply the same trick in in the paper body. In particular, we can write Taylor expansion as:

which converges for some sufficiently small neighborhood: X[ϵ,ϵ]\mathcal{X}\in[-{\epsilon},{\epsilon}]. First, assume that the second order term a2a_{2} dominates higher-order terms in this domain such that:

It’s is easy to verify that the following is true:

So, given the expansion for GeLU in Eq. 45, we can use this generic formula to obtain the multiplication approximation:

We plot this approximation against x2x^{2} for [0.1,0.1][0.1,-0.1] range in Fig. 7(a).

In the case of a2a_{2} is zero, we cannot get any second order term, and in the case of a2a_{2} is negligible O(x3+y3)\mathcal{O}(x^{3}+y^{3}) will dominate the Eq. 151, so we cannot obtain a good approximation of xyxy. In this case, we can resort to numerical derivatives and utilize the a3a_{3} term:

If a3a_{3} is not negligible, a3x2ai>3xia_{3}x^{2}\ll a_{i>3}x^{i} in the same domain, we can use numerical derivatives to get a multiplication term:

For example, tanh\tanh has no second order term in its Taylor expansion:

Using above formula we can obtain the following expression:

Similar to our construction in Eq. 126, we can construct a Transformer layer that calculates these quantities (noting that δ\delta is a small, input-independent scalar).

We plot this approximation against x2x^{2} for [0.1,0.1][0.1,-0.1] range in Fig. 7(b). Note that, if we use this approximation in our constructions we will need more hidden space as there are 6 different tanh\tanh term as opposed to 3GeLU3\operatorname{GeLU} term in Eq. 126.

ReLU is another commonly used non-linearity that is not differentiable. With ReLU, we can only hope to get piece-wise linear approximations. For example, we can try to approximate x2x^{2} with the following function:

We plot this approximation against x2x^{2} for [0.1,0.1][0.1,-0.1] range in Fig. 7(c).

Appendix H Empirical Scaling Analysis with Dimensionality

In Figs. 3(a), LABEL: and 3(b), we showed that ICL needs different hidden sizes to enter the “Ridge regression phase” (orange background) or “OLS phase” (green background) depending on the dimensionality dd of inputs xx. However, we cannot reliably read the actual relations between size requirements and the dimension of the problem from only two dimensions. To better understand size requirements, we ask the following empirical question for each dimension: how many layer/hidden size/heads are needed to better fit the least-squares solution than the Ridge(λ=ϵ)\operatorname{Ridge}(\lambda=\epsilon) regression solution (the green phase in Figs. 3(a), LABEL: and 3(b))?

To answer this important question, we experimented with d={1,2,4,8,12,16,20}d=\{1,2,4,8,12,16,20\} and run an experiment sweep for each dimension over:

number of layers (L): {1,2,4,8,12,16}\{1,2,4,8,12,16\},

hidden size (H): {16,32,64,256,512,1024}\{16,32,64,256,512,1024\},

For each feature that affects computational capacity of transformer (LL, HH, MM), we optimize other features and find the minimum value for the feature that satisfies SPD(OLS,ICL)<SPD(Ridge(λ=ϵ),ICL)\operatorname{SPD}(\operatorname{OLS},\operatorname{ICL})<\operatorname{SPD}(\operatorname{Ridge}(\lambda=\epsilon),\operatorname{ICL}). We plot our experiment with ϵ=0.1\epsilon=0.1 in LABEL:fig:empreqs. We find that single head is enough for all problem dimensions, while other parameters exhibit a step-function-like dependence on input size.

Please note that other hyperparameters discussed in Appendix D (e.g weight initialization) were not optimized for each dimension independently.