English (unofficial) translations of posts at kexue.fm
Source

Newton-Schulz Iteration for the msign Operator (Part 2)

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

In the previous post "Newton-Schulz Iteration for the msign Operator (Part 1)", we attempted to find a better Newton-Schulz iteration for the \operatorname{msign} operator to achieve the highest possible approximation within a limited number of iteration steps. This process can be transformed into finding a polynomial iteration of the same form for the scalar function \operatorname{sign}(x). At that time, our approach was to use the Adam optimizer to find a local optimal solution in an end-to-end manner, which was effective but somewhat crude.

A few days ago, a new paper appeared on arXiv titled "The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm". The authors utilized a series of exquisite mathematical conclusions to provide a more elegant and hardcore answer. In this article, let us appreciate and learn from this brilliant paper.

Problem Description

We will not repeat the background and transformation process; we directly state the problem to be solved as: \begin{equation} \operatorname{argmin}_f d(f(x), 1) \end{equation} where f = f_T \circ \dots \circ f_2 \circ f_1, \circ represents function composition, f_t(x) is an odd polynomial (containing only odd powers of x), and d(f(x), 1) is a metric measuring the distance between the function f(x) and 1. In the previous article, we uniformly selected a finite number of points in [0, 1] and took the average of the k largest values of |f(x) - 1| as the metric. In this article, we directly take the maximum value of |f(x) - 1| within the interval as the metric, i.e., \begin{equation} \operatorname{argmin}_f \max_{x \in [l, u]} |f(x) - 1| \label{eq:opt} \end{equation} where [l, u] \subset [0, 1]. Note that while u can be taken as 1, l cannot be 0 because f(0) is always 0, which means the above expression would always be at least 1 and unable to converge. Therefore, l must be a number very close to 0. According to the analysis in the previous article, for universality, we should account for singular values as small as 0.001, so we can consider l = 0.001.

Before starting the analysis, let’s briefly explain the meaning of the word "Polar" in the paper’s title. It actually represents the "Polar Decomposition" of a matrix:

Polar Decomposition: For a square matrix \boldsymbol{M} \in \mathbb{R}^{n \times n}, its polar decomposition is \boldsymbol{M} = \boldsymbol{Q}\boldsymbol{S}, where \boldsymbol{Q} is an orthogonal matrix and \boldsymbol{S} is a positive semi-definite matrix.

If the SVD of \boldsymbol{M} is \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}, then we have: \begin{equation} \boldsymbol{M} = \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top} = (\boldsymbol{U}\boldsymbol{V}^{\top})(\boldsymbol{V}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}) \end{equation} Thus \boldsymbol{Q} = \boldsymbol{U}\boldsymbol{V}^{\top} and \boldsymbol{S} = \boldsymbol{V}\boldsymbol{\Sigma}\boldsymbol{V}^{\top} is one solution for the polar decomposition. We know that when \boldsymbol{M} is a full-rank matrix, \boldsymbol{U}\boldsymbol{V}^{\top} is exactly \operatorname{msign}(\boldsymbol{M}). This is why \operatorname{msign} is associated with "Polar"; once it is calculated, the "Polar Decomposition" of the matrix can be obtained. In other words, the essential difficulty of polar decomposition is calculating \operatorname{msign}, which is consistent with the Muon algorithm.

Greedy is Enough

Returning to the main topic. Regarding problem [eq:opt], the first conclusion of the original paper—and arguably the most core conclusion of the entire paper—is: its greedy solution is exactly its global optimal solution! Formally, it means that the solution to problem [eq:opt] can be transformed into: \begin{equation} \begin{gathered} f^* = f_T^* \circ \dots \circ f_2^* \circ f_1^* \\[12pt] f_1^* = \operatorname{argmin}_{f_1} \max_{x \in [l_1, u_1]} |f_1(x) - 1| \\ f_2^* = \operatorname{argmin}_{f_2} \max_{x \in [l_2, u_2]} |f_2(x) - 1| \\ \vdots \\ f_T^* = \operatorname{argmin}_{f_T} \max_{x \in [l_T, u_T]} |f_T(x) - 1| \\[24pt] l_1 = l, \quad u_1 = u, \\[8pt] l_{t+1} = \min_{x \in [l_t, u_t]} f_t^*(x), \quad u_{t+1} = \max_{x \in [l_t, u_t]} f_t^*(x) \end{gathered} \end{equation}

