English (unofficial) translations of posts at kexue.fm
Source

Calculating Singular Value Clipping (mclip) via msign (Part 2)

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

Previously, in "Calculating Singular Value Clipping (mclip) via msign (Part 1)", we discussed the numerical calculation of singular value clipping (\operatorname{mclip}). The core idea originated from @leloykun’s article "Numerically Stable Spectral Clipping Via Newton-Schulz Iteration" (now revised and renamed). By finding an expression based on \operatorname{msign}, we avoid the need to search for a separate Newton-Schulz iteration. In that article, I proposed a nested \operatorname{msign} scheme with lower computational cost.

However, a few days ago, @leloykun pointed out on Twitter that my scheme suffers from relatively large errors in practical calculations. This article analyzes this issue in detail and provides a more efficient new scheme with lower error.

Basic Concepts

As per convention, let’s first organize the basic concepts. First is the \operatorname{clip} operator for a scalar x, which we define generally as: \begin{equation} \operatorname{clip}_{[\alpha,\beta]}(x) = \max(\min(x, \beta), \alpha) = \left\{\begin{aligned}\beta, & \quad x \geq \beta \\ x, & \quad x \in (\alpha, \beta)\\ \alpha, & \quad x \leq \alpha \end{aligned}\right. \end{equation} When the interval is not specified, the default is [-1,1], i.e., \operatorname{clip}(x) = \operatorname{clip}_{[-1,1]}(x). Let the SVD of matrix \boldsymbol{M} \in \mathbb{R}^{n \times m} be \boldsymbol{M} = \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}, where \boldsymbol{U} \in \mathbb{R}^{n \times n} and \boldsymbol{V} \in \mathbb{R}^{m \times m} are orthogonal matrices, and \boldsymbol{\Sigma} \in \mathbb{R}^{n \times m} is the diagonal matrix of singular values. Then we define: \begin{equation} \operatorname{mclip}_{[\alpha,\beta]}(\boldsymbol{M}) = \boldsymbol{U} \operatorname{clip}_{[\alpha,\beta]}(\boldsymbol{\Sigma}) \boldsymbol{V}^{\top} \end{equation} Applying \operatorname{clip} to a diagonal matrix means applying \operatorname{clip} to each of its diagonal elements. Simply put, \operatorname{mclip}_{[\alpha,\beta]} clips the singular values of \boldsymbol{M} into the range [\alpha, \beta].

Since singular values are non-negative, when \alpha < 0, we have \operatorname{mclip}_{[\alpha,\beta]}(\boldsymbol{M}) = \operatorname{mclip}_{[0,\beta]}(\boldsymbol{M}). However, as we will see later, due to practical calculation errors, considering a negative \alpha can have a magical error-canceling effect.

Theoretical General Solution

The goal of this section is to express \operatorname{mclip} using \operatorname{msign}. The starting point is the identity: \begin{equation} \operatorname{mclip}_{[\alpha,\beta]} (x) = \frac{\alpha + \beta + (\alpha - x)\operatorname{sign}(\alpha - x) - (\beta - x)\operatorname{sign}(\beta - x)}{2} \end{equation} The key to finding this identity is expressing \operatorname{clip} as a linear operation of absolute values and the variable itself, then transitioning to the \operatorname{sign} operation via |x| = x \operatorname{sign}(x). We won’t expand on that here.

For simplicity, assume \boldsymbol{M} is a full-rank square matrix. Based on this identity, we have: \begin{equation} 2\operatorname{mclip}_{[\alpha,\beta]}(\boldsymbol{M}) = \boldsymbol{U}\Big((\alpha + \beta)\boldsymbol{I} + (\alpha \boldsymbol{I} - \boldsymbol{\Sigma})\operatorname{sign}(\alpha \boldsymbol{I} - \boldsymbol{\Sigma}) - (\beta \boldsymbol{I} - \boldsymbol{\Sigma})\operatorname{sign}(\beta \boldsymbol{I} - \boldsymbol{\Sigma})\Big)\boldsymbol{V}^{\top} \end{equation} Expanding the right side, it contains several types of terms (where \gamma \in \{\alpha, \beta\}):

