English (unofficial) translations of posts at kexue.fm
Source

Beyond MuP: 4. Maintaining Parameter Stability

Translated by DeepSeek V4 Pro. Translations can be inaccurate, please refer to the original post for important stuff.

Through the derivations and calculations in the previous articles, we can see that the three stability indicators proposed in the first article Beyond MuP: 1. Three Characteristics of a Good Model can generally be divided into “parameter stability” and “increment stability”. In Beyond MuP: 2. Linear Layers and Steepest Descent and Beyond MuP: 3. Special Cases, Special Treatment, we demonstrated the process of combining increment stability with steepest descent to derive new update rules (optimizers).

However, for parameter stability, we have previously only focused on initialization. The task of this article is precisely to explore how to maintain parameter stability throughout the entire training process, thereby completing the theory with practice.

Problem Background

Taking Beyond MuP: 2. Linear Layers and Steepest Descent as an example, the three stability indicators are: \begin{aligned} &\text{Forward stability:}\quad\max_{\Vert\boldsymbol{x}\Vert_{RMS}=1} \Vert \boldsymbol{x}\boldsymbol{W}\Vert_{RMS} = \sqrt{\frac{d_{in}}{d_{out}}}\Vert\boldsymbol{W}\Vert_2 \\[5pt] &\text{Dependency stability:}\quad\max_{\Vert\boldsymbol{x}_1\Vert_{RMS}=\Vert\boldsymbol{x}_2\Vert_{RMS}=1} \frac{\Vert \boldsymbol{x}_1\boldsymbol{W} - \boldsymbol{x}_2\boldsymbol{W}\Vert_{RMS}}{\Vert \boldsymbol{x}_1 - \boldsymbol{x}_2\Vert_{RMS}} = \sqrt{\frac{d_{in}}{d_{out}}}\Vert\boldsymbol{W}\Vert_2 \\[5pt] &\text{Update stability:}\quad\max_{\Vert\boldsymbol{x}\Vert_{RMS}=1} \Vert \boldsymbol{x}(\boldsymbol{W} + \Delta\boldsymbol{W}) - \boldsymbol{x}\boldsymbol{W}\Vert_{RMS} = \sqrt{\frac{d_{in}}{d_{out}}}\Vert\Delta\boldsymbol{W}\Vert_2 \end{aligned} where \boldsymbol{W}\in\mathbb{R}^{d_{in}\times d_{out}} are the parameters of the linear layer. We want all three indicators to be \Theta(1), which means we want the parameters and their increments to satisfy \Vert\boldsymbol{W}\Vert_2 = \Theta(\sqrt{d_{out}/d_{in}}) and \Vert\Delta\boldsymbol{W}\Vert_2 = \Theta(\sqrt{d_{out}/d_{in}}) respectively. In Beyond MuP: 3. Special Cases, Special Treatment, we computed similar quantities for Embedding, LM Head, and other layers, leading to similar conclusions, just with different norms.

We take the increment condition as a stability indicator and, based on the principle of “seeking speed while maintaining stability” via steepest descent, derive theoretically optimal update rules. For linear layers, for example, this corresponds to the Muon optimizer: \mathop{\mathrm{argmin}}_{\Vert\Delta\boldsymbol{W}\Vert_2\leq\eta\sqrt{\frac{d_{out}}{d_{in}}}} \mathop{\mathrm{tr}}(\boldsymbol{G}^{\top}\Delta\boldsymbol{W}) \qquad \Rightarrow \qquad \Delta\boldsymbol{W} = -\eta\sqrt{\frac{d_{out}}{d_{in}}}\mathop{\mathrm{msign}}(\boldsymbol{G}) For the parameter stability part, we previously only required the initialization to satisfy \Vert\boldsymbol{W}\Vert_2 = \Theta(\sqrt{d_{out}/d_{in}}). How to ensure that the model maintains the same parameter stability throughout the entire training process has not yet been understood.

General Framework

How can we ensure that \boldsymbol{W} maintains \Vert\boldsymbol{W}\Vert_2 = \Theta(\sqrt{d_{out}/d_{in}})? More generally, given a parameter \boldsymbol{\omega}, which could be a vector, matrix, or even higher-order tensor, and given a norm \Vert\cdot\Vert, which is typically an indicator induced by forward stability or dependency stability, and given a scale \tau, the question is: How can we keep \boldsymbol{\omega} satisfying \Vert\boldsymbol{\omega}\Vert=\Theta(\tau) throughout training?

