English (unofficial) translations of posts at kexue.fm
Source

Muon Implementation Based on Streaming Power Iteration: 2. Acceleration

Translated by DeepSeek V4 Pro. Translations can be inaccurate, please refer to the original post for important stuff.

Streaming Iteration

We will continue to use all concepts and notations from the first article; readers with related questions should refer back to it first. First, the update formula of Muon is \begin{aligned} \boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt] \boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t [\mathop{\text{msign}}(\boldsymbol{M}_t) + \lambda \boldsymbol{W}_{t-1}] \\ \end{aligned} where the standard implementation of \mathop{\text{msign}} is Newton-Schulz iteration, which is also the most expensive computation in the Muon optimizer. In contrast, the update formula of the streaming power iteration scheme is \begin{aligned} \boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt] \boldsymbol{V}_t =&\, \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}) \\[5pt] \boldsymbol{U}_t =&\, \mathop{\text{ColNorm}}(\boldsymbol{M}_t\boldsymbol{V}_t) \\[5pt] \boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t (\boldsymbol{U}_t\boldsymbol{V}_t^{\top} + \lambda \boldsymbol{W}_{t-1}) \\ \end{aligned} If we repeatedly execute \boldsymbol{V}_t = \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}), that is standard power iteration, and the result will converge to the right singular matrix of \boldsymbol{M}_t, thereby realizing the SVD of \boldsymbol{M}_t and then computing \mathop{\text{msign}}. However, performing a full power iteration at every step is too costly; instead we cache the result \boldsymbol{V}_{t-1} from the previous step and perform only one \mathop{\text{QR}} iteration per step as an approximation—this is the meaning of “streaming”.

Now the most expensive operation becomes the \mathop{\text{QR}} decomposition. The most naive implementation is naturally to call the built-in QR function, whose underlying principle is Householder transformation, which has good stability but is relatively slow.

First Speedup

To accelerate, in the previous article we introduced Cholesky QR, which splits the QR decomposition of a matrix \boldsymbol{A} into two steps: 1. Perform Cholesky decomposition on \boldsymbol{A}^{\top}\boldsymbol{A} to obtain the upper triangular matrix \boldsymbol{R}; 2. Solve the equation \boldsymbol{Q}\boldsymbol{R}=\boldsymbol{A} to obtain the orthogonal matrix \boldsymbol{Q}. Both steps are theoretically very efficient, but in practice the computation may fail if the condition number is too large. To address this, we also introduced the Shift technique, which adds a regularization term \lambda \boldsymbol{I} (with \lambda=\epsilon \Vert\boldsymbol{A}^{\top}\boldsymbol{A}\Vert_F) to \boldsymbol{A}^{\top}\boldsymbol{A} to reduce the condition number.

Combining the two, we abbreviate it as “SCQR (Shifted Cholesky QR)”. A reference implementation based on Jax is as follows:

import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
from jax import lax

def shift(A, eps=1e-9):
    return A + eps * jnp.linalg.matrix_norm(A, keepdims=True) * jnp.eye(A.shape[-1])

def scqr(A, eps=1e-9):
    """First try Shifted Cholesky QR; if it fails, fall back to default QR
    """
    R = jnp.linalg.cholesky(shift(A.mT @ A, eps), upper=True)
    Q = solve_triangular(R.mT, A.mT, lower=True).mT
    return lax.cond(jnp.isfinite(Q).all(), lambda: Q, lambda: jnp.linalg.qr(A)[0])

Note that the smaller \lambda is, the more likely SCQR is to fail, while the larger \lambda is, the more the result deviates from orthogonality and the worse the effect becomes. Therefore \lambda must be “just right”, which means the current scheme still has a relatively high probability of falling back to standard QR. In addition, the previous article also mentioned that adding \mathop{\text{ColNorm}} to the power iteration (i.e., changing it to \boldsymbol{V}_t = \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\mathop{\text{ColNorm}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1}))) can stabilize the training effect, and its effect is even more pronounced under SCQR.

Full Precision

