English (unofficial) translations of posts at kexue.fm
Source

Steepest Descent on Manifolds: 5. Dual Gradient Descent

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

In the previous four articles, we solved several specific steepest descent problems with equality constraints on parameters. Among them, the problems in the third and fourth articles could not be solved analytically, so I proposed corresponding fixed-point iteration methods. Specifically, the "Muon + Stiefel" problem studied in the third article, "Steepest Descent on Manifolds: 3. Muon + Stiefel", originated from Jeremy Bernstein’s article "Orthogonal manifold".

For this problem, Jeremy Bernstein eventually provided his own solution, which I call "Dual Gradient Descent," and it is quite worth learning.

Basic Concepts

Jeremy Bernstein’s solution was finally published in the Thinking Machines Lab blog post "Modular Manifolds". It is the second blog post from that lab, where they refer to it as "Dual Ascent," but here I will continue to call it "Dual Gradient Descent" to maintain consistency with the previous four articles.

In fact, dual gradient descent can be seen as a natural consequence of the method of Lagrange multipliers. However, a rigorous discussion of Lagrange multipliers is quite cumbersome, requiring the introduction of the Minimax theorem, for instance. To avoid these complications in this series, we adopted a derivation method using "undetermined coefficients," which makes dual gradient descent seem less natural. Nevertheless, we can still derive it following our line of reasoning, though it may take a bit more space.

First, let’s review the notation. \boldsymbol{W}\in\mathbb{R}^{n\times m} is a matrix parameter; without loss of generality, assume n\geq m. \boldsymbol{G}\in\mathbb{R}^{n\times m} is its gradient. \Vert\boldsymbol{G}\Vert_2 is the spectral norm of matrix \boldsymbol{G}, equal to the largest singular value; \Vert\boldsymbol{G}\Vert_* is the nuclear norm of matrix \boldsymbol{G}, equal to the sum of all singular values. Specifically, according to the conclusions in the article "Derivatives of SVD", we have: \begin{equation} \nabla_{\boldsymbol{G}}\Vert\boldsymbol{G}\Vert_* = \sum_i \nabla_{\boldsymbol{G}} \sigma_i = \sum_i \boldsymbol{u}_i \boldsymbol{v}_i^{\top} = \boldsymbol{U}\boldsymbol{V}^{\top} = \mathop{\mathrm{msign}}(\boldsymbol{G}) \label{eq:nuclear-grad} \end{equation} where \boldsymbol{G}=\sum_i \sigma_i \boldsymbol{u}_i \boldsymbol{v}_i^{\top} = \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top} is the SVD of \boldsymbol{G}. In other words, the gradient of the nuclear norm is exactly the \mathop{\mathrm{msign}} operator, which is an important foundation for the following derivation.

Problem Description

We will continue to introduce dual gradient descent along our previous derivation path, so this section restates the problems and existing results.

In "Steepest Descent on Manifolds: 3. Muon + Stiefel", the problem we wanted to solve was: \begin{equation} \max_{\boldsymbol{\Phi}} \mathop{\mathrm{tr}}(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) \qquad \text{s.t.}\qquad \Vert\boldsymbol{\Phi}\Vert_2 = 1,\,\, \boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I},\,\,\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} = \boldsymbol{0} \label{eq:muon-stiefel} \end{equation} The solution is \boldsymbol{\Phi} = \mathop{\mathrm{msign}}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}), where \boldsymbol{X}\in\mathbb{R}^{m\times m} is an undetermined symmetric matrix such that \boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} = \boldsymbol{0}.

In "Steepest Descent on Manifolds: 4. Muon + Spectral Sphere", the problem we wanted to solve was: \begin{equation} \max_{\boldsymbol{\Phi}} \mathop{\mathrm{tr}}(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) \qquad \text{s.t.}\qquad \Vert\boldsymbol{\Phi}\Vert_2 = 1,\,\, \mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})=0 \label{eq:muon-spectral} \end{equation} The answer is \boldsymbol{\Phi} = \mathop{\mathrm{msign}}(\boldsymbol{G} + \lambda\boldsymbol{\Theta}), where \lambda is an undetermined coefficient such that \mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})=0.