I believe this conclusion will surprise many readers; I was also quite astonished and struck by its brilliance when I first saw it. It not only greatly reduces the difficulty of the solution—transforming a T-step composite function problem into a step-by-step solution for single polynomials—but also allows us to push the solution forward step by step while maintaining optimality (i.e., the optimal solution for T+1 steps only requires calculating one more step based on the T-step optimal solution, without recalculating from scratch).

It is worth noting that this conclusion allows each f_t to have a different degree (where "degree" refers to the highest power of the polynomial). For example, f_1 could be 3rd degree, f_2 could be 5th degree, and so on, yet the conclusion that "the greedy solution is the global optimal solution" remains unchanged. However, for simplicity, we will keep all f_t at the same degree below, primarily considering 3rd and 5th-degree results.

The complete proof of this conclusion is slightly complex, so we will place it at the end and first complete the subsequent operations based on this conclusion.

Equioscillation

Since we have transformed the original problem into finding greedy solutions, we now only need to focus on solving: \begin{equation} \operatorname{argmin}_{f_t} \max_{x \in [l_t, u_t]} |f_t(x) - 1| \label{eq:local} \end{equation} To solve the above, we need to understand the "Equioscillation Theorem" for odd polynomials introduced in "Equioscillation Theorem: Necessary and Sufficient Conditions for Optimal Polynomial Approximation":

Equioscillation Theorem (Odd): Let f(x) be an odd polynomial of degree at most 2n+1, and g(x) be a continuous function on the interval [a, b] \subset (0, \infty). Then \begin{equation} f^* = \operatorname{argmin}_f \max_{x \in [a, b]} |f(x) - g(x)| \end{equation} if and only if there exist a \leq x_0 < x_1 < \dots < x_{n+1} \leq b and \sigma \in \{0, 1\} such that \begin{equation} f^*(x_k) - g(x_k) = (-1)^{k+\sigma} \max_{x \in [a, b]} |f^*(x) - g(x)| \end{equation}

Now we are solving for f_t, and the target g is identically 1. The Equioscillation Theorem tells us that |f_t^*(x) - 1| reaches the maximum error (denoted as \mathcal{E}) at least n+2 times in [l_t, u_t]. It is not hard to see that the maximum points of |f_t^*(x) - 1| can only be boundary points or local extrema of f_t^*(x). Since a (2n+1)-th degree odd polynomial has at most n extrema in (0, \infty), to "gather" enough n+2 points, we "must" include the boundary points. This determines x_0 = l_t, x_{n+1} = u_t, while x_1, \dots, x_n are the zeros of \frac{d}{dx}f_t^*(x).

Furthermore, since the target function is 1, the slope of f_t^*(x) at x=0 is greater than zero, so l_t must be a minimum point of f_t^*(x), hence \sigma=1. Combining these results, we are actually solving the system of equations: \begin{equation} f_t(l_t) = 1 - \mathcal{E}, \quad f_t(u_t) = 1 + (-1)^n \mathcal{E}, \quad f_t(x_i) = 1 + (-1)^{i+1}\mathcal{E}, \quad f_t'(x_i) = 0 \end{equation} where i = 1, 2, 3, \dots, n. We can see there are 2n+2 equations and unknowns. By adding the constraints l_t < x_1 < \dots < x_n < u_t and \mathcal{E} > 0, the solution can theoretically be determined.

Solving the Equations

For a 3rd-degree odd polynomial (n=1), the original paper provides an analytical solution. For a 5th-degree odd polynomial (n=2), the paper provides an iterative algorithm: first fix x_1, x_2 to solve for coefficients a, b, c, then fix a, b, c to solve for x_1, x_2, and iterate repeatedly. This is essentially a simplified version of the Remez algorithm.

