English (unofficial) translations of posts at kexue.fm
Source

Rethinking Learning Rate and Batch Size (Part 1): Current Status

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

In previous articles "How should the learning rate change as the Batch Size increases?" and "How does Adam’s epsilon affect the Scaling Law of the learning rate?", we discussed the theoretical laws of learning rate variation with Batch Size. A classic part of this discussion is the second-order expansion analysis proposed by OpenAI. However, when dealing with non-SGD optimizers, the calculation process of this analysis method often becomes quite complex, leaving one feeling at a loss.

In the next few articles, I will reorganize and rethink the relevant details from the aforementioned articles, attempt to simplify some of the derivation steps, provide a more general and lightweight derivation path, and explore the possibility of extending it to the Muon optimizer.

Main Methodology

First, let’s review the previous analysis method. In "How should the learning rate change as the Batch Size increases?", we introduced several approaches to analyzing the relationship between learning rate and Batch Size. Among them, the second-order approximation analysis proposed by OpenAI in "An Empirical Model of Large-Batch Training" occupied the main space, and this article follows the same logic.

Next, we need to introduce some notation. Let the loss function be \mathcal{L}(\boldsymbol{w}), where \boldsymbol{w}\in\mathbb{R}^N is the parameter vector and \boldsymbol{g} is its gradient. Note that the ideal loss function is calculated as an expectation over all training samples, but in practice, we can only sample one Batch to calculate it. This leads to randomness in the gradient. We denote the gradient of a single sample as \tilde{\boldsymbol{g}}, its mean is \boldsymbol{g}, and its covariance matrix is denoted as \boldsymbol{\Sigma}. When the Batch Size is B, the gradient is denoted as \tilde{\boldsymbol{g}}_B; its mean is still \boldsymbol{g}, but its covariance matrix becomes \boldsymbol{\Sigma}/B.

Furthermore, let the current learning rate be \eta and the update vector be \tilde{\boldsymbol{\varphi}}_B. Then the updated loss function will be: \begin{equation} \begin{aligned} \mathcal{L}(\boldsymbol{w} - \eta\tilde{\boldsymbol{\varphi}}_B) \approx&\, \mathcal{L}(\boldsymbol{w}) - \eta \tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2\tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{H}\tilde{\boldsymbol{\varphi}}_B \\ =&\, \mathcal{L}(\boldsymbol{w}) - \eta \tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2\mathop{\mathrm{tr}}(\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{H}) \end{aligned} \end{equation} On the right side, we have Taylor-expanded to the second order, where \boldsymbol{H} is the Hessian matrix and \mathop{\mathrm{tr}} is the trace of the matrix. The second equality uses the identity \mathop{\mathrm{tr}}(\boldsymbol{A}\boldsymbol{B})=\mathop{\mathrm{tr}}(\boldsymbol{B}\boldsymbol{A}). To obtain a deterministic result, we take the expectation of both sides: \begin{equation} \mathbb{E}[\mathcal{L}(\boldsymbol{w} - \eta\tilde{\boldsymbol{\varphi}}_B)] \approx \mathcal{L}(\boldsymbol{w}) - \eta\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2 \mathop{\mathrm{tr}}(\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]\boldsymbol{H}) \end{equation} We view the right side as a quadratic function of \eta and assume that the quadratic coefficient is positive (a stronger assumption is that the \boldsymbol{H} matrix is positive definite). Then we can find the minimum point: \begin{equation} \eta^* \approx \frac{\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]^{\top}\boldsymbol{g}}{\mathop{\mathrm{tr}}(\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]\boldsymbol{H})} \end{equation} This is the learning rate that, on average, makes the loss function decrease the fastest, representing the theoretical optimal solution for the learning rate. What we need to do is calculate \mathbb{E}[\tilde{\boldsymbol{\varphi}}_B] and \mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}] for specific \tilde{\boldsymbol{\varphi}}_B, and then extract its relationship with the Batch Size (i.e., B) from the above formula.

Warm-up Exercise

