English (unofficial) translations of posts at kexue.fm
Source

Asymptotic Estimation of Weight RMS for AdamW

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

In "Why is Adam’s Update RMS 0.2?", we used the mean field approximation to estimate the Update RMS of Adam. Shortly after, reader @EIFY pointed out that the same result had already appeared in the paper "Rotational Equilibrium: How Weight Decay Balances Learning Across Neural Networks". Upon reading it, I discovered that it contains not only the estimation of Update RMS but also the estimation of Weight RMS.

In other words, for a model trained with AdamW, the RMS of its weights can be estimated as an asymptotic result in advance. Does this conclusion seem surprising? I was quite surprised when I first saw it. Intuitively, the weight magnitude is something the model learns from the training set; the idea that it is already hidden within the optimizer’s hyperparameters is quite counter-intuitive.

In this article, we will again use the mean field approximation method to reproduce the asymptotic estimation of Weight RMS.

Sliding Perspective

First, let’s review the update rules for AdamW: \begin{equation} \text{Adam}\color{cyan}{\text{W}}:=\left\{\begin{aligned} &\boldsymbol{m}_t = \beta_1 \boldsymbol{m}_{t-1} + \left(1 - \beta_1\right) \boldsymbol{g}_t\\ &\boldsymbol{v}_t = \beta_2 \boldsymbol{v}_{t-1} + \left(1 - \beta_2\right) \boldsymbol{g}_t^2\\ &\hat{\boldsymbol{m}}_t = \boldsymbol{m}_t\left/\left(1 - \beta_1^t\right)\right.\\ &\hat{\boldsymbol{v}}_t = \boldsymbol{v}_t\left/\left(1 - \beta_2^t\right)\right.\\ &\boldsymbol{u}_t =\hat{\boldsymbol{m}}_t\left/\left(\sqrt{\hat{\boldsymbol{v}}_t} + \epsilon\right)\right.\\ &\boldsymbol{\theta}_t = \boldsymbol{\theta}_{t-1} - \eta_t (\boldsymbol{u}_t \color{cyan}{ + \lambda_t \boldsymbol{\theta}_{t-1}}) \end{aligned}\right. \end{equation} Again, bold symbols denote vectors in \mathbb{R}^d by default, and vector multiplication/division (including squares and square roots) refers to element-wise Hadamard products/quotients.

As in "Why is Adam’s Update RMS 0.2?", we consider t \to \infty (relative to \beta_1, \beta_2) and \epsilon \to 0, so \boldsymbol{u}_t = \boldsymbol{m}_t / \sqrt{\boldsymbol{v}_t}. For now, let’s consider the case where \eta_t and \lambda_t are constants, so their subscripts can be omitted. By defining \beta_3 = 1 - \eta\lambda, we have: \begin{equation} \boldsymbol{\theta}_t = \beta_3\boldsymbol{\theta}_{t-1} + (1-\beta_3)(-\boldsymbol{u}_t/\lambda) \label{eq:ema-wd} \end{equation} This equation shows that we can understand Weight Decay from the perspective of an Exponential Moving Average (EMA) of the update amounts. This is a meaningful perspective shift and serves as the foundation for works such as "How to set AdamW’s weight decay as you scale model and dataset size" and "Power Lines: Scaling Laws for Weight Decay and Batch Size in LLM Pre-training".

Weighted Average

According to Equation [eq:ema-wd], we can expand \boldsymbol{\theta}_t into a weighted average form: \begin{equation} \boldsymbol{\theta}_t = \beta_3^t\boldsymbol{\theta}_0 + (1-\beta_3)\sum_{i=1}^t \beta_3^{t-i} (-\boldsymbol{u}_i/\lambda) \label{eq:theta-t} \end{equation} Similarly, \boldsymbol{m}_t and \boldsymbol{v}_t can be expanded as: \begin{equation} \boldsymbol{m}_t = (1 - \beta_1)\sum_{i=1}^t \beta_1^{t-i}\boldsymbol{g}_i,\qquad \boldsymbol{v}_t = (1 - \beta_2)\sum_{i=1}^t \beta_2^{t-i}\boldsymbol{g}_i^2 \label{eq:mv-roll} \end{equation} There is a small detail here: we retained \boldsymbol{\theta}_0 in the expression for \boldsymbol{\theta}_t, but we did not retain \boldsymbol{m}_0 and \boldsymbol{v}_0 for \boldsymbol{m}_t and \boldsymbol{v}_t. This is for two reasons: 1. \boldsymbol{m} and \boldsymbol{v} are typically initialized to zero; 2. Even if they were not zero, the corresponding \beta_1^t and \beta_2^t would become close enough to zero that the influence of initialization can be ignored.

However, \boldsymbol{\theta} represents the model weights, and its initialization is usually not zero. Furthermore, \beta_3 is often very close to 1. For the entire training cycle, \beta_3^t may not necessarily be close enough to zero. Therefore, we explicitly retain \beta_3^t and \boldsymbol{\theta}_0, choosing whether to keep them based on the situation.

Fast Estimation

Our task is to estimate the Weight RMS, denoted as \|\boldsymbol{\theta}_t\|_{RMS}. As the name suggests, it is the Root Mean Square of the individual components: \begin{equation} \|\boldsymbol{\theta}\|_{RMS} = \sqrt{\frac{1}{d}\sum_{i=1}^d \theta_i^2},\qquad\qquad \text{where } \boldsymbol{\theta} = (\theta_1,\theta_2,\cdots,\theta_d) \end{equation} The difference between it and the norm is just the division by \sqrt{d}, so most properties of the norm apply to the RMS as well. For \|\boldsymbol{\theta}_t\|_{RMS}, there is a fast but not entirely accurate derivation: by taking the \|\cdot\|_{RMS}^2 of both sides of Equation [eq:ema-wd], we get: \begin{equation} \begin{aligned} \|\boldsymbol{\theta}_t\|_{RMS}^2 =&\, \|\beta_3\boldsymbol{\theta}_{t-1} + (1-\beta_3)(-\boldsymbol{u}_t/\lambda)\|_{RMS}^2 \\[5pt] =&\, \beta_3^2\|\boldsymbol{\theta}_{t-1}\|_{RMS}^2 + (1-\beta_3)^2\|\boldsymbol{u}_t\|_{RMS}^2/\lambda^2 - 2\beta_3(1-\beta_3)\boldsymbol{\theta}_{t-1}\cdot\boldsymbol{u}_t/(\lambda d) \end{aligned} \end{equation} Assuming \boldsymbol{\theta}_{t-1} and \boldsymbol{u}_t are nearly orthogonal, then \boldsymbol{\theta}_{t-1}\cdot\boldsymbol{u}_t \approx 0. This is usually a good approximation in high-dimensional spaces (refer to "Distribution of the angle between two random vectors in n-dimensional space"). We have already calculated \|\boldsymbol{u}_t\|_{RMS}, which is approximately \sqrt{\frac{1-\beta_1}{1+\beta_1}}. Finally, considering the steady-state result where \|\boldsymbol{\theta}_t\|_{RMS}^2 = \|\boldsymbol{\theta}_{t-1}\|_{RMS}^2, we have: \begin{equation} (1-\beta_3^2)\|\boldsymbol{\theta}_t\|_{RMS}^2 \approx (1-\beta_3)^2 \frac{1-\beta_1}{1+\beta_1} /\lambda^2 \qquad\Rightarrow\qquad \|\boldsymbol{\theta}_t\|_{RMS} \approx \sqrt{\frac{1-\beta_1}{1+\beta_1}\frac{\eta}{2\lambda}} \end{equation} The approximation \beta_3 \approx 1 was used in the transition from the left to the right expression. The final result will have some error because \boldsymbol{\theta}_t\cdot\boldsymbol{u}_t \approx 0 is not strictly true, but the conclusion \|\boldsymbol{\theta}_t\|_{RMS} \propto \sqrt{\eta/\lambda} is correct. A similar derivation also appears in "Why Gradients Rapidly Increase Near the End of Training".

Better Approximation

In many cases, knowing \|\boldsymbol{\theta}_t\|_{RMS} \propto \sqrt{\eta/\lambda} is sufficient; it is a fairly general conclusion. For readers seeking a more accurate conclusion, we can use the mean field method to obtain a better approximation. The cost is a more complex calculation process, but the benefit is a clearer and more detailed understanding.

Step One

Starting from Equation [eq:theta-t], the summation term itself has the form of a weighted average. We first apply the mean field approximation: \begin{equation} \underbrace{\frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i} \boldsymbol{u}_i}_{\text{denoted as } \bar{\boldsymbol{u}}_t} = \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i} \frac{\hat{\boldsymbol{m}}_i}{\sqrt{\hat{\boldsymbol{v}}_i}} \approx \frac{\bar{\boldsymbol{m}}_t \,\,\triangleq\,\, \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\hat{\boldsymbol{m}}_i}{\sqrt{\bar{\boldsymbol{v}}_t \,\,\triangleq\,\, \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\hat{\boldsymbol{v}}_i}} \label{eq:u-bar} \end{equation} Now returning to Equation [eq:theta-t], since \boldsymbol{\theta}_0 is a random initialization vector, we can assume \boldsymbol{\theta}_0 is orthogonal to \bar{\boldsymbol{u}}_t, thus: \begin{equation} \|\boldsymbol{\theta}_t\|_{RMS}^2 \approx \beta_3^{2t}\|\boldsymbol{\theta}_0\|_{RMS}^2 + (1-\beta_3^t)^2 \lambda^{-2}\| \bar{\boldsymbol{u}}_t\|_{RMS}^2 \end{equation} Now we need to find \| \bar{\boldsymbol{u}}_t\|_{RMS}^2. Based on previous experience, we assume \boldsymbol{g}_j are i.i.d. following \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\sigma}^2), and then calculate: \begin{equation} \mathbb{E}[\bar{\boldsymbol{u}}_t^2] \approx \mathbb{E}\left[\frac{\bar{\boldsymbol{m}}_t^2}{\bar{\boldsymbol{v}}_t}\right] \approx \frac{\mathbb{E}[\bar{\boldsymbol{m}}_t^2]}{\mathbb{E}[\bar{\boldsymbol{v}}_t]} \end{equation} Finally, averaging over the components of \mathbb{E}[\bar{\boldsymbol{u}}_t^2] serves as an approximation for \| \bar{\boldsymbol{u}}_t\|_{RMS}^2.

