English (unofficial) translations of posts at kexue.fm
Source

VQ the Key, and Transformer Complexity Becomes Linear

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

Efficient Transformer is a general term for all efforts dedicated to reducing the quadratic complexity of the Transformer. Initially, it specifically referred to improvements to the Attention mechanism, but later more general ideas, such as Fourier transforms and linear RNNs, were also included in this category. It must be said that in order to reduce the quadratic complexity of the Transformer, various experts have "crossed the sea like the Eight Immortals, each showing their own prowess," and various magical ideas have "flourished." I have also learned a lot of theoretical knowledge from them. However, although Efficient Transformers are brilliant in theory, the field has actually remained in a lukewarm state, with no models performing exceptionally well in practice. In today’s LLM-dominated era, they have even gradually faded from public view and from my own range of interests.

However, a recent paper titled "Transformer-VQ: Linear-Time Transformers via Vector Quantization" has truly impressed me. The authors brilliantly observed that by simply applying VQ (Vector Quantization) to the Keys of standard Attention, the complexity automatically becomes linear! This linearization approach preserves the form of standard Attention and serves as a perfect transition from standard Attention to linear Attention, while retaining the capabilities of standard Attention to the greatest extent possible.

The Challenge of Efficiency

Speaking of which, this site was among the early followers of Efficient Transformer work, dating back to a 2019 blog post interpreting Sparse Transformers: "Born for Savings: From Standard Attention to Sparse Attention". Since then, I have written several other blog posts on Efficient Transformers:

However, as mentioned at the beginning of this article, although there has been much work on Efficient Transformers and they were once highly anticipated, the field has not produced many "breakout" works. The reasons might be:

  1. Many Efficient Transformers achieve speedups at the cost of performance;

  2. The complexity reduction of many Efficient Transformers is only theoretical, with no significant improvement in actual use;

  3. Some Efficient Transformers are difficult to use for training Causal LMs, making them less useful in the current LLM era;

  4. The emergence of Flash Attention shows that even standard Transformers still have significant room for speedup.

Just VQ it

So, why does Transformer-VQ have the potential to "break out"?

Simply put, Transformer-VQ "clusters" the Key vector sequence of Attention and approximates the original vectors with their respective cluster centers, which then makes the Attention complexity linear. In other words, Transformer-VQ only changes the form of the Key, while the rest (theoretically) remains completely unchanged. Therefore, this is a linearization scheme with very minimal changes to Attention, and it clearly shows where the precision is lost after linearization (i.e., the gap between the original vector and the cluster center).

With that preamble, let’s formally introduce Transformer-VQ. First, assume Q, K \in \mathbb{R}^{n \times d_k} and V \in \mathbb{R}^{n \times d_v}. Standard Attention is: \begin{equation} \text{softmax}\left(QK^{\top}\right)V \end{equation} For simplicity, the scale factor is omitted here. Transformer-VQ changes this to: \begin{equation} \text{softmax}\left(Q\hat{K}^{\top}\right)V, \quad \hat{K} = \mathcal{VQ}(K, C) \label{eq:vq-att} \end{equation} where C \in \mathbb{R}^{c \times d_k} is a trainable parameter and also the Codebook for VQ. By the way, "VQ" here refers to the VQ in VQ-VAE. Readers unfamiliar with this can refer to "A Brief Introduction to VQ-VAE: Quantized Autoencoders" and "Embarrassingly Simple FSQ: ’Rounding’ Surpasses VQ-VAE". In short, after \mathcal{VQ}, the most direct result is that each vector in K becomes the one in C that is closest to it. This means each vector in \hat{K} is one of the vectors in C; in mathematical terms, K \in \mathbb{R}^{n \times d_k} becomes \hat{K} \in C^n.

Encoder

Of course, if we implement Transformer-VQ directly according to Equation [eq:vq-att], the complexity is still quadratic. However, since every vector in \hat{K} is one of the vectors in C, we can first calculate \exp(QC^{\top}) and then "pick out" the results corresponding to \exp(Q\hat{K}^{\top}). Since the size of C is fixed, the complexity of the key operation QC^{\top} is linear. This is the principle behind the linearization of Transformer-VQ (which we might call the "picking out" trick).

As a foundation, let’s first consider the bidirectional attention case for an Encoder. Since: \begin{equation} \text{softmax}\left(QK^{\top}\right)V = \frac{\exp\left(QK^{\top}\right)V}{\exp\left(QK^{\top}\right)1_{n\times 1}} \label{eq:softmax-qkv} \end{equation} where 1_{n\times 1} refers to an n \times 1 matrix of all ones. The denominator can be seen as a special form of the numerator, so we only need to consider the numerator \exp(QK^{\top})V. Since each vector in \hat{K} is one of C, we can construct a one-hot matrix \Delta \in \{0,1\}^{n \times c}, where \Delta_i \in \{0,1\}^c is a one-hot vector. If the dimension where 1 is located is j, then \hat{K}_i = C_j, so \hat{K} = \Delta C.

