In this article, we will derive the formula for the derivative of the \operatorname{msign} operator. If you are looking to combine TTT and Muon, as done in "Test-Time Training Done Right", then this post may be helpful to you.
Two Definitions
This article assumes that readers are already familiar with \operatorname{msign}. If not, you may first refer to "Appreciation of the Muon Optimizer: An Essential Leap from Vectors to Matrices" and "Newton-Schulz Iteration for the msign Operator (Part 1)". Given a matrix \boldsymbol{M} \in \mathbb{R}^{n \times m}, we have: \begin{equation} \boldsymbol{U}, \boldsymbol{\Sigma}, \boldsymbol{V}^{\top} = \text{SVD}(\boldsymbol{M}) \quad \Rightarrow \quad \operatorname{msign}(\boldsymbol{M}) = \boldsymbol{U}_{[:,:r]} \boldsymbol{V}_{[:,:r]}^{\top} \end{equation} where \boldsymbol{U} \in \mathbb{R}^{n \times n}, \boldsymbol{\Sigma} \in \mathbb{R}^{n \times m}, \boldsymbol{V} \in \mathbb{R}^{m \times m}, and r is the rank of \boldsymbol{M}. Simply put, \operatorname{msign} is the new matrix obtained by changing all non-zero singular values of the matrix to 1. Based on SVD, we can also prove: \begin{equation} \operatorname{msign}(\boldsymbol{M}) = (\boldsymbol{M}\boldsymbol{M}^{\top})^{-1/2}\boldsymbol{M} = \boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2} \end{equation} Here, ^{-1/2} denotes the inverse of the matrix square root. Since \boldsymbol{M}\boldsymbol{M}^{\top} and \boldsymbol{M}^{\top}\boldsymbol{M} are (semi-)positive definite and symmetric, the square root can always be found, but the inverse might not exist. When it is not invertible, we can use the "pseudo-inverse". The name \operatorname{msign} originates from the similarity of the above expression to the scalar sign function \operatorname{sign}(x) = x/\sqrt{x^2}. However, as mentioned previously, there is another matrix version of the sign function, which we denote here as \operatorname{mcsgn}: \begin{equation} \operatorname{mcsgn}(\boldsymbol{M}) = \boldsymbol{M}(\boldsymbol{M}^2)^{-1/2} \end{equation} That is, \boldsymbol{M}^{\top}\boldsymbol{M} in \operatorname{msign} is replaced by \boldsymbol{M}^2. Since only square matrices can be squared, this definition only applies to square matrices. Introducing two similar but different definitions in one article can be confusing, but unfortunately, both definitions are needed in the subsequent calculations.
\operatorname{mcsgn} possesses similarity invariance: if \boldsymbol{M} = \boldsymbol{P}\boldsymbol{\Lambda}\boldsymbol{P}^{-1}, then \operatorname{mcsgn}(\boldsymbol{M}) = \boldsymbol{P}\operatorname{mcsgn}(\boldsymbol{\Lambda})\boldsymbol{P}^{-1}. Furthermore, if \boldsymbol{\Lambda} is a diagonal matrix (which is almost always possible in the complex field), then: \begin{equation} \operatorname{mcsgn}(\boldsymbol{M}) = \boldsymbol{P}\operatorname{csgn}(\boldsymbol{\Lambda})\boldsymbol{P}^{-1} \end{equation} \operatorname{csgn}(\boldsymbol{\Lambda}) indicates that \operatorname{csgn} is applied to each diagonal element, where \operatorname{csgn}(z) = z/\sqrt{z^2} is the complex version of the sign function. If the real part of z is non-zero, it equals \operatorname{sign}(\operatorname{Re}[z]). Thus, the difference between \operatorname{msign} and \operatorname{mcsgn} is that the former applies the sign function to the singular values based on SVD, while the latter applies it to the eigenvalues based on eigenvalue decomposition. When \boldsymbol{M} is a symmetric matrix, they are equivalent.
Unified Calculation
Currently, the numerical calculation of \operatorname{msign} mainly relies on the "Newton-Schulz iteration" of the following format: \begin{equation} \boldsymbol{X}_0 = \frac{\boldsymbol{M}}{\|\boldsymbol{M}\|_F}, \qquad \boldsymbol{X}_{t+1} = a_{t+1}\boldsymbol{X}_t + b_{t+1}\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t) + c_{t+1}\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t)^2 \end{equation} Regarding the choice of coefficients, we have discussed this in detail in "Newton-Schulz Iteration for the msign Operator (Part 1)" and "Newton-Schulz Iteration for the msign Operator (Part 2)". The relatively new results from the second part are:
| t | a \times 1.01 | b \times 1.01^3 | c \times 1.01^5 |
|---|---|---|---|
| 1 | 8.28721 | -23.5959 | 17.3004 |
| 2 | 4.10706 | -2.94785 | 0.544843 |
| 3 | 3.94869 | -2.9089 | 0.551819 |
| 4 | 3.31842 | -2.48849 | 0.510049 |
| 5 | 2.30065 | -1.6689 | 0.418807 |
| 6 | 1.8913 | -1.268 | 0.376804 |
| 7 | 1.875 | -1.25 | 0.375 |
| 8 | 1.875 | -1.25 | 0.375 |
The advantage of this result is that it can be arbitrarily truncated and superimposed. For example, keeping only the first 5 rows gives the optimal 5-step iteration, and keeping the first 6 rows gives the optimal 6-step iteration, with the approximation guaranteed to be better than the 5-step iteration, and so on.
As for \operatorname{mcsgn}, it simply replaces \boldsymbol{M}^{\top}\boldsymbol{M} in \operatorname{msign} with \boldsymbol{M}^2. Therefore, in theory, Newton-Schulz iteration can also be used. However, since eigenvalues can be complex, general convergence is much more difficult. Nevertheless, if we can confirm beforehand that the eigenvalues of matrix \boldsymbol{M} are all real (such as the block triangular matrices to which we will apply \operatorname{mcsgn} later), then we can reuse the iteration and coefficients of \operatorname{msign}: \begin{equation} \boldsymbol{X}_0 = \frac{\boldsymbol{M}}{\sqrt{\operatorname{tr}(\boldsymbol{M}^2)}}, \qquad \boldsymbol{X}_{t+1} = a_{t+1}\boldsymbol{X}_t + b_{t+1}\boldsymbol{X}_t^3 + c_{t+1}\boldsymbol{X}_t^5 \end{equation}
Derivation Process
Now we formally enter the main topic—calculating the derivative of \boldsymbol{O} = \operatorname{msign}(\boldsymbol{M}). If you are just using Muon as a regular optimizer, this article probably has nothing to do with you. Only when we need to refer to TTT and use the Muon optimizer to construct RNN models do we need the derivative of \operatorname{msign}. In this case, \operatorname{msign} appears in the forward propagation of the model, and to perform backpropagation for the entire model, the derivative of \operatorname{msign} is naturally involved.
Since \operatorname{msign} is calculated via Newton-Schulz iteration, it can actually be backpropagated directly. Thus, numerical differentiation of \operatorname{msign} itself is not a problem. However, backpropagation based on iteration means many intermediate states must be stored, which often leads to memory explosion. Therefore, we hope to obtain an analytical solution for the derivative to simplify things. On the other hand, in "The Derivative of SVD", we actually calculated the derivative of \operatorname{msign}, but that was based on the SVD expression, and SVD is not a GPU-efficient algorithm.
Therefore, our goal is to seek a result that does not depend on SVD and can be calculated efficiently. We start from the identity: \begin{equation} \boldsymbol{M} = \boldsymbol{O}\boldsymbol{M}^{\top}\boldsymbol{O} \end{equation} (which can be proven by the definition of \operatorname{msign}). Differentiating both sides gives: \begin{equation} d\boldsymbol{M} = (d\boldsymbol{O})\boldsymbol{M}^{\top}\boldsymbol{O} + \boldsymbol{O}(d\boldsymbol{M}^{\top})\boldsymbol{O} + \boldsymbol{O}\boldsymbol{M}^{\top}(d\boldsymbol{O}) \label{eq:dm-do} \end{equation} The difficulty of this result is that it is not easy to isolate d\boldsymbol{M} = f(d\boldsymbol{O}) or d\boldsymbol{O} = f(d\boldsymbol{M}). Consequently, it is hard to see the relationship between \nabla_{\boldsymbol{O}}\mathcal{L} and \nabla_{\boldsymbol{M}}\mathcal{L} (\mathcal{L} is the loss function). In this case, the best approach is to return to the fundamental method of matrix differentiation—the "trace trick":
Trace Trick: If we can find a matrix \boldsymbol{G} of the same shape as \boldsymbol{M} such that \begin{equation} d\mathcal{L} = \langle \boldsymbol{G}, d\boldsymbol{M} \rangle_F = \operatorname{tr}(\boldsymbol{G}^{\top} (d\boldsymbol{M})) \end{equation} then \boldsymbol{G} = \nabla_{\boldsymbol{M}}\mathcal{L}.
The essence of the trace trick is to transform matrices/vectors into scalars, and then scalars into traces, allowing the use of trace identities: \begin{equation} \operatorname{tr}(\boldsymbol{A}\boldsymbol{B}) = \operatorname{tr}(\boldsymbol{B}\boldsymbol{A}) = \operatorname{tr}(\boldsymbol{A}^{\top}\boldsymbol{B}^{\top}) = \operatorname{tr}(\boldsymbol{B}^{\top}\boldsymbol{A}^{\top}) \end{equation} Now let \boldsymbol{X} be any matrix of the same shape as \boldsymbol{M}. Multiply both sides of Eq. [eq:dm-do] by \boldsymbol{X}^{\top} and take the trace: \begin{equation} \begin{aligned} \operatorname{tr}(\boldsymbol{X}^{\top}(d\boldsymbol{M})) &= \operatorname{tr}(\boldsymbol{X}^{\top}(d\boldsymbol{O})\boldsymbol{M}^{\top}\boldsymbol{O}) + \operatorname{tr}(\boldsymbol{X}^{\top}\boldsymbol{O}(d\boldsymbol{M}^{\top})\boldsymbol{O}) + \operatorname{tr}(\boldsymbol{X}^{\top}\boldsymbol{O}\boldsymbol{M}^{\top}(d\boldsymbol{O})) \\ &= \operatorname{tr}(\boldsymbol{M}^{\top}\boldsymbol{O}\boldsymbol{X}^{\top}(d\boldsymbol{O})) + \operatorname{tr}(\boldsymbol{O}\boldsymbol{X}^{\top}\boldsymbol{O}(d\boldsymbol{M}^{\top})) + \operatorname{tr}(\boldsymbol{X}^{\top}\boldsymbol{O}\boldsymbol{M}^{\top}(d\boldsymbol{O})) \\ &= \operatorname{tr}(\boldsymbol{M}^{\top}\boldsymbol{O}\boldsymbol{X}^{\top}(d\boldsymbol{O})) + \operatorname{tr}(\boldsymbol{O}^{\top}\boldsymbol{X}\boldsymbol{O}^{\top}(d\boldsymbol{M})) + \operatorname{tr}(\boldsymbol{X}^{\top}\boldsymbol{O}\boldsymbol{M}^{\top}(d\boldsymbol{O})) \end{aligned} \end{equation} From this, we obtain: \begin{equation} \operatorname{tr}((\boldsymbol{X}^{\top} - \boldsymbol{O}^{\top}\boldsymbol{X}\boldsymbol{O}^{\top})(d\boldsymbol{M})) = \operatorname{tr}((\boldsymbol{M}^{\top}\boldsymbol{O}\boldsymbol{X}^{\top} + \boldsymbol{X}^{\top}\boldsymbol{O}\boldsymbol{M}^{\top})(d\boldsymbol{O})) \end{equation} If we let \boldsymbol{M}^{\top}\boldsymbol{O}\boldsymbol{X}^{\top} + \boldsymbol{X}^{\top}\boldsymbol{O}\boldsymbol{M}^{\top} = (\nabla_{\boldsymbol{O}}\mathcal{L})^{\top}, then the above equation takes on the meaning of d\mathcal{L}. According to the trace trick, \boldsymbol{X}^{\top} - \boldsymbol{O}^{\top}\boldsymbol{X}\boldsymbol{O}^{\top} = (\nabla_{\boldsymbol{M}}\mathcal{L})^{\top}. This indicates that the relationship between \nabla_{\boldsymbol{M}}\mathcal{L} and \nabla_{\boldsymbol{O}}\mathcal{L} is described by the following system of equations: \begin{gather} \boldsymbol{X} - \boldsymbol{O}\boldsymbol{X}^{\top}\boldsymbol{O} = \nabla_{\boldsymbol{M}}\mathcal{L} \label{eq:g-m} \\[7pt] \boldsymbol{X}\boldsymbol{O}^{\top}\boldsymbol{M} + \boldsymbol{M}\boldsymbol{O}^{\top}\boldsymbol{X} = \nabla_{\boldsymbol{O}}\mathcal{L} \label{eq:g-o} \end{gather}
Theoretical Form
So, the problem now becomes solving for \boldsymbol{X} from Eq. [eq:g-o] and then substituting it into Eq. [eq:g-m] to obtain \nabla_{\boldsymbol{M}}\mathcal{L}, thereby expressing \nabla_{\boldsymbol{M}}\mathcal{L} as a function of \nabla_{\boldsymbol{O}}\mathcal{L} and avoiding the direct calculation of \nabla_{\boldsymbol{M}}\boldsymbol{O}. Obviously, the only difficulty is solving Eq. [eq:g-o].
In this section, we first derive a theoretical solution based on SVD, which is not very practical but helps us understand the properties of Eq. [eq:g-o] and align it with previous results. Let \boldsymbol{X} = \boldsymbol{U}\boldsymbol{Y}\boldsymbol{V}^{\top}. We also have \boldsymbol{O}^{\top}\boldsymbol{M} = (\boldsymbol{M}^{\top}\boldsymbol{M})^{1/2} = \boldsymbol{V}(\boldsymbol{\Sigma}^{\top}\boldsymbol{\Sigma})^{1/2}\boldsymbol{V}^{\top} and \boldsymbol{M}\boldsymbol{O}^{\top} = (\boldsymbol{M}\boldsymbol{M}^{\top})^{1/2} = \boldsymbol{U}(\boldsymbol{\Sigma}\boldsymbol{\Sigma}^{\top})^{1/2}\boldsymbol{U}^{\top}. Substituting these into Eq. [eq:g-o] yields: \begin{equation} \boldsymbol{U}\boldsymbol{Y}(\boldsymbol{\Sigma}^{\top}\boldsymbol{\Sigma})^{1/2}\boldsymbol{V}^{\top} + \boldsymbol{U}(\boldsymbol{\Sigma}\boldsymbol{\Sigma}^{\top})^{1/2}\boldsymbol{Y}\boldsymbol{V}^{\top} = \nabla_{\boldsymbol{O}}\mathcal{L} \end{equation} which is: \begin{equation} \boldsymbol{Y}(\boldsymbol{\Sigma}^{\top}\boldsymbol{\Sigma})^{1/2} + (\boldsymbol{\Sigma}\boldsymbol{\Sigma}^{\top})^{1/2}\boldsymbol{Y} = \boldsymbol{U}^{\top}(\nabla_{\boldsymbol{O}}\mathcal{L})\boldsymbol{V} \label{eq:g-o-2} \end{equation} The left side of the above equation, written in component form, is \boldsymbol{Y}_{i,j}\sigma_j + \sigma_i \boldsymbol{Y}_{i,j} = (\sigma_i + \sigma_j)\boldsymbol{Y}_{i,j}, where \sigma_1, \sigma_2, \dots, \sigma_r are the non-zero singular values of \boldsymbol{M}, and 0 = \sigma_{r+1} = \sigma_{r+2} = \dots. Clearly, if \boldsymbol{M} is a full-rank square matrix, we can solve for: \begin{equation} \boldsymbol{Y} = (\boldsymbol{U}^{\top}(\nabla_{\boldsymbol{O}}\mathcal{L})\boldsymbol{V}) \oslash \boldsymbol{S} \end{equation} where \boldsymbol{S}_{i,j} = \sigma_i + \sigma_j and \oslash is the Hadamard division (element-wise division). Substituting \boldsymbol{X} = \boldsymbol{U}\boldsymbol{Y}\boldsymbol{V}^{\top} into Eq. [eq:g-m] yields results consistent with those in "The Derivative of SVD". This convergence of different methods strengthens our confidence that our derivation so far is correct.
What if \boldsymbol{M} is not full-rank or not square? In this case, if the right side \boldsymbol{U}^{\top}(\nabla_{\boldsymbol{O}}\mathcal{L})\boldsymbol{V} does not "cooperate," Eq. [eq:g-o-2] has no solution. However, since Eq. [eq:g-o-2] is derived from a real problem, it must have a solution, so the right side "must cooperate"! What does cooperation mean? If the rank of \boldsymbol{M} is r, then the matrix \boldsymbol{S} is non-zero only in \boldsymbol{S}_{[:r,:r]}. For Eq. [eq:g-o-2] to have a solution, the parts of (\boldsymbol{U}^{\top}(\nabla_{\boldsymbol{O}}\mathcal{L})\boldsymbol{V}) outside of [:r,:r] must be zero. Under this condition, we can write: \begin{equation} \boldsymbol{Y} = \lim_{\epsilon \to 0} (\boldsymbol{U}^{\top}(\nabla_{\boldsymbol{O}}\mathcal{L})\boldsymbol{V}) \oslash (\boldsymbol{S} + \epsilon) \end{equation} This is equivalent to saying that we can add some perturbation to the singular values to transform them into a case where all singular values are non-zero, and then let the perturbation tend to zero after the calculation to obtain the correct result.
Efficient Solution
The SVD solution in the previous section often has only theoretical value. For efficient calculation on GPUs, we need to seek other forms of solutions. Introducing the notation \boldsymbol{M}\boldsymbol{O}^{\top} = \boldsymbol{A}, \boldsymbol{O}^{\top}\boldsymbol{M} = \boldsymbol{B}, and \nabla_{\boldsymbol{O}}\mathcal{L} = \boldsymbol{C}, Eq. [eq:g-o] is actually a Sylvester equation: \begin{equation} \boldsymbol{A}\boldsymbol{X} + \boldsymbol{X}\boldsymbol{B} = \boldsymbol{C} \end{equation} There are many methods to solve the Sylvester equation. Among them, the most ingenious and efficient for GPUs might be the solution scheme based on \operatorname{mcsgn} (not \operatorname{msign}) (referencing "Fast Differentiable Matrix Square Root"). First, starting from the above equation, we can verify that the following holds: \begin{equation} \begin{bmatrix} \boldsymbol{A} & -\boldsymbol{C} \\ \boldsymbol{0} & -\boldsymbol{B}\end{bmatrix} = \begin{bmatrix} \boldsymbol{I} & \boldsymbol{X} \\ \boldsymbol{0} & \boldsymbol{I}\end{bmatrix}\begin{bmatrix} \boldsymbol{A} & \boldsymbol{0} \\ \boldsymbol{0} & -\boldsymbol{B}\end{bmatrix}\begin{bmatrix} \boldsymbol{I} & \boldsymbol{X} \\ \boldsymbol{0} & \boldsymbol{I}\end{bmatrix}^{-1} \end{equation} Taking \operatorname{mcsgn} on both sides, according to the properties of \operatorname{mcsgn}, we have: \begin{equation} \operatorname{mcsgn}\left(\begin{bmatrix} \boldsymbol{A} & -\boldsymbol{C} \\ \boldsymbol{0} & -\boldsymbol{B}\end{bmatrix}\right) = \begin{bmatrix} \boldsymbol{I} & \boldsymbol{X} \\ \boldsymbol{0} & \boldsymbol{I}\end{bmatrix}\begin{bmatrix} \operatorname{mcsgn}(\boldsymbol{A}) & \boldsymbol{0} \\ \boldsymbol{0} & -\operatorname{mcsgn}(\boldsymbol{B})\end{bmatrix}\begin{bmatrix} \boldsymbol{I} & \boldsymbol{X} \\ \boldsymbol{0} & \boldsymbol{I}\end{bmatrix}^{-1} \end{equation} Note that \boldsymbol{A} = \boldsymbol{M}\boldsymbol{O}^{\top} = (\boldsymbol{M}\boldsymbol{M}^{\top})^{1/2} and \boldsymbol{B} = \boldsymbol{O}^{\top}\boldsymbol{M} = (\boldsymbol{M}^{\top}\boldsymbol{M})^{1/2}. Assuming \boldsymbol{M} is a full-rank square matrix, then \boldsymbol{A} and \boldsymbol{B} are both positive definite and symmetric. The \operatorname{mcsgn} of a positive definite symmetric matrix is the identity matrix \boldsymbol{I}. Therefore: \begin{equation} \operatorname{mcsgn}\left(\begin{bmatrix} \boldsymbol{A} & -\boldsymbol{C} \\ \boldsymbol{0} & -\boldsymbol{B}\end{bmatrix}\right) = \begin{bmatrix} \boldsymbol{I} & \boldsymbol{X} \\ \boldsymbol{0} & \boldsymbol{I}\end{bmatrix}\begin{bmatrix} \boldsymbol{I} & \boldsymbol{0} \\ \boldsymbol{0} & -\boldsymbol{I}\end{bmatrix}\begin{bmatrix} \boldsymbol{I} & \boldsymbol{X} \\ \boldsymbol{0} & \boldsymbol{I}\end{bmatrix}^{-1} = \begin{bmatrix} \boldsymbol{I} & -2\boldsymbol{X} \\ \boldsymbol{0} & -\boldsymbol{I}\end{bmatrix} \end{equation} The final simplification uses the identity \begin{bmatrix} \boldsymbol{I} & \boldsymbol{X} \\ \boldsymbol{0} & \boldsymbol{I}\end{bmatrix}^{-1} = \begin{bmatrix} \boldsymbol{I} & -\boldsymbol{X} \\ \boldsymbol{0} & \boldsymbol{I}\end{bmatrix}. From this result, we can see that we only need to calculate \operatorname{mcsgn} for the block matrix \begin{bmatrix} \boldsymbol{A} & -\boldsymbol{C} \\ \boldsymbol{0} & -\boldsymbol{B}\end{bmatrix}, and then \boldsymbol{X} can be read from the upper-right corner of the result. Since \operatorname{mcsgn} can be efficiently calculated via Newton-Schulz iteration, this scheme is GPU-friendly.
When \boldsymbol{M} is not full-rank or not square, \boldsymbol{A} and \boldsymbol{B} are only semi-positive definite, and their \operatorname{mcsgn} is not \boldsymbol{I}. However, the experience from the previous section tells us that since \nabla_{\boldsymbol{O}}\mathcal{L} "must cooperate," we only need to add a small perturbation to \boldsymbol{\Sigma} to make it positive definite. Adding a perturbation to \boldsymbol{\Sigma} is equivalent to adding \epsilon \boldsymbol{I} to \boldsymbol{A} and \boldsymbol{B}, so: \begin{equation} \boldsymbol{X} = -\frac{1}{2} \left(\lim_{\epsilon \to 0} \operatorname{mcsgn}\left(\begin{bmatrix} \boldsymbol{A} + \epsilon \boldsymbol{I} & -\boldsymbol{C} \\ \boldsymbol{0} & -\boldsymbol{B} - \epsilon \boldsymbol{I}\end{bmatrix}\right)\right)_{[:n,n:]} \end{equation} In actual calculation, we can only choose a small \epsilon > 0 for approximate calculation. \epsilon = 10^{-3} can be considered, as it falls within the lower bound range of the Newton-Schulz iteration we previously sought.
Conclusion
This article discussed the calculation of the derivative of the \operatorname{msign} operator. If you are interested in the combination of "TTT + Muon," then this article may be of help to you.
When reposting, please include the original address of this article: https://kexue.fm/archives/11025
For more detailed reposting matters, please refer to: "Scientific Space FAQ"