Transformers learn in-context by gradient descent

Johannes von Oswald, Eyvind Niklasson, Ettore Randazzo, João Sacramento, Alexander Mordvintsev, Andrey Zhmoginov, Max Vladymyrov

Introduction

In recent years Transformers (TFs; Vaswani et al., 2017) have demonstrated their superiority in numerous benchmarks and various fields of modern machine learning, and have emerged as the de-facto neural network architecture used for modern AI (Dosovitskiy et al., 2021; Yun et al., 2019; Carion et al., 2020; Gulati et al., 2020). It has been hypothesised that their success is due in part to a phenomenon called in-context learning (Brown et al., 2020; Liu et al., 2021): an ability to flexibly adjust their prediction based on additional data given in context (i.e. in the input sequence itself). In-context learning offers a seemingly different approach to few-shot and meta-learning (Brown et al., 2020), but as of today the exact mechanisms of how it works are not fully understood. It is thus of great interest to understand what makes Transformers pay attention to their context, what the mechanisms are, and under which circumstances, they come into play (Chan et al., 2022b; Olsson et al., 2022).

In this paper, we aim to bridge the gap between in-context and meta-learning, and show that in-context learning in Transformers can be an emergent property approximating gradient-based few-shot learning within its forward pass, see Figure 1. For this to be realized, we show how Transformers (1) construct a loss function dependent on the data given in sequence and (2) learn based on gradients of that loss. We will first focus on the latter, the more elaborate learning task, in sections 2 and 3, after which we provide evidence for the former in section 4.

We summarize our contributions as followsMain experiments can be reproduced with notebooks provided under the following link: https://github.com/google-research/self-organising-systems/tree/master/transformers_learn_icl_by_gd:

We construct explicit weights for a linear self-attention layer that induces an update identical to a single step of gradient descent (GD) on a mean squared error loss. Additionally, we show how several self-attention layers can iteratively perform curvature correction improving on plain gradient descent.

When optimized on linear regression datasets, we demonstrate that linear self-attention-only Transformers either converge to our weight construction and therefore implement gradient descent, or generate linear models that closely align with models trained by GD, both in in- and out-of-distribution validation tasks.

By incorporating multi-layer-perceptrons (MLPs) into the Transformer architecture, we enable solving nonlinear regression tasks within Transformers by showing its equivalence to learning a linear model on deep representations. We discuss connections to kernel regression as well as nonparametric kernel smoothing methods. Empirically, we compare meta-learned MLPs and a single step of GD on its output layer with trained Transformers and demonstrate striking similarities between the identified solutions.

We resolve the dependency on the specific token construction by providing evidence that learned Transformers first encode incoming tokens into a format amenable to the in-context gradient descent learning that occurs in the later layers of the Transformer.

These findings allow us to connect learning Transformer weights and the concept of meta-learning a learning algorithm (Schmidhuber, 1987; Hinton & Plaut, 1987; Bengio et al., 1990; Chalmers, 1991; Schmidhuber, 1992; Thrun & Pratt, 1998; Hochreiter et al., 2001; Andrychowicz et al., 2016; Ba et al., 2016; Kirsch & Schmidhuber, 2021). In this extensive research field, meta-learning is typically regarded as learning that takes place on various time scales namely fast and slow. The slowly changing parameters control and prepare for fast adaptation reacting to sudden changes in the incoming data by e.g. a context switch. Notably, we build heavily on the concept of fast weights (Schmidhuber, 1992) which has shown to be equivalent to linear self-attention (Schlag et al., 2021) and show how optimized Transformers implement interpretable learning algorithms within their weights.

Another related meta-learning concept, termed MAML (Finn et al., 2017), aims to meta-learn a deep neural network initialization which allows for fast adaptation on novel tasks. It has been shown that in many circumstances, the solution found can be approximated well when only adapting the output layer i.e. learning a linear model on a meta-learned deep data representations (Finn et al., 2017; Finn & Levine, 2018; Gordon et al., 2019; Lee et al., 2019; Rusu et al., 2019; Raghu et al., 2020; von Oswald et al., 2021). In section 3, we show the equivalence of this framework to in-context learning implemented in a common Transformer block i.e. when combining self-attention layers with a multi-layer-perceptron.

In the light of meta-learning we show how optimizing Transformer weights can be regarded as learning on two time scales. More concretely, we find that solely through the pressure to predict correctly Transformers discover learning algorithms inside their forward computations, effectively meta-learning a learning algorithm. Recently, this concept of an emergent optimizer within a learned neural network, such as a Transformer, has been termed “mesa-optimization” (Hubinger et al., 2019). We find and describe one possible realization of this concept and hypothesize that the in-context learning capabilities of language models emerge through mechanisms similar to the ones we discuss here.

Transformers come in different “shapes and sizes”, operate on vastly different domains, and exhibit varying forms of phase transitions of in-context learning (Kirsch et al., 2022; Chan et al., 2022a), suggesting variance and significant complexity of the underlying learning mechanisms. As a result, we expect our findings on linear self-attention-only Transformers to only explain a limited part of a complex process, and it may be one of many possible methods giving rise to in-context learning. Nevertheless, our approach provides an intriguing perspective on, and novel evidence for, an in-context learning mechanism that significantly differs from existing mechanisms based on associative memory (Ramsauer et al., 2020), or by the copying mechanism termed induction heads identified by (Olsson et al., 2022). We, therefore, state the following

When training Transformers on auto-regressive tasks, in-context learning in the Transformer forward pass is implemented by gradient-based optimization of an implicit auto-regressive inner loss constructed from its in-context data.

We acknowledge work done in parallel, investigating the same hypothesis. Akyürek et al. (2023) puts forward a weight construction based on a chain of Transformer layers (including MLPs) that together implement a single step of gradient descent with weight decay. Similar to work done by Garg et al. (2022), they then show that trained Transformers match the performance of models obtained by gradient descent. Nevertheless, it is not clear that optimization finds Transformer weights that coincide with their construction.

Here, we present a much simpler construction that builds on Schlag et al. (2021) and only requires a single linear self-attention layer to implement a step of gradient descent. This allows us to (1) show that optimizing self-attention-only Transformers finds weights that match our weight construction (Proposition 1), demonstrating its practical relevance, and (2) explain in-context learning in shallow two layer Transformers intensively studied by Olsson et al. (2022). Therefore, although related work provides comprehensive empirical evidence that Transformers indeed seem to implement gradient descent based learning on the data given in-context, we will in the following present mechanistic verification of this hypothesis and provide compelling evidence that our construction, which implements GD in a Transformer forward pass, is found in practice.

Linear self-attention can emulate gradient descent on a linear regression task

We start by reviewing a standard multi-head self-attention (SA) layer with parameters θ\theta. A SA layer updates each element eje_{j} of a set of tokens {e1,,eN}\{e_{1},\ldots,e_{N}\} according to

