English (unofficial) translations of posts at kexue.fm
Source

A First Look at MuP: Cross-Model Scaling Laws for Hyperparameter Transfer

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

As is well known, the cost of training a large-scale LLM (Large Language Model) from scratch is exorbitant. This dictates that we cannot directly and repeatedly test hyperparameters on large LLMs. A natural idea is to hope that we can carefully search for hyperparameters on a small model with the same architecture and then directly transfer the optimal combination to the large model. Although this idea is simple, implementing it is non-trivial; it requires us to understand the scaling laws between common hyperparameters and model scales. MuP is a practical implementation of this very idea.

MuP, sometimes written as \mu P, stands for Maximal Update Parametrization. It originates from the paper "Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer". With the popularization of LLM training, it has gradually become one of the de-facto standards for scientific model training.

Main Idea

Before diving into the main topic, I must complain that the original MuP paper is written quite obscurely, and the expression of its conclusions is not clear enough, which adds significant difficulty to understanding. Therefore, in the following, I will try to reproduce the conclusions of MuP in a way that I consider (personally) concise and clear.

To state the conclusion first: MuP primarily studies the transfer laws of hyperparameters across model scales. There are several keywords here:

  1. Hyperparameters: Currently, this mainly refers to the learning rate.

  2. Model Scale: Currently, this mainly refers to the model width.

  3. The core here is "transfer".

Please note that MuP does not study what the optimal hyperparameters are; it only studies the scaling laws of the optimal hyperparameters as the model scale changes. Therefore, we need to search for the optimal hyperparameter combination on a small model and then transfer it to a large model. This is the use case and methodology of MuP.

The principle for deriving MuP is to ensure that the forward propagation, backward propagation, loss increment, and feature variation of the model do not change significantly with the model scale:

  • The specific approach is to analyze the order of magnitude at initialization and assume that these conclusions represent the laws of subsequent optimization.

  • Simply put, it assumes that if the initialization is done correctly, the subsequent process will automatically follow the correct trajectory (a good start is half the battle?).

  • Of course, one could tell stories about the Law of Large Numbers or the Central Limit Theorem for this assumption, but I personally believe it is not strictly necessary.

Forward Propagation

We begin our discussion with forward propagation, as it is a relatively simple and mature part. First, consider a linear layer \boldsymbol{Y}=\boldsymbol{X}\boldsymbol{W}, where \boldsymbol{X}\in\mathbb{R}^{b\times d_{in}} and \boldsymbol{W}\in\mathbb{R}^{d_{in}\times d_{out}}. We use RMS (Root Mean Square) as an indicator of matrix scale, for example: \begin{equation} \text{RMS}(\boldsymbol{W}) = \sqrt{\frac{1}{d_{in} d_{out}}\sum_{i=1}^{d_{in}} \sum_{j=1}^{d_{out}} W_{i,j}^2} \end{equation}

We know that to keep the RMS of \boldsymbol{X} roughly equal to the RMS of \boldsymbol{Y} during the initialization phase (referred to as "stable"), \boldsymbol{W} should use:

LeCun Initialization: Random initialization with "mean 0 and variance 1/d_{in}".

This is already one of the fundamental conclusions in deep learning, so we will not expand on its derivation. Readers who are not familiar with it can refer to previous blog posts such as "Understanding Model Parameter Initialization Strategies from a Geometric Perspective" and "A Brief Discussion on Initialization, Parameterization, and Normalization in Transformers".

Next, we consider a non-linear layer \boldsymbol{Y}=\phi(\boldsymbol{X}\boldsymbol{W}), where \phi is an element-wise activation function. If we still want to maintain the RMS of \boldsymbol{X} approximately equal to the RMS of \boldsymbol{Y}, the result will be slightly different. For example, with \text{relu} activation, we get:

Kaiming Initialization: Random initialization with "mean 0 and variance 2/d_{in}".

It is easy to see that compared to LeCun initialization, Kaiming initialization only differs by a constant factor of 2 in variance (which is independent of model scale). It can be proven that results for other activation functions are similar. Thus, we can draw a conclusion:

