English (unofficial) translations of posts at kexue.fm
Source

Beyond MuP: 3. Special Cases Require Special Handling

Translated by DeepSeek V4 Pro. Translations can be inaccurate, please refer to the original post for important stuff.

After so many related blog posts, many readers are likely familiar with the Muon optimizer—even if not with the theoretical details, they probably have the impression that it is an “optimizer specifically designed for matrix parameters.” However, this statement is not entirely accurate—for instance, for the input Embedding layer and the output LM Head, although their parameters are also matrices, they are not suitable for Muon (refer to Muon Optimizer Guide: Quick Start and Key Details).

Why should they be “treated differently”? This article will use the three stability criteria proposed in the first post to explore the initialization patterns of different types of layers and their corresponding steepest descent directions, thereby answering this question.

Review

In the first article Beyond MuP: 1. Three Characteristics of a Good Model, we proposed three stability criteria: \begin{aligned} &\text{Forward Stability:}\quad\max_{\boldsymbol{x}} \Vert \boldsymbol{f}(\boldsymbol{x};\boldsymbol{\omega})\Vert_{\mathrm{RMS}} = \Theta(1) \label{eq:cc1} \\[5pt] &\text{Dependence Stability:}\quad\max_{\boldsymbol{x}_1,\boldsymbol{x}_2} \frac{\Vert \boldsymbol{f}(\boldsymbol{x}_1;\boldsymbol{\omega}) - \boldsymbol{f}(\boldsymbol{x}_2;\boldsymbol{\omega})\Vert_{\mathrm{RMS}}}{\Vert\boldsymbol{x}_1 - \boldsymbol{x}_2\Vert_{\mathrm{RMS}}} = \Theta(1) \label{eq:cc2} \\[5pt] &\text{Update Stability:}\quad\max_{\boldsymbol{x}} \Vert \boldsymbol{f}(\boldsymbol{x};\boldsymbol{\omega} + \Delta\boldsymbol{\omega}) - \boldsymbol{f}(\boldsymbol{x};\boldsymbol{\omega})\Vert_{\mathrm{RMS}} = \Theta(1) \label{eq:cc3} \end{aligned} The three criteria share a unified format: compute the RMS of the output, then take \max over the input. Here \boldsymbol{x} denotes the input, \boldsymbol{\omega} the parameters, and \boldsymbol{f}(\boldsymbol{x};\boldsymbol{\omega}) can represent a layer, a block, or even the entire model, depending on our ability to compute the maximum.

Since the domain of \boldsymbol{x} is not bounded, the maximum does not always exist; sometimes we need to supplement the model with additional operations, which in turn guides model design. For example, in the previous article Beyond MuP: 2. Linear Layers and Steepest Descent, to compute the stability criteria of linear layers, we added an In Norm. Moreover, by combining the steepest descent idea, we recovered the derivation of the Muon optimizer.

Steepest descent is not a new concept; it answers the question “given the stability criterion, which optimizer should be used?” The core contribution of the “Beyond MuP” series is to answer the question “which stability criterion should be used?” and to provide a formula for computing the stability criterion applicable to arbitrary layers.

The Embedding Layer

Now consider the Embedding layer, which is perhaps the simplest layer: the input is an index i, and the output is the corresponding vector, i.e., \boldsymbol{f}(i;\boldsymbol{E}) = \boldsymbol{E}_i, where \boldsymbol{E} is a |V|\times d matrix, \boldsymbol{E}_i \triangleq \boldsymbol{E}_{i,:} denoting the i-th row of \boldsymbol{E}. It is easy to compute \begin{aligned} &\text{Forward Stability:}\quad\max_i \Vert\boldsymbol{E}_i\Vert_{\mathrm{RMS}} = \Theta(1)\\[5pt] &\text{Update Stability:}\quad\max_i \Vert \Delta \boldsymbol{E}_i\Vert_{\mathrm{RMS}} = \Theta(1) \label{eq:ec3} \end{aligned} Note that there is no “dependence stability” here, because the input to the Embedding layer is a discrete token ID, and we cannot subtract two IDs. Of course, one can convert to one-hot and force computation, but that brings no new signal. Moreover, from the perspective of backpropagation, we do not need to propagate further through token IDs, so there is no need to consider its stability.