Initial Thoughts

A naive idea is to directly enforce \Vert\boldsymbol{\omega}\Vert=\tau (one can also replace \tau with a constant multiple, but this does not affect the following discussion). The simplest way to achieve this is to rescale the norm to \tau after each optimization step via normalization (e.g., Hyperball and Nemotron-Flash). Another approach is to use normalization to reparameterize the original model, i.e., replace \boldsymbol{f}(\boldsymbol{x};\boldsymbol{\omega}) with \boldsymbol{f}(\boldsymbol{x};\tau\boldsymbol{\omega}/\Vert\boldsymbol{\omega}\Vert), which in theory can achieve a similar effect.

A further step is to combine the steepest descent idea to adjust the update rule, as discussed in the articles Steepest Descent on Manifolds: 1. SGD + Hypersphere, Steepest Descent on Manifolds: 4. Muon + Spectral Sphere, and the paper Controlled LLM Training on Spectral Sphere. This approach is more elegant methodologically, but more complex in practice, usually requiring solving a nonlinear equation to obtain the exact update amount.

However, should we really strictly control a certain norm of a parameter to a specific value? Intuitively, the norm of a parameter should be determined by the training process itself; at most we can set a prior range for it. Although some works show that, when properly set, fixing the parameter norm to a preset value does not harm performance, this still disrupts the original training dynamics and may require more effort to understand and adapt to it.

Therefore, the viewpoint proposed in this article is that we only need to ensure \Vert\boldsymbol{\omega}\Vert = \mathcal{O}(\tau), specifically, to guarantee that each step satisfies \Vert\boldsymbol{\omega}\Vert \leq \tau. As for what specific value it takes, and whether it can achieve \Theta(\tau), we leave that to the training algorithm to decide, without further intervention.

Post Clip

The next question naturally is: How to enforce \Vert\boldsymbol{\omega}\Vert \leq \tau? More concretely, suppose the original update rule for \boldsymbol{\omega} is \boldsymbol{\omega}_t = \boldsymbol{\omega}_{t-1} - \eta \boldsymbol{\phi}_t\label{eq:base-update} Then how should we modify it so that \boldsymbol{\omega}_t always satisfies \Vert\boldsymbol{\omega}_t\Vert\leq\tau? There are certainly many methods, such as the normalization mentioned in the previous section, which is also a viable approach. That being the case, we wish to select among them the method that has the least impact on the optimization process. That is, given the parameter \boldsymbol{\omega} and the norm \Vert\cdot\Vert, we wish to modify it minimally so that its norm does not exceed \tau, formally defined as \textcolor{skyblue}{\lfloor}\boldsymbol{\omega}\textcolor{skyblue}{\rfloor}_{\textcolor{skyblue}{\Vert\cdot\Vert\leq\tau}} = \mathop{\mathrm{argmin}}_{\Vert\tilde{\boldsymbol{\omega}}\Vert\leq\tau} \Vert \boldsymbol{\omega} - \tilde{\boldsymbol{\omega}}\Vert_{RMS}\label{eq:nclip} Readers familiar with convex optimization will easily recognize that this is precisely the projection of \boldsymbol{\omega} onto the ball of norm radius not exceeding \tau. The crucial point here is that we want to achieve the goal of having the norm not exceed \tau, but also want to minimize the impact on the original parameter \boldsymbol{\omega}, so we minimize the discrepancy \Vert \boldsymbol{\omega} - \tilde{\boldsymbol{\omega}}\Vert_{RMS}, which induces a specific projection or clipping operation.

As for how to compute \textcolor{skyblue}{\lfloor}\boldsymbol{\omega}\textcolor{skyblue}{\rfloor}_{\textcolor{skyblue}{\Vert\cdot\Vert\leq\tau}}, we need to analyze it on a case-by-case basis for different norms, which we will elaborate later. With this operation, one possible approach is to clip the parameter norm after each update step, i.e., modify Equation [eq:base-update] to \boldsymbol{\omega}_t = \textcolor{skyblue}{\lfloor}\boldsymbol{\omega}_{t-1} - \eta \boldsymbol{\phi}_t\textcolor{skyblue}{\rfloor}_{\textcolor{skyblue}{\Vert\cdot\Vert\leq\tau}} We shall tentatively call this scheme “Post Clip”. It is characterized by being simple and intuitive, but may give an impression of being “non-smooth”. This is not difficult to understand: suppose our initialization has a radius less than \tau, and as training begins the parameter radius slowly increases. After reaching \tau, the clipping is suddenly triggered. Although the process is continuous, it is not smooth, similar to the \max(x,0) function.

