Deja Vu: Contextual Sparsity for Efficient LLMs at Inference Time

Zichang Liu, Jue Wang, Tri Dao, Tianyi Zhou, Binhang Yuan, Zhao Song, Anshumali Shrivastava, Ce Zhang, Yuandong Tian, Christopher Re, Beidi Chen

Introduction

Large language models (LLMs), such as GPT-3, PaLM, and OPT have demonstrated that an immense number of parameters unleashes impressive performance and emergent in-context-learning abilities—they can perform a task by conditioning on input-output examples, without updating their parameters (Bommasani et al., 2021; Liang et al., 2022; Brown et al., 2020; Min et al., 2022; Chan et al., 2022). However, they are very expensive at inference time, especially for latency-sensitive applications (Pope et al., 2022). An ideal inference-time model should use less computation and memory while maintaining the performance and special abilities of pre-trained LLMs. The simplest and most natural approach is sparsification or pruning, which has a long history before the LLM era (LeCun et al., 1989). Unfortunately, speeding up inference-time sparse LLMs in wall-clock time while maintaining quality and in-context learning abilities remains a challenging problem.

While sparsity and pruning have been well-studied, they have not seen wide adoption on LLMs due to the poor quality and efficiency trade-offs on modern hardware such as GPUs. First, it is infeasible to retrain or iteratively prune models at the scale of hundreds of billions of parameters. Thus, methods in iterative pruning and lottery ticket hypothesis (Lee et al., 2018; Frankle & Carbin, 2018) can only be applied to smaller-scale models. Second, it is challenging to find sparsity that preserves the in-context learning ability of LLMs. Many works have shown the effectiveness of task-dependent pruning (Michel et al., 2019; Bansal et al., 2022), but maintaining different models for each task conflicts with the task independence goal of LLMs. Lastly, it is hard to achieve wall-clock time speed-up with unstructured sparsity due to its well-known difficulty with modern hardware (Hooker, 2021). For example, recent development in zero-shot pruning like SparseGPT (Frantar & Alistarh, 2023) finds 60% unstructured sparsity but does not yet lead to any wall-clock time speedup.

An ideal sparsity for LLMs should (i) not require model retraining, (ii) preserve quality and in-context learning ability, and (iii) lead to speed-up in wall-clock time on modern hardware. To achieve such demanding requirements, we go beyond static sparsity in previous works (e.g., structured/unstructured weight pruning). We instead envision contextual sparsity, which are small, input-dependent sets of attention heads and MLP parameters that lead to (approximately) the same output as the full model for an input. Inspired by the connections between LLMs, Hidden Markov Models (Xie et al., 2022; Baum & Petrie, 1966), and the classic Viterbi algorithm (Viterbi, 1967), we hypothesize that for pre-trained LLMs,

contextual sparsity exists given any input.

The hypothesis, if true, would enable us to cut off specific attention heads and MLP parameters (structured sparsity) on the fly for inference-time, without modifying pre-trained models. However, there are three challenges.

Existence: It is nontrivial to verify if such contextual sparsity exists, and naive verification can be prohibitively expensive.

Prediction: Even if contextual sparsity exists, it is challenging to predict the sparsity for a given input in advance.

Efficiency: Even if the sparsity can be predicted, it might be difficult to achieve end-to-end wall-clock time speedup. Taking OPT-175B as an example, the latency of one MLP block is only 0.2 ms on an 8×\timesA100 80GB machine. Without a fast prediction and optimized implementation, the overhead can easily increase the LLM latency rather than reduce it.

In this work, we address these challenges as follows:

Existence: Fortunately, we verify the existence of contextual sparsity with a surprisingly simple approach. To achieve essentially the same output, contextual sparsity is on average 85% structured sparse and thereby potentially leads to a 7×7\times parameter reduction for each specific input while maintaining accuracy (Figure 1(a)). During explorations of contextual sparsity, we make important empirical observations and build a theoretical understanding of major components in LLMs that help address the prediction and efficiency challenge.

Prediction: We discover that contextual sparsity depends not only on individual input tokens (i.e., non-contextual dynamic sparsity) but also on their interactions (contextual dynamic sparsity). Figure 1(b) shows that with pure dynamic information, sparsity prediction is inaccurate. Only with token embeddings with sufficient contextual information can we predict sparsity accurately. Another finding is that contextual dynamic sparsity for every layer can be predicted based on the “similarity” between layer parameters (heads/MLP) and the output from the previous layer, which carries the immediate contextual mixture of token embeddings.

Efficiency: Because at inference time, model parameters are static, inspired by the classical nearest neighbor search (NNS) literature and its applications in efficient deep learning, it is possible to formulate the above similarity-based prediction as an NNS problem (Indyk & Motwani, 1998b; Zhang et al., 2018; Chen et al., 2020a). However, as mentioned, the overhead might be difficult to overcome as we would need to perform on-the-fly predictions before every layer. Luckily, we exploit a phenomenon of LLM where token embeddings change slowly across layers due to residual connections (well-known in computer vision (He et al., 2016)). Since the inputs to a few consecutive layers are very similar, we can design an asynchronous lookahead predictor (Figure 2).

Based on our findings, we present a system, dejavu, that exploits contextual sparsity and realizes efficient LLMs for latency-sensitive applications.

In Section 4.1 and Section 4.2, we present a low-cost learning-based algorithm to predict sparsity on the fly. Given the input to a specific layer, it predicts a relevant subset of attention (heads) or MLP parameters in the next layer and only loads them for the computation.

In Section 4.3, we propose an asynchronous predictor (similar to classic branch predictor (Smith, 1998)) to avoid the sequential overhead. A theoretical guarantee justifies that the cross-layer design suffices for accurate sparsity prediction.

After integrating hardware-aware implementation of sparse matrix multiply (Section 4.4), dejavu (written mostly in Python) can reduce latency of open-source LLMs such as OPT-175B by over 2×\times end-to-end without quality degradation compared to the state-of-the-art library FasterTransformer from Nvidia (written entirely in C++/CUDA), and over 2×\times compared to the widely used Hugging Face implementation at small batch sizes. Furthermore, we show several ablations on different components of dejavu and its compatibility with quantization techniques.

Related Work and Problem Formulation

We first briefly discuss the rich literature on efficient inference. Then, we introduce the latency breakdown in our setting. Last, we provide a formal problem formulation.

Various relaxations have been studied for decades for model inference in machine learning. There are three main techniques: quantization (Han et al., 2015; Jacob et al., 2018; Nagel et al., 2019; Zhao et al., 2019), pruning or sparsity (Molchanov et al., 2016; Liu et al., 2018; Hoefler et al., 2021), and distillation (Hinton et al., 2015; Tang et al., 2019; Touvron et al., 2021). They are orthogonal areas and usually excel in different settings. Recently, there is active research attempting to apply one or a combination of such techniques in LLM inference (Yao et al., 2022; Park et al., 2022; Dettmers et al., 2022; Frantar et al., 2022; Frantar & Alistarh, 2023; Bansal et al., 2022; Xiao et al., 2022). More discussion is presented in Appendix A.

2 LLM Inference Latency Breakdown

The generative procedure of LLMs consists of two phases: (i) the prompt phase takes an input sequence to generate the keys and values (KV cache) for each transformer block of LLMs, which is similar to the forwarding pass of LLMs training; and (ii) the token generation phase utilizes and updates the KV cache to generate tokens step by step, where the current token generation depends on previously generated tokens.

This paper studies the setting where the token generation phase easily dominates the end-to-end inference time. As shown in Table 1, generating a sequence of length 128 takes much longer time than processing a sequence of length 128 as prompt due to I/O latency of loading model parameters. In addition, Table 2 shows that attention and MLP are both bottlenecks in LLMs, e.g., in 175B models, loading MLP parameters takes around 23\frac{2}{3} of the total I/O and attention heads take the other 13\frac{1}{3}. Further, in the tensor-parallel regime, there are two communications between GPUs, one after the attention block, and the other one after the MLP block. As shown in Table 3, communication between GPUs takes around 15 % token generation latency. This paper focuses on making attention and MLP more efficient. Communication cost implies that the upper bound of such speed-up is around 6×\times when skipping all transformer blocks.

3 Problem Formulation

The goal is to reduce the generation latency of LLMs by exploiting contextual sparsity. In the following, we formally define the sparsified attention and MLP blocks.

where σ\sigma is the activation function, e.g., ReLU, GeLU. Note that since the computation in the first linear results in sparse activations, the second linear layer is also sparsified.

For both MLP and Attention, given a compute budget, the goal is to find SMS_{M} and SAS_{A} that minimize the error between the sparse approximation and full computation.

Pre-trained LLMs are Contextually Sparse

In this section, we present several key observations and theoretical understandings of sparsity in LLMs, upon which the dejavu design is based. We first test the contextual sparsity hypothesis and verify that contextual sparsity exists in pre-trained LLMs in Section 3.1. Then, we build an understanding of why contextual sparsity happens naturally even when LLMs are densely trained in Section 3.2. Finally, we present an observation on residual connections and explain their relationship to contextual sparsity analytically in Section 3.3.

Inspired by prior pruning literature (Molchanov et al., 2016), we find a surprisingly simple method is sufficient to study and verify our hypothesis. In this section, we describe the testing procedure, observation details, and insights of this study.

Verification: Our test is performed on OPT-175B, 66B, and 30B models and various downstream datasets such as OpenBookQA (Mihaylov et al., 2018) and Wiki-Text (Merity et al., 2016). We find the contextual sparsity for every input example with two forward passes of the model. In the first pass, we record a subset of parameters, specifically which attention heads and MLP neurons yield large output norms for the input. In the second pass, each input example only uses the recorded subset of parameters for the computation. Surprisingly, these two forward passes lead to similar prediction or performance on all in-context learning and language modeling tasks.

