English (unofficial) translations of posts at kexue.fm
Source

Muon Implementation Based on Streaming Power Iteration: 1. First Look

Translated by DeepSeek V4 Pro. Translations can be inaccurate, please refer to the original post for important stuff.

The core operation of Muon is \mathop{\text{msign}}, and the current standard implementation is the Newton-Schulz iteration. It must be said that this is indeed a very efficient and GPU-friendly algorithm, and at least half of the credit for Muon’s popularity goes to this algorithm. However, this algorithm also gives a feeling of “only one way, no other alternatives,” because it seems to be limited to computing \mathop{\text{msign}}. Once we want to modify Muon (e.g., replace \mathop{\text{msign}} with \mathop{\text{mclip}} from here), the corresponding computation becomes troublesome.

This article proposes a new implementation idea—approximating SVD via Streaming Power Iteration. This is not a completely new idea; it has appeared in some previous optimizer works, but here we extract it separately and use it as an independent algorithm.

Content Review

We won’t go into the details of Muon; readers can refer to previous articles such as “Muon Optimizer Appreciation: The Essential Leap from Vectors to Matrices”, “Muon Sequel: Why We Choose to Try Muon?”, “Muon Optimizer Guide: Quick Start and Key Details”. Here we directly give its formula: \begin{aligned} \boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt] \boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t [\mathop{\text{msign}}(\boldsymbol{M}_t) + \lambda \boldsymbol{W}_{t-1}] \\ \end{aligned} where \mathop{\text{msign}} is \mathop{\text{msign}}(\boldsymbol{M})=\boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2}=\boldsymbol{U}_{[:, :r]}\boldsymbol{V}_{[:, :r]}^{\top} Here \boldsymbol{M}\in\mathbb{R}^{n\times m}, without loss of generality, we assume n\geq m, and for simplicity, in most cases we assume r=m (i.e., full rank). Only when absolutely necessary will we discuss the rank-deficient case.

Since SVD is relatively expensive, in most cases we use the Newton-Schulz iteration to compute \mathop{\text{msign}}, which we have discussed in detail in “Newton-Schulz Iteration for the msign Operator (Part 1)” and “Newton-Schulz Iteration for the msign Operator (Part 2)”. Overall, the Newton-Schulz iteration is very clever and is the main contributor to Muon’s success, but its extensibility is relatively weak.

To expand the application scenarios of the Newton-Schulz iteration, the author has also done some work, such as “Computing Singular Value Clipping mclip via msign (Part 1)”, “Computing Singular Value Clipping mclip via msign (Part 2)”, “Efficient Computation of Matrix Square Root and Inverse Square Root”, “Efficient Computation of Matrix r-th Root and Inverse r-th Root”, etc., but what can be done is still relatively limited.

Obviously, the once-and-for-all method is to directly compute the SVD, which is the idea we will focus on next.

Power Iteration