The remaining two stability results concern the maximum row norm of \boldsymbol{E} or \Delta\boldsymbol{E} (times 1/\sqrt{d}). We use forward and dependence stability only as initialization guidelines: they tell us to initialize \boldsymbol{E} with zero mean and \Theta(1) variance. As for update stability, Eq.[eq:ec3] tells us that, although both are matrices, the stability metric for the Embedding layer should not be the spectral norm but rather the maximum row norm, which implies that its steepest descent is not Muon.

To find the steepest descent for the Embedding layer, we need to solve the optimization problem \min_{\Delta \boldsymbol{E}} \langle\boldsymbol{G},\Delta\boldsymbol{E}\rangle \qquad \text{s.t.}\qquad \max_i \underbrace{\Vert\Delta\boldsymbol{E}_i\Vert_{\mathrm{RMS}}}_{\Vert\Delta\boldsymbol{E}_i\Vert_2/\sqrt{d}}\leq\eta This problem is not difficult; we simply use the Cauchy-Schwarz inequality: \langle\boldsymbol{G},\Delta\boldsymbol{E}\rangle = \sum_{i=1}^{|V|}\langle\boldsymbol{G}_i,\Delta\boldsymbol{E}_i\rangle \geq -\sum_{i=1}^{|V|}\Vert\boldsymbol{G}_i\Vert_2 \times \Vert\Delta\boldsymbol{E}_i\Vert_2 \geq -\eta\sqrt{d}\sum_{i=1}^{|V|}\Vert\boldsymbol{G}_i\Vert_2 Equality holds when \Delta\boldsymbol{E}_i = - \eta\boldsymbol{G}_i / \Vert\boldsymbol{G}_i\Vert_{\mathrm{RMS}}. That is, the suitable steepest descent for the Embedding layer is performing row-wise RMS Norm on the gradient (Normalized SGD).

The Output Head

Next, consider the LM Head. On the surface, it is also a linear layer: the input is \boldsymbol{x}\in\mathbb{R}^d, the weight is \boldsymbol{W}\in\mathbb{R}^{d\times |V|}, and the output is \boldsymbol{x}\boldsymbol{W}\in\mathbb{R}^{|V|}. Usually \boldsymbol{x} also undergoes RMS Norm, so in every respect it looks like a linear layer. Why is it not suitable for Muon?

Responsible for the Loss

The answer is, the LM Head needs to be “responsible” for the Loss.

Keep in mind that steepest descent serves training. From the inference perspective, the model takes several tokens to predict the next token; but from the training perspective, the true “model” is: input several tokens together with the next token to compute the Loss. In other words, both the data and the label are essentially inputs, and the real output is the Loss. For earlier layers, we can ignore the label and the Loss, but the LM Head, being the last layer adjacent to the Loss, must account for the influence of the label and the Loss.

Thus, the input to the LM Head becomes \boldsymbol{x} and the next token ID t, and the output becomes the cross-entropy loss, i.e., \ell(\boldsymbol{x},t;\boldsymbol{W}) = \log\sum_{i=1}^{|V|} e^{\langle \boldsymbol{x},\boldsymbol{w}_i\rangle} - \langle \boldsymbol{x},\boldsymbol{w}_t\rangle = \log\sum_{i=1}^{|V|} e^{\langle \boldsymbol{x},\boldsymbol{w}_i - \boldsymbol{w}_t\rangle} where \boldsymbol{w}_i\triangleq \boldsymbol{W}_{:, i} is the i-th column of \boldsymbol{W}. Because \ell is a complex nonlinear function of \boldsymbol{x},t,\boldsymbol{W}, its three criteria cannot be computed exactly; our goal is to obtain a bound that is as tight as possible.