Step Two

Combining with Equation [eq:mv-roll], we get: \begin{gather} \sum_{i=1}^t \beta_3^{t-i}\hat{\boldsymbol{m}}_i = (1 - \beta_1)\sum_{i=1}^t \beta_3^{t-i} \sum_{j=1}^i \beta_1^{i-j}\boldsymbol{g}_j = (1 - \beta_1)\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_1^{t-j+1}}{\beta_3 - \beta_1}\boldsymbol{g}_j\\ \sum_{i=1}^t \beta_3^{t-i}\hat{\boldsymbol{v}}_i = (1 - \beta_2)\sum_{i=1}^t \beta_3^{t-i} \sum_{j=1}^i \beta_2^{i-j}\boldsymbol{g}_j^2 = (1 - \beta_2)\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_2^{t-j+1}}{\beta_3 - \beta_2}\boldsymbol{g}_j^2 \end{gather} The simplification of the double summation can be handled by tools like Kimi (refer to link). As seen above, \bar{\boldsymbol{m}}_t and \bar{\boldsymbol{v}}_t are weighted averages of the gradient and the squared gradient, respectively. Thus, calculating \| \bar{\boldsymbol{u}}_t\|_{RMS}^2 is essentially the same as calculating \| \boldsymbol{u}_t\|_{RMS}^2 in "Why is Adam’s Update RMS 0.2?", just with different weighting coefficients.

