English (unofficial) translations of posts at kexue.fm
Source

Efficient Computation of Matrix $r$-th Roots and Inverse $r$-th Roots

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

In the previous article "Efficient Computation of Matrix Square Roots and Inverse Square Roots", the author proposed an elegant method for calculating matrix square roots and inverse square roots starting from the \mathop{\mathrm{mcsgn}} operator. Curiously, after simplification, the final formula no longer shows the original \mathop{\mathrm{mcsgn}} form. This naturally leads to deeper reflection: What is the more fundamental working principle of this scheme? Is there a possibility of generalizing it to an arbitrary r-th root?

Analyzing from this perspective, the author was pleasantly surprised to find that we can understand the previous iterative algorithm from a simpler angle, and in this new light, it can be easily generalized to the computation of arbitrary r-th roots and inverse r-th roots. In the following, we will share this process.

Review of Previous Work

Let \boldsymbol{G} \in \mathbb{R}^{m \times n} be an arbitrary matrix, and \boldsymbol{P} \in \mathbb{R}^{n \times n} be an arbitrary matrix with eigenvalues all within [0, 1]. The previous article gave: \begin{gather} \boldsymbol{G}_0 = \boldsymbol{G}, \quad \boldsymbol{P}_0 = \boldsymbol{P} \notag\\[6pt] \boldsymbol{G}_{t+1} = \boldsymbol{G}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2) \label{eq:r2-rsqrt}\\[6pt] \boldsymbol{P}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2)^2\boldsymbol{P}_t \label{eq:r3-rsqrt}\\[6pt] \lim_{t\to\infty} \boldsymbol{G}_t = \boldsymbol{G}\boldsymbol{P}^{-1/2}\notag \end{gather} By substituting \boldsymbol{G}=\boldsymbol{P}, we can obtain \boldsymbol{P}^{1/2}, and by substituting \boldsymbol{G}=\boldsymbol{I}, we can obtain \boldsymbol{P}^{-1/2}. Upon careful observation, we find that the above iteration is actually a manifestation of the following limit: \begin{equation} \prod_{t=0}^{\infty}(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2) = \boldsymbol{P}^{-1/2}\label{eq:prod-rsqrt} \end{equation} Interestingly, proving this limit directly is not complicated. By taking the square root of both sides of Eq. [eq:r3-rsqrt] and substituting it into the above expression, we get: \begin{equation} \prod_{t=0}^{\infty}(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2) = \prod_{t=0}^{\infty} \boldsymbol{P}_{t+1}^{1/2}\boldsymbol{P}_t^{-1/2} = \lim_{t\to\infty} \boldsymbol{P}_t^{1/2}\boldsymbol{P}_0^{-1/2} = \lim_{t\to\infty} \boldsymbol{P}_t^{1/2}\boldsymbol{P}^{-1/2} \end{equation} It can be seen that as long as the sequence \{\boldsymbol{P}_t\} remains invertible and its final limit is \boldsymbol{I}, the limit [eq:prod-rsqrt] automatically holds. As for how the iteration [eq:r3-rsqrt] allows \{\boldsymbol{P}_t\} to maintain these two conditions, we will discuss that shortly.

General Form

Let us consider the iteration in a general sense: \begin{gather} \boldsymbol{G}_0 = \boldsymbol{G}, \quad \boldsymbol{P}_0 = \boldsymbol{P} \notag\\[6pt] \boldsymbol{G}_{t+1} = \boldsymbol{G}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2)^s\\[6pt] \boldsymbol{P}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2)^r\boldsymbol{P}_t \end{gather} Similarly, if the sequence \{\boldsymbol{P}_t\} remains invertible and its final limit is \boldsymbol{I}, it can be proven that: \begin{equation} \lim_{t\to\infty} \boldsymbol{G}_t = \boldsymbol{G}\boldsymbol{P}^{-s/r} \end{equation} Thus, we have obtained a general iterative form for calculating an arbitrary -s/r power of a matrix. Based on this result, we only need to choose \boldsymbol{G}=\boldsymbol{P}, s=r-1 to obtain \boldsymbol{P}^{1/r}. Therefore, we only need to focus on solving for the inverse of powers between 0 and 1.