In articles such as “Lipschitz Constraints in Deep Learning: Generalization and Generative Models” and “From Spectral Norm Gradient to New Weight Decay”, we have already encountered Power Iteration, which we used to find the principal eigenvector of \boldsymbol{M}^{\top}\boldsymbol{M}, or the right principal singular vector of \boldsymbol{M}, with the iteration format as follows: \boldsymbol{v}_1^{(t)} = \frac{\boldsymbol{M}^{\top}\boldsymbol{M}\boldsymbol{v}_1^{(t-1)}}{\Vert\boldsymbol{M}^{\top}\boldsymbol{M}\boldsymbol{v}_1^{(t-1)}\Vert_2} Assuming we have obtained the principal eigenvector \boldsymbol{v}_1, we can add orthogonalization to the power iteration to find the next eigenvector: \boldsymbol{v}_2^{(t)} = \frac{\tilde{\boldsymbol{v}}_2^{(t)} - \langle\tilde{\boldsymbol{v}}_2^{(t)},\boldsymbol{v}_1\rangle\boldsymbol{v}_1}{\Vert\tilde{\boldsymbol{v}}_2^{(t)} - \langle\tilde{\boldsymbol{v}}_2^{(t)},\boldsymbol{v}_1\rangle\boldsymbol{v}_1\Vert_2},\qquad \tilde{\boldsymbol{v}}_2^{(t)} = \boldsymbol{M}^{\top}\boldsymbol{M}\boldsymbol{v}_2^{(t-1)} Because orthogonality with \boldsymbol{v}_1 is guaranteed, this will converge to the next eigenvector \boldsymbol{v}_2. Similarly, assuming we know \boldsymbol{v}_1,\boldsymbol{v}_2,\cdots,\boldsymbol{v}_{k-1}, we can use Gram-Schmidt orthogonalization to find the k-th eigenvector: \boldsymbol{v}_k^{(t)} = \frac{\tilde{\boldsymbol{v}}_k^{(t)} - \langle\tilde{\boldsymbol{v}}_k^{(t)},\boldsymbol{v}_1\rangle\boldsymbol{v}_1 - \cdots - \langle\tilde{\boldsymbol{v}}_k^{(t)},\boldsymbol{v}_{k-1}\rangle\boldsymbol{v}_{k-1}}{\Vert\tilde{\boldsymbol{v}}_k^{(t)} - \langle\tilde{\boldsymbol{v}}_k^{(t)},\boldsymbol{v}_1\rangle\boldsymbol{v}_1 - \cdots - \langle\tilde{\boldsymbol{v}}_k^{(t)},\boldsymbol{v}_{k-1}\rangle\boldsymbol{v}_{k-1}\Vert_2},\qquad \tilde{\boldsymbol{v}}_k^{(t)} = \boldsymbol{M}^{\top}\boldsymbol{M}\boldsymbol{v}_k^{(t-1)}\label{eq:vk-pi} In fact, we don’t need to wait until \boldsymbol{v}_1,\boldsymbol{v}_2,\cdots,\boldsymbol{v}_{k-1} are all computed before computing \boldsymbol{v}_k; the entire set \boldsymbol{V}=[\boldsymbol{v}_1,\boldsymbol{v}_2,\cdots,\boldsymbol{v}_m] can be iterated in parallel. Specifically, starting from an existing approximation \boldsymbol{V}_{t-1}, we compute \boldsymbol{M}^{\top}\boldsymbol{M}\boldsymbol{V}_{t-1} in batch, then re-orthogonalize the columns (e.g., using QR decomposition), yielding a better approximation, which we denote as \boldsymbol{V}_t. Repeating this iteration will eventually converge to our target \boldsymbol{V}: \boldsymbol{V}_t = \mathop{\text{QR}}(\boldsymbol{M}^{\top}\boldsymbol{M}\boldsymbol{V}_{t-1}) Here \mathop{\text{QR}} refers to the orthogonal matrix after QR decomposition. Once we have \boldsymbol{V}, it is clear that \boldsymbol{U} = \mathop{\text{ColNorm}}(\boldsymbol{M}\boldsymbol{V}), where \mathop{\text{ColNorm}} means L2 normalization of each column (axis=0), and \boldsymbol{\Sigma}=\mathop{\text{diag}}(\boldsymbol{U}^{\top}\boldsymbol{M}\boldsymbol{V}). Thus we obtain an approximate SVD computation scheme based on power iteration and QR decomposition. Of course, when n > m it only yields an incomplete decomposition, with \boldsymbol{U}\in\mathbb{R}^{n\times m} and \boldsymbol{\Sigma},\boldsymbol{V}\in\mathbb{R}^{m\times m}, but this is sufficient.

Streaming Update

However, using power iteration to compute SVD is extremely inefficient in practice, far slower than directly calling the framework’s built-in SVD function, so it is not practical. But considering that training itself is a long-term iterative process, we can assume that \boldsymbol{V} does not change much at each step. Thus we can save the previous step’s \boldsymbol{V} as the initialization for the current step, and perform only one power iteration per step, i.e., \begin{aligned} \boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt] \boldsymbol{V}_t =&\, \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}) \\[5pt] \boldsymbol{U}_t =&\, \mathop{\text{ColNorm}}(\boldsymbol{M}_t\boldsymbol{V}_t) \\[5pt] \boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t (\boldsymbol{U}_t\boldsymbol{V}_t^{\top} + \lambda \boldsymbol{W}_{t-1}) \\ \end{aligned}\label{eq:muon-qr} where \boldsymbol{V}_0=\boldsymbol{I}. Experiments show that Muon implemented via this streaming power iteration can indeed produce a convergence curve on LM Loss that almost coincides with the Newton-Schulz version, indicating that it is a feasible approach. This is largely due to the momentum mechanism and small learning rate, which make the assumption that “\boldsymbol{V} does not change much at each step” approximately hold, thereby allowing the cost of power iteration to be “amortized” over each step.

