In the previous article Beyond MuP: 1. Three Characteristics of Good Models, we proposed three core metrics — forward stability, dependency stability, and update stability — and provided their mathematical definitions. At the same time, we proposed to characterize the quality of a model by whether they satisfy \Theta(1), which will serve as the theoretical foundation for our subsequent analysis and computation. Next, we will combine them with the idea of steepest descent to customize update rules that are both stable and fast for each parameter.
\begin{aligned} &\text{Forward Stability:}\quad\max_{\boldsymbol{x}} \Vert \boldsymbol{f}(\boldsymbol{x};\boldsymbol{\omega})\Vert_{RMS} = \Theta(1) \label{eq:c1} \\[5pt] &\text{Dependency Stability:}\quad\max_{\boldsymbol{x}_1,\boldsymbol{x}_2} \frac{\Vert \boldsymbol{f}(\boldsymbol{x}_1;\boldsymbol{\omega}) - \boldsymbol{f}(\boldsymbol{x}_2;\boldsymbol{\omega})\Vert_{RMS}}{\Vert\boldsymbol{x}_1 - \boldsymbol{x}_2\Vert_{RMS}} = \Theta(1) \label{eq:c2} \\[5pt] &\text{Update Stability:}\quad\max_{\boldsymbol{x}} \Vert \boldsymbol{f}(\boldsymbol{x};\boldsymbol{\omega} + \Delta\boldsymbol{\omega}) - \boldsymbol{f}(\boldsymbol{x};\boldsymbol{\omega})\Vert_{RMS} = \Theta(1) \label{eq:c3} \end{aligned}
We take the linear layer as the first example. The result should be familiar to some readers: it is the Muon optimizer that has gradually gained popularity over the past year. Of course, our goal is not to rediscover Muon, but to demonstrate the process of designing models and optimizers from first principles, providing a unified methodology for handling other parameters in the future.
Linear Transformations
For a linear layer, the input is a vector \boldsymbol{x}\in\mathbb{R}^{d_{in}}, the parameter is a matrix \boldsymbol{W}\in\mathbb{R}^{d_{in}\times d_{out}}, and the model is \boldsymbol{f}(\boldsymbol{x};\boldsymbol{W})=\boldsymbol{x}\boldsymbol{W}. Note that in the definitions of the three metrics, we did not restrict \boldsymbol{x} to be bounded, so for a naive linear layer, the three metrics may not exist. For example, \max\limits_{\boldsymbol{x}}\Vert\boldsymbol{x}\boldsymbol{W}\Vert_{RMS} is generally infinite. To address this, we simply add some operations to the model that make the results bounded, such as: \begin{aligned} &\text{In Norm:}\quad \mathop{\mathrm{Norm}}(\boldsymbol{x})\boldsymbol{W} \\[5pt] &\text{Out Norm:}\quad \mathop{\mathrm{Norm}}(\boldsymbol{x}\boldsymbol{W}) \end{aligned} where \mathop{\mathrm{Norm}}(\boldsymbol{x}) = \boldsymbol{x} / \Vert\boldsymbol{x}\Vert_{RMS}, and we omit the gamma parameter of RMS Norm, assuming its effect is secondary. We know that there are two common ways to use residuals: Pre Norm and Post Norm. Pre Norm obviously corresponds to In Norm. However, here we point out that Post Norm is actually also In Norm: \begin{aligned} &\text{Pre Norm:}\quad \boldsymbol{x}_{t+1} = \boldsymbol{x}_t + \boldsymbol{F}_t(\mathop{\mathrm{Norm}}(\boldsymbol{x}_t)) \\[5pt] &\text{Post Norm:} \quad \boldsymbol{x}_{t+1} = \mathop{\mathrm{Norm}}(\underbrace{\boldsymbol{x}_t + \boldsymbol{F}_t(\boldsymbol{x}_t)}_{\text{denoted }\boldsymbol{y}_{t+1}}) \quad \Rightarrow\quad \boldsymbol{y}_{t+1} = \mathop{\mathrm{Norm}}(\boldsymbol{y}_t) + \boldsymbol{F}_t(\mathop{\mathrm{Norm}}(\boldsymbol{y}_t)) \end{aligned} So compared to Pre Norm, Post Norm merely replaces \boldsymbol{x}_t + \boldsymbol{F}_t(\mathop{\mathrm{Norm}}(\boldsymbol{x}_t)) with \mathop{\mathrm{Norm}}(\boldsymbol{x}_t) + \boldsymbol{F}_t(\mathop{\mathrm{Norm}}(\boldsymbol{x}_t)). For \boldsymbol{F}_t, both are In Norm; this article takes In Norm as the example.
Apart from Out Norm, In Norm has an additional advantage: more room for speedup, because (\boldsymbol{x} / \Vert\boldsymbol{x}\Vert_{RMS})\boldsymbol{W}=\boldsymbol{x}\boldsymbol{W} / \Vert\boldsymbol{x}\Vert_{RMS}. In theory, \boldsymbol{x}\boldsymbol{W} and \Vert\boldsymbol{x}\Vert_{RMS} can be computed in parallel and then divided at the end, reducing latency. This idea is reflected in works such as FlashNorm: fast normalization for LLMs, Block-level AI Operator Fusion, and Superoptimizing RMSNorm and Linear.
Initial Variance
According to the discussion in the previous section, we agree to only consider linear layers with In Norm. Then by the definition of spectral norm, we can compute the three metrics: \begin{aligned} &\text{Forward Stability:}\quad\max_{\Vert\boldsymbol{x}\Vert_{RMS}=1} \Vert \boldsymbol{x}\boldsymbol{W}\Vert_{RMS} = \sqrt{\frac{d_{in}}{d_{out}}}\Vert\boldsymbol{W}\Vert_2 \\[5pt] &\text{Dependency Stability:}\quad\max_{\Vert\boldsymbol{x}_1\Vert_{RMS}=\Vert\boldsymbol{x}_2\Vert_{RMS}=1} \frac{\Vert \boldsymbol{x}_1\boldsymbol{W} - \boldsymbol{x}_2\boldsymbol{W}\Vert_{RMS}}{\Vert \boldsymbol{x}_1 - \boldsymbol{x}_2\Vert_{RMS}} = \sqrt{\frac{d_{in}}{d_{out}}}\Vert\boldsymbol{W}\Vert_2 \\[5pt] &\text{Update Stability:}\quad\max_{\Vert\boldsymbol{x}\Vert_{RMS}=1} \Vert \boldsymbol{x}(\boldsymbol{W} + \Delta\boldsymbol{W}) - \boldsymbol{x}\boldsymbol{W}\Vert_{RMS} = \sqrt{\frac{d_{in}}{d_{out}}}\Vert\Delta\boldsymbol{W}\Vert_2 \end{aligned} where \Vert\cdot\Vert_2 on a matrix denotes its spectral norm. As can be seen, all three metrics are some variants of the spectral norm, or more precisely, the three metrics we proposed are essentially generalizations starting from the spectral norm.
The first two metrics are functions of \boldsymbol{W}; for linear layers they happen to be the same. If we want them to be \Theta(1), then we have \Vert\boldsymbol{W}\Vert_2 = \Theta(\sqrt{d_{out}/d_{in}}), which at least imposes a requirement on the initialization of \boldsymbol{W}. According to Fast Estimation of the Spectral Norm of Random Matrices, for a standard normal matrix of size d_{in}\times d_{out}, its spectral norm is roughly \sqrt{d_{in}} + \sqrt{d_{out}}. Therefore, for the initialization to satisfy \Vert\boldsymbol{W}\Vert_2 = \Theta(\sqrt{d_{out}/d_{in}}), the initial variance \sigma^2 should satisfy \sigma = \Theta\left(\sqrt{\frac{d_{out}}{d_{in}}}\frac{1}{\sqrt{d_{in}} + \sqrt{d_{out}}}\right)
In addition, we could also consider constraining \Vert\boldsymbol{W}\Vert_2 throughout the optimization process. This has inspired some works, such as Steepest Descent on Manifolds: 4. Muon + Spectral Sphere and Controlled LLM Training on Spectral Sphere, which we will discuss later.
Steepest Descent
Next we mainly look at the “update stability” metric \sqrt{d_{in}/d_{out}}\Vert\Delta\boldsymbol{W}\Vert_2, which is a variant of the spectral norm of the parameter increment \Delta\boldsymbol{W}. As is well known, the update amount is determined by the optimizer, so this part provides guidance for the optimizer. According to the principle of “stable yet fast”, now “stability” is already in place; when is it fastest?
This is the question that steepest descent seeks to answer. Previously we have discussed this in articles like Sequel to Muon: Why Did We Choose to Try Muon?, Steepest Descent on Manifolds: 1. SGD + Hypersphere, and Steepest Descent on Manifolds: 2. Muon + Orthogonality. But for the completeness of this series, we still go through it again. Steepest descent refers to the update that makes the loss decrease the fastest under some constraint, formally defined as \min_{\Delta \boldsymbol{W}} \mathcal{L}(\boldsymbol{W} +\Delta\boldsymbol{W}) \qquad \text{s.t.}\qquad \rho(\Delta\boldsymbol{W})\leq \eta where \mathcal{L} is the loss function, and \rho(\Delta\boldsymbol{W}) is a stability metric for the increment \Delta\boldsymbol{W}, which we already have: \sqrt{d_{in}/d_{out}}\Vert\Delta\boldsymbol{W}\Vert_2. However, directly solving this problem is still too complicated. We need to replace \mathcal{L}(\boldsymbol{W} +\Delta\boldsymbol{W}) with its first-order approximation \mathcal{L}(\boldsymbol{W}) + \langle \boldsymbol{G}, \Delta\boldsymbol{W}\rangle_F to make the solution feasible. Then the problem becomes \min_{\Delta \boldsymbol{W}} \mathop{\mathrm{tr}}(\boldsymbol{G}^{\top}\Delta\boldsymbol{W}) \qquad \text{s.t.}\qquad \Vert\Delta\boldsymbol{W}\Vert_2\leq\eta\sqrt{\frac{d_{out}}{d_{in}}} where \boldsymbol{G}=\nabla_{\boldsymbol{W}}\mathcal{L}(\boldsymbol{W}) is the gradient of the loss, and we used the identity \langle \boldsymbol{G}, \Delta\boldsymbol{W}\rangle_F=\mathop{\mathrm{tr}}(\boldsymbol{G}^{\top}\Delta\boldsymbol{W}).
Solution Process
Further, we set \Delta\boldsymbol{W}=-\kappa \boldsymbol{\Phi}, rewriting the optimization objective as \max_{\kappa,\boldsymbol{\Phi}}\kappa\mathop{\mathrm{tr}}(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) \qquad \text{s.t.}\qquad 0\leq \kappa \leq \eta\sqrt{\frac{d_{out}}{d_{in}}}, \quad\Vert\boldsymbol{\Phi}\Vert_2=1 Clearly, the optimization of \kappa can be done separately, with the maximum attained at \kappa = \eta\sqrt{d_{out}/d_{in}}. So we only need to solve \max_{\boldsymbol{\Phi}} \mathop{\mathrm{tr}}(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) \qquad \text{s.t.}\qquad \Vert\boldsymbol{\Phi}\Vert_2=1 Next, let \boldsymbol{G} have an SVD \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top} = \sum_{i=1}^r \sigma_i \boldsymbol{u}_i \boldsymbol{v}_i^{\top}, where r is the rank of \boldsymbol{G}. We have \mathop{\mathrm{tr}}(\boldsymbol{G}^{\top}\boldsymbol{\Phi})=\mathop{\mathrm{tr}}\left(\sum_{i=1}^r \sigma_i \boldsymbol{v}_i \boldsymbol{u}_i^{\top}\boldsymbol{\Phi}\right) = \sum_{i=1}^r \sigma_i \boldsymbol{u}_i^{\top}\boldsymbol{\Phi}\boldsymbol{v}_i By definition, when \Vert\boldsymbol{\Phi}\Vert_2=1, we have \Vert\boldsymbol{\Phi}\boldsymbol{v}_i\Vert_2\leq \Vert\boldsymbol{v}_i\Vert_2=1, so \boldsymbol{u}_i^{\top}\boldsymbol{\Phi}\boldsymbol{v}_i\leq 1. Therefore \mathop{\mathrm{tr}}(\boldsymbol{G}^{\top}\boldsymbol{\Phi})\leq \sum_{i=1}^r \sigma_i = \Vert \boldsymbol{G}\Vert_* where \Vert\cdot\Vert_* is called the nuclear norm of the matrix. Equality holds when all \boldsymbol{u}_i^{\top}\boldsymbol{\Phi}\boldsymbol{v}_i equal 1, in which case \boldsymbol{\Phi} = \sum_{i=1}^r \boldsymbol{u}_i \boldsymbol{v}_i^{\top} = \boldsymbol{U}_{[:,:r]}\boldsymbol{V}_{[:,:r]}^{\top} = \mathop{\mathrm{msign}}(\boldsymbol{G})
Summary of Results
In summary, so far, starting from the three stability metrics, we have obtained at least two conclusions: first, the initialization variance \sigma^2 of parameter \boldsymbol{W} should satisfy \sigma = \Theta\left(\sqrt{\frac{d_{out}}{d_{in}}}\frac{1}{\sqrt{d_{in}} + \sqrt{d_{out}}}\right) second, its increment \Delta\boldsymbol{W} should take the following form \Delta\boldsymbol{W} = -\eta\sqrt{\frac{d_{out}}{d_{in}}}\mathop{\mathrm{msign}}(\boldsymbol{G}) This is exactly the MuP version of Muon (for differences between versions, refer to Muon Optimizer Guide: Quick Start and Key Details). In standard Muon, \boldsymbol{G} is replaced by its momentum, which can be seen as a smoother gradient estimate. In addition, regarding the constraint on \boldsymbol{W}, there is still work to be done, which we will explore in future articles.
Since we have already introduced MuP and Muon in several previous blog posts, these two results are not new. Thus, this article merely serves as a first case study to demonstrate the rationality of metrics [eq:c1],[eq:c2],[eq:c3]; they will provide a unified stability metric formula for the parameters and increments of any layer, thereby generalizing the conclusions of Muon.
Remaining Issues
Before generalizing, we still need to answer one question: all the previous derivations were based on In Norm design. So do we need to add In Norm to every linear layer? Without In Norm, can we still use Muon? To answer this, let’s borrow a passage from the previous article:
Here \boldsymbol{f}(\boldsymbol{x};\boldsymbol{\omega}) can be a layer, a block composed of several layers, or even the entire model. In theory, the coarser the granularity, the looser or more accurate the resulting constraints, but solving the \max also becomes more difficult. So this depends on our ability to compute the \max.
Simply put, the more accurate the calculation of the stability metrics, the better, but approximations are allowed. Therefore, without In Norm, the extent to which Muon is usable depends on how well the condition “\Vert\boldsymbol{x}\Vert_{RMS}=\text{some constant}” holds. For example, consider an FFN layer \boldsymbol{y}=\phi(\boldsymbol{x}\boldsymbol{W}_{up})\boldsymbol{W}_{down}. If we assume the activation function \phi has Lipschitz constant 1, then we still have \Vert\boldsymbol{y}\Vert_{RMS} \leq \Vert\boldsymbol{x}\Vert_{RMS} \times\sqrt{\frac{d_{in}}{d_{mid}}}\Vert\boldsymbol{W}_{up}\Vert_2\times \sqrt{\frac{d_{mid}}{d_{out}}}\Vert\boldsymbol{W}_{down}\Vert_2 where \boldsymbol{W}_{up}\in\mathbb{R}^{d_{in}\times d_{mid}},\boldsymbol{W}_{down}\in\mathbb{R}^{d_{mid}\times d_{out}}. In this way, even if we only apply RMS Norm to \boldsymbol{x}, for the second parameter \boldsymbol{W}_{down}, the same stability metric approximately holds, so Muon is also usable.
Similarly, even if no RMS Norm is applied at all, but if we still believe that “\Vert\boldsymbol{x}\Vert_{RMS}=\text{some constant}” can hold to some extent, then for the subsequent linear layer we can still try the Muon optimizer.
Conclusion
Starting from the three stability metrics of the previous article, this article demonstrates the process of “reproducing” conclusions related to MuP and Muon for linear layers. Next, we will apply this methodology to “customize” initialization and optimizer settings for parameters beyond linear layers.
For reprints, please include the address of this article: https://kexue.fm/archives/11605
For more detailed reprint matters, please refer to: Science Space FAQ