In the previous article "Making ’Alchemy’ More Scientific (III): Convergence of Endpoint Loss in SGD", we successfully transformed the convergence conclusion from average loss to endpoint loss, obtaining a convergence rate of \mathcal{O}(\sqrt{\ln T/T}). However, upon careful reflection, we find that this result is somewhat counter-intuitive: according to experience, the endpoint loss should be closer to the optimal value. If the convergence rate of the average loss can reach \mathcal{O}(1/\sqrt{T}), why would the endpoint convergence rate be slower?
The latest progress on this issue is found in the paper "Optimal Linear Decay Learning Rate Schedules and Further Refinements". The paper first generalizes the key identity proved previously and then points out the importance of learning rate scheduling for endpoint convergence, thereby accelerating the convergence of endpoint loss to \mathcal{O}(1/\sqrt{T}).
New Identity
The results in the original paper are quite rich, and we will introduce them across multiple articles. This article primarily follows the logic of the previous one to provide a preliminary introduction. To convert the convergence conclusion of the average loss into endpoint loss, the key identity introduced in the previous article was: \begin{equation} q_T = \frac{1}{T}\sum_{t=1}^T q_t + \sum_{k=1}^{T-1} \frac{1}{k(k+1)}\sum_{t=T-k}^T (q_t - q_{T-k}) \end{equation} In this article, we generalize it to a weighted average version: define w_{k:T} \triangleq \sum_{t=k}^T w_t, then we have: \begin{equation} q_T = \frac{1}{w_{1:T}}\sum_{t=1}^T w_t q_t + \sum_{k=1}^{T-1}\left(\frac{1}{w_{k+1:T}} - \frac{1}{w_{k:T}}\right)\sum_{t=k}^T w_t (q_t - q_k) \label{eq:qt-g} \end{equation} The proof logic is basically the same. Let \lambda_k = \sum_{t=T-k+1}^T w_t and S_k = \frac{1}{\lambda_k}\sum_{t=T-k+1}^T w_t q_t, then: \begin{equation} \begin{aligned} \lambda_k S_k =&\, \lambda_{k+1} S_{k+1} - w_{T-k} q_{T-k} \\[5pt] =&\, \lambda_k S_{k+1} + w_{T-k} (S_{k+1} - q_{T-k}) \\ =&\, \lambda_k S_{k+1} + \frac{w_{T-k}}{\lambda_{k+1}}\sum_{t=T-k}^T w_t (q_t - q_{T-k}) \end{aligned} \end{equation} Dividing both sides by \lambda_k and summing over k=1 \sim T-1 gives: \begin{equation} S_1 = S_T + \sum_{k=1}^{T-1}\frac{w_{T-k}}{\lambda_k\lambda_{k+1}}\sum_{t=T-k}^T w_t (q_t - q_{T-k}) \end{equation} Noting that \frac{w_{T-k}}{\lambda_k\lambda_{k+1}} = \frac{1}{\lambda_k} - \frac{1}{\lambda_{k+1}}, and substituting the definitions of S_1 and S_T, we get: \begin{equation} q_T = \frac{1}{\sum_{t=1}^T w_t}\sum_{t=1}^T w_t q_t + \sum_{k=1}^{T-1}\left(\frac{1}{\sum_{t=T-k+1}^T w_t} - \frac{1}{\sum_{t=T-k}^T w_t}\right)\sum_{t=T-k}^T w_t (q_t - q_{T-k}) \end{equation} Finally, by changing the variable k to T-k, we obtain equation [eq:qt-g].
General Conclusion
Next, we start again from the core inequality of the second article "Making ’Alchemy’ More Scientific (II): Extending Conclusions to Unbounded Domains": \begin{equation} \sum_{t=1}^T \eta_t \mathbb{E}[L(\boldsymbol{\theta}_t) - L(\boldsymbol{\varphi})]\leq \frac{\Vert\boldsymbol{\theta}_1 - \boldsymbol{\varphi}\Vert^2}{2} + \frac{G^2}{2}\sum_{t=1}^T \eta_t^2 \label{leq:avg-2-mid3} \end{equation} Following the logic of the "Preparation" section in "Making ’Alchemy’ More Scientific (III): Convergence of Endpoint Loss in SGD", we change the starting point to k and substitute \boldsymbol{\varphi}=\boldsymbol{\theta}_k. However, we do not need to assume the monotonicity of \eta_t to divide both sides by \eta_T; instead, we directly obtain: \begin{equation} \sum_{t=k}^T \eta_t\mathbb{E}[L(\boldsymbol{\theta}_t) - L(\boldsymbol{\theta}_k)] \leq \frac{G^2}{2} \sum_{t=k}^T \eta_t^2 \label{leq:avg-2-mid4} \end{equation} Substituting w_t = \eta_t and q_t = \mathbb{E}[L(\boldsymbol{\theta}_t)] - L(\boldsymbol{\theta}^*) into the identity [eq:qt-g], we get: \begin{equation} \begin{aligned} \mathbb{E}[L(\boldsymbol{\theta}_T) - L(\boldsymbol{\theta}^*)] =&\, \frac{1}{\eta_{1:T}}\underbrace{\sum_{t=1}^T \eta_t \mathbb{E}[L(\boldsymbol{\theta}_t) - L(\boldsymbol{\theta}^*)]}_{\eqref{leq:avg-2-mid3}} + \sum_{k=1}^{T-1}\left(\frac{1}{\eta_{k+1:T}} - \frac{1}{\eta_{k:T}}\right)\underbrace{\sum_{t=k}^T \eta_t \mathbb{E}[L(\boldsymbol{\theta}_t) - L(\boldsymbol{\theta}_k)]}_{\eqref{leq:avg-2-mid4}} \\ \leq&\, \frac{\Vert\boldsymbol{\theta}_1 - \boldsymbol{\varphi}\Vert^2}{2\eta_{1:T}} + \frac{G^2}{2\eta_{1:T}}\sum_{t=1}^T \eta_t^2 + \frac{G^2}{2}\sum_{k=1}^{T-1}\left(\frac{1}{\eta_{k+1:T}} - \frac{1}{\eta_{k:T}}\right)\sum_{t=k}^T \eta_t^2 \\ \end{aligned} \label{leq:avg-2-mid5} \end{equation} For the second term, using \sum_{k=1}^{T-1}\sum_{t=k}^T = \sum_{t=1}^{T} \sum_{k=1}^{\min(t, T-1)}: \begin{equation} \begin{aligned} \sum_{k=1}^{T-1}\left(\frac{1}{\eta_{k+1:T}} - \frac{1}{\eta_{k:T}}\right)\sum_{t=k}^T \eta_t^2 =&\, \sum_{t=1}^T\eta_t^2\sum_{k=1}^{\min(t, T-1)}\left(\frac{1}{\eta_{k+1:T}} - \frac{1}{\eta_{k:T}}\right) = \sum_{t=1}^T\eta_t^2\left(\frac{1}{\eta_{\min(t+1, T):T}} - \frac{1}{\eta_{1:T}}\right) \end{aligned} \end{equation} Substituting this back into equation [leq:avg-2-mid5] yields: \begin{equation} \mathbb{E}[L(\boldsymbol{\theta}_T) - L(\boldsymbol{\theta}^*)] \leq \frac{\Vert\boldsymbol{\theta}_1 - \boldsymbol{\varphi}\Vert^2}{2\eta_{1:T}} + \frac{G^2}{2}\sum_{t=1}^T\frac{\eta_t^2}{\eta_{\min(t+1, T):T}} \label{leq:avg-3} \end{equation} This is an enhanced result for endpoint loss convergence. It does not depend on the monotonic decrease of the learning rate, nor does it rely on dividing both ends by \eta_T, thus providing more flexible space for learning rate scheduling. The previous article was only equivalent to simply replacing \eta_{1:T} and \eta_{\min(t+1, T):T} with T\eta_T and \max(1, T-t)\eta_T, which was clearly a coarser conclusion. This result might be new; it serves as an intermediate conclusion between the previous blog post and the original paper.
Accelerating Convergence
In this section, we will see that under appropriate settings, equation [leq:avg-3] can achieve a convergence rate of \mathcal{O}(1/\sqrt{T}). Here, "appropriate settings" mainly refers to the learning rate scheduling strategy. Unlike previous constant learning rates or schedules like \alpha/\sqrt{t} or \alpha/t that have no endpoint, this time we choose "Linear Decay": \begin{equation} \eta_t = \alpha\left(1 - \frac{t}{T+1}\right) \label{eq:liner-decay} \end{equation} This learning rate function deserves its own line for emphasis because it is one of the best practices for learning rate strategies. For example, "Straight to Zero: Why Linearly Decaying the Learning Rate to Zero Works Best for LLMs" claims it is even superior to Cosine Decay, reflecting that our discussion is getting closer to practical scenarios.
Calculating step by step: \begin{gather} \eta_{1:T} = \sum_{\tau=1}^T \alpha\left(1 - \frac{\tau}{T+1}\right) = \frac{\alpha T}{2} \\[6pt] \eta_{t+1:T} = \sum_{\tau=t+1}^T \alpha\left(1 - \frac{\tau}{T + 1}\right) = \frac{\alpha(T-t)(T+1-t)}{2(T+1)} \\[6pt] \frac{\eta_t^2}{\eta_{t+1:T}} = \frac{2\alpha (T+1-t)}{(T-t)(T+1)}\leq \frac{4\alpha}{T+1} \\[6pt] \sum_{t=1}^T\frac{\eta_t^2}{\eta_{\min(t+1, T):T}} = \eta_T + \sum_{t=1}^{T-1}\frac{\eta_t^2}{\eta_{t+1:T}} \leq \frac{\alpha}{T+1} + \sum_{t=1}^{T-1} \frac{4\alpha}{T+1} \leq 4\alpha \end{gather} Substituting these results into equation [leq:avg-3] gives: \begin{equation} \mathbb{E}[L(\boldsymbol{\theta}_T) - L(\boldsymbol{\theta}^*)] \leq \frac{\Vert\boldsymbol{\theta}_1 - \boldsymbol{\varphi}\Vert^2}{\alpha T} + 2 G^2 \alpha \end{equation} By taking \alpha=\frac{\Vert\boldsymbol{\theta}_1 - \boldsymbol{\theta}^*\Vert}{G\sqrt{2T}}, we can minimize the right-hand side, achieving an endpoint convergence rate of \mathcal{O}(1/\sqrt{T}). It should be noted that without introducing stronger assumptions, \mathcal{O}(1/\sqrt{T}) cannot be improved; this is guaranteed by information theory (refer to "Information-theoretic lower bounds on the oracle complexity of stochastic convex optimization"). Thus, this is already the theoretically optimal convergence rate.
Extended Reflections
In the above derivations and conclusions, there are several key points worth our special attention. In a sense, this represents a milestone in the convergence theory of stochastic optimization.
First, if the learning rate is set to a constant, the conclusion [leq:avg-3] is consistent with the previous article. We have already proved that it can at most reach a convergence rate of \mathcal{O}(\sqrt{\ln T/T}), which is not optimal. However, the linear decay learning rate in equation [eq:liner-decay] can achieve \mathcal{O}(1/\sqrt{T}). This, on one hand, demonstrates the necessity of learning rate decay for endpoint convergence and, on the other hand, provides theoretical support for the linear decay strategy.
It is not difficult to prove that in the first three articles, the best convergence rates were achieved under a constant learning rate, but this constant was related to the total number of training steps T, such as \alpha/\sqrt{T}. Many works consider this a drawback and prefer learning rate strategies like \alpha/\sqrt{t} or \alpha/t because they do not require prior knowledge of the training steps T, belonging to "stop anytime, resume anytime, train for as many steps as you want" endpoint-less strategies.
However, the practical performance of such strategies is usually not ideal. The new conclusions in this article show that after switching to endpoint loss, some new characteristics emerge: the fastest convergence is achieved neither by a constant learning rate related to T nor by a dynamic learning rate independent of T, but rather by "both," such as linear decay. Additionally, Cosine-type decay is commonly used in practice. Their common keywords are: better endpoint, related to T, and dynamically changing.
In other words, there is no "one-size-fits-all" learning rate strategy. Only by fine-tuning the learning rate strategy according to the number of training steps can the best endpoint convergence results be obtained. This actually aligns well with current Scaling Law practices. For instance, Step Law found that the optimal learning rate and optimal batch size should both be finely adjusted according to the data volume. Note that once the data volume and batch size are given, the training steps T are determined, so they can also be said to be functions of the training steps T.
In a future article, we will also briefly discuss the connection between this series of conclusions and today’s Scaling Laws. Please stay tuned.
Summary
In this article, we generalized the core identity from the previous post and obtained the theoretically optimal endpoint loss convergence rate. Interestingly, the learning rate strategy that achieved this result was not a constant learning rate, nor the traditional inverse-step or inverse-square-root-step learning rates, but rather linear decay, which is closer to our daily practice. Next, we will continue to explore the profound significance behind this conclusion.
Original Address: https://kexue.fm/archives/11494
For more details on reprinting, please refer to: "Scientific Space FAQ"