English (unofficial) translations of posts at kexue.fm
Source

The Road to Transformer Upgrades: 20. Why is MLA So Good? (Part 1)

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

Since the explosion of DeepSeek, its proposed Attention variant, MLA (Multi-head Latent Attention), has received increasing attention. Through clever design, MLA achieves free switching between MHA and MQA, allowing the model to choose the optimal form based on different characteristics of training and inference (Compute-Bound or Memory-Bound), maximizing efficiency as much as possible.

Admittedly, MLA is very effective, but some argue it is not "elegant" enough. Consequently, efforts to find alternatives to MLA have always existed, including our own attempts. However, after a period of experimentation, we found that many Attention variants with the same or even larger KV Cache ultimately performed worse than MLA. This forced us to reflect: what exactly is the key reason behind MLA’s outstanding performance?

In this article, I will detail my thought process and related experimental results surrounding this question.

Observations

MLA was proposed in DeepSeek-V2. This article assumes the reader is already familiar with MLA, or at least understands the content introduced in the previous blog post "The Ultimate Tug-of-War Between Cache and Performance: From MHA, MQA, GQA to MLA". Therefore, the specific details of MLA itself will not be expanded upon excessively.

The main characteristics of MLA are as follows:

1. During the training phase, MLA is an MHA with qk_head_dims=(128+64) and v_head_dims=128;

2. During the decoding phase, MLA is a KV-Shared MQA with qk_head_dims=(512+64) and v_head_dims=512;

3. The concatenation of [qc, qr] and [kc, kr] in MLA can be understood as a form of Partial RoPE.

Conjectures

The head_dims commonly used in MHA and GQA is 128. For MLA, whether viewed from training (128+64) or inference (512+64), it is larger than 128. Combining this with the experience from "Breaking the Bottleneck: Building a More Powerful Transformer", we have:

Conjecture 1: Increasing head_dims is one of the keys to MLA’s success.

Additionally, the KV-Shared feature allows for increasing the head_dims or num_groups of GQA under the same KV Cache size. Thus:

Conjecture 2: KV-Shared is one of the keys to MLA’s success.

Finally, some previous theories and experiments have shown that Partial RoPE might have a positive impact on performance (refer to "The Road to Transformer Upgrades: 18. Principles for Choosing the Base of RoPE"). Thus:

Conjecture 3: Partial RoPE is one of the keys to MLA’s success.

Experiments

We will now test the above conjectures one by one through experiments.

Setup

The hyperparameters for the common parts of all experiments are as follows:

1. Dense model similar to Llama 3;

2. hidden_size=2048, num_layers=12, num_heads=16;

3. The optimizer is Muon, with per-head updates for the Attention part;

4. Training length is 4096, total tokens are 16B, total training steps are 16k;

5. All experiments only change the Attention mechanism, so the number of parameters will not be strictly aligned.

Part I

The KV Cache size of MLA is 512+64, which is approximately equal to GQA2-128 (the first number is num_groups, the second is head_dims). Therefore, the baselines for comparison are GQA2-128 and GQA1-256. To verify Partial RoPE, we added GQA1-256-PR, where the 256 dims of Q and K are split into 192+64; RoPE is applied to the 64 dims, while the 192 dims remain without it.

The results are as follows:

Params Loss Cache
MLA 894M 2.721 576
GQA2-128 842M 2.75 512
GQA1-256 943M 2.72 512
GQA1-256-PR 943M 2.711 512

That is: \text{GQA2-128} < \text{MLA} \lesssim \text{GQA1-256} < \text{GQA1-256-PR}

This preliminary result validates the roles of increasing head_dims and Partial RoPE. From this perspective, the seemingly "forced" design of splicing RoPE and NoPE in MLA is very likely the key reason for its superior performance! The original paper claims that MLA even outperforms MHA, which is likely because the MHA being compared only had a head_dims of 128.

Part II

To further verify the effect of increasing head_dims, we ran three additional experiments: MHA, GQA2-192, and MLA-256. MHA is a conventional MHA with head_dims=128. GQA2-192 directly increases the head_dims of GQA2 to 192. MLA-256 increases the MLA (128+64) to (192+64). The comparison is as follows:

Params Loss Cache
MHA 931M 2.721 4096
MLA 894M 2.721 576
MLA-256 989M 2.705 576
GQA2-128 842M 2.75 512
GQA2-192 899M 2.729 768
GQA1-256 943M 2.72 512
GQA1-256-PR 943M 2.711 512

As can be seen, MHA has more total parameters and a KV Cache 7 times larger than MLA, yet its Loss barely matches MLA. This is consistent with the conclusions in DeepSeek-V2. Furthermore, GQA2-192 is better than GQA2-128 but worse than GQA1-256. After increasing MLA’s head_dims to (192+64), the performance improved further compared to (128+64). These phenomena indicate that increasing head_dims is far more effective than increasing num_groups.

Part III

Next, we verify KV-Shared, where K and V share all or most dimensions. Here, we mainly consider GQA alternatives with head_dims not exceeding 256, and control the total KV Cache size to be close to MLA. Thus, with KV-Shared, we can consider at most GQA2-256.

Since KV-Shared is not fully compatible with RoPE, following MLA’s approach, we split 256 into 192+64:

1. The 192-dim part has no RoPE and is shared between K and V;

2. The 64-dim part has RoPE and is used only for K;

3. V additionally projects another 64 dims, which are concatenated to the shared 192 dims.

In this way, the head_dims for both K and V is 256, and the total KV Cache size is (192+64+64) \times 2 = 640, slightly larger than MLA’s 512+64=576. We denote this version as "GQA2-(192+64)-S1", where "S1" stands for "Shared-1".