Step Three

First, we calculate the denominator: \begin{equation} \begin{aligned} \mathbb{E}[\bar{\boldsymbol{v}}_t] =&\, \frac{(1 - \beta_3)(1 - \beta_2)}{1 - \beta_3^t}\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_2^{t-j+1}}{\beta_3 - \beta_2}\mathbb{E}[\boldsymbol{g}_j^2] \\ =&\, \frac{(1 - \beta_3)(1 - \beta_2)}{1 - \beta_3^t}\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_2^{t-j+1}}{\beta_3 - \beta_2}(\boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2) \\ =&\, \frac{(1 - \beta_3)(1 - \beta_2)}{(1 - \beta_3^t)(\beta_3 - \beta_2)}\left(\frac{\beta_3 - \beta_3^{t+1}}{1 - \beta_3} - \frac{\beta_2 - \beta_2^{t+1}}{1 - \beta_2}\right)(\boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2) \\[5pt] \approx &\, \boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2 \end{aligned} \end{equation} The final approximation is because in actual training, \beta_3 is close enough to 1, and \beta_2^{t+1} is close enough to 0, while \beta_3^{t+1} might not be. Thus, we replace \beta_2^{t+1} with zero, replace independent \beta_3 terms with 1 after simplification, and use the approximation \beta_3^{t+1} \approx \beta_3^t.