Thanks to the direct approximate computation of SVD, we can also perform some operations on the singular values and incorporate them into the optimizer, for example: \begin{aligned} \boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt] \boldsymbol{V}_t =&\, \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}) \\[5pt] \boldsymbol{U}_t =&\, \mathop{\text{ColNorm}}(\boldsymbol{M}_t\boldsymbol{V}_t) \\[5pt] \boldsymbol{\Sigma}_t =&\, \mathop{\text{diag}}(\boldsymbol{U}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_t) \\[5pt] \boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t (\boldsymbol{U}_t f(\boldsymbol{\Sigma}_t)\boldsymbol{V}_t^{\top} + \lambda \boldsymbol{W}_{t-1}) \\ \end{aligned} This makes it much easier to implement \mathop{\text{mclip}} or Muon variants based on general Schatten norms. In summary, having explicit results for \boldsymbol{U}_t,\boldsymbol{\Sigma}_t,\boldsymbol{V}_t (even if only approximate) allows us to easily try small modifications, significantly enhancing extensibility and playability.

Accelerating the Decomposition

Now the pressure shifts to the QR decomposition. The most time-consuming step in Eq. [eq:muon-qr] is the QR decomposition. The standard implementation is Householder QR. Although it is much faster than SVD, it is still slower than the \mathop{\text{msign}} computed by the Newton-Schulz iteration (polynomial iteration with BF16 multiplication is practically cheating). Therefore, to enhance the competitiveness of this new approach, we need to speed up the QR decomposition.

For a given matrix \boldsymbol{A}\in\mathbb{R}^{n\times m} (n\geq m), QR decomposition aims to find an orthogonal matrix \boldsymbol{Q}\in\mathbb{R}^{n\times m} and an upper triangular matrix \boldsymbol{R}\in\mathbb{R}^{m\times m} such that \boldsymbol{A}=\boldsymbol{Q}\boldsymbol{R} (here the orthogonal matrix only needs to satisfy \boldsymbol{Q}^{\top}\boldsymbol{Q}=\boldsymbol{I}, more accurately called a Stiefel matrix). Note that \boldsymbol{A}^{\top}\boldsymbol{A}=\boldsymbol{R}^{\top}\boldsymbol{R}, meaning that we only need to decompose \boldsymbol{A}^{\top}\boldsymbol{A} into the product of a lower triangular matrix and its transpose to obtain \boldsymbol{R}, and this is exactly what Cholesky decomposition does!

Cholesky decomposition is very efficient, so in the first step we can use it to obtain \boldsymbol{R}, and then solve the equation \boldsymbol{Q}\boldsymbol{R}=\boldsymbol{A} to get \boldsymbol{Q}. The equation can also be written as \boldsymbol{R}^{\top}\boldsymbol{Q}^{\top}=\boldsymbol{A}^{\top}, which can be solved using solve_triangular, also very efficient. These two steps together form the QR decomposition algorithm known as “Cholesky QR”. If numerical stability is not a concern, it may be the fastest QR decomposition method.

Unfortunately, compared to standard QR decomposition, Cholesky QR is very unstable; it is extremely sensitive to the condition number of \boldsymbol{A}^{\top}\boldsymbol{A}. To address this, Shifted CholeskyQR for computing the QR factorization of ill-conditioned matrices (abbreviated as “SCQR”) proposed adding a regularization term \lambda \boldsymbol{I} (\lambda=\epsilon \Vert\boldsymbol{A}^{\top}\boldsymbol{A}\Vert_F) to \boldsymbol{A}^{\top}\boldsymbol{A} to alleviate this problem. However, this is a double-edged sword: the larger \epsilon is, the more stable SCQR becomes, but the resulting \boldsymbol{Q} becomes less orthogonal, and the final performance deteriorates.