Original Simplified
\boldsymbol{U}\boldsymbol{V}^{\top} \operatorname{msign}(\boldsymbol{M})
\boldsymbol{U}\operatorname{sign}(\gamma \boldsymbol{I} - \boldsymbol{\Sigma})\boldsymbol{V}^{\top} \begin{aligned} & \operatorname{msign}(\gamma \boldsymbol{U}\boldsymbol{V}^{\top} - \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}) \\ = & \operatorname{msign}(\gamma \operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M}) \end{aligned}
\boldsymbol{U}\boldsymbol{\Sigma}\operatorname{sign}(\gamma \boldsymbol{I} - \boldsymbol{\Sigma})\boldsymbol{V}^{\top} \begin{aligned} & \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}\boldsymbol{V}\boldsymbol{U}^{\top}\boldsymbol{U}\operatorname{sign}(\gamma \boldsymbol{I} - \boldsymbol{\Sigma})\boldsymbol{V}^{\top} \\ = & \boldsymbol{M}\operatorname{msign}(\boldsymbol{M})^{\top}\operatorname{msign}(\gamma \operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M}) \end{aligned}

Substituting and rearranging, we get: \begin{equation} \operatorname{mclip}_{[\alpha,\beta]}(\boldsymbol{M}) = \frac{1}{2}\left\{\begin{aligned}&\,(\alpha + \beta)\operatorname{msign}(\boldsymbol{M}) \\ + &\, (\alpha \boldsymbol{I} - \boldsymbol{M}\operatorname{msign}(\boldsymbol{M})^{\top})\operatorname{msign}(\alpha \operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M})\\ - &\, (\beta \boldsymbol{I} - \boldsymbol{M}\operatorname{msign}(\boldsymbol{M})^{\top})\operatorname{msign}(\beta \operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M}) \end{aligned}\right\} \label{eq:general} \end{equation} For non-square or non-full-rank matrices, one can verify this by substituting \operatorname{msign}(\boldsymbol{M}) = \boldsymbol{U}_{[:,:r]}\boldsymbol{V}_{[:,:r]}^{\top}. Thus, the above equation is the theoretical general solution for \operatorname{mclip}.

Initial Form

Equation [eq:general] appears to require at least three \operatorname{msign} calculations, and the inputs for the latter two \operatorname{msign} operations depend on the result of the first \operatorname{msign}, making it a nested \operatorname{msign} form. When we set \alpha=0, \beta=1, the number of \operatorname{msign} operations can be reduced to two: \begin{equation} \operatorname{mclip}(\boldsymbol{M}) = \frac{1}{2}\Big[\boldsymbol{M} + \operatorname{msign}(\boldsymbol{M}) + (\boldsymbol{I} - \boldsymbol{M}\operatorname{msign}(\boldsymbol{M})^{\top}) \operatorname{msign}(\boldsymbol{M} - \operatorname{msign}(\boldsymbol{M}))\Big] \label{eq:mclip-1} \end{equation} This is the result I provided in the previous article. It only requires two \operatorname{msign} operations.

However, empirical tests show that when the singular values of \boldsymbol{M} are large and the precision of the \operatorname{msign} calculation is low, this formula produces significant errors, much larger than the scheme provided by @leloykun. But @leloykun’s scheme requires calculating \operatorname{msign} for a matrix approximately four times the size, \begin{bmatrix}\boldsymbol{I} & \boldsymbol{M} \\ \boldsymbol{M}^{\top} & \boldsymbol{I}\end{bmatrix}, which is costly. Therefore, I wanted to see if there was room for improvement in this scheme.

Removing Nesting

Intuitively, the source of the error is the cumulative error caused by nested \operatorname{msign} operations. Thus, I tried to find a way to remove the nesting. Fortunately, using a simple trick, the nesting can indeed be removed!

First, it can be proven that: \begin{equation} \begin{aligned} &\,(\boldsymbol{I} - \boldsymbol{M}\operatorname{msign}(\boldsymbol{M})^{\top}) \operatorname{msign}(\boldsymbol{M} - \operatorname{msign}(\boldsymbol{M})) \\[6pt] =&\, (\operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M}) \operatorname{msign}(\operatorname{msign}(\boldsymbol{M})^{\top}\boldsymbol{M} - \boldsymbol{I}) \end{aligned} \end{equation} Then we have: \begin{equation} \operatorname{msign}(\boldsymbol{M})^{\top}\boldsymbol{M} - \boldsymbol{I} = \boldsymbol{V}\boldsymbol{\Sigma}\boldsymbol{V}^{\top} - \boldsymbol{I} = \boldsymbol{V}(\boldsymbol{\Sigma}-\boldsymbol{I})\boldsymbol{V}^{\top} \end{equation} Based on this, we assert: \begin{equation} \operatorname{msign}(\operatorname{msign}(\boldsymbol{M})^{\top}\boldsymbol{M} - \boldsymbol{I}) = \operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} - \boldsymbol{I}) = \operatorname{msign}(\boldsymbol{V}(\boldsymbol{\Sigma}^2-\boldsymbol{I})\boldsymbol{V}^{\top}) \end{equation} This utilizes a very simple property: \forall x \geq 0, \operatorname{sign}(x-1) = \operatorname{sign}(x^2-1). Using this result, we obtain: \begin{equation} \operatorname{mclip}(\boldsymbol{M}) = \frac{1}{2}\Big[\boldsymbol{M} + \operatorname{msign}(\boldsymbol{M}) + (\operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M}) \operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} - \boldsymbol{I})\Big] \label{eq:mclip-2} \end{equation} This still uses two \operatorname{msign} operations, but they are no longer nested, meaning there is theoretically no cumulative error from nested \operatorname{msign} operations. Empirical tests show that the error of Equation [eq:mclip-2] is indeed about half that of Equation [eq:mclip-1], but in extreme cases, it is still not as good as @leloykun’s scheme. This suggests that nesting is not the primary source of error.

