English (unofficial) translations of posts at kexue.fm
Source

The Tug-of-War Between Cache and Performance: From MHA, MQA, GQA to MLA

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

A few days ago, the release of DeepSeek-V2 by DeepSeek sparked heated discussions. First and foremost, the most shocking aspect was the price of 1 RMB per 1 million tokens, which is generally two orders of magnitude cheaper than existing competitive APIs. This led some to joke, “At this price, even if it outputs gibberish, I would consider that gibberish a form of art.” Secondly, according to the model’s technical report, one of the key technologies behind such a low price is the newly proposed MLA (Multi-head Latent Attention). This is an improvement over GQA, claimed to be both more efficient and effective than GQA, which has also garnered widespread attention from readers.

In the following, we will trace the evolution from MHA, MQA, and GQA to MLA, focusing on the design philosophy behind MLA.

MHA

MHA (Multi-Head Attention) is the form of attention proposed in the seminal work “Attention is All You Need.” It can be said to be the foundational work for current mainstream LLMs. Mathematically, MHA is equivalent to the concatenation of multiple independent single-head attentions. Assuming the input (row) vector sequence is \boldsymbol{x}_1, \boldsymbol{x}_2, \dots, \boldsymbol{x}_l, where \boldsymbol{x}_i \in \mathbb{R}^d, MHA can be formally denoted as:

\begin{equation} \begin{gathered} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \dots, \boldsymbol{o}_t^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} = \text{Attention}\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right) \triangleq \frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\ \boldsymbol{k}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d\times d_k} \\ \boldsymbol{v}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d\times d_v} \end{gathered} \end{equation}

For simplicity, the scaling factor of the Attention matrix is omitted here. In practice, common settings are d_k = d_v = d / h. For Llama2-7b, d=4096, h=32, d_k = d_v = 128; for Llama2-70b, d=8192, h=64, d_k = d_v = 128.

Since we only consider the Causal Attention used in mainstream autoregressive LLMs, during token-by-token recursive generation, the newly predicted (t+1)-th token does not affect the already computed \boldsymbol{k}_{\leq t}^{(s)}, \boldsymbol{v}_{\leq t}^{(s)}. Therefore, these results can be cached for subsequent generation steps to avoid unnecessary redundant computation. This is known as the KV Cache.

The subsequent MQA, GQA, and MLA are all products developed around the theme of “how to reduce the KV Cache while maintaining performance as much as possible.”

The Bottleneck

A natural question is: Why is reducing the size of the KV Cache so important?

As is well known, LLM inference is generally performed on GPUs. The video memory (VRAM) of a single GPU is limited. One part is used to store model parameters and forward computation activations; this part depends on the model size and is a constant once the model is chosen. The other part is used to store the model’s KV Cache. This part depends not only on the model size but also on the input length of the model—meaning it grows dynamically during the inference process. When the context length is long enough, its size becomes dominant and may exceed the total VRAM of a single card or even a single machine (8 cards).

The principle of deploying models on GPUs is: if it can be deployed on one card, do not span multiple cards; if it can be deployed on one machine, do not span multiple machines. This is because “intra-card communication bandwidth > inter-card communication bandwidth > inter-machine communication bandwidth.” Due to the “bottleneck effect,” the more devices a model spans during deployment, the more it is “dragged down” by the communication bandwidth between devices. In fact, even though the bandwidth between SRAM and HBM within a single H100 card has reached 3TB/s, this speed is still the bottleneck for inference even for short contexts, let alone the slower inter-card and inter-machine communication.

Therefore, the purpose of reducing the KV Cache is to enable inference on longer contexts with fewer devices, or to allow for a larger batch size at the same context length, thereby achieving faster inference speeds or higher total throughput. Ultimately, the goal is to achieve lower inference costs.

To understand this issue in more detail, readers can further read “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”, “A guide to LLM inference and performance”, and “LLM inference speed of light.” I will not expand further here (mainly because my own level is limited, and I fear saying more might lead to more errors).

MQA

MQA, or Multi-Query Attention, is a very straightforward attempt to reduce the KV Cache. It was first proposed in “Fast Transformer Decoding: One Write-Head is All You Need.” This is already a 2019 paper, which means that long before LLMs became popular, reducing the KV Cache was already a topic of great concern to researchers.