However, the iteration in the original paper relies on root-finding formulas for x_1, x_2, which is not easy for larger n. Therefore, I will change the solving approach here. First, parameterize f_t'(x) using x_1, x_2, \dots, x_n: \begin{equation} f_t'(x) = k(x^2 - x_1^2)(x^2 - x_2^2) \dots (x^2 - x_n^2) \end{equation} Then f_t(x) = \int_0^x f_t'(x) dx. Thus, we have expressed f_t(x) using k and x_1, x_2, \dots, x_n. We then only need to solve the system: \begin{equation} f_t(l_t) = 1 - \mathcal{E}, \quad f_t(u_t) = 1 + (-1)^n \mathcal{E}, \quad f_t(x_i) = 1 + (-1)^{i+1}\mathcal{E} \end{equation} while avoiding solving f_t'(x) = 0. When n=1, we can solve: \begin{equation} x_1 = \sqrt{\frac{l_t^2 + l_t u_t + u_t^2}{3}}, \quad k = -\frac{6}{l_t^2 u_t + l_t u_t^2 + 2x_1^3} \end{equation} When n > 1, we can hand it over to Mathematica. For example, when n=2:

df[x_] = k*(x^2 - x1^2) (x^2 - x2^2);
f[x_] = Integrate[df[x], {x, 0, x}];
sol = NSolve[{f[l] == 1 - e, f[x1] == 1 + e, f[x2] == 1 - e, 
    f[u] == 1 + e, l < x1 < x2 < u, e > 0} /. {l -> 0.001, 
    u -> 1}, {k, x1, x2, e}, Reals]
f[x] /. sol

Finite Precision

At this point, it seems we have completed the solution to the original problem. Theoretically yes, but only for infinite precision. In actual calculations, precision is finite—especially since the Muon optimizer uses bfloat16, where precision loss is more severe—which brings about some issues.

The first issue is that each f_t^* is theoretically responsible only for the interval [l_t, u_t], but under finite precision, singular values might deviate from this interval. When n is even (i.e., f_t^* is a 5th, 9th, ... degree odd polynomial), there is a risk of divergence if x > u_t, because f_t^*(x) monotonically increases to infinity for x > u_t. If one is not careful, it will diverge with iterations. There are two solutions: one is to leave a slightly loose margin for [l_t, u_t] when solving for f_t^*, and the other is to keep the interval unchanged but divide the input by a number slightly greater than 1 after obtaining f_t^*.

The original paper uses the latter, changing f_t^*(x) to f_t^*(x / 1.01). The number 1.01 is approximately the first number after 1 in bfloat16 precision (the exact value is 1.00781). Obviously, this is to prevent singular values from expanding from 1 to the next representable value due to numerical error. If calculating in higher precision, this value can be appropriately reduced.

The second issue is more subtle. Let’s introduce it with a specific example. Suppose n=2, l_1=0.001, u_1=1. We can find f_1^* to be: \begin{equation} f_1^*(x) = 8.4703 x - 25.1081 x^3 + 18.6293 x^5 \end{equation} where x_1 = 0.3674, x_2 = 0.8208, \mathcal{E} = 0.9915. What is the problem with this solution? According to the Equioscillation Theorem, we know f_1^*(x_2) = 1 - \mathcal{E} = 0.0085, meaning it maps 0.8208 to 0.0085. However, our ultimate goal is to turn all numbers in (0, 1] into 1. Thus, f_1^* maps 0.8208, which is already quite close to the target, to 0.0085, which is very far from it. Although f_2^*, f_3^*, \dots will theoretically pull it back gradually, repeatedly shrinking and then expanding a number under finite precision leads to significant cumulative error.

Of course, from the Equioscillation Theorem, we know this oscillatory behavior is unavoidable. We can only hope that the maximum error \mathcal{E} is not too close to 1, thereby slowing down this cumulative error. It is easy to see that the larger the interval [l_t, u_t], the harder it is to fit theoretically, and the closer \mathcal{E} will be to 1. Therefore, the paper introduces a hyperparameter \lambda \in (0, 1), changing the optimization interval from [l_t, u_t] to [\max(l_t, \lambda u_t), u_t]. By limiting the interval size, we ensure \mathcal{E} is not too large. (Note: the paper uses \lambda = 0.1 in the text, but the appendix code actually uses \lambda \approx 0.024.)