Mutual Cancellation

Is there any further room for improvement? @leloykun’s scheme requires an odd function, so it actually considers \operatorname{mclip}_{[-1,1]} instead of \operatorname{mclip}_{[0,1]}. Is it possible that this choice causes certain error components to cancel each other out, resulting in better computational precision?

To verify this, we substitute \alpha=-1, \beta=1 into Equation [eq:general], obtaining: \begin{equation} \operatorname{mclip}(\boldsymbol{M}) = \frac{1}{2}\left\{\begin{aligned} &\,(\boldsymbol{I} + \boldsymbol{M}\operatorname{msign}(\boldsymbol{M})^{\top})\operatorname{msign}(\operatorname{msign}(\boldsymbol{M}) + \boldsymbol{M}) \\ - &\,(\boldsymbol{I} - \boldsymbol{M}\operatorname{msign}(\boldsymbol{M})^{\top})\operatorname{msign}(\operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M}) \end{aligned}\right\} \end{equation} Using the same de-nesting trick from the previous section, we get: \begin{equation} \operatorname{mclip}(\boldsymbol{M}) = \frac{1}{2}\left\{\begin{aligned} &\,(\operatorname{msign}(\boldsymbol{M}) + \boldsymbol{M})\operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} + \boldsymbol{I}) \\ + &\,(\operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M})\operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} - \boldsymbol{I}) \end{aligned}\right\} \label{eq:mclip-3} \end{equation} Note that \boldsymbol{M}^{\top}\boldsymbol{M} + \boldsymbol{I} is always a positive-definite symmetric matrix, so theoretically \operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} + \boldsymbol{I}) = \boldsymbol{I}, which would recover Equation [eq:mclip-2]. However, in practical calculations, the error between \operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} + \boldsymbol{I}) and \boldsymbol{I} might cancel out the error introduced by \operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} - \boldsymbol{I}). Thus, we decide whether to keep it based on experiments.

As expected, the numerical error of Equation [eq:mclip-3] is even smaller than @leloykun’s scheme! This confirms our suspicion: setting \alpha=-1 and \beta=1 to make \operatorname{mclip} an odd function helps cancel out errors.

Reasoning

Why does this happen to cancel errors? We can perform a simple quantitative analysis. Large errors occur under two conditions: first, \boldsymbol{M} has very large singular values; second, the number of \operatorname{msign} iterations is low, resulting in low precision for \operatorname{msign} itself.

Observing Equation [eq:mclip-3], it can be split into four terms. The terms \operatorname{msign}(\boldsymbol{M})\operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} \pm \boldsymbol{I}) are bounded; even if \operatorname{msign} precision is low, they generally won’t diverge. Therefore, the main error comes from: \begin{equation} \boldsymbol{M}\operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} + \boldsymbol{I}) - \boldsymbol{M}\operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} - \boldsymbol{I}) \label{eq:error-1} \end{equation} This is proportional to \boldsymbol{M}, which is most likely to amplify errors. Correspondingly, the main error term in Equation [eq:mclip-2] is: \begin{equation} \boldsymbol{M} - \boldsymbol{M}\operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} - \boldsymbol{I}) \label{eq:error-2} \end{equation} Consider singular values much larger than 1. If \operatorname{msign} were exact, the result of \operatorname{msign} would be 1, and the parts of the above expressions corresponding to large singular values would both be the expected 0.

However, if the number of \operatorname{msign} iterations is low, it might result in values like 0.6 or 1.4. In Equation [eq:error-2], the corresponding part would show a huge error of \sim \pm 0.4 \boldsymbol{M}. But in Equation [eq:error-1], when singular values are very large, the relative difference between \boldsymbol{M}^{\top}\boldsymbol{M} - \boldsymbol{I} and \boldsymbol{M}^{\top}\boldsymbol{M} + \boldsymbol{I} is small. Therefore, the difference between \operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} \pm \boldsymbol{I}) is very small, allowing Equation [eq:error-1] to still cancel out most of the error.