Step Four

Next is \mathbb{E}[\bar{\boldsymbol{m}}_t^2] = \mathbb{E}[\bar{\boldsymbol{m}}_t]^2 + \mathbb{V}ar[\bar{\boldsymbol{m}}_t]. The calculation of \mathbb{E}[\bar{\boldsymbol{m}}_t] is similar to \mathbb{E}[\bar{\boldsymbol{v}}_t], resulting in \boldsymbol{\mu}. For \mathbb{V}ar[\bar{\boldsymbol{m}}_t], we use the additivity of variance: \begin{equation} \begin{aligned} \mathbb{V}ar[\bar{\boldsymbol{m}}_t] =&\, \frac{(1 - \beta_3)^2(1 - \beta_1)^2}{(1-\beta_3^t)^2}\sum_{j=1}^t \left(\frac{\beta_3^{t-j+1} - \beta_1^{t-j+1}}{\beta_3 - \beta_1}\right)^2\mathbb{V}ar[\boldsymbol{g}_j] \\ =&\, \frac{(1 - \beta_3)^2(1 - \beta_1)^2}{(1-\beta_3^t)^2}\sum_{j=1}^t \left(\frac{\beta_3^{t-j+1} - \beta_1^{t-j+1}}{\beta_3 - \beta_1}\right)^2 \boldsymbol{\sigma}^2 \\ =&\, \frac{(1 - \beta_3)^2(1 - \beta_1)^2}{(1-\beta_3^t)^2(\beta_3 - \beta_1)^2}\left(\frac{\beta_3^2 - \beta_3^{2(t+1)}}{1 - \beta_3^2} + \frac{\beta_1^2 - \beta_1^{2(t+1)}}{1 - \beta_1^2} - 2\frac{\beta_1\beta_3 - \beta_1^{t+1}\beta_3^{t+1}}{1 - \beta_1\beta_3}\right) \boldsymbol{\sigma}^2 \\[5pt] \approx &\, (1 - \beta_3)(1 + \beta_3^t)\boldsymbol{\sigma}^2/2(1 - \beta_3^t) \end{aligned} \end{equation} The reasoning for the approximation is the same as above.

Step Five