with Ph,Vh,KhP_{h},V_{h},K_{h} the projection, value and key matrices, respectively, and qh,iq_{h,i} the query, all for the hh-th head. To simplify the presentation, we omit bias terms here and throughout. The columns of the value Vh=[vh,1,,vh,N]V_{h}=[v_{h,1},\dots,v_{h,N}] and key Kh=[kh,1,,kh,N]K_{h}=[k_{h,1},\dots,k_{h,N}] matrices consist of vectors vh,i=Wh,Veiv_{h,i}=W_{h,V}e_{i} and kh,i=Wh,Keik_{h,i}=W_{h,K}e_{i}; likewise, the query is produced by linearly projecting the tokens, qh,j=Wh,Qejq_{h,j}=W_{h,Q}e_{j}. The parameters θ={Ph,Wh,V,Wh,K,Wh,Q}h\theta=\{P_{h},W_{h,V},W_{h,K},W_{h,Q}\}_{h} of a SA layer consist of all the projection matrices, of all heads.

The self-attention layer described above corresponds to the one used in the standard Transformer model. Following Schlag et al. (2021), we now introduce our first (and only) departure from the standard model, and omit the softmax operation in equation 1, leading to the linear self-attention (LSA) layer ejej+LSAθ(j,{e1,,eN})=ej+hPhVhKhTqh,je_{j}\leftarrow e_{j}+\text{LSA}_{\theta}(j,\{e_{1},\ldots,e_{N}\})=e_{j}+\sum_{h}P_{h}V_{h}K_{h}^{T}q_{h,j} We next show that with some simple manipulations we can relate the update performed by an LSA layer to one step of gradient descent on a linear regression loss.

One step of gradient descent on LL with learning rate η\eta yields the weight change

Considering the loss after changing the weights, we obtain

where we introduced the transformed targets yiΔyiy_{i}-\Delta y_{i} with Δyi=ΔWxi\Delta y_{i}=\Delta Wx_{i}. Thus, we can view the outcome of a gradient descent step as an update to our regression loss (equation 2), where data, and not weights, are updated. Note that this formulation is closely linked to predicting based on nonparametric kernel smoothing, see Appendix A.8 for a discussion.

Transformations induced by gradient descent and a linear self-attention layer can be equivalent

We have re-cast the task of learning a linear model as directly modifying the data, instead of explicitly computing and returning the weights of the model (equation 4). We proceed to establish a connection between self-attention and gradient descent. We provide a construction where learning takes place simultaneously by directly updating all tokens, including the test token, through a linear self-attention layer. In other words, the token produced in response to a query (test) token is transformed from its initial value W0xtestW_{0}x_{\text{test}}, where W0W_{0} is the initial value of WW, to the post-learning prediction y^=(W0+ΔW)xtest\hat{y}=(W_{0}+\Delta W)x_{\text{test}} obtained after one gradient descent step.

Given a 1-head linear attention layer and the tokens ej=(xj,yj)e_{j}=(x_{j},y_{j}), for j=1,,Nj=1,\ldots,N, one can construct key, query and value matrices WK,WQ,WVW_{K},W_{Q},W_{V} as well as the projection matrix PP such that a Transformer step on every token eje_{j} is identical to the gradient-induced dynamics ej(xj,yj)+(0,ΔWxj)=(xj,yj)+PVKTqje_{j}\leftarrow(x_{j},y_{j})+(0,-\Delta Wx_{j})=(x_{j},y_{j})+P\,VK^{T}q_{j} such that ej=(xj,yjΔyj)e_{j}=(x_{j},y_{j}-\Delta y_{j}). For the test data token (xN+1,yN+1)(x_{N+1},y_{N+1}) the dynamics are identical.

The simple construction can be found in Appendix A.1 and we denote the corresponding self-attention weights by θGD\theta_{\text{GD}}.

Below, we provide some additional insights on what is needed to implement the provided LSA-layer weight construction, and further details on what it can achieve:

Full self-attention. Our dynamics model training is based on in-context tokens only, i.e., only e1,,eNe_{1},\ldots,e_{N} are used for computing key and value matrices; the query token eN+1e_{N+1} (containing test data) is excluded. This leads to a linear function in xtestx_{\text{test}} as well as to the correct ΔW\Delta W, induced by gradient descent on a loss consisting only of the training data. This is a minor deviation from full self-attention. In practice, this modification can be dropped, which corresponds to assuming that the underlying initial weight matrix is zero, W00W_{0}\approx 0, which makes ΔW\Delta W in equation 37 independent of the test token even if incorporating it in the key and value matrices. In our experiments, we see that these assumptions are met when initializing the attention weights θ\theta to small values.

Reading out predictions. When initializing the yy-entry of the test-data token with W0xN+1-W_{0}x_{N+1}, i.e. etest=(xtest,W0xtest)e_{\text{test}}=(x_{\text{test}},-W_{0}x_{\text{test}}), the test-data prediction y^\hat{y} can be easily read out by simply multiplying again by 1-1 the updated token, since yN+1+ΔyN+1=(yN+1ΔyN+1)=yN+1+ΔWxN+1-y_{N+1}+\Delta y_{N+1}=-(y_{N+1}-\Delta y_{N+1})=y_{N+1}+\Delta Wx_{N+1}. This can easily be done by a final projection matrix, which incidentally is usually found in Transformer architectures. Importantly, we see that a single head of self-attention is sufficient to transform our training targets as well as the test prediction simultaneously.

Uniqueness. We note that the construction is not unique; in particular, it is only required that the products PWVPW_{V} as well as WKWQW_{K}W_{Q} match the construction. Furthermore, since no nonlinearity is present, any rescaling ss of the matrix products, i.e., PWVsPW_{V}s and WKWQ/sW_{K}W_{Q}/s, leads to an equivalent result. If we correct for these equivalent formulations, we can experimentally verify that weights of our learned Transformers indeed match the presented construction.

Meta-learned task-shared learning rates. When training self-attention parameters θ\theta across a family of in-context learning tasks τ\tau, where the data (xτ,i,yτ,i)(x_{\tau,i},y_{\tau,i}) follows a certain distribution, the learning rate can be implicitly (meta-)learned such that an optimal loss reduction (averaged over tasks) is achieved given a fixed number of update steps. In our experiments, we find this to be the case. This kind of meta-learning to improve upon plain gradient descent has been leveraged in numerous previous approaches for deep neural networks (Li et al., 2017; Lee & Choi, 2018; Park & Oliva, 2019; Zhao et al., 2020; Flennerhag et al., 2020).

Task-specific data transformations. A self-attention layer is in principle further capable of exploiting statistics in the current training data samples, beyond modeling task-shared curvature information in θ\theta. More concretely, a LSA layer updates an input sample according to a data transformation xjxj+Δxj=(I+P(X)V(X)K(X)TWQ)xj=Hθ(X)xjx_{j}\leftarrow x_{j}+\Delta x_{j}=(I+P(X)V(X)K(X)^{T}W_{Q})x_{j}=H_{\theta}(X)x_{j}, with XX the Nx×NN_{x}\times N input training data matrix, when neglecting influences by target data yiy_{i}. Through Hθ(X)H_{\theta}(X), a LSA layer can encode in θ\theta an algorithm for carrying out data transformations which depend on the actual input training samples in XX. In our experiments, we see that trained self-attention learners employ a simple form of H(X)H(X) and that this leads to substantial speed ups in for GD and TF learning.

Trained Transformers do mimic gradient descent on linear regression tasks