The idea of MQA is simple: directly let all Attention Heads share the same K and V. In terms of formulas, this means removing the superscript {}^{(s)} from all \boldsymbol{k} and \boldsymbol{v} in MHA:

\begin{equation} \begin{gathered} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \dots, \boldsymbol{o}_t^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} = \text{Attention}\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{\cancel{(s)}} ,\boldsymbol{v}_{\leq t}^{\cancel{(s)}}\right) \triangleq \frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\cancel{(s)}\top}\right)\boldsymbol{v}_i^{\cancel{(s)}}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\cancel{(s)}\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\ \boldsymbol{k}_i^{\cancel{(s)}} = \boldsymbol{x}_i\boldsymbol{W}_k^{\cancel{(s)}}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{\cancel{(s)}}\in\mathbb{R}^{d\times d_k} \\ \boldsymbol{v}_i^{\cancel{(s)}} = \boldsymbol{x}_i\boldsymbol{W}_v^{\cancel{(s)}}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{\cancel{(s)}}\in\mathbb{R}^{d\times d_v} \end{gathered} \end{equation}

Models using MQA include PaLM, StarCoder, and Gemini. Obviously, MQA directly reduces the KV Cache to 1/h of its original size, which is very significant. From the perspective of saving VRAM alone, this is already the ceiling.

In terms of performance, it currently appears that the loss on most tasks is relatively limited, and supporters of MQA believe this loss can be compensated for through further training. Additionally, note that because MQA shares K and V, the number of parameters in the Attention mechanism is reduced by nearly half. To keep the total number of model parameters constant, the scale of the FFN/GLU is usually increased accordingly, which can also compensate for some performance loss.

GQA

However, some worry that MQA compresses the KV Cache too severely, potentially affecting the model’s learning efficiency and final performance. To this end, GQA (Grouped-Query Attention), a transitional version between MHA and MQA, emerged. It comes from the paper “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints,” which is work from last year.

In hindsight, the idea of GQA is also very simple: it divides all Heads into g groups (g can divide h), and each group shares the same pair of K and V. Mathematically, this is expressed as:

\begin{equation} \begin{gathered} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \dots, \boldsymbol{o}_t^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} = \text{Attention}\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(\lceil sg/h\rceil)} ,\boldsymbol{v}_{\leq t}^{(\lceil sg/h\rceil)}\right) \triangleq \frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(\lceil sg/h\rceil)\top}\right)\boldsymbol{v}_i^{(\lceil sg/h\rceil)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(\lceil sg/h\rceil)\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\ \boldsymbol{k}_i^{(\lceil sg/h\rceil)} = \boldsymbol{x}_i\boldsymbol{W}_k^{(\lceil sg/h\rceil)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{(\lceil sg/h\rceil)}\in\mathbb{R}^{d\times d_k} \\ \boldsymbol{v}_i^{(\lceil sg/h\rceil)} = \boldsymbol{x}_i\boldsymbol{W}_v^{(\lceil sg/h\rceil)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(\lceil sg/h\rceil)}\in\mathbb{R}^{d\times d_v} \end{gathered} \end{equation}

Here \lceil\cdot\rceil is the ceiling function. GQA provides a natural transition from MHA to MQA: when g=h, it is MHA; when g=1, it is MQA. When 1 < g < h, it only compresses the KV Cache to g/h. While the compression rate is not as high as MQA, it provides greater flexibility and better performance guarantees. The most well-known user of GQA is likely Meta’s open-source Llama2-70B and the entire Llama3 series. Other models using GQA include TigerBot, DeepSeek-V1, StarCoder2, Yi, ChatGLM2, and ChatGLM3, which are more numerous than models using MQA (although ChatGLM says it uses MQA in its introduction, it is actually GQA with g=2).

In Llama2/3-70B, g=8, and other models of similar size using GQA basically maintain this setting. This is not accidental but is also based on inference efficiency considerations. We know that a model of the 70B scale cannot be deployed on a single card (A100/H100 80G) without extreme quantization. If a single card won’t work, then a single machine will; generally, a machine can be equipped with 8 cards. As we just mentioned, each Head of Attention is actually computed independently and then concatenated. When g=8, each card can exactly handle the computation of the Attention Heads corresponding to one pair of K and V. This maximizes the reduction of inter-card communication while ensuring the diversity of K and V as much as possible.

MLA

With the groundwork of MHA, MQA, and GQA, understanding MLA (Multi-head Latent Attention) becomes relatively easier. The DeepSeek-V2 technical report introduces MLA from the perspective of low-rank projection, leading some readers to ask questions like, “Why has it taken so long since LoRA was proposed for someone to apply low-rank decomposition to the KV Cache?”

However, I believe the low-rank projection perspective is not the most essential one. If we are talking about low-rank projection, in fact, as long as we stack all the K and V of GQA together, we will find that GQA is also equivalent to performing a low-rank projection:

\begin{equation} \underbrace{\left[\boldsymbol{k}_i^{(1)},\dots,\boldsymbol{k}_i^{(g)},\boldsymbol{v}_i^{(1)},\dots,\boldsymbol{v}_i^{(g)}\right]}_{\boldsymbol{c}_i\in\mathbb{R}^{g(d_k+d_v)}} = \boldsymbol{x}_i \underbrace{\left[\boldsymbol{W}_k^{(1)},\dots,\boldsymbol{W}_k^{(g)},\boldsymbol{W}_v^{(1)},\dots,\boldsymbol{W}_v^{(g)}\right]}_{\boldsymbol{W}_c\in\mathbb{R}^{d\times g(d_k+d_v)}} \end{equation}

Here we concatenate all \boldsymbol{k}_i^{(s)}, \boldsymbol{v}_i^{(s)} together and denote them as \boldsymbol{c}_i, and the corresponding projection matrices are also concatenated and denoted as \boldsymbol{W}_c. Note that generally d_c = g(d_k+d_v) < d, so the transformation from \boldsymbol{x}_i to \boldsymbol{c}_i is a low-rank projection. Therefore, the essential improvement of MLA is not the low-rank projection itself, but what happens after the low-rank projection.

Part 1

What does GQA do after projection? First, it splits the vector in half to serve as K and V, then each half is further divided into g parts, and each part is copied h/g times to “make up” the K and V needed for the h Attention Heads. We know that splitting and copying are simple linear transformations. So, the first idea of MLA is to replace these simple linear transformations with general linear transformations to enhance the model’s capability:

\begin{equation} \begin{gathered} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \dots, \boldsymbol{o}_t^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} = \text{Attention}\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right) \triangleq \frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\ \boldsymbol{k}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d_c\times d_k} \\ \boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_c\times d_v} \\[10pt] \boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \end{gathered} \end{equation}

However, while this theoretically increases the model’s capability, remember that the main purpose of GQA is to reduce the KV Cache. To save on computation and communication costs, we generally cache the projected \boldsymbol{k}_i, \boldsymbol{v}_i rather than the pre-projection \boldsymbol{c}_i or \boldsymbol{x}_i. In this MLA approach, using different projection matrices makes all K and V Heads different again, so the KV Cache size returns to being as large as MHA, defeating the original purpose of GQA.

To address this, MLA discovered that we can combine the specific form of Dot-Product Attention and use a simple but clever identity transformation to circumvent this problem. First, training proceeds as usual, where there isn’t much room for optimization. Then, during the inference stage, we utilize:

\begin{equation} \boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)\top} = \left(\boldsymbol{x}_t\boldsymbol{W}_q^{(s)}\right) \left(\boldsymbol{c}_i\boldsymbol{W}_k^{(s)}\right)^{\top} = \boldsymbol{x}_t\left(\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)\top}\right)\boldsymbol{c}_i^{\top} \end{equation}

This means that during the inference stage, we can merge \boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)\top} into a single projection matrix for Q. Then \boldsymbol{c}_i replaces the original \boldsymbol{k}_i. Similarly, there is another projection matrix after \boldsymbol{o}_t, so the \boldsymbol{W}_v^{(s)} in \boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)} can also be absorbed into the subsequent projection matrix. Thus, equivalently, \boldsymbol{v}_i can also be replaced by \boldsymbol{c}_i. In other words, the KV Cache only needs to store all \boldsymbol{c}_i, rather than storing all \boldsymbol{k}_i^{(s)} and \boldsymbol{v}_i^{(s)}. Note that \boldsymbol{c}_i is independent of {}^{(s)}, meaning it is shared across all heads. Thus, during inference, MLA can be transformed via an identity mapping into an MQA.