Forward Stability

First, the relatively simple forward stability; a simple bound gives \begin{aligned} \ell(\boldsymbol{x},t;\boldsymbol{W}) = \log\sum_{i=1}^{|V|} e^{\langle \boldsymbol{x},\boldsymbol{w}_i - \boldsymbol{w}_t\rangle} \leq&\, \log \left(|V| \max_i e^{\langle \boldsymbol{x},\boldsymbol{w}_i - \boldsymbol{w}_t\rangle}\right) \\ =&\, \log |V| + \max_i \langle \boldsymbol{x},\boldsymbol{w}_i - \boldsymbol{w}_t\rangle \\ \leq &\, \log |V| + \max_i \Vert\boldsymbol{x}\Vert_2 \Vert\boldsymbol{w}_i - \boldsymbol{w}_t\Vert_2 \end{aligned} Hence \begin{aligned} \text{Forward Stability:}\quad\max_{t, \Vert\boldsymbol{x}\Vert_{\mathrm{RMS}}=1} \ell(\boldsymbol{x},t;\boldsymbol{W}) \leq&\, \log |V| + d\max_{i,t} \Vert\boldsymbol{w}_i - \boldsymbol{w}_t\Vert_{\mathrm{RMS}} \\ \leq&\, \log |V| + 2d\max_i \Vert\boldsymbol{w}_i\Vert_{\mathrm{RMS}} \end{aligned} If we drop the constant \log|V|, the inequality becomes a lower bound, so this bound is asymptotically fairly tight. To make it \Theta(1), the initialization variance of the LM Head should be chosen as \Theta(1/d^2).

An Important Inequality

For the remaining two criteria, because they involve differences, the computation is a bit more involved. We first prove an inequality we will need: \left|\log\sum_{i=1}^n e^{a_i} - \log\sum_{i=1}^n e^{b_i}\right| \leq \max_i |a_i - b_i|\label{leq:lse-ab} The proof is not difficult but requires a small trick: let the right-hand side be M. By monotonicity of \log,\sum,\exp, \log\sum_{i=1}^n e^{a_i} = \log\sum_{i=1}^n e^{(a_i - b_i)+b_i} \leq \log\sum_{i=1}^n e^{M + b_i} = M + \log\sum_{i=1}^n e^{b_i} This proves \log\sum_{i=1}^n e^{a_i} - \log\sum_{i=1}^n e^{b_i} \leq M By symmetry, swapping a_i and b_i also holds, thus proving the original inequality.

Dependence Stability