We now experimentally investigate whether trained attention-based models implement gradient-based in-context learning in their forward passes. We gradually build up from single linear self-attention layers to multi-layer nonlinear models, approaching full Transformers. In this section, we follow the assumption of Proposition 1 tightly and construct our tokens by concatenating input and target data, ej=(xj,yj)e_{j}=(x_{j},y_{j}) for 1jN1\leq j\leq N, and our query token by concatenating the test input and a zero vector, eN+1=(xtest,0)e_{N+1}=(x_{\text{test}},0). We show how to lift this assumption in the last section of the paper. The prediction y^θ({eτ,1,,eτ,N},eτ,N+1)\hat{y}_{\theta}(\{e_{\tau,1},\ldots,e_{\tau,N}\},e_{\tau,N+1}) of the attention-based model, which depends on all tokens and on the parameters θ\theta, is read-out from the yy-entry of the updated N+1N+1-th token as explained in the previous section.

where each task (context) τ\tau consists of in-context training data Dτ={(xτ,i,yτ,i)}i=1ND_{\tau}=\{(x_{\tau,i},y_{\tau,i})\}_{i=1}^{N} and test point (xτ,N+1,yτ,N+1)(x_{\tau,N+1},y_{\tau,N+1}), which we use to construct our tokens {eτ,i}i=1N+1\{e_{\tau,i}\}_{i=1}^{N+1} as described above. We denote the optimal parameters found by this optimization process by θ\theta^{*}. In our setup, finding θ\theta^{*} may be thought of as meta-learning, while learning a particular task τ\tau corresponds to simply evaluating the model y^θ({eτ,1,,eτ,N},eτ,N+1)\hat{y}_{\theta}(\{e_{\tau,1},\ldots,e_{\tau,N}\},e_{\tau,N+1}). Note that we therefore never see the exact same training task twice during training. See Appendix A.12, especially Figure 16 for an analyses when using a fixed dataset size which we cycle over during training.

We focus on solvable tasks and similarly to Garg et al. (2022) generate data for each task using a teacher model with parameters WτN(0,I)W_{\tau}\sim\mathcal{N}(0,I). We then sample xτ,iU(1,1)nIx_{\tau,i}\sim U(-1,1)^{n_{I}} and construct targets using the task-specific teacher model, yτ,i=Wτxτ,iy_{\tau,i}=W_{\tau}x_{\tau,i}. In the majority of our experiments we set the dimensions to N=nI=10N=n_{I}=10 and nO=1n_{O}=1. Since we use a noiseless teacher for simplicity, we can expect our regression tasks to be well-posed and analytically solvable as we only compute a loss on the Transformers last token, which stands in contrast to usual autoregressive training and the training setup of Garg et al. (2022). Full details and results for training with a fixed training set size may be found in Appendix A.12.

Our first goal is to investigate whether a trained single, linear self-attention layer can be explained by the provided weight construction that implements GD. To that end, we compare the predictions made by a LSA layer with trained weights θ\theta^{*} (which minimize equation 5) and with constructed weights θGD\theta_{\text{GD}} (which satisfy Proposition 1).

Recall that a LSA layer yields the prediction y^θ(xtest)=eN+1+LSAθ({e1,,eN},eN+1)=ΔWθ,Dxtest\hat{y}_{\theta}(x_{\text{test}})=e_{N+1}+\text{LSA}_{\theta}(\{e_{1},\ldots,e_{N}\},e_{N+1})=\Delta W_{\theta,D}x_{\text{test}}, which is linear in xtestx_{\text{test}}. We denote by ΔWθ,D\Delta W_{\theta,D} the matrix generated by the LSA layer following the construction provided in Proposition 1, with query token eN+1e_{N+1} set such that the initial prediction is set to zero, y^test=0\hat{y}_{\text{test}}=0. We compare y^θ(xtest)\hat{y}_{\theta}(x_{\text{test}}) to the prediction of the control LSA y^θGD(xtest)\hat{y}_{\theta_{\text{GD}}}(x_{\text{test}}), which under our token construction corresponds to a linear model trained by one step of gradient descent starting from W0=0W_{0}=0. For this control model, we determine the optimal learning rate η\eta by minimizing L(η)\mathcal{L}(\eta) over a training set of 10410^{4} tasks through line search, with L(η)\mathcal{L}(\eta) defined analogously to equation 5.

More concretely, to compare trained and constructed LSA layers, we sample Tval=104T_{\text{val}}=10^{4} validation tasks and record the following quantities, averaged over validation tasks: (1) the difference in predictions measured with the L2 norm, y^θ(xτ,test)y^θGD(xτ,test)\|\hat{y}_{\theta}(x_{\tau,\text{test}})-\hat{y}_{\theta_{\text{GD}}}(x_{\tau,\text{test}})\|, (2) the cosine similarity between the sensitivities y^θGD(xτ,test)xtest\frac{\partial\hat{y}_{\theta_{\text{GD}}}(x_{\tau,\text{test}})}{\partial x_{\text{test}}} and y^θ(xτ,test)xtest\frac{\partial\hat{y}_{\theta}(x_{\tau,\text{test}})}{\partial x_{\text{test}}} as well as (3) their difference y^θGD(xτ,test)xtesty^θ(xτ,test)xtest\|\frac{\partial\hat{y}_{\theta_{\text{GD}}}(x_{\tau,\text{test}})}{\partial x_{\text{test}}}-\frac{\partial\hat{y}_{\theta}(x_{\tau,\text{test}})}{\partial x_{\text{test}}}\| again according to the L2 norm, which in both cases yields the explicit models computed by the algorithm. We show the results of these comparisons in Figure 2. We find an excellent agreement between the two models over a wide range of hyperparameters. We note that as we do not have direct access to the initialization of WW in the attention-based learners (it is hidden in θ\theta), we cannot expect the models to agree exactly.

Although the above metrics are important to show similarities between the resulting learned models (in-context vs. gradient-based), the underlying algorithms could still be different. We therefore carry out an extended set of analyses:

Interpolation. We take inspiration on recent work (Benzing et al., 2022; Entezari et al., 2021) that showed approximate equivalence of models found by SGD after permuting weights within the trained neural networks. Since our models are deep linear networks with respect to xtestx_{\text{test}} we only correct for scaling mismatches between the two models – in this case the construction that implements GD and the trained weights. As shown in Figure 2, we observe (and can actually inspect by eye, see Appendix Figure 9) that a simple scaling correction on the trained weights is enough to recover the weight construction implementing GD. This leads to an identical loss of GD, the trained Transformer and the linearly interpolated weights θI=(θ+θGD)/2\theta_{\text{I}}=(\theta+\theta_{\text{GD}})/2. See details in Appendix A.3 on how our weight correction and interpolation is obtained.

Out-of-distribution validation tasks. To test if our in-context learner has found a generalizable update rule, we investigate how GD, the trained LSA layer and its interpolation behave when providing in-context data in regimes different to the ones used during training. We therefore visualize the loss increase when (1) sampling the input data from U(α,α)NxU(-\alpha,\alpha)^{N_{x}} or (2) scaling the teacher weights by α\alpha as αW\alpha W when sampling validation tasks. For both cases, we set α=1\alpha=1 during training. We again observe that when training a single linear self-attention Transformer, for both interventions, the Transformer performs equally to gradient descent outside of this training setups, see Figure 2 as well Appendix Figure 6. Note that the loss obtained through gradient descent also starts degrading quickly outside the training regime. Since we tune the learning rate for the input range $$ and one gradient step, tasks with larger input range will have higher curvature and the optimal learning rate for smaller ranges will lead to divergence and a drastic increase in loss also for GD.

