Broadly speaking, current length extrapolation techniques for Transformers can be divided into two categories. The first is post-hoc modification, such as NTK-RoPE, YaRN, and ReRoPE. These methods are characterized by directly modifying the inference model, achieving a certain degree of length extrapolation without fine-tuning. However, their disadvantage is that they cannot maintain the model’s identity (consistency) within the original training length. The second category is pre-hoc modification, such as ALIBI, KERPLE, XPOS, and HWFA. These can achieve length extrapolation without further modification, but the changes must be introduced before training. Consequently, they cannot be applied to existing models without fine-tuning, and it is not yet widely recognized whether these methods can effectively scale up.
In this article, I will introduce a length extrapolation scheme discovered by accident: “KeyNorm”—applying L2 Normalization to the Key sequence in Attention. It clearly belongs to the pre-hoc modification category, but the change to the Attention mechanism is so minimal that it appears very promising for scaling up.
Initial Motivation
The reason I call it an “accidental discovery” is that the original motivation for this change was not length extrapolation, but rather an attempt to replace the scaling method in Scaled Dot-Product Attention. As we know, the standard definition of Attention (considering the causal scenario in this article) is: \begin{equation} \boldsymbol{o}_i = \frac{\sum_{j = 1}^i\exp\left(\frac{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}{\sqrt{d}}\right)\boldsymbol{v}_j}{\sum_{j = 1}^i\exp\left(\frac{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}{\sqrt{d}}\right)},\quad \boldsymbol{q}_i,\boldsymbol{k}_j\in\mathbb{R}^d \label{eq:sdpa} \end{equation} The scale factor \frac{1}{\sqrt{d}} has been explained and even generalized many times, for example in “On the Initialization, Parameterization, and Standardization of Transformers”, “Attention Scaling from the Perspective of Entropy Invariance”, and “Attention Scaling from the Perspective of Gradient Maximization”. The standard derivation is performed under the assumption that \boldsymbol{q}_i, \boldsymbol{k}_j are independently sampled from a distribution with mean 0 and variance 1. Under this assumption, we have: \begin{equation} \Vert\boldsymbol{q}_i\Vert\approx \sqrt{d},\quad \Vert\boldsymbol{k}_j\Vert\approx \sqrt{d} \end{equation} This is because: \begin{equation} \Vert\boldsymbol{x}\Vert^2 = \sum_{i=1}^d x_i^2 = d\times\frac{1}{d}\sum_{i=1}^d x_i^2\approx d\,\mathbb{E}_{x\sim\mathcal{N}(0,1)}[x^2] = d \end{equation} For related generalizations, one can refer to “The Amazing Johnson-Lindenstrauss Lemma: Theory Edition”. This approximation implies that in the initial stage of Attention, Equation [eq:sdpa] has the same effect as the following two variants: \begin{align} \text{\textbf{Q}uery\textbf{N}orm:}\quad\boldsymbol{o}_i =&\, \frac{\sum_{j = 1}^i\exp\left(\tilde{\boldsymbol{q}}_i\cdot \boldsymbol{k}_j\right)\boldsymbol{v}_j}{\sum_{j = 1}^i\exp\left(\tilde{\boldsymbol{q}}_i\cdot \boldsymbol{k}_j\right)},\qquad \tilde{\boldsymbol{q}}_i = \frac{\boldsymbol{q}_i}{\Vert\boldsymbol{q}_i\Vert} \\[5pt] \text{\textbf{K}ey\textbf{N}orm:}\quad\boldsymbol{o}_i =&\, \frac{\sum_{j = 1}^i\exp\left(\boldsymbol{q}_i\cdot \tilde{\boldsymbol{k}}_j\right)\boldsymbol{v}_j}{\sum_{j = 1}^i\exp\left(\boldsymbol{q}_i\cdot \tilde{\textbf{k}}_j\right)},\qquad \tilde{\boldsymbol{k}}_j = \frac{\boldsymbol{k}_j}{\Vert\boldsymbol{k}_j\Vert} \end{align} Therefore, the idea was to verify which of these two variants is superior to the standard Equation [eq:sdpa]. For convenience, we can call them “Query/Key-Normalized Dot-Product Attention,” abbreviated as “QNA” and “KNA” respectively.
Furthermore, since we can have QueryNorm and KeyNorm, we can naturally consider normalizing both. Thus, we also include the following “Scaled Cosine Attention (CosA)” in our experiments: \begin{equation} \boldsymbol{o}_i = \frac{\sum_{j = 1}^i\exp\left(\lambda\,\tilde{\boldsymbol{q}}_i\cdot \tilde{\boldsymbol{k}}_j\right)\boldsymbol{v}_j}{\sum_{j = 1}^i\exp\left(\lambda\,\tilde{\boldsymbol{q}}_i\cdot \tilde{\boldsymbol{k}}_j\right)} = \frac{\sum_{j = 1}^i\exp\left(\lambda\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)\right)\boldsymbol{v}_j}{\sum_{j = 1}^i\exp\left(\lambda\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)\right)} \end{equation} Where \lambda is taken from the results in “Attention Scaling from the Perspective of Gradient Maximization”, i.e., \lambda = 4\log n (the original paper used 3.5, but since the training length here is relatively small, 4 is more accurate), where n is fixed at half the training length or dynamically taken as the position ID plus 1.
Experimental Results
Following the previous experimental setup for length extrapolation, we use a small model with 100 million parameters and the GAU architecture. The models are trained for the same number of steps (due to time constraints, the models are not fully trained at this step count), with a training length of 512, and extrapolation to a length of 4096 is considered. The experimental results are shown in the table below. “Baseline” refers to Equation [eq:sdpa], and “-\log n” refers to adding the length-related scaling factor introduced in “Attention Scaling from the Perspective of Entropy Invariance”. The evaluation metric is the per-token accuracy of the language model (higher is better).
| Test Length | 512 (Train) | 4096 (Repeat) | 4096 (Non-repeat) |
|---|---|---|---|
| Baseline | 49.41% | 24.17% | 23.16% |
| Baseline-\log n | 49.40% | 24.60% | 24.02% |
| QNA | 49.55% | 22.45% | 22.18% |
| QNA-\log n | 49.42% | 19.55% | 18.74% |
| KNA | 49.60% | 61.08% | 47.69% |
| KNA-\log n | 49.58% | 63.17% | 46.40% |
| CosA | 49.73% | 58.90% | 46.98% |
| CosA-\log n | 49.67% | 64.74% | 48.95% |
From the table, we can observe: 1. Both QueryNorm and KeyNorm achieve better results at the training length. Although this advantage is very slight and likely negligible as training progresses, it is very stable, suggesting the possibility of smoother training. 2. KeyNorm provides a very significant boost to length extrapolation, which is the “unexpected pleasant surprise” in the experimental results!
Note that unlike NTK-RoPE and YaRN, which require modifying the model during the inference stage, the length extrapolation for KNA and CosA here involves no changes during inference. Therefore, readers might wonder: since KNA and CosA already perform so well without inference-time modifications, would the results be even better if combined with extrapolation techniques like NTK-RoPE or YaRN? I also tested this, and the results are as follows:
| Test Length | 512 (Train) | 4096 (Repeat) | 4096 (Non-repeat) |
|---|---|---|---|
| Baseline | 49.41% | 24.17% | 23.16% |
| Baseline-NTK | 49.41% | 60.57% | 42.20% |
| Baseline-YaRN | 49.41% | 80.10% | 47.45% |
| Baseline-ReRoPE | 49.41% | 76.11% | 47.82% |
| Baseline-\log n | 49.40% | 24.60% | 24.02% |
| Baseline-\log n-NTK | 49.40% | 75.86% | 47.06% |
| Baseline-\log n-YaRN | 49.40% | 82.57% | 46.52% |
| Baseline-\log n-ReRoPE | 49.40% | 85.47% | 48.87% |
| QNA | 49.55% | 22.45% | 22.18% |
| QNA-NTK | 49.55% | 52.28% | 39.88% |
| QNA-YaRN | 49.55% | 82.53% | 47.50% |
| QNA-ReRoPE | 49.55% | 78.22% | 47.72% |
| QNA-\log n | 49.42% | 19.55% | 18.74% |
| QNA-\log n-NTK | 49.42% | 57.44% | 41.56% |
| QNA-\log n-YaRN | 49.42% | 80.08% | 45.16% |
| QNA-\log n-ReRoPE | 49.42% | 84.71% | 48.31% |
| KNA | 49.60% | 61.08% | 47.69% |
| KNA-NTK | 49.60% | 64.44% | 43.02% |
| KNA-YaRN | 49.60% | 84.19% | 47.44% |
| KNA-ReRoPE | 49.60% | 77.76% | 47.73% |
| KNA-\log n | 49.58% | 63.17% | 46.40% |
| KNA-\log n-NTK | 49.58% | 79.05% | 47.43% |
| KNA-\log n-YaRN | 49.58% | 83.95% | 47.16% |
| KNA-\log n-ReRoPE | 49.58% | 85.48% | 48.78% |
| CosA | 49.73% | 58.90% | 46.98% |
| CosA-NTK | 49.73% | 62.50% | 42.77% |
| CosA-YaRN | 49.73% | 83.40% | 47.80% |
| CosA-ReRoPE | 49.73% | 77.82% | 47.80% |
| CosA-\log n | 49.67% | 64.74% | 48.39% |
| CosA-\log n-NTK | 49.67% | 78.97% | 47.46% |
| CosA-\log n-YaRN | 49.67% | 82.28% | 45.72% |
| CosA-\log n-ReRoPE | 49.67% | 85.67% | 48.39% |
This table is quite detailed, mainly to provide a comprehensive sense of the performance differences between mainstream length extrapolation techniques. You can compare the dimensions you are interested in, but note that when looking at length extrapolation effects, the “Non-repeat” column should be the primary focus, with the “Repeat” column as a secondary reference. From the table, the results are somewhat surprising: KeyNorm seems to be “immune” to existing RoPE extrapolation techniques. Adding NTK, YaRN, and other techniques did not yield significant improvements and might even cause a decrease. However, overall, there is still a significant improvement in the “Repeat” column; what is not significant is the “Non-repeat” column. These results indicate that KeyNorm still has the problem of being unable to effectively identify positions beyond the training length (hence the lower “Repeat” results), but it effectively avoids the PPL explosion problem (hence the decent “Non-repeat” results).
This might be good news for those working on Long Context: on one hand, unlike ALIBI or KERPLE, KeyNorm’s length extrapolation does not require adding Local constraints and involves no modifications after training—it is essentially a “free lunch.” It even seems that training performance improves after adding KeyNorm. On the other hand, because it is non-Local, it can be used for continued training on longer texts, and there is no longer a need to agonize over whether to choose PI or ABF; for KeyNorm, you don’t have to change anything.
Theoretical Analysis
Although this was an accidental discovery, we still need to try to explain it; otherwise, it remains just an accident. In this section, we attempt to think about why KeyNorm helps with length extrapolation.
Let’s return to Equation [eq:sdpa]. The correlation score between the i-th token and the j-th token is determined by the dot product: \begin{equation} s(j|i) = \boldsymbol{q}_i\cdot \boldsymbol{k}_j = \Vert\boldsymbol{q}_i\Vert \Vert\boldsymbol{k}_j\Vert \cos(\boldsymbol{q}_i,\boldsymbol{k}_j),\quad p(j|i) = \frac{\exp\left(\frac{s(j|i)}{\sqrt{d}}\right)}{\sum_{j=1}^i \exp\left(\frac{s(j|i)}{\sqrt{d}}\right)} \end{equation} In the second equality, starting from geometric meaning, we decompose it into the product of their respective norms and the cosine of the angle between them. Attention p(j|i) is a conditional probability. \Vert\boldsymbol{q}_i\Vert only relates to the current position i; it does not change the relative magnitude of attention but only changes the degree of sparsity. \Vert\boldsymbol{k}_j\Vert, however, has the ability to change the relative magnitude of p(j|i), but it does not involve the interaction between i and j. It can be used to express some absolute signals; for example, Scissorhands shows that the attention of tokens at certain absolute positions is always very high, which could potentially be expressed by \Vert\boldsymbol{k}_j\Vert. The remaining \cos(\boldsymbol{q}_i,\boldsymbol{k}_j) is used to express the interaction between i and j, and it is the term with the greatest degree of freedom.
Clearly, to increase the relative importance of a certain position j, the model has two choices: 1. Increase the norm \Vert\boldsymbol{k}_j\Vert; 2. Increase \cos(\boldsymbol{q}_i,\boldsymbol{k}_j), i.e., reduce the angle between \boldsymbol{q}_i and \boldsymbol{k}_j. However, due to the “curse of dimensionality,” it is relatively difficult to significantly change the angle in high-dimensional space. Therefore, if the goal can be achieved by increasing the norm \Vert\boldsymbol{k}_j\Vert, the model will prioritize increasing the norm \Vert\boldsymbol{k}_j\Vert. The direct consequence of this is that the training of \cos(\boldsymbol{q}_i,\boldsymbol{k}_j) may be insufficient.
Here, I make an assertion (hypothesis):
Insufficient training of \cos(\boldsymbol{q}_i,\boldsymbol{k}_j) is the primary reason why Attention fails to extrapolate in length.
Insufficient training of \cos(\boldsymbol{q}_i,\boldsymbol{k}_j) means that the angles of \boldsymbol{q}_i, \boldsymbol{k}_j encountered during training are only a finite set. When performing length extrapolation, the model faces a larger set and is thus unable to make correct predictions. Careful consideration of the derivation in the YaRN paper reveals that NTK and YaRN are effective because they modify the implementation of RoPE during inference, causing the angles of \boldsymbol{q}_i, \boldsymbol{k}_j to fall back into the finite set from the original training phase, avoiding unseen larger sets and turning extrapolation into interpolation. ReRoPE is even more direct, truncating relative positions outside the window, ensuring that position encodings during inference are never “unfamiliar.” These techniques, to some extent, indirectly validate this assertion.
Starting from this assertion, the reason for KeyNorm’s length extrapolation becomes simple. Whether it is KNA (only KeyNorm) or CosA (both QueryNorm and KeyNorm), they exclude \Vert\boldsymbol{k}_j\Vert from the definition of Attention. Consequently, to change the relative importance of j, the model has only one choice: “adjust \cos(\boldsymbol{q}_i,\boldsymbol{k}_j).” This forces the model to train and utilize \cos(\boldsymbol{q}_i,\boldsymbol{k}_j) more thoroughly, thereby indirectly promoting length extrapolation. Additionally, I have experimented with the combination of “KeyNorm + NoPE,” but did not find length extrapolation properties. This indicates that RoPE also plays an important role in KeyNorm’s length extrapolation. In fact, this is not hard to understand: RoPE rotates \boldsymbol{q}_i, \boldsymbol{k}_j, which helps expand the range of \cos(\boldsymbol{q}_i,\boldsymbol{k}_j) during training, making the training of \cos(\boldsymbol{q}_i,\boldsymbol{k}_j) more sufficient.
Has any work already tried QueryNorm and KeyNorm? Yes. The 2020 paper “Query-Key Normalization for Transformers” experimented with CosA. The paper also proposed a similar logarithmic length scale factor but did not discuss length extrapolation. Furthermore, Google’s paper earlier this year, “Scaling Vision Transformers to 22 Billion Parameters”, also added Norm to Query and Key, but they used LayerNorm. LayerNorm or RMSNorm both have learnable gamma parameters, which means the norm of the vectors after normalization is not necessarily constant. Therefore, it is hard to say if they can achieve the same length extrapolation effect as in this article.
Summary
This article introduced a length extrapolation scheme discovered by accident: “KeyNorm”—applying L2 normalization to the Key sequence in Attention. It achieved better results at the training length and showed significant improvements in length extrapolation. It belongs to the “pre-hoc modification” schemes. Compared to other pre-hoc schemes like ALIBI and KERPLE, it has no Local constraints and is therefore more promising for scaling up. Compared to “post-hoc modification” schemes like NTK-RoPE and YaRN, it does not lose performance within the training length during extrapolation.
Original Address: https://kexue.fm/archives/9859