Observation: Figure 3 shows that on average, we can impose up to 80% sparsity on attention heads and 95% sparsity on MLP neurons. As mentioned in Section 2, OPT-175B model has 2×2\times MLP parameters than those of attention blocks. Therefore total sparsity here is around 85%. Since these are all structured sparsity (heads and neurons), predicting them accurately could potentially lead to 7×7\times speedup.

Insight: It is intuitive that we can find contextual sparsity in MLP blocks at inference time because of their activation functions, e.g., ReLU or GeLU (Kurtz et al., 2020). Similar observations were made by (Li et al., 2022). However, it is surprising that we can find contextual sparsity in attention layers. Note that, finding contextual sparsity in attention is not the same as head pruning. We cross-check that different examples have different contextual sparsity. Although 80%80\% of the parameters are not included in the paths for a given example, they might be used by other examples. Next, we will try to understand why contextual sparsity exists in attention blocks.

2 Token Clustering in Attention Layers

In the previous section, we have verified that there exists contextual sparsity for a given input in LLMs. In this section, we try to understand the reason for such phenomena, especially in attention layers. We first show an in-depth observation of attention. Then we present a hypothesis that self-attentions are conceptually clustering algorithms. Last we show analytical evidence to support this hypothesis.

Observation: Figure 4 shows the attention map of three different heads from the same layer for an example input. The next token it should predict is “Truck”. Darker color represents higher attention scores. We observe that the middle head is a relatively uniform token-mixing head while the top and bottom ones are “heavy hitter” attention heads (with high attention to “like” and “shipping”). Unsurprisingly, only selecting heavy hitter heads but not uniform heads does not affect the prediction, since uniform heads do not model or encode important token interactions. In the next section, we will also explain in detail how the criteria for selecting uniform attention heads and heads with small output norms are highly correlated.

Hypothesis: We hypothesize that the attention head is performing mean-shift clustering (Derpanis, 2005).

Therefore, the self-attention head can be regarded as one mean-shift step to push input embeddings of different tokens together, if they are already neighbors in a projection space specified by WiQ(WiK)W_{i}^{Q}(W_{i}^{K})^{\top}. Different heads learn different projection spaces to perform clustering. These dynamics explain the precise reason why token embeddings tend to cluster after going through more layers, resulting in high attention scores among cluster members, and low scores for non-members. Furthermore, the cluster patterns are different at different heads (More details in Appendix K).

The above analysis not only provides an understanding of why contextual sparsity exists naturally in pre-trained LLMs, but also inspires our design of “similarity”-based sparsity prediction for dejavu in Section 4.

3 Slowly Changing Embeddings across Layers

High similar embeddings in consecutive layers: In Figure 5(a), we show that for the same given input, the cosine similarity between embeddings or activations in two consecutive layers is exceptionally high on 7 different sizes of OPT models. Specifically, we collect activations from each layer while performing OPT model inference on C4 validation set (Raffel et al., 2019). Taking OPT-175B as an example, starting from the second layer, the similarity between any two consecutive layers is around 0.99, which indicates that when an input is passed through the model, the direction of its embedding changes slowly. Interestingly, the most drastic change happens in the first layer. Furthermore, we increase the gap and investigate the similarity between the embedding at layer ll and at layer l+nl+n shown in Figure 5(b). As we increase the gap, the similarity decreases as expected while the differences in cosine similarity between various choices of nn are smaller at the shallower layer. We plot the mean similarity, and the standard deviation is indicated by the shading. Similar plots on more models are presented in Appendix B.

Connection to residuals: We verify that the high similarity in embeddings in LLM inference is due to the residual connection. We first dissect the computation graph inside each transformer layer to understand the cause behind this phenomenon. There are two residual connections inside a transformer layer, one around the attention block, and the other one around the MLP block. The residual connection can be written as X+F(X)X+F(X), where FF is either the Multi-Head Attention or two MLP Layers. In Figure 5(c) and Figure 5(d), indeed we can see that X\|X\| is significantly greater than F(X)\|F(X)\|, confirming that embeddings are changing slowly because the residual norm is large.

Connection to Contextual Sparsity: We take a step deeper trying to understand the reason behind the large residual norm with mathematical modeling. We discover that one possible reason for small F(X)\|F(X)\| is due to high sparsity. For the MLP Block, high sparsity may contribute to the small norm of F(X)F(X) because a large portion of outputs have small norms. Similar reasoning applies to the Attention Block, and thus a large number of attention heads yield small norm outputs.

Residual Two Sides Bound: Besides empirical reasoning, we formally define the computation of LLMs mathematically. Under our computation model, we can show that a shrinking property which is observed by our practical experiments. Proofs are in Appendix G, H, I.

Let 0<ϵ1<ϵ2<10<\epsilon_{1}<\epsilon_{2}<1 be the lower and upper bound of the shrinking factor. Let xx be the yy be the output. We have the residual connection y=x+F(x)y=x+F(x). For the MLP block F(x)F(x), we have ϵ1yx2ϵ2\epsilon_{1}\leq\|y-x\|_{2}\leq\epsilon_{2}. For the attention block F(x)F(x), we have ϵ1yx2ϵ2\epsilon_{1}\leq\|y-x\|_{2}\leq\epsilon_{2}.

dejavu

In this section, we present our framework for inference-time contextual sparsity search for LLMs. We introduce the sparsity predictor for MLPs in Section 4.1 and for attention heads in Section 4.2. dejavu’s workflow is shown in Figure 2. Section 4.3 discusses exploiting our observation on LLMs to avoid the sparse prediction overhead with theoretical guarantees. In Section 4.4, we present our optimized implementation that enables end-to-end latency reduction. More details are presented in Section D.

As explained in Section 2, MLP blocks are one of the major bottlenecks for the LLM generation (23\frac{2}{3} of the FLOPs and IOs). In this section, we discuss how we achieve wall-clock time speed-up with contextual sparsity in the MLP blocks.

Challenge Figure 3(b) shows that for a given token, the contextual sparsity of 95% is possible. The contextual sparsity in the MLP block can be identified after computing the activation. However, this only demonstrates the existence of contextual sparsity but brings no benefits in terms of efficiency. A fast and precise prediction is needed to exploit contextual sparsity for end-to-end efficiency. The naive way is to select a subset of neurons randomly. Unsurprisingly, random selection fails to identify the accurate contextual sparsity, resulting in drastic model degradation.

A Near-Neighbor Search Problem: Recall that we verify the existence of contextual sparsity by recording which neurons yield significant norms. Essentially, given the input, the goal is to search for the neurons that have high inner products with the input, because the activation function “filters" low activation. Thus, we formulate the contextual sparsity prediction of an MLP layer as the classical near-neighbor search problem under the inner product metric.

Our W1W^{1} (first linear layer) and yy (input embedding) in MLP blocks can be viewed as the dataset and query in Definition 4.1 respectively.

Design The standard state-of-the-art near-neighbor search methods and implementations slow down the computation. Take OPT-175B where dd is 12288 as an example. HNSW (Malkov & Yashunin, 2018) requires more than 10ms, and FAISS (Johnson et al., 2019) requires more than 4ms, while the MLP computation is only 0.2ms. The high dimensionality and complications of data structure implementation on GPU make the search time longer than the MLP computation. Therefore, we choose a neural network classifier as our near-neighbor search method to exploit the fast matrix multiplication on GPU. For each MLP block, we train a small two-layer fully connected network to predict contextual sparsity. Collecting training data is straightforward because we know the contextual sparsity using dense computation. The training algorithm is summarized in Algorithm 1. The sparsified computation in W1W^{1} has two steps: (1) Given yy, the sparsity predictor SPM\mathsf{SP}_{M} predicts a set SMS_{M} of important neurons in weights W1W^{1}. (2) Compute the sparsified MLP defined in Eq. equation 1. Note here the sparsity in MLP is highly structured.

2 Contextual Sparsity Prediction in Attention Blocks

Attention blocks take around 30% I/Os in the generation. In this section, we describe how dejavu exploits contextual sparsity to speed up the Attention blocks.

Challenge: As discussed in Section 3.1, only a few heads perform important computations for a given input token. Similar to the MLP blocks, a fast selection of attention heads without full computation is required to reduce end-to-end latency. Furthermore, one particular challenge of sparse prediction in attention blocks is attention’s dependence on previous tokens. On the one hand, it is unclear whether the past token’s key and value caches are needed for sparse prediction. On the other hand, it is unclear how to handle the missing KV cache of past tokens for the current token computation at the selected head.

A Near-Neighbor Search Problem: Head prediction can also be formulated as a near-neighbor search problem based on our understanding in Section 3.2. Since each head is performing mean-shift clustering, after the first few layers, the current token embedding alone is sufficient for the prediction thanks to the token-mixing nature of the transformer. Therefore, the prediction can be based on the similarity between yy and head parameters.

Approach: We design our attention sparse predictor to be the same architecture as the MLP sparse predictor. Each head is regarded as one class and a similar training process is used (Algorithm 1). Then, similar to how MLP prediction is performed, the attention sparsity predictor SPA\mathsf{SP}_{A} selects a set SAS_{A} of heads HiH_{i} (see Eq. equation 2). To address the problem of missing KV cache for a past token, we exploit the fact that the generation latency is I/O bounded while computation is essentially “free". Specifically, for the predicted attention head of input yy, we compute the corresponding keys, and values and store them in the KV cache. But we also save a copy of yy for all the other non-selected heads. Then during the future token generation, if there is missing KV cache in the selected heads, we could load stored token embeddings and compute the keys and values together. This requires almost minimal extra memory access (the main cost is loading the weight matrices).