Pre Decay

If one minds this non-smoothness, one can consider imitating weight decay and distribute the penalty across each update step. Starting again from the update rule [eq:base-update], assume \boldsymbol{\phi}_t satisfies \Vert\boldsymbol{\phi}_t\Vert\leq\tau. Then by the triangle inequality we have \Vert\boldsymbol{\omega}_t\Vert = \Vert\boldsymbol{\omega}_{t-1} - \eta \boldsymbol{\phi}_t\Vert\leq \Vert\boldsymbol{\omega}_{t-1}\Vert + \eta \tau, meaning that in the worst case the norm increases by \eta\tau per step, and over long-term accumulation it will “run out of control”.

To prevent this phenomenon, we can apply a little preprocessing to \boldsymbol{\omega}_{t-1} before - \eta \boldsymbol{\phi}_t, to reduce its norm just enough to offset the growth caused by the update. Based on the experience of weight decay, we can consider \boldsymbol{\omega}_t = \textcolor{skyblue}{\lfloor}\boldsymbol{\omega}_{t-1}\textcolor{skyblue}{\rfloor}_{\textcolor{skyblue}{\Vert\cdot\Vert\leq (1-\eta)\Vert\boldsymbol{\omega}_{t-1}\Vert}} - \eta \boldsymbol{\phi}_t\label{eq:pre-decay} That is, we first try to reduce the norm of \boldsymbol{\omega}_{t-1} to 1-\eta times its original value, and then apply the update. Consequently, \Vert\boldsymbol{\omega}_t\Vert \leq (1-\eta)\Vert\boldsymbol{\omega}_{t-1}\Vert + \eta \tau \leq \max(\Vert\boldsymbol{\omega}_{t-1}\Vert,\tau) Propagating this, we have \Vert\boldsymbol{\omega}_t\Vert \leq \max(\Vert\boldsymbol{\omega}_{t-1}\Vert,\tau) \leq \cdots \leq \max(\Vert\boldsymbol{\omega}_0\Vert,\tau), i.e., as long as the initialization satisfies \Vert\boldsymbol{\omega}_0\Vert\leq\tau, the entire update chain automatically satisfies \Vert\boldsymbol{\omega}_t\Vert\leq \tau. This conclusion is independent of the specific norm; it only relies on the triangle inequality of the norm. The operation that reduces the norm with minimal modification is precisely the clip operator defined in Equation [eq:nclip], so it is natural to use it for norm reduction.

We call this scheme “Pre Decay”. The difference from “Post Clip” is that the latter’s threshold is static \tau, so clipping does not necessarily trigger, whereas the former’s threshold is dynamic (1-\eta)\Vert\boldsymbol{\omega}_{t-1}\Vert, and the clipping always triggers. This process is smoother, hence we call it decay rather than clip. It is a general generalization of weight decay.

Basic Results

So far, we have established a general framework for constraining parameter norms, with two schemes: “Post Clip” and “Pre Decay”. The core operation is the clip operator defined in Equation [eq:nclip]: \textcolor{skyblue}{\lfloor}\boldsymbol{\omega}\textcolor{skyblue}{\rfloor}_{\textcolor{skyblue}{\Vert\cdot\Vert\leq\tau}}. Currently we only have a formal definition; in practice we need to compute it for specific norms. Below we provide some basic computational results.

Simple Example

