English (unofficial) translations of posts at kexue.fm
Source

Efficient Computation of Matrix Square Roots and Inverse Square Roots

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

Let \boldsymbol{P} \in \mathbb{R}^{n \times n} be a square matrix of order n whose eigenvalues are all non-negative real numbers. This article discusses the calculation of its square root \boldsymbol{P}^{1/2} and inverse square root \boldsymbol{P}^{-1/2}.

Basic Concepts

The square root of a matrix \boldsymbol{P} refers to a matrix \boldsymbol{X} that satisfies \boldsymbol{X}^2 = \boldsymbol{P}. We know that positive numbers have two square roots, so it is not hard to imagine that matrix square roots are generally not unique. However, the "arithmetic square root" is unique. Just as the arithmetic square root of a positive number is the positive one, we refer to the square root of \boldsymbol{P} whose eigenvalues are all non-negative as the arithmetic square root. In this article, the matrix square roots we seek are, by default, arithmetic square roots.

The calculations in this article rely on the matrix sign function discussed in "What can the matrix sign function mcsgn calculate?": \begin{equation} \operatorname{mcsgn}(\boldsymbol{M}) = (\boldsymbol{M}^2)^{-1/2}\boldsymbol{M} = \boldsymbol{M}(\boldsymbol{M}^2)^{-1/2} \end{equation} Simply put, it transforms any matrix \boldsymbol{M} \in \mathbb{R}^{n \times n} into a new matrix where the eigenvalues are replaced by their corresponding sign function values. Assuming the eigenvalues of \boldsymbol{M} are all real, \operatorname{mcsgn} can be efficiently computed via Newton-Schulz iteration: \begin{equation} \boldsymbol{X}_0 = \frac{\boldsymbol{M}}{\sqrt{\operatorname{tr}(\boldsymbol{M}^2)}}, \qquad \boldsymbol{X}_{t+1} = a_{t+1}\boldsymbol{X}_t + b_{t+1}\boldsymbol{X}_t^3 + c_{t+1}\boldsymbol{X}_t^5 \end{equation} where \frac{\boldsymbol{M}}{\sqrt{\operatorname{tr}(\boldsymbol{M}^2)}} is used to scale the eigenvalues of \boldsymbol{X}_0 into the range [-1, 1], and a_t, b_t, c_t are the coefficients derived in "Newton-Schulz iteration for the msign operator (Part 2)":

t a \times 1.01 b \times 1.01^3 c \times 1.01^5
1 8.28721 -23.5959 17.3004
2 4.10706 -2.94785 0.544843
3 3.94869 -2.9089 0.551819
4 3.31842 -2.48849 0.510049
5 2.30065 -1.6689 0.418807
6 1.8913 -1.268 0.376804
7 1.875 -1.25 0.375
8 1.875 -1.25 0.375

In fact, when the eigenvalues of \boldsymbol{M} are all real, the calculation principle of \operatorname{mcsgn} is consistent with another matrix sign function \operatorname{msign}.

Calculation Principle

The starting point for the following calculation is the identity: \begin{equation} \operatorname{mcsgn}\left(\begin{bmatrix}\boldsymbol{0} & \boldsymbol{A} \\ \boldsymbol{B} & \boldsymbol{0}\end{bmatrix}\right) = \begin{bmatrix}\boldsymbol{0} & \boldsymbol{A}(\boldsymbol{B}\boldsymbol{A})^{-1/2} \\ \boldsymbol{B}(\boldsymbol{A}\boldsymbol{B})^{-1/2} & \boldsymbol{0}\end{bmatrix} \label{eq:core} \end{equation} This can be verified by directly substituting into the definition of \operatorname{mcsgn} (Note: \boldsymbol{A}, \boldsymbol{B} are not necessarily square matrices). Next, we need to determine under what conditions the eigenvalues of the matrix on the left side are all real. Let \lambda be a non-zero eigenvalue; then: \begin{equation} 0 = \det\left(\lambda\boldsymbol{I} - \begin{bmatrix}\boldsymbol{0} & \boldsymbol{A} \\ \boldsymbol{B} & \boldsymbol{0} \end{bmatrix}\right) = \det\left(\begin{bmatrix}\lambda\boldsymbol{I} & -\boldsymbol{A} \\ -\boldsymbol{B} & \lambda\boldsymbol{I} \end{bmatrix}\right) = \det(\lambda^2 \boldsymbol{I} - \boldsymbol{A}\boldsymbol{B}) \end{equation} That is, \lambda^2 is an eigenvalue of the matrix \boldsymbol{A}\boldsymbol{B}. This means that all eigenvalues of the above block matrix are real if and only if all eigenvalues of \boldsymbol{A}\boldsymbol{B} are non-negative.