The above basically covers the entire content of the first article. It can only be said that it indeed ran through the whole process, verified feasibility, and compared to directly calling the framework’s built-in QR decomposition function, SCQR also provided some speedup, but the speed is still noticeably slower than the \mathop{\text{msign}} implemented via Newton-Schulz iteration. So we still need to find ways to speed it up.

This section introduces the first acceleration technique—enabling the “full” FP32 precision matrix multiplication. First, it should be pointed out that the several new steps added by streaming power iteration are all computed in FP32 precision. However, starting from the A100’s introduction of the TF32 format, some frameworks (such as Jax, which the author uses for small experiments, or certain versions of Torch) by default convert FP32 arrays to TF32 format during matrix multiplication to accelerate, and one must manually enable true FP32 precision multiplication.

Some readers may wonder: shouldn’t increasing multiplication precision slow things down? How can it speed up instead? This is indeed counter-intuitive, but it is not hard to understand. Reducing matrix precision often increases its condition number, thereby increasing the probability of SCQR failure and raising the chance of falling back to standard QR, which increases time consumption. Conversely, increasing precision can increase the success rate of SCQR, and QR is precisely the most time-consuming part, so the total time actually becomes shorter.

According to available information, Jax has always defaulted to TF32 for FP32 multiplication, so one must manually enable it via jax.config.update('jax_default_matmul_precision', 'highest'). Torch is more complicated: versions 1.7 to 1.11 default to TF32 multiplication, but starting from 1.12 the default is FP32 multiplication. Considering that Torch is now at version 2.11, it is estimated that most users no longer need to manually enable it.

Double Orthogonalization

The second acceleration technique the author thought of is adding an extra orthogonalization step for the left eigenmatrix, i.e., changing the power iteration step to \boldsymbol{V}_t = \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\mathop{\text{QR}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1})) This step is also counter-intuitive: adding one more \mathop{\text{QR}} actually makes the overall speed faster. The reason is similar to the previous section: it reduces the condition number of the matrix to be decomposed, thereby increasing the success rate of SCQR. Since SCQR itself is very fast, executing it twice does not add much time; on the contrary, it significantly speeds up the process by greatly reducing the number of fallbacks to standard QR.

Understanding this acceleration technique involves two steps: first, adding this extra \mathop{\text{QR}} theoretically does not change the power iteration; second, adding this extra \mathop{\text{QR}} indeed reduces the condition number. The first point is easy to understand. If \boldsymbol{A}=\boldsymbol{Q}\boldsymbol{R}, then \boldsymbol{Q}=\boldsymbol{A}\boldsymbol{R}^{-1}, where \boldsymbol{R}^{-1} is also an upper triangular matrix. That is, QR decomposition can be written as right-multiplication by an upper triangular matrix. Then \boldsymbol{M}_t^{\top}\mathop{\text{QR}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1}) = \boldsymbol{M}_t^{\top}(\boldsymbol{M}_t\boldsymbol{V}_{t-1}\times \text{some upper triangular matrix}) = \boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}\times \text{some upper triangular matrix} By the uniqueness of QR decomposition, right-multiplying by an upper triangular matrix does not change the result of \mathop{\text{QR}}, so theoretically it is equivalent to \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}).

As for the condition number, it equals the ratio of the largest singular value to the smallest singular value of the matrix. If only a single \mathop{\text{QR}} is performed, then the matrix for Cholesky decomposition is \boldsymbol{V}_{t-1}^{\top}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t)^2\boldsymbol{V}_{t-1}. Note that orthogonal transformations do not change singular values, hence do not change the condition number, so at this point the condition number of the matrix to be decomposed reaches the fourth power of the condition number of \boldsymbol{M}_t! If two \mathop{\text{QR}}s are performed, then the matrix for Cholesky decomposition will be \boldsymbol{Q}_t^{\top}(\boldsymbol{M}_t\boldsymbol{M}_t^{\top})\boldsymbol{Q}_t, where \boldsymbol{Q}_t is the orthogonal matrix from the first \mathop{\text{QR}}. At this time the condition number is only the square of that of \boldsymbol{M}_t, a significant reduction.

