English (unofficial) translations of posts at kexue.fm
Source

Steepest Descent on Manifolds: 3. Muon + Stiefel

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

As mentioned previously, when we shift our optimization target from vector parameters to matrix parameters and choose the spectral norm constraint—which is more suitable for matrices—the Muon optimizer emerges naturally. Furthermore, we considered the steepest descent direction after adding orthogonal constraints to the parameters. This was discussed in two parts: square matrices and non-square matrices. While the solution for square matrices was completed in the previous article, the non-square part remained unresolved.

The goal of this article is to provide the solution for the non-square case, thereby fully resolving optimization under orthogonal constraints.

Task Information

Let’s briefly review the results from the previous article "Steepest Descent on Manifolds: 2. Muon + Orthogonality". The objective we want to solve is: \begin{equation} \max_{\boldsymbol{\Phi}} \operatorname{tr}(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) \qquad \text{s.t.}\qquad \|\boldsymbol{\Phi}\|_2 = 1,\,\, \boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I},\,\,(\boldsymbol{W} - \eta \boldsymbol{\Phi})^{\top}(\boldsymbol{W} - \eta \boldsymbol{\Phi})=\boldsymbol{I} \end{equation} where \boldsymbol{W},\boldsymbol{\Phi}\in\mathbb{R}^{n\times m}(n \geq m), and \|\cdot\|_2 is the spectral norm. Based on the principle that "first-order approximation is sufficient," this can be simplified to: \begin{equation} \max_{\boldsymbol{\Phi}} \operatorname{tr}(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) \qquad \text{s.t.}\qquad \|\boldsymbol{\Phi}\|_2 = 1,\,\, \boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I},\,\,\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} = \boldsymbol{0} \label{eq:ori-obj} \end{equation} The set of all \boldsymbol{\Phi} satisfying \boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} = \boldsymbol{0} is also known as the "tangent space" of \boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}. In the previous article, we derived the general form of the solution: \begin{equation} \boldsymbol{\Phi} = \operatorname{msign}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}) \end{equation} where \boldsymbol{X}\in\mathbb{R}^{m\times m} is a symmetric matrix to be determined.

The remaining challenge is to provide a method for calculating the symmetric matrix \boldsymbol{X} such that \boldsymbol{W}^{\top}\boldsymbol{\Phi} is a skew-symmetric matrix. Once solved, the corresponding \boldsymbol{\Phi} is naturally the optimal solution. For n=m, we already obtained the closed-form solution \boldsymbol{X}=-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}. The real difficulty lies in the case where n > m, also known as the "Stiefel manifold," which was left as an Open problem in "Orthogonal manifold."

Equation Transformation

Simply put, our current task is to solve the system of equations: \begin{equation} \boldsymbol{W}^{\top}\operatorname{msign}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})+\operatorname{msign}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})^{\top}\boldsymbol{W} = \boldsymbol{0} \label{eq:start} \end{equation} When n=m, \boldsymbol{W}^{\top} can be directly absorbed into the \operatorname{msign} function, simplifying the solution. However, when n > m, this absorption is not possible, which is where the difficulty lies. I suspect there is no simple explicit solution for n > m, so we seek a numerical algorithm.

According to the definition \operatorname{msign}(\boldsymbol{M})=\boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2}, we can write: \begin{equation} \boldsymbol{W}^{\top}\operatorname{msign}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}) = \boldsymbol{W}^{\top}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})\boldsymbol{Q}^{-1} = (\boldsymbol{W}^{\top}\boldsymbol{G} + \boldsymbol{X})\boldsymbol{Q}^{-1} \end{equation} where \boldsymbol{Q} = ((\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})^{\top}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}))^{1/2}. Under this new notation, the system of equations becomes: \begin{equation} (\boldsymbol{W}^{\top}\boldsymbol{G} + \boldsymbol{X})\boldsymbol{Q}^{-1} + \boldsymbol{Q}^{-1}(\boldsymbol{G}^{\top}\boldsymbol{W} + \boldsymbol{X}) = \boldsymbol{0} \end{equation} Multiplying by \boldsymbol{Q} on both the left and right, we get: \begin{equation} \boldsymbol{Q}(\boldsymbol{W}^{\top}\boldsymbol{G} + \boldsymbol{X}) + (\boldsymbol{G}^{\top}\boldsymbol{W} + \boldsymbol{X})\boldsymbol{Q} = \boldsymbol{0} \label{eq:r-x} \end{equation} where \boldsymbol{Q} also satisfies: \begin{equation} \boldsymbol{Q} = (\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})^{\top}\operatorname{msign}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}) \label{eq:r-q} \end{equation}

Iterative Solution