While iterating directly on the original matrix is possible, it is computationally wasteful. We can exploit its anti-diagonal structure to reduce the computational load. Since: \begin{equation} \begin{bmatrix}\boldsymbol{0} & \boldsymbol{Y} \\ \boldsymbol{Z} & \boldsymbol{0}\end{bmatrix}^3 = \begin{bmatrix}\boldsymbol{0} & (\boldsymbol{Y}\boldsymbol{Z})\boldsymbol{Y} \\ \boldsymbol{Z}(\boldsymbol{Y}\boldsymbol{Z}) & \boldsymbol{0}\end{bmatrix}, \quad \begin{bmatrix}\boldsymbol{0} & \boldsymbol{Y} \\ \boldsymbol{Z} & \boldsymbol{0}\end{bmatrix}^5 = \begin{bmatrix}\boldsymbol{0} & (\boldsymbol{Y}\boldsymbol{Z})^2\boldsymbol{Y} \\ \boldsymbol{Z}(\boldsymbol{Y}\boldsymbol{Z})^2 & \boldsymbol{0}\end{bmatrix} \end{equation} We obtain the iterations: \begin{gather} \boldsymbol{Y}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)\boldsymbol{Y}_t \label{eq:r1} \\[6pt] \boldsymbol{Z}_{t+1} = \boldsymbol{Z}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2) \label{eq:r2} \end{gather} Then \boldsymbol{Y}_t \to \boldsymbol{A}(\boldsymbol{B}\boldsymbol{A})^{-1/2} and \boldsymbol{Z}_t \to \boldsymbol{B}(\boldsymbol{A}\boldsymbol{B})^{-1/2}. Specifically, multiplying the two equations above yields the recursion for \boldsymbol{Y}_t\boldsymbol{Z}_t: \begin{equation} \boldsymbol{Y}_{t+1}\boldsymbol{Z}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)^2\boldsymbol{Y}_t\boldsymbol{Z}_t \label{eq:r3} \end{equation}

Calculating the Square Root

Now we formally enter the calculation of the square root. Since we assumed the eigenvalues of \boldsymbol{P} are non-negative, we can always further compress its eigenvalues into the range [0, 1] by dividing by \operatorname{tr}(\boldsymbol{P}). Therefore, without loss of generality, we assume the eigenvalues of \boldsymbol{P} are in [0, 1], allowing us to use Newton-Schulz iteration to compute \operatorname{mcsgn} directly.

Substituting \boldsymbol{A}=\boldsymbol{P}, \boldsymbol{B}=\boldsymbol{I} into Eq. [eq:core], we get: \begin{equation} \operatorname{mcsgn}\left(\begin{bmatrix}\boldsymbol{0} & \boldsymbol{P} \\ \boldsymbol{I} & \boldsymbol{0}\end{bmatrix}\right) = \begin{bmatrix}\boldsymbol{0} & \boldsymbol{P}^{1/2} \\ \boldsymbol{P}^{-1/2} & \boldsymbol{0}\end{bmatrix} \end{equation} It is quite remarkable that, theoretically, with just one \operatorname{mcsgn} operation, both the square root and the inverse square root can be obtained. By iterating according to Eq. [eq:r1] and [eq:r2], we can complete both tasks simultaneously!

However, in practice, it is not always ideal. If \boldsymbol{P} has singular values very close to 0, then \boldsymbol{P}^{-1/2} will suffer from numerical explosion (equivalent to 1/\sqrt{0}), whereas \boldsymbol{P}^{1/2} will not. Thus, if we only care about the value of \boldsymbol{P}^{1/2}, calculating both \boldsymbol{P}^{1/2} and \boldsymbol{P}^{-1/2} simultaneously might increase numerical instability. In this case, a better approach is to iterate using Eq. [eq:r1] and [eq:r3] to calculate only \boldsymbol{P}^{1/2}: \begin{gather} \boldsymbol{Y}_0 = \boldsymbol{P}, \quad \boldsymbol{Y}_0\boldsymbol{Z}_0 = \boldsymbol{P} \notag\\[6pt] \boldsymbol{Y}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)\boldsymbol{Y}_t \\[6pt] \boldsymbol{Y}_{t+1}\boldsymbol{Z}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)^2\boldsymbol{Y}_t\boldsymbol{Z}_t \\[6pt] \lim_{t\to\infty} \boldsymbol{Y}_t = \boldsymbol{P}^{1/2}\notag \end{gather} Since the limit of \boldsymbol{Z}_t is \boldsymbol{P}^{-1/2}, the limit of \boldsymbol{Y}_t\boldsymbol{Z}_t is \boldsymbol{I}. Therefore, iterating \boldsymbol{Y}_t\boldsymbol{Z}_t is less prone to numerical risks. Reference code is as follows:

import numpy as np

def abc(steps):
    coefs = [
        (8.287212018145622, -23.59588651909882, 17.300387312530923),
        (4.107059111542197, -2.9478499167379084, 0.54484310829266),
        (3.9486908534822938, -2.908902115962947, 0.5518191394370131),
        (3.3184196573706055, -2.488488024314878, 0.5100489401237208),
        (2.3006520199548186, -1.6689039845747518, 0.4188073119525678),
        (1.8913014077874002, -1.2679958271945908, 0.37680408948524996),
        (1.875, -1.25, 0.375)
    ]
    for a, b, c in coefs[:steps] + max(steps - 7, 0) * [coefs[-1]]:
        yield a / 1.01, b / 1.01**3, c / 1.01**5

def msqrt(P, steps=6):
    Y = YZ = P / (t := np.trace(P))
    I = np.eye(P.shape[0])
    for a, b, c in abc(steps):
        W = a * I + b * YZ + c * YZ @ YZ
        Y, YZ = W @ Y, W @ W @ YZ
    return Y * t**0.5

d = 100
P = (x := np.random.randn(d, d) / d**0.5) @ x.T
print(np.abs(msqrt(P) @ msqrt(P) - P).mean())  # ~= 2e-4

Calculating the Inverse Square Root

If we must explicitly calculate the inverse square root \boldsymbol{P}^{-1/2}, there is no easy way around it; what is destined to explode will explode. In this case, whether using the combination of Eq. [eq:r2] and [eq:r1] or Eq. [eq:r2] and [eq:r3], the results should be similar, though the latter might be relatively more stable: \begin{gather} \boldsymbol{Z}_0 = \boldsymbol{I}, \quad \boldsymbol{Y}_0\boldsymbol{Z}_0 = \boldsymbol{P} \notag\\[6pt] \boldsymbol{Z}_{t+1} = \boldsymbol{Z}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2) \label{eq:r2-rsqrt} \\[6pt] \boldsymbol{Y}_{t+1}\boldsymbol{Z}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)^2\boldsymbol{Y}_t\boldsymbol{Z}_t \label{eq:r3-rsqrt} \\[6pt] \lim_{t\to\infty} \boldsymbol{Z}_t = \boldsymbol{P}^{-1/2}\notag \end{gather}

Reference code:

def mrsqrt(P, steps=6):
    YZ = P / (t := np.trace(P))
    Z = I = np.eye(P.shape[0])
    for a, b, c in abc(steps):
        W = a * I + b * YZ + c * YZ @ YZ
        Z, YZ = Z @ W, W @ W @ YZ
    return Z / t**0.5

d = 100
P = (x := np.random.randn(d, d) / d**0.5) @ x.T
print(np.abs(mrsqrt(P) @ mrsqrt(P) @ P - np.eye(d)).mean())  # ~= 5e-4

Matrix Multiplication

In most cases, calculating \boldsymbol{P}^{-1/2} is just an intermediate step, usually followed by multiplication with another matrix. Let \boldsymbol{G} \in \mathbb{R}^{m \times n}; we need to compute \boldsymbol{G}\boldsymbol{P}^{-1/2}. If we can treat \boldsymbol{G}\boldsymbol{P}^{-1/2} as a single iterative object, it often provides better numerical stability compared to calculating \boldsymbol{P}^{-1/2} separately and then performing the matrix multiplication.

Looking closely at Eq. [eq:r2-rsqrt] and [eq:r3-rsqrt], it is clear that when we treat \boldsymbol{Y}_t\boldsymbol{Z}_t as a whole, its iteration Eq. [eq:r3-rsqrt] is independent of \boldsymbol{Z}_t. Thus, Eq. [eq:r2-rsqrt] for \boldsymbol{Z}_t is essentially just a linear recursion! Multiplying it by a matrix on the left does not change the iterative form; we only need to modify the initial value: \begin{gather} \boldsymbol{Z}_0 = \boldsymbol{G}, \quad \boldsymbol{Y}_0\boldsymbol{Z}_0 = \boldsymbol{P} \notag\\[6pt] \boldsymbol{Z}_{t+1} = \boldsymbol{Z}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2) \label{eq:r2-final} \\[6pt] \boldsymbol{Y}_{t+1}\boldsymbol{Z}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Y}_t\boldsymbol{Z}_t + c_{t+1}(\boldsymbol{Y}_t\boldsymbol{Z}_t)^2)^2\boldsymbol{Y}_t\boldsymbol{Z}_t \label{eq:r3-final}\\[6pt] \lim_{t\to\infty} \boldsymbol{Z}_t = \boldsymbol{G}\boldsymbol{P}^{-1/2}\notag \end{gather}

Reference code:

import scipy as sp