Repeating the LSA update. Since we claim that a single trained LSA layer implements a GD-like learning rule, we further test its behavior when applying it repeatedly, not only once as in training. After we correct the learning rate of both algorithms, i.e. for GD and the trained Transformer with a dampening parameter λ=0.75\lambda=0.75 (details in Appendix A.6), we see an identical loss decrease of both GD and the Transformer, see Figure 1.

To conclude, we present evidence that optimizing a single LSA layer to solve linear regression tasks finds weights that (approximately) coincide with the LSA-layer weight construction of Proposition 1, hence implementing a step of gradient descent, leading to the same learning capabilities on in- and out-of-distribution tasks. We comment on the random seed dependent phase transition of the loss during training in Appendix A.11.

Multiple steps of gradient descent vs. multiple layers of self-attention

We now turn to deep linear self-attention-only Transformers. The construction we put forth in Proposition 1, can be immediately stacked up over KK layers; in this case, the final prediction can be read out from the last layer as before by negating the yy-entry of the last test token: yN+1+k=1KΔyk,N+1=(yN+1k=1KΔyk,N+1)=yN+1+k=1KΔWkxN+1-y_{N+1}+\sum_{k=1}^{K}\Delta y_{k,N+1}=-(y_{N+1}-\sum_{k=1}^{K}\Delta y_{k,N+1})=y_{N+1}+\sum_{k=1}^{K}\Delta W_{k}x_{N+1}, where yk,N+1y_{k,N+1} are the test token values at layer kk, and Δyk,N+1\Delta y_{k,N+1} the change in the yy-entry of the test token after applying the kk-th step of self-attention, and ΔWk\Delta W_{k} the kk-th implicit change in the underlying linear model parameters WW. When optimizing such Transformers with KK layers, we observe that these models generally outperform KK steps of plain gradient descent, see Figure 3. Their behavior is however well described by a variant of gradient descent, for which we tune a single parameter γ\gamma defined through the transformation function H(X)H(X) which transforms the input data according to xjH(X)xjx_{j}\leftarrow H(X)x_{j}, with H(X)=(IγXXT)H(X)=(I-\gamma XX^{T}). We term this gradient descent variant GD++ which we explain and analyze in Appendix A.10.

To analyze the effect of adding more layers to the architecture, we first turn to the arguably simplest extension of a single SA layer and analyze a recurrent or looped 2-layer LSA model. Here, we simply repeatably apply the same layer (with the same weights) multiple times i.e. drawing the analogy to learning an iterative algorithm that applies the same logic multiple times.

Somewhat surprisingly, we find that the trained model surpasses plain gradient descent, which also results in decreasing alignment between the two models (see center left column), and the recurrent Transformer realigns perfectly with GD++ while matching its performance on in- and out-of distribution tasks. Again, we can interpolate between the Transformer weights found by optimization and the LSA-weight construction with learned η,γ\eta,\gamma, see Figure 3 & 6.

We next consider deeper, non-recurrent 5-layer LSA-only Transformers, with different parameters per layer (i.e. no weight tying). We see that a different GD learning rate as well as γ\gamma per step (layer) need to be tuned to match the Transformer performance. This slight modification leads again to almost perfect alignment between the trained TF and GD++ with in this case 10 additional parameters and loss close to 0, see Figure 3. Nevertheless, we see that the naive correction necessary for model interpolation used in the aforementioned experiments is not enough to interpolate without a loss increase. We leave a search for better weight corrections to future work. We further study Transformers with different depths for recurrent as well as non-recurrent architectures with multiple heads and equipped with MLPs, and find qualitatively equivalent results, see Appendix Figure 7 and Figure 8. Additionally, in Appendix A.9, we provide results obtained when using softmax SA layers as well as LayerNorm, thus essentially retrieving the standard Transformer architecture. We again observe and are able to explain (after slight architectural modifications) good learning performance and as well as alignment with the construction of Proposition 1, though worse than when using linear self-attention. These findings suggest that the in-context learning abilities of the standard Transformer with these common architecture choices can be explained by the gradient-based learning hypothesis explored here. Our findings also question the ubiquitous use of softmax attention, and suggest further investigation is warranted into the performance of linear vs. softmax SA layers in real-world learning tasks, as initiated by Schlag et al. (2021).

Transformers solve nonlinear regression tasks by gradient descent on deep data representations

It is unreasonable to assume that the astonishing in-context learning flexibility observed in large Transformers is explained by gradient descent on linear models. We now show that this limitation can be resolved by incorporating one additional element of fully-fledged Transformers: preceding self-attention layers by MLPs enables learning linear models by gradient descent on deep representations which motivates our illustration in Figure 1. Empirically, we demonstrate this by solving non-linear sine-wave regression tasks, see Figure 4. Experimental details can be found in Appendix A.7. We state

Given a Transformer block i.e. a MLP m(e)m(e) which transforms the tokens ej=(xj,yj)e_{j}=(x_{j},y_{j}) followed by an attention layer, we can construct weights that lead to gradient descent dynamics descending 12Ni=1NWm(xi)yi2\frac{1}{2N}\sum_{i=1}^{N}||Wm(x_{i})-y_{i}||^{2}. Iteratively applying Transformer blocks therefore can solve kernelized least-squares regression problems with kernel function k(x,y)=m(x)m(y)k(x,y)=m(x)^{\top}m(y) induced by the MLP m()m(\cdot).

A detailed discussion on this form of kernel regression as well as kernel smoothing w/wo softmax nonlinearity through gradient descent on the data can be found in Appendix A.8. The way MLPs transform data in Transformers diverges from the standard meta-learning approach, where a task-shared input embedding network is optimized by backpropagation-through-training to improve the learning performance of a task-specific readout (e.g., Raghu et al., 2020; Lee et al., 2019; Bertinetto et al., 2019). On the other hand, given our token construction in Proposition 1, MLPs in Transformers intriguingly process both inputs and targets. The output of this transformation is then processed by a single linear self-attention layer, which, according to our theory, is capable of implementing gradient descent learning. We compare the performance of this Transformer model, where all weights are learned, to a control Transformer where the final LSA weights are set to the construction θGD\theta_{\text{GD}} which is therefore identical to training an MLP by backpropagation through a GD updated output layer.

Intriguingly, both obtained functions show again surprising similarity on (1) the initial (meta-learned) prediction, read out after the MLP, and (2) the final prediction, after altering the output of the MLP through GD or the self-attention layer. This is again reflected in our alignment measures that now, since the obtained models are nonlinear w.r.t. xtestx_{\text{test}}, only represent the two first parts of the Taylor approximation of the obtained functions. Our results serve as a first demonstration of how MLPs and self-attention layers can interplay to support nonlinear in-context learning, allowing to fine-tune deep data representations by gradient descent. Investigating the interplay between MLPs and SA-layer in deep TFs is left for future work.