Substituting the results from the previous two sections, we have: \begin{equation} \mathbb{E}[\bar{\boldsymbol{u}}_t^2] \approx \frac{\boldsymbol{\mu}^2 + (1 - \beta_3)(1 + \beta_3^t)\boldsymbol{\sigma}^2/2(1 - \beta_3^t)}{\boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2} \end{equation} Then: \begin{equation} \|\bar{\boldsymbol{u}}_t\|_{RMS}^2 \approx \frac{\|\boldsymbol{\mu}\|^2/\|\boldsymbol{\sigma}\|^2 + (1 - \beta_3)(1 + \beta_3^t)/2(1 - \beta_3^t)}{\|\boldsymbol{\mu}\|^2/\|\boldsymbol{\sigma}\|^2 + 1} \end{equation} Finally, we obtain: \begin{equation} \|\boldsymbol{\theta}_t\|_{RMS}^2 \approx \beta_3^{2t}\|\boldsymbol{\theta}_0\|_{RMS}^2 + (1-\beta_3^t)^2 \frac{\|\boldsymbol{\mu}\|^2/\|\boldsymbol{\sigma}\|^2 + (1 - \beta_3)(1 + \beta_3^t)/2(1 - \beta_3^t)}{\lambda^2(\|\boldsymbol{\mu}\|^2/\|\boldsymbol{\sigma}\|^2 + 1)} \label{eq:theta-rms} \end{equation}

Analysis of Results

Equation [eq:theta-rms] looks complicated, so let’s examine a few special cases. First, consider the case where \boldsymbol{\mu}=\boldsymbol{0}: \begin{equation} \|\boldsymbol{\theta}_t\|_{RMS}^2 \approx \beta_3^{2t}\|\boldsymbol{\theta}_0\|_{RMS}^2 + (1-\beta_3^{2t}) (1 - \beta_3)/2\lambda^2 = \beta_3^{2t}\|\boldsymbol{\theta}_0\|_{RMS}^2 + (1-\beta_3^{2t}) \eta/2\lambda \label{eq:theta-rms-mu0} \end{equation} Specifically, if we consider t \to \infty, or if \|\boldsymbol{\theta}_0\|_{RMS}^2 is initialized to \eta/2\lambda, then: \begin{equation} \|\boldsymbol{\theta}_t\|_{RMS} \approx \sqrt{\frac{\eta}{2\lambda}} \label{eq:theta-rms-simple} \end{equation} This is the result given in the paper "Rotational Equilibrium: How Weight Decay Balances Learning Across Neural Networks". Consistent with the original paper’s assumptions, it is the steady-state result of a random walk with zero mean. If we do not consider t \to \infty but instead consider the limit \lambda \to 0, then from Equation [eq:theta-rms-mu0] we get: \begin{equation} \|\boldsymbol{\theta}_t\|_{RMS}^2 \approx \|\boldsymbol{\theta}_0\|_{RMS}^2 + \eta^2 t \end{equation} This indicates that without Weight Decay, \|\boldsymbol{\theta}_t\|_{RMS} grows roughly at a rate of \eta\sqrt{t}. This also suggests that in the absence of Weight Decay, we can achieve Weight RMS stability by setting a specific learning rate schedule. On the other hand, if the Batch Size is large enough such that the signal-to-noise ratio term \|\boldsymbol{\mu}\|^2/\|\boldsymbol{\sigma}\|^2 dominates, then from Equation [eq:theta-rms]: \begin{equation} \|\boldsymbol{\theta}_t\|_{RMS}^2 \approx \beta_3^{2t}\|\boldsymbol{\theta}_0\|_{RMS}^2 + (1-\beta_3^t)^2 \frac{\|\boldsymbol{\mu}\|^2/\|\boldsymbol{\sigma}\|^2}{\lambda^2(\|\boldsymbol{\mu}\|^2/\|\boldsymbol{\sigma}\|^2 + 1)} \end{equation} This might apply to special cases where the model needs to actively increase Weight RMS. However, from experience, the probability of this happening is generally small.

Simulation Experiment

We can use the following simulation script to simply verify the accuracy of the above:

import numpy as np