3 Reducing Overhead with Asynchronous Execution

Sparse prediction overhead may easily increase the end-to-end latency rather than reduce it despite the reduction in FLOPs. Therefore, we introduce a look-ahead sparse prediction method, inspired by our observations in Section 3.3.

where set SAlS_{A}^{l} is the contextual sparsity for the Attention block, and set SMlS_{M}^{l} is the contextual sparsity for the MLP block at ll-th layer. Note that the computation at Attention and MLP blocks have to wait for the sparse predictor decision. This overhead potentially outweighs the saving from Attention and MLP blocks in terms of latency.

Approach: In Section 3.3, we present the slowly evolving embedding phenomenon, which provides opportunities to relax the sequential computation to parallel computation. Along with the observation of low computation intensity during generation, we parallel the sparse prediction with the computation of each block ( See Figure 2). The computation can be written as follows:

We remark SAl+1S_{A}^{l+1} and SMl+1S_{M}^{l+1} can be computed in parallel with y~l\widetilde{y}_{l} or y^l\widehat{y}_{l}, while the previous 4 steps are sequential.

Let ϵ(0,1)\epsilon\in(0,1). Let yly_{l} be input at ll-th layer. Let yl1y_{l-1} be the input at (l1)(l-1)-th layer. Suppose that ylyl12ϵ\|y_{l}-y_{l-1}\|_{2}\leq\epsilon. For any parameters c,τc,\tau such that ϵ<O(cτ)\epsilon<O(c\tau). Then we can show that, solving MaxIP(c,τ){\sf MaxIP}(c,\tau) is sufficient to solve MaxIP(0.99c,τ){\sf MaxIP}(0.99c,\tau).

4 Hardware-efficient Implementation

We describe how dejavu is implemented in a hardware-efficient manner to realize the theoretical speedup of contextual sparsity. Taking into account hardware characteristics leads to over 2×\times speedup compared to an optimized dense model, and 4×\times faster than a standard sparse implementation.

We highlight some hardware characteristics of GPUs:

Small-batch generation is bottlenecked by GPU memory I/Os (NVIDIA, 2022; Ivanov et al., 2021; Dao et al., 2022). This is because of low arithmetic intensity. For each element loaded from GPU memory, only a small number of floating point operations are performed.

GPUs are block-oriented devices: loading a single byte of memory takes the same time as loading a block of memory around that same address (Harris, 2013). The block size is usually 128 bytes for NVIDIA GPUs (Cook, 2012).

These characteristics present some challenges in implementing contextual sparsity. However, they can be addressed with classical techniques in GPU programming.

Kernel fusion: A standard implementation of sparse matrix-vector multiply (e.g., in PyTorch) that separately indexes a subset of the matrix WSM1W^{1}_{S_{M}} before multiplying with input yy would incur 3×\times the amount of memory I/Os. Therefore, to avoid such overhead, we fuse the indexing and the multiplication step. Specifically, we load a subset of WSM1W^{1}_{S_{M}} to memory, along with yy, perform the multiply, then write down the result. This fused implementation (in Triton (Tillet et al., 2019)) yields up to 4×\times speedup compared to a standard PyTorch implementation (Appendix E).

Memory coalescing: In the dense implementation, the weight matrices of two linear layers in MLP are stored as (W1)(W^{1})^{\top} and W2W^{2} so that no extra transpose operation is needed. They are conventionally stored in row-major format. In the sparse implementation, it allows us to load (WSM1)(W^{1}_{S_{M}})^{\top} optimally (the second dimension is contiguous in memory). However, for cases where we need to load (WSM2)(W^{2}_{S_{M}}), this format significantly slows down memory loading, as indices in SMS_{M} point to non-contiguous memory. We simply store these matrices in column-major format (i.e., store (W2)(W^{2})^{\top} in row-major format), then use the same fused kernel above. Similarly, in attention blocks, we store attention output projection WOW^{O} column-major format.

These two techniques (kernel fusion and memory-coalescing) make dejavu hardware-efficient, yielding up to 2×\times speedup end-to-end compared to the state-of-the-art FasterTransformer (Section 5.1).

Empirical Evaluation

In Section 5.1, we present the end-to-end results that show dejavu achieves over 2×\times reduction in token generation latency compared to the state-of-the-art FasterTransformer and over 6×\times compared to Hugging Face with no accuracy loss. In Section 5.2, we perform a list of ablation studies such as independent evaluation on the inference-time contextual sparsity of the MLP block and the Attention block (Details are presented in Section C). At last, we present the additional results to demonstrate the future possibility of sparsifying the entire LLMs via layer skipping in Section C.3.

Experiment Setting: We compare the accuracy of dejavu-OPT against the original OPT model on two language modeling datasets Wiki-Text (Merity et al., 2016) and C4 (Raffel et al., 2019) and seven few-shot downstream tasks: CB (de Marneffe et al., 2019), COPA (Gordon et al., 2012), Lambada (Radford et al., 2019), OpenBookQA (Mihaylov et al., 2018), PIQA (Bisk et al., 2020), RTE (Giampiccolo et al., 2007), Winogrande (ai2, 2019). We use lm-eval-harness (Gao et al., 2021) for zero-shot and five-shot tasks. We collect training data for the sparsity predictor using 500 random data points from the C4 training dataset. Our experiments are conducted on NVIDIA A100 80GB GPU servers.

No accuracy drop until 75% sparsity: In Figure 6, we present dejavu-OPT-175B’s accuracy trend. In a zero-shot setting, the average accuracy across tasks does not drop until 75% sparsity. A similar trend can be observed for the five-shot setting, which verifies the model’s ability for in-context learning. This result is exceptionally encouraging given our observation in Figure 1(a), where we could impose 85% sparsity when allowed full computation.

Over 2×\times latency reduction: Figure 7 presents the latency speed-up for the token generation with OPT-175B at batch size 1, where dejavu achieves the best performance. At around 75% sparsity, dejavu speeds up generation by 1.8-2×\times compared to the state-of-the-art FasterTransformers (FT)http://github.com/NVIDIA/FasterTransformer and by 4.8-6×\times to Hugging Face (HF) implementationhttp://github.com/huggingface/transformers.

2 Ablation Results

Contextual Sparsity for Larger Batches: Although this paper focuses on latency-sensitive settings, we demonstrate that dejavu generalizes to larger batches. we present the Union contextual sparsity (fraction of neurons/heads that are not used by any of the inputs in the batch) of different batches sizes for MLP and Attention blocks, respectively, in Figure 8 and 11. The union operation is essential to realize a fast sparse GEMM. Surprisingly the number of MLP neurons and Attention heads that dejavu activated does not grow linearly with the batch size. This suggests a power law distribution rather than a uniform distribution of parameter access from all input examples. This provides an opportunity for potentially extending Dejavu to the high-throughout setting. For example, we can first pre-process the inputs and batch similar inputs to enjoy a higher level of union contextual sparsity.

Contextual sparsity on MLP blocks: We study the contextual sparsification of the MLP block in OPT-175B. We leave the Attention block as dense computation. Table 4 shows the model performance at 85% sparsity. The MLP sparse predictor introduces no accuracy loss on both zero-shot tasks and language modeling. In the training of the MLP sparse predictor, we observe that the sparse predictor achieves high validation accuracy. The shallow layer seems easier to model because the predictor has validation accuracy over 99% in the shallow layers and drops to around 93% in the ending layers.

Contextual sparsity on attention blocks: In this section, we study the sparse predictor for the Attention block on OPT-175B and leave the MLP block as dense computation. Table 4 displays the test accuracy on zero-shot tasks and perplexity on the language modeling datasets. In summary, the Attention sparse predictor introduces no accuracy loss at around 50% sparsity. During the training of the Attention sparse predictor, we observe different trends compared to the MLP sparse predictor. The validation accuracy is around 93% in the middle layers and near 99% in the shallow and deep layers.

Contextual Sparsity on Smaller Models: Our main experiments focus on OPT-175B. Here, we verify dejavu’s effectiveness on a smaller model, specifically OPT-66B. In Table 5, we summarize the accuracy on zero-shot task at 50%50\% sparsity. Similar to dejavu-OPT-175B, we notice no accuracy loss.

Contextual Sparsity on Other Models: We expand the evaluation to another model family. In Table 6, we summarize the accuracy at attention sparsity 50% and MLP sparsity 30%. Similar to OPT family, we notice no accuracy loss. The lower sparsity level in MLP is due to the difference in activation function.

Non-Contextual Sparsity: As we mentioned in Section 1, one could predict sparsity without contextual information. For non-contextual sparsity, we rely on the original embedding at the input layer. At every block, we first pass the original embedding to record a subset of parameters yielding a large norm. In the second pass, the embedding at every layer only uses the recorded subset. As shown in Figure 1, non-contextual prediction is not sufficient and leads to accuracy losses even at 50% sparsity. This result verifies our design choices of relying on the activation at every layer as input to make contextual sparsity predictions.

Compatibility with Quantization: Quantization is another promising direction for efficient language models. We investigate the possibility of combining contextual sparsity with quantization techniques. For dejavu-OPT-175B, we set the entire model sparsity at 75%. For quantization, we apply 4-bit quantization on model weights (W4A16). As shown in Table 7, the combination of quantization and dejavu almost always achieves better accuracy than dejavu or quantization alone. This suggests that the approximation errors from these two directions do not get compounded.

Conclusion