In this section we first compute a simple example, where the chosen norm is \Vert\cdot\Vert_{RMS}. For vectors this is equivalent to the L2 norm; for matrices it is equivalent to the Frobenius norm. It is not difficult to obtain \textcolor{skyblue}{\lfloor}\boldsymbol{\omega}\textcolor{skyblue}{\rfloor}_{\textcolor{skyblue}{\Vert\cdot\Vert_{RMS}\leq\tau}} = \mathop{\mathrm{argmin}}_{\Vert\tilde{\boldsymbol{\omega}}\Vert_{RMS}\leq\tau} \Vert \boldsymbol{\omega} - \tilde{\boldsymbol{\omega}}\Vert_{RMS} = \min\left(1,\,\frac{\tau}{\Vert\boldsymbol{\omega}\Vert_{RMS}}\right)\boldsymbol{\omega} The proof is left to the reader (if you really cannot figure it out, you can ask Kimi). In particular, substituting \tau = (1 - \eta) \Vert\omega\Vert_{RMS} yields \textcolor{skyblue}{\lfloor}\boldsymbol{\omega}\textcolor{skyblue}{\rfloor}_{\textcolor{skyblue}{\Vert\cdot\Vert_{RMS}\leq (1 - \eta) \Vert\omega\Vert_{RMS}}} = \min\left(1,\,\frac{(1 - \eta) \Vert\omega\Vert_{RMS}}{\Vert\boldsymbol{\omega}\Vert_{RMS}}\right)\boldsymbol{\omega} = (1-\eta)\boldsymbol{\omega} Then plugging into Equation [eq:pre-decay] gives \boldsymbol{\omega}_t = (1-\eta)\boldsymbol{\omega}_{t-1} - \eta \boldsymbol{\phi}_t We can easily see that this is precisely the conventional weight decay. In other words, Pre Decay under the RMS norm is exactly the weight decay we commonly use. It is the minimal-modification Pre Decay scheme under the constraint of maintaining the RMS norm (equivalently, the L2 norm for vectors or Frobenius norm for matrices).

Singular Value Clipping

Now we enter the “main event” of this article—matrix parameters and Muon. Here we return to the notation \boldsymbol{W} and rewrite Muon’s original update rule as \boldsymbol{W}_t = \boldsymbol{W}_{t-1} - \eta\lambda\boldsymbol{\Phi}_t,\quad \boldsymbol{\Phi}_t=\frac{1}{\lambda}\sqrt{\frac{d_{out}}{d_{in}}}\mathop{\mathrm{msign}}(\boldsymbol{G}_t) Let \tau = \frac{1}{\lambda}\sqrt{\frac{d_{out}}{d_{in}}}, so that \Vert\boldsymbol{\Phi}_t\Vert_2=\tau. The two schemes for making \boldsymbol{W}_t satisfy \Vert\boldsymbol{W}_t\Vert_2\leq\tau are: \begin{aligned} \text{Post Clip:}\quad\boldsymbol{W}_t =&\, \textcolor{skyblue}{\lfloor}\boldsymbol{W}_{t-1} - \eta\lambda\boldsymbol{\Phi}_t\textcolor{skyblue}{\rfloor}_{\textcolor{skyblue}{\Vert\cdot\Vert_2\leq\tau}} \\[5pt] \text{Pre Decay:}\quad\boldsymbol{W}_t =&\, \textcolor{skyblue}{\lfloor}\boldsymbol{W}_{t-1}\textcolor{skyblue}{\rfloor}_{\textcolor{skyblue}{\Vert\cdot\Vert_2\leq(1-\eta\lambda)\Vert\boldsymbol{W}_{t-1}\Vert_2}} - \eta\lambda\boldsymbol{\Phi}_t \end{aligned} The next task is to compute \textcolor{skyblue}{\lfloor}\boldsymbol{W}\textcolor{skyblue}{\rfloor}_{\textcolor{skyblue}{\Vert\cdot\Vert_2\leq\tau}}. By the equivalence of RMS and Frobenius norms, it is also equal to \textcolor{skyblue}{\lfloor}\boldsymbol{W}\textcolor{skyblue}{\rfloor}_{\textcolor{skyblue}{\Vert\cdot\Vert_2\leq\tau}} = \mathop{\mathrm{argmin}}_{\Vert\tilde{\boldsymbol{W}}\Vert_2\leq\tau} \Vert\boldsymbol{W} - \tilde{\boldsymbol{W}}\Vert_F\label{eq:mclip-loss} The optimal solution to this problem should not be unfamiliar to some readers. It is the “Singular Value Clipping (SVC)” we mentioned in Higher-Order MuP: Simpler but More Elegant Spectral Condition Scaling, and in Computing Singular Value Clipping mclip via msign (Part 1) and Computing Singular Value Clipping mclip via msign (Part 2) it is denoted as \mathop{\mathrm{mclip}}: \textcolor{skyblue}{\lfloor}\boldsymbol{W}\textcolor{skyblue}{\rfloor}_{\textcolor{skyblue}{\Vert\cdot\Vert_2\leq\tau}} = \mathop{\mathrm{mclip}}(\boldsymbol{W};\tau) = \boldsymbol{U}\min(\boldsymbol{\Sigma},\tau)\boldsymbol{V}^{\top}\label{eq:2-to-mclip} where \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top} is the SVD of \boldsymbol{W}, and \min(\boldsymbol{\Sigma},\tau) clips the singular values to not exceed \tau. We will demonstrate the proof in the next section. With this notation, the two schemes can be written as \begin{aligned} \text{Post Clip:}\quad\boldsymbol{W}_t =&\, \mathop{\mathrm{mclip}}(\boldsymbol{W}_{t-1} - \eta\lambda\boldsymbol{\Phi}_t;\tau) \\[5pt] \text{Pre Decay:}\quad\boldsymbol{W}_t =&\, \mathop{\mathrm{mclip}}(\boldsymbol{W}_{t-1};(1-\eta\lambda)\Vert\boldsymbol{W}_{t-1}\Vert_2) - \eta\lambda\boldsymbol{\Phi}_t \end{aligned}