But remember, this always assumes that \boldsymbol{M} has singular values significantly larger than 1 and that the number of iterations is low. If these conditions are not met, the error of Equation [eq:mclip-2] is already small, and Equation [eq:mclip-3] might actually increase error due to the extra \operatorname{msign} calculation. Therefore, which formula performs best depends on the specific situation.

Comparison Code

We construct a matrix with singular values both greater and smaller than 1, with a maximum singular value near 1000, and test each algorithm in bfloat16 precision. The reference code is as follows:

import numpy as np
import jax.numpy as jnp
import jax.lax as lax

def msign(x, steps=4, eps=1e-20):
    """The coefficients come from https://kexue.fm/archives/10996
    """
    abc = [
        (8.287212018145622, -23.59588651909882, 17.300387312530923),
        (4.107059111542197, -2.9478499167379084, 0.54484310829266),
        (3.9486908534822938, -2.908902115962947, 0.5518191394370131),
        (3.3184196573706055, -2.488488024314878, 0.5100489401237208),
        (2.3006520199548186, -1.6689039845747518, 0.4188073119525678),
        (1.8913014077874002, -1.2679958271945908, 0.37680408948524996),
        (1.875, -1.25, 0.375)
    ]
    y = x.mT if x.shape[-2] > x.shape[-1] else x
    y = y * lax.rsqrt((y**2).sum(axis=[-2, -1], keepdims=True) + eps)
    for a, b, c in abc[:steps] + max(steps - 7, 0) * abc[-1:]:
        a, b, c = a / 1.01, b / 1.01**3, c / 1.01**5
        y = a * y + (b * (u := y @ y.mT) + c * u @ u) @ y
    return y.mT if x.shape[-2] > x.shape[-1] else y

def mclip1(m):
    """1st version (2 nested msign)
    """
    ms2 = msign(m - (ms1 := msign(m)))
    return (m + ms1 + ms2 - m @ ms1.mT @ ms2) / 2

def mclip2(m):
    """2nd version (2 non-nested msign)
    """
    ms1 = msign(m)
    ms2 = msign(m.mT @ m - jnp.eye(m.shape[-1]))
    return (m + ms1 + (ms1 - m) @ ms2) / 2

def mclip3(m):
    """3rd version (3 non-nested msign)
    """
    ms1 = msign(m)
    ms2 = msign(m.mT @ m + jnp.eye(m.shape[-1]))
    ms3 = msign(m.mT @ m - jnp.eye(m.shape[-1]))
    return ((ms1 + m) @ ms2  + (ms1 - m) @ ms3) / 2

def spectral_clip(W):
    """@leloykun version: https://leloykun.github.io/ponder/spectral-clipping/
    """
    m, n = W.shape
    H = jnp.block([[jnp.eye(m), W], [W.T, jnp.eye(n)]])
    OH = msign(H)
    P, Q = OH[:m, :m], OH[:m, m:]
    return Q + P @ W

m = np.random.randn(4096, 1024)
u, s, vh = jnp.linalg.svd(m, full_matrices=False)
s = np.concatenate([np.linspace(1, 1000, 128), np.linspace(0, 1, 896)])
s = np.sort(s)[::-1]
m = u @ jnp.diag(s) @ vh  # matrix with large singular values

result0 = u @ np.diag(s.clip(0, 1)) @ vh  # exact result via SVD
result1 = mclip1(m.astype('bfloat16'))
result2 = mclip2(m.astype('bfloat16'))
result3 = mclip3(m.astype('bfloat16'))
result4 = spectral_clip(m.astype('bfloat16'))

# spectral norm of the resulting matrix, closer to 1 is better.
# result0: 1
# result1: approx 700
# result2: approx 250
# result3: approx 1.5
# result4: approx 13

# mean absolute error of singular values, closer to 0 is better.
# result1: approx 20
# result2: approx 10
# result3: approx 0.5
# result4: approx 0.7

# mean absolute error of total matrix, closer to 0 is better.
# result1: approx 1
# result2: approx 0.5
# result3: approx 0.01
# result4: approx 0.02

Summary

This article continues to refine the scheme for calculating \operatorname{mclip} using \operatorname{msign} from the previous post. By removing the nesting of \operatorname{msign} and introducing additional correction terms, we have successfully reduced the computational error.

Reprinting: Please include the original link: https://kexue.fm/archives/11059

For more details on reprinting/citation, please refer to: "Scientific Space FAQ"