Our main goal is to make LLM inference efficient so that their powerful in-context learning abilities can be used in more application domains. We observe that contextual sparsity can be accurately predicted with lightweight learning-based algorithms. This motivated us to design dejavu that uses asynchronous lookahead predictors and hardware-efficient sparsity to speed up LLM inference in wall-clock time. Our encouraging empirical results validate that contextual sparsity can reduce inference latency by over 2×\times compared to the state-of-the-art FasterTransformer without model quality drops. Our method is a step towards making LLMs more accessible to the general community, which could unlock exciting new AI applications.

Acknowledgements

We would like to thank Ryan Spring, Laurel Orr, Guangxuan Xiao, Eric Han, Xun Huang, Daniel Y. Fu, Benjamin Spector, Ruan Silva, Diana Liskovich, and the anonymous reviewers for helpful discussions and feedback. We acknowledge the generous support by Together Computer, which enabled the necessary partial computations in this work.

References

Appendix A Related Work

Generative LLM inference. Taking OPT-175B as an example, assume 6 A100 80GB PCIe, based on the hardware specifications, we compare two main phases of inference time LLM, namely prompting and token generation in Table 1, and two major components, namely Multi-Head-Attention block and MLP block in Table 2. In practice, the token generation phase usually dominates the end-to-end test latency due to IO latency. Generating only two tokens is about the same latency as prompting. Further, during token generation, the MLP block is 2 ×\times more expensive in both FLOPs and IO access. The hardware is often at low utilization because memory reads and writes are more limited on modern hardware than tensor core computation.

Given the rapid development of LLM, there is an emergence of systems that are specialized for LLM inference, such as Faster Transformer (NVIDIA, ), Orca (Yu et al., 2022), LightSeq (Wang et al., 2021), PaLM inference (Pope et al., 2022), TurboTransformers (Fang et al., 2021), and Deepspeed-Inference (Aminabadi et al., 2022). In practice, the token generation phase usually dominates the end-to-end inference time. Although the state-of-the-art systems introduce some helpful system optimizations for speedup, there is a lack of careful algorithm and system co-design to unleash the full potential of hardware efficiency during the LLM inference computation.

Near-neighbor Search for Efficient Deep Neural Networks. Near-neighbor Search is a well-studied problem with wide applications in recommendation system (Xue et al., 2017; Hall & Attenberg, 2015), question answering (Boytsov et al., 2016; Seo et al., 2019; Chang et al., 2020) and natural language processing (Bengio et al., 2003; Lee et al., 2016). There has been a line of work using Near-neighbor Search techniques such as Locality-sensitive hashing (Gionis et al., 1999) and Graph-based indexing (Malkov et al., 2014) for efficient deep neural network training or inference (Zhang et al., 2018; Chen et al., 2019, 2020a; Kitaev et al., 2020; Chen et al., 2021b, a; Liu et al., 2022).

Quantization, pruning, distillation for LLM inference. Various system relaxations have been studied for decades for model inference in machine learning. For example, quantization (Han et al., 2015; Jacob et al., 2018; Nagel et al., 2019; Zhao et al., 2019), pruning (Molchanov et al., 2016; Liu et al., 2018; He et al., 2019; Hoefler et al., 2021), and distillation (Hinton et al., 2015; Cho & Hariharan, 2019; Tang et al., 2019; Touvron et al., 2021) have been applied to speed up the inference of the machine learning model. Active research has recently attempted to apply such techniques in LLM inference. For example, zeroQuant (Yao et al., 2022) and nuQmm (Park et al., 2022) implement customized CUDA kernels to support tenor-wise or group-wise quantization for LLM inference; LLM.int8 (Dettmers et al., 2022) adopts a mixed INT8/FP16 computation to diminish the influence of activation outliers; SmoothQuant (Xiao et al., 2022) enables efficient 8-bit weight and activation for LLM inference; GPTQ (Frantar et al., 2022) adopts a one-shot weight quantization method based on approximate second-order information for accuracy and efficiency; SparseGPT (Frantar & Alistarh, 2023) introduces an approximate sparse regression solver to enable the sparsity in LLM inference; (Bansal et al., 2022) has reported that a small set of attention heads can perform primitive induction operations associated with in-context learning, and use this property to prune LLM for acceleration.

Residual connections in neural networks. Residual connection shows great advantages for neural network generalization, it provides additional paths for activations to reach the latter parts of the neural network by skipping some layers (He et al., 2016). The advancement of residual connections can be viewed as ensembles of multiple shallow neural networks (Veit et al., 2016). Plenty of active research has discussed the effectiveness of residual connections (Balduzzi et al., 2017; Bello et al., 2021; Allen-Zhu & Li, 2019; Frei et al., 2019). However, as far as we know, there is no former work that leverages the property of residual connections to improve the efficiency of LLM inference.

Appendix B Additional Observation on Slowly Changing Observation

First, we present more plots on the cosine similarity between representations. Figure 9 plots the cosine similarity between activation across layers on OPT family. It is evident that similarity is high for the larger models.

There are two residual connections inside a transformer layer, one around the attention block, and the other one around the MLP block. The residual connection can be written as X+F(X)X+F(X), where FF is either the Multi-Head Attention or two MLP Layer. Figure 10 plots the cosine similarity between XX and X+F(X)X+F(X), which is close to 1.0, and the cosine similarity between XX and F(X)F(X), which is close to 0.0. This happens because X\|X\| is significantly greater than F(X)\|F(X)\|, shown in the purple. In the first layer, F(X)\|F(X)\| is larger, which explains the low cosine similarity. The magnitude of the L2L2 norm is different across models, however, we observe a similar trend with models of different sizes. There exists a normalization layer before F(X)F(X) and the layer normalization scale X\|X\| to a consistent magnitude across layers (e.g. 85 for OPT-30B, 110 for OPT175B), but not necessarily scale down X\|X\|.

Appendix C Additional Experiment Detail

To help understand where the speed-up comes from when batch size is greater than 1, we present the Union Contextual Sparsity (fraction of neurons/heads that are not used by any of the inputs in the batch) of different batches sizes for MLP and Attention blocks, respectively, in Figure 11. Union Contextual Sparsity is calculated as 1.0 - the union of activated MLP neurons or Attention heads in the batch / total neurons or heads. The union operation is essential to realize a fast sparse GEMM.

Surprisingly the number of MLP neurons/Attention heads that dejavu activated does not grow linearly with the batch size. This suggests a power law distribution rather than a uniform distribution of parameter access from all input examples. Further, a larger batch size can easily lead to out-of-memory for long sequence settings due to the limited GPU memory, the giant large model size, and the stored KV cache. For example, the total GPU memory of 8 80GB A100 is 640GB. Model parameters are around 350GB for OPT175B. The KV cache for a batch size 32 with a sequence longer than 1920 tokens has already filled up the GPU memory.

C.2 Near Neighbor classifier

In the dejavu framework, any near-neighbor search method under the inner product metric would be sufficient to predict a sparsity pattern. "Training predictor" is to reduce the cost of on-the-fly prediction, rather than training the model itself.

For example, in our exploration stage mentioned in Section 4.1, we adopt HNSW, a state-of-art near-neighbor search method, to predict MLP sparse pattern, and we can see from the following table there is no drop in the perplexity at 90 % sparsity ratio. However, due to the high dimensionality of embedding and HNSW’s reliance on CPU, the time HNSW took to identify the sparsity pattern is 10ms, which is longer than the MLP computation.

In our paper, we choose a neural network classifier as our near neighbor search method to take advantage of the fast matrix multiplication on GPU. And training such classifiers to predict sparsity patterns is not only cheaper in terms of training cost but also inherently different from the method concept.

C.3 Future Possibility: Skipping Layer

Deja Vu currently sparsifies from the perspective of model width. Here, we explore the possibility of sparsification from model depth. As observed in Section 3, we show that the activation of large language models changes slowly across blocks. This property can be leveraged to increase the efficiency of a trained model by parallelizing, reordering, or skipping certain intermediate sub-blocks without significantly impacting the overall accuracy.

Improving the inference efficiency of Transformer models is a challenging task due to their sequential execution of Transformer layers. Each sub-block depends on the output of the previous one, leading to low hardware efficiency, particularly during the token generation phase where each forward pass is computed for only one token. However, the sequential execution of blocks and sub-blocks yields computation bubbles, and the latter involves a large amount of communication overhead. Here, we present an interesting observation that can potentially alleviate these challenges. We found that the activation of the model changes slowly across blocks. Specifically, the cosine similarity of activations between adjacent blocks is often above 0.99. This suggests that the blocks might take the previous activation as input – parallelize or reorder the blocks – without significantly affecting the output. Slowly changing activations suggest that it may be possible to parallelize, reorder, or even skip blocks while maintaining a similar output. Some existing models, such as GPT-J (Wang & Komatsuzaki, 2021), GPT-NeoX (Black et al., 2022), and PaLM (Chowdhery et al., 2022) already placed the Attention block and MLP block in parallel in training to facilitate parallel computation and reduce the communication overhead.

Here we investigate the possibility at inference time. And surprisingly, we found parallelizing those blocks for models that are trained in a sequence manner will not hurt the performance of downstream tasks significantly. And surprisingly, we found parallelizing those blocks for models that are trained in a sequence manner will not hurt the performance of downstream tasks significantly. TableC.3 presents some preliminary results of OPT-175B and Bloom

Given the activation yy and Transformer layer ll, we have:

Parallelizing two blocks refers to placing the Attention and MLP blocks in parallel, i.e.:

Parallelizing four blocks then parallelize the blocks of two Transformer layers, defined as follows:

Skipping layers is straightforward, which drops an entire Transformer layer for every nn layers.

We are surprised to find that parallel two layers preserve accuracy on a series of tasks across models. Besides, randomly skipping 25% layers doesn’t lead to catastrophic quality. Our findings suggest from the downstream task perspective, the activation patterns within the model are relatively consistent across different blocks, providing a potential avenue for future research on model compression and optimization.

