In the blog post "Asymptotic Estimation of Weight RMS of AdamW (Part 1)", we derived the asymptotic expression for the RMS of model weights trained by AdamW. However, at that time, we assumed that the Weight Decay and learning rate were fixed throughout the training process, which does not perfectly align with actual training. Therefore, in this article, we will generalize the previous conclusions into a dynamic version.
The so-called dynamic version allows both Weight Decay and the learning rate to change as the number of training steps increases, such as the classic Cosine Decay, WSD (Warmup Stable Decay), etc., thereby making the conclusions more general.
Step 1
Our starting point is still the definition of AdamW: \begin{equation} \text{Adam}\textcolor{skyblue}{\text{W}} := \left\{ \begin{aligned} &\bm{m}_t = \beta_1 \bm{m}_{t-1} + (1 - \beta_1) \bm{g}_t \\ &\bm{v}_t = \beta_2 \bm{v}_{t-1} + (1 - \beta_2) \bm{g}_t^2 \\ &\hat{\bm{m}}_t = \bm{m}_t / (1 - \beta_1^t) \\ &\hat{\bm{v}}_t = \bm{v}_t / (1 - \beta_2^t) \\ &\bm{u}_t = \hat{\bm{m}}_t / (\sqrt{\hat{\bm{v}}_t} + \epsilon) \\ &\bm{\theta}_t = \bm{\theta}_{t-1} - \eta_t (\bm{u}_t \textcolor{skyblue}{+ \lambda_t \bm{\theta}_{t-1}}) \end{aligned} \right. \end{equation} Since \eta_t \lambda_t \ll 1, we can write: \begin{equation} \bm{\theta}_t = (1 - \eta_t \lambda_t) \bm{\theta}_{t-1} - \eta_t \bm{u}_t \approx e^{-\eta_t \lambda_t} \bm{\theta}_{t-1} - \eta_t \bm{u}_t \end{equation} Let \kappa_t = \sum_{i=1}^t \eta_i \lambda_i. Expanding directly gives: \begin{equation} \bm{\theta}_t \approx e^{-\kappa_t} \bm{\theta}_0 - \sum_{i=1}^t e^{-(\kappa_t - \kappa_i)} \eta_i \bm{u}_i = e^{-\kappa_t} \left( \bm{\theta}_0 - \sum_{i=1}^t e^{\kappa_i} \eta_i \bm{u}_i \right) \end{equation} Then let z_t = \sum_{i=1}^t e^{\kappa_i} \eta_i. By the mean field approximation: \begin{equation} \bar{\bm{u}}_t \triangleq \frac{1}{z_t} \sum_{i=1}^t e^{\kappa_i} \eta_i \bm{u}_i = \frac{1}{z_t} \sum_{i=1}^t e^{\kappa_i} \eta_i \frac{\bm{m}_i}{\sqrt{\bm{v}_i}} \approx \frac{\bar{\bm{m}}_t \triangleq \frac{1}{z_t} \sum_{i=1}^t e^{\kappa_i} \eta_i \bm{m}_i}{\sqrt{\bar{\bm{v}}_t \triangleq \frac{1}{z_t} \sum_{i=1}^t e^{\kappa_i} \eta_i \bm{v}_i}} \end{equation} Thus: \begin{equation} \|\bm{\theta}_t\|_{RMS}^2 \approx e^{-2\kappa_t} \|\bm{\theta}_0 - z_t \bar{\bm{u}}_t\|_{RMS}^2 \approx e^{-2\kappa_t} \|\bm{\theta}_0\|_{RMS}^2 + e^{-2\kappa_t} z_t^2 \|\bar{\bm{u}}_t\|_{RMS}^2 \end{equation}
Step 2
Following the previous logic, to estimate \|\bar{\bm{u}}_t\|_{RMS}^2, we need to assume that \bm{g}_j are independent and identically distributed following \mathcal{N}(\bm{\mu}, \bm{\sigma}^2), and then find: \begin{equation} \mathbb{E}[\bar{\bm{u}}_t^2] \approx \mathbb{E}\left[ \frac{\bar{\bm{m}}_t^2}{\bar{\bm{v}}_t} \right] \approx \frac{\mathbb{E}[\bar{\bm{m}}_t^2]}{\mathbb{E}[\bar{\bm{v}}_t]} \end{equation} Finally, by averaging the components of \mathbb{E}[\bar{\bm{u}}_t^2], the result can serve as an approximation for \|\bar{\bm{u}}_t\|_{RMS}^2.
Expanding \bm{m}_t and \bm{v}_t: \begin{equation} \bm{m}_t = (1 - \beta_1) \sum_{i=1}^t \beta_1^{t-i} \bm{g}_i, \qquad \bm{v}_t = (1 - \beta_2) \sum_{i=1}^t \beta_2^{t-i} \bm{g}_i^2 \label{eq:mv-roll} \end{equation} We also have the identity: \begin{equation} \sum_{i=1}^t \sum_{j=1}^i a_i b_j = \sum_{j=1}^t \sum_{i=j}^t a_i b_j \end{equation} Using these results, we can write: \begin{gather} \bar{\bm{m}}_t = \frac{1}{z_t} \sum_{i=1}^t e^{\kappa_i} \eta_i \bm{m}_i = \frac{1 - \beta_1}{z_t} \sum_{i=1}^t e^{\kappa_i} \eta_i \sum_{j=1}^i \beta_1^{i-j} \bm{g}_j = \sum_{j=1}^t \bm{g}_j \underbrace{\frac{1 - \beta_1}{z_t} \sum_{i=j}^t e^{\kappa_i} \beta_1^{i-j} \eta_i}_{\text{denoted as } \bar{\beta}_1(j,t)} \\ \bar{\bm{v}}_t = \frac{1}{z_t} \sum_{i=1}^t e^{\kappa_i} \eta_i \bm{v}_i = \frac{1 - \beta_2}{z_t} \sum_{i=1}^t e^{\kappa_i} \eta_i \sum_{j=1}^i \beta_2^{i-j} \bm{g}_j^2 = \sum_{j=1}^t \bm{g}_j^2 \underbrace{\frac{1 - \beta_2}{z_t} \sum_{i=j}^t e^{\kappa_i} \beta_2^{i-j} \eta_i}_{\text{denoted as } \bar{\beta}_2(j,t)} \end{gather}
Step 3
First, consider the denominator. When t is sufficiently large (\beta_1^t, \beta_2^t are small enough), \sum_{j=1}^t \bar{\beta}_1(j,t) and \sum_{j=1}^t \bar{\beta}_2(j,t) will be close to 1 (since they are double weighted averages with swapped summation indices), so: \begin{equation} \mathbb{E}[\bar{\bm{v}}_t] = \sum_{j=1}^t \bar{\beta}_2(j,t) \mathbb{E}[\bm{g}_j^2] = \sum_{j=1}^t \bar{\beta}_2(j,t) (\bm{\mu}^2 + \bm{\sigma}^2) \approx \bm{\mu}^2 + \bm{\sigma}^2 \end{equation} Similarly, \mathbb{E}[\bar{\bm{m}}_t] = \bm{\mu}, and \mathbb{E}[\bar{\bm{m}}_t^2] = \mathbb{E}[\bar{\bm{m}}_t]^2 + \operatorname{Var}[\bar{\bm{m}}_t]. Using the additivity of variance: \begin{equation} \operatorname{Var}[\bar{\bm{m}}_t] = \sum_{j=1}^t \bar{\beta}_1(j,t)^2 \operatorname{Var}[\bm{g}_j] = \sum_{j=1}^t \bar{\beta}_1(j,t)^2 \bm{\sigma}^2 \end{equation} So: \begin{equation} \mathbb{E}[\bar{\bm{u}}_t^2] \approx \frac{\bm{\mu}^2 + \bm{\sigma}^2 \sum_{j=1}^t \bar{\beta}_1(j,t)^2}{\bm{\mu}^2 + \bm{\sigma}^2} \end{equation} And: \begin{equation} \|\bar{\bm{u}}_t\|_{RMS}^2 \approx \frac{\|\bm{\mu}\|^2/\|\bm{\sigma}\|^2 + \sum_{j=1}^t \bar{\beta}_1(j,t)^2}{\|\bm{\mu}\|^2/\|\bm{\sigma}\|^2 + 1} \end{equation} Finally: \begin{equation} \|\bm{\theta}_t\|_{RMS}^2 \approx e^{-2\kappa_t} \|\bm{\theta}_0\|_{RMS}^2 + e^{-2\kappa_t} z_t^2 \frac{\|\bm{\mu}\|^2/\|\bm{\sigma}\|^2 + \sum_{j=1}^t \bar{\beta}_1(j,t)^2}{\|\bm{\mu}\|^2/\|\bm{\sigma}\|^2 + 1} \end{equation} If readers are looking at this article directly, some steps might seem like a leap. In that case, it is worth revisiting "Asymptotic Estimation of Weight RMS of AdamW (Part 1)" to familiarize yourself with the ideas behind each approximation.
Example 1
First, consider \bm{\mu} = \bm{0}. Substituting the expression for \bar{\beta}_1(j,t) into the above equation gives: \begin{equation} \|\bm{\theta}_t\|_{RMS}^2 \approx e^{-2\kappa_t} \|\bm{\theta}_0\|_{RMS}^2 + e^{-2\kappa_t} (1 - \beta_1)^2 \sum_{j=1}^t \left( \sum_{i=j}^t e^{\kappa_i} \beta_1^{i-j} \eta_i \right)^2 \label{eq:w-rms-mu0} \end{equation} Consider the simple case where \lambda_t = 0 (no Weight Decay). Then: \begin{equation} \|\bm{\theta}_t\|_{RMS}^2 \approx \|\bm{\theta}_0\|_{RMS}^2 + (1 - \beta_1)^2 \sum_{j=1}^t \left( \sum_{i=j}^t \beta_1^{i-j} \eta_i \right)^2 \end{equation} If \beta_1 \to 0, then immediately \|\bm{\theta}_t\|_{RMS}^2 \approx \|\bm{\theta}_0\|_{RMS}^2 + \sum_{j=1}^t \eta_j^2. This indicates that without Weight Decay and as t \to \infty, for the Weight RMS not to explode, the sum of the squares of the learning rate sequence must converge, which is one of the classic conditions in traditional optimization theory. In fact, even for 0 < \beta_1 < 1, this condition is necessary and sufficient: \begin{equation} \sum_{j=1}^{\infty} \left( \sum_{i=j}^{\infty} \beta_1^{i-j} \eta_i \right)^2 < \infty \iff \sum_{j=1}^{\infty} \eta_j^2 < \infty \end{equation} The proof is not difficult. We transform the left side: \begin{equation} \begin{aligned} \sum_{j=1}^{\infty} \left( \sum_{i=j}^{\infty} \beta_1^{i-j} \eta_i \right)^2 &= \sum_{j=1}^{\infty} \left( \sum_{i=0}^{\infty} \beta_1^i \eta_{i+j} \right)^2 = \sum_{j=1}^{\infty} \left( \sum_{i_1=0}^{\infty} \beta_1^{i_1} \eta_{i_1+j} \right) \left( \sum_{i_2=0}^{\infty} \beta_1^{i_2} \eta_{i_2+j} \right) \\ &= \sum_{i_1=0}^{\infty} \sum_{i_2=0}^{\infty} \beta_1^{i_1 + i_2} \sum_{j=1}^{\infty} \eta_{i_1+j} \eta_{i_2+j} \end{aligned} \end{equation} This shows that if the left side converges, then for all i_1, i_2, the sum \sum_{j=1}^{\infty} \eta_{i_1+j} \eta_{i_2+j} must converge, which naturally implies \sum_{j=1}^{\infty} \eta_j^2 converges. Sufficiency can be proven using the Cauchy-Schwarz inequality: \begin{equation} \begin{aligned} \sum_{i_1=0}^{\infty} \sum_{i_2=0}^{\infty} \beta_1^{i_1 + i_2} \sum_{j=1}^{\infty} \eta_{i_1+j} \eta_{i_2+j} &\leq \sum_{i_1=0}^{\infty} \sum_{i_2=0}^{\infty} \beta_1^{i_1 + i_2} \sqrt{\left( \sum_{j=1}^{\infty} \eta_{i_1+j}^2 \right) \left( \sum_{j=1}^{\infty} \eta_{i_2+j}^2 \right)} \\ &\leq \sum_{i_1=0}^{\infty} \sum_{i_2=0}^{\infty} \beta_1^{i_1 + i_2} \sum_{j=1}^{\infty} \eta_j^2 = \frac{1}{(1 - \beta_1)^2} \sum_{j=1}^{\infty} \eta_j^2 \end{aligned} \end{equation}
Example 2
Next, consider the case where Weight Decay is constant and the learning rate is variable. Here \kappa_t = \lambda \sum_{i=1}^t \eta_i. If we want to train indefinitely and get as close as possible to the theoretical optimum, the learning rate should satisfy \sum_{i=1}^{\infty} \eta_i \to \infty, so that the initialization is completely "forgotten" via the first term e^{-2\kappa_t} \|\bm{\theta}_0\|_{RMS}^2. Interestingly, this is also a classic condition in traditional optimization theory.
For the general case, calculating [eq:w-rms-mu0] is difficult, but we can consider an approximation. In practice, \lambda_t \eta_t \ll 1, so the growth rate of e^{\kappa_i} is much slower than the decay rate of \beta_1^i. Also, \eta_i is usually slowly varying compared to \beta_1^i. Thus: \begin{equation} \sum_{i=j}^t e^{\kappa_i} \beta_1^{i-j} \eta_i \approx \sum_{i=j}^t e^{\kappa_j} \beta_1^{i-j} \eta_j \approx e^{\kappa_j} \eta_j \sum_{i=j}^{\infty} \beta_1^{i-j} = \frac{e^{\kappa_j} \eta_j}{1 - \beta_1} \end{equation} Substituting this back into [eq:w-rms-mu0]: \begin{equation} \|\bm{\theta}_t\|_{RMS}^2 \approx e^{-2\kappa_t} \|\bm{\theta}_0\|_{RMS}^2 + e^{-2\kappa_t} \sum_{j=1}^t e^{2\kappa_j} \eta_j^2 \label{eq:w-rms-simp} \end{equation} For constant \lambda and \eta, \kappa_t = \lambda \eta t, and: \begin{equation} \|\bm{\theta}_t\|_{RMS}^2 \approx e^{-2\lambda \eta t} \|\bm{\theta}_0\|_{RMS}^2 + \frac{e^{2\lambda \eta} (1 - e^{-2\lambda \eta t})}{e^{2\lambda \eta} - 1} \eta^2 \approx e^{-2\lambda \eta t} \|\bm{\theta}_0\|_{RMS}^2 + (1 - e^{-2\lambda \eta t}) \frac{\eta}{2\lambda} \end{equation} This matches the result from the previous article.
Differential Equation
Equation [eq:w-rms-simp] is concise for numerical calculation, but for analytical results with general \lambda_t, \eta_t, we can use an integral approximation: \begin{equation} \|\bm{\theta}_t\|_{RMS}^2 \approx e^{-2\kappa_t} \|\bm{\theta}_0\|_{RMS}^2 + e^{-2\kappa_t} \int_0^t e^{2\kappa_s} \eta_s^2 ds \label{eq:w-rms-int} \end{equation} where \kappa_t = \int_0^t \lambda_s \eta_s ds. Let \rho_t = \|\bm{\theta}_t\|_{RMS}^2. Multiplying by e^{2\kappa_t} and differentiating: \begin{equation} \frac{d}{dt} \rho_t \approx -2\lambda_t \eta_t \rho_t + \eta_t^2 \end{equation} This is the differential equation for the RMS squared. If \rho_t converges to a constant as t \to \infty, then: \begin{equation} \lim_{t \to \infty} \rho_t \approx \lim_{t \to \infty} \frac{\eta_t}{2\lambda_t} \end{equation} This suggests that for decay-type learning rate schedules, the final learning rate should not be set to 0, otherwise the weights risk collapsing; alternatively, one can set \lambda_t \propto \eta_t (e.g., AdamC).
Mean Field
In pre-training scenarios (Single-Epoch), \kappa_t is often \Theta(1). Under this assumption, we can use a mean field approximation. From [eq:w-rms-int], since \kappa_s is monotonically increasing from 0 to \kappa_t: \begin{equation} e^{-2\kappa_t} \int_0^t \eta_s^2 ds \leq \int_0^t e^{2\kappa_s - 2\kappa_t} \eta_s^2 ds \leq \int_0^t \eta_s^2 ds \end{equation} Let \nu_t = \int_0^t \eta_s^2 ds. If \kappa_t = \Theta(1), then \nu_t itself is a good approximation. More precisely, approximating \kappa_s as (\kappa_t/t)s: \begin{equation} e^{-2\kappa_t} \int_0^t e^{2\kappa_s} \eta_s^2 ds \approx \frac{\nu_t e^{-2\kappa_t}}{t} \int_0^t e^{2(\kappa_t/t)s} ds = \frac{\nu_t}{2\kappa_t} (1 - e^{-2\kappa_t}) \end{equation} Substituting into [eq:w-rms-int]: \begin{equation} \|\bm{\theta}_t\|_{RMS}^2 \approx e^{-2\kappa_t} \|\bm{\theta}_0\|_{RMS}^2 + (1 - e^{-2\kappa_t}) \frac{\nu_t}{2\kappa_t} \end{equation}
Example 3
For fixed Weight Decay and variable learning rate, we calculate \kappa_t, \nu_t for common schedules. Linear Schedule: \eta_s = \eta_a + (\eta_b - \eta_a)s/t. \begin{gather} \kappa_t = \lambda (\eta_a + \eta_b) t / 2, \qquad \nu_t = (\eta_a^2 + \eta_a \eta_b + \eta_b^2) t / 3 \end{gather} Cosine Decay: \eta_s = \eta_{\min} + (\eta_{\max} - \eta_{\min})(\frac{1}{2} + \frac{1}{2}\cos \frac{s\pi}{t}). \begin{gather} \kappa_t = \lambda (\eta_{\min} + \eta_{\max}) t / 2, \qquad \nu_t = (3\eta_{\min}^2 + 2\eta_{\min} \eta_{\max} + 3\eta_{\max}^2) t / 8 \end{gather} WSD (Warmup Stable Decay): \begin{gather} \kappa_t = \lambda \eta_{\max} (t + t_2 - t_1) / 2, \qquad \nu_t = \eta_{\max}^2 (t + 2t_2 - 2t_1) / 3 \end{gather}
Simulation Verification
We can verify these approximations with a numerical simulation:
import numpy as np
N, T = 10000, 10000
beta1, beta2 = 0.9, 0.95
m, v = 0, 0
w = np.random.randn(N) * (init_std := 0.1)
lr_max, lr_min, wd = 0.001, 0.0001, 0.1
lr = lr_min + (lr_max - lr_min) * (1 + np.cos(np.arange(T) / T * np.pi)) / 2
for i in range(T):
g = np.random.randn(N)
m = beta1 * m + (1 - beta1) * g
v = beta2 * v + (1 - beta2) * g**2
w = w - lr[i] * (m / v**0.5 + wd * w)
# Direct calculation approx 0.0744
weight_rms = (w**2).mean()**0.5
# Series approximation approx 0.0742
kappa = wd * lr.cumsum()
approx1 = ((np.exp(kappa * 2) * lr**2).sum() + init_std**2)**0.5 * np.exp(-kappa[-1])
# Mean field approximation approx 0.0760
kappa_val = wd * (lr_max + lr_min) / 2 * T
nu_val = (3 * lr_max**2 + 2 * lr_max * lr_min + 3 * lr_min**2) / 8 * T
approx2 = ((np.exp(kappa_val * 2) - 1) * nu_val / kappa_val / 2 + init_std**2)**0.5 * np.exp(-kappa_val)
print(f"Weight RMS: {weight_rms}")
print(f"Approx 1: {approx1}")
print(f"Approx 2: {approx2}")Summary
This article generalizes the results of the previous post to a dynamic version, allowing us to estimate the Weight RMS of AdamW under time-varying learning rates and Weight Decay.
Reprinting: Please include the original address: https://kexue.fm/archives/11404
For more details on reprinting, please refer to: "Scientific Space FAQ"