Translation Invariance

The third acceleration technique was derived from discussions between the author and @YouJiacheng. It exploits the translation invariance of eigenmatrices. We know that the power iteration \boldsymbol{V}_t = \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}) can also be understood as finding the eigenmatrix of the positive definite matrix \boldsymbol{M}_t^{\top}\boldsymbol{M}_t. A property of positive definite matrices is that adding a multiple of the identity matrix does not change the eigenmatrix.

In other words, \boldsymbol{M}_t^{\top}\boldsymbol{M}_t and \boldsymbol{M}_t^{\top}\boldsymbol{M}_t + \lambda \boldsymbol{I} have the same eigenmatrix, so we can change the power iteration to \boldsymbol{V}_t = \mathop{\text{QR}}((\boldsymbol{M}_t^{\top}\boldsymbol{M}_t + \lambda \boldsymbol{I})\boldsymbol{V}_{t-1}) = \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1} + \lambda \boldsymbol{V}_{t-1}) without altering the convergence result of the power iteration. What is the benefit of adding \lambda \boldsymbol{I} to \boldsymbol{M}_t^{\top}\boldsymbol{M}_t? The answer is again to reduce the condition number, i.e., (\sigma_{\max} + \lambda)/(\sigma_{\min} + \lambda) < \sigma_{\max}/\sigma_{\min}, which can also improve the success rate of Cholesky QR. Note that here we are talking about Cholesky QR, not SCQR, because setting an appropriate \lambda externally can guarantee the condition number, eliminating the need for the Shift, and the resulting matrix will definitely be orthogonal, which is also a nice property.

But don’t celebrate too early. The larger \lambda is, the easier Cholesky QR naturally succeeds, but at the same time it will reduce the convergence speed of the power iteration! This is because the convergence speed of power iteration depends on the ratio of adjacent singular values: the smaller \sigma_{i+1}/\sigma_i is (singular values sorted descending), the faster the convergence. However, (\sigma_{i+1} + \lambda)/(\sigma_i + \lambda) > \sigma_{i+1}/\sigma_i, so the larger \lambda is, the slower the power iteration converges, and the final effect will also deteriorate.

Therefore, we must carefully adjust the value of \lambda to balance the success rate of Cholesky QR and the convergence speed of the power iteration. The author’s tests found that taking \lambda = \epsilon\Vert\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\Vert_F with \epsilon=10^{-4} yields relatively good results. Another approach is to use a larger \lambda to guarantee the success rate of Cholesky QR, and then perform two iteration steps to improve the convergence speed of the power iteration, i.e., \boldsymbol{V}_t = \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\tilde{\boldsymbol{V}}_t + \lambda \tilde{\boldsymbol{V}}_t),\qquad \tilde{\boldsymbol{V}}_t = \mathop{\text{QR}}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1} + \lambda \boldsymbol{V}_{t-1}) This can simultaneously take into account both Cholesky QR and power iteration, at the cost of requiring two \mathop{\text{QR}}s per step.

Multi-step Correction

The fourth acceleration technique is called “SCQR2”, which is a general multi-step correction technique for SCQR. Let us review the two steps of SCQR (given a matrix \boldsymbol{A} to be decomposed): \begin{aligned} 1)\quad&\, \boldsymbol{R}^{\top}\boldsymbol{R}= \boldsymbol{A}^{\top}\boldsymbol{A} + \lambda \boldsymbol{I} &\,(\text{Cholesky decomposition of }\boldsymbol{A}^{\top}\boldsymbol{A}+\lambda\boldsymbol{I}) \\[5pt] 2)\quad&\, \boldsymbol{Q} = \boldsymbol{A}\boldsymbol{R}^{-1}&\,(\text{solve triangular linear system }\boldsymbol{Q}\boldsymbol{R}=\boldsymbol{A}) \end{aligned} The problem with SCQR is that the larger \lambda is, the easier the Cholesky decomposition succeeds, but the less orthogonal \boldsymbol{Q} = \boldsymbol{A}\boldsymbol{R}^{-1} becomes. The idea of SCQR2 is to first perform SCQR with a relatively large \lambda; although the result is not orthogonal at this point, it is closer to orthogonality than the original \boldsymbol{A}, indicating that the condition number has been reduced. Then we can perform SCQR again on the result with a smaller \lambda to correct the orthogonality. A rough implementation is as follows:

def shift(A, eps=1e-9):
    return A + eps * jnp.linalg.matrix_norm(A, keepdims=True) * jnp.eye(A.shape[-1])

def scqr(A, eps=1e-9):
    """Shifted Cholesky QR
    """
    R = jnp.linalg.cholesky(shift(A.mT @ A, eps), upper=True)
    return solve_triangular(R.mT, A.mT, lower=True).mT

def scqr2(A, eps1=1e-4, eps2=1e-8):
    """SCQR twice; if it fails, fall back to default QR
    """
    Q = scqr(scqr(A, eps1), eps2)
    return lax.cond(jnp.isfinite(Q).all(), lambda: Q, lambda: jnp.linalg.qr(A)[0])

In principle, we need to understand why the two-step correction is feasible. Suppose the first SCQR yields \boldsymbol{Q}_1 = \boldsymbol{A}\boldsymbol{R}_1^{-1}. Although it deviates from orthogonality, it has the form “\boldsymbol{A}\times \text{upper triangular matrix}”. As we said earlier, right-multiplying by an upper triangular matrix does not change the QR result, so this allows us to perform another SCQR on top of the first SCQR. Of course, in principle we could also perform more steps of correction.

Method Summary

We have discussed four acceleration techniques above. Here is a brief summary of their characteristics.

The first technique is to increase the precision of FP32 matrix multiplication. This is universal: Jax requires manual enabling, while newer versions of Torch have it enabled by default. The second, third, and fourth techniques are independent and cannot be combined with each other. Intuitively, the upper limit of Technique 2 is higher, because Techniques 3 and 4 both take \boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1} as input, meaning the condition number has already been amplified, and then they try to remedy it. Technique 2, on the other hand, modifies the input to \boldsymbol{M}_t^{\top}\mathop{\text{QR}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1}), reducing the condition number at the source.

Interestingly, Techniques 2, 3, and 4 all seem to point toward two \mathop{\text{QR}}s. Except that Technique 3 can use only one \mathop{\text{QR}} when \lambda is carefully tuned, the rest all require at least two \mathop{\text{QR}}s. It seems this is indeed the most reliable choice. In terms of speed, if Technique 3 can be tuned to use only one \mathop{\text{QR}}, it is the fastest; otherwise it is as fast as Technique 2. Technique 4 is somewhat unstable: if SCQR2 is applied to \boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}, the speed is very fast but the effect is poor; changing to \boldsymbol{M}_t^{\top}\mathop{\text{ColNorm}}(\boldsymbol{M}_t\boldsymbol{V}_{t-1}) can guarantee the effect, but the speed will drop.

The author recommends the combination of Techniques 1 and 2, which is relatively reliable in both effect and efficiency. In standalone tests, its speed is about half that of the Newton-Schulz iteration for \mathop{\text{msign}}. Some readers might think, “All that effort for only half the speed?” Actually, this is already quite ideal, considering that we are computing everything in FP32 and also need two \mathop{\text{QR}}s. On the other hand, the end-to-end time proportion of the \mathop{\text{msign}} step is only about 1%; doubling it only adds another 1% of time, which is still acceptable.

Furthermore, the efficiency of Newton-Schulz iteration depends on the number of iteration steps. If we use the coefficients from Polar Express to further increase the number of steps to improve precision, then the speed gap with our approach will further narrow. In summary, streaming power iteration is indeed slower, but it also yields richer and more accurate results (SVD), enabling more possibilities.

Conclusion

This article introduced techniques to further accelerate streaming power iteration. The essence is to find ways to reduce the condition number of the matrix, thereby increasing the success rate of Cholesky QR.

To reprint, please include the address of this article: https://kexue.fm/archives/11673

For more detailed reprint matters, please refer to: Science Space FAQ