English (unofficial) translations of posts at kexue.fm
Source

Reflections on Spectral Norm Gradients and a New Type of Weight Decay

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

In the article "Appreciating the Muon Optimizer: A Substantial Leap from Vectors to Matrices", we introduced a new optimizer called "Muon." One perspective for understanding it is as steepest gradient descent under spectral norm regularization, which seems to reveal a more fundamental optimization direction for matrix parameters. As is well known, we often apply Weight Decay to matrix parameters, which can be understood as the gradient of the squared Frobenius norm (F-norm). From the perspective of Muon, would constructing a new weight decay using the gradient of the squared spectral norm yield better results?

So the question arises: what does the gradient or derivative of the spectral norm look like? And what would a new weight decay designed with it look like? Next, we will explore these questions.

Basic Review

The spectral norm, also known as the "2-norm," is one of the most commonly used matrix norms. Compared to the simpler Frobenius norm (F-norm), it often reveals more fundamental signals related to matrix multiplication because its definition is inherently tied to it. For a matrix parameter \boldsymbol{W} \in \mathbb{R}^{n \times m}, its spectral norm is defined as: \begin{equation} \Vert\boldsymbol{W}\Vert_2 \triangleq \max_{\Vert\boldsymbol{x}\Vert=1} \Vert\boldsymbol{W}\boldsymbol{x}\Vert \end{equation} Here \boldsymbol{x} \in \mathbb{R}^m is a column vector, and \Vert\cdot\Vert on the right side denotes the vector length (Euclidean norm). From another perspective, the spectral norm is the smallest constant C such that the following inequality holds for all \boldsymbol{x} \in \mathbb{R}^m: \begin{equation} \Vert\boldsymbol{W}\boldsymbol{x}\Vert \leq C\Vert\boldsymbol{x}\Vert \end{equation} It is not difficult to prove that when C is taken as the F-norm \Vert \boldsymbol{W}\Vert_F, the above inequality also holds. Therefore, we can write \Vert \boldsymbol{W}\Vert_2 \leq \Vert \boldsymbol{W}\Vert_F (because \Vert \boldsymbol{W}\Vert_F is just one possible C that makes the inequality hold, while \Vert \boldsymbol{W}\Vert_2 is the smallest such C). This conclusion also indicates that if we want to control the magnitude of the output, using the spectral norm as a regularization term is more precise than the F-norm.

As early as six years ago, in "Lipschitz Constraints in Deep Learning: Generalization and Generative Models", we discussed the spectral norm. At that time, there were two main application scenarios: first, WGAN explicitly proposed Lipschitz constraints for the discriminator, and one implementation method was normalization based on the spectral norm; second, some work indicated that the spectral norm as a regularization term performs better than F-norm regularization.

Gradient Derivation

Now let’s get to the main topic and attempt to derive the gradient of the spectral norm \nabla_{\boldsymbol{W}} \Vert\boldsymbol{W}\Vert_2. We know that the spectral norm is numerically equal to its largest singular value, which we proved in the "Matrix Norms" section of "The Path to Low-Rank Approximation (II): SVD". This means that if \boldsymbol{W} can be decomposed via SVD as \sum_{i=1}^{\min(n,m)}\sigma_i \boldsymbol{u}_i\boldsymbol{v}_i^{\top}, then: \begin{equation} \Vert\boldsymbol{W}\Vert_2 = \sigma_1 = \boldsymbol{u}_1^{\top}\boldsymbol{W}\boldsymbol{v}_1 \end{equation} where \sigma_1 \geq \sigma_2 \geq \cdots \geq \sigma_{\min(n,m)} \geq 0 are the singular values of \boldsymbol{W}. Taking the differential of both sides, we get: \begin{equation} d\Vert\boldsymbol{W}\Vert_2 = d\boldsymbol{u}_1^{\top}\boldsymbol{W}\boldsymbol{v}_1 + \boldsymbol{u}_1^{\top}d\boldsymbol{W}\boldsymbol{v}_1 + \boldsymbol{u}_1^{\top}\boldsymbol{W}d\boldsymbol{v}_1 \end{equation} Note that: \begin{equation} d\boldsymbol{u}_1^{\top}\boldsymbol{W}\boldsymbol{v}_1 = d\boldsymbol{u}_1^{\top}\sum_{i=1}^{\min(n,m)}\sigma_i \boldsymbol{u}_i\boldsymbol{v}_i^{\top}\boldsymbol{v}_1 = d\boldsymbol{u}_1^{\top}\sigma_1 \boldsymbol{u}_1 = \frac{1}{2}\sigma_1 d(\Vert\boldsymbol{u}_1\Vert^2)=0 \end{equation} Similarly, \boldsymbol{u}_1^{\top}\boldsymbol{W}d\boldsymbol{v}_1=0, so: \begin{equation} d\Vert\boldsymbol{W}\Vert_2 = \boldsymbol{u}_1^{\top}d\boldsymbol{W}\boldsymbol{v}_1 = \text{Tr}((\boldsymbol{u}_1 \boldsymbol{v}_1^{\top})^{\top} d\boldsymbol{W}) \quad\Rightarrow\quad \nabla_{\boldsymbol{W}}\Vert\boldsymbol{W}\Vert_2 = \boldsymbol{u}_1 \boldsymbol{v}_1^{\top} \end{equation} Note that a key condition in this proof is \sigma_1 > \sigma_2. If \sigma_1 = \sigma_2, then \Vert\boldsymbol{W}\Vert_2 can be represented as both \boldsymbol{u}_1^{\top}\boldsymbol{W}\boldsymbol{v}_1 and \boldsymbol{u}_2^{\top}\boldsymbol{W}\boldsymbol{v}_2. The gradients calculated using the same method would be \boldsymbol{u}_1 \boldsymbol{v}_1^{\top} and \boldsymbol{u}_2 \boldsymbol{v}_2^{\top} respectively. The non-uniqueness of the result implies that the gradient does not exist. However, from a practical standpoint, the probability of two values being exactly equal is very small, so this point can be ignored.