Thus, for Transformer-VQ, we have: \begin{equation} \exp\left(Q\hat{K}^{\top}\right)V = \exp\left(QC^{\top}\Delta^{\top}\right)V = \exp\left(QC^{\top}\right)\Delta^{\top}V = \exp\left(QC^{\top}\right)(\Delta^{\top}V) \end{equation} Obviously, the most critical part is the second equal sign! For the one-hot matrix \Delta, right-multiplying by its transpose can be separated from the \exp function. This is the mathematical expression of the "picking out" trick. After separation, due to the associative law of matrix multiplication, \Delta^{\top} can first be multiplied by V to obtain a c \times d_v matrix. Since \exp(QC^{\top}) is an n \times c matrix, multiplying it by \Delta^{\top}V yields an n \times d_v matrix. The total theoretical complexity is \mathcal{O}(ncd_k + ncd_v + ncd_v) = \mathcal{O}(n).

Finally, according to Equation [eq:softmax-qkv], by substituting the result of \exp(Q\hat{K}^{\top})V, the complete Attention result can be calculated (possibly with some details to avoid overflow). The entire process can be completed within linear complexity.

Decoder

Now let’s consider the unidirectional attention for a Decoder, which is key to training generative models and the basis of current LLMs. With the Encoder foundation, the Decoder is not difficult to understand. Suppose Q_i, \hat{K}_j \in \mathbb{R}^{1 \times d_k} and V_j \in \mathbb{R}^{1 \times d_v} are row vectors of the sequences Q, \hat{K}, V. For the numerator of the Decoder: \begin{equation} \begin{aligned} O_i =& \sum_{j\leq i}\exp\left(Q_i\hat{K}_j^{\top}\right)V_j = \sum_{j\leq i}\exp\left(Q_i C^{\top}\Delta_j^{\top}\right)V_j \\ =& \sum_{j\leq i}\exp\left(Q_i C^{\top}\right)\Delta_j^{\top}V_j = \exp\left(Q_i C^{\top}\right)\sum_{j\leq i}\Delta_j^{\top}V_j \end{aligned} \end{equation} If c \times d_v is not large, the final expression can be calculated directly using the \text{cumsum} operator. However, in general, especially with Multi-Head Attention, to save memory, it is usually converted into an RNN for recursive calculation, similar to the "Autoregressive Generation" section in "Exploring Linear Attention". Let U_i = \sum_{j\leq i}\Delta_j^{\top}V_j \in \mathbb{R}^{c \times d_v}, then: \begin{equation} O_i = \exp\left(Q_i C^{\top}\right)U_i, \quad U_i = U_{i-1} + \Delta_i^{\top}V_i \end{equation} In the inference stage, this step-by-step recursive calculation is fine, but in the training stage, it might be slow. We can change it to block-by-block calculation to accelerate: without loss of generality, let n=lm, where l is the block size and m is the number of blocks. The block slice [il:(i+1)l] is abbreviated as [i]. Then: \begin{equation} \begin{aligned} O_{[i]} =& \exp\left(Q_{[i]}\hat{K}_{[i]}^{\top} + M\right)V_{[i]} + \sum_{j < i}\exp\left(Q_{[i]}\hat{K}_{[j]}^{\top}\right)V_{[j]} \\ =& \exp\left(Q_{[i]}\hat{K}_{[i]}^{\top} + M\right)V_{[i]} + \sum_{j < i}\exp\left(Q_{[i]}C^{\top}\Delta_{[j]}^{\top}\right)V_{[j]} \\ =& \exp\left(Q_{[i]}\hat{K}_{[i]}^{\top} + M\right)V_{[i]} + \exp\left(Q_{[i]}C^{\top}\right)\sum_{j < i}\Delta_{[j]}^{\top}V_{[j]} \\ \end{aligned} \end{equation} where M \in \{-\infty, 0\}^{l \times l} is the lower triangular Attention Mask, i.e., M_{i,j}=0 when i \geq j, otherwise M_{i,j}=-\infty. Denoting U_i = \sum_{j < i}\Delta_{[j]}^{\top}V_{[j]}, we have: \begin{equation} O_{[i]} = \exp\left(Q_{[i]}\hat{K}_{[i]}^{\top} + M\right)V_{[i]} + \exp\left(Q_{[i]}C^{\top}\right)U_{i-1}, \quad U_i = U_{i-1} + \Delta_{[i]}^{\top}V_{[i]} \end{equation} This reduces the number of recursive steps to m, allowing for linear efficiency while fully utilizing hardware parallelism. The denominator can be calculated in the same way, and the complete Attention result is obtained by division.