Using inequality [leq:lse-ab] and Cauchy-Schwarz, \begin{aligned} \frac{|\ell(\boldsymbol{x}_1,t;\boldsymbol{W}) - \ell(\boldsymbol{x}_2,t;\boldsymbol{W})|}{\Vert \boldsymbol{x}_1 - \boldsymbol{x}_2\Vert_{\mathrm{RMS}}} \leq&\, \frac{\max_i |\langle \boldsymbol{x}_1,\boldsymbol{w}_i - \boldsymbol{w}_t\rangle - \langle \boldsymbol{x}_2,\boldsymbol{w}_i - \boldsymbol{w}_t\rangle|}{\Vert \boldsymbol{x}_1 - \boldsymbol{x}_2\Vert_{\mathrm{RMS}}} \\ =&\, \frac{\max_i |\langle \boldsymbol{x}_1 - \boldsymbol{x}_2,\boldsymbol{w}_i - \boldsymbol{w}_t\rangle|}{\Vert \boldsymbol{x}_1 - \boldsymbol{x}_2\Vert_{\mathrm{RMS}}} \\ \leq&\, \frac{\max_i d\, \Vert \boldsymbol{x}_1 - \boldsymbol{x}_2\Vert_{\mathrm{RMS}} \Vert \boldsymbol{w}_i - \boldsymbol{w}_t\Vert_{\mathrm{RMS}}}{\Vert \boldsymbol{x}_1 - \boldsymbol{x}_2\Vert_{\mathrm{RMS}}} \\ =&\, d\max_i \Vert \boldsymbol{w}_i - \boldsymbol{w}_t\Vert_{\mathrm{RMS}} \end{aligned} Therefore \begin{aligned} \text{Dependence Stability:}\quad\max_{\begin{gathered}t \\ \Vert\boldsymbol{x}_1\Vert_{\mathrm{RMS}}=1 \\ \Vert\boldsymbol{x}_2\Vert_{\mathrm{RMS}}=1\end{gathered}} |\ell(\boldsymbol{x}_1,t;\boldsymbol{W}) - \ell(\boldsymbol{x}_2,t;\boldsymbol{W})| \leq&\, d\max_{i,t} \Vert\boldsymbol{w}_i - \boldsymbol{w}_t\Vert_{\mathrm{RMS}}\\ \leq&\, 2d\max_i \Vert\boldsymbol{w}_i\Vert_{\mathrm{RMS}} \end{aligned} The result coincides with forward stability. The calculation here is similar to the Embedding layer: since the label t is a discrete ID, we do not consider its backpropagation, so the denominator only involves the difference of \boldsymbol{x}_1 and \boldsymbol{x}_2.

Update Stability

Finally, the update stability, again using inequality [leq:lse-ab] and Cauchy-Schwarz: \begin{aligned} |\ell(\boldsymbol{x},t;\boldsymbol{W} + \Delta\boldsymbol{W}) - \ell(\boldsymbol{x},t;\boldsymbol{W})| \leq&\, \max_i |\langle \boldsymbol{x},\boldsymbol{w}_i + \Delta\boldsymbol{w}_i - \boldsymbol{w}_t - \Delta\boldsymbol{w}_t\rangle - \langle \boldsymbol{x},\boldsymbol{w}_i - \boldsymbol{w}_t\rangle| \\ =&\, \max_i |\langle \boldsymbol{x},\Delta\boldsymbol{w}_i - \Delta\boldsymbol{w}_t\rangle| \\ \leq &\, d \max_i \Vert\boldsymbol{x}\Vert_{\mathrm{RMS}} \Vert\Delta\boldsymbol{w}_i - \Delta\boldsymbol{w}_t\Vert_{\mathrm{RMS}} \end{aligned} Thus \begin{aligned} \text{Update Stability:}\quad\max_{t,\Vert\boldsymbol{x}\Vert_{\mathrm{RMS}}=1} |\ell(\boldsymbol{x},t;\boldsymbol{W} + \Delta\boldsymbol{W}) - \ell(\boldsymbol{x},t;\boldsymbol{W})| \leq&\, d \max_{i, t} \Vert\Delta\boldsymbol{w}_i - \Delta\boldsymbol{w}_t\Vert_{\mathrm{RMS}} \\ \leq&\, 2d\max_i \Vert\Delta\boldsymbol{w}_i\Vert_{\mathrm{RMS}} \end{aligned}

It is easy to see that the three stability criteria of the LM Head are essentially the same as those of the Embedding layer: they are all the maximum row/column norm of the parameter matrix or its update, which means the steepest descent for the LM Head is also Normalized SGD—the difference being that it is column-wise RMS Norm. Moreover, all three criteria for the LM Head carry a factor of d, so its initialization standard deviation and learning rate are scaled by \Theta(1/d), while for Embedding it is \Theta(1). This implies slightly different behavior when scaling across width.

Other Modules

Besides linear layers, Embedding, and LM Head, typical Transformer models also have some other parameters or layers that require separate analysis. Let us go through them one by one.

Hadamard Product

As we know, RMS Norm is often followed by a multiplicative vector \boldsymbol{\gamma} (Hadamard product), i.e., (\boldsymbol{x} / \Vert\boldsymbol{x}\Vert_{\mathrm{RMS}})\odot\boldsymbol{\gamma}, to control the output scale. This parameter is not a matrix, so the preceding Muon or Normalized SGD does not directly apply.