(Note: The proof process here refers to an answer on Stack Exchange, but that answer did not prove d\boldsymbol{u}_1^{\top}\boldsymbol{W}\boldsymbol{v}_1=0 and \boldsymbol{u}_1^{\top}\boldsymbol{W}d\boldsymbol{v}_1=0; this part was completed by the author.)

Weight Decay

Based on this result and the chain rule, we have: \begin{equation} \nabla_{\boldsymbol{W}}\left(\frac{1}{2}\Vert\boldsymbol{W}\Vert_2^2\right) = \Vert\boldsymbol{W}\Vert_2\nabla_{\boldsymbol{W}}\Vert\boldsymbol{W}\Vert_2 = \sigma_1 \boldsymbol{u}_1 \boldsymbol{v}_1^{\top} \label{eq:grad-2-2} \end{equation} Comparing this with the result under the F-norm: \begin{equation} \nabla_{\boldsymbol{W}}\left(\frac{1}{2}\Vert\boldsymbol{W}\Vert_F^2\right) = \boldsymbol{W} = \sum_{i=1}^{\min(n,m)}\sigma_i \boldsymbol{u}_i \boldsymbol{v}_i^{\top} \end{equation} The comparison is very clear: weight decay derived from the squared F-norm as a regularization term penalizes all singular values simultaneously, while weight decay corresponding to the squared spectral norm only penalizes the largest singular value. If our goal is to compress the magnitude of the output, then compressing the largest singular value is "just right." Compressing all singular values might achieve a similar goal but could also compress the representational capacity of the parameters.

According to the "Eckart-Young-Mirsky Theorem," the rightmost result in equation [eq:grad-2-2] also has another meaning: it is the "optimal rank-1 approximation" of the matrix \boldsymbol{W}. In other words, spectral norm weight decay changes the operation of subtracting the matrix itself at each step to subtracting its optimal rank-1 approximation. This weakens the penalty intensity and, to some extent, makes the penalty "hit the essence" more directly.

Numerical Calculation

For practical application, the most critical question is: how do we calculate \sigma_1 \boldsymbol{u}_1 \boldsymbol{v}_1^{\top}? SVD is certainly the most straightforward solution, but its computational complexity is undoubtedly the highest. We must find a more efficient calculation method.

Without loss of generality, let n \geq m. First, note that: \begin{equation} \sigma_1 \boldsymbol{u}_1 \boldsymbol{v}_1^{\top} = \sum_{i=1}^m\sigma_i \boldsymbol{u}_i \boldsymbol{v}_i^{\top} \boldsymbol{v}_1 \boldsymbol{v}_1^{\top} = \boldsymbol{W}\boldsymbol{v}_1 \boldsymbol{v}_1^{\top} \end{equation} Thus, calculating \sigma_1 \boldsymbol{u}_1 \boldsymbol{v}_1^{\top} only requires knowing \boldsymbol{v}_1. According to our discussion in "The Path to Low-Rank Approximation (II): SVD," \boldsymbol{v}_1 is actually the eigenvector corresponding to the largest eigenvalue of the matrix \boldsymbol{W}^{\top}\boldsymbol{W}. In this way, we have transformed the problem from the SVD of a general matrix \boldsymbol{W} into the eigenvalue decomposition of a real symmetric matrix \boldsymbol{W}^{\top}\boldsymbol{W}, which already reduces complexity because eigenvalue decomposition is usually significantly faster than SVD.