As a first example, we naturally consider the simplest case, SGD, where \tilde{\boldsymbol{\varphi}}_B=\tilde{\boldsymbol{g}}_B. Then it is easily obtained that \mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]=\boldsymbol{g} and \mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]=\boldsymbol{g}\boldsymbol{g}^{\top} + \boldsymbol{\Sigma}/B. Thus we have: \begin{equation} \eta^* \approx \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\mathop{\mathrm{tr}}((\boldsymbol{g}\boldsymbol{g}^{\top} + \boldsymbol{\Sigma}/B)\boldsymbol{H})} = \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g} + \mathop{\mathrm{tr}}(\boldsymbol{\Sigma}\boldsymbol{H})/B} = \frac{\eta_{\max}}{1 + \mathcal{B}_{\text{noise}}/B} \label{eq:eta-sgd} \end{equation} where \begin{equation} \eta_{\max} = \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}},\qquad\mathcal{B}_{\text{noise}} = \frac{\mathop{\mathrm{tr}}(\boldsymbol{\Sigma}\boldsymbol{H})}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}} \end{equation}

The result in [eq:eta-sgd] can be interpreted in several ways. First, it is a monotonically increasing function with an upper bound \eta_{\max}, indicating that the learning rate cannot increase indefinitely. Compared to simple linear or square root laws, this is more consistent with our intuitive understanding. When B \ll \mathcal{B}_{\text{noise}}, we have: \begin{equation} \eta^* \approx \frac{\eta_{\max}}{1 + \mathcal{B}_{\text{noise}}/B} \approx \frac{\eta_{\max}}{\mathcal{B}_{\text{noise}}/B} = \eta_{\max} B / \mathcal{B}_{\text{noise}} \end{equation} This shows that when the Batch Size is relatively small, the SGD learning rate is indeed linearly related to the Batch Size, while also implying that \mathcal{B}_{\text{noise}} is a key statistic. However, the definition of \mathcal{B}_{\text{noise}} depends on the Hessian matrix \boldsymbol{H}, which is almost impossible to calculate precisely in LLMs. Therefore, in practice, we usually assume it is (a multiple of) the identity matrix, yielding a simplified form: \begin{equation} \mathcal{B}_{\text{simple}} = \frac{\mathop{\mathrm{tr}}(\boldsymbol{\Sigma})}{\boldsymbol{g}^{\top}\boldsymbol{g}} \end{equation} This result takes the form of noise intensity (\mathop{\mathrm{tr}}(\boldsymbol{\Sigma})) divided by signal intensity (\boldsymbol{g}^{\top}\boldsymbol{g}), which is essentially the reciprocal of the signal-to-noise ratio (SNR). It indicates that the smaller the SNR, the larger the Batch Size required to use the same \eta_{\max}, which also aligns with our intuition. \mathop{\mathrm{tr}}(\boldsymbol{\Sigma}) only depends on the diagonal elements of \boldsymbol{\Sigma}, meaning we only need to independently estimate the mean and variance of each parameter, which is feasible in practice.

Data Efficiency

In addition to the direct relationship between learning rate and Batch Size, I believe the derived asymptotic relationship regarding training data volume and training steps is also a brilliant part that must be studied. In particular, this conclusion seems more general than the learning rate formula [eq:eta-sgd], because as we will see later, SignSGD can also yield a conclusion of the same form, even though its learning rate law is not [eq:eta-sgd].

The original paper’s discussion of this part is quite complex; the following derivation is simplified by me. Specifically, substituting \eta^* back into \mathcal{L}(\boldsymbol{w} - \eta\tilde{\boldsymbol{g}}_B), we get: \begin{equation} \overline{\Delta\mathcal{L}} = \mathcal{L}(\boldsymbol{w}) - \mathbb{E}[\mathcal{L}(\boldsymbol{w} - \eta^*\tilde{\boldsymbol{g}}_B)] \approx \frac{\Delta\mathcal{L}_{\max}}{1 + \mathcal{B}_{\text{noise}}/B} \end{equation} where \Delta\mathcal{L}_{\max} = \frac{(\boldsymbol{g}^{\top}\boldsymbol{g})^2}{2\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}}. How should we understand this result? First, it is a monotonically increasing function of B. As B\to\infty, it equals \Delta\mathcal{L}_{\max}. In other words, if we could use an infinitely large Batch Size, the loss reduction per step would be \Delta\mathcal{L}_{\max}, and the required number of training steps would be minimized, denoted as S_{\min}.

If the Batch Size is finite, the average loss reduction per step is only \overline{\Delta\mathcal{L}}. This means that, on average, we need to take 1 + \mathcal{B}_{\text{noise}}/B steps to achieve the same reduction as 1 step with an infinite Batch Size. Thus, to achieve the same loss, we must train for S = (1 + \mathcal{B}_{\text{noise}}/B)S_{\min} steps.