Derivation Process

In this section we prove the conclusion [eq:2-to-mclip]. Let the SVD of \boldsymbol{W} be \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}, where \boldsymbol{U}\in\mathbb{R}^{d_{in}\times d_{in}}, \boldsymbol{\Sigma}\in\mathbb{R}^{d_{in}\times d_{out}}, \boldsymbol{V}\in\mathbb{R}^{d_{out}\times d_{out}}. Then \Vert\boldsymbol{W} - \tilde{\boldsymbol{W}}\Vert_F = \Vert\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top} - \tilde{\boldsymbol{W}}\Vert_F = \Vert\boldsymbol{U}(\boldsymbol{\Sigma} - \boldsymbol{U}^{\top}\tilde{\boldsymbol{W}}\boldsymbol{V})\boldsymbol{V}^{\top}\Vert_F = \Vert\boldsymbol{\Sigma} - \boldsymbol{U}^{\top}\tilde{\boldsymbol{W}}\boldsymbol{V}\Vert_F The last equality holds because orthogonal matrices do not change the Frobenius norm. At the same time, orthogonal matrices do not change the spectral norm. Hence, setting \tilde{\boldsymbol{\Sigma}}=\boldsymbol{U}^{\top}\tilde{\boldsymbol{W}}\boldsymbol{V}, the objective [eq:mclip-loss] can be equivalently simplified to \mathop{\mathrm{argmin}}_{\Vert\tilde{\boldsymbol{\Sigma}}\Vert_2\leq\tau} \Vert\boldsymbol{\Sigma} - \tilde{\boldsymbol{\Sigma}}\Vert_F Note that here \boldsymbol{\Sigma} is a diagonal matrix, with diagonal elements \sigma_1,\sigma_2,\cdots \geq 0, but \tilde{\boldsymbol{\Sigma}} is for the moment undetermined; for the proof we should treat it as a general matrix. Writing in component form, \Vert\boldsymbol{\Sigma} - \tilde{\boldsymbol{\Sigma}}\Vert_F^2 = \sum_i \sigma_i^2 + \sum_{i,j} \tilde{\Sigma}_{i,j}^2 - 2\sum_i \sigma_i \tilde{\Sigma}_{i,i} \geq \sum_i \sigma_i^2 + \sum_i (\tilde{\Sigma}_{i,i}^2 - 2 \sigma_i \tilde{\Sigma}_{i,i}) Considering each term individually, \tilde{\Sigma}_{i,i}^2 - 2 \sigma_i \tilde{\Sigma}_{i,i} is simply a quadratic function in \tilde{\Sigma}_{i,i}, attaining its minimum at \tilde{\Sigma}_{i,i}=\sigma_i. However, we also have the constraint \Vert\tilde{\boldsymbol{\Sigma}}\Vert_2\leq\tau. Since the spectral norm is at least the absolute value of any matrix element, we must have \tilde{\Sigma}_{i,i}\leq\tau at minimum. Under this constraint, the minimum of \tilde{\Sigma}_{i,i}^2 - 2 \sigma_i \tilde{\Sigma}_{i,i} is achieved at \tilde{\Sigma}_{i,i}^* = \min(\sigma_i,\tau).