Do self-attention layers build regression tasks?

The construction provided in Proposition 1 and the previous experimental section relied on a token structure where both input and output data are concatenated into a single token. This design is different from the way tokens are typically built in most of the related work dealing with simple few-shot learning problems as well as in e.g. language modeling. We therefore ask: Can we overcome the assumption required in Proposition 1 and allow a Transformer to build the required token construction on its own? This motivates

Given a 1-head linear or softmax attention layer and the token construction e2j=(xj),e2j+1=(0,yj)e_{2j}=(x_{j}),e_{2j+1}=(0,y_{j}) with a zero vector of dim NxNyN_{x}-N_{y} and concatenated positional encodings, one can construct key, query and value matrix WK,WQ,WVW_{K},W_{Q},W_{V} as well as the projection matrix PP such that all tokens eje_{j} are transformed into tokens equivalent to the ones required in Proposition 1.

The construction and its discussion can be found in Appendix A.5. To provide evidence that copying is performed in trained Transformers, we optimize a two-layer self-attention circuit on in-context data where alternating tokens include input or output data i.e. e2j=(xj)e_{2j}=(x_{j}) and e2j+1=(0,yj)e_{2j+1}=(0,y_{j}). We again measure the loss as well as the mean of the norm of the partial derivative of the first layer’s output w.r.t. the input tokens during training, see Figure 5. First, the training speeds are highly variant given different training seeds, also reported in Garg et al. (2022). Nevertheless, the Transformer is able to match the performance of a single (not two) step gradient descent. Interestingly, before the Transformer performance jumps to the one of GD, token eje_{j} transformed by the first self-attention layer becomes notably dependant on the neighboring token ej+1e_{j+1} while staying independent on the others which we denote as eothere_{\text{other}} in Figure 5.

We interpret this as evidence for a copying mechanism of the Transformer’s first layer to merge input and output data into single tokens as required by Proposition 1. Then, in the second layer the Transformer performs a single step of GD. Notably, we were not able to train the Transformer with linear self-attention layers, but had to incorporate the softmax operation in the first layer. These preliminary findings support the study of Olsson et al. (2022) showing that softmax self-attention layers easily learn to copy; we confirm this claim, and further show that such copying allows the Transformer to proceed by emulating gradient-based learning in the second or deeper attention layers.

We conclude that copying through (softmax) attention layers is the second crucial mechanism for in-context learning in Transformers. This operation enables Transformers to merge data from different tokens and then to compute dot products of input and target data downstream, allowing for in-context learning by gradient descent to emerge.

Discussion

Transformers show remarkable in-context learning behavior. Mechanisms based on attention, associative memory and copying by induction heads are currently the leading explanations for this remarkable feature of learning within the Transformer forward pass. In this paper, we put forward the hypothesis, similar to Garg et al. (2022) and Akyürek et al. (2023), that Transformer’s in-context learning is driven by gradient descent, in short – Transformers learn to learn by gradient descent based on their context. Viewed through the lens of meta-learning, learning Transformer weights corresponds to the outer-loop which then enables the forward pass to transform tokens by gradient-based optimization.

To provide evidence for this hypothesis, we build on Schlag et al. (2021) that already provide a linear self-attention layer variant with (fast-)inner loop learning by the error-correcting delta rule (Widrow & Hoff, 1960). We diverge from their setting and focus on (in-context) learning where we specifically construct a dataset by considering neighboring elements in the input sequence as input- and target training pairs, see assumptions of Proposition 1. This construction could be realized, for example, due to the model learning to implement a copying layer, see section 4 and proposition 3, and allows us to provide a simple and different construction to Schlag et al. (2021) that solely is built on the standard linear, and approximately softmax, self-attention layer but still implements gradient descent based learning dynamics. We, therefore, are able to explain gradient descent based learning in these standard architectures. Furthermore, we extend this construction based on a single self-attention layer and provide an explanation of how deeper K-layer Transformer models implement principled K-step gradient descent learning, which deviates again from Schlag et al. and allows us to identify that deep Transformers implement GD++, an accelerated version of gradient descent.

We highlight that our construction of gradient descent and GD++ is not suggestive but when training multi-layer self-attention-only Transformers on simple regression tasks, we provide strong evidence that the construction is actually found. This allows us, at least in our restricted problems settings, to explain mechanistically in-context learning in trained Transformers and its close resemblance to GD observed by related work. Further work is needed to incorporate regression problems with noisy data and weight regularization into our hypothesis. We speculate aspects of learning in these settings are meta-learned – e.g., the weight magnitudes to be encoded in the self-attention weights. Additionally, we did not analyze logistic regression for which one possible weight construction is already presented in Zhmoginov et al. (2022).

Our refined understanding of in-context learning based on gradient descent motives us to investigate how to improve it. We are excited about several avenues of future research. First, to exceed upon a single step of gradient descent in every self-attention layer it could be advantageous to incorporate so called declarative nodes (Amos & Kolter, 2017; Bai et al., 2019; Gould et al., 2021; Zucchet & Sacramento, 2022) into Transformer architectures. This way, we would treat a single self-attention layer as the solution of a fully optimized regression loss leading to possibly more efficient architectures. Second, our findings are restricted to small Transformers and simple regression problems. We are excited to delve deeper into research trying to understand how further mechanistic understanding of Transformers and in-context learning in larger models is possible and to what extend. Third, we are excited about targeted modifications to Transformer architectures, or their training protocols, leading to improved gradient descent based learning algorithms or allow for alternative in-context learners to be implemented within Transformer weights, augmenting their functionality, as e.g. in Dai et al. (2023). Finally, it would be interesting to analyze in-context learning in HyperTransformers (Zhmoginov et al., 2022) that produce weights for target networks and already offer a different perspective on merging Transformers and meta-learning. There, Transformers transform weights instead of data and could potentially allow for gradient computations of weights deep inside the target network lifting the limitation of GD on linear models analyzed here.

João Sacramento and Johannes von Oswald deeply thank Angelika Steger for her support and guidance. The authors also thank Seijin Kobayashi, Marc Kaufmann, Nicolas Zucchet, Yassir Akram, Guillaume Obozinski and Mark Sandler for many valuable insights throughout the project and Dale Schuurmans and Timothy Nguyen for their valuable comments on the manuscript. João Sacramento was supported by an Ambizione grant (PZ00P3_186027) from the Swiss National Science Foundation and an ETH Research Grant (ETH-23 21-1).

References

Appendix A Appendix

First, we highlight the dependency on the tokens eie_{i} of the linear self-attention operation

with \otimes the outer product between two vectors. With this we can now easily draw connections to one step of gradient descent on L(W)=12Ni=1NWxiyi2L(W)=\frac{1}{2N}\sum_{i=1}^{N}\|Wx_{i}-y_{i}\|^{2} with learning rate η\eta which yields weight change

