In this article, we will explore a concept known as "Gradient Flow." Simply put, gradient flow connects the various points in the process of searching for a minimum using gradient descent to form a trajectory that changes over (virtual) time. This trajectory is called the "gradient flow." In the latter part of the article, we will focus on how to extend the concept of gradient flow to probability spaces, resulting in "Wasserstein Gradient Flow," which provides a new perspective for understanding the continuity equation, the Fokker-Planck equation, and other related topics.
Gradient Descent
Suppose we want to search for the minimum of a smooth function f(\boldsymbol{x}). A common approach is Gradient Descent (GD), which iterates according to the following format: \begin{equation} \boldsymbol{x}_{t+1} = \boldsymbol{x}_t -\alpha \nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\label{eq:gd-d} \end{equation} If f(\boldsymbol{x}) is convex with respect to \boldsymbol{x}, then gradient descent can usually find the global minimum; otherwise, it typically converges to a "stationary point"—that is, a point where the gradient is zero. In ideal cases, it converges to a local minimum. Here, we do not strictly distinguish between local and global minima, because in deep learning, even converging to a local minimum is quite an achievement.
If we denote \alpha as \Delta t and \boldsymbol{x}_{t+1} as \boldsymbol{x}_{t+\Delta t}, and consider the limit as \Delta t \to 0, then equation [eq:gd-d] becomes an Ordinary Differential Equation (ODE): \begin{equation} \frac{d\boldsymbol{x}_t}{dt} = -\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\label{eq:gd-c} \end{equation} The trajectory \boldsymbol{x}_t obtained by solving this ODE is what we call "Gradient Flow." In other words, gradient flow is the trajectory of gradient descent in the process of seeking the minimum. Under the premise that equation [eq:gd-c] holds, we also have: \begin{equation} \frac{df(\boldsymbol{x}_t)}{dt} = \left\langle\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t),\frac{d\boldsymbol{x}_t}{dt}\right\rangle = -\Vert\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\Vert^2 \leq 0 \end{equation} This means that as long as \nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t) \neq \boldsymbol{0}, gradient descent will always move in a direction that decreases f(\boldsymbol{x}), provided the learning rate is small enough.
For more related discussions, you can refer to previous series on optimization algorithms, such as "Optimization Algorithms from a Dynamics Perspective (I): From SGD to Momentum Acceleration" and "Optimization Algorithms from a Dynamics Perspective (III): A More Holistic View".
Steepest Direction
Why use gradient descent? A mainstream explanation is that "the negative gradient direction is the direction of steepest local descent." Searching for this phrase yields a lot of content. This statement is not wrong, but it is somewhat imprecise because it does not specify the prerequisite conditions—the "steepest" in "steepest descent" necessarily involves a quantitative comparison. Only after determining the metric for comparison can the "steepest" result be determined.
If we only care about the direction of steepest descent, the objective of gradient descent should be: \begin{equation} \boldsymbol{x}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{x},\Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert = \epsilon} f(\boldsymbol{x})\label{eq:gd-min-co} \end{equation} Assuming a first-order approximation is sufficient, we have: \begin{equation} \begin{aligned} f(\boldsymbol{x})&\,=f(\boldsymbol{x}_t) + \langle \nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t),\boldsymbol{x} - \boldsymbol{x}_t\rangle\\ &\,\geq f(\boldsymbol{x}_t) - \Vert\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\Vert \Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert\\ &\,= f(\boldsymbol{x}_t) - \Vert\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\Vert \epsilon\\ \end{aligned} \end{equation} The equality holds when: \begin{equation} \boldsymbol{x} - \boldsymbol{x}_t = -\epsilon\frac{\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)}{\Vert\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\Vert}\quad\Rightarrow\quad\boldsymbol{x}_{t+1} = \boldsymbol{x}_t - \epsilon\frac{\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)}{\Vert\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\Vert}\label{eq:gd-d-norm} \end{equation} As we can see, the update direction is exactly the negative gradient direction, so it is indeed the direction of steepest local descent. However, do not forget that this is obtained under the constraint \Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert = \epsilon, where \Vert\cdot\Vert is the norm in Euclidean space. If we change the definition of the norm, or simply change the constraint condition, the result will be different. Therefore, strictly speaking, it should be "In Euclidean space, the negative gradient direction is the direction of steepest local descent."
Optimization Perspective
Equation [eq:gd-min-co] is a constrained optimization problem, which is difficult to generalize and solve. Furthermore, the solution to [eq:gd-min-co] is [eq:gd-d-norm], which is not the original gradient descent [eq:gd-d]. In fact, it can be proven that the optimization objective corresponding to equation [eq:gd-d] is: \begin{equation} \boldsymbol{x}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{x}} \frac{\Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert^2}{2\alpha} + f(\boldsymbol{x})\label{eq:gd-min} \end{equation} In other words, the constraint is added to the optimization objective as a penalty term, so there is no need to consider solving constraints, and it is easy to generalize. Moreover, even with the addition of the extra \frac{\Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert^2}{2\alpha}, it is guaranteed that the optimization will not move in a worse direction, because substituting \boldsymbol{x} = \boldsymbol{x}_t clearly shows that the objective function is exactly f(\boldsymbol{x}_t), so the result of \min_{\boldsymbol{x}} will at least not be greater than f(\boldsymbol{x}_t).
When \alpha is small enough, the first term dominates. Therefore, \Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert must be small enough for the first term to become sufficiently small, meaning the optimal point should be very close to \boldsymbol{x}_t. Thus, we can expand f(\boldsymbol{x}) at \boldsymbol{x}_t to get: \begin{equation} \boldsymbol{x}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{x}} \frac{\Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert^2}{2\alpha} + f(\boldsymbol{x}_t)+\langle\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t),\boldsymbol{x}-\boldsymbol{x}_t\rangle \end{equation} At this point, it is simply a quadratic function minimization problem, and the solution is exactly equation [eq:gd-d].
Obviously, besides the squared norm, we can consider other regularization terms, leading to different gradient descent schemes. For example, Natural Gradient Descent uses the KL divergence as a regularization term: \begin{equation} \boldsymbol{x}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{x}} \frac{KL(p(\boldsymbol{y}|\boldsymbol{x})\Vert p(\boldsymbol{y}|\boldsymbol{x}_t))}{\alpha} + f(\boldsymbol{x}) \end{equation} where p(\boldsymbol{y}|\boldsymbol{x}) is some probability distribution related to f(\boldsymbol{x}). To solve the above equation, we similarly expand at f(\boldsymbol{x}). f(\boldsymbol{x}) is expanded to the first order, but the KL divergence is special; its first-order expansion is zero (refer to here), so it must be expanded to at least the second order. The total result is: \begin{equation} \boldsymbol{x}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{x}} \frac{(\boldsymbol{x}-\boldsymbol{x}_t)^{\top}\boldsymbol{F}(\boldsymbol{x}-\boldsymbol{x}_t)}{2\alpha} + f(\boldsymbol{x}_t)+\langle\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t),\boldsymbol{x}-\boldsymbol{x}_t\rangle \end{equation} Here \boldsymbol{F} is the Fisher Information Matrix. We won’t go into the calculation details here; the process can also be found here. Now the above equation is essentially also a quadratic minimization problem, and the result is: \begin{equation} \boldsymbol{x}_{t+1} = \boldsymbol{x}_t -\alpha \boldsymbol{F}^{-1}\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t) \end{equation} This is the so-called "Natural Gradient Descent."
Introduction to Functionals
Equation [eq:gd-min] can generalize not only the regularization term but also the optimization objective, for instance, extending it to functionals.
The term "functional" might sound intimidating, but for regular readers of this site, it is a concept encountered many times. Simply put, while a standard multivariate function takes a vector as input and outputs a scalar, a functional takes a function as input and outputs a scalar. For example, the definite integral: \begin{equation} \mathcal{I}[f] = \int_a^b f(x)dx \end{equation} For any function f, the result of \mathcal{I}[f] is a scalar, so \mathcal{I}[f] is a functional. Another example is the KL divergence mentioned earlier, defined as: \begin{equation} KL(p\Vert q) = \int p(\boldsymbol{x})\log \frac{p(\boldsymbol{x})}{q(\boldsymbol{x})}d\boldsymbol{x} \end{equation} Here the integral is assumed to be over the entire space. If p(\boldsymbol{x}) is fixed, then it is a functional of q(\boldsymbol{x}). More generally, the f-divergence introduced in "Introduction to f-GAN: The Production Workshop of GAN Models" is also a type of functional. These are relatively simple functionals; more complex ones might involve derivatives of the input function, such as the principle of least action in theoretical physics.
In the following, we mainly focus on functionals whose domain is the set of all probability density functions.
Probability Flow
Suppose we have a functional \mathcal{F}[q] and we want to find its minimum. Following the idea of gradient descent, if we can find some kind of gradient for it, we can iterate along its negative direction.
To determine the iteration format, we extend equation [eq:gd-min], replacing f(\boldsymbol{x}) with \mathcal{F}[q]. What should the first regularization term be replaced with? In equation [eq:gd-min], it is the square of the Euclidean distance. Naturally, we should replace it with the square of some distance between probability distributions. A well-behaved distance for probability distributions is the Wasserstein distance (specifically, the 2-Wasserstein distance): \begin{equation} \mathcal{W}_2[p,q]=\sqrt{\inf_{\gamma\in \Pi[p,q]} \iint \gamma(\boldsymbol{x},\boldsymbol{y}) \Vert\boldsymbol{x}-\boldsymbol{y}\Vert^2 d\boldsymbol{x}d\boldsymbol{y}} \end{equation} If we replace the Euclidean distance in [eq:gd-min] with the Wasserstein distance, the objective becomes: \begin{equation} q_{t+1} = \mathop{\text{argmin}}_{q} \frac{\mathcal{W}_2^2[q,q_t]}{2\alpha} + \mathcal{F}[q] \end{equation} The derivation of the solution to this objective is complex. Based on literature such as "Introduction to Gradient Flows in the 2-Wasserstein Space", the solution is: \begin{equation} q_{t+1}(\boldsymbol{x}) = q_t(\boldsymbol{x}) + \alpha \nabla_{\boldsymbol{x}}\cdot\left(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\frac{\delta \mathcal{F}[q_t(\boldsymbol{x})]}{\delta q_t(\boldsymbol{x})}\right) \end{equation} Taking the limit \alpha \to 0, we obtain: \begin{equation} \frac{\partial q_t(\boldsymbol{x})}{\partial t} = \nabla_{\boldsymbol{x}}\cdot\left(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\frac{\delta \mathcal{F}[q_t(\boldsymbol{x})]}{\delta q_t(\boldsymbol{x})}\right) \end{equation} This is the "Wasserstein Gradient Flow," where \frac{\delta \mathcal{F}[q]}{\delta q} is the variational derivative of \mathcal{F}[q]. For a definite integral functional, the variational derivative is the derivative of the integrand: \begin{equation} \mathcal{F}[q] = \int F(q(\boldsymbol{x}))d\boldsymbol{x} \quad\Rightarrow\quad \frac{\delta \mathcal{F}[q(\boldsymbol{x})]}{\delta q(\boldsymbol{x})} = \frac{\partial F(q(\boldsymbol{x}))}{\partial q(\boldsymbol{x})} \end{equation}
Some Examples
According to the definition of f-divergence: \begin{equation} \mathcal{D}_f(p\Vert q) = \int q(\boldsymbol{x}) f\left(\frac{p(\boldsymbol{x})}{q(\boldsymbol{x})}\right)d\boldsymbol{x} \end{equation} Fixing p and setting \mathcal{F}[q]=\mathcal{D}_f(p\Vert q), we get: \begin{equation} \frac{\partial q_t(\boldsymbol{x})}{\partial t} = \nabla_{\boldsymbol{x}}\cdot\Big(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\big(f(r_t(\boldsymbol{x})) - r_t(\boldsymbol{x}) f'(r_t(\boldsymbol{x}))\big)\Big)\label{eq:wgd} \end{equation} where r_t(\boldsymbol{x}) = \frac{p(\boldsymbol{x})}{q_t(\boldsymbol{x})}. This has the form of a continuity equation. Thus, through the ODE: \begin{equation} \frac{d\boldsymbol{x}}{dt} = -\nabla_{\boldsymbol{x}}\big(f(r_t(\boldsymbol{x})) - r_t(\boldsymbol{x}) f'(r_t(\boldsymbol{x}))\big) \end{equation} one can sample from the distribution q_t. As t \to \infty, q_t = p, and the ODE achieves sampling from p. However, this is often computationally difficult as it requires knowing q_t.
A simpler example is the (reverse) KL divergence, where f(r) = -\log r. Substituting into [eq:wgd] yields: \begin{equation} \begin{aligned}\frac{\partial q_t(\boldsymbol{x})}{\partial t} =&\, - \nabla_{\boldsymbol{x}}\cdot\left(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\log \frac{p(\boldsymbol{x})}{q_t(\boldsymbol{x})}\right)\\ =&\, - \nabla_{\boldsymbol{x}}\cdot\Big(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\big(\log p(\boldsymbol{x}) - \log q_t(\boldsymbol{x})\big)\Big)\\ =&\, - \nabla_{\boldsymbol{x}}\cdot\big(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\log p(\boldsymbol{x})\big) + \nabla_{\boldsymbol{x}}\cdot\nabla_{\boldsymbol{x}} q_t(\boldsymbol{x}) \end{aligned} \end{equation} This is exactly the Fokker-Planck equation, corresponding to the SDE: \begin{equation} d\boldsymbol{x} = \nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) dt + \sqrt{2}dw \end{equation} If we know \log p(\boldsymbol{x}), we can use this SDE to sample from p(\boldsymbol{x}), bypassing the need to solve for q_t(\boldsymbol{x}).
Summary
This article introduced the concept of "Gradient Flow" in the process of finding minima, extending from vector spaces to the Wasserstein Gradient Flow in probability spaces. We also discussed their connections to the continuity equation, the Fokker-Planck equation, and ODE/SDE sampling.
Reprinting: Please include the original address of this article: https://kexue.fm/archives/9660
Further details on reprinting: Please refer to "Scientific Space FAQ".