My idea now is to start from some initial value of \boldsymbol{X}, substitute it into equation [eq:r-q] to obtain \boldsymbol{Q}, and then substitute \boldsymbol{Q} into the system [eq:r-x] to solve for a new \boldsymbol{X}, iterating until convergence. Given that \operatorname{msign} is known, equation [eq:r-q] can be calculated explicitly, so the only difficulty is solving the system [eq:r-x].

We can rearrange equation [eq:r-x] as: \begin{equation} \boldsymbol{Q}\boldsymbol{X} + \boldsymbol{X}\boldsymbol{Q} = -2[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}} \label{eq:r-xx} \end{equation} Given \boldsymbol{Q}, this is actually a linear system of equations for \boldsymbol{X}, known as the "continuous Lyapunov equation," which can also be seen as a special case of the "Sylvester equation." If we only use the CPU for calculation, Scipy already includes a solver for this equation, scipy.linalg.solve_continuous_lyapunov, which can be called directly.

As for the choice of initial value, we can consider the solution for the square matrix case, -[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}, which provides a natural transition from square to non-square matrices. We can also observe the rationality of the initial value -[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}} from another equivalent form of equation [eq:r-xx]: \begin{equation} \boldsymbol{Q}(\boldsymbol{X} + [\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}) + (\boldsymbol{X} + [\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}})\boldsymbol{Q} =[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}}\boldsymbol{Q} -\boldsymbol{Q}[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}} \end{equation} Thus, the accuracy of -[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}} depends on the commutativity of [\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}} and \boldsymbol{Q}. The closer they are to commuting, the more accurate -[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}} is. However, subsequent empirical results show that our iterative algorithm is not particularly sensitive to the initial value; even starting with a zero matrix works fine.

Doing It Yourself

We mentioned that Scipy has a built-in Lyapunov equation solver, so one can call it without worrying about the solution process. However, this is limited to Scipy on CPU. I checked, and neither Torch nor Jax has a similar function, so for GPU computation, we must be "self-reliant."

There are two ways to program the solution for equation [eq:r-xx]. One is to follow the idea in "What can the matrix sign function mcsgn calculate?" using \operatorname{mcsgn} (not \operatorname{msign}): \begin{equation} \boldsymbol{X} = \operatorname{mcsgn}\left(\begin{bmatrix}-\boldsymbol{Q} & -[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}} \\ \boldsymbol{0} & \boldsymbol{Q}\end{bmatrix}\right)_{[:m,m:]} \end{equation} The second is based on SVD, a method we used in "Derivative of msign" when calculating the gradient of \operatorname{msign}. Let’s introduce it again in the context of equation [eq:r-xx]. Since \boldsymbol{Q} is positive definite and symmetric by definition, it can be eigendecomposed as \boldsymbol{V}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}, where \boldsymbol{V} is an orthogonal matrix and \boldsymbol{\Sigma}=\operatorname{diag}(\sigma_1,\cdots,\sigma_m) is a diagonal matrix. Substituting this into equation [eq:r-xx], we get: \begin{equation} \boldsymbol{\Sigma}(\boldsymbol{V}^{\top}\boldsymbol{X}\boldsymbol{V}) + (\boldsymbol{V}^{\top}\boldsymbol{X}\boldsymbol{V})\boldsymbol{\Sigma} = -2\boldsymbol{V}^{\top}[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}\boldsymbol{V} \end{equation} The left side can be expressed as (\boldsymbol{V}^{\top}\boldsymbol{X}\boldsymbol{V})\otimes \boldsymbol{S}, where \otimes is the Hadamard product and \boldsymbol{S}_{i,j} = \sigma_i + \sigma_j. From this, we can solve for \boldsymbol{X}: \begin{equation} \boldsymbol{X} = -2\boldsymbol{V}((\boldsymbol{V}^{\top}[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}\boldsymbol{V})\oslash \boldsymbol{S})\boldsymbol{V}^{\top} \end{equation} where \oslash is the Hadamard division. An interesting point here is that eigendecomposing \boldsymbol{Q} is essentially equivalent to performing an SVD on \boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}. Since the SVD of \boldsymbol{G} + \boldsymbol{W}\boldsymbol{X} can also be used to compute \operatorname{msign}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}), a single SVD can compute both \operatorname{msign} and the solution to equation [eq:r-xx].

Both approaches have their characteristics. The first approach requires calculating \operatorname{msign} for an m\times m matrix and then \operatorname{mcsgn} for a 2m\times 2m matrix. Although both can be computed efficiently using Newton-Schulz iteration, the cost is still significant. Additionally, one must choose coefficients that ensure convergence and high precision (the results in "Newton-Schulz iteration for the msign operator (Part 2)" are recommended); otherwise, the calculations for \operatorname{mcsgn} and \operatorname{msign} won’t converge, let alone \boldsymbol{X}.