If it is still considered slow, we need to bring out the principle behind many eigenvalue decomposition algorithms—"Power Iteration":

When \sigma_1 > \sigma_2, the iteration \begin{equation} \boldsymbol{x}_{t+1} = \frac{\boldsymbol{W}^{\top}\boldsymbol{W}\boldsymbol{x}_t}{\Vert\boldsymbol{W}^{\top}\boldsymbol{W}\boldsymbol{x}_t\Vert} \end{equation} converges to \boldsymbol{v}_1 at a rate of (\sigma_2/\sigma_1)^{2t}.

Each step of power iteration only requires two "matrix-vector" multiplications, with a complexity of \mathcal{O}(nm). The total complexity for t iterations is \mathcal{O}(tnm), which is very ideal. The disadvantage is that convergence is slow when \sigma_1 and \sigma_2 are close. However, the actual performance of power iteration is often better than theoretical expectations. Many early works even achieved good results with just one iteration, because \sigma_1 and \sigma_2 being close implies that they and their eigenvectors are somewhat interchangeable, and even if power iteration hasn’t fully converged, the result is an average of the two eigenvectors, which is often sufficient.

Iteration Proof

In this section, we complete the proof of power iteration. It is not hard to see that power iteration can be equivalently written as: \begin{equation} \lim_{t\to\infty} \frac{(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0}{\Vert(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0\Vert} = \boldsymbol{v}_1 \end{equation} To prove this limit, we start from \boldsymbol{W}=\sum_{i=1}^m\sigma_i \boldsymbol{u}_i\boldsymbol{v}_i^{\top}. Substituting this, we get: \begin{equation} \boldsymbol{W}^{\top}\boldsymbol{W} = \sum_{i=1}^m\sigma_i^2 \boldsymbol{v}_i\boldsymbol{v}_i^{\top},\qquad(\boldsymbol{W}^{\top}\boldsymbol{W})^t = \sum_{i=1}^m\sigma_i^{2t} \boldsymbol{v}_i\boldsymbol{v}_i^{\top} \end{equation} Since \boldsymbol{v}_1, \boldsymbol{v}_2, \dots, \boldsymbol{v}_m form an orthonormal basis for \mathbb{R}^m, \boldsymbol{x}_0 can be written as \sum_{j=1}^m c_j \boldsymbol{v}_j. Thus, we have: \begin{equation} (\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0 = \sum_{i=1}^m\sigma_i^{2t} \boldsymbol{v}_i\boldsymbol{v}_i^{\top}\sum_{j=1}^m c_j \boldsymbol{v}_j = \sum_{i=1}^m\sum_{j=1}^m c_j\sigma_i^{2t} \boldsymbol{v}_i\underbrace{\boldsymbol{v}_i^{\top} \boldsymbol{v}_j}_{=\delta_{i,j}} = \sum_{i=1}^m c_i\sigma_i^{2t} \boldsymbol{v}_i \end{equation} And: \begin{equation} \Vert(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0\Vert = \left\Vert \sum_{i=1}^m c_i\sigma_i^{2t} \boldsymbol{v}_i\right\Vert = \sqrt{\sum_{i=1}^m c_i^2\sigma_i^{4t}} \end{equation} Due to random initialization, the probability of c_1=0 is very small, so we can assume c_1 \neq 0. Then: \begin{equation} \frac{(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0}{\Vert(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0\Vert} = \frac{\sum_{i=1}^m c_i\sigma_i^{2t} \boldsymbol{v}_i}{\sqrt{\sum_{i=1}^m c_i^2\sigma_i^{4t}}} = \frac{\boldsymbol{v}_1 + \sum_{i=2}^m (c_i/c_1)(\sigma_i/\sigma_1)^{2t} \boldsymbol{v}_i}{\sqrt{1 + \sum_{i=2}^m (c_i/c_1)^2(\sigma_i/\sigma_1)^{4t}}} \end{equation} When \sigma_1 > \sigma_2, all \sigma_i/\sigma_1 (i \geq 2) are less than 1. Therefore, as t \to \infty, the corresponding terms become 0, and the final limit is \boldsymbol{v}_1.

Summary

This article derived the gradient of the spectral norm, leading to a new type of weight decay, and shared the author’s reflections on it.

When reprinting, please include the original address of this article: https://kexue.fm/archives/10648

For more detailed reprinting matters, please refer to: "Scientific Space FAQ"