In the article "Transformer Upgrade Road: 20. Why is MLA so good? (Part 1)", we conducted ablation experiments on several changes in MLA compared to common MHA, GQA, and MQA. These changes included "increasing head_dims", "Partial RoPE", and "KV sharing". The preliminary experimental results suggested that all three changes are likely reasons for MLA’s superior performance.
In this article, we will understand the success of MLA from a more theoretical perspective.
Partial Rotation
First, let’s put the final assertion upfront:
Under the same training and inference costs, MLA may be the best Full Attention variant.
Obviously, this judgment places MLA in a very high position. This conclusion is based on the experimental results of the previous article and the theoretical analysis in this article, under ideal and simplified assumptions. Since actual training and inference involve many complex factors, this conclusion will likely deviate somewhat, but we can at least conclude that MLA is on the right path of improvement.
The reason MLA can perform so well has a very large prerequisite: the effect of Partial RoPE is not inferior to, and may even be superior to, the full version of RoPE. Partial RoPE here can have two meanings: first, when we add RoPE to the Attention’s \boldsymbol{Q} and \boldsymbol{K}, we can add it to only a small portion of the dimensions, while the remaining dimensions remain unchanged; second, we can consider alternating RoPE and NoPE between layers, with NoPE layers potentially being the majority.
Simply put, RoPE can be added "just a little bit," but it cannot be omitted entirely; omitting it completely leads to poor performance. If a theory is needed, the author agrees with the explanation in "Transformer Upgrade Road: 18. Principles for Choosing the Base of RoPE", which roughly means that Partial RoPE makes retrieval results better balance position and semantics. In addition, new works such as FoX and SBA also show potential, but for MLA, these variants are equivalent to NoPE and thus do not change the conclusion.
The conclusion that "Partial RoPE is not bad" allows us to place the main computational complexity of Attention on the NoPE part, which provides more room for maneuver, and MLA benefits from this.
Key-Value Sharing
The evolution of Full Attention roughly goes from MHA, MQA, GQA, and then to MLA. Although MQA can be seen as a special case of GQA, chronologically GQA came later. After MLA, two variants, MFA and TPA, also appeared. These variants essentially aim to squeeze the KV Cache as much as possible to improve generation speed while maintaining performance.
Briefly, the complexity of an Attention model can be divided into three parts: Training, Prefill, and Decoding. Since Training and Prefill are similar, it is essentially Prefill and Decoding. Prefill refers to the stage where the model processes input until it outputs the first token; we will discuss this in the next section. Decoding refers to the token-by-token generation stage, which can be accelerated through the KV Cache mechanism, but this also makes the KV Cache size almost the sole bottleneck for Decoding speed.
Therefore, compressing the KV Cache means increasing Decoding speed. Now, let me ask a question: In the context of NoPE, given a fixed KV Cache size, what is the best Attention? If we do not consider differences in parameter counts and only discuss within a single layer of MHA/GQA/MQA (we will discuss TPA and MFA later), the answer would be:
An MQA where head_dims is equal to the KV Cache size, and K and V are shared.
Does this seem surprising? It’s actually not hard to understand. Since MHA and MQA can both be seen as special cases of GQA, we only need to analyze GQA. As we showed in "The Ultimate Tug-of-War Between Cache and Performance: From MHA, MQA, GQA to MLA", GQA can be re-expressed as a model where K and V are concatenated: \begin{equation} \underbrace{\left[\boldsymbol{k}_i^{(1)},\cdots,\boldsymbol{k}_i^{(g)},\boldsymbol{v}_i^{(1)},\cdots,\boldsymbol{v}_i^{(g)}\right]}_{\boldsymbol{c}_i\in\mathbb{R}^{g(d_k+d_v)}} = \boldsymbol{x}_i \underbrace{\left[\boldsymbol{W}_k^{(1)},\cdots,\boldsymbol{W}_k^{(g)},\boldsymbol{W}_v^{(1)},\cdots,\boldsymbol{W}_v^{(g)}\right]}_{\boldsymbol{W}_c\in\mathbb{R}^{d\times g(d_k+d_v)}} \end{equation} Here g(d_k+d_v) is exactly the total KV Cache size for a single token. Then, when we calculate Attention, the transformations from \boldsymbol{c} to \boldsymbol{k}, \boldsymbol{v} are absorbed into \boldsymbol{W}_q and \boldsymbol{W}_o, resulting in an MQA where both K and V are \boldsymbol{c}. Thus, "an MQA where head_dims equals the KV Cache size and K and V are shared" is actually a "superset" of MHA/GQA/MQA for a given KV Cache size, so it is theoretically the best choice.
Dual Projection
In summary, if we want the best performance at the same Decoding speed, we should train an MQA with a specified head_dims and shared KV. For example, if we agree that the KV Cache should not exceed 512, then an MQA with head_dims=512 and shared KV is the best choice. In fact, MLA in the Decoding stage is exactly a shared-KV MQA (the NoPE part), which is one of the manifestations of it moving in the right direction.
However, increasing head_dims to 512 is fine for Decoding, but difficult to accept for Training and Prefill, because their bottleneck is computation, and the main factors affecting computation speed are num_heads and head_dims. To ensure performance, there isn’t much room to change num_heads, so head_dims can be said to be the sole indicator of computational volume. Increasing head_dims to 512 means the computational volume increases to 4 times the original (compared to head_dims=128).
Now let me ask another question: Also in the context of NoPE, given num_heads and head_dims, what is the best Attention? I believe everyone can accept the answer to this question: MHA, because it has the fewest constraints. Therefore, from the perspective of Training and Prefill costs alone, what we want is to train an MHA with head_dims=128.
How to reconcile the different expectations of Prefill and Decoding? This is MLA’s "big move." It obtains K and V through two projection steps: first projecting the input into a single 512-dimensional vector, then projecting that vector into multiple 128-dimensional vectors. By utilizing the inherent identity transformation property of "Attention + NoPE," the model can switch freely between MHA-128 and MQA-512.
\begin{array}{c|c} \text{Training/Prefill} & \text{Decoding} \\ \hline \\ \begin{gathered} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} = \frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\ \boldsymbol{k}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d_c\times d_k} \\ \boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_c\times d_v} \\[10pt] \boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \end{gathered} & \begin{gathered} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}\boldsymbol{W}_v^{(1)}, \boldsymbol{o}_t^{(2)}\boldsymbol{W}_v^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\boldsymbol{W}_v^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} = \frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{lightgray}{\smash{\bcancel{(s)}}}}{}^{\top}\right)\boldsymbol{v}_i^{\color{lightgray}{\smash{\bcancel{(s)}}}} }{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{lightgray}{\smash{\bcancel{(s)}}}}{}^{\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)}{}^{\top}\in\mathbb{R}^{d_c}\\ \boldsymbol{k}_i^{\color{lightgray}{\smash{\bcancel{(s)}}}} = \boldsymbol{v}_i^{\color{lightgray}{\smash{\bcancel{(s)}}}} = \boldsymbol{c}_i= \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c} \end{gathered} \end{array}
In Summary
We summarize the previous reasoning logic:
Major Premise: The effect of Partial RoPE is not worse than, and may even be better than, RoPE, which allows us to focus our main energy on NoPE;
The main bottleneck of Decoding is KV Cache; the theoretically optimal model is a shared-KV MQA where head_dims = KV Cache size;
The main bottleneck of Training and Prefill is head_dims; the theoretically optimal model is an MHA with the desired head_dims;
Under the premise of NoPE, Attention has an identity transformation property, which allows LoRA to be used to balance these two ideal directions as much as possible, which is exactly what MLA does.
The rest is to concatenate a shared low-dimensional RoPE to K to supplement MLA with position information at minimal cost, while also "killing two birds with one stone": the practice of concatenating RoPE coincides with "Partial RoPE" and also increases head_dims, which is consistent with the conclusion of the previous article. In other words, the intentional or unintentional use of Partial RoPE and the increase in head_dims are the main reasons why MLA can still rival MHA under extreme compression.
From the perspective of MQA, MLA adds a rank=128 LoRA to Q; from the perspective of MHA, MLA adds a rank=512 LoRA to K and V. It can be said that MLA is an extreme "magic show" combining NoPE with LoRA, and MHA with MQA, successfully achieving a "two-way rush" between Prefill and Decoding.
Of course, the above thinking process certainly has some oversimplified parts. For example, actual training and inference have many detailed factors, and it is not completely accurate to summarize them simply as head_dims and KV Cache. For example, MQA cannot use Tensor Parallelism (TP) during the Decoding stage, which may bring new efficiency issues; also, in the analysis process, we did not pay special attention to the alignment of parameter counts. For example, when head_dims=128, we could also consider increasing the projection complexity of Q, K, and V to improve performance, rather than necessarily increasing head_dims; and so on.
In short, these two articles aim to provide some experiments and reflections to demonstrate the optimality of MLA within a certain range. Of course, MLA was first proposed by DeepSeek, and third-party use of MLA always gives a sense of copying DeepSeek. However, until a better variant appears or serious flaws are discovered, MLA remains a very competitive choice. If one avoids MLA simply to show that they are not "following" DeepSeek, that would be a rather unwise choice.
For example, currently, hybrid models of Linear Attention and Softmax Attention also show great competitiveness. But if we mix Linear Attention with the GQA8-128 used by LLAMA in a 3:1 ratio, the KV Cache is roughly reduced to 1/4 of GQA8-128. However, MLA itself can already reduce the KV Cache to 1/4 of GQA8-128.
Supplementary Discussion
We have been discussing MHA, GQA, MQA, and MLA. In this section, let’s briefly talk about two Attention variants that are less frequently mentioned: TPA and MFA.
TPA stands for Tensor Product Attention. The author gave it the name Tensor Product, which sounds quite "intimidating," but it is actually an intermediate product between GQA and MLA. Taking a target KV Cache of 512 as an example, TPA first projects to obtain a 512-dimensional vector, then reshapes it into (4, 128), and then divides it into two (2, 128) vectors representing K Cache and V Cache respectively. So far, TPA’s approach is consistent with GQA2-128.
Next, TPA draws on the idea of MLA to re-project the (2, 128) K/V into Multi-Head, but it does not project the entire vector like MLA. Instead, it projects along the dimension where "2" is located. Simply put, it takes 2 128-dimensional vectors and makes head_dims different linear combinations. Obviously, the upper limit of TPA is not as good as MLA, which projects directly from the entire 512-dimensional vector. To alleviate this problem, TPA introduces data-dependent combination coefficients to enhance the expression capability of K and V. Even so, the author still believes its upper limit is not as good as MLA.
Why was TPA designed this way? Largely to be compatible with RoPE, which is its biggest "advantage" compared to MLA. However, this "advantage" should be in quotes because, in the context where Partial RoPE is not inferior or even superior, being compatible with RoPE feels a bit ironic. Also, TPA’s design blocks its space to increase head_dims. For example, if head_dims is to be increased to 256, then K Cache and V Cache would only be in (1, 256) shape, and a single vector has no degree of freedom for linear combination.
Now let’s look at MFA, which stands for "Multi-matrix Factorization Attention". This name also sounds a bit "intimidating," but it is actually an MQA with Q-LoRA and head_dims=256. Does this configuration look familiar? Because this configuration perfectly matches the conclusion of our previous article—increasing head_dims to 256 to improve MQA performance, with KV Cache close to MLA, while controlling the number of parameters through Q-LoRA.
Therefore, it is not surprising to the author that MFA can "fight" MLA; we experimented with a similar approach in the previous article. In addition, in the previous article, we proposed two other directions to improve MQA performance: one is Partial RoPE, which has been mentioned many times in this article, and the other is to achieve complete KV sharing through QKVO-RoPE, turning MQA into GQA2-256. If these two points are added, MFA should be able to improve a bit more.
Article Summary
Based on the experimental results of the previous article, this article provides a theoretical thinking process to demonstrate the optimality of MLA within a certain range. Overall, in the context of Partial RoPE, MLA seems to be a very difficult Attention variant to surpass.