Moreover, even with the introduction of \epsilon, there is no guarantee that SCQR will always succeed, so we need an additional check: if it fails, fall back to standard QR.

Reference Implementation

A simple reference implementation based on Jax is as follows:

import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
from jax import lax

def shift(A, eps=1e-9):
    return A + eps * jnp.linalg.matrix_norm(A, keepdims=True) * jnp.eye(A.shape[-1])

def scqr(A, eps=1e-9):
    """First try Shifted Cholesky QR, fall back to default QR if it fails
    """
    R = jnp.linalg.cholesky(shift(A.mT @ A, eps), upper=True)
    Q = solve_triangular(R.mT, A.mT, lower=True).mT
    return lax.cond(jnp.isfinite(Q).all(), lambda: Q, lambda: jnp.linalg.qr(A)[0])

Simple tests show that if it can be executed successfully, the efficiency of SCQR is comparable to the Newton-Schulz version of \mathop{\text{msign}}. However, to ensure the approximation quality, \epsilon cannot be too small; otherwise the performance will degrade significantly. In practice, \epsilon=10^{-9} is usually required to guarantee reasonable performance, and at this value SCQR still has a relatively high probability of falling back to standard QR, so the final speed still lags behind the Newton-Schulz iteration.

In addition to directly improving the QR decomposition algorithm, there are other acceleration techniques, such as keeping only the first k eigenvectors, so that \boldsymbol{V} only needs to be initialized as m\times k instead of m\times m, which can also reduce some computation. Further acceleration of QR decomposition is left for readers to explore, and we will not elaborate here.

Other Details

Additionally, there are some details that require special attention, as they are closely related to training stability and final performance.

First, according to the convention \boldsymbol{M}_t\in\mathbb{R}^{n\times m}, we need to ensure n\geq m; otherwise we should transpose it. If n < m, then the matrix \boldsymbol{M}_t^{\top}\boldsymbol{M}_t \boldsymbol{V}_{t-1} is necessarily rank-deficient, and performing QR decomposition on a rank-deficient matrix is ill-posed. SCQR, in particular, is more prone to various pathological phenomena, ultimately degrading performance. Therefore, ensuring n\geq m not only guarantees numerical stability and improves performance but also accelerates computation, killing multiple birds with one stone.

Second, experiments show that adding an extra \mathop{\text{ColNorm}} to the \mathop{\text{QR}} step significantly helps improve performance: \boldsymbol{V}_t = \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}) \qquad\to\qquad \boldsymbol{V}_t = \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\mathop{\text{ColNorm}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1})) This is equivalent to changing \tilde{\boldsymbol{v}}_k^{(t)} = \boldsymbol{M}^{\top}\boldsymbol{M}\boldsymbol{v}_k^{(t-1)} in Eq. [eq:vk-pi] to \tilde{\boldsymbol{v}}_k^{(t)} = \boldsymbol{M}^{\top}(\boldsymbol{M}\boldsymbol{v}_k^{(t-1)} / \Vert\boldsymbol{M}\boldsymbol{v}_k^{(t-1)}\Vert_2), which does not alter the convergence of the power iteration itself. However, experiments show that this extra \mathop{\text{ColNorm}} step significantly helps training performance, especially under SCQR, noticeably narrowing the performance gap with standard QR.

It can be proven that this operation theoretically does not change the power iteration but only affects the numerical computation of QR. According to experimental observations, it actually increases the probability of SCQR falling back to standard QR (but not by much, so it does not significantly slow down), and it also slightly helps standard QR. Thus it appears to improve the conditioning of the matrix to be QR-decomposed, resulting in better QR quality.

Summary

This article mainly introduced the idea of computing SVD via Streaming Power Iteration and then implementing Muon. It only requires one QR decomposition per step, and compared to the standard Newton-Schulz iteration implementation, this approach offers more flexible extensibility.

For reprinting, please include the article address: https://kexue.fm/archives/11654

For more detailed reprinting guidelines, please refer to: Science Space FAQ