Appendix D Implementation Details

Figure 12 presents a more detailed workflow of dejavu. The left diagram shows how an input yy performs the sparse MHA with selected indices 0,3{0,3}, predicted by the head predictor. Similarly, the right diagram shows how an input yy performs the sparse MLP with selected indices 0,2{0,2}, predicted by the neuron predictor of that layer.

Appendix E Benchmarking Sparse MLP and Sparse Attention

We validate that our hardware-aware implementation of sparse MLP and sparse attention (Section 4.4) yields wall-clock speed up compared to both dense MLP/attention and compared to the standard implementation in PyTorch.

Recall that our implementation fuses the sparse indexing and the multiplication (WSM1)y(W_{S_{M}}^{1})^{\top}y for weight matrices (W1)(W^{1})^{\top} and input vector yy. In cases where we need to index WSM2W^{2}_{S_{M}}, we store the transpose of W2W^{2} to ensure memory coalescing. For the baseline implementation in PyTorch, we index (WSM1)(W_{S_{M}}^{1})^{\top} as a separate operation before multiplying with yy, which incurs more memory reads/writes.

Similarly, we fuse the sparse indexing and the multiplication (WSAQ)y(W_{S_{A}}^{Q})^{\top}y, (WSAK)y(W_{S_{A}}^{K})^{\top}y, (WSAV)y(W_{S_{A}}^{V})^{\top}y for weight matrices (WQ)(W^{Q})^{\top}, (WK)(W^{K})^{\top}, (WV)(W^{V})^{\top} and input vector yy. Note we usually concatenate all three matrices in the standard implementation, but we separate them here for clarity. In cases where we need to index WSAOW^{O}_{S_{A}}, we store the transpose of WOW^{O} to ensure memory coalescing.

In Figure 13 and Figure 14, our sparse MLP and attention implementations are 4-5×\times faster than the baseline implementation in Pytorch, and remains faster than the dense version for density up to 0.8.

Appendix F Notations and Basic Definitions

The connection between eigenvalues and singular values is

We use notation A0A\succeq 0 to denote that matrix AA is positive semidefinite (psd). Mathematically, A0A\succeq 0 means for all vectors xx, we have xAx0x^{\top}Ax\geq 0.

Similarly, for two squarer matrices AA and BB, we use ABA\succeq B to denote the case where for all vectors xx, xAxxBxx^{\top}Ax\geq x^{\top}Bx.

Appendix G Subspace Embeddings and Norm Preserving

Further, if d=O(k+log(1/δ))d=O(k+\log(1/\delta)), then we have

The above condition implies that ff is a shrinking operator but also not shrinking arbitrarily small.

Given dΩ(ϵ2(k+log(1/δ)))d\geq\Omega(\epsilon^{-2}(k+\log(1/\delta))), by using Lemma G.11 , we have

By property of subspace embedding, we know that if dΩ(ϵ2(s+log(1/δ)))d\geq\Omega(\epsilon^{-2}(s+\log(1/\delta))),

By property of function of ff, we know we only need to care y1=1\|y\|_{1}=1, this implies that

where the first step follows from Vy2(1+ϵ)y2\|\overline{V}y\|_{2}\leq(1+\epsilon)\|y\|_{2}, the second step follows from y2y1\|y\|_{2}\leq\|y\|_{1} and the last step follows from y1=1\|y\|_{1}=1.

where the first step follows from (1ϵ)y2Vy2(1-\epsilon)\|y\|_{2}\leq\|\overline{V}y\|_{2}, the second step follows from 1sy1y2\frac{1}{\sqrt{s}}\|y\|_{1}\leq\|y\|_{2} and the last step follows from y1=1\|y\|_{1}=1.

Combining Eq. (G.1)and Eq. (G.1) together, we have

By V=12τVV=\frac{1}{2}\tau\overline{V} and x2=1\|x\|_{2}=1, we have

G.2 ReLU Functions

Suppose sΩ(ϵ2log(1/δ))s\geq\Omega(\epsilon^{-2}\log(1/\delta)).

Using Lemma G.6, Fact G.7, we can show that for each fixed

By a standard ϵ\epsilon-net argument (Lemma G.9), the net points in X{\cal X} is at most (10/ϵ)O(k)(10/\epsilon)^{O(k)}.

Taking a union bound over all the net points, we can show that for all xXx\in{\cal X}

holds with probability 1δ/21-\delta/2 and sΩ(ϵ2klog(1/(δϵ)))s\geq\Omega(\epsilon^{-2}k\log(1/(\delta\epsilon))).

Further, we using Lemma G.11, we can show that

Rescaling the ϵ\epsilon, we complete the proof.

G.3 Folded Gaussian Distribution

We state a standard tool from literature,

Let XXk2X\sim\mathcal{X}_{k}^{2} be a chi-squared distributed random variable with kk degrees of freedom. Each one has zero means and σ2\sigma^{2} variance.

Further if kΩ(ϵ2t)k\geq\Omega(\epsilon^{-2}t) and tΩ(log(1/δ))t\geq\Omega(\log(1/\delta)), then we have

vi|v_{i}| follows i.i.d. from the following distribution: with half probability vi=0|v_{i}|=0, and with the other half probability vi|v_{i}| follows from folded Gaussian distributions N(0,2h2m)|\mathcal{N}(0,\frac{2\|h\|^{2}}{m})|.

mv22h2\frac{m\|v\|^{2}}{2\|h\|^{2}} is in distribution identical to χω2\chi_{\omega}^{2} (chi-square distribution of order ω\omega ) where ω\omega follows from binomial distribution B(m,1/2)\mathcal{B}(m,1/2).

We assume each vector WiW_{i} is generated by first generating a gaussian vector gN(0,2Im)g\sim\mathcal{N}(0,\frac{2I}{m}) and then setting Wi=±gW_{i}=\pm g where the sign is chosen with half-half probability. Now, Wi,h=g,h|\langle W_{i},h\rangle|=|\langle g,h\rangle| only depends on gg, and is in distribution identical to N(0,2h2m)|\mathcal{N}(0,\frac{2\|h\|^{2}}{m})|. Next, after the sign is determined, the indicator 1Wi,h+q0\mathbf{1}_{\langle W_{i},h+q\rangle\geq 0} is 11 with half probability and with another half. Therefore, vi|v_{i}| satisfies the aforementioned distribution. As for v2\|v\|^{2}, letting ω{0,1,,m}\omega\in\{0,1,\ldots,m\} be the variable indicator how many indicators are 11 , then ωB(m,1/2)\omega\sim\mathcal{B}(m,1/2) and mv22h2χω2\frac{m\|v\|^{2}}{2\|h\|^{2}}\sim\chi_{\omega}^{2}. ∎

We define a standard notion in sketching technique.We remark that sketching technique has widely applied to many applications such as linear regression, low-rank approximation (Clarkson & Woodruff, 2013; Nelson & Nguyên, 2013; Lu et al., 2013; Boutsidis et al., 2016; Cohen, 2016; Razenshteyn et al., 2016; Song et al., 2017, 2019), linear programming (Song & Yu, 2021; Dong et al., 2021; Jiang et al., 2021; Gu & Song, 2022), semi-definite programming (Gu & Song, 2022; Song et al., 2023b), empirical risk minimization(Lee et al., 2019; Qin et al., 2023b), training over-parameterized neural network (Brand et al., 2021; Song et al., 2021; Alman et al., 2022; Hu et al., 2022; Zhang, 2022).

where the UU is the orthonormal basis of AA.

For the reason of above conditions are equivalent, we refer the readers to the survey (Woodruff, 2014).

In (Wang & Woodruff, 2018), they show for every 1p<21\leq p<2, any oblivious subspace embedding with dimension rr has distortion κ=Ω(1(1d)1/plog2/pr+(rn)1/p1/2)\kappa=\Omega(\frac{1}{(\frac{1}{d})^{1/p}\cdot\log^{2/p}r+(\frac{r}{n})^{1/p-1/2}}). They also give sparse oblivious subspace embeddings for every 1p<21\leq p<2 which are optimal in dimension and distortion, up to poly (logd)(\log d) factors. Importantly for p=1p=1, they achieve r=O(dlogd),κ=O(dlogd)r=O(d\log d),\kappa=O(d\log d) and s=O(logd)s=O(\log d) non-zero entries per column.

G.6 Random Matrices

We consider several standard sketching matrices:

Subsampled randomized Hadamard/Fourier transform (SRHT) matrices (Lu et al., 2013).

AMS sketch matrices (Alon et al., 1996), random {1,+1}\{-1,+1\} per entry.

Count-Sketch matrices (Charikar et al., 2002), each column only has one non-zero entry, and is 1,+1-1,+1 half probability each.

Sparse embedding matrices (Nelson & Nguyên, 2013), each column only has ss non-zero entries, and each entry is 1s,+1s-\frac{1}{\sqrt{s}},+\frac{1}{\sqrt{s}} half probability each.

Appendix H Distances, Angles, and Inner Product

Most of the properties in this section are very standard in literature, e.g., see (Gu et al., 2023).

For any matrix XX, and for orthogonal matrix YY (YY=IkY^{\top}Y=I_{k}) we define

tanθ(Y,X):=YX(YX)1\tan\theta(Y,X):=\|Y_{\bot}^{\top}X(Y^{\top}X)^{-1}\|

For orthogonal matrices YY and XX (YY=IkY^{\top}Y=I_{k} and XX=IkX^{\top}X=I_{k}), we define

cosθ(Y,X):=σmin(YX)\cos\theta(Y,X):=\sigma_{\min}(Y^{\top}X).