fan_in Initialization: To ensure the stability of forward propagation, one should use random initialization with "mean 0 and variance proportional to 1/d_{in}".

This conclusion can also be understood as "the influence of the activation function is independent of the model scale." Therefore, if we only want to analyze the effect of model scale, we can ignore the existence of (element-wise) activation functions and directly obtain the scaling law \propto 1/d_{in} from LeCun initialization.

Backward Propagation

Now we continue to analyze backward propagation (gradients). Note that we assume variables and their gradients have the same shape. We can calculate: \begin{align} \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}} =&\, \boldsymbol{X}^{\top}\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\otimes \phi'(\boldsymbol{X}\boldsymbol{W})\right) \\[5pt] \frac{\partial\mathcal{L}}{\partial \boldsymbol{X}} =&\, \left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\otimes \phi'(\boldsymbol{X}\boldsymbol{W})\right)\boldsymbol{W}^{\top} \end{align} The first formula is the gradient of the parameters within the current layer, and the second formula is the gradient propagated back to the previous layer. \otimes denotes the Hadamard product, and \phi' is the derivative of \phi.

Notice a fact: for the activation functions we commonly use, their derivatives can be bounded by a (scale-independent) constant. Therefore, at least in terms of order of magnitude, we can write: \begin{align} \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}} \sim&\, \boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}} \label{eq:grad-w}\\[5pt] \frac{\partial\mathcal{L}}{\partial \boldsymbol{X}} \sim&\, \frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\boldsymbol{W}^{\top}\label{eq:grad-x} \end{align} Let’s look at the second formula. Compared to \boldsymbol{Y}=\boldsymbol{X}\boldsymbol{W}, the matrix multiplied on the right has become \boldsymbol{W}^{\top}. Following the conclusion from the previous section, to maintain RMS stability in backward propagation, the initialization of \boldsymbol{W} should be:

fan_out Initialization: Random initialization with "mean 0 and variance 1/d_{out}".

When d_{in} \neq d_{out}, the requirements for forward and backward propagation conflict. At this point, a compromise strategy was proposed:

Xavier Initialization: Random initialization with "mean 0 and variance 2/(d_{in} + d_{out})".

This is also called "fan_avg initialization" because it simply takes the arithmetic mean of d_{in} and d_{out}. Other averaging methods can also be considered; refer to "Thinking on Dimension Averaging Strategies for Non-Square Matrices in Initialization Methods". Xavier initialization seems to account for both forward and backward propagation, but one could also say it accounts for neither. A better approach is to design the model such that most parameters are square matrices, as in the model family discussed in Equation [eq:model].

Loss Increment

With the groundwork of forward and backward propagation, we can attempt to analyze the increment of the loss function. Consider the change in the loss function when \boldsymbol{W} \to \boldsymbol{W} + \Delta\boldsymbol{W}: \begin{equation} \Delta \mathcal{L} = \mathcal{L}(\boldsymbol{W} + \Delta\boldsymbol{W}) - \mathcal{L}(\boldsymbol{W})\approx \left\langle\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}, \Delta\boldsymbol{W}\right\rangle_F \end{equation} Here \langle\cdot,\cdot\rangle_F is the Frobenius inner product, which is the inner product of the matrices after flattening them into vectors. Considering gradient descent \Delta\boldsymbol{W} = -\eta \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}, where \eta is naturally the learning rate, and combining this with Equation [eq:grad-w], we have: \begin{equation} \Delta \mathcal{L}\approx -\eta\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\Vert_F^2\sim -\eta \left\Vert\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\right\Vert_F^2 \end{equation} In fact, this equation already tells us why the same learning rate \eta cannot be used across different model scales:

  1. \boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}} is a d_{in}\times d_{out} matrix;

  2. \left\Vert\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\right\Vert_F^2 is the sum of squares of d_{in}\times d_{out} numbers;

  3. \boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}} is exactly the product of forward and backward terms;

  4. If both forward and backward are stable, then each element of \boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}} is \Theta(1) (\Theta is the Big Theta Notation);

  5. Therefore, \left\Vert\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\right\Vert_F^2 is \Theta(d_{in} d_{out}).

