Since the explosion of DeepSeek, its proposed Attention variant, MLA (Multi-head Latent Attention), has received increasing attention. Through clever design, MLA achieves free switching between MHA and MQA, allowing the model to choose the optimal form based on different characteristics of training and inference (Compute-Bound or Memory-Bound), maximizing efficiency as much as possible.
Admittedly, MLA is very effective, but some argue it is not "elegant" enough. Consequently, efforts to find alternatives to MLA have always existed, including our own attempts. However, after a period of experimentation, we found that many Attention variants with the same or even larger KV Cache ultimately performed worse than MLA. This forced us to reflect: what exactly is the key reason behind MLA’s outstanding performance?
In this article, I will detail my thought process and related experimental results surrounding this question.
Observations
MLA was proposed in DeepSeek-V2. This article assumes the reader is already familiar with MLA, or at least understands the content introduced in the previous blog post "The Ultimate Tug-of-War Between Cache and Performance: From MHA, MQA, GQA to MLA". Therefore, the specific details of MLA itself will not be expanded upon excessively.
The main characteristics of MLA are as follows:
1. During the training phase, MLA is an MHA with
qk_head_dims=(128+64)andv_head_dims=128;2. During the decoding phase, MLA is a KV-Shared MQA with
qk_head_dims=(512+64)andv_head_dims=512;3. The concatenation of
[qc, qr]and[kc, kr]in MLA can be understood as a form of Partial RoPE.
Conjectures
The head_dims commonly used in MHA and GQA is 128. For
MLA, whether viewed from training (128+64) or inference (512+64), it is
larger than 128. Combining this with the experience from "Breaking the Bottleneck: Building
a More Powerful Transformer", we have:
Conjecture 1: Increasing
head_dimsis one of the keys to MLA’s success.
Additionally, the KV-Shared feature allows for increasing the
head_dims or num_groups of GQA under the same
KV Cache size. Thus:
Conjecture 2: KV-Shared is one of the keys to MLA’s success.
Finally, some previous theories and experiments have shown that Partial RoPE might have a positive impact on performance (refer to "The Road to Transformer Upgrades: 18. Principles for Choosing the Base of RoPE"). Thus:
Conjecture 3: Partial RoPE is one of the keys to MLA’s success.
Experiments
We will now test the above conjectures one by one through experiments.
Setup
The hyperparameters for the common parts of all experiments are as follows:
1. Dense model similar to Llama 3;
2.
hidden_size=2048,num_layers=12,num_heads=16;3. The optimizer is Muon, with per-head updates for the Attention part;
4. Training length is 4096, total tokens are 16B, total training steps are 16k;
5. All experiments only change the Attention mechanism, so the number of parameters will not be strictly aligned.
Part I
The KV Cache size of MLA is 512+64, which is approximately equal to
GQA2-128 (the first number is num_groups, the second is
head_dims). Therefore, the baselines for comparison are
GQA2-128 and GQA1-256. To verify Partial RoPE, we added
GQA1-256-PR, where the 256 dims of Q and
K are split into 192+64; RoPE is applied to the 64 dims, while the 192
dims remain without it.
The results are as follows:
| Params | Loss | Cache | |
|---|---|---|---|
| MLA | 894M | 2.721 | 576 |
| GQA2-128 | 842M | 2.75 | 512 |
| GQA1-256 | 943M | 2.72 | 512 |
| GQA1-256-PR | 943M | 2.711 | 512 |
That is: \text{GQA2-128} < \text{MLA} \lesssim \text{GQA1-256} < \text{GQA1-256-PR}
This preliminary result validates the roles of increasing
head_dims and Partial RoPE. From this perspective, the
seemingly "forced" design of splicing RoPE and NoPE in MLA is very
likely the key reason for its superior performance! The original paper
claims that MLA even outperforms MHA, which is likely because the MHA
being compared only had a head_dims of 128.
Part II
To further verify the effect of increasing head_dims, we
ran three additional experiments: MHA, GQA2-192, and MLA-256. MHA is a
conventional MHA with head_dims=128. GQA2-192 directly
increases the head_dims of GQA2 to 192. MLA-256 increases
the MLA (128+64) to (192+64). The comparison is as follows:
| Params | Loss | Cache | |
|---|---|---|---|
| MHA | 931M | 2.721 | 4096 |
| MLA | 894M | 2.721 | 576 |
| MLA-256 | 989M | 2.705 | 576 |
| GQA2-128 | 842M | 2.75 | 512 |
| GQA2-192 | 899M | 2.729 | 768 |
| GQA1-256 | 943M | 2.72 | 512 |
| GQA1-256-PR | 943M | 2.711 | 512 |
As can be seen, MHA has more total parameters and a KV Cache 7 times
larger than MLA, yet its Loss barely matches MLA. This is consistent
with the conclusions in DeepSeek-V2. Furthermore, GQA2-192 is better
than GQA2-128 but worse than GQA1-256. After increasing MLA’s
head_dims to (192+64), the performance improved further
compared to (128+64). These phenomena indicate that increasing
head_dims is far more effective than increasing
num_groups.
Part III
Next, we verify KV-Shared, where K and V share all or most
dimensions. Here, we mainly consider GQA alternatives with
head_dims not exceeding 256, and control the total KV Cache
size to be close to MLA. Thus, with KV-Shared, we can consider at most
GQA2-256.
Since KV-Shared is not fully compatible with RoPE, following MLA’s approach, we split 256 into 192+64:
1. The 192-dim part has no RoPE and is shared between K and V;
2. The 64-dim part has RoPE and is used only for K;
3. V additionally projects another 64 dims, which are concatenated to the shared 192 dims.
In this way, the head_dims for both K and V is 256, and
the total KV Cache size is (192+64+64) \times
2 = 640, slightly larger than MLA’s 512+64=576. We denote this version as "GQA2-(192+64)-S1", where "S1" stands for
"Shared-1".
Part IV
Another KV-Shared scheme is:
1. The 192-dim part has no RoPE and is shared between K and V;
2. The 64-dim part has RoPE and is also shared between K and V;
3. Perform Attention. Since V carries RoPE, this results in an absolute positional encoding effect;
4. To ensure relative positional encoding, the output is split into 192+64 parts, and an inverse RoPE is applied to the 64-dim part.
In this approach, K and V are completely shared. The KV Cache size is (192+64) \times 2 = 512, slightly smaller than MLA. We call this version "GQA2-(192+64)-S2", where "S2" stands for "Shared-2". The underlying principle is the VO-RoPE proposed by the author, refer to "The Road to Transformer Upgrades: 19. The Second Type of Rotary Positional Encoding".
Part V
Additionally, several experiments for GQA4 and GQA1 were added following the same logic. All experimental results are summarized below:
| Params | Loss | Cache | Remarks | |
|---|---|---|---|---|
| MLA | 894M | 2.721 | 576 | |
| MLA-256 | 989M | 2.705 | 576 | |
| GQA2-(192+64)-S1 | 946M | 2.714 | 640 | |
| GQA2-(192+64)-S2 | 943M | 2.708 | 512 | Includes VO-RoPE |
| GQA4-(64+64)-S2 | 842M | 2.738 | 512 | |
| GQA4-(128+64)-S2 | 899M | 2.713 | 768 | Largest KV Cache |
| GQA1-(512+64)-S3 | 1171M | 2.677 | 576 | Largest head_dims |
Here, "GQA1-(512+64)-S3" is an MQA
implemented according to MLA’s inference form, with a structure between
S1 and S2. Its main feature is the large head_dims.
Interpretation of results:
1. KV-Shared GQA inherently includes Partial RoPE;
2. KV-Shared GQA2-256 can also outperform MLA;
3. The introduction of VO-RoPE seems beneficial to performance (S1 \lesssim S2);
4. Under the same KV Cache, larger
head_dimsis better;5. GQA2-(192+64)-S2 slightly outperforms GQA1-256-PR;
6. GQA4-(128+64)-S2 has the largest KV Cache, but its performance is not optimal, again indicating that
head_dimsis more critical.
Two more observations regarding KV-Shared:
1. During training, GQA1-256-PR was significantly ahead of GQA2-(192+64)-S2 in the early stages, but was caught up or even slightly overtaken in the later stages. It is conjectured that GQA1-256-PR might lack "stamina";
2. Without KV-Shared, GQA is at most GQA1-256, meaning
head_dimscaps at 256. But with KV-Shared, GQA can reach GQA1-512-S. Purely from the perspective ofhead_dims, KV-Shared has a higher ceiling.
Part VI
Since the parameter counts were not strictly aligned, readers might
wonder whether increasing parameters or increasing
head_dims is more fundamental. Therefore, we added several
experiments with aligned parameter counts.
We consider three ways to align parameter counts:
1. double-heads: Taking "GQA2-128 vs GQA1-256" as an example, doubling the
num_headsof GQA2-128 makes its parameter count the same as GQA1-256;2. Shrinking MLP: Reducing the
intermediate_sizeof the MLP (SwiGLU) can make the parameter count of GQA1-256 roughly the same as GQA2-128;3. Q&O LoRA: The main parameter count of GQA comes from the projection matrices of Query and Output. Using LoRA for these two matrices can also reduce the parameter count of GQA1-256.
Experimental results are as follows:
| Params | Loss | Cache | heads | inter_size | qo_lora | |
|---|---|---|---|---|---|---|
| MLA | 894M | 2.721 | 576 | 16 | 5456 | No |
| GQA2-128 | 842M | 2.75 | 512 | 16 | 5456 | No |
| GQA1-256 | 943M | 2.72 | 512 | 16 | 5456 | No |
| GQA2-128 | 943M | 2.723 | 512 | 32 | 5456 | No |
| GQA1-256 | 843M | 2.747 | 512 | 16 | 4096 | No |
| GQA1-256 | 842M | 2.726 | 512 | 16 | 5456 | Yes |
| GQA4-(64+64)-S2 | 842M | 2.738 | 512 | 16 | 5456 | No |
| GQA2-(192+64)-S2 | 943M | 2.708 | 512 | 16 | 5456 | No |
| GQA4-(64+64)-S2 | 943M | 2.711 | 512 | 32 | 5456 | No |
| GQA2-(192+64)-S2 | 843M | 2.733 | 512 | 16 | 4096 | No |
| GQA2-(192+64)-S2 | 842M | 2.708 | 512 | 16 | 5456 | Yes |
The results are mainly divided into three parts:
1. Doubling heads compared to doubling
head_dimsresults in a Loss that is consistently worse by about 0.003;2. Shrinking the MLP compared to halving
head_dimsresults in a Loss that is consistently better by about 0.004;3. Q&O LoRA has the smallest performance loss, allowing
head_dimsto double without increasing parameter count, while significantly reducing Loss.
Conclusion: From the perspective of increasing parameter count,
increasing head_dims is likely the direction with the
largest performance gain. Combined with Q&O LoRA, it can achieve
almost no increase in parameters while maintaining significant
gains.
Summary
The preliminary conclusions are:
1. Increasing
head_dimsyields the highest returns;2. Partial RoPE also helps reduce Loss;
3. KV-Shared likely plays a role as well.
It seems that our previous attempts to find MLA alternatives under
head_dims=128 were fundamentally disadvantaged from the
start. No wonder they couldn’t beat MLA. To match MLA,
head_dims should start at 192, supplemented by Partial
RoPE. As for KV-Shared, it may be useful, but likely requires
larger-scale verification.
Significance
The significance here depends on how strong our determination is to replace MLA.
Suppose GQA2-(192+64)-S2 can replace MLA, but MLA can also be upgraded to 256. Currently, GQA2-(192+64)-S2 does not match MLA-256. Thus, the only two benefits of replacing MLA are:
1. The structure is simpler, making it easier to add QK-Norm;
2. The
head_dimsin the decoding stage changes from 512+64 to 256, whilenum_groupsbecomes 2, which allows for Tensor Parallelism (TP).
Reprinting: Please include the original address of this article: https://kexue.fm/archives/10907
For more detailed reprinting matters, please refer to: "Scientific Space FAQ"