It is obvious that cos(Y,X)=1/(YX)1\cos(Y,X)=1/\|(Y^{\top}X)^{-1}\| and cos(Y,X)1\cos(Y,X)\leq 1.

It is obvious that sinθ(Y,X)=YYX=YX\sin\theta(Y,X)=\|Y_{\bot}Y_{\bot}^{\top}X\|=\|Y_{\bot}^{\top}X\| and sinθ(Y,X)1\sin\theta(Y,X)\leq 1.

dist(Y,X):=minQOkYQX\operatorname{dist}(Y,X):=\min_{Q\in O_{k}}\|YQ-X\|

where OkO_{k} is the set of k×kk\times k orthogonal matrices.

Let us first compute the Gram of YXY^{\top}X, which is

where the first step follows from YY+YY=IY_{\bot}Y_{\bot}^{\top}+YY^{\top}=I, the second step follows from simple algebra, and the last step follows from XX is an orthogonal matrix, so X=X1X^{\top}=X^{-1}.

This means that (YX)=YX(Y^{\top}X)_{\bot}=Y_{\bot}^{\top}X. ∎

we argue that xx corresponds to the smallest singular value of AA_{\bot} via contradiction. Suppose there exists some unit vector yy with Ay2<Ax2\|A_{\bot}y\|_{2}<\|A_{\bot}x\|_{2}, by definition, we know that Ay22+Ay22=1\|A_{\bot}y\|_{2}^{2}+\|Ay\|_{2}^{2}=1, this means that Ay2>Ax2=A\|Ay\|_{2}>\|Ax\|_{2}=\|A\|, contradicts the definition of spectral norm. Similarly, if zz is the unit vector that realizes the spectral norm of AA_{\bot}, then it is also singular vector corresponds to the smallest singular value of AA, or equivalently, the spectral norm of A1A^{-1}. Our above argument essentially implies that AA_{\bot} and A1A^{-1} have the same set of singular vectors. The proof is then straightforward: suppose Az=λzA_{\bot}z=\lambda z and A1z=μzA^{-1}z=\mu z, then

where the first step follows from our assumption, the second step follows from μ\mu is a real number and a real number multiplying a matrix is commutative and follows from the associative property, and the third step follows from our assumption.

Thus, we have AA1=AA1\|A_{\bot}A^{-1}\|=\|A_{\bot}\|\|A^{-1}\|, and we have proved the assertion. ∎

Due to Lemma H.2, we have (YX)=YX(Y^{\top}X)_{\bot}=Y^{\top}_{\bot}X. Thus, tanθ(Y,X)=(YX)(YX)1\tan\theta(Y,X)=\|(Y^{\top}X)_{\bot}(Y^{\top}X)^{-1}\|. The proof then follows straightforwardly from Lemma H.3. ∎

Let ϵ(0,0.1)\epsilon\in(0,0.1) Let xx denote a unit vector, i.e., x2=1\|x\|_{2}=1.

If y2ϵx2\|y\|_{2}\leq\epsilon\cdot\|x\|_{2}, then

where the first step follows from triangle inequality.

where the first step follows from Eq. (8) and Eq. (9) and the rest of them follow from simple algebra.

where the first step follow the definition of zz, the second step follows from the reorganization, the third step follows from the definition of inner product, the fourth step follows from x2=1\|x\|_{2}=1, the fifth step follows from Eq. (H.1), the sixth step follows from 1+x,y1x,y1x2y21ϵ1+\langle x,y\rangle\geq 1-|\langle x,y\rangle|\geq 1-\|x\|_{2}\cdot\|y\|_{2}\geq 1-\epsilon, the seventh step follows from Eq. (H.1) and the last step follows from simple algebra.

Appendix I Function Approximations

We first we show the function approximation for two operators in Section I.1, which means that there are two functions. Then we show the function approximations for four operators in Section I.2.

Condition 1a. f1f_{1} is a linear function

Condition 1b. f1(x)2ϵ1x2\|f_{1}(x)\|_{2}\leq\epsilon_{1}\|x\|_{2} (f1f_{1} is shrinking)

Condition 1c. f1(x)f1(y)2L1xy2\|f_{1}(x)-f_{1}(y)\|_{2}\leq L_{1}\|x-y\|_{2} (f1f_{1} is Lipschitz)

Condition 2a. f2f_{2} is a linear function

Condition 2b. f2(x)2ϵ2x2\|f_{2}(x)\|_{2}\leq\epsilon_{2}\|x\|_{2} (f2f_{2} is shrinking)

Condition 2c. f2(x)f2(y)2L2xy2\|f_{2}(x)-f_{2}(y)\|_{2}\leq L_{2}\|x-y\|_{2} (f2f_{2} is Lipschitz)

Part 1. g1(x)g2(x)22ϵ1ϵ2x2\|g_{1}(x)-g_{2}(x)\|_{2}\leq 2\epsilon_{1}\epsilon_{2}\|x\|_{2}(if f1f_{1} and f2f_{2} are linear functions)

Part 2. g1(x)g2(x)2(ϵ2L1+ϵ1L2)x2\|g_{1}(x)-g_{2}(x)\|_{2}\leq(\epsilon_{2}\cdot L_{1}+\epsilon_{1}\cdot L_{2})\|x\|_{2} (if f1f_{1} and f2f_{2} are Lipschitz functions)

Part 3. g1(x)g3(x)2ϵ1ϵ2x2\|g_{1}(x)-g_{3}(x)\|_{2}\leq\epsilon_{1}\epsilon_{2}\|x\|_{2} (if f1f_{1} is a linear function)

Part 4. g1(x)g3(x)2ϵ2L1x2\|g_{1}(x)-g_{3}(x)\|_{2}\leq\epsilon_{2}\cdot L_{1}\|x\|_{2} (if f1f_{1} is a Lipschitz function)

Part 5. g2(x)g3(x)2ϵ1ϵ2x2\|g_{2}(x)-g_{3}(x)\|_{2}\leq\epsilon_{1}\epsilon_{2}\|x\|_{2} (if f2f_{2} is a linear function)

Part 6. g2(x)g3(x)2ϵ1L2x2\|g_{2}(x)-g_{3}(x)\|_{2}\leq\epsilon_{1}\cdot L_{2}\|x\|_{2} (if f2f_{2} is a Lipschitz function)

where the first step follows from triangular inequality, the second step follows from Part 3 and Part 5 and the last step follows from simple algebra.

where the first step follows from triangular inequality, the second step follows from Part 4 and Part 6 and the last step follows from simple algebra.

where the first step follows from the definition of g1g_{1} and g3g_{3}, the second step follows from the fact that f1f_{1} is a linear function, the third step follows from simple algebra, the fourth step follows from Condition 1b and the last step follows from Condition 2b.

where the first step follows from definition of g1g_{1} and g3g_{3}, the second step follows from Condition 1c, the third step follows from simple algebra and the last step follows from Condition 2b. ∎

where the first step follows from the definition of g2g_{2} and g3g_{3}, the second step follows from the fact that f2f_{2} is a linear function, the third step follows from simple algebra, the fourth step follows from Condition 2b and the last step follows from Condition 1b.

where the first step follows from definition of g1g_{1} and g3g_{3}, the second step follows from Condition 2c, the third step follows from simple algebra and the last step follows from Condition 1b.

I.2 Function Approximations for Four Operators

For each ii\in, we assume the following conditions

i(b) fi(x)2ϵix2\|f_{i}(x)\|_{2}\leq\epsilon_{i}\|x\|_{2} (fif_{i} is shriking)

i(c) fi(x)fi(y)2Lixy2\|f_{i}(x)-f_{i}(y)\|_{2}\leq L_{i}\|x-y\|_{2} (fif_{i} is Lipschitz)

g1(x):=(I+f1)(I+f2)(I+f3)(I+f4)(x)g_{1}(x):=(I+f_{1})\cdot(I+f_{2})\cdot(I+f_{3})\cdot(I+f_{4})(x)

g2(x):=(I+f1)(I+f3)(I+f2)(I+f4)(x)g_{2}(x):=(I+f_{1})\cdot(I+f_{3})\cdot(I+f_{2})\cdot(I+f_{4})(x)

g3(x):=(I+f1+f2+f3+f4)(x)g_{3}(x):=(I+f_{1}+f_{2}+f_{3}+f_{4})(x)

where the first step follows from triangular inequality and the last step follows from Part 3 and Part 5.

where the first step follows from triangular inequality and the last step follows from Part 4 and Part 6.

where the first step follows from the definition of g1g_{1} and g3g_{3}, the second step follows from simple algebra, the third step follows from reorganization, the fourth step follows from the fact that all fi,if_{i},\forall i\in are linear function, the fifth step follows from triangular inequality, the sixth step follows from i(b)i(b) and the last step follows from reorganization.

where the first step follows from the definition of g1g_{1} and g3g_{3}, the second step follows from simple algebra, the third step follows from simple algebra, the fourth step follows from reorganization, the fifth step follows from the fact that all fi,if_{i},\forall i\in are Lipschitz functions, the sixth step follows from simple algebra, the seventh step follows from i(b)i(b), the eighth step follows from triangular inequality, the ninth step follows from i(b)i(b), the tenth step follows from i(b)i(b) and the last step follows from reorganization.

where the first step follows from the definition of g2g_{2} and g3g_{3}, the second step follows from simple algebra, the third step follows from the fact that all fi,if_{i},\forall i\in are linear function, the fourth step follows from triangular inequality and i(b)i(b), and the last step follows from reorganization.

where the first step follows from the definition of g2g_{2} and g3g_{3}, the second step follows from simple algebra, the third step follows from reorganization, the fourth step follows from triangular inequality, the fifth step follows from the fact that all fi,if_{i},\forall i\in are Lipschitz functions and i(b)i(b), the sixth step follows from triangular inequality, and the last step follows from reorganization.