Given a 1-head linear attention layer and the tokens ej=(xj,yj)e_{j}=(x_{j},y_{j}), for j=1,,Nj=1,\ldots,N, one can construct key, query and value matrices WK,WQ,WVW_{K},W_{Q},W_{V} as well as the projection matrix PP such that a Transformer step on every token eje_{j} is identical to the gradient-induced dynamics ej(xj,yj)+(0,ΔWxj)=(xi,yi)+PVKTqje_{j}\leftarrow(x_{j},y_{j})+(0,-\Delta Wx_{j})=(x_{i},y_{i})+P\,VK^{T}q_{j} such that ej=(xj,yjΔyj)e_{j}=(x_{j},y_{j}-\Delta y_{j}). For the test data token (xN+1,yN+1)(x_{N+1},y_{N+1}) the dynamics are identical.

for every token ej=(xj,yj)e_{j}=(x_{j},y_{j}) including the query token eN+1=etest=(xtest,W0xtest)e_{N+1}=e_{\text{test}}=(x_{\text{test}},-W_{0}x_{\text{test}}) which will give us the desired result.

A.2 Comparing the out-of-distribution behavior of trained Transformers and GD

We provide more experimental results when comparing GD with tuned learning rate η\eta and data transformation scalar γ\gamma and the trained Transformer on other data distributions than provided during training, see Figure 6. We do so by changing the in-context data distribution and measure the loss of both methods averaged over 10.000 tasks when either changing α\alpha that 1) affects the input data range xU(α,α)Nxx\sim U(-\alpha,\alpha)^{N_{x}} or 2) the teacher by αW\alpha W with WN(0,I)W\sim\mathcal{N}(0,I). This setups leads to results shown in the main text, in the first two columns of Figure 6 and in the corresponding plots of Figure 7. Although the match for deeper architectures starts to become worse, overall the trained Transformers behaves remarkably similar to GD and GD++ for layer depth greater than 1.

Furthermore, we try GD and the trained Transformer on input distributions that it never has seen during training. Here, we chose by chance of 1/31/3 either a normal, exponential or Laplace distribution (with JAX default parameters) and depict the average loss value over 10.000 tasks where the α\alpha value now simply scales the input values that are sampled from one of the distributions αx\alpha x. The teacher scaling is identical to the one described above. See for results the two right columns of Figure 6, where we see almost identical behavior for recurrent architectures with less good match for deeper non-recurrent architectures far away from the training range of α=1\alpha=1. Note that for deeper Transformers (K>2K>2) the corresponding GD and GD++ version, see for more experimental details Appendix section A.12, we include a harsh clipping of the token values after every step of transformation between $$ (for the trained TF and GD) to improve training stability. Therefore, the loss increase is restricted to a certain value and plateaus.

A.3 Linear mode connectivity between the weight construction of Prop 1 and trained Transformers

In order to interpolate between the construction θGD\theta_{\text{GD}} and the trained weights of the Transformer θ\theta, we need to correct for some scaling ambiguity. For clarification, we restate here the linear self-attention operation for a single head

Now, to match the weight construction of Prop. 1 we have the aim for the matrix product WKQW_{KQ} to match an identify matrix (except for the last diagonal entry) after re-scaling. Therefore we compute the mean of the diagonal of the matrix product of the trained Transformer weights WKQW_{KQ} which we denote by β\beta. After resealing both operations i.e. WKQWKQ/βW_{KQ}\leftarrow W_{KQ}/\beta and WPVWPVβW_{PV}\leftarrow W_{PV}\beta we interpolate linearly between the matrix products of GD as well as these rescaled trained matrix products i.e. WI,KQ=(WGD,KQ+WTF,KQ)/2W_{I,KQ}=(W_{GD,KQ}+W_{TF,KQ})/2 as well as WI,PV=(WGD,PV+WTF,PV)/2W_{I,PV}=(W_{GD,PV}+W_{TF,PV})/2. We use these parameters to obtain results throughout the paper denote with Interpolated. We do so for GD as well as GD++ when comparing to recurrent Transformers. Note that for non-recurrent Transformers, we face more ambiguity that we have to correct for since e.g. scalings influence each other across layer. We also see this in practice and are not able (only for some seeds) to interpolate between weights with our simple correction from above. We leave the search for more elaborate corrections for future work.

A.4 Visualizing the trained Transformer weights

The simplicity of our construction enables us to visually compare trained Transformers and the construction put forward in Proposition A.1 in weight space. As discussed in the previous section A.3 there is redundancy in the way the trained Transformer can construct the matrix products leading to the weights corresponding to gradient descent. We therefore visualize WKQ=WKTWQW_{KQ}=W^{T}_{K}W_{Q} as well as WPV=PKWVW_{PV}=P_{K}W_{V} in Figure 9.

A.5 Proof and discussion of Proposition 3

We state here again Proposition 3, provide the necessary construction and a short discussion.

Given a 1-head linear- or softmax attention layer and the token construction e2j=(xj),e2j+1=(0,yj)e_{2j}=(x_{j}),e_{2j+1}=(0,y_{j}) with a zero vector of dim NxNyN_{x}-N_{y} and concatenated positional encodings, one can construct key, query and value matrix WK,WQ,WVW_{K},W_{Q},W_{V} as well as the projection matrix PP such that all tokens eje_{j} are transformed into tokens equivalent to the ones required in proposition 1.

This means that a token replaces its own positional encoding by coping the target data of the next token to itself leading to ej=(xj/2,0,yj/2+1)e_{j}=(x_{j/2},0,y_{j/2+1}), with slight abusive of notation. This can simply be realized by (for example) setting P=IP=I, W_{V}=\left(\begin{array}[]{@{}c c@{}}0&0\\ I_{x}&-I_{x,off}\end{array}\right),W_{K}=\left(\begin{array}[]{@{}c c@{}}0&0\\ 0&I_{x}\end{array}\right) and W_{Q}=\left(\begin{array}[]{@{}c c@{}}0&0\\ 0&I_{x,off}^{T}\end{array}\right) with Ix,offI_{x,off} the lower diagonal identity matrix fo size NxN_{x}. Note that then simply KTWQej=pj+1K^{T}W_{Q}e_{j}=p_{j+1} i.e. it chooses the j+1j+1 element of VV which stays pj+1p_{j+1} if we apply the softmax operation on KTqjK^{T}q_{j}. Since the j+1j+1 entry of VV is (0,yj/2+1pj)(0,y_{j/2+1}-p_{j}) we obtain the desired result.

For the (toy-)regression problems considered in this manuscript, the provided result would give N/2N/2 tokens for which we also copy (parts) of xjx_{j} underneath yjy_{j}. This is desired for modalities such as language where every two tokens could be considered an in-and output pair for the implicit autoregressive inner-loop loss. These tokens do not have be necessarily next to each other, see for this behavior experimental findings presented in (Olsson et al., 2022). For the experiments conducted here, one solution is to zero out these tokens which could be constructed by a two-head self-attention layer that given uneven jj simply subtracts itself resulting in a zero token. For all even tokens, we use the construction from above which effectively coincides with the token construction required in Proposition 1.

A.6 Dampening the self-attention layer

