A few days ago, in "VQ the Key, and Transformer Complexity Becomes Linear", we introduced "Transformer-VQ." This is a scheme that achieves linear complexity for Attention by applying a VQ (Vector Quantization) transformation to the Key sequence. Admittedly, Transformer-VQ provides a very elegant transition from standard Attention to linear Attention, embodying the beauty of "great truths are always simple." However, readers familiar with VQ might sense that as the codebook size or model parameters increase, VQ could likely become a bottleneck for performance. This is because the gradients estimated via the STE (Straight-Through Estimator) are highly likely to be sub-optimal (experimental results from FSQ provide some supporting evidence). Furthermore, the gradient truncation performed by Transformer-VQ to ensure linear training efficiency might also become a performance bottleneck in the future.
To this end, I spent some time thinking about linearization ideas that could replace VQ. From the \exp\left(QC^{\top}\right) form in Transformer-VQ, I was reminded of the Performer. Subsequently, by "following the clues," I discovered that the Performer can actually be viewed as a "Soft" version of Transformer-VQ. Furthermore, I attempted to re-derive Transformer-VQ by analogy with the Performer’s derivation method, providing some reference results for future optimizations.
A Brief Recap
First, let us take a moment to review Transformer-VQ. Let Q, K \in \mathbb{R}^{n \times d_k} and V \in \mathbb{R}^{n \times d_v}. The key to Transformer-VQ is applying the following VQ approximation to K: \begin{equation} K \approx \hat{K} \triangleq \Delta C \end{equation} Here, \Delta \in \{0,1\}^{n \times c} and C \in \mathbb{R}^{c \times d_k} are matrices, where C contains trainable parameters (the codebook) and \Delta is defined as: \begin{equation} \Delta_{i,j} = \left\{\begin{aligned}& 1, \quad j=\mathop{\text{argmin}}_{k=1,2,\cdots,c} \Vert K_i - C_k\Vert \\ & 0, \quad\text{otherwise}\end{aligned}\right. \end{equation} Simply put, VQ approximates K_i using the C_j that is closest to it. Under this approximation, we have (taking the Encoder as an example for simplicity): \begin{equation} \exp\left(Q\hat{K}{}^{\top}\right)V = \exp\left(QC^{\top}\Delta^{\top}\right)V = \exp\left(QC^{\top}\right)\Delta^{\top}V = \exp\left(QC^{\top}\right)(\Delta^{\top}V) \label{eq:transformer-vq} \end{equation} Readers familiar with linear Attention will easily recognize that the operation in the last expression has linear complexity. This is one of the main subjects of this article, Transformer-VQ (specifically the numerator; the denominator follows the same logic).
Without complex derivations, linear Attention emerges. This gives us the feeling that we have reduced the complexity of Attention to linear "inadvertently" while approximating the Key, which is full of aesthetic appeal. Therefore, we return once again to the evaluation mentioned many times—Transformer-VQ provides a very beautiful transition from standard Attention to linear Attention.
A Sense of Familiarity
The \exp\left(QC^{\top}\right) term in Transformer-VQ reminded me of a previous article "The Road to Transformer Upgrade: 3. From Performer to Linear Attention". In that article, I simplified the results of the Performer and asserted that the optimal activation function for Q, K in linear Attention is \exp. Since \exp also appears in Transformer-VQ, there might be some correlation between them.
To explore this connection, let us bring in the Performer, which is based on a beautiful approximation: \begin{equation} e^{\boldsymbol{q}\cdot \boldsymbol{k}}=\mathbb{E}_{\boldsymbol{\omega}\sim \mathcal{N}(\boldsymbol{\omega};0,\boldsymbol{1}_d)}\left[e^{\boldsymbol{\omega}\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2} \,e^{\boldsymbol{\omega}\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2}\right]\approx\underbrace{\frac{1}{\sqrt{m}}\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2} \\ e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2}\\ \vdots\\ e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{q}}} \cdot \underbrace{\frac{1}{\sqrt{m}}\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2} \\ e^{\boldsymbol{\omega}_2\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2}\\ \vdots\\ e^{\boldsymbol{\omega}_m\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{k}}} \label{eq:performer} \end{equation} Since the final step involves normalizing the attention over all \boldsymbol{k}, removing the \frac{1}{\sqrt{m}} and -\Vert \boldsymbol{q}\Vert^2/2 terms from the above equation does not affect the final result. At the same time, if we assume that the lengths of \boldsymbol{\omega}_1, \boldsymbol{\omega}_2, \cdots, \boldsymbol{\omega}_m are all equal (refer to the JL Lemma), then subtracting \Vert\boldsymbol{\omega}_i\Vert^2/2 from the exponent of \boldsymbol{k} will also not affect the result. Thus, the Performer is equivalent to using the following format for \tilde{\boldsymbol{q}}, \tilde{\boldsymbol{k}}: \begin{equation} \underbrace{\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}} \\ e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}}\\ \vdots\\ e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}} \end{pmatrix}}_{\tilde{\boldsymbol{q}}} \cdot \underbrace{\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2-\Vert \boldsymbol{\omega}_1\Vert^2 / 2} \\ e^{\boldsymbol{\omega}_2\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2-\Vert \boldsymbol{\omega}_2\Vert^2 / 2}\\ \vdots\\ e^{\boldsymbol{\omega}_m\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2-\Vert \boldsymbol{\omega}_m\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{k}}} = \underbrace{\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}} \\ e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}}\\ \vdots\\ e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}} \end{pmatrix}}_{\tilde{\boldsymbol{q}}} \cdot \underbrace{\begin{pmatrix}e^{-\Vert \boldsymbol{k}-\boldsymbol{\omega}_1\Vert^2 / 2} \\ e^{-\Vert \boldsymbol{k} - \boldsymbol{\omega}_2\Vert^2 / 2}\\ \vdots\\ e^{-\Vert \boldsymbol{k} - \boldsymbol{\omega}_m\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{k}}} \propto \underbrace{\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}} \\ e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}}\\ \vdots\\ e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}} \end{pmatrix}}_{\tilde{\boldsymbol{q}}} \cdot \text{softmax}\underbrace{\begin{pmatrix}e^{-\Vert \boldsymbol{k}-\boldsymbol{\omega}_1\Vert^2 / 2} \\ e^{-\Vert \boldsymbol{k} - \boldsymbol{\omega}_2\Vert^2 / 2}\\ \vdots\\ e^{-\Vert \boldsymbol{k} - \boldsymbol{\omega}_m\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{k}}} \end{equation} Comparing the last expression with Eq. [eq:transformer-vq], one finds many similarities: aren’t \boldsymbol{\omega}_1, \boldsymbol{\omega}_2, \cdots, \boldsymbol{\omega}_m equivalent to the codebook C? Isn’t \tilde{\boldsymbol{q}} equivalent to \exp\left(QC^{\top}\right)? As for the final \tilde{\boldsymbol{k}}, it uses -\Vert \boldsymbol{k} - \boldsymbol{\omega}_i\Vert^2 / 2 as logits for a softmax; doesn’t this highlight the \boldsymbol{\omega}_i closest to \boldsymbol{k}? Since the limit of softmax is one-hot, doesn’t this correspond exactly to the \Delta matrix in Transformer-VQ? Therefore, while not identical, they are remarkably similar.
Following the Pattern
Of course, the above result is more of a figurative analogy than an equivalence, because the Performer is essentially based on a completely different approximation logic. For instance, the \boldsymbol{\omega}_1, \boldsymbol{\omega}_2, \cdots, \boldsymbol{\omega}_m in the Performer are randomly sampled and fixed, which means their approximation as center vectors is actually quite poor. However, this similarity triggers a thought: can we imitate the Performer’s logic to re-derive Transformer-VQ? That is, like Eq. [eq:performer], first construct an exact identity and then transform it into a sampled approximation to obtain the linear version.
After a few days of reflection, I discovered a scheme that can construct the desired derivation. First, we use the Dirac delta function to write: \begin{equation} e^{\boldsymbol{q}\cdot \boldsymbol{k}} = \int e^{\boldsymbol{q}\cdot \boldsymbol{\omega}}\delta(\boldsymbol{\omega} - \boldsymbol{k})d\boldsymbol{\omega} \end{equation} This is a pure identity given by the definition of the Dirac delta function, involving no sophisticated operations or approximations yet. However, when we substitute it into Attention (the numerator), some interesting results emerge: \begin{equation} \sum_j e^{\boldsymbol{q}\cdot \boldsymbol{k}_j} \boldsymbol{v}_j = \sum_j \boldsymbol{v}_j\int e^{\boldsymbol{q}\cdot \boldsymbol{\omega}}\delta(\boldsymbol{\omega} - \boldsymbol{k}_j)d\boldsymbol{\omega} = \int e^{\boldsymbol{q}\cdot \boldsymbol{\omega}} \left[\sum_j \delta(\boldsymbol{\omega} - \boldsymbol{k}_j) \boldsymbol{v}_j\right]d\boldsymbol{\omega}\label{eq:inf-vq} \end{equation} The last equality is exactly in the form of linear Attention! Of course, because it requires integration over \boldsymbol{\omega}, this expression, like the one in "The Road to Transformer Upgrade: 5. Linear Attention as Infinite Dimensions", is an "infinite-dimensional" linear Attention, which currently holds only formal value.
Typically, we understand \delta(\boldsymbol{\omega} - \boldsymbol{k}_j) as the limit of a normal distribution \mathcal{N}(\boldsymbol{\omega};\boldsymbol{k}_j,\sigma^2\boldsymbol{I}) as \sigma\to 0, which also means \delta(\boldsymbol{\omega} - \boldsymbol{k}_j) carries the meaning of a conditional distribution p(\boldsymbol{\omega}|\boldsymbol{k}_j). However, from the perspective of generative models, the Dirac delta function is a single-point distribution—essentially "memorizing" the training set—so it lacks abstraction and generalization capabilities. To alleviate this, we approximate p(\boldsymbol{\omega}|\boldsymbol{k}_j) using a GMM (Gaussian Mixture Model): \begin{equation} p(\boldsymbol{\omega}|\boldsymbol{k}_j) \approx \sum_{y=1}^m \mathcal{N}(\boldsymbol{\omega};\boldsymbol{c}_y,\sigma^2\boldsymbol{I}) \,p(y|\boldsymbol{k}_j) \end{equation} Substituting this into Eq. [eq:inf-vq] and taking the limit \sigma\to 0, we obtain: \begin{equation} \sum_j e^{\boldsymbol{q}\cdot \boldsymbol{k}_j} \boldsymbol{v}_j \approx \sum_{y=1}^m e^{\boldsymbol{q}\cdot \boldsymbol{c}_y} \left[\sum_j p(y|\boldsymbol{k}_j) \boldsymbol{v}_j\right] \end{equation} This yields a finite-dimensional linear Attention. If we align p(y|\boldsymbol{k}_j) with the definition of the one-hot distribution \Delta in Transformer-VQ, the resulting expression is exactly Eq. [eq:transformer-vq] of Transformer-VQ.
Summary
This article introduced a discovery: the early linear Attention work "Performer" can be viewed as a "Soft" version of Transformer-VQ. Based on this observation, a new derivation of Transformer-VQ was obtained: using the Dirac delta function to transform standard Attention into infinite-dimensional linear Attention, and then applying a GMM approximation to arrive at Transformer-VQ.
Reprinting should include the original address of this article: https://kexue.fm/archives/9862
For more detailed reprinting matters, please refer to: "Scientific Space FAQ"