Making all equalities hold simultaneously, we obtain \tilde{\Sigma}_{i,j}^*=0(i\neq j). Thus \tilde{\boldsymbol{\Sigma}}^* is also a diagonal matrix, which can be simply written as \tilde{\boldsymbol{\Sigma}}^*=\min(\boldsymbol{\Sigma},\tau), corresponding to \tilde{\boldsymbol{W}}^*=\boldsymbol{U}\min(\boldsymbol{\Sigma},\tau)\boldsymbol{V}^{\top}. This completes the proof of [eq:2-to-mclip].

Clipping the Principal Singular Value

How, then, can \mathop{\mathrm{mclip}} be computed efficiently? Performing an SVD at every training step is obviously too expensive. In the articles Computing Singular Value Clipping mclip via msign (Part 1) and Computing Singular Value Clipping mclip via msign (Part 2), we systematically explored this problem. The approach at that time was to implement it with the aid of \mathop{\mathrm{msign}}, but it required 2 to 3 \mathop{\mathrm{msign}} operations, which is not cheap. For example, one identity found in the second part is \mathop{\mathrm{mclip}}(\boldsymbol{W};\tau) =\frac{1}{2}\Bigl\{\boldsymbol{W}+\tau\mathop{\mathrm{msign}}(\boldsymbol{W})-(\tau\boldsymbol{I}-\boldsymbol{W}\mathop{\mathrm{msign}}(\boldsymbol{W})^{\top})\mathop{\mathrm{msign}}(\tau\mathop{\mathrm{msign}}(\boldsymbol{W})-\boldsymbol{W})\Bigr\} It requires two \mathop{\mathrm{msign}} operations. Since the parameters are typically computed in FP32, executing two \mathop{\mathrm{msign}} operations is still relatively expensive, so it is not yet particularly practical.

Here we mainly consider the component-wise clipping approach discussed in Streaming Power Iteration-based Muon Implementation: 5. Extensions. Specifically, \mathop{\mathrm{mclip}} turns all singular values greater than \tau into \tau, so the necessary operation is to clip the principal singular value to \tau (if it exceeds \tau). After clipping the principal singular value, if there remain singular values greater than \tau, the largest among them will become the new principal singular value. Hence, by repeatedly “clipping the principal singular value to \tau”, one can realize \mathop{\mathrm{mclip}}.

Since the principal singular value and principal singular vectors can be efficiently obtained via power iteration (denoted as \mathop{\text{SVD1}}), principal singular value clipping can be considered efficient. Moreover, assuming the training is smooth enough, we can perform only one principal singular value clipping per step, which can approximately achieve the same effect. Based on this strategy, the two schemes for constraining singular values can be further written as \begin{aligned} \text{Post Clip:}\quad\boldsymbol{W}_t =&\, \tilde{\boldsymbol{W}}_t - \max(\sigma_1 - \tau, 0) \boldsymbol{u}_1 \boldsymbol{v}_1^{\top},\quad\sigma_1, \boldsymbol{u}_1, \boldsymbol{v}_1 = \mathop{\text{SVD1}}(\tilde{\boldsymbol{W}}_t),\quad\tilde{\boldsymbol{W}}_t = \boldsymbol{W}_{t-1} - \eta \boldsymbol{\Phi}_t \\[5pt] \text{Pre Decay:}\quad\boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \lambda\eta\sigma_1 \boldsymbol{u}_1 \boldsymbol{v}_1^{\top} - \eta \boldsymbol{\Phi}_t,\quad\sigma_1, \boldsymbol{u}_1, \boldsymbol{v}_1 = \mathop{\text{SVD1}}(\boldsymbol{W}_{t-1}) \end{aligned} Among them, the “Pre Decay” version is exactly the spectral weight decay introduced in From Spectral Norm Gradients to a New Type of Weight Decay. After more than a year, we have arrived at the same result from a different path. As for the “Post Clip” version, @_arohan_ once mentioned it on X, calling it “Wion”. In practice, because only one singular value is clipped per step, there may exist some “ambitious” matrices whose spectral norm significantly deviates from the set threshold. This is normal, and it will gradually decrease during the learning rate decay phase.

Other Details