Appendix J Nearest Neighbor Search Data Structure

We use the reduction-based approximate MaxIP\mathsf{MaxIP} method with LSH\mathsf{LSH} data-structure to achieve sublinear iteration cost. Note that we choose this method due to its clear theoretical guarantee on the retrieval results. It is well-known that an LSH\mathsf{LSH} data-structures is used for approximate nearest neighbor problem. The following definition of approximate nearest neighbor search is very standard in literature (Arya & Mount, 1993; Indyk & Motwani, 1998a; Datar et al., 2004; Andoni et al., 2014, 2015; Andoni & Razenshteyn, 2015; Indyk & Wagner, 2018; Andoni et al., 2017, 2018; Dong et al., 2019; Chen et al., 2020b; Li & Li, 2022; Li et al., 2019).

We start with the defining the Approximate Nearest Neighbor (ANN\mathsf{ANN}) problem (Arya & Mount, 1993; Indyk & Motwani, 1998a; Datar et al., 2004; Andoni et al., 2014, 2015; Andoni & Razenshteyn, 2015; Indyk & Wagner, 2018; Andoni et al., 2017, 2018; Dong et al., 2019; Chen et al., 2020b) as:

The ANN\mathsf{ANN} problem can be solved via locality sensitive hashing (LSH\mathsf{LSH}) (Indyk & Motwani, 1998a; Datar et al., 2004; Indyk & Wagner, 2018). In this paper, we use the standard definitions of LSH\mathsf{LSH} (see Indyk and Motwani (Indyk & Motwani, 1998a)).

if xy2r\|x-y\|_{2}\leq r, then PrhH[h(x)=h(y)]p1\Pr_{h\sim{\cal H}}[h(x)=h(y)]\geq p_{1},

if xy2cr\|x-y\|_{2}\geq\overline{c}\cdot r, then PrhH[h(x)=h(y)]p2\Pr_{h\sim{\cal H}}[h(x)=h(y)]\leq p_{2}.

Next, we show that LSH\mathsf{LSH} solves ANN\mathsf{ANN} problem with sublinear query time complexity.

Let c>1\overline{c}>1 and r(0,2)r\in(0,2) denote two parameters. One can solve (c,r)(\overline{c},r)-ANN\mathsf{ANN} on a unit sphere in query time O(dnρ)O(d\cdot n^{\rho}) using preprocessing time O(dn1+o(1))O(dn^{1+o(1)}) and space O(n1+o(1)+dn)O(n^{1+o(1)}+dn), where ρ=2c21c4+o(1)\rho=\frac{2}{\overline{c}^{2}}-\frac{1}{\overline{c}^{4}}+o(1).

Here we write o(1)o(1) is equivalent to O(1/logn)O(1/\sqrt{\log n}). Note that we could reduce dd to no(1)n^{o(1)} with Johnson–Lindenstrauss Lemma (Johnson & Lindenstrauss, 1984). Besides, we could achieve better ρ\rho using LSH\mathsf{LSH} in (Andoni & Razenshteyn, 2015) if we allowed to have more proprocessing time.

In this work, we focus on a well-known problem in computational complexity: approximate MaxIP\mathsf{MaxIP}. In this work, we follow the standard notation in (Chen, 2018) and define the approximate MaxIP\mathsf{MaxIP} problem as follows:

In many applications, it is more convenient to doing inner product search in a transformed/projected space compared to doing inner product search in the original space. Thus, we propose the following definitions (Definition J.5 and Definition J.6)

J.2 Connections

Let x~\widetilde{x} denote the vector that x~,x112ϵ2\langle\widetilde{x},x\rangle\geq 1-\frac{1}{2}\epsilon^{2}, where both x~\widetilde{x} and xx are unit vectors. We have

Let x~\widetilde{x} denote the vector that x~,x112ϵ2\langle\widetilde{x},x\rangle\geq 1-\frac{1}{2}\epsilon^{2}, where both x~\widetilde{x} and xx are unit vectors. Let 0.01cτ>ϵ0.01c\cdot\tau>\epsilon.

Suppose there is a zYz\in Y, where z2=1\|z\|_{2}=1, such that

Note that maxyYx,yτ\max_{y\in Y}\langle x,y\rangle\geq\tau. Then, we can find a zYz\in Y such that

where the first step follows from simple algebra, the second step follows from the fact that x,yx,y\langle x,y\rangle\geq-|\langle x,y\rangle|, the third step follows from the property of inner product, the fourth step follows from Fact J.7, the fifth step follows from x,zcmaxyYx,y\langle x,z\rangle\geq c\cdot\max_{y\in Y}\langle x,y\rangle and the final step follows from the fact that 0.01cτ>ϵ0.01c\cdot\tau>\epsilon.

J.3 Efficient Transformations

Therefore, we could transform the direction search problem into a MaxIP\mathsf{MaxIP} problem.

Here DxD_{x}, DyD_{y} are some constant that make sure both x/Dxx/D_{x} and y/Dyy/D_{y} have norms less than 11. Under these transformations, both ϕ1(x)\phi_{1}(x) and ψ1(y)\psi_{1}(y) have norm 11 and argmaxyYϕ1(x),ψ1(y)=argmaxyYx,y\arg\max_{y\in Y}\langle\phi_{1}(x),\psi_{1}(y)\rangle=\arg\max_{y\in Y}\langle x,y\rangle.

J.4 Data Structures

In this section, we present a formal statement that solves (c,τ)(c,\tau)-MaxIP\mathsf{MaxIP} problem on unit sphere using LSH\mathsf{LSH} for (c,r)(\overline{c},r)-ANN\mathsf{ANN}.

Let c(0,1)c\in(0,1) and τ(0,1)\tau\in(0,1). Given a set of nn-vector set YSd1Y\subset{\cal S}^{d-1} on the unit sphere, there exists a data structure with O(dn1+o(1))O(dn^{1+o(1)}) preprocessing time and O(n1+o(1)+dn)O(n^{1+o(1)}+dn) space so that for any query xSd1x\in{\cal S}^{d-1}, we take O(dnρ)O(d\cdot n^{\rho}) query time to retrieve the (c,τ)(c,\tau)-MaxIP\mathsf{MaxIP} of xx in YY with probability at least 0.90.9It is obvious to boost probability from constant to δ\delta by repeating the data structure log(1/δ)\log(1/\delta) times., where ρ:=2(1τ)2(1cτ)2(1τ)4(1cτ)4+o(1)\rho:=\frac{2(1-\tau)^{2}}{(1-c\tau)^{2}}-\frac{(1-\tau)^{4}}{(1-c\tau)^{4}}+o(1)

We know that xy22=22x,y\|x-y\|_{2}^{2}=2-2\langle x,y\rangle for all x,ySd1x,y\in{\cal S}^{d-1}. In this way, if we have a LSH\mathsf{LSH} data-structure for (c,r)(\overline{c},r)-ANN\mathsf{ANN}. It could be used to solve (c,τ)(c,\tau)-MaxIP\mathsf{MaxIP} with τ=10.5r2\tau=1-0.5r^{2} and c=10.5c2r210.5r2c=\frac{1-0.5\overline{c}^{2}r^{2}}{1-0.5r^{2}}. Next, we write c2\overline{c}^{2} as

Next, we show that if the LSH\mathsf{LSH} is initialized following Theorem J.3, it takes query time O(dnρ)O(d\cdot n^{\rho}), space O(n1+o(1)+dn)O(n^{1+o(1)}+dn) and preprocessing time O(dn1+o(1))O(dn^{1+o(1)}) to solve (c,τ)(c,\tau)-MaxIP\mathsf{MaxIP} through solving (c,r)(\overline{c},r)-ANN\mathsf{ANN}, where

In practice, cc is increasing as we set parameter τ\tau close to MaxIP(x,Y)\mathsf{MaxIP}(x,Y). There is also another LSH\mathsf{LSH} data structure (Andoni & Razenshteyn, 2015) with longer preprocessing time and larger space that could solve the (c,τ)(c,\tau)-MaxIP\mathsf{MaxIP} with similar query time complexity. We refer readers to Section 8.2 in (Shrivastava et al., 2021) for more detailsRecently, there a line of work that use fast MaxIP\mathsf{MaxIP} data structure to speedup the iterative-type optimization algorithms (Shrivastava et al., 2021; Song & Ye, 2023; Qin et al., 2023a; Song et al., 2023a).. Moreover, Corollary J.9 could be applied to projected MaxIP\mathsf{MaxIP} problem.

The preprocessing phase can be decomposed in two parts.

It takes O(Tψn)O({\cal T}_{\psi}n) time to transform every yYy\in Y into ψ(y)\psi(y).

