Why Can GPT Learn In-Context? Language Models Implicitly Perform Gradient Descent as Meta-Optimizers
Damai Dai, Yutao Sun, Li Dong, Yaru Hao, Shuming Ma, Zhifang Sui, Furu Wei
Introduction
In recent years, large pretrained language models, especially in Transformer-based architectures (e.g., GPT; Brown et al. 2020), have shown strong emergent in-context learning (ICL) ability (Wei et al., 2022; Dong et al., 2023). Different from finetuning which needs additional parameter updates, ICL just needs several demonstration examples prepended before the query input, and then the model can predict labels for unseen inputs. On numerous downstream tasks, large GPT models can achieve surprising performance, which even exceeds smaller models with supervised finetuning. However, although ICL has achieved great performance, its working mechanism is still an open question to be investigated.
In this paper, we explain in-context learning as a process of meta-optimization and analyze connections between GPT-based in-context learning and finetuning. Concentrating on the attention modules, we figure out that the Transformer attention has a dual form of gradient descent. On top of it, we propose a novel perspective to explain in-context learning: (1) a pretrained GPT serves as a meta-optimizer; (2) it produces meta-gradients according to the demonstration examples through forward computation; (3) the meta-gradients are applied to the original language model through attention to build an ICL model. As illustrated in Figure 1, in-context learning and explicit finetuning share a dual view of gradient descent, where ICL produces meta-gradients through forward computation, while finetuning computes gradients by back-propagation. Therefore, it is reasonable to understand in-context learning as implicit finetuning.
In order to provide empirical evidence to support our understanding, we conduct comprehensive experiments based on real tasks. On six classification tasks, we compare the model predictions, attention outputs, attention weights to query tokens, and attention weights to training tokens between in-context learning and finetuning. Experimental results validate that the behavior of in-context learning is similar to explicit finetuning from multiple perspectives. These results are strong evidence to prove the reasonability of our understanding of in-context learning as implicit finetuning.
Further, inspired by the dual form between Transformer attention and gradient descent, we design a momentum-based attention, which regards the attention values as meta-gradients and applies the momentum mechanism (Polyak, 1964; Sutskever et al., 2013) to them. Experiments on both language modeling and in-context learning show that our momentum-based attention consistently outperforms vanilla attention, which supports our understanding of meta-optimization again from another perspective. We note that beyond this preliminary attempt, our understanding may have more potential to enlighten model design, which is worth investigating in the future.
Our contributions are summarized as follows:
We figure out a dual form between Transformer attention and gradient descent, and explain ICL as a process of meta-optimization.
We analyze connections between in-context learning and explicit finetuning and propose to understand ICL as implicit finetuning.
We provide several lines of empirical evidence to prove that ICL and explicit finetuning behave similarly from multiple perspectives.
We design a momentum-based attention and validate its effectiveness, which supports our understanding of meta-optimization again and shows the potential of our understanding to enlighten future model design.
Background
In this paper, we focus on ICL for classification tasks using GPT (Brown et al., 2020). A GPT model is stacked with identical Transformer (Vaswani et al., 2017) decoder layers where each layer consists of an attention module and a feed-forward network. For a classification task, given a query input text and a candidate answer set , we need to predict a label conditional on demonstration examples , where is an input-label pair different from the query one. Formally, given a GPT model , we first compute the probability of each answer :
Since the label space is restricted for classification, we predict the final answer by selecting the answer with the highest probability from the candidate answer set :
In practice, we usually use a pre-defined template to format the demonstrations and prepend them before the query input. Let be the function that formats an example, e.g.:
The contextual model input is organized like
Feeding this contextual input into , the probability of an answer is computed as
where denotes the output hidden state at the last token position; denotes the output word embedding of ; and is the logit corresponding to the -th answer.
2 Dual Form Between Attention and Linear Layers Optimized by Gradient Descent
where is derived from the historic output gradients by multiplying , the negative learning rate. Combing Equation (7) and Equation (8), we can derive the dual form of linear layers optimized by gradient descent:
where denotes the linear attention operation, in which we regard the historic output error signals as values, the historic inputs as keys, and the current input as the query.
Understanding In-Context Learning (ICL) as Implicit Finetuning
We first qualitatively analyze the Transformer attention under a relaxed linear attention form to figure out a dual form between it and gradient descent. Then, we compare in-context learning with explicit finetuning to analyze connections between these two optimization forms. Based on these theoretical findings, we propose to understand in-context learning as implicit finetuning.
We define as the initialized parameters to be updated since is the attention result in the zero-shot learning (ZSL) setting, where no demonstrations are given. Following the reverse direction of Equation (9), we derive a dual form of the Transformer attention:
As shown in the above equations, the attention to the demonstration tokens is equivalent to parameter updates that take effect on . In addition, by analogy with in Equation (9), we regard as meta-gradients, which are used to compute the update matrix .
In summary, we explain in-context learning as a process of meta-optimization: (1) a pretrained GPT model serves as a meta-optimizer; (2) it produces meta-gradients according to the demonstration examples through forward computation; (3) through attention, the meta-gradients are applied to the original language model to build an ICL model.
2 Comparing ICL with Finetuning
Based on the above understanding of in-context learning, we further compare the meta-optimization of in-context learning with the explicit optimization of finetuning to analyze connections between them. Considering that ICL directly takes effect on only the attention keys and values, we design a specific finetuning setting as the compared baseline, which also updates only the parameters for the key and value projection. Also in the relaxed linear attention form, the attention result of a finetuned head is formulated as
where and denote the parameter updates to and , respectively, which are acquired by back-propagation from task-specific training objectives; and is the updates to introduced by finetuning.
For a more fair comparison with in-context learning, we further restrict the finetuning setting as follows: (1) we specify the training examples as the demonstration examples for in-context learning; (2) we train each example for only one step in the same order as demonstrated for in-context learning; (3) we format each training example with the same template used for ICL and use the causal language modeling objective for finetuning.
Comparing in-context learning and this finetuning setting, we find that ICL has many properties in common with finetuning. We organize these common properties into the following four aspects.
Comparing Equation (12) and Equation (13), we find that both in-context learning and finetuning introduce updates ( v.s. ) to , which drive from implicit and explicit gradient descent, respectively. The main difference is that ICL produces meta-gradients by forward computation while finetuning acquires real gradients by back-propagation.
Same Training Information
The meta-gradients of ICL are produced according to the demonstration examples. The gradients of finetuning are also derived from the same training examples. That is to say, in-context learning and finetuning share the same source of training information.
Same Causal Order of Training Examples
In-context learning and our finetuning setting share the same causal order of training examples. ICL uses decoder-only Transformers so the subsequent tokens in the demonstrations will not affect the preceding ones. For our finetuning setting, we use the same order of training examples and train only one epoch, so we can also guarantee that the subsequent examples have no effect on the preceding ones.
Both Aim at Attention
Compared with zero-shot learning, the direct effect of in-context learning and our finetuning are both restricted to the computation of attention keys and values. For ICL, the model parameters are unchanged and it encodes demonstration information into additional keys and values to change the attention behavior. For finetuning, due to our restriction, the training information can be introduced to only the projection matrices for attention keys and values as well.
Considering the above common properties between in-context learning and finetuning, we show that it is reasonable to understand in-context learning as implicit finetuning. In the rest of this paper, we compare ICL and explicit finetuning empirically from multiple perspectives to provide quantitative results to support this understanding.
Experiments
We analyze two off-the-shelf pretrained GPT models with 1.3 billion and 2.7 billion model parameters, respectively, which are released by fairseqhttps://github.com/facebookresearch/fairseq. In the rest of this paper, we call them GPT 1.3B and GPT 2.7B for short. All experiments are conducted on NVIDIA V100 GPUs with 32 GB memory.
For each task, we use the same template to format examples for zero-shot learning (ZSL), finetuning (FT), and in-context learning (ICL). Details of the templates used for each task are provided in Appendix A. The answer prediction processes for ZSL and finetuning are the same with ICL as described in Section 2.1, except that they do not have demonstration examples.
For in-context learning, we fix the max number of demonstration examples to 32 and tune the random seed for each task to find a set of demonstration examples that achieves the best validation performance. For explicit finetuning, we use the same demonstration examples for in-context learning as the training examples and use SGD as the optimizer. For a fair comparison, we fine-tune the model for only one epoch and the training examples are provided in the same order as demonstrated for in-context learning. We tune the learning rate for finetuning and select the one that achieves the best validation performance. Details of the search range and selected value for the random seeds and learning rates are shown in Appendix B.
2 Evaluation Datasets
We compare in-context learning and finetuning based on six datasets spanning three sorts of classification tasks. SST2 (Socher et al., 2013), SST5 (Socher et al., 2013), MR (Pang and Lee, 2005) and Subj (Pang and Lee, 2004) are four datasets for sentiment classification; AGNews (Zhang et al., 2015) is a topic classification dataset; and CB (De Marneffe et al., 2019) is used for natural language inference. Statistics of the number of validation examples and label types are summarized in Table 1.
For reference, we present the validation accuracy in the ZSL, finetuning, and ICL settings on six classification datasets in Table 1. Compared with ZSL, ICL and finetuning both achieve considerable improvements, which means the optimizations they make are both helpful to these downstream tasks.
3 ICL Covers Most of Correct Predictions of Finetuning
We compute a recall to finetuning prediction (Rec2FTP) to measure ICL can cover how much behavior of finetuning from the perspective of the model prediction. We first count , the number of query examples that finetuning can predict correctly but ZSL cannot. Then, among these examples, we count , the number that ICL can also predict correctly. Finally, we compute the Rec2FTP score as . A higher Rec2FTP score suggests that ICL covers more correct behavior of finetuning from the perspective of the model prediction.
We show the Rec2FTP scores for two GPT models on six datasets in Table 2. As shown in the table, on average, ICL can correctly predict more than 85% of the examples that finetuning can correct from ZSL. These results indicate that from the perspective of model prediction, ICL can cover most of the correct behavior of finetuning.
4 ICL Tends to Change Attention Outputs in the Same Direction as Finetuning
From the perspective of representation, we compute a similarity of the attention output updates (SimAOU) to measure the similarity between the updates that ICL and finetuning make. For a query example, let denote the normalized output representation of the last token at the -th attention layer in setting X. The updates of ICL and finetuning compared with ZSL are and , respectively. We compute the cosine between these two updates to get SimAOU (FT) at the -th layer. A higher SimAOU (FT) means ICL is more inclined to update the attention output in the same direction as finetuning. For comparison, we also compute a baseline metric called SimAOU (Random ) that computes the similarity between ICL updates and randomly generated updates.
We present the SimAOU scores averaged across examples and layers for two GPT models on six datasets in Table 3. From the table, we find that SimAOU (Random ) is always around zero, while SimAOU (FT) remains much more positive. These results indicate that ICL updates are much more similar to finetuning updates than to random updates. From the perspective of representation, we prove that ICL tends to change the attention outputs in the same direction as finetuning.
5 ICL Is Inclined to Generate Similar Attention Weights to Finetuning
From the perspective of attention behavior, we compute a similarity of the attention map (SimAM) to measure the similarity of the attention map to query tokens for ICL and finetuning. For a query example, let denote the attention weights before softmax of the last token at the -th attention head in the -th attention layer in setting X. For ICL, we omit the attention to the demonstration tokens and only monitor the attention weights to the query tokens. First, before finetuning, we compute the cosine between and and then average the similarity across attention heads to get SimAM (Before Finetuning) at each layer. Similarly, after finetuning, we compute the cosine between and to get SimAM (After Finetuning). A higher SimAM (After Finetuning) over SimAM (Before Finetuning) indicates that the attention behavior of ICL is more similar to a finetuned model than a non-finetuned one.
Table 4 demonstrates the SimAM scores averaged across examples and layers for two GPT models on six datasets. We observe that compared with attention weights before finetuning, ICL is more inclined to generate similar attention weights to attention weights after finetuning. Again, from the perspective of attention behavior, we prove that ICL behaves similarly to finetuning.
6 ICL and Finetuning Tend to Pay Similar Attention to Training Tokens
Table 5 shows the Kendall correlation coefficients averaged across examples and layers for two GPT models on six datasets. We find that Kendall (ICL, Random) is always near zero, while Kendall (ICL, FT) always maintains a distinctly positive value. These results suggest that ICL and finetuning tend to pay similar attention to training tokens.
Momentum-Based Attention Inspired by Dual Form of Transformer Attention
We have figured out the dual form between Transformer attention and gradient descent. As illustrated in Figure 2, inspired by this dual view, we investigate whether we can utilize momentum (Polyak, 1964; Sutskever et al., 2013), a widely used technique for optimization algorithms, to improve Transformer attention.
Gradient descent with momentum averages gradients among timestamps:
where is the learning rate and is a scalar between 0 and 1. As stated in Section 3.1, the attention values serve as meta-gradients. By analogy with gradient descent with momentum, we try to use Exponential Moving Average (EMA; Hunter 1986) to average the attention values to build the momentum-based attention:
where is the -th attention value vector. The momentum of attention value vectors explicitly strengthens the recency bias of attention, which has been shown helpful for language modeling (Press et al., 2022). Therefore, we assume that introducing momentum into attention will contribute to faster convergence and better performance.
First, we evaluate the effect of momentum-based attention on language modeling. We train two GPT models with 350M parameters from scratch, where one is the vanilla Transformer, and another applies momentum to attention. More training details are provided in Appendix C. We evaluate the perplexity of these two models on the training set and three validation sets with input lengths of 256, 512, and 1024, respectively. The results are shown in Table 6. On all of the validation sets, applying momentum to attention introduces a consistent perplexity improvement compared with the vanilla Transformer.
Experiments on In-Context Learning
We also evaluate the in-context learning ability of the above language models to verify the effectiveness of momentum-based attention on downstream tasks. We consider six datasets for sentiment analysis (SST5 (Socher et al., 2013), IMDB (Maas et al., 2011), and MR (Pang and Lee, 2005)), natural language inference (CB (De Marneffe et al., 2019)), and multi-choice selection (ARC-E (Clark et al., 2018) and PIQA (Bisk et al., 2020)). For all of these datasets, we use up to 32 examples as demonstrations. As shown in Table 7, compared with vanilla Transformer, using momentum-based attention achieves consistently higher accuracy on all of these datasets.
The performance improvements on both language modeling and in-context learning prove our deduction that introducing momentum will improve Transformer attention. From another perspective, these results further support our understanding of Transformer attention as meta-optimization.
Related Work
Recently, some pieces of work have attempted to understand the inference mechanism of in-context learning. Xie et al. (2022) explain in-context learning as implicit Bayesian inference. They state that in-context learning emerges when language models can infer the shared latent concept among the demonstration examples, which is learned during pretraining. On another aspect, Olsson et al. (2022) focus on specific modules in Transformers. They find some induction heads in Transformers that refer to abstract patterns in previous sequences to help predict the next token. They indicate that the induction heads drive the ability of in-context learning. Different from them, we concentrate on the learning algorithm of ICL and explain it as a process of meta-optimization.
Some other work also studies the learning algorithm of ICL. As a case study, Garg et al. (2022) show that Transformers can be trained to in-context learn a class of linear functions and the performance is comparable to the least squares estimator. Based on linear regression, Akyürek et al. (2022) prove that they can construct parameters of Transformers to implement gradient-descent-based learning algorithms. Further, they show that models trained with an in-context learning objective tend to match the behavior of models computed by explicit learning algorithms. Also based on regression tasks, von Oswald et al. (2022) show that linear attention-only Transformers with constructed parameters that implement gradient descent and models learned by an in-context learning objective are highly related. Compared with them, we are the first ones to explain in-context learning in real scenarios. To be specific, (1) we analyze in-context learning for off-the-shelf GPT models, instead of models trained from scratch by an ICL objective; (2) our experiments are based on real NLP tasks, instead of toy ones like linear regression.
Conclusion
In this paper, we aim to explain the working mechanism of GPT-based ICL. Theoretically, we figure out a dual form between Transformer attention and gradient descent, and propose to understand ICL as a process of meta-optimization. Further, we analyze connections between ICL and explicit finetuning and show the reasonability to regard ICL as implicit finetuning. Empirically, we comprehensively compare ICL and finetuning based on six real NLP tasks. The results prove that ICL behaves similarly to explicit finetuning from multiple perspectives. Further, inspired by our understanding of meta-optimization, we design a momentum-based attention that achieves consistent performance improvements over vanilla attention. We believe our understanding will have more potential to enlighten ICL applications and model design in the future.
Limitations
Although the ability of in-context learning has been found for different architectures (e.g., Transformer and LSTM), we consider only Transformer-based in-context learning in this paper because Transformer is the current mainstream architecture of NLP. However, as for in-context learning itself, figuring out how it works for other architectures is also a meaningful problem, which we encourage to study in the future.
As for the dual form we point out between Transformer attention and gradient descent, we consider a relaxed form of linear attention for qualitative analysis. Although the experimental results support our understanding well, the mechanism of standard Transformer attention without approximation may be more complex and should be studied more clearly in the future.
As for empirical experiments, our analysis needs to record a large number of intermediate results (e.g., attention output representations, and attention weights to query tokens and demonstration tokens) for thousands of validation examples. Considering the storage space and computational cost of analysis, we only analyze GPT models with up to 2.7B parameters and leave larger models such as GPT 13B for future work. In addition, for the clarity of the problem definition and the convenience of experiments, our analysis is based on only classification tasks. Although classification is a representative application of in-context learning, other tasks like multiple choice and open-ended generation are not considered in this paper and could be investigated in the future.
Acknowledgement
Damai Dai and Zhifang Sui are supported by the National Key Research and Development Program of China 2020AAA0106700 and NSFC project U19A2065.
References
Appendix
Appendix A Templates for In-Context Learning
We demonstrate the templates used to format examples and the candidate answer sets for six classification datasets used in our experiments in Table 8.
Appendix B Hyper-Parameters for In-Context Learning and Finetuning
We perform grid search to find the best random seed for ICL and the best learning rate for finetuning. The search range for all the datasets is the same. For random seeds, we search in . For learning rates, the search base values are and we scale them to , , , and times, i.e., we have values to search. As an exception, for GPT 1.3B finetuned on SST5, we perform a more fine-grained search and finally set its learning rate to 0.00016 since the finetuned model cannot outperform the zero-shot learning with the above 36 learning rates.
In Table 9, we present the details of the selected random seeds and learning rates for two GPT models on six classification datasets.
Appendix C Hyper-Parameters for Training Language Models from Scratch
The hyper-parameters for training two language models from scratch are summarized in Table 10.