In this way, the problem becomes how to choose appropriate \{a_t, b_t, c_t\} so that the sequence \{\boldsymbol{P}_t\} converges to \boldsymbol{I} as quickly as possible. Faster convergence means we can reach the specified precision with fewer iteration steps.

Iteration Coefficients

According to the assumption, \boldsymbol{P}_0 = \boldsymbol{P} is a matrix with eigenvalues in [0, 1], and the target matrix \boldsymbol{I} is a matrix with eigenvalues all equal to 1. Thus, the sequence \{\boldsymbol{P}_t\} is essentially the process of transforming eigenvalues from any value in [0, 1] to 1. This is exactly what \mathop{\mathrm{mcsgn}} does!

Let \boldsymbol{X}_t = \boldsymbol{P}_t^{1/r}. Then \boldsymbol{X}_0 = \boldsymbol{P}^{1/r} is also a matrix with eigenvalues in [0, 1], and the iteration equation becomes: \begin{equation} \boldsymbol{X}_{t+1} = a_{t+1}\boldsymbol{X}_t + b_{t+1}\boldsymbol{X}_t^{r+1} + c_{t+1}\boldsymbol{X}_t^{2r+1} \end{equation} Now the problem becomes how to make \boldsymbol{X}_0 transform into \boldsymbol{I} as quickly as possible. This is essentially the same problem we discussed in "Newton-Schulz Iteration for the msign Operator (Part 1)" and "Newton-Schulz Iteration for the msign Operator (Part 2)". Specifically, "Part 2" gave the theoretical optimal solution for r=2, but its derivation and conclusions can be generalized to any r.

Specifically, we first transform the problem into a scalar iteration: \begin{equation} x_{t+1} = f_t(x_t) = a_{t+1}x_t + b_{t+1}x_t^{r+1} + c_{t+1}x_t^{2r+1} \end{equation} Then we prove that the greedy solution is the optimal solution, and finding the greedy solution becomes solving the equations: \begin{equation} \begin{gathered} f_t(l_t) = 1 - \mathcal{E}, \quad f_t(u_t) = 1 + \mathcal{E} \\ f_t(x_1) = 1 + \mathcal{E}, \quad f_t(x_2) = 1 - \mathcal{E} \\ f_t'(x_1) = 0, \quad f_t'(x_2) = 0 \end{gathered} \end{equation} For simplicity, parameterize f_t' as: \begin{equation} f_t'(x) = k(x^r-x_1^r)(x^r-x_2^r) \end{equation} Then, just like in "Part 2", we can use Mathematica to solve it.

Initial Analysis

However, before formally solving, we must analyze the initialization. In the previous article "Efficient Computation of Matrix Square Roots and Inverse Square Roots", we mentioned that under the assumption that the eigenvalues of \boldsymbol{P} are non-negative, we can compress the eigenvalues into [0, 1] by dividing by \mathop{\mathrm{tr}}(\boldsymbol{P}). However, this compression ratio is often too large. In this article, we change it to: \begin{equation} \boldsymbol{P}_0 = \frac{\boldsymbol{P}}{\sqrt{\mathop{\mathrm{tr}}(\boldsymbol{P}^2)}} \end{equation} We know that \mathop{\mathrm{tr}}(\boldsymbol{P}^2) equals the sum of the squares of all eigenvalues, while \mathop{\mathrm{tr}}(\boldsymbol{P})^2 equals the square of the sum of all eigenvalues. When eigenvalues are non-negative, \mathop{\mathrm{tr}}(\boldsymbol{P}^2) \leq \mathop{\mathrm{tr}}(\boldsymbol{P})^2 always holds, so the above formula provides a tighter initial value. In particular, calculating \mathop{\mathrm{tr}}(\boldsymbol{P}^2) does not require explicitly computing \boldsymbol{P}^2, as we have the identity: \begin{equation} \mathop{\mathrm{tr}}(\boldsymbol{P}^2) = \langle \boldsymbol{P}, \boldsymbol{P}^{\top}\rangle_F \end{equation}