To emphasize again, the theme of this article is reducing the KV Cache. So far, what has MLA achieved? The answer is that it enhances the capability of GQA through different projection matrices while maintaining the same KV Cache size during inference. Conversely, if we only need capability similar to GQA, can we reduce the KV Cache even further? In other words, d_c doesn’t need to be g(d_k+d_v); it can take a smaller value (DeepSeek-V2 chose 512), thereby further compressing the KV Cache. This is the core idea of MLA.

Supplementary Notes:

1. The identity transformation of merging \boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)\top} into one matrix theoretically only holds under infinite precision. In practice, if we use single precision, especially BF16, the precision loss after transformation is often quite noticeable and may accumulate across multiple layers to a significant degree.

2. In practice, we generally do not calculate Q as \boldsymbol{x}_t\left(\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)\top}\right), but rather as \left(\boldsymbol{x}_t\boldsymbol{W}_q^{(s)}\right)\boldsymbol{W}_k^{(s)\top}. Although this is serial, the amount of computation is less under the low-rank assumption, and the theoretical precision loss is also smaller. However, in this article, we still introduce it as merging \boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)\top} into one matrix.

Part 2

Everything seems perfect, and it looks like an ideal design that is both good and efficient is about to emerge. But wait, when we think deeper, we find that the MLA described so far has a difficult-to-avoid flaw—it is incompatible with RoPE (Rotary Positional Embedding).

As we just said, the key step for MLA to maintain a KV Cache size similar to GQA is “merging \boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)\top} into a single (position-independent) matrix as the projection matrix for Q.” But if RoPE is added, this step cannot be realized. This is because RoPE is a position-dependent d_k \times d_k block-diagonal matrix \boldsymbol{\mathcal{R}}_m that satisfies \boldsymbol{\mathcal{R}}_m\boldsymbol{\mathcal{R}}_n^{\top}=\boldsymbol{\mathcal{R}}_{m-n}. After adding RoPE to MLA, an extra term \boldsymbol{\mathcal{R}}_{t-i} is inserted between \boldsymbol{W}_q^{(s)} and \boldsymbol{W}_k^{(s)\top}:

\begin{equation} \begin{gathered} \boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\textcolor[rgb]{0.235,0.886,0.969}{\boldsymbol{\mathcal{R}}_i}\quad,\quad\boldsymbol{k}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_k^{(s)}\textcolor[rgb]{0.235,0.886,0.969}{\boldsymbol{\mathcal{R}}_i} \\ \boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)\top} = \left(\boldsymbol{x}_t\boldsymbol{W}_q^{(s)}\textcolor[rgb]{0.235,0.886,0.969}{\boldsymbol{\mathcal{R}}_t}\right) \left(\boldsymbol{c}_i\boldsymbol{W}_k^{(s)}\textcolor[rgb]{0.235,0.886,0.969}{\boldsymbol{\mathcal{R}}_i}\right)^{\top} = \boldsymbol{x}_t\left(\boldsymbol{W}_q^{(s)}\textcolor[rgb]{0.235,0.886,0.969}{\boldsymbol{\mathcal{R}}_{t-i}}\boldsymbol{W}_k^{(s)\top}\right)\boldsymbol{c}_i^{\top} \end{gathered} \end{equation}

Here \boldsymbol{W}_q^{(s)}\textcolor[rgb]{0.235,0.886,0.969}{\boldsymbol{\mathcal{R}}_{t-i}}\boldsymbol{W}_k^{(s)\top} cannot be merged into a fixed projection matrix (as it depends on the position difference t-i), so the MLA idea cannot be implemented in conjunction with RoPE.

Some time ago, I had the honor of discussing this issue with the DeepSeek team, but this problem is very fundamental, so at the time, I couldn’t actually offer any effective suggestions. The simplest way is to abandon RoPE and use other attention-bias-based positional encodings like ALiBi, but DeepSeek’s experiments showed it was significantly inferior to RoPE (note that MLA is not unable to use RoPE, but after adding RoPE, it cannot use the identity transformation trick to reduce the KV Cache). I also suggested using Sandwich; it doesn’t decay monotonically to negative infinity like ALiBi, so the effect might be better, but it felt like treating the symptoms rather than the root cause. Another compromise is to change the input of \boldsymbol{q}_i to \boldsymbol{c}_i as well, and then add RoPE after \boldsymbol{c}_i, i.e.,

