If training a model is likened to “alchemy,” then the “furnace” is clearly the optimizer. It is rumored that the AdamW optimizer is currently the fastest solution for training neural networks. While I have not compared every option individually, it is true that most pre-training tasks currently use AdamW or its variant, LAMB. However, just as having a furnace does not guarantee a high-quality elixir, even if we decide on the AdamW optimizer, many questions remain without definitive answers, such as:
How should the learning rate adapt to different initializations and parameterizations?
How should the weight decay rate be adjusted?
What strategy should be used for the learning rate schedule?
Can the memory footprint of the optimizer be reduced?
In practical applications, we often simply adopt parameters and strategies tuned by predecessors. However, the lack of systematic guidance for parameter tuning always leaves a sense of uncertainty during “alchemy.” In this article, based on the ideas of the Amos optimizer recently proposed by Google, we provide some reference results.
Background Review
The Amos optimizer comes from Google’s recent paper “Amos: An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale.” It provides a relatively complete derivation for the aforementioned questions and confirms its effectiveness through experiments. However, the original paper’s derivation is quite difficult to read; various notations and estimations are somewhat haphazard, giving a sense of “disorder.” Fortunately, the core idea of Amos is not overly complex, and we can borrow from it.
Before starting the derivation, let us review the existing solutions for the questions mentioned above.
First, regarding the first question, some may not fully understand what “initialization” and “parameterization” mean. These are two ways of setting model weights. A common example is an n \times n matrix, typically initialized with a “mean of 0 and variance of 1/n.” For a detailed introduction, readers can refer to my previous articles “Understanding Model Parameter Initialization Strategies from a Geometric Perspective” and “A Brief Discussion on Initialization, Parameterization, and Normalization in Transformers.” From the “variance of 1/n,” we can see that different parameters have different scales (or magnitudes). If we use the same learning rate to update all parameters, the update magnitude for each parameter will differ. I believe a more elegant solution to this problem is the LAMB optimizer, where the norm of each update depends directly on the norm of the parameter itself, and the learning rate is used to describe the relative update size.
As for the weight decay rate, at least in the field of pre-training, I have observed that the original choice of 0.01 is almost always followed, with little work dedicated to tuning this parameter. Regarding the learning rate schedule, it is well known that the learning rate should gradually decrease to zero, but there is little theoretical guidance on which specific decay strategy to choose; most results are summarized from experiments. Finally, concerning the reduction of memory usage, a classic work is the AdaFactor optimizer, which I introduced in “A Brief Analysis of the AdaFactor Optimizer (with Open Source Implementation).” There are two main ideas for reducing optimizer memory: removing momentum and performing low-rank decomposition on the second moment. Amos essentially follows these two ideas.
Problem Setting
This article primarily focuses on the first three questions mentioned at the beginning, hoping to derive some “plug-and-play” results. First, we simplify the optimizer’s update rule as: \begin{equation} \boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \alpha_t \boldsymbol{u}_t \end{equation} Here, \boldsymbol{\theta}_t and \boldsymbol{\theta}_{t+1} represent the parameter values at times t and t+1, respectively. \boldsymbol{u}_t represents the update vector at time t (depending on the task and data), and the scalar \alpha_t > 0 (where each element of the vector is greater than 0) represents the learning rate at time t.
Since AdamW, mainstream optimizers have tended to decouple the weight decay term from \boldsymbol{u}_t, i.e.: \begin{equation} \boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - (\alpha_t \boldsymbol{u}_t + \rho_t\boldsymbol{\theta}_t) \end{equation} where \rho_t > 0 is the weight decay rate. The main task of this article is to solve how \alpha_t and \rho_t should be set.
Weight Decay
We know that weight decay, like L2 regularization, is independent of the training objective itself; it is an auxiliary term intended to improve the model’s generalization ability. Since it is auxiliary, a basic requirement is that it should not “overshadow the main goal.” To this end, we might introduce a constraint: \begin{equation} \mathcal{O}(\alpha_t^2) = \mathcal{O}(\rho_t) \end{equation} That is, throughout the update process, the update magnitude brought by weight decay should always be one order higher than the update magnitude related to the objective. Since \alpha_t and \rho_t are generally less than 1, a higher order implies a smaller value.
Let the optimal parameter point be \boldsymbol{\theta}^*, and we denote \boldsymbol{\varepsilon}_t = \boldsymbol{\theta}_t - \boldsymbol{\theta}^*. According to the update rule, we have: \begin{equation} \begin{aligned} \Vert\boldsymbol{\varepsilon}_{t+1}\Vert^2 =&\, \Vert\boldsymbol{\theta}_{t+1} - \boldsymbol{\theta}^*\Vert^2 \\ =&\, \Vert\boldsymbol{\theta}_t - (\alpha_t \boldsymbol{u}_t + \rho_t\boldsymbol{\theta}_t) - \boldsymbol{\theta}^*\Vert^2 \\ \approx&\, \Vert\boldsymbol{\varepsilon}_t\Vert^2 - 2 \alpha_t \boldsymbol{u}_t \cdot \boldsymbol{\varepsilon}_t + \left(\alpha_t^2 \Vert\boldsymbol{u}_t\Vert^2 - 2 \rho_t \boldsymbol{\theta}_t \cdot \boldsymbol{\varepsilon}_t\right) \end{aligned}\label{eq:base-approx} \end{equation} The final approximation only retains terms up to \mathcal{O}(\alpha_t^2).
Clearly, \Vert\boldsymbol{\varepsilon}_t\Vert is the distance between the current result and the target, which we naturally want to minimize. Therefore, we hope each update reduces this distance, i.e., \Vert\boldsymbol{\varepsilon}_{t+1}\Vert < \Vert\boldsymbol{\varepsilon}_t\Vert. Looking at Eq. [eq:base-approx], - 2 \alpha_t \boldsymbol{u}_t \cdot \boldsymbol{\varepsilon}_t can be positive or negative; if it is negative, it helps achieve \Vert\boldsymbol{\varepsilon}_{t+1}\Vert < \Vert\boldsymbol{\varepsilon}_t\Vert. However, \alpha_t^2 \Vert\boldsymbol{u}_t\Vert^2 is necessarily positive, which is detrimental to achieving \Vert\boldsymbol{\varepsilon}_{t+1}\Vert < \Vert\boldsymbol{\varepsilon}_t\Vert. But with the introduction of weight decay, an additional term - 2 \rho_t \boldsymbol{\theta}_t \cdot \boldsymbol{\varepsilon}_t appears. If this term can cancel out the negative effect of \alpha_t^2 \Vert\boldsymbol{u}_t\Vert^2, then the introduction of weight decay not only enhances generalization but also aids model convergence.
Feasibility Analysis
Next, we examine the feasibility of: \begin{equation} \alpha_t^2 \Vert\boldsymbol{u}_t\Vert^2 = 2 \rho_t \boldsymbol{\theta}_t \cdot \boldsymbol{\varepsilon}_t\label{eq:base-cond} \end{equation} Feasibility means whether \boldsymbol{\theta}_t \cdot \boldsymbol{\varepsilon}_t can be greater than 0; only then can both sides be equal. Using the definition of \boldsymbol{\varepsilon}_t, we have \boldsymbol{\theta}_t = \boldsymbol{\varepsilon}_t + \boldsymbol{\theta}^*, so: \begin{equation} \boldsymbol{\theta}_t \cdot \boldsymbol{\varepsilon}_t = (\boldsymbol{\varepsilon}_t + \boldsymbol{\theta}^*) \cdot \boldsymbol{\varepsilon}_t = \Vert \boldsymbol{\varepsilon}_t\Vert^2 + \boldsymbol{\theta}^* \cdot \boldsymbol{\varepsilon}_t \end{equation} Note that \boldsymbol{\theta}^* is our target, a fixed point, while \boldsymbol{\varepsilon}_t is the difference vector between the current state and the target. Generally, there is no necessary correlation between them, so we can approximately treat them as two random vectors in a high-dimensional space. According to “Angle Distribution of Two Random Vectors in n-dimensional Space,” we know that two random vectors in high-dimensional space are almost always orthogonal, thus \boldsymbol{\theta}^* \cdot \boldsymbol{\varepsilon}_t \approx 0, meaning \boldsymbol{\theta}_t \cdot \boldsymbol{\varepsilon}_t \approx \Vert \boldsymbol{\varepsilon}_t\Vert^2. To be safe, we can introduce a parameter q: \begin{equation} \boldsymbol{\theta}_t \cdot \boldsymbol{\varepsilon}_t \approx q\Vert \boldsymbol{\varepsilon}_t\Vert^2 \end{equation} At this point, Eq. [eq:base-cond] becomes: \begin{equation} \alpha_t^2 \Vert\boldsymbol{u}_t\Vert^2 \approx 2 \rho_t q\Vert \boldsymbol{\varepsilon}_t\Vert^2\label{eq:base-cond-approx} \end{equation} Both sides are greater than 0, so Eq. [eq:base-cond] is potentially valid.
Asymptotic Estimation
If Eq. [eq:base-cond] holds, then Eq. [eq:base-approx] simplifies to: \begin{equation} \Vert\boldsymbol{\varepsilon}_{t+1}\Vert^2 \approx \Vert\boldsymbol{\varepsilon}_t\Vert^2 - 2 \alpha_t \boldsymbol{u}_t \cdot \boldsymbol{\varepsilon}_t = \Vert\boldsymbol{\varepsilon}_t\Vert^2 - 2 \alpha_t \Vert\boldsymbol{u}_t\Vert \Vert\boldsymbol{\varepsilon}_t\Vert \cos(\boldsymbol{u}_t, \boldsymbol{\varepsilon}_t) \end{equation} We stated that \boldsymbol{u}_t represents the task-related update magnitude. On average, it must be beneficial to the task (otherwise the original optimizer would be flawed), so on average, we should have \cos(\boldsymbol{u}_t, \boldsymbol{\varepsilon}_t) > 0. We further assume there exists a p > 0 such that \cos(\boldsymbol{u}_t, \boldsymbol{\varepsilon}_t) \sim p, thus: \begin{equation} \Vert\boldsymbol{\varepsilon}_{t+1}\Vert^2 \approx \Vert\boldsymbol{\varepsilon}_t\Vert^2 - 2 \alpha_t p\Vert\boldsymbol{u}_t\Vert \Vert\boldsymbol{\varepsilon}_t\Vert \end{equation} According to the approximation in Eq. [eq:base-cond-approx], we have \alpha_t \Vert\boldsymbol{u}_t \Vert \Vert \boldsymbol{\varepsilon}_t\Vert \approx \sqrt{2 \rho_t q}\Vert \boldsymbol{\varepsilon}_t\Vert^2. Substituting this into the above equation gives: \begin{equation} \Vert\boldsymbol{\varepsilon}_{t+1}\Vert^2 \approx \Vert\boldsymbol{\varepsilon}_t\Vert^2(1 - 2 p\sqrt{2 \rho_t q}) \approx \Vert\boldsymbol{\varepsilon}_t\Vert^2 \exp(- 2 p\sqrt{2 \rho_t q}) \end{equation} Iterating step by step, we obtain: \begin{equation} \Vert\boldsymbol{\varepsilon}_t\Vert^2 \approx \Vert\boldsymbol{\varepsilon}_0\Vert^2 \exp\left(- 2 \sum_{i=1}^{t-1} p\sqrt{2 \rho_i q}\right)\label{eq:varepsilon-t} \end{equation} It can be seen that the exponent on the right side is monotonically decreasing; it is a decay function. Now looking back at the approximation in Eq. [eq:base-cond-approx], there are two parameters \alpha_t and \rho_t to tune, but only one (approximate) equation. To allow \alpha_t and \rho_t to decay at the same rate, we set 2\rho_t q \approx \lambda^2 \Vert\boldsymbol{\varepsilon}_t\Vert^2, which yields: \begin{equation} \begin{aligned} \alpha_t \approx \frac{\lambda\Vert\boldsymbol{\varepsilon}_t\Vert^2}{\Vert\boldsymbol{u}_t\Vert} \approx&\, \frac{\lambda\Vert\boldsymbol{\varepsilon}_0\Vert^2}{\Vert\boldsymbol{u}_t\Vert} \exp\left(- 2 \sum_{i=1}^{t-1}p\sqrt{2 \rho_i q}\right) \\ \rho_t \approx \frac{\lambda^2\Vert\boldsymbol{\varepsilon}_t\Vert^2}{2q} \approx&\, \frac{\lambda^2\Vert\boldsymbol{\varepsilon}_0\Vert^2}{2q} \exp\left(- 2 \sum_{i=1}^{t-1}p\sqrt{2 \rho_i q}\right) \end{aligned}\label{eq:alpha-rho} \end{equation} This is the variation law for \alpha_t and \rho_t derived in this article. Of course, while we have the variation law, there are still four parameters \lambda, \Vert\boldsymbol{\varepsilon}_0\Vert, p, q to determine. Among them, q is relatively simple; setting q=1 is generally fine. However, three parameters still remain.
Scale Prediction
By definition, \Vert\boldsymbol{\varepsilon}_0\Vert = \Vert\boldsymbol{\theta}_0 - \boldsymbol{\theta}^*\Vert, which is the distance between the initial parameters and the target parameters. This can be understood as the scale of parameter change, and it has several different cases.
First, for parameters that are matrix multiplication kernels, such as the kernel matrices of fully connected or convolutional layers, they are generally initialized with a “mean of 0 and variance of \sigma^2” (where \sigma depends on the shape). If \boldsymbol{\theta} \in \mathbb{R}^k, we can estimate \Vert\boldsymbol{\theta}_0\Vert^2 \approx k\sigma^2. Furthermore, for such parameters, under reasonable initialization, the mean and variance of the parameters after training will not change significantly—at least the magnitude remains consistent. Therefore, we can also assume \Vert\boldsymbol{\theta}^*\Vert^2 \approx k\sigma^2. Since the initialization is random, \boldsymbol{\theta}_0 \cdot \boldsymbol{\theta}^* \approx 0, thus: \begin{equation} \Vert\boldsymbol{\varepsilon}_0\Vert^2 = \Vert\boldsymbol{\theta}_0 - \boldsymbol{\theta}^*\Vert^2 = \Vert\boldsymbol{\theta}_0\Vert^2 + \Vert\boldsymbol{\theta}^*\Vert^2 - 2\boldsymbol{\theta}_0 \cdot \boldsymbol{\theta}^* \approx 2k\sigma^2 \end{equation}
Second, for parameters that are additive bias terms, such as the bias vectors of fully connected or convolutional layers, and the \boldsymbol{\beta} vector of Normalization layers, these parameters are generally “zero-initialized,” so \Vert\boldsymbol{\varepsilon}_0\Vert^2 = \Vert\boldsymbol{\theta}^*\Vert^2. If we predict based on experience that the trained model’s bias terms are around \pm\sigma, we can also estimate \Vert\boldsymbol{\theta}^*\Vert^2 \approx k\sigma^2. The Amos paper takes \sigma=0.5. Finally, for the \boldsymbol{\gamma} vector of Normalization layers, it is generally “initialized to all ones” and remains around 1 after training. Assuming an error of \pm\sigma, we can also estimate \Vert\boldsymbol{\theta}^*\Vert^2 \approx k\sigma^2. Here, k refers to the vector dimension.
It can be seen that the results for \Vert\boldsymbol{\varepsilon}_0\Vert^2 share a commonality: they can all be written as k\sigma^2, where \sigma is our prediction of the parameter change scale. For multiplicative matrices, \sigma can be taken directly as the standard deviation of the initialization. For additive biases or \boldsymbol{\gamma} vectors, one can simply take \sigma=0.5, or handle other special parameters specifically.
Separating Scale
Now let’s look at the complete update amount. According to Eq. [eq:alpha-rho]: \begin{equation} \alpha_t \boldsymbol{u}_t \approx \lambda\Vert\boldsymbol{\varepsilon}_0\Vert^2 \times \frac{\boldsymbol{u}_t}{\Vert\boldsymbol{u}_t\Vert} \times \exp\left(- 2 \sum_{i=1}^{t-1}p\sqrt{2 \rho_i q}\right) \end{equation} Here, \frac{\boldsymbol{u}_t}{\Vert\boldsymbol{u}_t\Vert} is a unit vector controlling the update direction, and the \exp part is a decay term. We can ignore the decay term for a moment; thus, the norm of the update is controlled by \lambda\Vert\boldsymbol{\varepsilon}_0\Vert^2.
Returning to the first question at the beginning: “How should the learning rate adapt to different initializations and parameterizations?” Clearly, the intuitive idea is that parameters with a larger scale of change should have a larger update magnitude at each step, or simply be proportional to the scale of change. Since we estimated the scale of change using \Vert\boldsymbol{\varepsilon}_0\Vert, we assume \lambda\Vert\boldsymbol{\varepsilon}_0\Vert^2 = \alpha_0 \Vert\boldsymbol{\varepsilon}_0\Vert, where \alpha_0 is the global initial learning rate. Solving for \lambda gives \lambda = \alpha_0 / \Vert\boldsymbol{\varepsilon}_0\Vert. Substituting this into Eq. [eq:alpha-rho] yields: \begin{equation} \alpha_t \approx \frac{\alpha_0\Vert\boldsymbol{\varepsilon}_0\Vert}{\Vert\boldsymbol{u}_t\Vert} \exp\left(- 2 \sum_{i=1}^{t-1}p\sqrt{2 \rho_i q}\right),\quad \rho_t \approx \frac{\alpha_0^2}{2q} \exp\left(- 2 \sum_{i=1}^{t-1}p\sqrt{2 \rho_i q}\right)\label{eq:alpha-rho-2} \end{equation} Here, \alpha_0 represents the relative update magnitude per step (global learning rate). There isn’t much room for further derivation here; it is typically taken around 10^{-3}, or up to 10^{-2} for simple tasks. \Vert\boldsymbol{\varepsilon}_0\Vert was estimated in the previous section as roughly \sqrt{k}\sigma, where \sigma represents the average parameter change scale. By using it, we explicitly separate the parameter scale, achieving an adaptive effect (update magnitude proportional to \sigma). Notably, if we replace \Vert\boldsymbol{\varepsilon}_0\Vert in the above equation with \Vert\boldsymbol{\theta}_t\Vert, it becomes the LAMB optimizer. From this, we can also see that if the initialization mean of \boldsymbol{\theta} is not 0 (like the \boldsymbol{\gamma} vector), using \Vert\boldsymbol{\theta}_t\Vert instead of \Vert\boldsymbol{\varepsilon}_0\Vert would be problematic. Therefore, LAMB’s approach is to simply not transform the update magnitude for these parameters (i.e., keep the original update rule).
Analytical Approximation
The current results are already suitable for programming, but the parameter p is difficult to tune. To further see how p affects the decay function, we can derive an analytical approximation for \rho_t.
Multiplying both sides of \rho_t in Eq. [eq:alpha-rho-2] by 2q and taking the square root: Denoting the sum in the exponent \sum_{i=1}^{t-1}p\sqrt{2 \rho_i q} as S_t, the above equation corresponds to the difference equation: \begin{equation} \frac{S_t - S_{t-1}}{p} \approx \alpha_0 \exp\left(- S_{t-1}\right) \quad \Rightarrow \quad S_{t+1} - S_t \approx \alpha_0 p\exp\left(- S_t\right) \end{equation} The decay function is then \exp(-2S_t). To find the asymptotic approximation, we replace the difference with a derivative (refer to “Perturbation Methods for Difference Equations”): \begin{equation} \frac{dS_t}{dt} \approx \alpha_0 p \exp\left(- S_t\right) \end{equation} This is a simple differential equation. Solving it (with S_0=0) gives: \begin{equation} \exp\left(-2S_t\right) \approx \frac{1}{(\alpha_0 p t + 1)^2} \end{equation} This is the explicit solution for the decay function, indicating that hyperparameters should decay according to the inverse square of the number of steps. Substituting this back into Eq. [eq:alpha-rho-2] gives the complete result: \begin{equation} \alpha_t \approx \frac{\alpha_0\Vert\boldsymbol{\varepsilon}_0\Vert}{\Vert\boldsymbol{u}_t\Vert} \frac{1}{(\alpha_0 p t + 1)^2},\quad \rho_t \approx \frac{\alpha_0^2}{2q} \frac{1}{(\alpha_0 p t + 1)^2}\label{eq:alpha-rho-3} \end{equation} This explicit solution not only makes implementation easier but also clarifies the meaning of p. For example, if we want the learning rate to drop to half after T steps, then (\alpha_0 p T + 1)^2 = 2, which gives: \begin{equation} \alpha_0 p = \frac{\sqrt{2}-1}{T} \end{equation} As for what T should be, it depends on the task difficulty and data volume, leaving little room for further derivation.
Dynamic Convergence
The previous discussion assumed the existence of a constant p > 0 such that \cos(\boldsymbol{u}_t, \boldsymbol{\varepsilon}_t) \sim p. This can be understood as the model converging at a fixed speed, which rarely holds in practice. More commonly, as training progresses, the convergence speed slows down. To account for this, we can assume p is a function of the step t, denoted as p_t. The previous derivation remains largely valid, but the constant p is replaced by p_i: \begin{equation} \sqrt{2\rho_t q} \approx \alpha_0 \exp\left(- \sum_{i=1}^{t-1}p_i\sqrt{2 \rho_i q}\right) \end{equation} Repeating the derivation from the previous section: \begin{equation} \frac{S_t - S_{t-1}}{p_t} \approx \alpha_0 \exp\left(- S_{t-1}\right) \quad \Rightarrow \quad S_{t+1} - S_t \approx \alpha_0 p_t\exp\left(- S_t\right) \end{equation} The approximate differential equation is: \begin{equation} \frac{dS_t}{dt} \approx \alpha_0 p_t \exp\left(- S_t\right) \end{equation} The result of the integration is: \begin{equation} \exp\left(-S_t\right) \approx \frac{1}{\alpha_0 \int_0^t p_{\tau} d\tau + 1} \end{equation} Now we need to determine p_t. To reduce tuning costs, we might assume the rate of convergence decrease matches the rate of decrease of \Vert\boldsymbol{\varepsilon}_t\Vert. According to Eq. [eq:varepsilon-t], the decay function for \Vert\boldsymbol{\varepsilon}_t\Vert is \exp(-S_t), so we set p_t = p_0 \exp(-S_t). Substituting this into the equation: \begin{equation} \exp\left(-S_t\right) \approx \frac{1}{\alpha_0 p_0 \int_0^t \exp\left(-S_{\tau}\right) d\tau + 1} \end{equation} This is a simple differential equation that yields: \begin{equation} \exp\left(-2S_t\right) \approx \frac{1}{2\alpha_0 p_0 t + 1} \end{equation} Substituting this back into Eq. [eq:alpha-rho-2] gives: \begin{equation} \alpha_t \approx \frac{\alpha_0\Vert\boldsymbol{\varepsilon}_0\Vert}{\Vert\boldsymbol{u}_t\Vert} \frac{1}{2\alpha_0 p_0 t + 1},\quad \rho_t \approx \frac{\alpha_0^2}{2q} \frac{1}{2\alpha_0 p_0 t + 1}\label{eq:alpha-rho-4} \end{equation} Looking at the decay strategy alone, this is exactly “Inverse Time Decay,” which is a common learning rate decay strategy. Theoretically, this result is more reasonable in its assumptions than Eq. [eq:alpha-rho-3].
Summary
This article draws on the ideas of the Amos optimizer to derive results regarding the learning rate and weight decay rate, specifically Eq. [eq:alpha-rho-3] and Eq. [eq:alpha-rho-4]. These results can be applied as plug-and-play components to existing optimizers, helping to simplify the difficulty of parameter tuning to some extent.
When reposting, please include the original article address: https://kexue.fm/archives/9344
For more details on reposting, please refer to: Scientific Space FAQ