As we can see, our ultimate task has become finding the undetermined coefficients that satisfy the additional equality constraints. This is essentially solving a system of nonlinear equations. Dual gradient descent transforms the task of solving equations into the minimization of a certain objective function, which is then solved using gradient descent.

Dual Objective

The key to this transformation is the nuclear norm gradient equality [eq:nuclear-grad]. For simplicity, let’s first look at the "Muon + Spectral Sphere" problem [eq:muon-spectral], where the undetermined coefficient is just a scalar, making it easier to observe. It is not difficult to verify that: \begin{equation} \nabla_{\lambda} \Vert\boldsymbol{G} + \lambda\boldsymbol{\Theta}\Vert_* = \mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top}\mathop{\mathrm{msign}}(\boldsymbol{G} + \lambda\boldsymbol{\Theta})) = \mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi}) \end{equation} This means that solving the equation \mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})=0 is equivalent to finding a point where the gradient of \Vert\boldsymbol{G} + \lambda\boldsymbol{\Theta}\Vert_* is zero. This could be its (local) minimum or maximum point. Since \Vert\boldsymbol{G} + \lambda\boldsymbol{\Theta}\Vert_* clearly has no maximum value, we transform it into finding its minimum point: \begin{equation} \lambda^* = \mathop{\mathrm{argmin}}_{\lambda} \Vert\boldsymbol{G} + \lambda\boldsymbol{\Theta}\Vert_* \label{eq:muon-spectral-obj} \end{equation}

Let’s summarize the steps here:

1. Our goal is to solve the equation \mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})=0, finding any solution will do;

2. \mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi}) happens to be the gradient of \Vert\boldsymbol{G} + \lambda\boldsymbol{\Theta}\Vert_* with respect to \lambda;

3. This transforms into the problem of finding a (local) minimum or maximum point, because the gradient at such points is usually zero;

4. We can easily determine there is no maximum, so we can only look for the minimum.

Gradient Descent

After determining the objective [eq:muon-spectral-obj], we can solve it using gradient descent. Since the gradient is already available as \mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi}), the gradient descent format is: \begin{equation} \lambda \quad \leftarrow\quad \lambda - \eta \mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi}) \end{equation} Of course, we could also consider adding a \mathop{\mathrm{sign}} to \mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi}), i.e., SignSGD; there is room for flexibility. From the perspective of the iteration format, dual gradient descent is much simpler than the fixed-point iteration we proposed earlier. However, in many cases, dual gradient descent requires significantly more iteration steps and may need careful tuning of the learning rate or the introduction of momentum mechanisms to converge.

Therefore, as far as solving the equation \mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})=0 is concerned, dual gradient descent is not a particularly ideal scheme. However, our ultimate goal is not just to solve the equation \mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})=0, but to calculate \boldsymbol{\Phi} as the optimization direction for the model. Model optimization is itself an iterative process. We can cache the historical \lambda and adopt an approximation strategy where \lambda is updated synchronously with the model parameters: \begin{equation} \boldsymbol{\Phi} = \mathop{\mathrm{msign}}(\boldsymbol{G} + \lambda\boldsymbol{\Theta}), \quad \boldsymbol{W}\leftarrow\boldsymbol{W}- \eta_1 \boldsymbol{\Phi},\quad \lambda \leftarrow\lambda - \eta_2 \mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi}) \end{equation} In this way, each training step only requires calculating an almost free step of \lambda - \eta_2 \mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi}) to obtain an approximate implementation of the original objective [eq:muon-spectral]. Formally, it acts as a kind of adaptive Weight Decay for Muon.

On Stiefel