Next, we need to analyze how small the eigenvalues we need to handle are, which is the same as the initial singular value analysis in "Newton-Schulz Iteration for the msign Operator (Part 1)". After dividing by \sqrt{\mathop{\mathrm{tr}}(\boldsymbol{P}^2)}, the eigenvalues of \boldsymbol{P}_0 form a unit vector. If the eigenvalues are all equal, then each eigenvalue is 1/\sqrt{n}. By the pigeonhole principle, in general, there must exist eigenvalues smaller than 1/\sqrt{n}. For safety, we accommodate down to 0.01/\sqrt{n}.

Considering large LLMs where n has reached 100^2, we need to accommodate down to 0.0001. Note that these are eigenvalues of \boldsymbol{P}_0, and \boldsymbol{X}_0 = \boldsymbol{P}_0^{1/r}, so for \boldsymbol{X}_0 we only need to accommodate down to 0.0001^{1/r}. This is more ideal than the cases for \mathop{\mathrm{mcsgn}} and \mathop{\mathrm{msign}}, because the input for \mathop{\mathrm{mcsgn}} and \mathop{\mathrm{msign}} is \boldsymbol{X}_0, and we need to accommodate small eigenvalues of \boldsymbol{X}_0, but here the input is \boldsymbol{P}_0, and we only need to consider starting from \boldsymbol{P}_0.

Calculation Results

Based on the above considerations, our final solving code (in Mathematica) is as follows:

r = 4;
df[x_] = k*(x^r - x1^r) (x^r - x2^r);
f[x_] = Integrate[df[x], {x, 0, x}];
sol[l_, u_] := 
 NSolve[{f[l] == 1 - e, f[x1] == 1 + e, f[x2] == 1 - e, f[u] == 1 + e,
    l < x1 < x2 < u, e > 0, k > 0}, {k, x1, x2, e}]
ff[x_, l_, u_] = f[x]*2/(f[l] + f[u]) // Expand;
lt = 0.0001^(1/r); ut = 1; lambda = 0.1;
While[1 - lt > 0.0001,
 fff[x_] = ff[x, lt, ut] /. sol[Max[lt, lambda*ut], ut][[1]];
 Print[fff[x]];
 lt = fff[lt]; ut = 2 - lt]
f[x] /. Solve[f[1] == 1, k][[1]] /. {x1 -> 1, x2 -> 1}

The calculation results for r=1 \sim 5 are as follows:

r t a b c
1 14.2975 -31.2203 18.9214
2 7.12258 -7.78207 2.35989
1 3 6.9396 -7.61544 2.3195
4 5.98456 -6.77016 2.12571
5 3.79109 -4.18664 1.39555
\geq 6 3 -3 1
1 7.42487 -18.3958 12.8967
2 3.48773 -2.33004 0.440469
2 3 2.77661 -2.07064 0.463023
4 1.99131 -1.37394 0.387593
\geq 5 15/8 -5/4 3/8
1 5.05052 -13.5427 10.2579
2 2.31728 -1.06581 0.144441
3 3 1.79293 -0.913562 0.186699
4 1.56683 -0.786609 0.220008
\geq 5 14/9 -7/9 2/9
1 3.85003 -10.8539 8.61893
4 2 1.80992 -0.587778 0.0647852
3 1.50394 -0.594516 0.121161
\geq 4 45/32 -9/16 5/32
1 3.11194 -8.28217 6.67716
5 2 1.5752 -0.393327 0.0380364
3 1.3736 -0.44661 0.0911259
\geq 4 33/25 -11/25 3/25

The convergence values for the last step are derived from x_1=x_2=1 and f(1)=1.

Testing

A simple test code is as follows:

import numpy as np
import jax.numpy as jnp

coefs = [
    None,
    [
        (14.2975, -31.2203, 18.9214),
        (7.12258, -7.78207, 2.35989),
        (6.9396, -7.61544, 2.3195),
        (5.98456, -6.77016, 2.12571),
        (3.79109, -4.18664, 1.39555),
        (3, -3, 1),
    ],
    [
        (7.42487, -18.3958, 12.8967),
        (3.48773, -2.33004, 0.440469),
        (2.77661, -2.07064, 0.463023),
        (1.99131, -1.37394, 0.387593),
        (15 / 8, -5 / 4, 3 / 8),
    ],
    [
        (5.05052, -13.5427, 10.2579),
        (2.31728, -1.06581, 0.144441),
        (1.79293, -0.913562, 0.186699),
        (1.56683, -0.786609, 0.220008),
        (14 / 9, -7 / 9, 2 / 9),
    ],
    [
        (3.85003, -10.8539, 8.61893),
        (1.80992, -0.587778, 0.0647852),
        (1.50394, -0.594516, 0.121161),
        (45 / 32, -9 / 16, 5 / 32),
    ],
    [
        (3.11194, -8.28217, 6.67716),
        (1.5752, -0.393327, 0.0380364),
        (1.3736, -0.44661, 0.0911259),
        (33 / 25, -11 / 25, 3 / 25),
    ],
]