Point 4 deserves more comment. \boldsymbol{X}^{\top} is a d_{in}\times b matrix, and \frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}} is a b\times d_{out} matrix. Their product consists of d_{in} d_{out} inner products of b-dimensional vectors. An inner product is a sum of b terms, and the loss \mathcal{L} is usually averaged over samples (including a division by b). Thus, if \boldsymbol{X}^{\top} and \frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}} are scale-independent, their product is also basically scale-independent (i.e., RMS is \Theta(1)).

The final conclusion indicates that if we directly apply the learning rate of a small model to a large model, for a sufficiently large model, its loss increment per step will explode as the parameter scale (d_{in} d_{out}) increases. This means the convergence process of the small model cannot be replicated, and it might even fail to converge because the steps are too large.

At this point, one might think of setting \eta\propto 1/(d_{in} d_{out}) to scale \Delta\mathcal{L}. This idea is already aligned with MuP, but in practice, due to the incompatibility of forward and backward propagation mentioned earlier, Point 4 does not always hold. Thus, the actual situation is more complex.

Model Assumption

Now let us consider a scenario closer to practice. Our task is to train a model \mathbb{R}^{d_{in}}\mapsto \mathbb{R}^{d_{out}}, where d_{in}, d_{out} are determined by the data and cannot be changed. As stated at the beginning, MuP aims to study the scaling laws of hyperparameters with model scale, so all fixed quantities are treated as constants or \Theta(1). For example, an initialization variance of 1/d_{in} is equivalent to saying the variance is \Theta(1).

We can change the model architecture, number of parameters, etc., but MuP primarily considers the laws of width. So we define the model architecture. The model family considered here is: \begin{equation} \begin{gathered} \boldsymbol{Y}_{in} = \boldsymbol{X} \boldsymbol{W}_{in} \\[5pt] \boldsymbol{Y}_{out} = \text{NN}(\boldsymbol{Y}_{in},\boldsymbol{\Omega}) \\[5pt] \boldsymbol{Z} = \boldsymbol{Y}_{out} \boldsymbol{W}_{out} \end{gathered}\label{eq:model} \end{equation} Where:

  1. \boldsymbol{X}\in\mathbb{R}^{b\times d_{in}} (including batch size);

  2. \boldsymbol{W}_{in} \in \mathbb{R}^{d_{in}\times d}, \boldsymbol{W}_{out} \in \mathbb{R}^{d\times d_{out}};

  3. \text{NN} is any \mathbb{R}^d\mapsto \mathbb{R}^d neural network;

  4. d is what we usually call the hidden size;

  5. We can freely increase d to improve the model’s parameter count and potential;

  6. MuP wants to study the scaling laws of hyperparameters with respect to d.

More specifically, we consider \text{NN} to be a K-layer MLP: \begin{equation} \begin{aligned} \boldsymbol{Y}_0 =&\, \boldsymbol{Y}_{in} \\[5pt] \boldsymbol{Y}_{k+1} =&\, \phi(\boldsymbol{Y}_k \boldsymbol{W}_{k+1}) \\[5pt] \boldsymbol{Y}_{out} =&\, \boldsymbol{Y}_K \end{aligned} \end{equation} Here \boldsymbol{\Omega}=\{\boldsymbol{W}_1,\boldsymbol{W}_2,\cdots,\boldsymbol{W}_K\}, and \boldsymbol{W}_k\in\mathbb{R}^{d\times d} are all d\times d square matrices, all using fan_in initialization (which is equivalently fan_out initialization).

To supplement, the assumption that all parameter matrices are d\times d square matrices is purely to simplify the analysis. The real purpose is to assume that the parameters of \text{NN} do not have scale-independent shapes. For example, a shape like d\times 64 is not allowed because 64 is a constant, but d\times 4d is allowed because regardless of fan_in, fan_out, or fan_avg initialization, the variance is proportional to 1/d.

