In the article “Muon Optimizer Guide: Quick Start and Key Details”, we listed several versions of Muon, whose differences lie in the scaling factor related to the matrix shape of the learning rate, where the “official version (KellerJordan version)” only adds a \max(1,\cdot) truncation operation compared to the “MuP version”. This article will discuss specifically: where does this truncation operation come from?
Several Versions
The update rule of Muon can be uniformly written as \begin{aligned} \boldsymbol{M}_t &= \beta \boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt] \boldsymbol{W}_t &= \boldsymbol{W}_{t-1} - \eta_t \bigl(\alpha \mathop{\text{msign}}(\boldsymbol{M}_t) + \lambda \boldsymbol{W}_{t-1}\bigr) \end{aligned} The differences among several versions lie in \alpha, which are: \alpha = \left\{ \begin{aligned} &1 && \textcolor{skyblue}{(\text{Naive version})} \\[5pt] &\sqrt{\max(1, d_{\text{out}}/d_{\text{in}})} && \textcolor{skyblue}{(\text{KellerJordan version})} \\[5pt] &\sqrt{d_{\text{out}}/d_{\text{in}}} && \textcolor{skyblue}{(\text{MuP version})} \\[5pt] &0.2 \times \sqrt{\max(d_{\text{out}}, d_{\text{in}})} && \textcolor{skyblue}{(\text{Moonlight version})} \end{aligned}\right. Here the matrix \boldsymbol{W}\in\mathbb{R}^{d_{\text{in}}\times d_{\text{out}}} represents the training parameters of a linear layer \boldsymbol{y}=\boldsymbol{x}\boldsymbol{W}, where the input \boldsymbol{x}\in\mathbb{R}^{d_{\text{in}}} is a row vector.
This article mainly concerns the “KellerJordan version” and the “MuP version”; the former adds an extra \max(1,) on top of the latter. According to our analysis in “Higher-Order MuP: Simpler but Smarter Spectral Condition Scaling” and “Above MuP: 2. Linear Layers and Steepest Descent”, under the spectral condition constraints relevant to MuP, the steepest descent should be the MuP version of Muon. How should we explain the extra \max(1,\cdot)?
Feature Increment
For simplicity, we omit the subscript t in the following discussion. Without loss of generality, assume the momentum \boldsymbol{M} is full-rank, then the singular values of \boldsymbol{\Phi} = \mathop{\text{msign}}(\boldsymbol{M}) are all 1, so when d_{\text{in}} \leq d_{\text{out}}, \boldsymbol{\Phi} \boldsymbol{\Phi}^{\top} = \boldsymbol{I}_{d_{\text{in}}}, and when d_{\text{in}} > d_{\text{out}}, \boldsymbol{\Phi}^{\top} \boldsymbol{\Phi} = \boldsymbol{I}_{d_{\text{out}}}.
Denote \Delta \boldsymbol{W} = \eta\alpha \boldsymbol{\Phi}. Our goal is to find the relationship between \alpha and d_{\text{in}}, d_{\text{out}}. From “Why Do We Prefer Isotropy? An Understanding Based on Steepest Descent” we know that parameters are merely by-products of the model, and changes at the feature level may be more fundamental. Transforming \Delta \boldsymbol{W} to the feature level gives \Delta \boldsymbol{y} = \boldsymbol{x} \Delta\boldsymbol{W} = \eta\alpha \boldsymbol{x}\boldsymbol{\Phi}, thus \Vert\Delta \boldsymbol{y}\Vert_{\mathrm{RMS}} = \alpha \Vert\boldsymbol{x}\boldsymbol{\Phi}\Vert_{\mathrm{RMS}}.
Next we discuss case by case. First, when d_{\text{in}} \leq d_{\text{out}}, \boldsymbol{\Phi} can be written as \boldsymbol{U}[\boldsymbol{I}_{d_{\text{in}}}, \boldsymbol{0}_{d_{\text{in}}\times (d_{\text{out}}-d_{\text{in}})}]\boldsymbol{V}^{\top}, where \boldsymbol{U}\in\mathbb{R}^{d_{\text{in}}\times d_{\text{in}}}, \boldsymbol{V}\in\mathbb{R}^{d_{\text{out}}\times d_{\text{out}}} are orthogonal matrices. Then \begin{aligned} \Vert\Delta \boldsymbol{y}\Vert_{\mathrm{RMS}} &= \eta\alpha\Big\Vert\boldsymbol{x}\boldsymbol{U}[\boldsymbol{I}_{d_{\text{in}}}, \boldsymbol{0}_{d_{\text{in}}\times (d_{\text{out}}-d_{\text{in}})}]\boldsymbol{V}^{\top}\Big\Vert_{\mathrm{RMS}} \\ &= \eta\alpha\Big\Vert\boldsymbol{x}\boldsymbol{U}[\boldsymbol{I}_{d_{\text{in}}}, \boldsymbol{0}_{d_{\text{in}}\times (d_{\text{out}}-d_{\text{in}})}]\Big\Vert_{\mathrm{RMS}} \\ &= \eta\alpha\Big\Vert[\boldsymbol{x}\boldsymbol{U}, \boldsymbol{0}_{d_{\text{out}}-d_{\text{in}}}]\Big\Vert_{\mathrm{RMS}} \\ &= \eta\alpha\sqrt{\frac{d_{\text{in}}}{d_{\text{out}}}}\Vert\boldsymbol{x}\boldsymbol{U}\Vert_{\mathrm{RMS}} \\ &= \eta\alpha\sqrt{\frac{d_{\text{in}}}{d_{\text{out}}}}\Vert\boldsymbol{x}\Vert_{\mathrm{RMS}} \end{aligned} Note that all steps are equalities, so we only need to set \alpha = \sqrt{d_{\text{out}}/d_{\text{in}}} to make every \Delta \boldsymbol{y} have RMS equal to \eta\Vert\boldsymbol{x}\Vert_{\mathrm{RMS}}, i.e., uniform relative update magnitude across all tokens.
Isotropy
Unfortunately, for the second case d_{\text{in}} > d_{\text{out}}, the goal of “complete uniformity” cannot be achieved. Specifically, now the SVD of \boldsymbol{\Phi} must be written as \boldsymbol{U}\begin{bmatrix}\boldsymbol{I}_{d_{\text{out}}} \\ \boldsymbol{0}_{(d_{\text{in}}-d_{\text{out}})\times d_{\text{out}}}\end{bmatrix}\boldsymbol{V}^{\top}, hence \begin{aligned} \Vert\Delta \boldsymbol{y}\Vert_{\mathrm{RMS}} &= \eta\alpha\Big\Vert\boldsymbol{x}\boldsymbol{U}\begin{bmatrix}\boldsymbol{I}_{d_{\text{out}}} \\ \boldsymbol{0}_{(d_{\text{in}}-d_{\text{out}})\times d_{\text{out}}}\end{bmatrix}\boldsymbol{V}^{\top}\Big\Vert_{\mathrm{RMS}} \\ &= \eta\alpha\Big\Vert\boldsymbol{x}\boldsymbol{U}\begin{bmatrix}\boldsymbol{I}_{d_{\text{out}}} \\ \boldsymbol{0}_{(d_{\text{in}}-d_{\text{out}})\times d_{\text{out}}}\end{bmatrix}\Big\Vert_{\mathrm{RMS}} \\ &= \eta\alpha\big\Vert(\boldsymbol{x}\boldsymbol{U})_{[:d_{\text{out}}]}\big\Vert_{\mathrm{RMS}} \end{aligned} \boldsymbol{x}\boldsymbol{U} is a d_{\text{in}}-dimensional vector, and d_{\text{in}} > d_{\text{out}}, so (\boldsymbol{x}\boldsymbol{U})_{[:d_{\text{out}}]} only takes the first d_{\text{out}} dimensions of \boldsymbol{x}\boldsymbol{U} to compute RMS. Its RMS is uncertain; the maximum can be \sqrt{d_{\text{in}}/d_{\text{out}}}\Vert\boldsymbol{x}\Vert_{\mathrm{RMS}} (worst case), and the minimum can be 0.
We know that orthogonal matrices do not change RMS, so \Vert\boldsymbol{x}\boldsymbol{U}\Vert_{\mathrm{RMS}} = \Vert\boldsymbol{x}\Vert_{\mathrm{RMS}}. When the distribution of \boldsymbol{x} is sufficiently isotropic, we can assume that the average scale of each component of \boldsymbol{x}\boldsymbol{U} is \Vert\boldsymbol{x}\Vert_{\mathrm{RMS}}, thus taking the first d_{\text{out}} components and computing RMS also yields approximately \Vert\boldsymbol{x}\Vert_{\mathrm{RMS}} on average, i.e., \Vert\Delta \boldsymbol{y}\Vert_{\mathrm{RMS}} \approx \alpha\Vert\boldsymbol{x}\Vert_{\mathrm{RMS}}. Therefore, we only need to set \alpha = 1 to achieve an effect similar to the previous section.
Anisotropy
Combining the results of the previous two sections, we obtain \alpha = \sqrt{\max\left(1, \frac{d_{\text{out}}}{d_{\text{in}}}\right)} which is exactly the \max(1,\cdot) appearing in the KellerJordan version of Muon.
However, the conclusion of the previous section relies on the assumption that the input \boldsymbol{x} is sufficiently isotropic, which may hold approximately in the early stages of training. But as training progresses, the feature distribution gradually becomes anisotropic, concentrating on the “worst case” that maximizes \Vert\Delta \boldsymbol{y}\Vert_{\mathrm{RMS}}. At this point, the average approximation \Vert\Delta \boldsymbol{y}\Vert_{\mathrm{RMS}} \approx \eta\alpha\Vert\boldsymbol{x}\Vert_{\mathrm{RMS}} is no longer accurate; instead, the maximum value \eta\alpha\sqrt{d_{\text{in}}/d_{\text{out}}}\Vert\boldsymbol{x}\Vert_{\mathrm{RMS}} becomes more accurate.
In this situation, the \alpha that makes \Vert\Delta \boldsymbol{y}\Vert_{\mathrm{RMS}} \approx \eta\Vert\boldsymbol{x}\Vert_{\mathrm{RMS}} is \sqrt{d_{\text{out}}/d_{\text{in}}}, consistent with the d_{\text{in}} \leq d_{\text{out}} case, recovering the MuP version result. That is, for the mid-to-late training phase, the MuP version of Muon is more scientific. For this inconsistency, we have two strategies: one is to always use the MuP version of Muon, which will slightly reduce early convergence speed, but after all, the mid-to-late phase is the “most critical” part of training; the other is to change the scaling factor to \alpha = \sqrt{\max\left(\tau_t, \frac{d_{\text{out}}}{d_{\text{in}}}\right)} where \tau_t decays monotonically from 1 to 0, thus achieving a gradual transition from the KellerJordan version to the MuP version, at the cost of adjusting an additional schedule.
Conclusion
This article mainly explains the origin of the \max(1,\cdot) in the KellerJordan version from the perspective of the uniformity of “feature increment”.
Reprint Notice: For reprinting, please include the URL of this article: https://kexue.fm/archives/11772. For more detailed reprinting instructions, please refer to: “Science Space FAQ”.