\begin{equation} \boldsymbol{q}_i^{(s)} = \boldsymbol{c}_i\textcolor[rgb]{0.235,0.886,0.969}{\boldsymbol{\mathcal{R}}_i}\boldsymbol{W}_q^{(s)},\quad\boldsymbol{k}_i^{(s)} = \boldsymbol{c}_i\textcolor[rgb]{0.235,0.886,0.969}{\boldsymbol{\mathcal{R}}_i}\boldsymbol{W}_k^{(s)} \end{equation}

In this way, \boldsymbol{\mathcal{R}}_i can be absorbed into \boldsymbol{c}_i, but then there is no \boldsymbol{\mathcal{R}}_m\boldsymbol{\mathcal{R}}_n^{\top}=\boldsymbol{\mathcal{R}}_{m-n} operation. At this point, RoPE no longer achieves relative positions through absolute positions, but simply adds absolute positions to Q and K, letting the model find its own way to extract relative position information.

The final released MLA adopts a hybrid approach—adding d_r dimensions to the Q and K of each Attention Head to add RoPE, where the added dimensions for K are shared across all heads:

\begin{equation} \begin{gathered} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \dots, \boldsymbol{o}_t^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} = \text{Attention}\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right) \triangleq \frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} = \left[\boldsymbol{x}_i\boldsymbol{W}_{qc}^{(s)}, \boldsymbol{x}_i\boldsymbol{W}_{qr}^{(s)}\textcolor[rgb]{0.235,0.886,0.969}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_k + d_r},\quad \boldsymbol{W}_{qc}^{(s)}\in\mathbb{R}^{d\times d_k},\boldsymbol{W}_{qr}^{(s)}\in\mathbb{R}^{d\times d_r}\\ \boldsymbol{k}_i^{(s)} = \left[\boldsymbol{c}_i\boldsymbol{W}_{kc}^{(s)}, \boldsymbol{x}_i\boldsymbol{W}_{kr}^{\cancel{(s)}}\textcolor[rgb]{0.235,0.886,0.969}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_k+d_r},\quad \boldsymbol{W}_{kc}^{(s)}\in\mathbb{R}^{d_c\times d_k}, \boldsymbol{W}_{kr}^{\cancel{(s)}}\in\mathbb{R}^{d\times d_r} \\ \boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_c\times d_v} \\[10pt] \boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \end{gathered} \end{equation}

In this way, the dimensions without RoPE can repeat the operations of “Part 1.” During inference, the KV Cache only needs to store \boldsymbol{c}_i. The newly added dimensions with RoPE can be used to supplement positional information. Since they are shared across all heads, only d_r dimensions are added to the K Cache. The original paper took d_r = d_k / 2 = 64, which is a small increase compared to the original d_c=512.

Part 3

Finally, there is a detail: the final version of MLA also changes the input of Q to a low-rank projection form. This is unrelated to reducing the KV Cache and is mainly to reduce the VRAM occupied by the parameter count and corresponding gradients (the original paper mentions activations, which I personally don’t quite understand) during training:

\begin{equation} \dots \end{equation}