But then, wouldn’t the original l_t, especially the l we set at the beginning, be easily ignored? To solve this, the paper introduces the "Recenter" trick: if the optimization interval is [l_t, u_t], then f_t^*(l_t) + f_t^*(u_t) = 2 will be satisfied. After changing the interval to [\max(l_t, \lambda u_t), u_t], this might not hold. At this point, we multiply f_t^* by \gamma to satisfy this equality: \begin{equation} \gamma f_t^*(l_t) + \gamma f_t^*(u_t) = 2 \quad \Rightarrow \quad \gamma = \frac{2}{f_t^*(l_t) + f_t^*(u_t)} \end{equation} This incorporates the original l_t back into consideration.

Reference Code

Here is the complete Mathematica code for n=2:

df[x_] = k*(x^2 - x1^2) (x^2 - x2^2);
f[x_] = Integrate[df[x], {x, 0, x}];
sol[l_, u_] := 
 NSolve[{f[l] == 1 - e, f[x1] == 1 + e, f[x2] == 1 - e, f[u] == 1 + e,
    l < x1 < x2 < u, e > 0, k > 0}, {k, x1, x2, e}]
ff[x_, l_, u_] = f[x]*2/(f[l] + f[u]) // Expand;
lt = 0.001; ut = 1; lambda = 0.02407327424182761;
While[1 - lt > 0.0001,
 fff[x_] = ff[x, lt, ut] /. sol[Max[lt, lambda*ut], ut][[1]];
 Print[fff[x]];
 lt = fff[lt]; ut = 2 - lt]

The results are as follows (f_t(x) = a_t x + b_t x^3 + c_t x^5):

t a \times 1.01 b \times 1.01^3 c \times 1.01^5
1 8.28721 -23.5959 17.3004
2 4.10706 -2.94785 0.544843
3 3.94869 -2.9089 0.551819
4 3.31842 -2.48849 0.510049
5 2.30065 -1.6689 0.418807
6 1.8913 -1.268 0.376804
7 1.875 -1.25 0.375
8 1.875 -1.25 0.375

Note that the results given here are before the f_t^*(x / 1.01) processing, so the actual a, b, c should be divided by 1.01^1, 1.01^3, 1.01^5 respectively. The reason for not giving the results after division is that the convergence values 1.875, -1.25, 0.375 (for t \geq 7) are much cleaner and easier to appreciate. (Thought exercise: Please prove that the final convergence values can be solved from x_1=x_2=1 and f(1)=1.)

The code from the author’s appendix is organized as follows:

import numpy as np

def optimal_quintic(l, u):
    assert 0 <= l <= u
    if 1 - 5e-6 <= l / u:
        # Above this threshold, the equoscillating polynomials
        # is numerically equal to...
        return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
    # This initialization becomes exact as l -> u
    q = (3 * l + 1) / 4
    r = (l + 3) / 4
    E, old_E = np.inf, None
    while not old_E or abs(old_E - E) > 1e-15:
        old_E = E
        LHS = np.array([
            [l, l**3, l**5, 1],
            [q, q**3, q**5, -1],
            [r, r**3, r**5, 1],
            [u, u**3, u**5, -1],
        ])
        a, b, c, E = np.linalg.solve(LHS, np.ones(4))
        q, r = np.sqrt(
            (-3 * b + np.array([-1, 1]) * np.sqrt(9 * b**2 - 20 * a * c)) /
            (10 * c)
        )
    return float(a), float(b), float(c)

def optimal_composition(l, num_iters, cushion=0.02407327424182761):
    u = 1
    coefficients = []
    for _ in range(num_iters):
        a, b, c = optimal_quintic(max(l, cushion * u), u)
        # Due to cushioning, this may be centered around 1 with
        # respect to 0.024*u, u. Recenter it around 1 with respect
        # to l, u, meaning find c so that 1 - c*p(l) = c*p(u) - 1:
        pl = a * l + b * l**3 + c * l**5
        pu = a * u + b * u**3 + c * u**5
        rescalar = 2 / (pl + pu)
        a *= rescalar
        b *= rescalar
        c *= rescalar
        # Optionally incorporate safety factor here:
        # a /= 1.01; b /= 1.01**3; c /= 1.01**5
        coefficients.append((a, b, c))
        l = a * l + b * l**3 + c * l**5
        u = 2 - l
    return coefficients

