English (unofficial) translations of posts at kexue.fm
Source

Rethinking Learning Rate and Batch Size (Part 3): Muon

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

In the previous two articles, "Rethinking Learning Rate and Batch Size (Part 1): Status Quo" and "Rethinking Learning Rate and Batch Size (Part 2): Mean Field", we primarily proposed the mean-field method to simplify calculations related to learning rate and batch size. At that time, the optimizers we analyzed were SGD, SignSGD, and SoftSignSGD, and our main goal was simplification; essentially, no new conclusions were drawn.

However, in today’s "feast of optimizers," how could Muon be left out? Therefore, in this article, we will attempt to calculate the relevant conclusions for Muon to see if the relationship between its learning rate and batch size exhibits any new patterns.

Basic Notation

As is well known, the main characteristic of Muon is its non-element-wise update rule. Consequently, the element-wise calculation methods used previously in "How Should the Learning Rate Change as the Batch Size Increases?" and "How Does Adam’s Epsilon Affect the Scaling Law of Learning Rate?" are completely inapplicable. Fortunately, the mean-field method introduced in the previous article remains effective, requiring only a slight adjustment of details.

First, let us introduce some notation. Let the loss function be \mathcal{L}(\bm{W}), where \bm{W} \in \mathbb{R}^{n \times m} is a matrix (assume n \geq m). Let \bm{G} be its gradient. The gradient of a single sample is denoted as \tilde{\bm{G}}, its mean is \bm{G}, and its variance is \sigma^2. When the batch size is B, the gradient is denoted as \tilde{\bm{G}}_B; its mean remains \bm{G}, but its variance becomes \sigma^2/B. Note that the variance here is treated as a scalar \sigma^2, rather than considering the full covariance matrix as done previously.

The core reason for this simplification is that the random variable itself is already a matrix, so its corresponding covariance matrix would actually be a 4th-order tensor, which is cumbersome to discuss. Does simplifying it to a single scalar significantly sacrifice accuracy? In fact, it does not. Although we considered the full covariance matrix \bm{\Sigma} in the previous two articles, a closer look reveals that the final results only depend on \mathop{\mathrm{tr}}(\bm{\Sigma}), which is equivalent to simplifying it to a scalar from the beginning.

Hessian Matrix

Similarly, let the update amount be -\eta\tilde{\bm{\Phi}}_B. Consider the second-order expansion of the loss function: \begin{equation} \mathcal{L}(\bm{W} - \eta\tilde{\bm{\Phi}}_B) \approx \mathcal{L}(\bm{W}) - \eta \mathop{\mathrm{tr}}(\tilde{\bm{\Phi}}_B^{\top}\bm{G}) + \frac{1}{2}\eta^2 \mathop{\mathrm{tr}}(\tilde{\bm{\Phi}}_B^{\top}\bm{H}\tilde{\bm{\Phi}}_B) \label{eq:loss-2} \end{equation} The first two terms should be straightforward; the third term is more difficult to understand. Like the covariance matrix, the Hessian matrix \bm{H} here is a 4th-order tensor, which is complex to interpret.

The simplest entry point here is the linear operator perspective, i.e., treating \bm{H} as a linear operator where both input and output are matrices. We do not need to know what \bm{H} looks like or how \bm{H} operates with \tilde{\bm{\Phi}}_B; we only need to know that \bm{H}\tilde{\bm{\Phi}}_B is linear with respect to \tilde{\bm{\Phi}}_B. In this way, the objects we handle remain matrices, avoiding additional cognitive load. Any linear operator that satisfies the conditions can serve as an approximation of the Hessian matrix, without needing to write out the specific high-order tensor form.