Having discussed the relatively simple "Muon + Spectral Sphere," let’s look at "Muon + Stiefel," i.e., objective [eq:muon-stiefel]. Here, the undetermined matrix \boldsymbol{X} has the constraint \boldsymbol{X}=\boldsymbol{X}^{\top}. We remove this constraint by setting \boldsymbol{X}=\boldsymbol{\Lambda}+\boldsymbol{\Lambda}^{\top}, where \boldsymbol{\Lambda}\in\mathbb{R}^{m\times m} is an arbitrary matrix. We then find: \begin{equation} \nabla_{\boldsymbol{\Lambda}}\Vert\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}\Vert_* = \boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} \end{equation} where \boldsymbol{\Phi} = \mathop{\mathrm{msign}}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}). Thus, solving the system of equations \boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W}=\boldsymbol{0} can similarly be transformed into finding the minimum point of the function \Vert\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}\Vert_*, and then solved using gradient descent: \begin{equation} \boldsymbol{\Lambda} \quad\leftarrow\quad \boldsymbol{\Lambda} - \eta(\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W}) \end{equation} Since \boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} is necessarily symmetric, directly updating \boldsymbol{X} \leftarrow\boldsymbol{X} - \eta(\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W}) is also feasible. Iterating this synchronously with \boldsymbol{W}, we get: \begin{equation} \boldsymbol{\Phi} = \mathop{\mathrm{msign}}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}), \quad \boldsymbol{W}\leftarrow\boldsymbol{W}- \eta_1 \boldsymbol{\Phi},\quad \boldsymbol{X} \leftarrow\boldsymbol{X} - \eta_2(\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W}) \end{equation} This achieves an approximation of the objective [eq:muon-stiefel], and the extra step \boldsymbol{X} - \eta_2(\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W}) in each iteration is also almost free.

Lagrange Multipliers

In these two examples, the equations to be solved happen to be equal to the gradient of a certain nuclear norm objective. Is this a mere coincidence? Of course not. As mentioned in the "Basic Concepts" section, this is a natural result of the method of Lagrange multipliers. In this section, we will expand on this discussion.

To facilitate understanding, let’s take the relatively simple objective [eq:muon-spectral] as an example. It can be equivalently written as: \begin{equation} \max_{\Vert\boldsymbol{\Phi}\Vert_2\leq 1} \min_{\lambda\in\mathbb{R}}\mathop{\mathrm{tr}}(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) + \lambda\mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi}) \end{equation} To understand this transformation, one only needs to realize that the above expression must have \mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi})=0; otherwise, the \min step could always reach negative infinity, making the final \max result also negative infinity. As for changing \Vert\boldsymbol{\Phi}\Vert_2 = 1 to \Vert\boldsymbol{\Phi}\Vert_2\leq 1, it does not change the result of the maximum (since the maximum is always reached at the boundary), but it makes the feasible region of \boldsymbol{\Phi} a convex set.

With this equivalent form, we can use the Minimax theorem to swap the positions of \min and \max: \begin{equation} \begin{aligned} &\,\max_{\Vert\boldsymbol{\Phi}\Vert_2\leq 1} \min_{\lambda\in\mathbb{R}}\mathop{\mathrm{tr}}(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) + \lambda\mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi}) \\ =&\, \min_{\lambda\in\mathbb{R}}\max_{\Vert\boldsymbol{\Phi}\Vert_2\leq 1}\mathop{\mathrm{tr}}(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) + \lambda\mathop{\mathrm{tr}}(\boldsymbol{\Theta}^{\top} \boldsymbol{\Phi}) \\ =&\, \min_{\lambda\in\mathbb{R}} \Vert\boldsymbol{G} + \lambda \boldsymbol{\Theta}\Vert_* \end{aligned} \end{equation} The step of taking the \max over \Vert\boldsymbol{\Phi}\Vert_2\leq 1 is a basic result of the Muon derivation, so solving for the \max first presents no difficulty. Thus, we obtain the dual objective \Vert\boldsymbol{G} + \lambda \boldsymbol{\Theta}\Vert_* of the original problem [eq:muon-spectral].

Some readers might wonder: why does this Lagrange multiplier method seem different from what I learned? This is because the Lagrange multiplier method here is generalized to general convex sets, and the interchangeability of \min and \max is strictly discussed to ensure that the final result is what we want. The Lagrange multiplier method we usually learn is just a set of heuristic procedures for solving constrained optimization problems in \mathbb{R}^n, without much discussion on the details of theoretical guarantees.

Summary

In this article, we introduced the idea of finding the steepest descent direction on manifolds through dual gradient descent. This is also the method used in the Thinking Machines Lab blog post "Modular Manifolds" to solve for Muon on the Stiefel manifold.

Reprinting: Please include the address of this article: https://kexue.fm/archives/11388

For more details on reprinting, please refer to: "Scientific Space FAQ"