Recently, I came across a paper on arXiv titled "Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention". The experimental phenomena described in it align closely with some observations we made while training Kimi K2, such as issues starting from the second layer of Attention. The paper attributes this to the inherent biased errors in low-precision Attention. This analytical perspective was quite unexpected to me, so I read it with great interest.
However, the paper’s presentation seems somewhat difficult to understand—partly because I am not very familiar with low-precision arithmetic. In short, after consulting the authors several times, I managed to grasp the essence of the paper. I have recorded my understanding here for everyone’s reference.
Brief Conclusion
It should be noted that although the paper’s title mentions "Flash
Attention," according to the description, the same problem would occur
even if the block_size were as large as the training
sequence length. Therefore, the block-wise calculation of Flash
Attention is not the cause of the problem. We can simplify the analysis
by considering a naive low-precision Attention implementation.
For simplicity, let’s analyze single-head Attention. Let \bm{Q}, \bm{K}, \bm{V} \in \mathbb{R}^{n \times d}, and let \bm{S} = \bm{Q}\bm{K}^{\top}. Let the bold \bm{1} denote an n \times 1 matrix of ones, and \bm{S}_{\max} denote the n \times 1 matrix obtained by taking the maximum value of each row of \bm{S}. Then: \begin{equation} \bm{O} = \frac{\exp(\bm{S})\bm{V}}{\exp(\bm{S})\bm{1}} = \frac{\exp(\bm{S} - \bm{S}_{\max})\bm{V}}{\exp(\bm{S}- \bm{S}_{\max})\bm{1}} \end{equation} We denote \bar{\bm{P}} = \exp(\bm{S} - \bm{S}_{\max}). The key calculation in Attention is the matrix multiplication \bar{\bm{P}}\bm{V}, which is generally performed in BF16 precision. The conclusion given by the paper is: In low-precision calculations, the step \bar{\bm{P}}\bm{V} contains a biased rounding error. That is to say, in the long-term average, the expectation of the difference between the low-precision calculation of \bar{\bm{P}}\bm{V} and the exact value is not zero.
Consequently, the bias between different training steps may continuously accumulate, leading to problems such as MaxLogit explosion and Loss Spikes, until the training collapses. Of course, strictly speaking, this is only one possible mechanism for issues like MaxLogit explosion, not necessarily the only one, but it is still worth studying and reflecting upon.
Round-to-Even
To understand the paper’s conclusion, let’s first review some basic common sense regarding rounding errors. The reason I am writing this section is, as mentioned at the beginning, that I am not familiar with low-precision arithmetic myself—so this section is entirely for my own foundational learning. Readers who are already familiar with this can skip it.
We know that the most common rounding method is "rounding half up" (standard rounding). In base 10, if a positive 1-decimal number is rounded to the nearest integer, 0–4 becomes 0, producing errors of 0, -0.1, -0.2, -0.3, -0.4; 5–9 becomes 10 (carrying over), producing errors of 0.5, 0.4, 0.3, 0.2, 0.1. You might notice that the average of these errors is not 0, but 0.05. That is, "rounding half up" on average tends to increase the original number, creating a positive bias.
Of course, the relative bias decreases as the number of discarded digits increases. For example, if a 2-decimal number is rounded to an integer, the average error is 0.005. Regardless, this positive bias in standard rounding always exists; it just varies in magnitude. The root of the bias lies at the midpoint. For instance, 0.51 and 0.49 round up and down respectively, and their errors cancel out. But for 0.50, whether it is rounded up or down, there is no other number to cancel its error.
To eliminate this bias, IEEE 754 proposed the "Round-to-Even" principle. It stipulates that for the midpoint case, rounding should be towards the nearest even number. For example, 2.5 rounded to the nearest integer becomes 2, but 3.5 becomes 4. In this way, "5" has a 50% chance of producing a \pm 0.5 error, making the average error zero and eliminating the bias.
Back to the computer domain. Computers use binary, which only has 0 and 1. Thus, "1" plays the role of "5" in base 10. The bias of "rounding half up" in binary is even more apparent because the last bit can only be 0 or 1: if it is 0, it remains unchanged; if it is 1, it triggers a "round up" and carries over. Therefore, rounding a binary number by "rounding half up" always results in a value greater than or equal to the original number. Hence, "Round-to-Even" is also required here to eliminate bias.
BF16 Addition
Next, let’s review the BF16 format. BF16 uses 16 binary bits to represent a floating-point number: 1 sign bit, 7 mantissa bits, and 8 exponent bits. The 8-bit exponent design allows it to have the same range as FP32 (1 sign, 23 mantissa, 8 exponent), which has made it the primary floating-point format for LLM training today.
BF16 retains more exponent bits at the cost of fewer mantissa bits, resulting in lower precision. To mitigate the cumulative errors caused by low precision, BF16 arithmetic employs a "FP32 accumulation" strategy. This means that the addition of BF16 numbers involves first converting them to FP32, adding them in FP32 space to get an FP32 result, and finally converting it back to BF16.
Now consider the addition of two BF16 numbers with the same sign and exponent. Why choose the same exponent for analysis? Because we want to estimate the error, and the same exponent means the two numbers are of the same order of magnitude, where addition is most likely to produce the maximum error. For example, if one number is 100 times larger than the other, even if I just return the larger one, the error is only 1%. Thus, the maximum error often occurs when adding numbers of the same magnitude.
When two BF16 numbers with the same sign and exponent are added, a
carry will inevitably occur. For example,
1.0000001 + 1.0000100 = 10.0000101 = 1.00000101 \times 10 (in binary). At this
point, the exponent needs to increase by 1, and the last bit "1" must be
discarded to convert back to BF16. As described in the previous section,
if we follow "rounding half up," a positive bias will be generated.
However, as we know, scientists discovered this bias long ago and
proposed "Round-to-Even" to eliminate it.
Two Large, One Small
So far, everything is within control and expected; no bias has been generated. However, as the saying goes, if something can go wrong, it will.
Now let’s consider adding three numbers of the same sign. These three
numbers have a specific characteristic: two of them have the same large
exponent, and the third is very small. For example, based on our
previous example 1.0000001 + 1.0000100, let’s add
0.0000000001. We get
1.0000001 + 1.0000100 + 0.0000000001 = 10.0000101001 = 1.00000101001 \times 10.
Originally, adding the two numbers resulted in
1.00000101 \times 10. When discarding the
last bit, "Round-to-Even" would be triggered, resulting in
1.0000010 \times 10. But now, with the
addition of an extremely small number, the mantissa to be discarded when
converting to BF16 becomes 1001, which is greater than the
midpoint. This triggers the "round up" rule, resulting in
1.0000011 \times 10. From the perspective
of the original two-number addition, the presence of the third tiny
number has broken the "Round-to-Even" rule, causing the positive bias to
reappear!
Of course, the conditions for this situation seem quite stringent. First, the three numbers must have the same sign. Second, they must satisfy the "two large, one small" condition, where the two large numbers just happen to trigger a carry, and the small number is small enough to only affect the FP32 mantissa (i.e., the 9th to 23rd mantissa bits). In this way, the small number itself is so tiny that discarding it wouldn’t cause much error, but its existence happens to break the "Round-to-Even" rule of the two large numbers, thereby bringing about a one-sided bias.
Tailored for Attention
Can such stringent conditions actually occur in practice? Generally, it’s not easy, but for Attention, it seems like a "tailored" bug!
Let’s take a specific row and column (i.e., an element) of \bar{\bm{P}}\bm{V}. It can be written as: \begin{equation} \sum_{i=1}^n \bar{p}_i v_i \label{eq:sum-pi-vi} \end{equation} where \bar{p}_i = \exp(s_i - \max(s_i)) \leq 1. We know that the characteristic of Softmax Attention is its ability to "concentrate attention," meaning that attention may be focused on a few limited tokens. In terms of \bar{p}_i, this means the \bar{p}_i of a few tokens is close to 1, while the rest are very close to 0, though they cannot be exactly 0 due to the \exp function (unless they underflow the BF16 representation space).
As layers stack and training progresses, the input \bm{V} may exhibit "anisotropy." One manifestation is that the distribution of positive and negative signs in certain dimensions becomes uneven. Without loss of generality, assume most v_i are positive (the same applies to negative) and of similar magnitude. Then, the sum in [eq:sum-pi-vi] can be divided into two parts: a few main terms where \bar{p}_i is close to 1 multiplied by v_i, and the remaining terms where most \bar{p}_i are close to 0 multiplied by v_i.
The paper considers a special case: the \bar{p}_i corresponding to the main terms are not just close to 1, but equal to 1, meaning multiple maxima exist in some rows of \bm{S}. This special case is harder to satisfy but easier to understand; here, the inherent precision of the main terms \bar{p}_i v_i is only BF16. Thus, all conditions are met to trigger the bug described in the previous section:
Most terms are positive; the main terms have BF16 precision and the sum satisfies the carry condition; the remaining terms are extremely small, affecting only the end of the FP32 mantissa, which breaks "Round-to-Even" and leads to bias; finally, due to "concentrated attention," the number of main terms is small, so there aren’t too many carries (more discarded bits reduce bias), keeping the bias in a significant range!
This combination is essentially a "proprietary bug" tailored for Attention.
Eliminating the Remainder
After understanding the cause and effect, let’s think about how to solve the problem.
On the surface, the bias is caused by tiny remainders breaking "Round-to-Even." However, thinking deeper, the root cause is that the "rounding half up" rule has a mutation point at the midpoint, where small perturbations easily cause bias. While "Round-to-Even" eliminates the bias, it doesn’t eliminate the mutation point. The ideal radical cure is Stochastic Rounding, which rounds up or down probabilistically, thus avoiding bias from small perturbations to the greatest extent.
However, it is said that Stochastic Rounding is difficult to implement efficiently at the hardware level, so most current hardware matrix multiplication operators do not support it. Therefore, the original paper chose another path, which I call "eliminating the remainder." Specifically, when a certain trigger condition is detected, the Attention formula is modified to: \begin{equation} \bm{O} = \frac{\exp(\bm{S})\bm{V}}{\exp(\bm{S})\bm{1}} = \frac{\exp(\bm{S} - \beta\bm{S}_{\max})\bm{V}}{\exp(\bm{S}- \beta\bm{S}_{\max})\bm{1}} \end{equation} where \beta > 1. In this way, each term is divided by an additional \exp((\beta-1)\bm{S}_{\max}), which is a non-negligible number (the paper sets \beta \geq 2). Consequently, the already tiny remainders are more likely to underflow to zero and disappear, allowing "Round-to-Even" to function correctly and eliminate the bias.
What is the detection condition? The original paper’s approach is simple: the modification is triggered when a row in matrix \bm{S} has two or more maximum values, meaning there are at least two 1s in \bar{p}_i. However, I believe there is significant room for adjustment here, which remains a direction for improvement. Additionally, since Flash Attention is calculated in blocks, this detection and modification are also performed per block; details can be found in the code in the original paper’s appendix.
Further Reflections
Overall, the paper provides a unique perspective for understanding phenomena like MaxLogit explosion. It explains some things but does not cover everything, leaving many points for reflection.
First, the analysis of Attention bias relies on the anisotropy of \bm{V}. This might explain why MaxLogit explosion and other anomalies only appear starting from the second layer of Attention: the input to the first layer is the Embedding, which is relatively less likely to be anisotropic; however, inputs to the second and subsequent layers have passed through previous Attention layers and may inherently become anisotropic (reference).
However, this does not explain why MaxLogit explosion only occurs in specific layers (e.g., the paper observes it only in the 2nd layer, while K2 sees it in layers 2–4). Similarly, it doesn’t explain why Muon is more prone to MaxLogit explosion than Adam (observed in Moonlight and K2). Therefore, this is likely a combined result of architecture, optimizer, and low precision; looking at precision alone is incomplete.
Furthermore, there is a profound question of causality. Another condition for the Attention bias is that attention is concentrated on a few tokens. Intervening in the Attention calculation at this point successfully prevents subsequent anomalies. However, I observed a normally trained small model where attention was not as concentrated as imagined (e.g., the average Top-1 probability was less than 0.2, and the cumulative probability of Top-400 was needed to reach 0.9 with a sequence length of 4096).
So, is Attention bias the "cause" or the "effect" of training collapse? In other words, when "attention concentrates on a few tokens," does it mean the model has already entered a collapse range? Is intervening at that point "too late"? For instance, while it might prevent some anomalies in metrics, is it possible the model can no longer scale? These questions remain unanswered for now.
Summary
This article shared a paper analyzing the bias in low-precision Attention calculations and took the opportunity to review the basics of low-precision arithmetic.
Reprinting please include the original address: https://kexue.fm/archives/11371
For more details on reprinting, please refer to: "Scientific Space FAQ"