def abc(r=1, steps=None, scale=1):
    w, steps = coefs[r], steps or len(coefs[r])
    for a, b, c in w[:steps] + w[-1:] * max(steps - len(w), 0):
        yield a / scale, b / scale**(r + 1), c / scale**(2 * r + 1)

def matmul_invroot(G, P, r, s=1, steps=None, eps=1e-5):
    """return G @ P^(-s/r)
    """
    I = jnp.eye(P.shape[0], dtype=P.dtype)
    P = P / (t := (P * P.mT).sum()**0.5) + eps * I
    for a, b, c in abc(r, steps, 1.001):
        W = a * I + b * P + c * P @ P
        W1, W2 = jnp.linalg.matrix_power(W, s), jnp.linalg.matrix_power(W, r)
        G, P = G @ W1, P @ W2
    return G * t**(-s / r)

def matmul_invroot_by_eigh(G, P, r, s=1):
    """return G @ P^(-s/r)
    """
    S, Q = jnp.linalg.eigh(P)
    return G @ Q @ jnp.diag(S**(-s / r)) @ jnp.linalg.inv(Q)

d = 1000
s, r = 1, 4
G = np.random.randn(2 * d, d) / d**0.5
P = (x := np.random.randn(d, d) / d**0.5) @ x.T + 0.001 * np.eye(d)

X1 = matmul_invroot_by_eigh(G, P, r, s)
X2 = matmul_invroot(G, P, r, s, eps=0)
print(jnp.abs(X1 - X2).mean())  # ~= 1e-3

X2 = matmul_invroot(jnp.array(G, dtype='bfloat16'), jnp.array(P, dtype='bfloat16'), r, s, eps=0)
print(jnp.abs(X1 - X2).mean())  # ~= 2e-3

There are several points to note. First, the minimum eigenvalue of the input \boldsymbol{P} cannot be too small, otherwise the iteration process is extremely prone to exploding, even if we only want to calculate positive powers like \boldsymbol{P}^{1/2}. This is understandable because \sqrt{x} is ill-conditioned at x=0. Once it "accidentally" reaches the negative half-axis due to error, a real solution no longer exists, and the behavior of the iteration becomes unpredictable.

How small is "too small"? Roughly, the minimum eigenvalue of \boldsymbol{P}/\sqrt{\mathop{\mathrm{tr}}(\boldsymbol{P}^2)} should not be significantly smaller than the minimum eigenvalue we considered, which is 0.0001. If this cannot be guaranteed, it is recommended to set: \begin{equation} \boldsymbol{P}_0 = \frac{\boldsymbol{P}}{\sqrt{\mathop{\mathrm{tr}}(\boldsymbol{P}^2)}} + \epsilon \cdot \boldsymbol{I} \end{equation} where \epsilon \sim 0.0001. This will lose a bit of precision but significantly increase numerical stability.

Furthermore, in most cases, the number of iteration steps does not need to exceed the recommended value len(coefs[r]), especially in low-precision scenarios, because more steps increase the risk of explosion due to accumulated errors. As long as the eigenvalues are within the considered range, the recommended steps are sufficient to achieve ideal precision. Unless using fp32 or higher precision, one might consider setting \epsilon=0, scale=1, and using more iteration steps.

Summary

This article generalizes the results of the previous article to the calculation of arbitrary r-th roots and inverse r-th roots, obtaining a general iterative format for any -s/r power of a matrix.

Reprinting: Please include the original address of this article: https://kexue.fm/archives/11175

For more details on reprinting, please refer to: "Scientific Space FAQ"