N, T = 10000, 100000
beta1, beta2 = 0.9, 0.95
m, v = 0, 0
w = np.random.randn(N) * 0.1
for i in range(T):
    g = np.random.randn(N)
    m = beta1 * m + (1 - beta1) * g
    v = beta2 * v + (1 - beta2) * g**2
    w = w - 0.001 * (m / v**0.5 + 0.1 * w)

weight_rms = (w**2).mean()**0.5
print(weight_rms)

You can try changing the weight initialization or the mean and variance of the gradients to see how well the final result matches Equation [eq:theta-rms]. I tried it myself, and overall it is quite reliable.

Sign Version

By adjusting the previous proof slightly, it can be applied to the combination of "SignSGDM + Weight Decay": \begin{equation} \text{SignSGDM}\color{cyan}{\text{W}}:=\left\{\begin{aligned} &\boldsymbol{m}_t = \beta_1 \boldsymbol{m}_{t-1} + \left(1 - \beta_1\right) \boldsymbol{g}_t\\ &\boldsymbol{u}_t = \operatorname{sign}(\boldsymbol{m}_t)\\ &\boldsymbol{\theta}_t = \boldsymbol{\theta}_{t-1} - \eta_t (\boldsymbol{u}_t \color{cyan}{ + \lambda_t \boldsymbol{\theta}_{t-1}}) \end{aligned}\right. \end{equation} The modification is due to \operatorname{sign}(\boldsymbol{m}_t) = \boldsymbol{m}_t / \sqrt{\boldsymbol{m}_t^2}, so the definition of \bar{\boldsymbol{v}}_t should be changed to: \begin{equation} \bar{\boldsymbol{v}}_t \triangleq \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\boldsymbol{m}_i^2 \end{equation} Then: \begin{equation} \mathbb{E}[\bar{\boldsymbol{v}}_t] = \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\mathbb{E}[\boldsymbol{m}_i^2] \approx \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\mathbb{E}\left(\boldsymbol{\mu}^2 + \frac{1-\beta_1}{1 + \beta_1}\boldsymbol{\sigma}^2\right) = \boldsymbol{\mu}^2 + \frac{1-\beta_1}{1 + \beta_1}\boldsymbol{\sigma}^2 \end{equation} The calculation of \mathbb{E}[\boldsymbol{m}_i^2] can refer to "Why is Adam’s Update RMS 0.2?" or "Rethinking Learning Rate and Batch Size (IV): EMA". Using these results, we get: \begin{equation} \|\boldsymbol{\theta}_t\|_{RMS}^2 \approx \beta_3^{2t}\|\boldsymbol{\theta}_0\|_{RMS}^2 + (1-\beta_3^t)^2 \frac{\|\boldsymbol{\mu}\|^2/\|\boldsymbol{\sigma}\|^2 + (1 - \beta_3)(1 + \beta_3^t)/2(1 - \beta_3^t)}{\lambda^2\left(\|\boldsymbol{\mu}\|^2/\|\boldsymbol{\sigma}\|^2 + \frac{1-\beta_1}{1 + \beta_1}\right)} \end{equation} Specifically, considering the limit \boldsymbol{\mu}=0, t \to \infty, we have: \begin{equation} \|\boldsymbol{\theta}_t\|_{RMS}^2 \approx \sqrt{\frac{\eta}{2\lambda}\frac{1+\beta_1}{1 - \beta_1}} \end{equation} This result is also reasonable because the Update RMS of SignSGDMW is \sqrt{\frac{1+\beta_1}{1 - \beta_1}} times that of AdamW, so for the same \eta, \lambda, its Weight RMS is also \sqrt{\frac{1+\beta_1}{1 - \beta_1}} times larger.

Summary

In this article, we used the mean field approximation to derive an interesting and perhaps surprising conclusion: the RMS of the weights of a model trained with AdamW can be estimated asymptotically. In general, it depends only on the learning rate and Weight Decay.

When reposting, please include the original address: https://kexue.fm/archives/11307

For more detailed reposting matters, please refer to: "Scientific Space FAQ"