It takes O(O(dn1+o(1))O(O(dn^{1+o(1)}) time and O(dn1+o(1)+dn)O(dn^{1+o(1)}+dn) to index every ψ(y)\psi(y) into LSH\mathsf{LSH} using Theorem J.9.

The query phase can be decomposed in two parts.

It takes O(dnρ)O(d\cdot n^{\rho}) time perform query for ϕ(x)\phi(x) in LSH\mathsf{LSH} using Theorem J.9.

Appendix K Self-attention layer as a clustering algorithm

The self-attention layer in the Transformer looks like mean-shift clustering. Suppose {(xj,vj)}\{({\bm{x}}_{j},{\bm{v}}_{j})\} are a bunch of key and value pairs and q{\bm{q}} is the query. Note that q=Wqx{\bm{q}}=W_{q}{\bm{x}}, k=Wkx{\bm{k}}=W_{k}{\bm{x}} and v=Wvx{\bm{v}}=W_{v}{\bm{x}} are computed by three projection matrices WkW_{k}, WqW_{q} and WvW_{v} from a common x{\bm{x}}. Then from self-attention we have:

where (q,kj):=exp(qkj)=exp(xWqWkxj)\sim({\bm{q}},{\bm{k}}_{j}):=\exp({\bm{q}}^{\intercal}{\bm{k}}_{j})=\exp({\bm{x}}^{\intercal}W_{q}^{\intercal}W_{k}{\bm{x}}_{j}) and pj=(q,kj)/j(q,kj)p_{j}=\sim({\bm{q}},{\bm{k}}_{j})/\sum_{j}\sim({\bm{q}},{\bm{k}}_{j}).

On the other hand, mean-shift clustering looks like the following:

where K(xj,x)K({\bm{x}}_{j},{\bm{x}}) is a kernel matrix that measure the similarity between xj{\bm{x}}_{j} and x{\bm{x}}. According to the mean-shift algorithm, in the next iteration, we will simply replace x{\bm{x}} with m(x)m({\bm{x}}).

So in some sense, self-attention is just to do some kind of clustering for the input embedding q{\bm{q}} and k{\bm{k}}, plus a transformation of the embedding to another place. The term “projection” is due to the fact that there is a projection matrix WvW_{v} on x{\bm{x}} for the next level.

Residue connection and LayerNorm. Compared to mean-shift, Transformer layer has residue connection. Therefore, for single-headed attention, what you actually get is v+x{\bm{v}}+{\bm{x}}, followed by a LayerNorm. For the residue connection, the mean-shift analog already shows the output m(x)m({\bm{x}}) contains x+{\bm{x}}+ part. The reason why we need residue connection is that the self-attention part might only model the “change” of x{\bm{x}} in the mean-shift picture, rather than the full update of x{\bm{x}}.

Appendix L The role of self-attention

Consider we have a vocabulary of size mm and dd dimensional embedding space. In practice, many papers in NLP have reported clustering behaviors of word embeddings: such a clustering of word embedding naturally occurs after training.

An explanation for the above phenomenon is that, by grouping these word embedding together, we might generalize better, since similarity in word now can transfer (e.g., A linked to B, B linked to C, then A might link to C as well) and generalization follows.

Let’s treat it as a fact and focus on how this is achieved and how self-attention plays a role here.

First let us take a look at the following pairwise distance constraints between word embedding (e.g., some words should be close to each other, some should be far away from each other) as the following:

where D(i,j)D(i,j) is large for ii and jj that should be far apart and D(i,j)D(i,j) is small for ii and jj that are close to each other. In visualization, this is called Multidimensional Scaling (MDS) (Cox & Cox, 2008).

Note that in neural network training, the constraint (Eqn. 16) is not directly enforced during training, but the clustering naturally happens. Since we talk about capacity, how we achieve Eqn. 16 doesn’t matter for now.

In general we cannot find a fixed low-dimensional embedding (dmd\ll m) to satisfy these constraints, since we only have mdmd parameters (mm vectors, each has dd entries), but m2m^{2} constraint. So two vectors that are supposed to be close may not be close enough (but hopefully they remain close to each other).

L.2 The role of self-attention

For this, the self-attention mechanism comes to the rescue, trading model-size with additional computation. It fulfills what (static) embedding cannot achieve: to further group the embedding vectors together in a multi-layer structure.

Note that one sentence never covers all dd vocabularies. Once the words in the sentence are picked, they are grouped together via self-attention layers to collectively represent a concept that can be useful for the task.

L.3 How the clustering happens through self-attention?

Now one fundamental questions arise: How the static clustering of embedding happens during end-to-end training? In practice, no one explicitly enforces the MDS constraint (Eqn. 16).

Let’s start with a simple example. we have two unit embedding: x{\bm{x}} and y{\bm{y}} with the normalization condition that x2=1\|{\bm{x}}\|_{2}=1 and y2=1\|{\bm{y}}\|_{2}=1, and a simple self-attention layer (without projection) which output z{\bm{z}}:

Note that here we attend to x{\bm{x}} so 0<p<1/20<p<1/2 always. The last two is due to normalization condition.

Now we consider a loss function L=12z22L=-\frac{1}{2}\|{\bm{z}}\|_{2}^{2}. The intuition behind is that “for some reason, we found that z{\bm{z}} is a good representation for our task, and want to make sure its length is as long as possible”.

Under this context, what would be the gradient rule for x{\bm{x}} and y{\bm{y}}? Will they cluster together?

Let t:=1xyt:=1-{\bm{x}}^{\intercal}{\bm{y}} and define the following function with respect to tt:

Therefore, we can compute the gradient for x{\bm{x}} and gradient for y{\bm{y}}:

Note that since x{\bm{x}} and y{\bm{y}} are kept to be normalized, the term (1p)2x(1-p)^{2}{\bm{x}} in L/x\partial L/\partial{\bm{x}} is gone (and similarly p2yp^{2}{\bm{y}} for gy{\bm{g}}_{\bm{y}}). So how x{\bm{x}} and y{\bm{y}} move depends on the sign of 1f(t)1-f(t).

With some computation, we could see 0<f(t)<10<f(t)<1 when t<1.5424t<1.5424. In summary, if xy>0.4576{\bm{x}}^{\intercal}{\bm{y}}>-0.4576, then the (negative) gradient of x{\bm{x}} pushes it towards y{\bm{y}} and pushes x{\bm{x}} towards y{\bm{y}}, and the clustering of static embedding happens during training. Note that since both x{\bm{x}} and y{\bm{y}} are normalized, 1xy1-1\leq{\bm{x}}^{\intercal}{\bm{y}}\leq 1, so this is a quite loose condition and can be easily satisfied.

L.4 Multiple embeddings

People might wonder what happen to multiple unit embeddings x,y1,y2,,yK{\bm{x}},{\bm{y}}_{1},{\bm{y}}_{2},\ldots,{\bm{y}}_{K}? In this case, we can similarly define self-attention probability pip_{i} (note that here we consider the case that every embedding attends to x{\bm{x}}):

Define pS:=i=1Kpi=111+jexyj<1p_{S}:=\sum_{i=1}^{K}p_{i}=1-\frac{1}{1+\sum_{j}e^{{\bm{x}}^{\intercal}{\bm{y}}_{j}}}<1 and we have:

Now we can still compute the partial derivative:

Similar to the two unit case, we want to check gx-{\bm{g}}_{\bm{x}} to see how the embedding x{\bm{x}} changes over time.

If things are already quite clustered, then yˉ1\|\bar{\bm{y}}\|\approx 1 (usually yˉ2<1\|\bar{\bm{y}}\|_{2}<1 since sphere is a convex set), Qz0Q{\bm{z}}\approx 0 (since QQ spans on the tangent space of z{\bm{z}} at the sphere and z{\bm{z}} is perpendicular to it), and we have:

It is clear that xyˉ<1{\bm{x}}^{\intercal}\bar{\bm{y}}<1. When pS>1/2p_{S}>1/2, which is high likely for large KK, then gx-{\bm{g}}_{\bm{x}} has positive component of yˉ\bar{\bm{y}} and x{\bm{x}} will move towards yˉ\bar{\bm{y}}.

which gives an expression of gy-{\bm{g}}_{\bm{y}}:

With the same argument, it moves towards yˉ\bar{\bm{y}} (so all yi{\bm{y}}_{i} will cluster together) and towards x{\bm{x}}.

When there is a WkW_{k} and WqW_{q} before the embedding, following the same logic, only the column subspace of WkW_{k} (or WqW_{q}) will be clustered together. On the other hand, the value part will be different in order to enable encoding of more complicated concepts based on co-occurrence of multiple tokens.

Appendix M Link self-attention with generative models.

Here ϕ(xi;xj):=xi+βij(xjxi)\phi({\bm{x}}_{i};{\bm{x}}_{j}):={\bm{x}}_{i}+\beta_{ij}({\bm{x}}_{j}-{\bm{x}}_{i}) is the self-attention operation. More properties of this operator ϕ\phi need to be explored. Then we want to maximize the following objective:

or more formally, using a softmax to avoid trivial solution xix{\bm{x}}_{i}\equiv{\bm{x}}, we have:

We can compute its gradient update. Here we assume the index kk never appears in index ii and jj (encoding and decoding matrices are decoupled), then by gradient rule, we have:

where PxkP^{\perp}_{{\bm{x}}_{k}} is the projection matrix that projects a vector to the orthogonal complement space of xk{\bm{x}}_{k}. The projection is due to the constraint xk2=1\|{\bm{x}}_{k}\|_{2}=1. If the training converges (x˙k=0\dot{\bm{x}}_{k}=0), then we know that

for some γ>0\gamma>0 (note that γ<0\gamma<0 will be an unstable stationary point).

Depending on different structure of the generative model specified by P(ki,j)P(k|i,j), we might end up learning different embedding matrix XX.

And we could possibly show that jβij(xjxi)0\sum_{j}\beta_{ij}({\bm{x}}_{j}-{\bm{x}}_{i})\approx 0 since βij=1/(1+e1xixj)\beta_{ij}=1/(1+e^{1-{\bm{x}}_{i}^{\intercal}{\bm{x}}_{j}}) applies equal weights for embeddings around xi{\bm{x}}_{i} and they cancel out. Therefore, xk{\bm{x}}_{k} is aligned with xi{\bm{x}}_{i}.

If we have WqW_{q}, WkW_{k} and WvW_{v}, then the formulation doesn’t change that much. The only difference here is that now

and yijxk{\bm{y}}_{ij}^{\intercal}{\bm{x}}_{k} now becomes yijWvxk{\bm{y}}_{ij}^{\intercal}W_{v}{\bm{x}}_{k}.