One can follow the original definition to compute the three stability criteria for \boldsymbol{\gamma} step by step and then analyze its initialization and steepest descent. But there is a more clever way: notice that \boldsymbol{x}\odot\boldsymbol{\gamma}=\boldsymbol{x}\mathop{\mathrm{diag}}(\boldsymbol{\gamma}), i.e., the Hadamard product of \boldsymbol{x} and \boldsymbol{\gamma} equals the matrix product of \boldsymbol{x} and the diagonal matrix \mathop{\mathrm{diag}}(\boldsymbol{\gamma}). This turns it into a special linear layer with \boldsymbol{W}=\mathop{\mathrm{diag}}(\boldsymbol{\gamma}), allowing us to reuse the conclusions for linear layers.

According to the previous article, the initial spectral norm of \boldsymbol{W} should be \Theta(\sqrt{d_{\mathrm{out}}/d_{\mathrm{in}}}). Here \boldsymbol{W} is a square matrix, so this gives exactly \Theta(1). Since \boldsymbol{W} is a diagonal matrix, we can simply initialize it as the identity matrix to satisfy this requirement, which corresponds to initializing \boldsymbol{\gamma} to all ones.

As for the optimizer, let the gradient of \boldsymbol{\gamma} be \boldsymbol{g}; then the gradient of \boldsymbol{W} is \boldsymbol{G}=\mathop{\mathrm{diag}}(\boldsymbol{g}). We already know that the steepest descent for a linear layer is Muon, i.e., \Delta\boldsymbol{W}=-\eta\mathop{\mathrm{msign}}(\boldsymbol{G}). For a diagonal matrix, \mathop{\mathrm{msign}}(\boldsymbol{G})=\mathop{\mathrm{sign}}(\boldsymbol{G})=\mathop{\mathrm{diag}}(\mathop{\mathrm{sign}}(\boldsymbol{g})), so the steepest descent for the \boldsymbol{\gamma} parameter is SignSGD.

Linear Bias Term

Traditional linear layers often include a bias vector \boldsymbol{b}, i.e., the full linear operation is \boldsymbol{f}(\boldsymbol{x};\boldsymbol{W},\boldsymbol{b}) = \boldsymbol{x}\boldsymbol{W}+\boldsymbol{b}. However, open-source models in recent years have mostly removed the bias term, so it has little presence nowadays. For completeness, we still supplement its discussion.

After adding the bias vector, the three stability criteria become \begin{aligned} &\text{Forward Stability:}\quad\max_{\Vert\boldsymbol{x}\Vert_{\mathrm{RMS}}=1} \Vert \boldsymbol{x}\boldsymbol{W} + \boldsymbol{b}\Vert_{\mathrm{RMS}} \\[5pt] &\text{Dependence Stability:}\quad\max_{\Vert\boldsymbol{x}_1\Vert_{\mathrm{RMS}}=\Vert\boldsymbol{x}_2\Vert_{\mathrm{RMS}}=1} \frac{\Vert \boldsymbol{x}_1\boldsymbol{W} - \boldsymbol{x}_2\boldsymbol{W}\Vert_{\mathrm{RMS}}}{\Vert \boldsymbol{x}_1 - \boldsymbol{x}_2\Vert_{\mathrm{RMS}}}\\[5pt] &\text{Update Stability:}\quad\max_{\Vert\boldsymbol{x}\Vert_{\mathrm{RMS}}=1} \Vert \boldsymbol{x} \Delta\boldsymbol{W} + \Delta\boldsymbol{b}\Vert_{\mathrm{RMS}} \end{aligned} The dependence stability is the same as without bias, so we only need to examine forward and update stability. For simplicity, we use the inequality \Vert \boldsymbol{x}\boldsymbol{W} + \boldsymbol{b}\Vert_{\mathrm{RMS}}\leq \Vert \boldsymbol{x}\boldsymbol{W}\Vert_{\mathrm{RMS}} + \Vert\boldsymbol{b}\Vert_{\mathrm{RMS}}. Assuming \boldsymbol{W} continues to use its original initialization, the \Vert \boldsymbol{x}\boldsymbol{W}\Vert_{\mathrm{RMS}} part already achieves \Theta(1); thus we only need \Vert\boldsymbol{b}\Vert_{\mathrm{RMS}}=\mathcal{O}(1). In practice, for simplicity, \boldsymbol{b} is usually initialized to zeros.

