Reviewing the previous two articles Muon Implementation Based on Streaming Power Iteration: 1. First Encounter (https://kexue.fm/archives/11654) and Muon Implementation Based on Streaming Power Iteration: 2. Acceleration (https://kexue.fm/archives/11673), we introduced the streaming power iteration implementation of Muon, preliminarily verified its feasibility, and further discussed the acceleration of its core operation—QR decomposition—bringing its efficiency close to that of the Newton-Schulz iteration implementation.
In this article, we no longer limit ourselves to optimizing a single QR decomposition step, but instead view the streaming power iteration from a more holistic perspective, and further “refine” its implementation details based on the specific computational context, minimizing computational bottlenecks as much as possible to push its efficiency toward the theoretical limit.
Existing Results
Streaming power iteration is essentially “SVD while training”. Its idea is to compute the SVD via power iteration and amortize the computation over each training step by caching the previous step’s result, making it feasible to embed SVD into the optimizer. As for Muon, it is merely a basic application, because the most fundamental implementation of Muon’s core operation \mathop{\text{msign}} is exactly SVD. Specifically, the Muon update formula is \begin{aligned} \boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt] \boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t [\mathop{\text{msign}}(\boldsymbol{M}_t) + \lambda \boldsymbol{W}_{t-1}] \\ \end{aligned} Here the matrices are all of size n\times m, and we assume n\geq m. Let the SVD of \boldsymbol{M} be \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top} (where \boldsymbol{U}\in\mathbb{R}^{n\times m} and \boldsymbol{\Sigma},\boldsymbol{V}\in\mathbb{R}^{m\times m}), then \mathop{\text{msign}}(\boldsymbol{M})=\boldsymbol{U}\boldsymbol{V}^{\top}, so implementing SVD implements \mathop{\text{msign}}. Of course, direct SVD is usually expensive, but with streaming power iteration, it becomes feasible.
In the previous article, we also discussed four speedup ideas for streaming power iteration. The first one is to enable full-precision FP32 multiplication, which is universal, while the latter three are mutually exclusive to some extent, and we can only choose one. The author recommends choosing the second one, which has a higher theoretical ceiling, and the subsequent refinement is also based on the second one. Substituting the second idea into the streaming power iteration version for Muon, the iteration formula becomes \begin{aligned} \boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt] \boldsymbol{V}_t =&\, \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\mathop{\text{QR}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1})) \\[5pt] \boldsymbol{U}_t =&\, \mathop{\text{ColNorm}}(\boldsymbol{M}_t\boldsymbol{V}_t) \\[5pt] \boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t (\boldsymbol{U}_t\boldsymbol{V}_t^{\top} + \lambda \boldsymbol{W}_{t-1}) \\ \end{aligned} Obviously, the most expensive step now is \boldsymbol{V}_t = \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\mathop{\text{QR}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1})), which is exactly the target of our next optimization.
Accelerating the Decomposition
To ensure efficiency, we do not call the framework’s built-in QR decomposition function for \mathop{\text{QR}} here, but instead use “SCQR (Shifted Cholesky QR)”, which divides the QR decomposition of a matrix \boldsymbol{A} into two steps: 1. Perform Cholesky decomposition on \boldsymbol{A}^{\top}\boldsymbol{A} + \lambda \boldsymbol{I} to obtain the upper triangular matrix \boldsymbol{R}; 2. Solve the equation \boldsymbol{Q}\boldsymbol{R}=\boldsymbol{A} to obtain the orthogonal matrix \boldsymbol{Q}.
Both steps are theoretically very efficient, but they do not always succeed, so an additional check is needed: if it fails, fall back to the built-in standard QR function, which almost always succeeds. However, once the fallback is triggered, the end-to-end efficiency will be greatly reduced. The main reason for SCQR failure is that Cholesky decomposition is extremely sensitive to the condition number of the matrix, and the regularization term \lambda\boldsymbol{I} is precisely used to reduce the condition number of \boldsymbol{A}^{\top}\boldsymbol{A}.
However, this presents a dilemma: the larger \lambda is, the easier SCQR succeeds, but the final result will deviate more from orthogonality (i.e., larger error), leading to worse performance; the smaller \lambda is, the higher the precision, but the probability of falling back to standard QR also increases, resulting in worse efficiency. Empirical tests show that taking \lambda=\epsilon \Vert\boldsymbol{A}^{\top}\boldsymbol{A}\Vert_F with \epsilon=10^{-9} can achieve a good balance between effectiveness and efficiency.
The acceleration ideas in the previous article all revolve around reducing the condition number. The first version of streaming power iteration was \boldsymbol{V}_t = \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}), and the matrix we need to Cholesky-decompose is \boldsymbol{V}_{t-1}^{\top}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t)^2\boldsymbol{V}_{t-1}, which raises the condition number of \boldsymbol{M}_t to the fourth power, obviously causing a surge. After changing to \boldsymbol{V}_t = \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\mathop{\text{QR}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1})), although \mathop{\text{QR}} is performed twice, the condition number of the matrix for each Cholesky decomposition is only the square of that of \boldsymbol{M}_t, significantly reducing the condition number and greatly improving the success rate of SCQR, so the speed actually becomes faster.
Adjusting the Order
The above content is still a recap of the previous two articles (sorry for the lengthy preamble, but sharpening the axe does not delay the work of chopping wood). In this section, we begin to discuss new optimization ideas. A careful observation reveals that we have introduced two \mathop{\text{QR}} operations, but they are considered independently. However, two students, @YouJiacheng (https://x.com/YouJiacheng) and @Kimi (https://www.kimi.com/), discovered that if we consider them together, we can obtain some acceleration techniques.
Following the default order, the computation flow of \boldsymbol{V}_t = \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\mathop{\text{QR}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1})) is \begin{aligned} \boldsymbol{A}_{(1), t} =&\, \boldsymbol{M}_t\boldsymbol{V}_{t-1} \\ \boldsymbol{R}_{(1), t}^{\top}\boldsymbol{R}_{(1), t} =&\, \boldsymbol{A}_{(1), t}^{\top}\boldsymbol{A}_{(1), t} + \lambda \boldsymbol{I}\qquad(\text{Cholesky decomposition}) \\ \boldsymbol{Q}_{(1), t} =&\, \boldsymbol{A}_{(1), t} \boldsymbol{R}_{(1), t}^{-1} \qquad(\text{Triangular Solve}) \\ \boldsymbol{A}_{(2), t} =&\, \boldsymbol{M}_t^{\top}\boldsymbol{Q}_{(1), t} \\ \boldsymbol{R}_{(2), t}^{\top}\boldsymbol{R}_{(2), t} =&\, \boldsymbol{A}_{(2), t}^{\top}\boldsymbol{A}_{(2), t} + \lambda \boldsymbol{I}\quad(\text{Cholesky decomposition}) \\ \boldsymbol{Q}_{(2), t} =&\, \boldsymbol{A}_{(2), t} \boldsymbol{R}_{(2), t}^{-1} \qquad(\text{Triangular Solve}) \\ \end{aligned}\label{eq:qr2} Among these, the four steps \boldsymbol{M}_t\boldsymbol{V}_{t-1}, \boldsymbol{A}_{(1), t}^{\top}\boldsymbol{A}_{(1), t}, \boldsymbol{A}_{(1), t} \boldsymbol{R}_{(1), t}^{-1}, and \boldsymbol{M}_t^{\top}\boldsymbol{Q}_{(1), t} all have \mathcal{O}(nm^2) complexity, while the rest are \mathcal{O}(m^3). When n \gg m, \mathcal{O}(nm^2) may become a bottleneck. Interestingly, we can use an identity transformation to make \mathcal{O}(nm^2) appear only once! \begin{aligned} \boldsymbol{A}_{(1), t} =&\, (\boldsymbol{M}_t^{\top}\boldsymbol{M}_t)\boldsymbol{V}_{t-1} \\ \boldsymbol{R}_{(1), t}^{\top}\boldsymbol{R}_{(1), t} =&\, \boldsymbol{V}_{t-1}^{\top}\boldsymbol{A}_{(1), t} + \lambda \boldsymbol{I}\qquad(\text{Cholesky decomposition}) \\ \boldsymbol{A}_{(2), t} =&\, \boldsymbol{A}_{(1), t} \boldsymbol{R}_{(1), t}^{-1} \qquad(\text{Triangular Solve}) \\ \boldsymbol{R}_{(2), t}^{\top}\boldsymbol{R}_{(2), t} =&\, \boldsymbol{A}_{(2), t}^{\top}\boldsymbol{A}_{(2), t} + \lambda \boldsymbol{I}\qquad(\text{Cholesky decomposition}) \\ \boldsymbol{Q}_{(2), t} =&\, \boldsymbol{A}_{(2), t} \boldsymbol{R}_{(2), t}^{-1} \qquad(\text{Triangular Solve}) \\ \end{aligned}\label{eq:qr2-sim} This equivalent version is well worth savoring! First, it can be proven that it is theoretically completely equivalent to the original version, and this equivalence does not depend on the strict orthogonality of \boldsymbol{V}_{t-1} and \boldsymbol{Q}_{(1), t}. After the transformation, only the step \boldsymbol{M}_t^{\top}\boldsymbol{M}_t is \mathcal{O}(nm^2), and the rest are all \mathcal{O}(m^3). Moreover, the total number of steps is reduced by one (combining the original \boldsymbol{Q}_{(1), t} = \boldsymbol{A}_{(1), t} \boldsymbol{R}_{(1), t}^{-1} and \boldsymbol{A}_{(2), t} = \boldsymbol{M}_t^{\top}\boldsymbol{Q}_{(1), t} into one step)!
Note 1: According to @YouJiacheng (https://x.com/YouJiacheng), this ingenious transformation was automatically discovered by Kimi after he told it Equation [eq:qr2];
Note 2: There is actually a subtle difference between Equation [eq:qr2] and [eq:qr2-sim]—the first step of Equation [eq:qr2] is (\boldsymbol{M}_t\boldsymbol{V}_{t-1})^{\top}(\boldsymbol{M}_t\boldsymbol{V}_{t-1}), while Equation [eq:qr2-sim] is \boldsymbol{V}_{t-1}^{\top}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t)\boldsymbol{V}_{t-1}. These two algorithms are mathematically completely equivalent, but they differ under finite-precision floating-point operations. The matrix product in the latter has a larger condition number, so the Cholesky decomposition may require a slightly larger regularization.
Simplifying the Regularization
For the matrix \boldsymbol{A}^{\top}\boldsymbol{A}, the regularization term we add during Cholesky decomposition is \lambda\boldsymbol{I}, where \lambda=\epsilon \Vert\boldsymbol{A}^{\top}\boldsymbol{A}\Vert_F and \epsilon=10^{-9}. However, the reason for this form has not been explained in detail. Here we will elaborate on it and, combined with the problem context, derive a simpler regularization term.
Due to the positive definite symmetry of \boldsymbol{A}^{\top}\boldsymbol{A}, its SVD coincides with its eigenvalue decomposition. Let its SVD be \boldsymbol{A}^{\top}\boldsymbol{A} = \boldsymbol{V}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}, then \boldsymbol{A}^{\top}\boldsymbol{A} + \lambda\boldsymbol{I} = \boldsymbol{V}(\boldsymbol{\Sigma} + \lambda\boldsymbol{I})\boldsymbol{V}^{\top}. Let the maximum and minimum singular values of \boldsymbol{A}^{\top}\boldsymbol{A} be \sigma_{\max} and \sigma_{\min}, respectively. Then the maximum and minimum singular values of \boldsymbol{A}^{\top}\boldsymbol{A} + \lambda\boldsymbol{I} are \sigma_{\max} + \lambda and \sigma_{\min} + \lambda. The condition number is the ratio of the maximum to the minimum singular value, so it reduces from \sigma_{\max}/\sigma_{\min} to \frac{\sigma_{\max} + \lambda}{\sigma_{\min} + \lambda} < \frac{\sigma_{\max} + \lambda}{\lambda} = \frac{\sigma_{\max}}{\lambda} + 1 If we want to control the condition number to not exceed 1/\epsilon + 1, then \lambda \geq \epsilon \sigma_{\max}. This indicates that ideally we should use the maximum singular value of \boldsymbol{A}^{\top}\boldsymbol{A}—i.e., the spectral norm—as the benchmark to adjust \lambda. However, the spectral norm is relatively complex to compute, so we switched to the simpler Frobenius norm \Vert\boldsymbol{A}^{\top}\boldsymbol{A}\Vert_F. This is the origin of \lambda=\epsilon \Vert\boldsymbol{A}^{\top}\boldsymbol{A}\Vert_F. As for \epsilon=10^{-9}, it is purely an experimental conclusion.
However, “the spectral norm is complex to compute, so we use the Frobenius norm instead” is only a general conclusion for arbitrary matrices. Here, the streaming power iteration we are performing is itself used to compute the SVD. As training progresses, \boldsymbol{V}_t will get closer and closer to the right singular matrix of \boldsymbol{M}_t. Since \boldsymbol{M}_t changes slowly, \boldsymbol{V}_{t-1} is also roughly the same. Therefore, theoretically, \boldsymbol{V}_{t-1}^{\top}\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1} will become increasingly close to a diagonal matrix, and its top-left element will become increasingly close to its spectral norm!
Similarly, \tilde{\boldsymbol{U}}_t = \mathop{\text{QR}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1}) will get closer and closer to the left singular matrix of \boldsymbol{M}_t, so \tilde{\boldsymbol{U}}_t^{\top}\boldsymbol{M}_t\boldsymbol{M}_t^{\top}\tilde{\boldsymbol{U}}_t will also become increasingly close to a diagonal matrix, and its top-left element will become increasingly close to its spectral norm. Therefore, in our scenario, the simplest and most accurate benchmark is to directly use (\boldsymbol{A}^{\top}\boldsymbol{A})_{[0,0]} as an approximation of the spectral norm, i.e., \lambda=\epsilon \cdot (\boldsymbol{A}^{\top}\boldsymbol{A})_{[0,0]} suffices. Empirical tests show that \epsilon=10^{-7} can achieve a good balance between effectiveness and efficiency.
Reference Implementation
Combining the modifications from the above two sections, the reference implementation for the iteration from \boldsymbol{V}_{t-1} to \boldsymbol{V}_t is as follows:
import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
from jax import lax
def shift_old(A, eps=1e-9):
return A + eps * jnp.linalg.matrix_norm(A, keepdims=True) * jnp.eye(A.shape[-1])
def scqr(A, eps=1e-9):
"""First try Shifted Cholesky QR, fallback to default QR on failure"""
R = jnp.linalg.cholesky(shift_old(A.mT @ A, eps), upper=True)
Q = solve_triangular(R.mT, A.mT, lower=True).mT
return lax.cond(jnp.isfinite(Q).all(), lambda: Q, lambda: jnp.linalg.qr(A)[0])
def v_step_old(M, V, eps=1e-9):
return scqr(M.mT @ scqr(M @ V, eps), eps)
def shift(A, eps=1e-7):
return A + eps * A[..., :1, :1] * jnp.eye(A.shape[-1])
def v_step(M, V, eps=1e-7):
A = (M.mT @ M) @ V
R = jnp.linalg.cholesky(shift(V.mT @ A, eps), upper=True)
B = solve_triangular(R.mT, A.mT, lower=True).mT
R = jnp.linalg.cholesky(shift(B.mT @ B, eps), upper=True)
Q = solve_triangular(R.mT, B.mT, lower=True).mT
return lax.cond(jnp.isfinite(Q).all(), lambda: Q, lambda: jnp.linalg.qr(A)[0])
Concurrent Work
In the period between the publication of Muon Implementation Based on Streaming Power Iteration: 2. Acceleration (https://kexue.fm/archives/11673) and this article, some interesting optimization works have also appeared externally. They share similar optimization ideas with the two modifications proposed in this article, and we can study them together.
First, after the previous article was published, @Ji_Ha_Kim (https://x.com/Ji_Ha_Kim) also proposed some improvement ideas. For example, he mentioned in a discussion with GPT (link (https://x.com/Ji_Ha_Kim/status/2038282538452431071)) that we might be able to save one \text{Triangular Solve}! Specifically, we have \begin{aligned} \boldsymbol{V}_t =&\, \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\mathop{\text{QR}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1})) \\ =&\, \mathop{\text{QR}}(\boldsymbol{V}_{t-1}(\boldsymbol{M}_t \boldsymbol{V}_{t-1})^{\top}\mathop{\text{QR}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1})) \\ =&\, \mathop{\text{QR}}(\boldsymbol{V}_{t-1}(\mathop{\text{QR}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1})^{\top}\boldsymbol{M}_t \boldsymbol{V}_{t-1})^{\top}) \\ =&\, \boldsymbol{V}_{t-1}\mathop{\text{QR}}((\mathop{\text{QR}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1})^{\top}\boldsymbol{M}_t \boldsymbol{V}_{t-1})^{\top}) \\ \end{aligned} It is easy to see that \mathop{\text{QR}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1})^{\top}\boldsymbol{M}_t \boldsymbol{V}_{t-1} is exactly the R from the QR decomposition of \boldsymbol{M}_t\boldsymbol{V}_{t-1}, which can be directly obtained via Cholesky decomposition.
In other words, theoretically, after obtaining R from the first Cholesky decomposition, we can proceed to the second \mathop{\text{QR}}, saving one \text{Triangular Solve}. However, this only has theoretical significance, because this result relies on the strict orthogonality of \boldsymbol{V}_{t-1} and \mathop{\text{QR}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1}), which only holds under exact QR (i.e., \lambda=0). In practice, to ensure efficiency, we can only use SCQR, which leads to results that are not strictly orthogonal. Therefore, prematurely exploiting its orthogonality during the identity transformation process will instead cause error accumulation in the iterative algorithm.
During this period, the Tri-Dao team also published Gram Newton-Schulz: A Fast, Hardware-Aware Newton-Schulz Algorithm for Muon (https://dao-lab.ai/blog/2026/gram-newton-schulz/), proposing an acceleration idea for the \mathop{\text{msign}} operator. By definition, \mathop{\text{msign}}(\boldsymbol{M}) = \boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2}. The team aimed to compute the m\times m matrix (\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2} via Newton-Schulz iteration instead of \mathop{\text{msign}}, which can significantly reduce the computational cost when n\gg m. In fact, many researchers had previously attempted this idea but failed, and Tri-Dao et al. cleverly solved this problem through Restart.
Clearly, this optimization direction for \mathop{\text{msign}} is consistent with the transformation of streaming power iteration from Equation [eq:qr2] to [eq:qr2-sim]. Coincidentally, @Ji_Ha_Kim (https://x.com/Ji_Ha_Kim/status/2039043040233275804) suggested changing the Newton-Schulz iteration for \mathop{\text{msign}} from a polynomial to a rational form, which can achieve equally good results with fewer iterations. The problem with rational iteration is the need to compute an inverse matrix. However, combined with the specific context of \mathop{\text{msign}}, it only requires inverting an m\times m positive definite symmetric matrix, which can be done via Cholesky decomposition and two \text{Triangular Solve}s, and is still acceptable.
However, in this way, the computational flow of this rational iteration actually highly overlaps with streaming power iteration, and each step additionally requires one more \text{Triangular Solve}, so it seems that its speed cannot surpass that of streaming power iteration.
Summary
This article further “refines” the implementation details of streaming power iteration. The main improvements include: 1. Adjusting the computation order to reduce the number of \mathcal{O}(nm^2) complexity operations from four to one; 2. Simplifying the regularization term by leveraging the special context of streaming power iteration. These optimizations further reduce the computational bottlenecks of streaming power iteration, pushing its computational efficiency to the extreme.
To reprint, please include the article address: https://kexue.fm/archives/11697 (https://kexue.fm/archives/11697)
For more detailed reprint guidelines, please refer to: Science Space FAQ (https://kexue.fm/archives/6508)