Assembling Everything

After establishing the specific model, we can assemble the previous conclusions. The parameters to be updated are divided into three parts: \boldsymbol{W}_{in}, \boldsymbol{\Omega}, \boldsymbol{W}_{out}. We calculate the gradients for each: \begin{align} \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}} =&\, \boldsymbol{Y}_{out}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}} \\[6pt] \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k} =&\, \frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k} \cdot\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}_{out}} = \frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k} \cdot\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}\boldsymbol{W}_{out}^{\top}\right) \\[6pt] \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}} =&\, \boldsymbol{X}^{\top} \frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}_{in}} = \boldsymbol{X}^{\top} \left(\frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}}\cdot\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}\boldsymbol{W}_{out}^{\top}\right)\right) \end{align}

The \cdot operation here requires a brief explanation: \boldsymbol{Y}_{in}, \boldsymbol{Y}_{out} are matrices, so \frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}} is in principle a fourth-order tensor. The chain rule \frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}}\cdot\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}_{out}} is actually a high-order tensor multiplication. We won’t expand on it here; just know it is a generalization of matrix multiplication.

Now observe the patterns:

  1. All three formulas contain \frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}};

  2. The latter two formulas contain \boldsymbol{W}_{out}^{\top};

  3. Since \boldsymbol{W}_k are square matrices, \frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}} and \frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k} are stable [RMS is \Theta(1)];

  4. If \boldsymbol{W}_{in} also uses fan_in initialization, then \boldsymbol{Y}_{out} is also stable;

  5. For \frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}\boldsymbol{W}_{out}^{\top} to be stable, the initialization variance should be 1/d_{out}, but d_{out} is scale-independent, acting as a constant.

Consequently:

  1. The RMS of \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}} is \Theta(1). \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}\right\Vert_F^2 is the sum of squares of d\times d_{out} numbers, so its size is \Theta(d\times d_{out}). Since d_{out} is constant, it is \Theta(d). To obtain \Delta\mathcal{L} = \Theta(1), its learning rate must satisfy \eta_{out}\propto 1/d.

  2. \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}\right\Vert_F^2 is a sum of d^2 numbers. The RMS of \frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k} and \frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}} are both \Theta(1). If we set the initialization variance of \boldsymbol{W}_{out} to \propto 1/d^2, then the RMS of \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k} becomes \Theta(1/d). The sum of squares then becomes exactly \Theta(1), so the learning rate does not need to change.

  3. In this case, the RMS of \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}} is also \Theta(1/d), but \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}\right\Vert_F^2 is only a sum of d_{in}\times d squares, so the result is \Theta(1/d). To get \Delta\mathcal{L} = \Theta(1), the learning rate needs to be increased by a factor of d to cancel this effect, i.e., \eta_{in}\propto d.

Feature Variation

The above results are correct, but a careful look reveals a problem in the derivation: points 2 and 3 rely on the setting "we directly set the initialization variance of \boldsymbol{W}_{out} to \propto 1/d^2". However, there is currently no direct basis for this setting. Without further explanation, the derivation is incomplete.

In fact, if we only look at the requirement \Delta \mathcal{L}=\Theta(1), we cannot rule out other possibilities. For example, if the initialization variance of \boldsymbol{W}_{out} is set to \propto 1/d, then the RMS of \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k} is \Theta(1/\sqrt{d}), and the sum of squares is \Theta(d). Then a learning rate \eta\propto 1/d could also achieve \Delta \mathcal{L}=\Theta(1). Therefore, to explain the necessity of "setting the initialization variance of \boldsymbol{W}_{out} to \propto 1/d^2", we need to introduce a new condition.