\begin{equation} \begin{gathered} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} = \text{Attention}\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} = \left[\boldsymbol{c}_i'\boldsymbol{W}_{qc}^{(s)}, \boldsymbol{c}_i'\boldsymbol{W}_{qr}^{(s)}{\color[HTML]{3CE2F7}\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_k + d_r},\quad \boldsymbol{W}_{qc}^{(s)}\in\mathbb{R}^{d_c'\times d_k},\boldsymbol{W}_{qr}^{(s)}\in\mathbb{R}^{d_c'\times d_r}\\ \boldsymbol{k}_i^{(s)} = \left[\boldsymbol{c}_i\boldsymbol{W}_{kc}^{(s)}, \boldsymbol{x}_i\boldsymbol{W}_{kr}^{\color[HTML]{CCCCCC}{\smash{\bcancel{(s)}}}}{\color[HTML]{3CE2F7}\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_k+d_r},\quad \boldsymbol{W}_{kc}^{(s)}\in\mathbb{R}^{d_c\times d_k}, \boldsymbol{W}_{kr}^{\color[HTML]{CCCCCC}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_r} \\ \boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_c\times d_v} \\[10pt] \boldsymbol{c}_i' = \boldsymbol{x}_i \boldsymbol{W}_c'\in\mathbb{R}^{d_c'},\quad \boldsymbol{W}_c'\in\mathbb{R}^{d\times d_c'} \\ \boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \\ \end{gathered} \label{eq:mla-mha} \end{equation}

Note that in the second term of \boldsymbol{k}_i^{(s)}, the part with RoPE, the input is still \boldsymbol{x}_i rather than \boldsymbol{c}_i. This follows the setting of the original paper and is not a typo. The value of d_c' in the original paper is 1536, which is different from d_c=512. At the same time, we place the MHA with RoPE below for easy comparison:

\begin{equation} \begin{gathered} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} = \text{Attention}\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}{\color[HTML]{3CE2F7}\boldsymbol{\mathcal{R}}_i}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\ \boldsymbol{k}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_k^{(s)}{\color[HTML]{3CE2F7}\boldsymbol{\mathcal{R}}_i}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d\times d_k} \\ \boldsymbol{v}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d\times d_v} \end{gathered} \end{equation}

It can be observed that during the training phase, except for the additional low-rank projection step and applying RoPE only to partial dimensions, MLA is basically the same as MHA where the Head Size of Q and K is changed from d_k to d_k + d_r.

In the decoding phase, MLA is converted into an MQA form:

\begin{equation} \begin{gathered} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}\boldsymbol{W}_v^{(1)}, \boldsymbol{o}_t^{(2)}\boldsymbol{W}_v^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\boldsymbol{W}_v^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} = \text{Attention}\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{\color[HTML]{CCCCCC}{\smash{\bcancel{(s)}}}} ,\boldsymbol{c}_{\leq t}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color[HTML]{CCCCCC}{\smash{\bcancel{(s)}}}}{}^{\top}\right)\boldsymbol{c}_i}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color[HTML]{CCCCCC}{\smash{\bcancel{(s)}}}}{}^{\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} = \left[\boldsymbol{c}_i'\boldsymbol{W}_{qc}^{(s)}\boldsymbol{W}_{kc}^{(s)}{}^{\top}, \boldsymbol{c}_i'\boldsymbol{W}_{qr}^{(s)}{\color[HTML]{3CE2F7}\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_c + d_r}\\ \boldsymbol{k}_i^{\color[HTML]{CCCCCC}{\smash{\bcancel{(s)}}}} = \left[\boldsymbol{c}_i, \boldsymbol{x}_i\boldsymbol{W}_{kr}^{\color[HTML]{CCCCCC}{\smash{\bcancel{(s)}}}}{\color[HTML]{3CE2F7}\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_c+d_r}\\ \boldsymbol{W}_{qc}^{(s)}\in\mathbb{R}^{d_c'\times d_k},\boldsymbol{W}_{kc}^{(s)}\in\mathbb{R}^{d_c\times d_k},\boldsymbol{W}_{qr}^{(s)}\in\mathbb{R}^{d_c'\times d_r},\boldsymbol{W}_{kr}^{\color[HTML]{CCCCCC}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_r} \\[10pt] \boldsymbol{c}_i' = \boldsymbol{x}_i \boldsymbol{W}_c'\in\mathbb{R}^{d_c'},\quad \boldsymbol{W}_c'\in\mathbb{R}^{d\times d_c'} \\ \boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \\ \end{gathered} \label{eq:mla-mqa} \end{equation}

At this point, the Head Size of Q and K becomes d_c + d_r, and the Head Size of V becomes d_c. According to the original paper’s settings, this is 4 times the size of d_k and d_v. Therefore, although this transformation performed by MLA during the decoding phase effectively reduces the KV Cache, the computational cost of decoding actually increases.

So why can it still improve inference efficiency? This goes back to the issues discussed in the "Bottlenecks" section. We can divide LLM inference into two parts: the generation of the first token (Prefill) and the generation of each subsequent token (Generation). The Prefill phase involves parallel computation for all input tokens, followed by storing the corresponding KV Cache. This part is a bottleneck for computation, bandwidth, and memory, and we can use the MHA form of MLA [eq:mla-mha] for calculation. However, since the Generation phase only computes one token at each step, it is actually more of a bandwidth and memory bottleneck. At this point, we can use the MQA form of MLA [eq:mla-mqa] for calculation, thereby significantly improving the Generation speed.

Another detail fully reflects this characteristic. In general LLM architectures, the parameters satisfy h \times d_k = d, i.e., num_heads * head_size = hidden_size. However, DeepSeek-V2 is different: it has d_k=128, d=5120, but h=128, which is 3 times the usual setting! This is because the size of the MLA KV Cache is independent of h. Increasing h only increases the computational load and enhances the model’s capability without increasing the KV Cache, thus avoiding speed bottlenecks.

Summary

This article provides a brief overview of the evolution of multi-head attention, especially the conceptual changes from MHA to MQA, GQA, and finally to MLA, followed by a detailed introduction to MLA. In this article, MLA is viewed as a generalization of GQA, which replaces the splitting and repetition of GQA with projection matrices and introduces an identity transformation trick to further compress the KV Cache, while adopting a hybrid approach to maintain compatibility with RoPE. Overall, MLA can be considered a very practical attention variant.

Decoupled RoPE

A final challenge remains: the integration of RoPE (Rotary Positional Embedding). RoPE is a position-dependent linear transformation. If we were to apply RoPE directly to the compressed latent vector c_n^{KV}, the associative property used to absorb W_{UK} into W_{UQ} would be broken. Specifically, the rotation matrix R_n would be "trapped" between the matrices: q_t^\top k_n = (c_t^Q W_{UQ})^\top R_{t-n} (c_n^{KV} W_{UK}) In this case, W_{UK} cannot be pre-multiplied into W_{UQ} because of the intervening R_{t-n}.

To solve this, DeepSeek-V2 introduces a Decoupled RoPE strategy. The query and key are split into two distinct parts: a "content" part and a "position" part.

  • Content Part: This part undergoes the low-rank compression and matrix absorption described earlier. It does not use RoPE.

  • Position Part: This part is dedicated to carrying positional information via RoPE and is not compressed in the same way.

The final formulation for a head h at position t is: \begin{align*} q_{t,h} &= [ \underbrace{W_{UQ,h} c_t^Q}_{\text{Content}}, \underbrace{\text{RoPE}(W_{QR,h} c_t^Q)}_{\text{Position}} ] \\ k_{n,h} &= [ \underbrace{W_{UK} c_n^{KV}}_{\text{Content}}, \underbrace{\text{RoPE}(W_{KR} c_n^{KV})}_{\text{Position}} ] \end{align*} In DeepSeek-V2, the positional key k_{n}^R is shared across all heads to further minimize the KV cache. This decoupled design allows the model to enjoy the benefits of KV compression while maintaining the performance gains provided by rotary embeddings.

Comparison of KV Cache Efficiency

The primary motivation for MLA is to reduce the KV cache bottleneck during inference. The following table compares the KV cache requirements (number of elements stored per token) for different attention architectures:

KV Cache Size Comparison
Mechanism KV Cache Size (per token)
MHA 2 \cdot n_{heads} \cdot d_{head}
MQA 2 \cdot d_{head}
GQA 2 \cdot n_{groups} \cdot d_{head}
MLA (DeepSeek-V2) d_{latent} + d_{rope}

In the case of DeepSeek-V2, d_{latent} = 512 and d_{rope} = 64. This results in a KV cache that is significantly smaller than that of Multi-Head Attention (MHA) and even Grouped-Query Attention (GQA), while the model’s performance exceeds that of standard MHA configurations.

Conclusion

The evolution from MHA to MQA, GQA, and finally MLA represents a continuous effort to navigate the "tug-of-war" between model performance and inference efficiency. By rethinking the attention mechanism through the lens of low-rank matrix factorization and decoupled positional embeddings, MLA (Multi-Head Latent Attention) provides a path forward for scaling Large Language Models without being crippled by the memory overhead of the KV cache.

For more technical details, readers are encouraged to refer to the original DeepSeek-V2 paper: DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model.