If one wants more precise clipping, we can also use power iteration to simultaneously compute the Top-k singular values and singular vectors, clipping at most k singular values per step. The cost is replacing the L2 normalization in power iteration with QR decomposition. There are also some acceleration techniques for QR decomposition; the relevant principles can be found in the streaming power iteration series of articles, such as Streaming Power Iteration-based Muon Implementation: 1. First Acquaintance.

Apart from the spectral norm of linear layer matrices, in Beyond MuP: 3. Special Cases, Special Treatment we encountered other layers with different norms. For example, Embedding and LM Head correspond to the maximum row RMS and maximum column RMS, respectively, while the gamma parameter of the RMS Norm layer corresponds to the maximum absolute value, also known as the infinity norm of a vector.

Fortunately, the clip operator \textcolor{skyblue}{\lfloor}\boldsymbol{\omega}\textcolor{skyblue}{\rfloor}_{\textcolor{skyblue}{\Vert\cdot\Vert\leq\tau}} under these norms is relatively easy to compute. For instance, for the Embedding layer, whose norm is the maximum over row RMS, the clipping operator clips the RMS of each row vector to not exceed \tau. The LM Head is analogous, just with rows replaced by columns. As for the gamma parameter, it is even simpler: it is directly element-wise clipping \mathop{\text{clip}}(\boldsymbol{\gamma};-\tau,\tau) = \max(\min(\boldsymbol{\gamma},\tau),-\tau).

These conclusions are all intuitive and their proofs are relatively straightforward, so we will not elaborate; they serve as exercises for the reader.

Necessary Guarantee

Some readers may wonder: is it really necessary to make it so complicated? Can’t we just use ordinary weight decay, as in Training Deep Learning Models with Norm-Constrained LMOs? For example, \boldsymbol{W}_t = (1-\eta\lambda)\boldsymbol{W}_{t-1} - \eta\sqrt{\frac{d_{out}}{d_{in}}}\mathop{\mathrm{msign}}(\boldsymbol{G}_t)\label{eq:muon-wd} This can also constrain the spectral norm to within \tau = \frac{1}{\lambda}\sqrt{\frac{d_{out}}{d_{in}}}. Why not use such a simple form?

The answer is: to avoid excessive intervention. From the definition [eq:nclip], we can see that the clip operator we defined is the operation that minimally modifies the original parameters while achieving the same effect. For the spectral norm, directly multiplying by 1-\eta\lambda can also make the spectral norm of \boldsymbol{W}_{t-1} not exceed (1-\eta\lambda)\Vert\boldsymbol{W}_{t-1}\Vert_2, but because it is different from the minimal-modification \mathop{\mathrm{mclip}}, it necessarily involves some degree of “excessive intervention”.

The consequences of excessive intervention are twofold: either, to ensure performance, one chooses a smaller \lambda, which makes \tau too large, i.e., fails to guarantee the spectral norm stays within the desired range; or, to ensure control over the spectral norm, one chooses a larger \lambda, but this will significantly degrade performance. For example, if d_{in}=d_{out} and we want the spectral norm not to exceed 5, then according to the formula \lambda=0.2. For the Muon of Equation [eq:muon-wd], a weight decay coefficient of 0.2 is extremely large (the typical value is around 0.01).

Note that we have repeatedly emphasized “guarantee”, which is crucial. Suppose we use weight decay with a coefficient of 0.01; theoretically the spectral norm can reach up to 100, but experiments on small models may find it stays below 5. This is quite common. However, the safety of small models does not imply the safety of large models. As we have said before, large models are powerful enough to amplify any subtle bug. If the theoretical upper bound is 100, small models may not have the chance to hit it, but large models really might.

Therefore, it is essential to ensure that the key norm of parameters remains theoretically within a reasonable bound, which is an embodiment of the “stability” in the principle of “seeking speed while maintaining stability”. The clip operator defined by Equation [eq:nclip] is the most “lightweight” operation that guarantees boundedness; in other words, it is probably the operation with the least impact on performance while ensuring the same bound.

Summary

This article, based on the idea of minimal modification, proposes a general framework for maintaining parameter stability during training, encompassing two schemes: Post Clip and Pre Decay. Under the spectral norm, they further evolve into singular value clipping and spectral weight decay. These operations aim to ensure that key parameter norms remain bounded while minimizing interference with the training dynamics.

To reprint, please include the address of this article: https://kexue.fm/archives/11729

For more detailed reprint guidelines, please refer to: Science Space FAQ