The loss function \mathcal{L} is a macro indicator of the model. Looking at its change alone is insufficient to explain all results; we need to look inside the model. Specifically, we want the change in the output of each layer (usually called features or activations) to also have scale invariance. For a linear layer \boldsymbol{Y}_k = \boldsymbol{Y}_{k-1} \boldsymbol{W}_k, the change in output caused by \boldsymbol{W}_k\to \boldsymbol{W}_k + \Delta \boldsymbol{W}_k is: \begin{equation} \Delta\boldsymbol{Y}_k = \boldsymbol{Y}_{k-1} (\boldsymbol{W}_k + \Delta \boldsymbol{W}_k) - \boldsymbol{Y}_{k-1} \boldsymbol{W}_k = \boldsymbol{Y}_{k-1} \Delta\boldsymbol{W}_k \end{equation} Note that \boldsymbol{Y}_{k-1}\in\mathbb{R}^{b\times d} and \Delta\boldsymbol{W}_k\in\mathbb{R}^{d\times d}, so \boldsymbol{Y}_{k-1} \Delta\boldsymbol{W}_k is the inner product of b\times d pairs of d-dimensional vectors. Since \Delta\boldsymbol{W}_k is a carefully designed update, it is unlikely to be independent of \boldsymbol{Y}_{k-1} like the initialization. Thus, the inner product of d-dimensional vectors is more likely to be \Theta(d). Therefore, if the RMS of \Delta\boldsymbol{Y}_{k-1} is \Theta(1), we can assume the RMS of \Delta\boldsymbol{Y}_k will be \Theta(d\times \text{RMS}(\Delta \boldsymbol{W}_k)).

Thus, to keep the RMS of \Delta\boldsymbol{Y}_k at \Theta(1), we get an additional requirement for \Delta \boldsymbol{W}_k: \begin{equation} \text{RMS}(\Delta \boldsymbol{W}_k) = \Theta(1 / d)\label{eq:dw-rms} \end{equation}

Combining \Delta \boldsymbol{W}_k = -\eta\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k} and \Delta\mathcal{L}=\Theta(1), we obtain the result that the initialization variance of \boldsymbol{W}_{out} must be \propto 1/d^2.

(Note: This section relies on guidance from @Chenyu Zheng; many thanks!)

Adam Version

The above is MuP for SGD. For Adam, we usually use SignSGD as an approximation for order-of-magnitude analysis:

  1. \Delta \boldsymbol{W} = -\eta \mathop{\text{sign}}\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right);

  2. \Delta \mathcal{L} \approx -\eta \left|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right|_1;

  3. Here |\cdot|_1 refers to the sum of absolute values of all elements.

Regarding the SignSGD approximation itself, readers can refer to "How Should Learning Rate Scale with Batch Size?" and "How Adam’s Epsilon Affects the Learning Rate Scaling Law?". In short, SignSGD is a common approximation for analyzing Adam-related scaling laws.

Now we can mimic the SGD analysis:

  1. The RMS of \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}} is \Theta(1). \left|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}\right|_1 is a sum of d\times d_{out} terms, which is \Theta(d\times d_{out}) = \Theta(d). Thus, its learning rate must satisfy \eta_{out}\propto 1/d to cancel the scale effect.

  2. \left|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}\right|_1 is a sum of d^2 terms. The RMS of \frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k} and \frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}} are both \Theta(1). If we set the initial variance of \boldsymbol{W}_{out} to \propto 1/d^2, then the RMS of \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k} is \Theta(1/d). The sum of d^2 terms is \Theta(d), so the learning rate scales as \eta_k\propto 1/d to cancel the scale effect.

  3. In this case, the RMS of \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}} is also \Theta(1/d), but \left|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}\right|_1 is only a sum of d_{in}\times d terms, so it is already \Theta(1). Thus, the learning rate does not need to change with scale.

(Note: Readers can verify that Equation [eq:dw-rms] is satisfied.)

Muon Version