The protagonist of this article is Muon. we take \tilde{\bm{\Phi}}_B = \mathop{\mathrm{msign}}(\tilde{\bm{G}}_B) as its approximation for calculation. By definition, we write \mathop{\mathrm{msign}}(\tilde{\bm{G}}_B) = \tilde{\bm{G}}_B(\tilde{\bm{G}}_B^{\top}\tilde{\bm{G}}_B)^{-1/2}. From a Newton’s method perspective, this is equivalent to assuming \bm{H}^{-1}\bm{X} = \eta_{\max}\bm{X}(\bm{G}^{\top}\bm{G})^{-1/2}, which implies \bm{H}\bm{X} = \eta_{\max}^{-1}\bm{X}(\bm{G}^{\top}\bm{G})^{1/2}. This will be used in subsequent calculations.

Calculating Expectation

Taking the expectation of both sides of Eq. [eq:loss-2], we get: \begin{equation} \mathbb{E}[\mathcal{L}(\bm{W} - \eta\tilde{\bm{\Phi}}_B)] \approx \mathcal{L}(\bm{W}) - \eta \mathop{\mathrm{tr}}(\mathbb{E}[\tilde{\bm{\Phi}}_B]^{\top}\bm{G}) + \frac{1}{2}\eta^2\mathbb{E}[\mathop{\mathrm{tr}}(\tilde{\bm{\Phi}}_B^{\top}\bm{H}\tilde{\bm{\Phi}}_B)] \end{equation} First, calculate \mathbb{E}[\tilde{\bm{\Phi}}_B]: \begin{equation} \mathbb{E}[\tilde{\bm{\Phi}}_B] = \mathbb{E}[\tilde{\bm{G}}_B(\tilde{\bm{G}}_B^{\top}\tilde{\bm{G}}_B)^{-1/2}] \approx \mathbb{E}[\tilde{\bm{G}}_B](\mathbb{E}[\tilde{\bm{G}}_B^{\top}\tilde{\bm{G}}_B])^{-1/2} = \bm{G}(\mathbb{E}[\tilde{\bm{G}}_B^{\top}\tilde{\bm{G}}_B])^{-1/2} \end{equation} We write out \mathbb{E}[\tilde{\bm{G}}_B^{\top}\tilde{\bm{G}}_B] by components and assume independence between different components: \begin{equation} \mathbb{E}[\tilde{\bm{G}}_B^{\top}\tilde{\bm{G}}_B]_{i,j} = \mathbb{E}\left[\sum_{k=1}^n (\tilde{G}_B)_{k,i}(\tilde{G}_B)_{k,j}\right] = \left\{\begin{aligned} &\mathbb{E}\left[\sum_{k=1}^n (\tilde{G}_B)_{k,i}^2\right] = \left(\sum_{k=1}^n G_{k,i}^2\right) + n\sigma^2/B, & (i=j) \\[6pt] &\sum_{k=1}^n \mathbb{E}[(\tilde{G}_B)_{k,i}] \mathbb{E}[(\tilde{G}_B)_{k,j}] = \sum_{k=1}^n G_{k,i}G_{k,j}, & (i\neq j) \end{aligned}\right. \end{equation} Combining these, we have \mathbb{E}[\tilde{\bm{G}}_B^{\top}\tilde{\bm{G}}_B] = \bm{G}^{\top}\bm{G} + (n\sigma^2/B) \bm{I}, so: \begin{equation} \mathbb{E}[\tilde{\bm{\Phi}}_B] \approx \bm{G}(\bm{G}^{\top}\bm{G} + (n\sigma^2/B) \bm{I})^{-1/2} = \mathop{\mathrm{msign}}(\bm{G})(\bm{I} + (n\sigma^2/B) (\bm{G}^{\top}\bm{G})^{-1})^{-1/2} \end{equation} To further simplify the dependency on B, we approximate \bm{G}^{\top}\bm{G} with \mathop{\mathrm{tr}}(\bm{G}^{\top}\bm{G})\bm{I}/m, which means keeping only the diagonal part of \bm{G}^{\top}\bm{G} and then replacing the diagonal elements with their average. Thus, we obtain: \begin{equation} \mathbb{E}[\tilde{\bm{\Phi}}_B] \approx \mathop{\mathrm{msign}}(\bm{G})(1 + \mathcal{B}_{\text{simple}}/B)^{-1/2} \end{equation} where \mathcal{B}_{\text{simple}} = mn\sigma^2/\mathop{\mathrm{tr}}(\bm{G}^{\top}\bm{G}) = mn\sigma^2/\|\bm{G}\|_F. This is actually the same as treating \bm{G} as a vector and calculating \mathcal{B}_{\text{simple}} as in the previous two articles. The form of the above equation is identical to that of SignSGD. From this, we can guess that Muon will not present many new results regarding the relationship between learning rate and batch size.

Same Patterns

As for \mathbb{E}[\mathop{\mathrm{tr}}(\tilde{\bm{\Phi}}_B^{\top}\bm{H}\tilde{\bm{\Phi}}_B)], we only calculate the assumption corresponding to Muon derived earlier, namely \bm{H}\bm{X} = \eta_{\max}^{-1}\bm{X}(\bm{G}^{\top}\bm{G})^{1/2}. Then: \begin{equation} \mathop{\mathrm{tr}}(\tilde{\bm{\Phi}}_B^{\top}\bm{H}\tilde{\bm{\Phi}}_B) = \eta_{\max}^{-1}\mathop{\mathrm{tr}}(\tilde{\bm{\Phi}}_B^{\top}\tilde{\bm{\Phi}}_B(\bm{G}^{\top}\bm{G})^{1/2}) \end{equation} Note that \tilde{\bm{\Phi}}_B is the result of \mathop{\mathrm{msign}}, so it must be an orthogonal matrix (full rank), which means \tilde{\bm{\Phi}}_B^{\top}\tilde{\bm{\Phi}}_B = \bm{I}. In this case, \mathop{\mathrm{tr}}(\tilde{\bm{\Phi}}_B^{\top}\bm{H}\tilde{\bm{\Phi}}_B) is a fixed constant \eta_{\max}^{-1}\mathop{\mathrm{tr}}((\bm{G}^{\top}\bm{G})^{1/2}) = \eta_{\max}^{-1}\mathop{\mathrm{msign}}(\bm{G})^{\top}\bm{G}. Thus, we can obtain: \begin{equation} \eta^* \approx \frac{\mathop{\mathrm{tr}}(\mathbb{E}[\tilde{\bm{\Phi}}_B]^{\top}\bm{G})}{\mathbb{E}[\mathop{\mathrm{tr}}(\tilde{\bm{\Phi}}_B^{\top}\bm{H}\tilde{\bm{\Phi}}_B)]} \approx \frac{\eta_{\max}}{\sqrt{1 + \mathcal{B}_{\text{simple}}/B}} \end{equation} As expected, the form is exactly the same as the result for SignSGD, with no new patterns.

Actually, upon reflection, this is reasonable. SignSGD directly applies \mathop{\mathrm{sign}} to the gradient, while Muon’s \mathop{\mathrm{msign}} applies \mathop{\mathrm{sign}} to the singular values. Intuitively, it is equivalent to applying \mathop{\mathrm{sign}} in a different coordinate system. It brings a new matrix update rule, but the learning rate \eta^* and batch size B are just scalars. Given that the core of both is \mathop{\mathrm{sign}}, it is highly likely that the asymptotic relationship of these scalars will not undergo significant changes.

Of course, we have only calculated for a specific \bm{H}. If a more general \bm{H} is considered, it is possible that, like SignSGD, a "Surge" phenomenon might occur where "as the batch size increases, the learning rate should instead decrease." However, as mentioned in the "Reflections on Causes" section of the previous article, if a Surge phenomenon is truly observed, it might be more appropriate to change the optimizer rather than correcting the relationship between \eta^* and B.

Summary

In this article, we attempted a simple analysis of Muon using the mean-field approximation. The conclusion is that its relationship between learning rate and batch size is consistent with SignSGD, with no new patterns discovered.

When reposting, please include the original address of this article: https://kexue.fm/archives/11285

For more detailed reposting matters, please refer to: "Scientific Space FAQ"