print(*optimal_composition(1e-3, 10), sep="\n")

Completing the Proof

In the final section, let’s provide the proof that "the greedy solution is exactly the global optimal solution."

According to the Equioscillation Theorem, we know the range of f_t^* is [l_{t+1}, u_{t+1}], where l_{t+1} = f_t^*(l_t) and u_{t+1} = 2 - l_{t+1}. From this, the maximum error of the T-step greedy solution is \mathcal{E}_T = 1 - l_{T+1} = 1 - f_T^*(l_T). We only need to prove that the maximum error of the T-step global optimal solution can only be reduced to 1 - f_T^*(l_T) to obtain the conclusion.

The proof uses mathematical induction. Assume the conclusion holds for t = 1, 2, \dots, T-1. Then \hat{f} = f_{T-1}^* \circ \dots \circ f_2^* \circ f_1^* is the global optimal solution for T-1 steps, with range [l_T, u_T] and maximum error \mathcal{E}_{T-1} = 1 - l_T = u_T - 1. On the other hand, let \tilde{f} = \tilde{f}_{T-1} \circ \dots \circ \tilde{f}_2 \circ \tilde{f}_1 be any (T-1)-step solution with range [a, b]. Let c = \frac{2}{a+b}, then the range of c\tilde{f} is [ca, cb]. Clearly ca \leq 1, cb \geq 1. According to the induction hypothesis, we have: \begin{equation} \begin{aligned} 1 - ca \geq \mathcal{E}_{T-1} \\ cb - 1 \geq \mathcal{E}_{T-1} \end{aligned} \quad \Rightarrow \quad \frac{a}{b} \leq \frac{1 - \mathcal{E}_{T-1}}{1 + \mathcal{E}_{T-1}} = \frac{l_T}{u_T} \end{equation} That is, the relative size of the range of any (T-1)-step solution is no smaller than the relative size of the range [l_T, u_T] of the (T-1)-step optimal solution. Then we have: \begin{equation} \begin{aligned} \min_{f_T} \max_{x \in [l, u]} |f_T(\tilde{f}(x)) - 1| &= \min_{f_T} \max_{x \in [a, b]} |f_T(x) - 1| \\ &= \min_{f_T} \max_{x \in [a/b, 1]} |f_T(x) - 1| \\ &\geq \min_{f_T} \max_{x \in [l_T/u_T, 1]} |f_T(x) - 1| \\ &= \min_{f_T} \max_{x \in [l_T, u_T]} |f_T(x) - 1| \\ &= \mathcal{E}_T \end{aligned} \end{equation} In other words, if you take any other (T-1)-step solution, the maximum error can at best be as small as the greedy solution. Thus, the maximum error of the greedy solution is already globally optimal, completing the recursive proof. The key step in the above equation is: \begin{equation} \min_{f_T} \max_{x \in [a, b]} |f_T(x) - 1| = \min_{f_T} \max_{x \in [a/b, 1]} |f_T(x) - 1| \end{equation} This is because we can always set g_T(y) = f_T(b y). g_T still represents any odd polynomial of the same degree, so g_T and f_T are in the same function space, and the notation can be swapped: \begin{equation} \min_{f_T} \max_{x \in [a, b]} |f_T(x) - 1| = \min_{g_T} \max_{y \in [a/b, 1]} |g_T(y) - 1| = \min_{f_T} \max_{x \in [a/b, 1]} |f_T(x) - 1| \end{equation}

Summary

This article introduced the latest progress in finding better Newton-Schulz iterations for the \operatorname{msign} operator. By using the Equioscillation Theorem and greedy transformation, it directly solves for the theoretical optimal solution. The entire process is quite hardcore and well worth learning.

Reprinting: Please include the original address of this article: https://kexue.fm/archives/10996

For more detailed reprinting matters, please refer to: "Scientific Space FAQ"