Previously, in "Calculating Singular Value Clipping (mclip) via msign (Part 1)", we discussed the numerical calculation of singular value clipping (\operatorname{mclip}). The core idea originated from @leloykun’s article "Numerically Stable Spectral Clipping Via Newton-Schulz Iteration" (now revised and renamed). By finding an expression based on \operatorname{msign}, we avoid the need to search for a separate Newton-Schulz iteration. In that article, I proposed a nested \operatorname{msign} scheme with lower computational cost.
However, a few days ago, @leloykun pointed out on Twitter that my scheme suffers from relatively large errors in practical calculations. This article analyzes this issue in detail and provides a more efficient new scheme with lower error.
Basic Concepts
As per convention, let’s first organize the basic concepts. First is the \operatorname{clip} operator for a scalar x, which we define generally as: \begin{equation} \operatorname{clip}_{[\alpha,\beta]}(x) = \max(\min(x, \beta), \alpha) = \left\{\begin{aligned}\beta, & \quad x \geq \beta \\ x, & \quad x \in (\alpha, \beta)\\ \alpha, & \quad x \leq \alpha \end{aligned}\right. \end{equation} When the interval is not specified, the default is [-1,1], i.e., \operatorname{clip}(x) = \operatorname{clip}_{[-1,1]}(x). Let the SVD of matrix \boldsymbol{M} \in \mathbb{R}^{n \times m} be \boldsymbol{M} = \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}, where \boldsymbol{U} \in \mathbb{R}^{n \times n} and \boldsymbol{V} \in \mathbb{R}^{m \times m} are orthogonal matrices, and \boldsymbol{\Sigma} \in \mathbb{R}^{n \times m} is the diagonal matrix of singular values. Then we define: \begin{equation} \operatorname{mclip}_{[\alpha,\beta]}(\boldsymbol{M}) = \boldsymbol{U} \operatorname{clip}_{[\alpha,\beta]}(\boldsymbol{\Sigma}) \boldsymbol{V}^{\top} \end{equation} Applying \operatorname{clip} to a diagonal matrix means applying \operatorname{clip} to each of its diagonal elements. Simply put, \operatorname{mclip}_{[\alpha,\beta]} clips the singular values of \boldsymbol{M} into the range [\alpha, \beta].
Since singular values are non-negative, when \alpha < 0, we have \operatorname{mclip}_{[\alpha,\beta]}(\boldsymbol{M}) = \operatorname{mclip}_{[0,\beta]}(\boldsymbol{M}). However, as we will see later, due to practical calculation errors, considering a negative \alpha can have a magical error-canceling effect.
Theoretical General Solution
The goal of this section is to express \operatorname{mclip} using \operatorname{msign}. The starting point is the identity: \begin{equation} \operatorname{mclip}_{[\alpha,\beta]} (x) = \frac{\alpha + \beta + (\alpha - x)\operatorname{sign}(\alpha - x) - (\beta - x)\operatorname{sign}(\beta - x)}{2} \end{equation} The key to finding this identity is expressing \operatorname{clip} as a linear operation of absolute values and the variable itself, then transitioning to the \operatorname{sign} operation via |x| = x \operatorname{sign}(x). We won’t expand on that here.
For simplicity, assume \boldsymbol{M} is a full-rank square matrix. Based on this identity, we have: \begin{equation} 2\operatorname{mclip}_{[\alpha,\beta]}(\boldsymbol{M}) = \boldsymbol{U}\Big((\alpha + \beta)\boldsymbol{I} + (\alpha \boldsymbol{I} - \boldsymbol{\Sigma})\operatorname{sign}(\alpha \boldsymbol{I} - \boldsymbol{\Sigma}) - (\beta \boldsymbol{I} - \boldsymbol{\Sigma})\operatorname{sign}(\beta \boldsymbol{I} - \boldsymbol{\Sigma})\Big)\boldsymbol{V}^{\top} \end{equation} Expanding the right side, it contains several types of terms (where \gamma \in \{\alpha, \beta\}):
| Original | Simplified |
|---|---|
| \boldsymbol{U}\boldsymbol{V}^{\top} | \operatorname{msign}(\boldsymbol{M}) |
| \boldsymbol{U}\operatorname{sign}(\gamma \boldsymbol{I} - \boldsymbol{\Sigma})\boldsymbol{V}^{\top} | \begin{aligned} & \operatorname{msign}(\gamma \boldsymbol{U}\boldsymbol{V}^{\top} - \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}) \\ = & \operatorname{msign}(\gamma \operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M}) \end{aligned} |
| \boldsymbol{U}\boldsymbol{\Sigma}\operatorname{sign}(\gamma \boldsymbol{I} - \boldsymbol{\Sigma})\boldsymbol{V}^{\top} | \begin{aligned} & \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}\boldsymbol{V}\boldsymbol{U}^{\top}\boldsymbol{U}\operatorname{sign}(\gamma \boldsymbol{I} - \boldsymbol{\Sigma})\boldsymbol{V}^{\top} \\ = & \boldsymbol{M}\operatorname{msign}(\boldsymbol{M})^{\top}\operatorname{msign}(\gamma \operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M}) \end{aligned} |
Substituting and rearranging, we get: \begin{equation} \operatorname{mclip}_{[\alpha,\beta]}(\boldsymbol{M}) = \frac{1}{2}\left\{\begin{aligned}&\,(\alpha + \beta)\operatorname{msign}(\boldsymbol{M}) \\ + &\, (\alpha \boldsymbol{I} - \boldsymbol{M}\operatorname{msign}(\boldsymbol{M})^{\top})\operatorname{msign}(\alpha \operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M})\\ - &\, (\beta \boldsymbol{I} - \boldsymbol{M}\operatorname{msign}(\boldsymbol{M})^{\top})\operatorname{msign}(\beta \operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M}) \end{aligned}\right\} \label{eq:general} \end{equation} For non-square or non-full-rank matrices, one can verify this by substituting \operatorname{msign}(\boldsymbol{M}) = \boldsymbol{U}_{[:,:r]}\boldsymbol{V}_{[:,:r]}^{\top}. Thus, the above equation is the theoretical general solution for \operatorname{mclip}.
Initial Form
Equation [eq:general] appears to require at least three \operatorname{msign} calculations, and the inputs for the latter two \operatorname{msign} operations depend on the result of the first \operatorname{msign}, making it a nested \operatorname{msign} form. When we set \alpha=0, \beta=1, the number of \operatorname{msign} operations can be reduced to two: \begin{equation} \operatorname{mclip}(\boldsymbol{M}) = \frac{1}{2}\Big[\boldsymbol{M} + \operatorname{msign}(\boldsymbol{M}) + (\boldsymbol{I} - \boldsymbol{M}\operatorname{msign}(\boldsymbol{M})^{\top}) \operatorname{msign}(\boldsymbol{M} - \operatorname{msign}(\boldsymbol{M}))\Big] \label{eq:mclip-1} \end{equation} This is the result I provided in the previous article. It only requires two \operatorname{msign} operations.
However, empirical tests show that when the singular values of \boldsymbol{M} are large and the precision of the \operatorname{msign} calculation is low, this formula produces significant errors, much larger than the scheme provided by @leloykun. But @leloykun’s scheme requires calculating \operatorname{msign} for a matrix approximately four times the size, \begin{bmatrix}\boldsymbol{I} & \boldsymbol{M} \\ \boldsymbol{M}^{\top} & \boldsymbol{I}\end{bmatrix}, which is costly. Therefore, I wanted to see if there was room for improvement in this scheme.
Removing Nesting
Intuitively, the source of the error is the cumulative error caused by nested \operatorname{msign} operations. Thus, I tried to find a way to remove the nesting. Fortunately, using a simple trick, the nesting can indeed be removed!
First, it can be proven that: \begin{equation} \begin{aligned} &\,(\boldsymbol{I} - \boldsymbol{M}\operatorname{msign}(\boldsymbol{M})^{\top}) \operatorname{msign}(\boldsymbol{M} - \operatorname{msign}(\boldsymbol{M})) \\[6pt] =&\, (\operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M}) \operatorname{msign}(\operatorname{msign}(\boldsymbol{M})^{\top}\boldsymbol{M} - \boldsymbol{I}) \end{aligned} \end{equation} Then we have: \begin{equation} \operatorname{msign}(\boldsymbol{M})^{\top}\boldsymbol{M} - \boldsymbol{I} = \boldsymbol{V}\boldsymbol{\Sigma}\boldsymbol{V}^{\top} - \boldsymbol{I} = \boldsymbol{V}(\boldsymbol{\Sigma}-\boldsymbol{I})\boldsymbol{V}^{\top} \end{equation} Based on this, we assert: \begin{equation} \operatorname{msign}(\operatorname{msign}(\boldsymbol{M})^{\top}\boldsymbol{M} - \boldsymbol{I}) = \operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} - \boldsymbol{I}) = \operatorname{msign}(\boldsymbol{V}(\boldsymbol{\Sigma}^2-\boldsymbol{I})\boldsymbol{V}^{\top}) \end{equation} This utilizes a very simple property: \forall x \geq 0, \operatorname{sign}(x-1) = \operatorname{sign}(x^2-1). Using this result, we obtain: \begin{equation} \operatorname{mclip}(\boldsymbol{M}) = \frac{1}{2}\Big[\boldsymbol{M} + \operatorname{msign}(\boldsymbol{M}) + (\operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M}) \operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} - \boldsymbol{I})\Big] \label{eq:mclip-2} \end{equation} This still uses two \operatorname{msign} operations, but they are no longer nested, meaning there is theoretically no cumulative error from nested \operatorname{msign} operations. Empirical tests show that the error of Equation [eq:mclip-2] is indeed about half that of Equation [eq:mclip-1], but in extreme cases, it is still not as good as @leloykun’s scheme. This suggests that nesting is not the primary source of error.
Mutual Cancellation
Is there any further room for improvement? @leloykun’s scheme requires an odd function, so it actually considers \operatorname{mclip}_{[-1,1]} instead of \operatorname{mclip}_{[0,1]}. Is it possible that this choice causes certain error components to cancel each other out, resulting in better computational precision?
To verify this, we substitute \alpha=-1, \beta=1 into Equation [eq:general], obtaining: \begin{equation} \operatorname{mclip}(\boldsymbol{M}) = \frac{1}{2}\left\{\begin{aligned} &\,(\boldsymbol{I} + \boldsymbol{M}\operatorname{msign}(\boldsymbol{M})^{\top})\operatorname{msign}(\operatorname{msign}(\boldsymbol{M}) + \boldsymbol{M}) \\ - &\,(\boldsymbol{I} - \boldsymbol{M}\operatorname{msign}(\boldsymbol{M})^{\top})\operatorname{msign}(\operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M}) \end{aligned}\right\} \end{equation} Using the same de-nesting trick from the previous section, we get: \begin{equation} \operatorname{mclip}(\boldsymbol{M}) = \frac{1}{2}\left\{\begin{aligned} &\,(\operatorname{msign}(\boldsymbol{M}) + \boldsymbol{M})\operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} + \boldsymbol{I}) \\ + &\,(\operatorname{msign}(\boldsymbol{M}) - \boldsymbol{M})\operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} - \boldsymbol{I}) \end{aligned}\right\} \label{eq:mclip-3} \end{equation} Note that \boldsymbol{M}^{\top}\boldsymbol{M} + \boldsymbol{I} is always a positive-definite symmetric matrix, so theoretically \operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} + \boldsymbol{I}) = \boldsymbol{I}, which would recover Equation [eq:mclip-2]. However, in practical calculations, the error between \operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} + \boldsymbol{I}) and \boldsymbol{I} might cancel out the error introduced by \operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} - \boldsymbol{I}). Thus, we decide whether to keep it based on experiments.
As expected, the numerical error of Equation [eq:mclip-3] is even smaller than @leloykun’s scheme! This confirms our suspicion: setting \alpha=-1 and \beta=1 to make \operatorname{mclip} an odd function helps cancel out errors.
Reasoning
Why does this happen to cancel errors? We can perform a simple quantitative analysis. Large errors occur under two conditions: first, \boldsymbol{M} has very large singular values; second, the number of \operatorname{msign} iterations is low, resulting in low precision for \operatorname{msign} itself.
Observing Equation [eq:mclip-3], it can be split into four terms. The terms \operatorname{msign}(\boldsymbol{M})\operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} \pm \boldsymbol{I}) are bounded; even if \operatorname{msign} precision is low, they generally won’t diverge. Therefore, the main error comes from: \begin{equation} \boldsymbol{M}\operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} + \boldsymbol{I}) - \boldsymbol{M}\operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} - \boldsymbol{I}) \label{eq:error-1} \end{equation} This is proportional to \boldsymbol{M}, which is most likely to amplify errors. Correspondingly, the main error term in Equation [eq:mclip-2] is: \begin{equation} \boldsymbol{M} - \boldsymbol{M}\operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} - \boldsymbol{I}) \label{eq:error-2} \end{equation} Consider singular values much larger than 1. If \operatorname{msign} were exact, the result of \operatorname{msign} would be 1, and the parts of the above expressions corresponding to large singular values would both be the expected 0.
However, if the number of \operatorname{msign} iterations is low, it might result in values like 0.6 or 1.4. In Equation [eq:error-2], the corresponding part would show a huge error of \sim \pm 0.4 \boldsymbol{M}. But in Equation [eq:error-1], when singular values are very large, the relative difference between \boldsymbol{M}^{\top}\boldsymbol{M} - \boldsymbol{I} and \boldsymbol{M}^{\top}\boldsymbol{M} + \boldsymbol{I} is small. Therefore, the difference between \operatorname{msign}(\boldsymbol{M}^{\top}\boldsymbol{M} \pm \boldsymbol{I}) is very small, allowing Equation [eq:error-1] to still cancel out most of the error.
But remember, this always assumes that \boldsymbol{M} has singular values significantly larger than 1 and that the number of iterations is low. If these conditions are not met, the error of Equation [eq:mclip-2] is already small, and Equation [eq:mclip-3] might actually increase error due to the extra \operatorname{msign} calculation. Therefore, which formula performs best depends on the specific situation.
Comparison Code
We construct a matrix with singular values both greater and smaller than 1, with a maximum singular value near 1000, and test each algorithm in bfloat16 precision. The reference code is as follows:
import numpy as np
import jax.numpy as jnp
import jax.lax as lax
def msign(x, steps=4, eps=1e-20):
"""The coefficients come from https://kexue.fm/archives/10996
"""
abc = [
(8.287212018145622, -23.59588651909882, 17.300387312530923),
(4.107059111542197, -2.9478499167379084, 0.54484310829266),
(3.9486908534822938, -2.908902115962947, 0.5518191394370131),
(3.3184196573706055, -2.488488024314878, 0.5100489401237208),
(2.3006520199548186, -1.6689039845747518, 0.4188073119525678),
(1.8913014077874002, -1.2679958271945908, 0.37680408948524996),
(1.875, -1.25, 0.375)
]
y = x.mT if x.shape[-2] > x.shape[-1] else x
y = y * lax.rsqrt((y**2).sum(axis=[-2, -1], keepdims=True) + eps)
for a, b, c in abc[:steps] + max(steps - 7, 0) * abc[-1:]:
a, b, c = a / 1.01, b / 1.01**3, c / 1.01**5
y = a * y + (b * (u := y @ y.mT) + c * u @ u) @ y
return y.mT if x.shape[-2] > x.shape[-1] else y
def mclip1(m):
"""1st version (2 nested msign)
"""
ms2 = msign(m - (ms1 := msign(m)))
return (m + ms1 + ms2 - m @ ms1.mT @ ms2) / 2
def mclip2(m):
"""2nd version (2 non-nested msign)
"""
ms1 = msign(m)
ms2 = msign(m.mT @ m - jnp.eye(m.shape[-1]))
return (m + ms1 + (ms1 - m) @ ms2) / 2
def mclip3(m):
"""3rd version (3 non-nested msign)
"""
ms1 = msign(m)
ms2 = msign(m.mT @ m + jnp.eye(m.shape[-1]))
ms3 = msign(m.mT @ m - jnp.eye(m.shape[-1]))
return ((ms1 + m) @ ms2 + (ms1 - m) @ ms3) / 2
def spectral_clip(W):
"""@leloykun version: https://leloykun.github.io/ponder/spectral-clipping/
"""
m, n = W.shape
H = jnp.block([[jnp.eye(m), W], [W.T, jnp.eye(n)]])
OH = msign(H)
P, Q = OH[:m, :m], OH[:m, m:]
return Q + P @ W
m = np.random.randn(4096, 1024)
u, s, vh = jnp.linalg.svd(m, full_matrices=False)
s = np.concatenate([np.linspace(1, 1000, 128), np.linspace(0, 1, 896)])
s = np.sort(s)[::-1]
m = u @ jnp.diag(s) @ vh # matrix with large singular values
result0 = u @ np.diag(s.clip(0, 1)) @ vh # exact result via SVD
result1 = mclip1(m.astype('bfloat16'))
result2 = mclip2(m.astype('bfloat16'))
result3 = mclip3(m.astype('bfloat16'))
result4 = spectral_clip(m.astype('bfloat16'))
# spectral norm of the resulting matrix, closer to 1 is better.
# result0: 1
# result1: approx 700
# result2: approx 250
# result3: approx 1.5
# result4: approx 13
# mean absolute error of singular values, closer to 0 is better.
# result1: approx 20
# result2: approx 10
# result3: approx 0.5
# result4: approx 0.7
# mean absolute error of total matrix, closer to 0 is better.
# result1: approx 1
# result2: approx 0.5
# result3: approx 0.01
# result4: approx 0.02
Summary
This article continues to refine the scheme for calculating \operatorname{mclip} using \operatorname{msign} from the previous post. By removing the nesting of \operatorname{msign} and introducing additional correction terms, we have successfully reduced the computational error.
Reprinting: Please include the original link: https://kexue.fm/archives/11059
For more details on reprinting/citation, please refer to: "Scientific Space FAQ"