Readers who follow visual generative models know that the Fréchet Inception Distance (FID) is one of the key evaluation metrics; the smaller it is, the more realistic the generated results. A natural question is: why not directly use FID as the loss function to train generative models? Is it because FID is non-differentiable? Actually, FID is differentiable, and using it as a loss is theoretically fine, but practical difficulties arise in computation.
Recently, the paper Representation Fréchet Loss for Visual Generation made some attempts to overcome these difficulties, successfully applied FID to the fine-tuning of generative models, and significantly improved the performance of single-step generation. This article will briefly discuss the mathematical principles and implementation techniques involved.
Generation Metrics
FID stands for “Fréchet Inception Distance.” We can understand it in two parts: “Fréchet Distance (FD)” and “Inception (I).”
Suppose we have two distributions p and q, representing real and generated samples respectively. We encode each sample \boldsymbol{x} through some pretrained encoder \phi into a feature vector \boldsymbol{z}=\phi(\boldsymbol{x})\in\mathbb{R}^d, and estimate their respective mean vectors \boldsymbol{\mu}_p,\boldsymbol{\mu}_q and covariance matrices \boldsymbol{\Sigma}_p,\boldsymbol{\Sigma}_q. Then, assuming the encoded results follow a multivariate normal distribution, we can use a discrepancy function between normal distributions to measure their difference. The Fréchet Distance chooses the Wasserstein-2 distance: \begin{aligned} \mathcal{F}\triangleq\mathcal{W}_2^2[p,q]=&\,\Vert \boldsymbol{\mu}_p - \boldsymbol{\mu}_q\Vert^2 + \mathop{\mathrm{tr}}(\boldsymbol{\Sigma}_p + \boldsymbol{\Sigma}_q - 2(\boldsymbol{\Sigma}_p\boldsymbol{\Sigma}_q)^{1/2})\\[4pt] =&\,\Vert \boldsymbol{\mu}_p - \boldsymbol{\mu}_q\Vert^2 + \mathop{\mathrm{tr}}(\boldsymbol{\Sigma}_p + \boldsymbol{\Sigma}_q - 2(\boldsymbol{\Sigma}_p^{1/2}\boldsymbol{\Sigma}_q\boldsymbol{\Sigma}_p^{1/2})^{1/2}) \end{aligned}\label{eq:w-p-q} Plugging the mean vectors and covariance matrices of the respective encodings into the above formula yields the Fréchet Distance. For the derivation, interested readers can refer to “KL Divergence, Bhattacharyya Distance, and Wasserstein Distance between Two Multivariate Normal Distributions”.
If the encoder \phi is chosen as InceptionV3 (I), the resulting metric is called Fréchet Inception Distance (FID). This evaluation metric was first proposed in the 2017 paper GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium, which in some sense is already a product of the “ancient era”.
Of course, nowadays we do not necessarily have to use InceptionV3 for either training or evaluation; we can use more advanced feature models, such as SigLIP, or compute Fréchet Distance with multiple different encoders and then sum them up, etc. We can uniformly refer to these approaches as “FD Loss.”
Gradient Computation
Now let us derive step by step and see what difficulties FD really
encounters as a loss. The first thing is to solve the gradient
computation problem. p represents the
real distribution; its \boldsymbol{\mu}_p,\boldsymbol{\Sigma}_p are
fixed, and we only need to compute gradients with respect to \boldsymbol{\mu}_q,\boldsymbol{\Sigma}_q. The
gradient w.r.t. \boldsymbol{\mu}_q is
relatively simple: \nabla_{\boldsymbol{\mu}_q}\mathcal{F} =
\nabla_{\boldsymbol{\mu}_q}\Vert \boldsymbol{\mu}_p -
\boldsymbol{\mu}_q\Vert^2 = 2(\boldsymbol{\mu}_q -
\boldsymbol{\mu}_p) The gradient w.r.t. \boldsymbol{\Sigma}_q is \nabla_{\boldsymbol{\Sigma}_q}\mathcal{F} =
\nabla_{\boldsymbol{\Sigma}_q}\mathop{\mathrm{tr}}(\boldsymbol{\Sigma}_p
+ \boldsymbol{\Sigma}_q -
2(\boldsymbol{\Sigma}_p^{1/2}\boldsymbol{\Sigma}_q\boldsymbol{\Sigma}_p^{1/2})^{1/2})
= \boldsymbol{I} -
2\nabla_{\boldsymbol{\Sigma}_q}\mathop{\mathrm{tr}}((\boldsymbol{\Sigma}_p^{1/2}\boldsymbol{\Sigma}_q\boldsymbol{\Sigma}_p^{1/2})^{1/2})
Here we use the second line of Eq. [eq:w-p-q],
which looks more complicated but has one advantage: the matrix \boldsymbol{S} =
\boldsymbol{\Sigma}_p^{1/2}\boldsymbol{\Sigma}_q\boldsymbol{\Sigma}_p^{1/2}
is symmetric positive definite, a fact that can simplify the
computation. Let the eigendecomposition (which coincides with the SVD
here) of \boldsymbol{S} be \boldsymbol{U}\boldsymbol{\Lambda}\boldsymbol{U}^{\top};
then \boldsymbol{S}^{1/2}=\boldsymbol{U}\boldsymbol{\Lambda}^{1/2}\boldsymbol{U}^{\top},
and we have \begin{aligned}
\mathop{\mathrm{tr}}(\boldsymbol{S}^{1/2})=&\,\mathop{\mathrm{tr}}(\boldsymbol{\Lambda}^{1/2})=\sqrt{\lambda_1}+\sqrt{\lambda_2}+\cdots+\sqrt{\lambda_d}
\\[4pt]
\nabla_{\boldsymbol{S}}\mathop{\mathrm{tr}}(\boldsymbol{S}^{1/2})
=&\, \frac{1}{2}\sum_{i=1}^d\frac{\nabla_{\boldsymbol{S}}
\lambda_i}{\sqrt{\lambda_i}}
= \frac{1}{2}\sum_{i=1}^d\frac{\boldsymbol{u}_i\boldsymbol{u}_i^{\top}}{\sqrt{\lambda_i}}
=
\frac{1}{2}\boldsymbol{U}\boldsymbol{\Lambda}^{-1/2}\boldsymbol{U}^{\top}
= \frac{1}{2}\boldsymbol{S}^{-1/2}
\end{aligned} where the derivative of eigenvalues can be found in
“Derivative of SVD”. The
final result resembles \frac{d}{dx}\sqrt{x} =
\frac{1}{2\sqrt{x}}, which looks intuitive, but it is not
trivial; if \boldsymbol{S} is not a
positive definite symmetric matrix, it generally does not hold. Finally,
by the chain rule, \nabla_{\boldsymbol{\Sigma}_q}
\mathop{\mathrm{tr}}(\boldsymbol{S}^{1/2}) =
\boldsymbol{\Sigma}_p^{1/2}[\nabla_{\boldsymbol{S}}\mathop{\mathrm{tr}}(\boldsymbol{S}^{1/2})]
\boldsymbol{\Sigma}_p^{1/2} =
\frac{1}{2}\boldsymbol{\Sigma}_p^{1/2}\boldsymbol{S}^{-1/2}\boldsymbol{\Sigma}_p^{1/2}
Putting everything together, we obtain \nabla_{\boldsymbol{\Sigma}_q}\mathcal{W}_2^2[p,q]
= \boldsymbol{I} -
\boldsymbol{\Sigma}_p^{1/2}(\boldsymbol{\Sigma}_p^{1/2}\boldsymbol{\Sigma}_q\boldsymbol{\Sigma}_p^{1/2})^{-1/2}\boldsymbol{\Sigma}_p^{1/2}\label{eq:Sigma-grad}
This form may seem complicated, but \boldsymbol{\Sigma}_p^{1/2} can be
precomputed; we only need to compute the square root and inverse square
root of the positive definite symmetric matrix \boldsymbol{S} =
\boldsymbol{\Sigma}_p^{1/2}\boldsymbol{\Sigma}_q\boldsymbol{\Sigma}_p^{1/2}
to obtain the FID and its gradient. This can be done via the
eigh function, or using the Newton–Schulz iteration scheme
introduced in “Efficient
Computation of Matrix Square Root and Inverse Square Root” and “Efficient Computation of Matrix
r-th Root and Inverse r-th Root”.
Very Large Batch Size
Introduce the notation \begin{gathered} \boldsymbol{\mu}_p = \mathbb{E}[\boldsymbol{z}_p], \qquad \boldsymbol{V}_p = \mathbb{E}[\boldsymbol{z}_p \boldsymbol{z}_p^{\top}], \qquad \boldsymbol{z}_p = \phi(\boldsymbol{x}_p),\qquad \boldsymbol{x}_p\sim p \\[4pt] \boldsymbol{\mu}_q = \mathbb{E}[\boldsymbol{z}_q], \qquad \boldsymbol{V}_q = \mathbb{E}[\boldsymbol{z}_q \boldsymbol{z}_q^{\top}], \qquad \boldsymbol{z}_q = \phi(\boldsymbol{x}_q),\qquad \boldsymbol{x}_q\sim q \end{gathered} Then \boldsymbol{\Sigma}_p = \boldsymbol{V}_p - \boldsymbol{\mu}_p \boldsymbol{\mu}_p^{\top},\qquad\boldsymbol{\Sigma}_q = \boldsymbol{V}_q - \boldsymbol{\mu}_q \boldsymbol{\mu}_q^{\top} Note that \boldsymbol{z}=\phi(\boldsymbol{x}) typically has thousands of dimensions (InceptionV3 is 2048), so to obtain accurate estimates we usually need tens of thousands of samples. The real distribution is fixed; its \boldsymbol{\mu}_p,\boldsymbol{\Sigma}_p can be computed in advance without problems. However, the generated distribution changes in real time. If we always use tens of thousands of samples per step, it means the batch size must be tens of thousands, which is quite expensive in many scenarios.
On the other hand, from the gradient formula [eq:Sigma-grad] we can also see the necessity of a large batch. If the batch size is too small, the estimated \boldsymbol{V}_q will not be full rank, hence \boldsymbol{\Sigma}_q will not be full rank either, and computing (\boldsymbol{\Sigma}_p^{1/2}\boldsymbol{\Sigma}_q\boldsymbol{\Sigma}_p^{1/2})^{-1/2} becomes impossible (one would face 0^{-1/2}). Therefore, FD as a loss imposes a requirement on the training batch size, which is arguably the central practical difficulty. Limited by computational power, we can only try to simulate the effect of a large batch size with a small batch size, similar to the need for gradient accumulation in contrastive learning.
Equivalent Loss
Suppose that a batch size of B yields sufficiently accurate \boldsymbol{\mu}_q,\boldsymbol{V}_q, but we can only run a small batch size b each time, so we need k = B/b steps to simulate the large-batch effect. Denote the statistics from each step as \tilde{\boldsymbol{\mu}}_q^{(1)},\tilde{\boldsymbol{V}}_q^{(1)}, \tilde{\boldsymbol{\mu}}_q^{(2)},\tilde{\boldsymbol{V}}_q^{(2)}, …, \tilde{\boldsymbol{\mu}}_q^{(k)},\tilde{\boldsymbol{V}}_q^{(k)}; then the relationships \boldsymbol{\mu}_q = \frac{1}{k}\sum_{i=1}^k\tilde{\boldsymbol{\mu}}_q^{(i)},\qquad \boldsymbol{V}_q = \frac{1}{k}\sum_{i=1}^k\tilde{\boldsymbol{V}}_q^{(i)} hold. We want to find an ideal equivalent loss such that the total gradient equals the sum of gradients from each small batch, thereby achieving an unbiased estimate. To do so, take the differential of both sides of [eq:w-p-q]: \begin{aligned} d\mathcal{F}(\boldsymbol{\mu}_q,\boldsymbol{V}_q) =&\, \langle\nabla_{\boldsymbol{\mu}_q}\mathcal{F}, d\boldsymbol{\mu}_q \rangle + \langle\nabla_{\boldsymbol{V}_q}\mathcal{F}, d\boldsymbol{V}_q \rangle_F \\ =&\, \sum_{i=1}^k \left[\langle\nabla_{\boldsymbol{\mu}_q}\mathcal{F}, d\tilde{\boldsymbol{\mu}}_q^{(i)}/k \rangle + \langle\nabla_{\boldsymbol{V}_q}\mathcal{F}, d\tilde{\boldsymbol{V}}_q^{(i)}/k \rangle_F\right] \\ =&\, d\sum_{i=1}^k \mathcal{F}(\textcolor{skyblue}{[}\boldsymbol{\mu}_q - \tilde{\boldsymbol{\mu}}_q^{(i)}/k\textcolor{skyblue}{]_{\text{sg}}} + \tilde{\boldsymbol{\mu}}_q^{(i)}/k,\textcolor{skyblue}{[}\boldsymbol{V}_q - \tilde{\boldsymbol{V}}_q^{(i)}/k\textcolor{skyblue}{]_{\text{sg}}} + \tilde{\boldsymbol{V}}_q^{(i)}/k) \\ \end{aligned} This equality means we can perform forward passes in small batches sequentially to obtain \tilde{\boldsymbol{\mu}}_q^{(i)},\tilde{\boldsymbol{V}}_q^{(i)}, average them to get sufficiently accurate \boldsymbol{\mu}_q,\boldsymbol{V}_q, and then for each batch compute the loss \mathcal{F}_i = \mathcal{F}(\textcolor{skyblue}{[}\boldsymbol{\mu}_q - \tilde{\boldsymbol{\mu}}_q^{(i)}/k\textcolor{skyblue}{]_{\text{sg}}} + \tilde{\boldsymbol{\mu}}_q^{(i)}/k,\textcolor{skyblue}{[}\boldsymbol{V}_q - \tilde{\boldsymbol{V}}_q^{(i)}/k\textcolor{skyblue}{]_{\text{sg}}} + \tilde{\boldsymbol{V}}_q^{(i)}/k)\label{eq:Fi} and backpropagate normally. Finally, accumulating their gradients yields the gradient equivalent to batch size B, where \textcolor{skyblue}{[\cdot]_{\text{sg}}} is the stop-gradient operator. Of course, we can also consider not accumulating the gradients but performing an update at each step with a correspondingly smaller learning rate; the effect is similar.
Making Up with History
Although the above scheme is theoretically feasible, it requires k forward passes to compute an accurate \boldsymbol{\mu}_q,\boldsymbol{V}_q before the gradients for each step can be computed, which makes the whole process not very “smooth.” The bottleneck is that we must know the global \boldsymbol{\mu}_q,\boldsymbol{V}_q in order to compute an unbiased local gradient.
A natural idea is: can we use some approximation for \boldsymbol{\mu}_q,\boldsymbol{V}_q? Considering the learning rate is small and the parameter updates are slow, the changes in \boldsymbol{\mu}_q,\boldsymbol{V}_q should also be slow. After incorporating the current batch, the new \boldsymbol{\mu}_q,\boldsymbol{V}_q should merely be a slight fine-tuning of the old ones. We can use an Exponential Moving Average (EMA) to approximate this operation: \boldsymbol{\mu}_q^{(t)} = \beta \boldsymbol{\mu}_q^{(t-1)} + (1-\beta) \tilde{\boldsymbol{\mu}}_q^{(t)},\qquad \boldsymbol{V}_q^{(t)} = \beta \boldsymbol{V}_q^{(t-1)} + (1-\beta) \tilde{\boldsymbol{V}}_q^{(t)} This roughly maintains an average window of size \mathcal{O}(1/(1-\beta)), effectively expanding the statistical batch size of \boldsymbol{\mu}_q,\boldsymbol{V}_q to \mathcal{O}(1/(1-\beta)) times. In this way, at each step we can compute gradients according to the loss \mathcal{F}_t = \mathcal{F}(\underbrace{\beta \textcolor{skyblue}{[}\boldsymbol{\mu}_q^{(t-1)}\textcolor{skyblue}{]_{\text{sg}}} + (1-\beta) \tilde{\boldsymbol{\mu}}_q^{(t)}}_{\boldsymbol{\mu}_q^{(t)}},\underbrace{\beta \textcolor{skyblue}{[}\boldsymbol{V}_q^{(t-1)}\textcolor{skyblue}{]_{\text{sg}}} + (1-\beta) \tilde{\boldsymbol{V}}_q^{(t)}}_{\boldsymbol{V}_q^{(t)}}) and update the model. The extra cost is caching \boldsymbol{\mu}_q,\boldsymbol{V}_q, which is minimal. This operation of “batch size not enough, history to make up” actually embodies the “streaming” idea of streaming power iteration. In addition, the paper also discusses a queue-based approach, maintaining a queue of k historical batches, incorporating the current batch according to Eq. [eq:Fi] to compute the gradient, and evicting the oldest batch. This approach is relatively naive, takes up much more memory than EMA, and empirically performs worse than EMA.
Experiment Appreciation
The paper’s experiments mainly focus on post-training of generative models, aiming to improve the original single-step generation models or to fine-tune multi-step generation models into single-step ones by training with FD Loss. When using multiple different encoders to compute FD Loss, the paper applied a loss normalization technique to balance losses of different magnitudes: \mathcal{L} = \sum_i \frac{\mathcal{F}[\phi_i]}{\textcolor{skyblue}{[}\mathcal{F}[\phi_i]\textcolor{skyblue}{]_{\text{sg}}} + \epsilon} This technique was also discussed in “Multi-task Learning Ramblings (I): In the Name of Loss”.
The core achievement of the paper is pushing the performance (FID) of single-step generation to an entirely new level, surpassing all other single-step and multi-step generation models—a result that appears to have reached the ceiling. Some of the plots are shown below.
Summary
This article mainly theoretically analyzed the difficulties faced when using FID as a loss function for generative models, and how to derive corresponding techniques to overcome these difficulties through the derivation process.
For reprint, please include the original article address: https://kexue.fm/archives/11738
For more detailed reprint matters, refer to: Science Space FAQ