The second approach requires SVD. Although SVD has higher complexity and often requires FP32 precision, in this problem, each iteration only needs one SVD to solve for both \operatorname{msign} and \boldsymbol{X}, so the overall efficiency is not bad. If the number of matrix parameters requiring orthogonal constraints is not large, SVD might be the simplest choice.

Testing

Below we test these methods in Numpy. The main purpose is to verify the correctness of the methods themselves, so we use SVD and eigendecomposition directly to implement \operatorname{msign} and \operatorname{mcsgn}.

import numpy as np
import scipy as sp

def mcsgn(x):
    """Accurate calculation of mcsgn via eigendecomposition"""
    s, v = np.linalg.eig(x)
    return v @ np.diag(np.sign(s)) @ np.linalg.inv(v)

def msign(g):
    """Accurate calculation of msign via SVD"""
    u, s, vh = np.linalg.svd(g, full_matrices=False)
    return u @ np.diag(np.sign(s)) @ vh

def sym(x):
    """Symmetrization"""
    return (x + x.T) * 0.5

def skew(x):
    """Skew-symmetrization"""
    return (x - x.T) * 0.5

def proj(g, w):
    """Project onto the tangent space of the orthogonal manifold"""
    return g - w @ sym(w.T @ g)

def jianlin_by_mcsgn(g, w, steps=20):
    """Construct the iteration in this article via mcsgn"""
    n, m = g.shape
    x = -sym(w.T @ g)
    for i in range(1, steps + 1):
        phi = msign(z := g + w @ x)
        print('step:', i, ', inner product:', (phi * g).sum(), ', tangent error:', np.abs(sym(w.T @ phi)).mean())
        if i == steps:
            return phi
        q = z.T @ phi
        x = mcsgn(np.block([[-q, -sym(q @ w.T @ g)], [np.zeros_like(q), q]]))[:m, m:]
        # x = -2 * sp.linalg.solve_continuous_lyapunov(q, sym(q @ w.T @ g))

def jianlin_by_svd(g, w, steps=20):
    """Construct the iteration in this article via SVD"""
    x = -sym(w.T @ g)
    for i in range(1, steps + 1):
        u, s, vh = np.linalg.svd(z := g + w @ x, full_matrices=False)
        phi = (u * np.sign(s)) @ vh
        print('step:', i, ', inner product:', (phi * g).sum(), ', tangent error:', np.abs(sym(w.T @ phi)).mean())
        if i == steps:
            return phi
        x = -2 * vh.T @ (vh @ sym(z.T @ phi @ w.T @ g) @ vh.T / (s + s[:, None])) @ vh

def leloy_v1(g, w, steps=20):
    """Alternating projection between tangent space and orthogonal space"""
    phi = g
    for i in range(1, steps + 1):
        phi = msign(proj(phi, w))
        print('step:', i, ', inner product:', (phi * g).sum(), ', tangent error:', np.abs(sym(w.T @ phi)).mean())
    return phi

def leloy_v2(g, w, steps=20):
    """Partial greedy solution + line search (simplified by the author)"""
    n, m = g.shape
    taus = np.linspace(0, 1, steps + 2)[1:-1]
    p_max, tau_opt, phi_opt = 0, 0, None
    for tau in taus:
        b = (b := skew(w.T @ g)) * tau / max(np.linalg.norm(b, ord=2), 1e-8)
        r = np.linalg.cholesky(np.eye(m) - b.T @ b)
        c = msign((np.eye(n) - w @ w.T) @ g @ r) @ r
        phi = w @ b + c
        print('tau:', tau, ', inner product:', p := (phi * g).sum())
        if p > p_max:
            p_max, tau_opt, phi_opt = p, tau, phi
    print('best inner product:', p_max, ', tau:', tau_opt)
    return phi_opt

w = np.array([[ 0.69453734, -0.26590866, -0.44721806,  0.2753041 ],
              [-0.11738148, -0.5588003 , -0.17580748,  0.3218624 ],
              [-0.4515288 , -0.23489913, -0.26683152, -0.25739142],
              [ 0.02392521,  0.02664689,  0.48423648,  0.6193399 ],
              [ 0.45194831, -0.25206333,  0.27654836, -0.60242337],
              [ 0.21197332, -0.09174792,  0.24521762, -0.08484317],
              [-0.15496767, -0.26446804, -0.34942415, -0.01877318],
              [-0.16181251, -0.6474956 ,  0.45243263, -0.01776086]])

g = np.array([[-17.85745   , -10.758921  ,  -2.9583392 ,   6.245008  ],
              [-28.883093  ,  19.772121  ,   8.086545  , -21.564013  ],
              [ -1.6274693 , -14.96859   ,   3.4465332 ,   3.1070817 ],
              [ -7.8890743 ,   1.5304767 ,  -8.949573  ,   9.579629  ],
              [  2.246596  ,  14.46572   ,  12.8451    ,  -2.7370298 ],
              [ -0.9496974 ,   6.9879804 ,   2.849277  ,   1.1148484 ],
              [ -8.115278  , -18.054405  ,  -0.19287404,   7.0389237 ],
              [-15.062008  , -15.02901   ,   2.9083247 ,  21.706533  ]])