Part IV

Another KV-Shared scheme is:

1. The 192-dim part has no RoPE and is shared between K and V;

2. The 64-dim part has RoPE and is also shared between K and V;

3. Perform Attention. Since V carries RoPE, this results in an absolute positional encoding effect;

4. To ensure relative positional encoding, the output is split into 192+64 parts, and an inverse RoPE is applied to the 64-dim part.

In this approach, K and V are completely shared. The KV Cache size is (192+64) \times 2 = 512, slightly smaller than MLA. We call this version "GQA2-(192+64)-S2", where "S2" stands for "Shared-2". The underlying principle is the VO-RoPE proposed by the author, refer to "The Road to Transformer Upgrades: 19. The Second Type of Rotary Positional Encoding".

Part V

Additionally, several experiments for GQA4 and GQA1 were added following the same logic. All experimental results are summarized below:

Params Loss Cache Remarks
MLA 894M 2.721 576
MLA-256 989M 2.705 576
GQA2-(192+64)-S1 946M 2.714 640
GQA2-(192+64)-S2 943M 2.708 512 Includes VO-RoPE
GQA4-(64+64)-S2 842M 2.738 512
GQA4-(128+64)-S2 899M 2.713 768 Largest KV Cache
GQA1-(512+64)-S3 1171M 2.677 576 Largest head_dims

Here, "GQA1-(512+64)-S3" is an MQA implemented according to MLA’s inference form, with a structure between S1 and S2. Its main feature is the large head_dims.

Interpretation of results:

1. KV-Shared GQA inherently includes Partial RoPE;

2. KV-Shared GQA2-256 can also outperform MLA;

3. The introduction of VO-RoPE seems beneficial to performance (S1 \lesssim S2);

4. Under the same KV Cache, larger head_dims is better;

5. GQA2-(192+64)-S2 slightly outperforms GQA1-256-PR;

6. GQA4-(128+64)-S2 has the largest KV Cache, but its performance is not optimal, again indicating that head_dims is more critical.

Two more observations regarding KV-Shared:

1. During training, GQA1-256-PR was significantly ahead of GQA2-(192+64)-S2 in the early stages, but was caught up or even slightly overtaken in the later stages. It is conjectured that GQA1-256-PR might lack "stamina";

2. Without KV-Shared, GQA is at most GQA1-256, meaning head_dims caps at 256. But with KV-Shared, GQA can reach GQA1-512-S. Purely from the perspective of head_dims, KV-Shared has a higher ceiling.

Part VI

Since the parameter counts were not strictly aligned, readers might wonder whether increasing parameters or increasing head_dims is more fundamental. Therefore, we added several experiments with aligned parameter counts.

We consider three ways to align parameter counts:

1. double-heads: Taking "GQA2-128 vs GQA1-256" as an example, doubling the num_heads of GQA2-128 makes its parameter count the same as GQA1-256;

2. Shrinking MLP: Reducing the intermediate_size of the MLP (SwiGLU) can make the parameter count of GQA1-256 roughly the same as GQA2-128;

3. Q&O LoRA: The main parameter count of GQA comes from the projection matrices of Query and Output. Using LoRA for these two matrices can also reduce the parameter count of GQA1-256.

Experimental results are as follows:

Params Loss Cache heads inter_size qo_lora
MLA 894M 2.721 576 16 5456 No
GQA2-128 842M 2.75 512 16 5456 No
GQA1-256 943M 2.72 512 16 5456 No
GQA2-128 943M 2.723 512 32 5456 No
GQA1-256 843M 2.747 512 16 4096 No
GQA1-256 842M 2.726 512 16 5456 Yes
GQA4-(64+64)-S2 842M 2.738 512 16 5456 No
GQA2-(192+64)-S2 943M 2.708 512 16 5456 No
GQA4-(64+64)-S2 943M 2.711 512 32 5456 No
GQA2-(192+64)-S2 843M 2.733 512 16 4096 No
GQA2-(192+64)-S2 842M 2.708 512 16 5456 Yes

The results are mainly divided into three parts:

1. Doubling heads compared to doubling head_dims results in a Loss that is consistently worse by about 0.003;

2. Shrinking the MLP compared to halving head_dims results in a Loss that is consistently better by about 0.004;

3. Q&O LoRA has the smallest performance loss, allowing head_dims to double without increasing parameter count, while significantly reducing Loss.

Conclusion: From the perspective of increasing parameter count, increasing head_dims is likely the direction with the largest performance gain. Combined with Q&O LoRA, it can achieve almost no increase in parameters while maintaining significant gains.

Summary

The preliminary conclusions are:

1. Increasing head_dims yields the highest returns;

2. Partial RoPE also helps reduce Loss;

3. KV-Shared likely plays a role as well.

It seems that our previous attempts to find MLA alternatives under head_dims=128 were fundamentally disadvantaged from the start. No wonder they couldn’t beat MLA. To match MLA, head_dims should start at 192, supplemented by Partial RoPE. As for KV-Shared, it may be useful, but likely requires larger-scale verification.

Significance

The significance here depends on how strong our determination is to replace MLA.

Suppose GQA2-(192+64)-S2 can replace MLA, but MLA can also be upgraded to 256. Currently, GQA2-(192+64)-S2 does not match MLA-256. Thus, the only two benefits of replacing MLA are:

1. The structure is simpler, making it easier to add QK-Norm;

2. The head_dims in the decoding stage changes from 512+64 to 256, while num_groups becomes 2, which allows for Tensor Parallelism (TP).

Reprinting: Please include the original address of this article: https://kexue.fm/archives/10907

For more detailed reprinting matters, please refer to: "Scientific Space FAQ"