Local Enhancement

Is that all? Not quite. If it were just this, Transformer-VQ might not differ much from previous matrix decomposition-based Kernelized Attention like Performer. When the sequence length n is much larger than the codebook size c, the pigeonhole principle tells us that some code vectors will inevitably reappear, and we can even reasonably guess that all code vectors should be evenly distributed throughout the sequence. Consequently, the Attention of nearby tokens would be the same as that of some distant tokens, meaning the model cannot distinguish between near and far. This is essentially the low-rank problem inherent in all Kernelized Attention.

Experience tells us that for language models, nearby tokens are often more important than distant ones, so a good language model architecture should have the ability to distinguish distance. To this end, Transformer-VQ chooses to add a Sliding Window-shaped Attention Bias (denoted as B) after Q\hat{K} to weight nearby tokens, as shown below:

Window Attention Bias Diagram

As seen from the diagram, if the window size is set to the block size l, i.e., B_{i,j}=0 when i < j or i - j \leq l, then in block-wise calculation, the matrix B at most affects the two nearest blocks. Further blocks can still be linearized using the "picking out" trick. For the following derivation, let B_{[i,j]} = B_{[il:(i+1)l, jl:(j+1)l]}, then: \begin{equation} \begin{aligned} O_{[i]} =& \exp\left(Q_{[i]}\hat{K}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \sum_{j < i-1}\exp\left(Q_{[i]}\hat{K}_{[j]}^{\top}\right)V_{[j]} \\ =& \exp\left(Q_{[i]}\hat{K}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \sum_{j < i-1}\exp\left(Q_{[i]}C^{\top}\Delta_{[j]}^{\top}\right)V_{[j]} \\ =& \exp\left(Q_{[i]}\hat{K}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \exp\left(Q_{[i]}C^{\top}\right)\sum_{j < i-1}\Delta_{[j]}^{\top}V_{[j]} \\ \end{aligned} \end{equation} So clearly, we have (assuming V_{[-1]}, U_{[-1]}, U_{[-2]} are all zero matrices): \begin{equation} \begin{aligned} O_{[i]} =& \exp\left(Q_{[i]}\hat{K}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \exp\left(Q_{[i]}C^{\top}\right)U_{i-2}\\[5pt] U_i =& U_{i-1} + \Delta_{[i]}^{\top}V_{[i]} \end{aligned} \label{eq:tvq} \end{equation} I believe the introduction of B is the key to Transformer-VQ pulling ahead of other Kernelized Attention methods. To reduce the number of parameters and support variable-length generation, we constrain the non-zero part of B to be a "Toeplitz matrix," meaning B_{i,j} is a function of i-j. In this case, B is equivalent to an additive relative position encoding. Alternatively, one could consider using ReRoPE, which is a windowed version of Rotary Position Embedding and shares the same relative position encoding shape as B.

Gradient Backpropagation

Wait, we seem to have forgotten something. Readers familiar with VQ-VAE know that "each vector in \hat{K} is one of the vectors in C" is only the behavior in the forward pass; the backward pass uses the original K. This means that even if \hat{K}_j at different positions equals the same C_k, their gradients are not equal. This is called STE (Straight-Through Estimator). Due to the existence of STE, the "picking out" trick can theoretically only be used in the inference stage; it cannot be linearized in the training stage.

Is there no other way? Indeed, if we insist on obtaining exact gradient results, there is no linear-efficiency scheme. However, considering that the VQ gradient itself is an approximation, obtaining exact gradients for Attention might not be strictly necessary. Thus, the authors devised a compromise: continue to perform recursive calculations according to Equation [eq:tvq], using STE only for the first two terms (allowing the Key sequence to receive gradients), while stopping the gradient for U_{i-1} (using the stop_gradient operator). This maintains the linearity of the model while preserving the most important gradients (the two nearest blocks), serving as a reasonable approximation. From this perspective, Transformer-VQ is similar to Transformer-XL, which also stops gradients for historical windows during recursion.

After solving the gradient backpropagation problem, the complete training objective is obtained by adding the auxiliary loss brought by VQ (used to update the codebook) to the autoregressive cross-entropy loss. For the codebook update, Transformer-VQ uses an exponential moving average (EMA) scheme. These details will be clear to readers familiar with VQ-VAE after a quick look at the original paper.

Experimental Results

In this section, we look at the experimental results from the original paper. The authors have open-sourced the code:

Github: https://github.com/transformer-vq/transformer_vq

It is worth noting that the base architecture the authors used for VQ is not the conventional MHA (Multi-Head Attention), but the GAU (Gated Attention Unit) + Softmax, which I have always highly recommended. Transformer-VQ should more accurately be named "GAU-VQ." Readers unfamiliar with GAU can refer to "FLASH: Perhaps the Most Interesting Efficient Transformer Design Recently" and "It Seems Attention and Softmax Go Better Together". Simply put, GAU itself is more efficient than MHA, and with the VQ trick, it becomes even more powerful.

In terms of experiments, the authors tested language models (ENWIK8, PG-19) and image generation (IMAGENET64). In all experiments, the codebook size was c=512. The maximum parameter count was 1.3B, which, while not as large as mainstream large models, is quite significant for research. The experimental results are generally excellent:

Experimental results on IMAGENET64

Finally, it is surprising to note that Transformer-VQ has only one author, whose affiliation is "Independent Researcher."

Divergent Thinking

I find that starting from Transformer-VQ, one can connect to many research topics, which is one of the reasons I appreciate it so much.

First, I once again applaud the author’s amazing insight. The discovery that "simply VQing the Key makes Transformer complexity linear" is truly beautiful. It achieves a natural transition from standard Attention to linear Attention and can be more effective than many Kernelized Attentions by adding Attention Bias. Furthermore, the "clustering" approach via VQ is more sophisticated than Linformer or Nyströmformer, as it prevents future information leakage and can naturally be used for Causal language models.

We know that VQ is essentially an operation that converts a sequence into discrete IDs, which is very similar to the function of a Tokenizer. From this perspective, Transformer-VQ, like models such as MegaByte, builds the Tokenizer into the model. Compared to MegaByte, the VQ operation is more similar and intuitive to our traditional understanding of a Tokenizer. Therefore, Transformer-VQ is actually very suitable for training "No Tokenizer" models that directly take Bytes as input. In fact, the ENWIK8 experiment mentioned above used Byte input, and Transformer-VQ significantly outperformed MegaByte.

Compared to the recently released RetNet, Transformer-VQ has no explicit long-range decay, so its Long Context capability might be better. At the same time, because the Keys are VQ-ed and belong to a finite set, there will be no "unseen" Keys, so its length extrapolation capability is likely to be better. Although the base architecture of Transformer-VQ, GAU, is single-headed, its memory state size during recursion is \Delta_i^{\top}V_i \in \mathbb{R}^{c \times d_v}. In the default settings, this is larger than the Multi-Head RetNet (RetNet’s memory state size is nd_k^2, and in default settings d_v = 2nd_k). Thus, the memory capacity is theoretically sufficient.

Since the previous article was about "Embarrassingly Simple FSQ: ’Rounding’ Surpasses VQ-VAE", some readers might wonder if the simpler FSQ can replace VQ. I think it would be difficult, for reasons given in the previous article: first, c=512 is within the range where VQ outperforms FSQ; second, since the Key of every Attention layer must be VQ-ed, the Encoder and Decoder of VQ are not strong on average, a situation where VQ’s approximation accuracy is higher (FSQ is better suited for scenarios where both Encoder and Decoder are strong enough); third, Transformer-VQ needs the center vectors after VQ, whereas FSQ directly yields IDs, making it harder to recover approximate center vectors.

Furthermore, using VQ instead of FSQ gives Transformer-VQ the hope of being fine-tuned from existing pre-trained models like LLAMA2, rather than just being trained from scratch. Because VQ has clear geometric meaning and many similarities with K-Means, we could start from an existing pre-trained model, calculate Keys for some samples, perform K-Means on the Keys to get center vectors for codebook initialization, and then fine-tune the original model with VQ added. However, Transformer-VQ does not adapt well to RoPE, so as mentioned earlier, it would be better to replace RoPE with ReRoPE before VQ, in which case the Bias might not be needed.

In short, in my eyes, Transformer-VQ is one of the most unique, outstanding, and high-potential works among many Efficient Transformer efforts.

Summary

This article introduced an Efficient Transformer scheme called Transformer-VQ. It is based on the observation that "simply VQing the Key makes Transformer complexity linear." Personally, I find this a very unique and brilliant linearization idea, and the experimental results are also excellent. It can be understood as a more sophisticated linear Attention/RNN model, or as an Attention model with a "trainable Tokenizer."

Reprinting: Please include the original address of this article: https://kexue.fm/archives/9844

For more details on reprinting, please refer to: "Scientific Space FAQ"