Since the Batch Size is B, it is easy to conclude that the total amount of training data consumed is E = BS = (B + \mathcal{B}_{\text{noise}})S_{\min}. From this result, we can see that after increasing the Batch Size, to achieve the same effect, we also need to appropriately increase the data volume E. As B\to 0, the required data volume is minimized, E_{\min} = \mathcal{B}_{\text{noise}}S_{\min}. Using these notations, we can write: \begin{equation} \left(\frac{S}{S_{\min}} - 1\right)\left(\frac{E}{E_{\min}} - 1\right) = 1 \end{equation} This is the classic relationship between training data volume and training steps, which has two parameters S_{\min} and E_{\min}. We can also search for multiple (S, E) pairs through experiments to fit the above equation, thereby estimating S_{\min} and E_{\min}, and further estimating \mathcal{B}_{\text{noise}} = E_{\min} / S_{\min}. For more analysis details, please refer back to the previous article "How should the learning rate change as the Batch Size increases?" or OpenAI’s original paper "An Empirical Model of Large-Batch Training".

Difficulties in Analysis

Everything written so far has remained within the scope of SGD. From a computational perspective, SGD is trivial; the real complexity arises when \tilde{\boldsymbol{\varphi}}_B depends non-linearly on \tilde{\boldsymbol{g}}_B, such as SignSGD where \tilde{\boldsymbol{\varphi}}_B=\mathop{\mathrm{sign}}(\tilde{\boldsymbol{g}}_B). In theoretical analysis, it is often used as an approximation for Adam, while a more accurate approximation is SoftSignSGD, which considers \epsilon, as we attempted to analyze in "How does Adam’s epsilon affect the Scaling Law of the learning rate?".

In these non-linear scenarios, calculating \mathbb{E}[\tilde{\boldsymbol{\varphi}}_B] and \mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}] is often quite difficult, even if we assume \tilde{\boldsymbol{g}}_B follows a simple normal distribution (note that in SGD analysis, we do not need to assume its distribution form). For example, in previous articles, for SignSGD with \tilde{\boldsymbol{\varphi}}_B=\mathop{\mathrm{sign}}(\tilde{\boldsymbol{g}}_B), we went through the following steps to calculate \mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]:

  1. Assume the components of \tilde{\boldsymbol{g}}_B are independent, simplifying the problem to the expectation of a single component \tilde{\varphi}_B=\mathop{\mathrm{sign}}(\tilde{g}_B);

  2. Assume \tilde{g}_B (now a scalar) follows a normal distribution, allowing us to calculate \mathbb{E}[\tilde{\varphi}_B], with the answer expressed using the \mathop{\mathrm{erf}} function;

  3. Approximate the \mathop{\mathrm{erf}} function with a function of the form x/\sqrt{x^2+c} to simplify the result.

In other words, we have to go through a series of convoluted steps just to barely calculate an approximate result that can be analyzed (this process first appeared in Tencent’s paper "Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling"). And this is already considered simple; SoftSignSGD is even more complex:

  1. Assume the components of \tilde{\boldsymbol{g}}_B are independent, simplifying the problem to the expectation of a single component \tilde{\varphi}_B=\mathop{\mathrm{softsign}}(\tilde{g}_B, \epsilon);

  2. Approximate the \mathop{\mathrm{softsign}} function with a piecewise linear function to calculate the integral;

  3. Assume \tilde{g}_B follows a normal distribution, and combined with the approximation in step 2, calculate \mathbb{E}[\tilde{\varphi}_B], resulting in a complex function containing \mathop{\mathrm{erf}};

  4. Approximate the complex function with a function of the form x/\sqrt{x^2+c} to simplify the result.

The trouble doesn’t end there. After all that effort and numerous assumptions, we only manage to calculate \mathbb{E}[\tilde{\boldsymbol{\varphi}}_B], and then we still need to calculate \mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}], which is often even more complex (SignSGD is an exception because \mathop{\mathrm{sign}}(x)^2 is always 1, making it simpler). However, the computational complexity is secondary; the main issue is that these steps don’t seem to follow any generalizable pattern, making it feel like every problem must be analyzed specifically, which is quite exhausting.

To be continued

To avoid making the article too long, this post ends here, primarily reviewing existing analysis results and computational difficulties. In the next article, I will introduce some attempts I made to reduce the mental burden during the derivation process.

Reprinted please include this article address: https://kexue.fm/archives/11260

For more detailed reprint matters, please refer to: Scientific Space FAQ