When seeking the minimum of a function, we usually first find the derivative and look for its roots. In fortunate cases, one of these roots happens to be the minimum point of the original function. For vector-valued functions, the derivative is replaced by the gradient, and we seek its zeros. When the gradient’s roots are difficult to obtain, we can use gradient descent to gradually approach the minimum point.
The above are fundamental results of unconstrained optimization, which I believe many readers are familiar with. However, the theme of this article is optimization in probability space—that is, the input to the objective function is a probability distribution. Optimization of this type is more complex because its search space is no longer unconstrained. If we were to simply solve for gradient zeros or perform gradient descent, the resulting outcome might not necessarily be a valid probability distribution. Therefore, we need to find new analytical and computational methods to ensure that the optimization results conform to the characteristics of a probability distribution.
I have personally found this topic quite challenging for a long time. Recently, I decided to "reflect on the pain" and systematically study the problem of optimization over probability distributions. I have organized my findings here for your reference.
Gradient Descent
Let’s first revisit the relevant content of unconstrained optimization. Suppose our goal is: \begin{equation} \boldsymbol{x}_* = \mathop{\text{argmin}}_{\boldsymbol{x}\in\mathbb{R}^n} F(\boldsymbol{x}) \end{equation} Even high school students know that to find the extremum of a function, one often takes the derivative and sets it to zero to find critical points. For many, this has become "common sense." But let me test the readers: how many can prove this conclusion? In other words, why is the extremum of a function related to the "derivative being zero"?
Search Perspective
We can explore this problem from a search perspective. Suppose our current known \boldsymbol{x} is \boldsymbol{x}_t. How do we determine if \boldsymbol{x}_t is the minimum point? We can think about this in reverse: if we can find \boldsymbol{x}_{t+\eta} such that F(\boldsymbol{x}_{t+\eta}) < F(\boldsymbol{x}_t), then \boldsymbol{x}_t cannot be the minimum point. To this end, we can search for \boldsymbol{x}_{t+\eta} in the following format: \begin{equation} \boldsymbol{x}_{t+\eta} = \boldsymbol{x}_t + \eta \boldsymbol{u}_t,\quad 0 < \eta \ll 1 \end{equation} When F(\boldsymbol{x}) is sufficiently smooth and \eta is sufficiently small, we consider the first-order approximation to be accurate enough, so we use: \begin{equation} F(\boldsymbol{x}_{t+\eta}) = F(\boldsymbol{x}_t + \eta \boldsymbol{u}_t) \approx F(\boldsymbol{x}_t) + \eta \boldsymbol{u}_t \cdot \nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t) \end{equation} As long as \nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t)\neq 0, we can choose \boldsymbol{u}_t = -\nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t), such that: \begin{equation} F(\boldsymbol{x}_{t+\eta}) \approx F(\boldsymbol{x}_t) - \eta \Vert\nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t)\Vert^2 < F(\boldsymbol{x}_t) \end{equation} This means that for a sufficiently smooth function, its minimum can only be achieved at points where \nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t) = 0 or at infinity. This is why the first step in finding an extremum is usually "setting the derivative to zero." If \nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t) \neq 0, we can always choose a sufficiently small \eta and use: \begin{equation} \boldsymbol{x}_{t+\eta} = \boldsymbol{x}_t-\eta\nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t)\label{eq:gd} \end{equation} to obtain a point that makes f smaller; this is gradient descent. If we let \eta\to 0, we obtain the ODE: \begin{equation} \frac{d\boldsymbol{x}_t}{dt} = -\nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t) \end{equation} This is the "Gradient Flow" introduced in "Gradient Flow: Exploring the Path to the Minimum", which can be viewed as the trajectory of our search for the minimum point using gradient descent.
Projected Descent
What we have discussed so far is unconstrained optimization. Now, let’s briefly discuss a simple generalization of gradient descent in constrained optimization. Suppose the problem we face is: \begin{equation} \boldsymbol{x}_* = \mathop{\text{argmin}}_{\boldsymbol{x}\in\mathbb{X}} F(\boldsymbol{x})\label{eq:c-loss} \end{equation} where \mathbb{X} is a subset of \mathbb{R}^n. For theoretical analysis, \mathbb{X} is usually required to be a "bounded convex set," but for a simple understanding, we can ignore these details for now.
If we still use gradient descent \eqref{eq:gd} at this point, the biggest problem is that we cannot guarantee \boldsymbol{x}_{t+\eta}\in\mathbb{X}. However, we can add a projection operation: \begin{equation} \Pi_{\mathbb{X}} (\boldsymbol{y}) = \mathop{\text{argmin}}_{\boldsymbol{x}\in\mathbb{X}}\Vert\boldsymbol{x}-\boldsymbol{y}\Vert\label{eq:project} \end{equation} thereby forming "Projected Gradient Descent": \begin{equation} \boldsymbol{x}_{t+\eta} = \Pi_{\mathbb{X}}(\boldsymbol{x}_t-\eta\nabla_{\boldsymbol{x}_t}F(\boldsymbol{x}_t))\label{eq:pgd} \end{equation} Simply put, projected gradient descent first performs gradient descent and then finds the point in \mathbb{X} closest to the result of the gradient descent as the output, ensuring the output is within \mathbb{X}. In "Making Alchemy More Scientific (I): Average Loss Convergence of SGD", we proved that under certain assumptions, projected gradient descent can find the optimal solution to the constrained optimization problem \eqref{eq:c-loss}.
From the results, projected gradient descent transforms the constrained optimization \eqref{eq:c-loss} into two steps: "gradient descent + projection." The projection \eqref{eq:project} itself is also a constrained optimization problem. Although the optimization objective is fixed, it remains an unsolved problem that requires specific analysis for a given \mathbb{X}, necessitating further exploration.
Discrete Distributions
This article focuses on optimization in probability space, meaning the search object must be a probability distribution. In this section, we first focus on discrete distributions. We denote the search space as \Delta^{n-1}, which is the set of all n-dimensional discrete probability distributions: \begin{equation} \Delta^{n-1} = \left\{\boldsymbol{p}=(p_1,p_2,\cdots,p_n)\left|\, p_1,p_2,\cdots,p_n\geq 0,\sum_{i=1}^n p_i = 1\right.\right\} \end{equation} Our optimization goal is: \begin{equation} \boldsymbol{p}_* = \mathop{\text{argmin}}_{\boldsymbol{p}\in\Delta^{n-1}} F(\boldsymbol{p})\label{eq:p-loss} \end{equation}
Lagrange Multipliers
For optimization problems under equality or inequality constraints, the standard method is usually the "Lagrange Multiplier Method", which transforms the constrained optimization problem \eqref{eq:p-loss} into a weakly constrained \text{min-max} problem: \begin{equation} \min_{\boldsymbol{p}\in\Delta^{n-1}} F(\boldsymbol{p}) = \min_{\boldsymbol{p}\in\mathbb{R}^n} \max_{\mu_i \geq 0,\lambda\in\mathbb{R}}F(\boldsymbol{p}) - \sum_{i=1}^n \mu_i p_i + \lambda\left(\sum_{i=1}^n p_i - 1\right)\label{eq:min-max} \end{equation} Note that in this \text{min-max} optimization, we have removed the constraint \boldsymbol{p}\in\Delta^{n-1}, leaving only a simple \mu_i \geq 0 constraint in the \max step. How do we prove the right side is equivalent to the left? It’s not difficult, and can be understood in three steps:
1. First, we must understand the meaning of \text{min-max} on the right: \min is on the left and \max is on the right, meaning we ultimately want to find a result as small as possible, but this objective function must first be maximized with respect to certain variables;
2. When p_i < 0, the \max step will inevitably have \mu_i\to\infty, making the objective function value \infty. If p_i \geq 0, then the \max step will necessarily have \mu_i p_i = 0, and the objective function value will be finite. Clearly, the latter is smaller. Thus, when the right side reaches its optimum, p_i\geq 0 must hold. Similarly, we can prove \sum_{i=1}^n p_i = 1 holds;
3. From the analysis in step 2, it is clear that when the right side reaches its optimum, it must satisfy \boldsymbol{p}\in\Delta^{n-1}, and the extra terms become zero, making it equivalent to the optimization problem on the left.
Next, we use the "Minimax Theorem":
If \mathbb{X}, \mathbb{Y} are two convex sets, \boldsymbol{x}\in\mathbb{X}, \boldsymbol{y}\in\mathbb{Y}, and f(\boldsymbol{x}, \boldsymbol{y}) is convex with respect to \boldsymbol{x} (for any fixed \boldsymbol{y}) and concave with respect to \boldsymbol{y} (for any fixed \boldsymbol{x}), then: \begin{equation} \min_{\boldsymbol{x}\in\mathbb{X}}\max_{\boldsymbol{y}\in\mathbb{Y}} f(\boldsymbol{x},\boldsymbol{y}) = \max_{\boldsymbol{y}\in\mathbb{Y}}\min_{\boldsymbol{x}\in\mathbb{X}} f(\boldsymbol{x},\boldsymbol{y}) \end{equation}
The Minimax Theorem provides a sufficient condition for swapping \min and \max. A new term "convex set" appears here, referring to a set where the weighted average of any two points in the set remains within the set: \begin{equation} (1-\lambda)\boldsymbol{x}_1 + \lambda \boldsymbol{x}_2\in \mathbb{X},\qquad\forall \boldsymbol{x}_1,\boldsymbol{x}_2\in \mathbb{X},\quad\forall \lambda\in [0, 1] \end{equation} It can be seen that the condition for a convex set is not too restrictive; \mathbb{R}^n, \Delta^{n-1}, and the set of all non-negative numbers are all convex sets.
For the objective function on the right side of Eq. \eqref{eq:min-max}, it is a linear function with respect to \mu_i, \lambda, thus satisfying the condition of being concave with respect to \mu_i, \lambda. Furthermore, the terms other than F(\boldsymbol{p}) are linear with respect to \boldsymbol{p}. Therefore, the convexity of the entire objective function with respect to \boldsymbol{p} is equivalent to the convexity of F(\boldsymbol{p}) with respect to \boldsymbol{p}. That is, if F(\boldsymbol{p}) is a convex function of \boldsymbol{p}, then the \min and \max in Eq. \eqref{eq:min-max} can be swapped: \begin{equation} \small\min_{\boldsymbol{p}\in\mathbb{R}^n} \max_{\mu_i \geq 0,\lambda\in\mathbb{R}}F(\boldsymbol{p}) - \sum_{i=1}^n \mu_i p_i + \lambda\left(\sum_{i=1}^n p_i - 1\right) = \max_{\mu_i \geq 0,\lambda\in\mathbb{R}} \min_{\boldsymbol{p}\in\mathbb{R}^n} F(\boldsymbol{p}) - \sum_{i=1}^n \mu_i p_i + \lambda\left(\sum_{i=1}^n p_i - 1\right) \end{equation} In this way, we can first minimize with respect to \boldsymbol{p}. This is an unconstrained minimization problem that can be completed by solving the system of equations where the gradient equals zero. The result will contain parameters \lambda and \mu_i, which are finally determined by p_i \geq 0, \mu_i p_i = 0, and \sum_{i=1}^n p_i = 1.
Convex Set Search
However, although the Lagrange multiplier method is considered the standard method for solving constrained optimization problems, it is not very intuitive. Moreover, it can only obtain exact solutions by solving equations and cannot derive an iterative approximation algorithm similar to gradient descent. Therefore, we cannot be satisfied with just the Lagrange multiplier method.
From a search perspective, the key to solving optimization problems in probability space is ensuring that the trial points remain within the set \Delta^{n-1} during the search process. In other words, assuming the current probability distribution is \boldsymbol{p}_t\in \Delta^{n-1}, how do we construct the next trial point \boldsymbol{p}_{t+\eta}? It has two requirements: first, \boldsymbol{p}_{t+\eta}\in \Delta^{n-1}; second, its proximity to \boldsymbol{p}_t can be controlled by the size of \eta. This is where the "convex set" property of \Delta^{n-1} comes in handy. Using this property, we can define \boldsymbol{p}_{t+\eta} as: \begin{equation} \boldsymbol{p}_{t+\eta} = (1-\eta)\boldsymbol{p}_t + \eta \boldsymbol{q}_t,\quad \boldsymbol{q}_t\in \Delta^{n-1} \end{equation} Then we have: \begin{equation} F(\boldsymbol{p}_{t+\eta}) = F((1-\eta)\boldsymbol{p}_t + \eta \boldsymbol{q}_t) \approx F(\boldsymbol{p}_t) + \eta(\boldsymbol{q}_t - \boldsymbol{p}_t)\cdot\nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t) \end{equation} Assuming the first-order approximation is accurate enough, obtaining the direction of steepest descent is equivalent to solving: \begin{equation} \mathop{\text{argmin}}_{\boldsymbol{q}_t\in\Delta^{n-1}}\,\boldsymbol{q}_t\cdot\nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t) \end{equation} This objective function is quite simple, and the answer is: \begin{equation} \boldsymbol{q}_t = \text{onehot}(\text{argmin}(\nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t))) \end{equation} Here \nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t) is a vector, and performing \text{argmin} on a vector means finding the position of the smallest component. Thus, the above equation says that \boldsymbol{q}_t is a one-hot distribution where the 1 is located at the position of the smallest component of the gradient \nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t).
From this, we see that the form of gradient descent in probability space is: \begin{equation} \boldsymbol{p}_{t+\eta} = (1 - \eta)\boldsymbol{p}_t + \eta\, \text{onehot}(\text{argmin}(\nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t))) \end{equation} And the condition for \boldsymbol{p}_t to be a local minimum of F(\boldsymbol{p}_t) is: \begin{equation} \boldsymbol{p}_t\cdot\nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t) = (\nabla_{\boldsymbol{p}_t} F(\boldsymbol{p}_t))_{\min}\label{eq:p-min} \end{equation} Here, the \min of a vector refers to returning the smallest component.
An Example
Take Sparsemax, introduced in "The Road to Probability Distributions: A Review of Softmax and its Alternatives", as an example. Its original definition is: \begin{equation} Sparsemax(\boldsymbol{x}) = \mathop{\text{argmin}}\limits_{\boldsymbol{p}\in\Delta^{n-1}}\Vert \boldsymbol{p} - \boldsymbol{x}\Vert^2 \end{equation} where \boldsymbol{x}\in\mathbb{R}^n. It is not hard to see that from the perspective of projected gradient descent discussed earlier, Sparsemax is exactly the "projection" operation from \mathbb{R}^n to \Delta^{n-1}.
We denote F(\boldsymbol{p})=\Vert \boldsymbol{p} - \boldsymbol{x}\Vert^2. Its gradient with respect to \boldsymbol{p} is 2(\boldsymbol{p} - \boldsymbol{x}). According to Eq. \eqref{eq:p-min}, the equation satisfied by the minimum point is: \begin{equation} \boldsymbol{p}\cdot(\boldsymbol{p}-\boldsymbol{x}) = (\boldsymbol{p}-\boldsymbol{x})_{\min} \end{equation} We assume x_i = x_j \Leftrightarrow p_i = p_j. Here, a non-bold subscript like p_i represents the i-th component of vector \boldsymbol{p} (a scalar), while the bold subscript in the previous section like \boldsymbol{p}_t represents the t-th iteration result of \boldsymbol{p} (still a vector). Please distinguish them carefully.
Under this convention, from the above equation, we can obtain: \begin{equation} p_i > 0 \quad \Leftrightarrow \quad p_i-x_i = (\boldsymbol{p}-\boldsymbol{x})_{\min} \end{equation} Since \boldsymbol{p} is determined by \boldsymbol{x}, (\boldsymbol{p}-\boldsymbol{x})_{\min} is a function of \boldsymbol{x}, which we denote as -\lambda(\boldsymbol{x}). Then p_i = x_i - \lambda(\boldsymbol{x}), but this only holds for p_i > 0. For p_i=0, we have p_i-x_i > (\boldsymbol{p}-\boldsymbol{x})_{\min}, i.e., x_i - \lambda(\boldsymbol{x}) < 0. Based on these two points, we can uniformly write: \begin{equation} p_i = \text{relu}(x_i - \lambda(\boldsymbol{x})) \end{equation} where \lambda(\boldsymbol{x}) is determined by the condition that the sum of the components of \boldsymbol{p} is 1. For other details, please refer to "The Road to Probability Distributions: A Review of Softmax and its Alternatives".
Continuous Distributions
Having discussed discrete distributions, we now turn to continuous distributions. At first glance, continuous distributions seem to be just the limit version of discrete ones, and the results shouldn’t differ much. However, in reality, their characteristics are fundamentally different, to the extent that we need to construct a completely new methodology for continuous distributions.
Functional Objectives
First, let’s talk about the objective function. We know that continuous distributions are described by probability density functions. Thus, the input to the objective function is a probability density function. At this point, the objective function is no longer an ordinary function; we usually call it a "functional"—a mapping from an entire function to a scalar. In other words, we need to find a probability density function that minimizes a certain objective functional.
Although many people find "functional analysis makes one’s heart turn cold," most of us have actually encountered functionals because mappings that satisfy "input a function, output a scalar" are very common. For example, the definite integral: \begin{equation} \mathcal{I}[f]\triangleq \int_a^b f(x) dx \end{equation} is a mapping from a function to a scalar, so it is a functional. In fact, the functionals we encounter in practical applications are basically constructed from definite integrals, such as the KL divergence of probability distributions: \begin{equation} \mathcal{KL}[p\Vert q] = \int p(\boldsymbol{x})\log \frac{p(\boldsymbol{x})}{q(\boldsymbol{x})}d\boldsymbol{x} \end{equation} where the integral is by default over the entire space (the whole \mathbb{R}^n). More general functionals might include derivative terms in the integrand, such as the principle of least action in theoretical physics: \begin{equation} \mathcal{A}[x] = \int_{t_a}^{t_b} L(x(t),x'(t),t)dt \end{equation}
The objective functional we want to minimize can generally be written as: \begin{equation} \mathcal{F}[p] = \int F(p(\boldsymbol{x}))d\boldsymbol{x} \end{equation} For convenience, we can also define the functional derivative: \begin{equation} \frac{\delta\mathcal{F}[p]}{\delta p}(\boldsymbol{x}) = \frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})} \end{equation}
Compact Support
Furthermore, we need a notation for the continuous probability space. Its basic definition is: \begin{equation} \mathbb{P} = \left\{p(\boldsymbol{x}) \,\Bigg|\, p(\boldsymbol{x})\geq 0(\forall\boldsymbol{x}\in\mathbb{R}^n),\int p(\boldsymbol{x})d\boldsymbol{x} = 1\right\} \end{equation} It is not difficult to prove that if the limit of the probability density function p(\boldsymbol{x}) exists as \Vert\boldsymbol{x}\Vert\to\infty, then it must be \lim_{\Vert\boldsymbol{x}\Vert\to\infty} p(\boldsymbol{x}) = 0. This is a property we will use in later proofs.
However, it can be shown by example that not all probability density functions have a limit at infinity. To avoid theoretical difficulties, we usually assume during theoretical proofs that the support of p(\boldsymbol{x}) is a compact set. There are two concepts here: Support and Compact Set. The support refers to the set of all \boldsymbol{x} such that p(\boldsymbol{x}) > 0: \begin{equation} \text{supp}(p) = \{\boldsymbol{x} | p(\boldsymbol{x}) > 0\} \end{equation} The general definition of a compact set is complex, but in \mathbb{R}^n, a compact set is equivalent to a bounded closed set. So, simply put, the assumption that the support of p(\boldsymbol{x}) is compact directly grants p(\boldsymbol{x}) the property that "there exists a constant C such that p(\boldsymbol{x}) = 0 for all |\boldsymbol{x}| > C." This simplifies the behavior of p(\boldsymbol{x}) at infinity and fundamentally avoids the discussion of \lim_{\Vert\boldsymbol{x}\Vert\to\infty} p(\boldsymbol{x}) = 0.
Theoretically, this is a very strong assumption; it even excludes simple distributions like the normal distribution (the support of a normal distribution is \mathbb{R}^n). However, practically speaking, this assumption is not unreasonable. As we said, if the limit \lim_{\Vert\boldsymbol{x}\Vert\to\infty} p(\boldsymbol{x}) exists, it must be zero, so beyond a certain range, it is not much different from being zero. Examples where the limit does not exist do exist, but they usually require deliberate construction. For the data we encounter in practice, the condition that the limit exists is basically satisfied.
The Old Path Fails
Intuitively, the optimization of continuous distributions should follow the logic of discrete distributions, i.e., let \boldsymbol{p}_{t+\eta}(\boldsymbol{x}) = (1 - \eta)\boldsymbol{p}_t(\boldsymbol{x}) + \eta \boldsymbol{q}_t(\boldsymbol{x}), because like discrete distributions, the set of probability density functions \mathbb{P} is also a convex set. Now we substitute this into the objective functional: \begin{equation} \begin{aligned} \mathcal{F}[p_{t+\eta}] =&\, \int F(p_{t+\eta}(\boldsymbol{x}))d\boldsymbol{x} \\ =&\, \int F((1 - \eta)\boldsymbol{p}_t(\boldsymbol{x}) + \eta \boldsymbol{q}_t(\boldsymbol{x}))d\boldsymbol{x} \\ \approx&\,\int \left[F(p_t(\boldsymbol{x})) + \eta\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\Big(q_t(\boldsymbol{x}) - p_t(\boldsymbol{x})\Big)\right]d\boldsymbol{x} \end{aligned} \end{equation} Assuming the first-order approximation is sufficient, the problem transforms into: \begin{equation} \mathop{\text{argmin}}_{q_t\in \mathbb{P}}\int\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}q_t(\boldsymbol{x})d\boldsymbol{x} \end{equation} This problem is also not hard to solve. The answer is similar to the one-hot distribution in the discrete case: \begin{equation} q_t(\boldsymbol{x}) = \delta\left(\boldsymbol{x} - \mathop{\text{argmin}}_{\boldsymbol{x}'} \frac{\partial F(p_t(\boldsymbol{x}'))}{\partial p_t(\boldsymbol{x}')}\right) \end{equation} where \delta(\cdot) is the Dirac delta function, representing the probability density of a single-point distribution.
It looks smooth, but in reality, this path is blocked. First, the Dirac delta function is not a function in the conventional sense; it is a generalized function (also a type of functional). Second, if we look at it from the perspective of ordinary functions, the Dirac delta function has an infinite value at a certain point. Since it is infinite, the assumption that "the first-order approximation is sufficient" in the derivation cannot possibly hold.
Variable Substitution
We could consider patching the derivation from the previous section, for example, by adding a constraint q_t(\boldsymbol{x}) \leq C to obtain meaningful results. However, such patching ultimately feels inelegant. But if we don’t use the properties of convex sets, how else can we construct the next trial distribution \boldsymbol{p}_{t+\eta}(\boldsymbol{x})?
This is where we must fully utilize the characteristics of probability density functions—we can transform one probability density function into another through variable substitution. This is a unique property of continuous distributions. Specifically, if p(\boldsymbol{x}) is a probability density function and \boldsymbol{y}=\boldsymbol{T}(\boldsymbol{x}) is an invertible transformation, then p(\boldsymbol{T}(\boldsymbol{x}))\left|\frac{\partial \boldsymbol{T}(\boldsymbol{x})}{\partial\boldsymbol{x}}\right| is also a probability density function, where |\cdot| denotes the absolute value of the determinant of the matrix.
Based on this property, we define the next trial probability distribution as: \begin{equation} \begin{aligned} p_{t+\eta}(\boldsymbol{x}) =&\, p_t(\boldsymbol{x} + \eta\boldsymbol{\mu}_t(\boldsymbol{x}))\left|\boldsymbol{I} + \eta\frac{\partial \boldsymbol{\mu}_t(\boldsymbol{x})}{\partial\boldsymbol{x}}\right| \\ \approx &\, \Big[p_t(\boldsymbol{x}) + \eta\boldsymbol{\mu}_t(\boldsymbol{x})\cdot\nabla_{\boldsymbol{x}} p_t(\boldsymbol{x})\Big]\left[1 + \eta\,\text{Tr}\frac{\partial \boldsymbol{\mu}_t(\boldsymbol{x})}{\partial\boldsymbol{x}}\right] \\[3pt] \approx &\, p_t(\boldsymbol{x}) + \eta\boldsymbol{\mu}_t(\boldsymbol{x})\cdot\nabla_{\boldsymbol{x}} p_t(\boldsymbol{x}) + \eta\, p_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\cdot\boldsymbol{\mu}_t(\boldsymbol{x}) \\[5pt] = &\, p_t(\boldsymbol{x}) + \eta\nabla_{\boldsymbol{x}}\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big] \\ \end{aligned} \end{equation} We derived the same result in "Talk on Generative Diffusion Models (XII): ’Hard-Core’ Diffusion ODE". For the approximate expansion of the determinant, one can refer to the article "Derivatives of Determinants".
Integral Transformation
Using this new p_{t+\eta}(\boldsymbol{x}), we can obtain: \begin{equation} \begin{aligned} \mathcal{F}[p_{t+\eta}] =&\, \int F(p_{t+\eta}(\boldsymbol{x}))d\boldsymbol{x} \\ \approx&\, \int F\Big(p_t(\boldsymbol{x}) + \eta\nabla_{\boldsymbol{x}}\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big]\Big)d\boldsymbol{x} \\ \approx&\, \int \left[F(p_t(\boldsymbol{x})) + \eta\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\nabla_{\boldsymbol{x}}\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big]\right]d\boldsymbol{x} \\ =&\, \mathcal{F}[p_t] + \eta\int \frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\nabla_{\boldsymbol{x}}\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big] d\boldsymbol{x} \\ \end{aligned}\label{eq:px-approx} \end{equation} Next, as in "Deriving the Continuity Equation and Fokker-Planck Equation via the Test Function Method", we need to derive an integral identity related to probability density. First, we have: \begin{equation} \begin{aligned} \end{aligned} \end{equation}
\begin{equation} \begin{aligned} &\,\int \nabla_{\boldsymbol{x}}\cdot\left[\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})} p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\right] d\boldsymbol{x} \\[5pt] =&\, \int \frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\nabla_{\boldsymbol{x}}\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big] d\boldsymbol{x} + \int \left(\nabla_{\boldsymbol{x}}\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\right)\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big] d\boldsymbol{x} \end{aligned} \end{equation} According to the divergence theorem, we have \begin{equation} \int_{\Omega} \nabla_{\boldsymbol{x}}\cdot\left[\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})} p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\right] d\boldsymbol{x} = \int_{\partial\Omega} \left[\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})} p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\right]\cdot \hat{\boldsymbol{n}} dS \end{equation} where \Omega is the integration region, which in this case is the entire \mathbb{R}^n; \partial\Omega is the boundary of the region, and the boundary of \mathbb{R}^n is naturally at infinity; \hat{\boldsymbol{n}} is the outward unit normal vector of the boundary; and dS is the area element. Under the assumption of compact support, p_t(\boldsymbol{x})=0 at infinity, so the right-hand side of the above equation is effectively the integral of zero, and the result is zero. Therefore, we have \begin{equation} \int \frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\nabla_{\boldsymbol{x}}\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big] d\boldsymbol{x} = - \int \left(\nabla_{\boldsymbol{x}}\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\right)\cdot\big[p_t(\boldsymbol{x})\boldsymbol{\mu}_t(\boldsymbol{x})\big] d\boldsymbol{x} \end{equation} Substituting this into Eq. [eq:px-approx] yields \begin{equation} \mathcal{F}[p_{t+\eta}] \approx \mathcal{F}[p_t] - \eta\int \left(p_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\right)\cdot \boldsymbol{\mu}_t(\boldsymbol{x}) d\boldsymbol{x} \label{eq:px-approx-2} \end{equation}
Gradient Flow
According to Eq. [eq:px-approx-2], a simple choice to ensure \mathcal{F}[p_{t+\eta}] \leq \mathcal{F}[p_t] is \begin{equation} \boldsymbol{\mu}_t(\boldsymbol{x}) = \nabla_{\boldsymbol{x}}\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})} \end{equation} The corresponding iterative scheme is \begin{equation} p_{t+\eta}(\boldsymbol{x}) \approx p_t(\boldsymbol{x}) + \eta\nabla_{\boldsymbol{x}}\cdot\left[p_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\right] \end{equation} Taking the limit \eta\to 0, we obtain \begin{equation} \frac{\partial}{\partial t}p_t(\boldsymbol{x}) = \nabla_{\boldsymbol{x}}\cdot\left[p_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})}\right] \end{equation} Or written more concisely as \begin{equation} \frac{\partial p_t}{\partial t} = \nabla\cdot\left[p_t\nabla\frac{\delta \mathcal{F}[p_t]}{\delta p_t}\right] \end{equation} This is the Wasserstein gradient flow introduced in "Gradient Flow: Exploring the Path to the Minimum", but here we have obtained the same result without explicitly introducing the concept of the Wasserstein distance.
Since p_{t+\eta}(\boldsymbol{x}) is obtained from p_t(\boldsymbol{x}) through the transformation \boldsymbol{x}\to \boldsymbol{x} + \eta \boldsymbol{\mu}_t(\boldsymbol{x}), we can also write the Ordinary Differential Equation (ODE) for the trajectory of \boldsymbol{x}: \begin{equation} \boldsymbol{x}_t = \boldsymbol{x}_{t+\eta} + \eta \boldsymbol{\mu}_t(\boldsymbol{x}_{t+\eta})\quad\Rightarrow\quad \frac{d\boldsymbol{x}_t}{dt} = -\boldsymbol{\mu}_t(\boldsymbol{x}_t) = -\nabla_{\boldsymbol{x}}\frac{\partial F(p_t(\boldsymbol{x}))}{\partial p_t(\boldsymbol{x})} \end{equation} The significance of this ODE is that, starting from a sample \boldsymbol{x}_0 of the distribution p_0(\boldsymbol{x}), when it moves to \boldsymbol{x}_t according to this ODE, the distribution followed by \boldsymbol{x}_t is exactly p_t(\boldsymbol{x}).
Summary
This article systematically organizes the minimization methods for objective functions in probability space, including the necessary conditions for reaching a minimum and iterative methods similar to gradient descent. These results are frequently used in scenarios such as optimization and generative models (especially diffusion models).
Therefore, this proves that under the constraint of the continuity equation, the solution that minimizes the total kinetic energy is indeed the straight-line path: \begin{equation*} x_t = (1 - t)x_0 + t x_1 \end{equation*} This result provides a solid theoretical foundation for the straight-line ODE adopted in Flow Matching.
Summary
In this article, we started from the perspective of minimizing functionals in probability space and introduced how to use the continuity equation as a constraint to derive the equations satisfied by the optimal distribution evolution through the method of Lagrange multipliers. In particular, we proved that under the quadratic kinetic energy loss, the optimal path corresponds to linear motion in the probability space, which provides theoretical support for generative models such as Flow Matching. This perspective not only helps to understand existing generative models but also provides a general framework for designing new distribution transformation algorithms.
Original link: https://kexue.fm/archives/10330