Jump to Conclusions: Short-Cutting Transformers With Linear Transformations
Alexander Yom Din, Taelin Karidi, Leshem Choshen, Mor Geva
Introduction
Transformer-based language models (LMs) process an input sequence of tokens by first representing it as a sequence of vectors and then repeatedly transforming it through a fixed number of attention and feed-forward network (FFN) layers Vaswani et al. (2017). While each transformation creates new token representations, only the final representations are used to obtain model predictions. Correspondingly, LM loss minimization directly optimizes the final representations, while hidden representations are only optimized implicitly, thus making their interpretation and usefulness more obscure.
However, utilizing hidden representations is highly desirable; a successful interpretation of them can shed light on the “decision-making process” in the course of transformer inference Tenney et al. (2019); Voita et al. (2019); Slobodkin et al. (2021); Geva et al. (2022b), and obtaining predictions from them can substantially reduce compute resources Schwartz et al. (2020); Xu et al. (2021).
Previous attempts to exploit hidden representations viewed the hidden representations of an input token as a sequence of approximations of its final representation Elhage et al. (2021); Geva et al. (2022b). This view is motivated by the additive updates induced via the residual connections He et al. (2016) around each layer in the network. Indeed, previous works Geva et al. (2021, 2022a); Ram et al. (2022); Alammar (2021) followed a simplifying assumption that representations at any layer can be transformed into a distribution over the output vocabulary by the output embeddings. While this approach has proven to be surprisingly effective for interpretability Geva et al. (2022a); Dar et al. (2022) and computation efficiency Schuster et al. (2022); Xin et al. (2020); Schwartz et al. (2020), it oversimplifies the model’s computation and assumes that all the hidden layers in the network operate in the same space.
We further test our approach in the context of language modeling (§4), checking how often predictions from final representation substitutes produced by mat agree with those of actual final representations. Experiments across two data sources and different scales of GPT-2 Radford et al. (2019) and BERT Devlin et al. (2019) show large accuracy gains (- at most layers) in prediction estimation by mat over naive projections (id). Moreover, we observe that our mappings often (in of the cases) produce correct predictions when applied to the very early layers in the network.
We then leverage these findings for improving model efficiency and demonstrate our method’s utility in the setting of early exiting, a strategy according to which one dynamically decides at which layer to stop the inference pass and use that layer’s representation as a final layer representation substitute. While previous works have utilized these hidden representations intact (i.e. using id), we transform them using mat, showing that our method performs better than the baseline in this setting as well (§5), allowing for the saving of additional (resp. ) of the layers for GPT-2 (resp. BERT) when aiming at accuracy.
Last, we analyze how well different parts of the transformer computation can be estimated linearly (§6). To this end, we apply the same methodology to replace the sub-modules of attention, FFN, and layer normalization with linear mappings. Interestingly, we find that linearly approximating attention, the only sub-module in the network that has contextual processing, results in the least reduction of precision. This hints at an interesting possibility of compute time reduction, since non-contextual inference is parallelizable.
To conclude, we propose a method for casting hidden representations across transformer layers, that is light to train, cheap to infer, and provides more accurate representation approximations than the hitherto implicitly accepted baseline of identical propagation. The method is not only appealing for model analysis but also has concrete applications for efficiency.
Background and Notation
These representations are then repeatedly transformed through transformer blocks, where each block outputs hidden representations that are the inputs to the next block:
are considered as the transformer stack’s output. These representations are used to form various predictions. In this work, we investigate whether and how hidden representations from earlier layers can be utilized for this purpose instead.
Linear Shortcut Across Blocks
2 Baseline
Notably, this commonly-used baseline assumes that representations at different layers operate in the same linear space.
3 Quality of Fit
We first evaluate our method by measuring how well the learned linear mappings approximate the representations at the target layer. To this end, we calculate the (coordinate-averaged) -score of our mapping’s outputs with respect to the “real” representations obtained from a full inference pass, and compare to the same for the id baseline.
We use GPT-2 Radford et al. (2019), a decoder-only auto-regressive LM, and BERT Devlin et al. (2019), an encoder-only model trained with masked language modeling. We conduct our evaluation over multiple scales of these models, with , , , and layers for GPT-2, and and layers for BERT. The plots presented in this section are for -layered GPT-2 and -layered BERT; in §D we gather the plots for the rest of the models.
Data.
As we observed similar results for the two data sources across all our experiments, throughout the paper we will report results for Wikipedia and provide the results for the news articles in §B. Further details on the data and models are provided in §A.
Evaluation.
Results.
Results for -layered GPT-2 and -layered BERT are presented in Fig. 2 and 3, respectively. In both models, mat consistently yields better approximations than id, as it obtains higher -scores (in blue) across the network.
This gap between mat and id is especially evident in BERT, where id completely fails to map the representations between most layers, suggesting that hidden representations are modified substantially by every transformer block. Overall, this highlights the shortcoming of existing practices to inspect representations in the same linear space, and the gains from using our method to approximate future layers in the network.
Linear Shortcut for Language Modeling
We saw that our method approximates future hidden representations substantially better than a naive propagation. In this section, we will show that this improvement also translates to better predictive abilities from earlier layers. Specifically, we will use our method to estimate how often intermediate representations encode the final prediction, in the context of two fundamental LM tasks; next token prediction and masked token prediction.
Precision@ ( is better): This checks whether the token with the highest probability according to appears in the top- tokens according to . Namely, we sort and assign a score of if appears in the top- tokens by , and otherwise.
Surprisal ( is better): We measure the minus log-probability according to , of the highest-probability token according to . Intuitively, low values mean that the model sees the substitute result as probable and hence not surprising.
We report the average Precision@ and Surprisal over the validation set .
1 Next Token Prediction
Auto-regressive LMs output for every position a probability distribution over the vocabulary for the next token. Specifically, the output distribution for every position is given by , where:
For some LMs, including GPT-2, a layer normalization ln_f is applied to the final layer representation before this conversion (i.e., computing rather than ).
Figs. 4 and 5 show the average Precision@ and Surprisal scores per layer in -layered GPT-2, respectively (the plots for the other GPT-2 models are presented in §D). Across all layers, mat outperforms id in terms of both scores, often by a large margin (e.g. till layer the Precision@ achieved by mat is bigger than that of id by more than ). This shows that linear mappings enable not just better estimation of final layer representations, but also of the predictions they induce. Moreover, the relatively high Precision@ scores of mat in early layers (- for , - for , and - for ) suggest that early representations already encode a good estimation of the final prediction. Also, the substantially lower Surprisal scores of mat compared to id imply that our method allows for a more representative reading into the layer-wise prediction-formation of the model than allowed through direct projection to the vocabulary.
2 Masked Token Prediction
We now conduct the same experiment for the task of masked language modeling, where the model predicts a probability distribution of a masked token in the input rather than the token that follows the input. Unlike next token prediction, where the output distribution is computed from representations of varying input tokens, in masked token prediction the output is always obtained from representations of the same input token (i.e. [MASK]).
Figs. 6 and 7 present the average Precision@ and Surprisal scores per layer in -layered BERT (the plots for the -layered BERT model are presented in §D), overall showing trends similar to those observed for next token prediction in GPT-2 (§4.1). This is despite the differences between the two tasks and the considerable architectural differences between BERT and GPT-2. Notably, the superiority of mat over id in this setting is even more prominent; while mat’s precision is between in the first ten layers (Fig. 6), id’s precision for all values of is close to zero, again strongly indicating that our method allows for better reading into early layer hidden representations. More generally, mat improves the Precision@ of id by more than at most layers, and unveils that a substantial amount of predictions ( starting from layer ) appear already in the very first layers. Interestingly, the (rough) divide between the first half of layers and last half of layers for id in Figs. 6, 7 seems to align with the two-hump shape of the blue region for mat in Fig. 3.
Analysis.
Implication to Early Exiting
The fact that it is often possible to approximate the final prediction already from early layers in the network has important implications to efficiency. Concretely, applying our linear mapping instead of executing transformer blocks of quadratic time complexity, could potentially save a substantial portion of the computation. In this section, we demonstrate this in the context of early exiting.
where is the average length of the input until position for , and is a confidence hyper-parameter.
We evaluate each variant in terms of both prediction’s accuracy, using the Precision@ metric (see §4), and efficiency, measured as the average number of layers processed during inference.
Results.
Figs. 8 and 9 plot the average Precision@ score against the average number of layers processed, for -layered GPT-2 and -layered BERT, respectively. For both models, under an early exit strategy our mapping mat again provides a substantial improvement over the baseline id. For example, aiming at average precision, mat saves (%) layers in GPT-2 compared to only (%) layers by id, and (%) layers in BERT versus (%) layers by id. These results highlight to potential gains prominent early exit methods can obtain by using our method. Notably, in both models and using each of the two mapping methods, early exit obtains better results than fixed layer exit, as expected.
Linear Shortcut Across Sub-modules
Our experiments show that, despite the commonly-applied simplification by interpretability works, transformer layers do not operate in the same linear space and there is a major gap in approximating future representations using an identity mapping (§3, §4). In this section, we investigate whether discrepancies across layers result from specific sub-modules or are a general behaviour of all sub-modules in the network.
This is done by extending our approach to test how well particular components in transformer blocks can be linearly approximated.
Discussing GPT-2 for definiteness, we have
and define a replacement of the attention sub-module (Eq. 3) by
Evaluation.
Results.
Related Work
Recently, there was a lot of interest in utilizing intermediate representations in transformer-based LMs, both for interpretability and for efficiency.
In the direction of interpretability, one seeks to understand the prediction construction process of the model Tenney et al. (2019); Voita et al. (2019).
More recent works use mechanistic interpretability and view the inference pass as a residual stream of information Dar et al. (2022); Geva et al. (2022b). Additionally, there are works on probing, attempting to understand what features are stored in the hidden representations Adi et al. (2017); Conneau et al. (2018); Liu et al. (2019). Our work is different in that it attempts to convert intermediate representations into a final-layer form, which is interpretable by design.
In the direction of efficiency, there is the thread of work on early exit, where computation is cut at a dynamically-decided earlier stage Schwartz et al. (2020); Xin et al. (2020); Schuster et al. (2022). Other works utilize a fixed early stage network to parallelize inference (Leviathan et al., 2022; Chen et al., 2023). However, intermediate representations are directly propagated in these works, which we show is substantially worse than our approach. Moreover, our method requires training considerably less parameters than methods such as Schuster et al. (2021), that learn a different output softmax for each intermediate layer.
More broadly, skipping transformer layers and analyzing the linearity properties of transformer components have been discussed in prior works Zhao et al. (2021); Mickus et al. (2022); Wang et al. (2022); Lamparth and Reuel (2023).
Conclusion and Future Work
We present a simple method for inspection of hidden representations in transformer models, by using pre-fitted context-free and token-uniform linear mappings. Through a series of experiments on different data sources, model architectures and scales, we show that our method consistently outperforms the prevalent practice of interpreting representations in the final-layer space of the model, yielding better approximations of succeeding representations and the predictions they induce.
We also demonstrate the practicality of our method for improving computation efficiency, saving a substantial amount of compute on top of prominent early exiting approaches.
Last, by extending our method to sub-modules, more specifically the attention sub-modules, we observe that in some cases replacing a part of transformer inference by a non-contextual linear computation only results in a small deterioration of the prediction. This opens new research directions for improving model efficiency, including breaking the computation into several parallelizable tasks.
Limitations
Although we see in this work that there is more linear structure to transformer inference than could be explained solely by the residual connection, we do not elucidate a reason for that. We also do not try to formulate formal criteria according to which to judge, in principle, the quality of ways of short-cutting transformer inference in-between layers. In addition, our experiments cover only English data.
Acknowledgements
We thank Tal Schuster for constructive comments.
References
Appendix A Experimental Setup Details
Here we specify the experimental setup details.
We use the four versions of the GPT-2 modelhttps://huggingface.co/gpt2 from Huggingface Wolf et al. (2020), having hidden layers and hidden dimensions respectively.
BERT.
We use the bert-large-uncased modelhttps://huggingface.co/bert-large-uncased from Huggingface, having hidden layers and hidden dimension , and the bert-base-uncased modelhttps://huggingface.co/bert-base-uncased from Huggingface, having hidden layers and hidden dimension . We use the BertForMaskedLM heads from Huggingface, pretrained for these models.
A.2 Data
When experimenting on the Wikipedia dataset, to generate an input sequence for our training and validation sets, we pick a random document from the Wikipedia datasethttps://huggingface.co/datasets/wikipedia from Huggingface, use spaCyhttps://spacy.io/ to break the document into sentences and pick a random sentence ending with a newline character among those.
When experimneting on the news article sentences dataset, we use the 10K English 2020 news sentences corpushttps://downloads.wortschatz-leipzig.de/corpora/eng_news_2020_10K.tar.gz from the Leipzig Corpora Collection, which we randomly divide into a training set consisting of 9,000 examples and a validation set consisting of 1,000 examples.
In the GPT-2 (resp. BERT) experiment, a tokenized sentence with more than 1024 (resp. 512) tokens was truncated to have 1024 (resp. 512) tokens.
A.3 Values of λ𝜆\lambda Used in §5
In §5, to have a plot which is gradual enough, we use the following values of :
.
For the baseline method id, to the values above we also add
Appendix B Results on News Articles Data
Here we record the results of the main experiments when evaluated on the Leipzig Corpora news article sentences dataset, in Figs. 12, 13 (quality of fit), Figs. 14, 15, 16, 17 (linear shortcut for language modeling), Figs. 18, 19 (linear shortcut across sub-modules in blocks) and Figs. 20, 21 (early exit).
Appendix D Results for Models of Various Sizes
Here we record results of some of the experiments when performed with , and -layered versions of GPT-2 and -layered version of BERT, on the Wikipedia dataset; Figs. 22, 23, 24, 25 (quality of fit), Figs. 26, 27, 28, 29 (layer linear shortcut, Precision@), Figs. 30, 31, 32, 33 (layer linear shortcut, Surprisal),
Appendix E More Figures
Here we record Figs. 34, 35 mentioned in §6.