Next is the analysis for Muon. We have introduced Muon in detail in previous posts. Similar to using SignSGD for Adam, we use MSignSGD to approximate Muon:

  1. \Delta \boldsymbol{W} = -\eta \mathop{\text{msign}}\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right);

  2. \Delta \mathcal{L} \approx -\eta \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\Vert_*;

  3. Here \Vert\cdot\Vert_* is the Nuclear norm, the sum of all singular values;

  4. The Nuclear norm is hard to calculate, but the Frobenius norm (F-norm) is easy; it is the square root of the sum of squares of all singular values;

  5. We use the F-norm as an approximation for the Nuclear norm: \Delta \mathcal{L} \approx -\eta \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\Vert_*\approx -\eta \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\Vert_F;

  6. The F-norm is also the square root of the sum of squares of all elements.

The analysis follows:

  1. The RMS of \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}} is \Theta(1), so \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}\right\Vert_* is \Theta(\sqrt{d\times d_{out}}) = \Theta(\sqrt{d}). To eliminate the scale effect, its learning rate must satisfy \eta_{out}\propto 1/\sqrt{d}.

  2. \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}\right\Vert_F is the square root of the sum of d^2 squares. The RMS of \frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k} and \frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}} are \Theta(1). If we set the initial variance of \boldsymbol{W}_{out} to \propto 1/d^2, then the RMS of \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k} is \Theta(1/d). The square root of the sum of squares is \Theta(1), so the learning rate does not change.

  3. The RMS of \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}} is also \Theta(1/d), but \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}\right\Vert_F is the square root of the sum of d_{in}\times d squares, so it is \Theta(1/\sqrt{d}). The learning rate needs to be increased by \sqrt{d} to cancel this, i.e., \eta_{in}\propto \sqrt{d}.

(Note: The Muon conclusion here is correct, but it does not satisfy Equation [eq:dw-rms] because that condition assumes an element-wise update, which Muon is not. We bypassed this by sticking to the \boldsymbol{W}_{out} variance conclusion.)

Conclusion Summary

The summary of the above conclusions is:

\boldsymbol{W}_{in} Var \boldsymbol{W}_{in} LR \boldsymbol{W}_k Var \boldsymbol{W}_k LR \boldsymbol{W}_{out} Var \boldsymbol{W}_{out} LR
SGD 1/d_{in} d 1 / d 1 1/d^2 1 / d
Adam 1/d_{in} 1 1 / d 1 / d 1/d^2 1 / d
Muon 1/d_{in} \sqrt{d} 1 / d 1 1/d^2 1 / \sqrt{d}

Here \boldsymbol{W}_k refers to all parameters except \boldsymbol{W}_{in} and \boldsymbol{W}_{out}. It is important to emphasize that these relationships are "proportional to" rather than "equal to". In practice, one might make slight adjustments. For example, when using Muon, \boldsymbol{W}_{in} and \boldsymbol{W}_{out} are often optimized with Adam, leading to:

  1. \eta_{out}\propto 1/d;

  2. \eta_{in} remains constant.

If combined with the "Adjust LR" mentioned in "Muon is Scalable for LLM Training", the learning rate is multiplied by \sqrt{\max(n, m)}, where n\times m is the matrix shape. Since we assumed \text{NN} parameters scale proportionally, \sqrt{\max(n, m)}\propto \sqrt{d}. To cancel this effect:

  1. \eta_k\propto 1/\sqrt{d}.

Summary

This article introduces MuP (Maximal Update Parametrization) in a concise and clear manner. MuP is a framework for studying the transfer laws of hyperparameters across model scales. Based on MuP, we can search for hyperparameters (mainly learning rate and initialization) on small models at a relatively low cost and then transfer them to large models, reducing the cost of training large models.

Objectively speaking, the introduction and analysis here are still preliminary. For instance, we did not consider Bias terms, evaluate the generality of the conclusions for architectures beyond MLP, or carefully consider the roles of Normalization and residuals. Omitting Bias terms was purely out of laziness; consider it an exercise for the reader. As for MuP under different architectures, while the analysis is more complex, the conclusions are generally similar due to the nature of neural networks. A key point for improvement is the impact of Normalization and residuals; Normalization, in particular, allows for stable forward propagation without special initialization, offering greater freedom.

These are left for future analysis.

Original address: https://kexue.fm/archives/10770

For more details on reprinting, please refer to: "Scientific Space FAQ"