def matmul_mrsqrt(G, P, steps=6):
    YZ = P / (t := np.trace(P))
    Z, I = G, np.eye(P.shape[0])
    for a, b, c in abc(steps):
        W = a * I + b * YZ + c * YZ @ YZ
        Z, YZ = Z @ W, W @ W @ YZ
    return Z / t**0.5

d = 100
P = (x := np.random.randn(d, d) / d**0.5) @ x.T
G = np.random.randn(2 * d, d) / d**0.5
X = matmul_mrsqrt(G, P)
print(np.abs(X @ sp.linalg.sqrtm(P) - G).mean())  # ~= 1e-4

Now, looking back at the algorithm for the square root, it is easy to see that it is just another equivalent way of writing the iteration in this section when \boldsymbol{G}=\boldsymbol{P}, i.e., \boldsymbol{P}^{1/2} = \boldsymbol{P}\boldsymbol{P}^{-1/2}. So, although we discussed three iterations in three separate sections, they are all essentially special cases of the last iteration!

Ultimate Generalization

Finally, we can generalize this to the calculation of \boldsymbol{Q}^{-1/2}\boldsymbol{G}\boldsymbol{P}^{-1/2}, where \boldsymbol{Q} \in \mathbb{R}^{m \times m} is another matrix with non-negative eigenvalues. The result is as follows: \begin{gather} \boldsymbol{G}_0 = \boldsymbol{G}, \quad \boldsymbol{Q}_0 = \boldsymbol{Q}, \quad \boldsymbol{P}_0 = \boldsymbol{P} \notag\\[6pt] \boldsymbol{G}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Q}_t + c_{t+1}\boldsymbol{Q}_t^2)\boldsymbol{G}_t(a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2) \\[6pt] \boldsymbol{Q}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{Q}_t + c_{t+1}\boldsymbol{Q}_t^2)^2\boldsymbol{Q}_t \\[6pt] \boldsymbol{P}_{t+1} = (a_{t+1}\boldsymbol{I} + b_{t+1}\boldsymbol{P}_t + c_{t+1}\boldsymbol{P}_t^2)^2\boldsymbol{P}_t \\[6pt] \lim_{t\to\infty} \boldsymbol{G}_t = \boldsymbol{Q}^{-1/2}\boldsymbol{G}\boldsymbol{P}^{-1/2}\notag \end{gather}

Reference code:

def mrsqrt_matmul_mrsqrt(Q, G, P, steps=6):
    Q = Q / (t1 := np.trace(Q))
    P = P / (t2 := np.trace(P))
    I1, I2 = np.eye(Q.shape[0]), np.eye(P.shape[0])
    for a, b, c in abc(steps):
        W1 = a * I1 + b * Q + c * Q @ Q
        W2 = a * I2 + b * P + c * P @ P
        G, Q, P = W1 @ G @ W2, W1 @ W1 @ Q, W2 @ W2 @ P
    return G / (t1 * t2) **0.5

d = 100
Q = (x := np.random.randn(2 * d, 2 * d) / (2 * d)**0.5) @ x.T
P = (x := np.random.randn(d, d) / d**0.5) @ x.T
G = np.random.randn(2 * d, d) / d**0.5
X = mrsqrt_matmul_mrsqrt(Q, G, P)
print(np.abs(sp.linalg.sqrtm(Q) @ X @ sp.linalg.sqrtm(P) - G).mean())  # ~= 2e-3

Readers are invited to complete the proof themselves based on the results of the previous sections.

For the Shampoo optimizer, we need to compute \boldsymbol{Q}^{-1/4}\boldsymbol{G}\boldsymbol{P}^{-1/4}. Currently, a feasible approach seems to be first calculating \boldsymbol{Q}^{1/2} and \boldsymbol{P}^{1/2} separately, and then substituting them into the above iteration to find (\boldsymbol{Q}^{1/2})^{-1/2}\boldsymbol{G}(\boldsymbol{P}^{1/2})^{-1/2}. This appears computationally intensive, but in the Update phase of an Optimizer, computational power is often not the bottleneck as long as the algorithm can be fully parallelized. Since the calculations of \boldsymbol{Q}^{1/2} and \boldsymbol{P}^{1/2} can be parallelized, and the two matrices W_1 and W_2 in the iteration can also be parallelized, it should be acceptable.

Of course, it will be slower than Muon, as the complexity of Shampoo has increased significantly; one cannot expect to pay no price for it (see "Efficient calculation of matrix r-th roots and inverse r-th roots").

Summary

This article proposes converting the matrix square root and inverse square root into \operatorname{mcsgn} form and utilizing its Newton-Schulz iteration to achieve an efficient calculation process.

Original Address: https://kexue.fm/archives/11158

For more details on reprinting, please refer to: "Scientific Space FAQ"