Similarly, \Vert \boldsymbol{x}\Delta\boldsymbol{W} + \Delta\boldsymbol{b}\Vert_{\mathrm{RMS}}\leq \Vert \boldsymbol{x}\Delta\boldsymbol{W}\Vert_{\mathrm{RMS}} + \Vert\Delta\boldsymbol{b}\Vert_{\mathrm{RMS}}. Let \Vert\Delta\boldsymbol{b}\Vert_{\mathrm{RMS}}=\mathcal{O}(1); then the \boldsymbol{b} parameter will perform steepest descent with \Vert\Delta\boldsymbol{b}\Vert_{\mathrm{RMS}} as the stability metric, which also results in Normalized SGD.

Attention Scaling

Using the forward stability criterion, we can also re-derive the scaling factor of the Attention mechanism. Let \boldsymbol{q}=\boldsymbol{x}\boldsymbol{W}_q,\boldsymbol{k}=\boldsymbol{x}\boldsymbol{W}_k. If \boldsymbol{W}_q,\boldsymbol{W}_k are treated as linear layers, we can assume they already achieve \Vert\boldsymbol{q}\Vert_{\mathrm{RMS}}=\Theta(1) and \Vert\boldsymbol{k}\Vert_{\mathrm{RMS}}=\Theta(1). Then by Cauchy-Schwarz, |\langle\boldsymbol{q},\boldsymbol{k}\rangle| \leq \Vert\boldsymbol{q}\Vert_2 \Vert\boldsymbol{k}\Vert_2 = d\Vert\boldsymbol{q}\Vert_{\mathrm{RMS}} \Vert\boldsymbol{k}\Vert_{\mathrm{RMS}} where d is the dimensionality of \boldsymbol{q},\boldsymbol{k}, i.e., the head dimension. Clearly the right-hand side is \Theta(d); to make it \Theta(1), we need to multiply \boldsymbol{q}\cdot\boldsymbol{k} by a scaling factor on the order of \Theta(1/d). This differs from the previously common 1/\sqrt{d} (refer to A Brief Discussion on Transformer Initialization, Parameterization and Normalization).

Which one is correct? Actually both are. 1/\sqrt{d} is the average result under random initialization, while \Theta(1/d) is the extreme value applicable throughout training. It does not mean we should simply change the scaling factor to 1/d, but rather that scaling inversely with d may yield better transferability; the two are compatible. For example, suppose at d=128 a scaling factor of 1/\sqrt{128} works well; when transferring to d=256, one may consider changing the factor to 1/2\sqrt{128} instead of 1/\sqrt{256}.

In fact, constrained by Flash Attention, the choice of head dimension is not large, typically 128 and at most 256, so in practice there is almost no parameter transfer across head dimensions. Therefore this result is more of theoretical interest.

Summary

Finally, the main results of these two articles are summarized in the following table:

Here, the steepest descent directions for Embedding and LM Head are row/column Normalized SGD, consistent with works such as Scion. As for the scaling rules of variance and learning rate, they agree with the conclusions of MuP. In these two articles, both are derived from the three stability criteria we proposed, indicating that we have indeed found a unified form for measuring the stability of arbitrary layers.

For reprints, please include the address of this article: https://kexue.fm/archives/11647

For more detailed reprint information, please refer to: Science Space FAQ