phi1 = jianlin_by_mcsgn(g, w, steps=100)
phi2 = jianlin_by_svd(g, w, steps=100)
phi3 = leloy_v1(g, w, steps=100)
phi4 = leloy_v2(g, w, steps=100)
assert np.allclose(phi1, phi2)

w = np.linalg.qr(np.random.randn(100, 50))[0]
g = np.random.randn(100, 50)

phi1 = jianlin_by_mcsgn(g, w, steps=10)
phi2 = jianlin_by_svd(g, w, steps=10)
phi3 = leloy_v1(g, w, steps=10)
phi4 = leloy_v2(g, w, steps=10)
assert np.allclose(phi1, phi2)

For the first set of \boldsymbol{W},\boldsymbol{G} given in the code, my method yields an optimal \operatorname{tr}(\boldsymbol{G}^{\top} \boldsymbol{\Phi}) of approximately 90, and the results from \operatorname{mcsgn} and SVD are identical. @leloy’s first method yields approximately 70, and the second method yields approximately 80, both falling short of the optimal solution.

However, the first set of \boldsymbol{W},\boldsymbol{G} was specifically chosen to highlight the differences between the three methods. If we use more random values, my solution and @leloy’s first solution are very close, and the number of iterations can be much smaller (5–10 steps). In these cases, @leloy’s second solution deviates more from the optimal. Readers can construct their own examples to test this.

Extended Thinking

Regarding the solution to the original problem [eq:ori-obj], we’ll pause here. Next, let’s discuss a few potentially confusing details.

First, for descriptive convenience, the iterative solution process I provided assumes that \boldsymbol{G} + \boldsymbol{W}\boldsymbol{X} is always full rank (rank m); otherwise, the matrix \boldsymbol{S} would have zero components, making \oslash\boldsymbol{S} difficult to handle. However, this difficulty is not fundamental because equation [eq:start] must have a solution. If a denominator is zero, the numerator must also be zero. Thus, we can simply replace zero components of \boldsymbol{S} with a small positive number to get the correct result.

From a numerical computation perspective, we rarely encounter singular values that are exactly zero, so there’s no need to worry much—just assume \boldsymbol{G} + \boldsymbol{W}\boldsymbol{X} is full rank. Under this assumption, the retraction operation becomes very simple because: \begin{equation} (\boldsymbol{W} - \eta\boldsymbol{\Phi})^{\top}(\boldsymbol{W} - \eta\boldsymbol{\Phi}) = \boldsymbol{W}^{\top} \boldsymbol{W} - \eta(\boldsymbol{W}^{\top} \boldsymbol{\Phi} + \boldsymbol{\Phi}^{\top}\boldsymbol{W}) + \eta^2 \boldsymbol{\Phi}^{\top}\boldsymbol{\Phi} \end{equation} According to the definition of the Stiefel manifold, the first term on the right is \boldsymbol{I}. According to the tangent space condition, the second term is \boldsymbol{0}. Finally, when full rank, the result of \operatorname{msign} is also a Stiefel manifold matrix, so the third term is \eta^2 \boldsymbol{I}. The total result is (1+\eta^2)\boldsymbol{I}. Retraction can be achieved by dividing by \sqrt{1+\eta^2}: \begin{equation} \boldsymbol{W}\quad\leftarrow\quad\frac{\boldsymbol{W} - \eta\boldsymbol{\Phi}}{\sqrt{1+\eta^2}} \end{equation}

At this point, you might notice a more profound question: whether dealing with the relatively simple orthogonal manifold or the more complex Stiefel manifold, what precision should we use for calculations? "Orthogonality" is a precise quantitative constraint; \boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I} consists of m(m+1)/2 equality constraints. It is foreseeable that using the above formula for iteration in low precision will eventually lead to a serious deviation from orthogonality, not to mention errors in the process of solving for \boldsymbol{\Phi}.

Therefore, I believe that unless we periodically apply an orthogonalization operation (i.e., \boldsymbol{W}\leftarrow\operatorname{msign}(\boldsymbol{W})) to pull the parameters back to the orthogonal manifold, the calculation precision during the solution process should be at least FP32. Since the number of parameters requiring orthogonal constraints is usually not large, this is generally not too high a price to pay.

Summary

This article generalizes "Muon + Orthogonal Manifold" from the previous post to the more general "Muon + Stiefel Manifold." The main discovery is an iterative algorithm for solving for the corresponding update.

Original address: https://kexue.fm/archives/11221