As an additional out-of-distribution experiment, we test the behavior when repeating a single LSA-layer trained to lower our objective, see equation 5, with the aim to repeat the learned learning/update rule. Note that GD as well as the self-attention layer were optimized to be optimal for one step. For GD we line search the otpimal learning rate η\eta on 10.000 task. Interestingly, for both methods we observe quick divergence when applied multiple times, see left plot of Figure 10. Nevertheless, both of our update functions are described by a linear self-attention layer for which we can control the norm, post training, by a simple scale which we denote as λ\lambda. This results in the new update ytest+λΔWxtesty_{\text{test}}+\lambda\Delta Wx_{\text{test}} for GD and ytest+λPVKTWQxtesty_{\text{test}}+\lambda PVK^{T}W_{Q}x_{\text{test}} for the trained self-attention layer which effectively re-tunes the learning rate for GD and the trained self-attention layer. Intriguingly, both methods do generalize similarly well (or poorly) on this out-of-distribution experiment when changing λ\lambda, see again Figure 10. We show in Figure 1 the behavior for λ=0.75\lambda=0.75 for which we see both methods steadily decreasing the loss within 50 steps.

A.7 Sine wave regression

For the sine wave regression tasks, we follow (Finn et al., 2017) and other meta-learning literature and sample for each task an amplitude aU(0.1,5)a\sim U(0.1,5) and a phase ρU(0,π)\rho\sim U(0,\pi). Each tasks consist of N=10N=10 data points where inputs are sampled xU(5,5)x\sim U(-5,5) and targets computed by y=asin(ρ+x)y=a\sin(\rho+x). We choose here for the first time, for GD as well as for the Transformer, an input embedding emb that maps tokens ei=(xi,yi)e_{i}=(x_{i},y_{i}) into a 4040 dimensional space emb(ei)=Wembei\text{emb}(e_{i})=W_{\text{emb}}e_{i} through an affine projection without bias. We skip the first self-attention layer but, as usually done in Transformers, then transform the embedded tokens through an MLP mm with a single hidden layer, widening factor of 4 (160 hidden neuros) and GELU nonlinearity (Hendrycks & Gimpel, 2016) i.e. ejm(emb(ej))+emb(ej)e_{j}\leftarrow m(\text{emb}(e_{j}))+\text{emb}(e_{j}).

We interpret the last entry of the transformed tokens as the (transformed) targets and the rest as a higher-dimensional input data representation on which we train a model with a single gradient descent step. We compare the obtained meta-learned GD solution with training a Transformer on the same token embeddings but instead learn a self-attention layer. Note that the embeddings of the tokens, including the transformation through the MLP, are not dependent on an interplay between the tokens. Furthermore, the initial transformation is dependent on ei=(xi,yi)e_{i}=(x_{i},y_{i}), i.e., input as well as on the target data except for the query token for which ytest=0y_{\text{test}}=0. This means that this construction is, except for the additional dependency on targets, close to a large corpus of meta-learning literature that aims to find a deep representation optimized for (fast) fine tuning and few-shot learning. In order to compare the meta-training of the MLP and the Transformer, we choose the same seed to initialize the network weights for the MLPs and the input embedding trained by meta-learning i.e. backprop through training or the Transformer. This leads to the plots and almost identical learned initial function and updated functions shown in Figure 4.

A.8 Proposition 2 and connections between gradient descent, kernelized regression and kernel smoothing

Interestingly, for the test token etest=(xtest,0)e_{\text{test}}=(x_{\text{test}},0) this induces, after a multiplication with 1-1, an initial prediction after a single Transformer block given by

A.9 Linear vs. softmax self-attention as well LayerNorm Transformers

Although linear Transformers and their variants have been shown to be competitive with their softmax counterpart (Irie et al., 2021), the removal of this nonlinearity is still a major departure from classic Transformers and more importantly from the Transformers used in related studies analyzing in-context learning. In this section we investigate whether and when gradient-based learning emerges in trained softmax self-attention layers, and we provide an analytical argument to back our findings.

First, we show, see Figure 12, that a single layer of softmax self-attention is not able to match GD performance. We tuned the learning rate as well as the weight initialization but found no significant difference over the hyperparameters we used througout this study. In general, we hypothesize that GD is an optimal update given the limited capacity of a single layer of (single-head) self-attention. We therefore argue that the softmax induces (at best) a linear offset of the matrix product of training data and query vector

proportional to a factor dependent on all {xτ,i}i=1N+1\{x_{\tau,i}\}_{i=1}^{N+1}. We speculate that the dependency on the specific task τ\tau, for large NxN_{x} vanishes or that the xx-dependent value matrix could introduce a correcting effect. In this case the softmax operation introduces an additive error w.r.t. to the optimal GD update. To overcome this disadvantageous offset, the Transformer can (approximately) introduce a correction with a second self-attention head by a simple subtraction i.e.

Here we assume that PVPV 1) subsumes the dividing factor of the softmax and that 2) is the same (up to scaling) for each head. Note that if (W1,KQW2,KQ)(W_{1,KQ}-W_{2,KQ}) is diagonal, and PP and VV chosen as in the Proposition of Appendix A.1, we recover our gradient descent construction.

We base this derivation on empirical findings, see Figure 12, that, first of all, show the softmax self-attention performance increases drastically when using two heads instead of one. Nevertheless, the self-attention layer has difficulties to match the loss values of a model trained with GD. Furthermore, this architecture change leads to a very much improved alignment of the trained model and GD. Second, we can observe that when training a two-headed softmax self-attention layer on regression tasks the correction proposed above is actually observed in weight space, see Figure 11. Here, we visualize the matrix product within the softmax operation Wh,KQW_{h,KQ} per head which we scale with the last diagonal entry of PhWh,VP_{h}W_{h,V} which we denote by ηh=PhWh,V(1,1)\eta_{h}=P_{h}W_{h,V}(-1,-1). Intriguingly, this results in an almost perfect cancellation (right plot) of the off-diagonal terms and therefore in sum to an improved approximation of our construction, see the derivation above.

We would like to reiterate that the stronger inductive bias for copying data of the softmax layer remains, and is not invalidated by the analysis above. Therefore, even for our shallow and simple constructions they indeed fulfill an important role in support for our hypotheses: The ability to merge or copy input and target data into single tokens allowing for their dot product computation necessary for the construction in Proposition 1, see Section 4 in the main text.

We end this section by analysing Transformers equipped with LayerNorm which we apply as usually done before the self-attention layer: Overall, we observe qualitatively similar results to Transformers with softmax self-attention layer i.e. a decrease in performance compared to GD accompanied with a decrease in alignment between models generated by the Transformer and models trained with GD, see Figure 14. Here, we test again a single linear self-attention layer succeeding LayerNorm as well as two layers where we skip the first LayerNorm and only include a LayerNorm between the two. Including more heads does not help substantially. We again assume the optimality of GD and argue that information of targets and inputs present in the tokens is lost by averaging when applying LayerNorm. This naturally leads to decreasing performance compared to GD, see first row of Figure 14. Although the alignment to GD and GD++, especially for two layers, is high, we overall see inferior performance to one or two steps of GD or two steps of GD++. Nevertheless, we speculate that LayerNorm might not only stabilize Transformer training but could also act as some form of data normalization procedure that implicitly enables better generalization for larger inputs as well as targets provided in-context, see OOD experiments in Figure 14.

Overall we conclude that common architecture choices like softmax and LayerNorm seem supoptimal for the constructed in-context learning settings when comparing to GD or linear self-attention. Nevertheless, we speculate that the potentially small performance drops of in-context learning are negligible when turning to deep and wide Transformers for which these architecture choices have empirically proven to be superior.

A.10 Details of curvature correction

We give here a precise construction showing how to implement in a single head, a step of GD and the discussed data transformation, resulting in GD++. Recall again the linear self-attention operation with a single head

for every token ej=(xj,yj)e_{j}=(x_{j},y_{j}) including the query token eN+1=etest=(xtest,0)e_{N+1}=e_{\text{test}}=(x_{\text{test}},0) which will give us the desired result.

Why does GD++ perform better? We give here one possible explanation of the superior performance of GD++ compared to GD. Note that there is a close resemblance of the GD transformation and a heavily truncated Neuman series approximation of the inverse XXTXX^{T}. We provide here a more heuristic explanation for the observed acceleration.

Given the original Hessian H=XXT=UΣUTH=XX^{T}=U\Sigma U^{T} with it’s set of sorted eigenvalues {λ1,,λn}\{\lambda_{1},\dots,\lambda_{n}\} and λi0\lambda_{i}\geq 0 on the diagonal matrix Σ\Sigma we can express the new Hessian through U,ΣU,\Sigma i.e. H++=(IγXXT)X((IγXXT)X)T=(IγUΣUT)UΣUT(IγUΣUT)TH^{++}=(I-\gamma XX^{T})X((I-\gamma XX^{T})X)^{T}=(I-\gamma U\Sigma U^{T})U\Sigma U^{T}(I-\gamma U\Sigma U^{T})^{T}.

Given the eigenspectrum {λ1,,λn}\{\lambda_{1},\dots,\lambda_{n}\} of HH, we obtain an (unsorted) eigenspecturm for H++H^{++} with {λ12γλ12+γ2λ13,,λn2γλn2+γ2λn3}\{\lambda_{1}-2\gamma\lambda_{1}^{2}+\gamma^{2}\lambda_{1}^{3},\dots,\lambda_{n}-2\gamma\lambda_{n}^{2}+\gamma^{2}\lambda_{n}^{3}\} which we visualize in Figure 13 for different γ\gamma observed in practice. We hypotheses that the Transformer chooses γ\gamma in a way that on average, across the distribution of tasks, the data transformation (iteratively) decreases the condition number λ1/λn\lambda_{1}/\lambda_{n} leading to accelerated learning. This could be achieved, for example, by keeping the smallest eigenvalue λnλn++\lambda_{n}\approx\lambda_{n}^{++} fixed and choosing γ\gamma such that the largest eigenvalue of the transformed data λ1++\lambda_{1}^{++} is reduced, while the original λ1\lambda_{1} stays within [λ1++,λn++][\lambda_{1}^{++},\lambda_{n}^{++}].

Given the derived function of eigenvalue change f(λ,γ)f(\lambda,\gamma), we compute the condition number of H++H^{++} by dividing the novel maximum eigenvalues λ1++=f(1/(3γ),γ)\lambda_{1}^{++}=f(1/(3\gamma),\gamma) where λ=1/(3γ)\lambda=1/(3\gamma) as the local maximum of f(λ,γ)f(\lambda,\gamma), for fixed γ\gamma, and the novel minimum eigenvalue λn++=min(f(λ1,γ),f(λn,γ))\lambda_{n}^{++}=\min(f(\lambda_{1},\gamma),f(\lambda_{n},\gamma)). Note that with too small γ\gamma, we move the original λn\lambda_{n} closer to the root of f(λ,γ)f(\lambda,\gamma) i.e. λ=1/γ\lambda=1/\gamma and therefore can change the smallest eigenvalue.

Given the task distribution and its corresponding eigenvalue distribution, we see that choosing γ\gamma reduces the new condition number κ++=λ1++/λn++\kappa^{++}=\lambda_{1}^{++}/\lambda_{n}^{++} which leads to better conditioned learning, see center plot of Figure 13. Note that the optimal γ\gamma based on our derivation above is based on the maximum and minimum eigenvalue across all tasks and does not take the change of the eigenvalue distribution into account. We argue therefore that the simplicity of the arguments above does not capture the task statistics and distribution shifts entirely and therefore obtains a slightly larger γ\gamma as an optimal value. We furthermore visualize the condition number change for N=25N=25 and 10000 tasks in the right plot of Figure 13 and observe the distribution moving to desirable κ\kappa values close to 11. For γ\gamma values larger than 0.1 the distribution quickly exhibits exploding condition numbers.

A.11 Phase transitions

We comment shorty on the curiously looking phase transitions of the training loss observed in many of our experiments, see Figure 2. Nevertheless, simply switching from a single-headed self-attention layer to a two-headed self-attention layer mitigates the random seed dependent training instabilities in our experiments presented in the main text, see left and center plot of Figure 15.

Furthermore, these transitions look reminiscent of the recently observed ”grokking” behaviour (Power et al., 2022). Interestingly, when carefully tuning the learning rate and batchsize we can also make the Transformers trained in these linear regression tasks grokk. For this, we train a single Transformer block (self-attention layer and MLP) on a limited amount of data (8192 tasks), see right plot of Figure 15, and observe grokking like train and test loss phase transitions where test set first increases drastically before experiencing a sudden drop in loss almost matching the desired GD loss of  0.2~{}0.2. We leave a thorough investigation of these phenomena for future study.

A.12 Experimental details

We use for most experiments identical hyperparameters that were tuned by hand which we list here

Optimizer: Adam (Kingma & Ba, 2014) with default parameters and learning rate of 0.001 for Transformer with depth K<3K<3 and 0.0005 otherwise. We use a batchsize of 2048 and applied gradient clipping to obtain gradients with global norm of 1010. We used the Optax library.

Haiku weight initialisation (fan-in) with truncated normal and std 0.002/K0.002/K where KK the number of layers.

We did not use any regularisation and observed for deeper Transformers with K>2K>2 instabilities when reaching GD performance. We speculate that this occurs since the GD performance is, for the given training tasks, already close to divergence as seen when providing tasks with larger input ranges. Therefore, training Transformers also becomes instable when we approach GD with an optimal learning rate. In order to stabilize training, we simply clipped the token values to be in the range of $$.

When applicable we use standard positional encodings of size 2020 which we concatenated to all tokens.

For simplicity, and to follow the provided weight construction closely, we did use square key, value and query parameter matrix in all experiments.

The training length varied throughout our experimental setups and can be read off our training plots in the article.

When training meta-parameters for gradient descent i.e. η\eta and γ\gamma we used an identical training setup but usually training required much less iterations.

In all experiments we choose inital W0=0W_{0}=0 for gradient descent trained models.

Inspired by (Garg et al., 2022), we additionally provide results when training a single linear self-attention layer on a fixed number of training tasks. Therefore, we iterate over a single fixed batch of size BB instead of drawing new batch of tasks at every iteration. Results can be found in Figure 16. Intriguingly, we find that (meta-)gradient descent finds Transformer weights that align remarkable well with the provided construction and therefore gradient descent even when provided with an arguably very small number of training tasks. We argue that this again highlights the strong inductive bias of the LSA-layer to match (approximately) gradient descent learning in its forward pass.