English (unofficial) translations of posts at kexue.fm
Source

Generative Diffusion Models (Part 8): Optimal Diffusion Variance Estimation (II)

Translated by DeepSeek V4 Pro. Translations can be inaccurate, please refer to the original post for important stuff.

In the previous article "Generative Diffusion Models (Part 7): Optimal Diffusion Variance Estimation (I)", we introduced and derived the optimal variance estimation results for diffusion models in Analytic-DPM. It directly provides an analytical estimate of the optimal variance for a pre-trained generative diffusion model. Experiments showed that this estimation result can effectively improve the generation quality of diffusion models.

In this article, we continue to introduce the upgraded version of Analytic-DPM, from the same author team, titled "Estimating the Optimal Covariance with Imperfect Mean in Diffusion Probabilistic Models". It is referred to as "Extended-Analytic-DPM" in the official GitHub repository, and we will use this name below as well.

Review of Results

The previous article was based on DDIM and derived that the optimal variance for the DDIM generation process should be: \sigma_t^2 + \gamma_t^2\bar{\sigma}_t^2 where \bar{\sigma}_t^2 is the variance of the distribution p(\boldsymbol{x}_0|\boldsymbol{x}_t). It has the following estimation result (taking the result from "Variance Estimation 2"): \bar{\sigma}_t^2 = \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\left(1 - \frac{1}{d}\mathbb{E}_{\boldsymbol{x}_t\sim p(\boldsymbol{x}_t)}\left[ \Vert\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\Vert^2\right]\right) \label{eq:basic}

In hindsight, the estimation logic is not particularly difficult. Assuming \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) = \frac{1}{\bar{\alpha}_t}\left(\boldsymbol{x}_t - \bar{\beta}_t \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right) \label{eq:bar-mu} has accurately predicted the mean vector of the distribution p(\boldsymbol{x}_0|\boldsymbol{x}_t), then according to the definition, the covariance can be obtained as: \begin{aligned} \boldsymbol{\Sigma}(\boldsymbol{x}_t) =&\, \mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left(\boldsymbol{x}_0 - \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)\right)\left(\boldsymbol{x}_0 - \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)\right)^{\top}\right] \\ =&\, \frac{1}{\bar{\alpha}_t^2}\mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left(\boldsymbol{x}_t - \bar{\alpha}_t\boldsymbol{x}_0\right)\left(\boldsymbol{x}_t - \bar{\alpha}_t\boldsymbol{x}_0\right)^{\top}\right] - \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2} \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)^{\top} \\ \end{aligned} \label{eq:full-cov} Averaging both sides over \boldsymbol{x}_t\sim p(\boldsymbol{x}_t) to eliminate the dependence on \boldsymbol{x}_t: \boldsymbol{\Sigma}_t = \mathbb{E}_{\boldsymbol{x}_t\sim p(\boldsymbol{x}_t)}[\boldsymbol{\Sigma}(\boldsymbol{x}_t)] = \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\left(\boldsymbol{I} - \mathbb{E}_{\boldsymbol{x}_t\sim p(\boldsymbol{x}_t)}\left[ \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)^{\top}\right]\right) \label{eq:uncond-var-2} Finally, by averaging the diagonal elements to turn it into a scalar (or assuming the covariance is a multiple of the identity matrix), i.e., \bar{\sigma}_t^2 = \text{Tr}(\boldsymbol{\Sigma}_t)/d, we obtain the estimation formula [eq:basic].

How to Improve

Before formally introducing Extended-Analytic-DPM, we can first think about what room for improvement remains for Analytic-DPM.

Actually, with a little thought, many improvements can be found. For example, Analytic-DPM assumes that the covariance matrix of the normal distribution used to approximate p(\boldsymbol{x}_0|\boldsymbol{x}_t) is designed as \bar{\sigma}_t^2\boldsymbol{I}, which is a diagonal matrix with identical diagonal elements. A direct improvement would be to allow the diagonal elements to be different, i.e., \text{diag}(\bar{\boldsymbol{\sigma}}_t^2). Here, we define vector multiplication based on the Hadamard product, e.g., \boldsymbol{x}^2 = \boldsymbol{x} \odot \boldsymbol{x}. The corresponding result would only consider the diagonal part of \boldsymbol{\Sigma}_t. Starting from equation [eq:uncond-var-2], the corresponding estimate is: \bar{\boldsymbol{\sigma}}_t^2 = \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\left(\boldsymbol{1}_d - \mathbb{E}_{\boldsymbol{x}_t\sim p(\boldsymbol{x}_t)}\left[ \boldsymbol{\epsilon}_{\boldsymbol{\theta}}^2(\boldsymbol{x}_t, t)\right]\right) where \boldsymbol{1}_d is a d-dimensional vector of all ones. A further improvement is to retain the dependence of \bar{\boldsymbol{\sigma}}_t^2 on \boldsymbol{x}_t, i.e., considering \bar{\boldsymbol{\sigma}}_t^2(\boldsymbol{x}_t). This is similar to \boldsymbol{\mu}(\boldsymbol{x}_t) and would require a model that takes \boldsymbol{x}_t as input to learn it.

Can we consider the full \boldsymbol{\Sigma}_t? Theoretically yes, but practically it is almost infeasible because the full \boldsymbol{\Sigma}_t is a d \times d matrix. For image scenarios, d is the total number of pixels. Even for CIFAR-10, d = 32^2 \times 3 = 3072, let alone higher resolution images. Thus, given the experimental context, the storage and computational costs of a d \times d matrix are too high.

Besides this, there is a problem that many readers might not have realized: the previous analytical derivations all rely on \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) = \mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}[\boldsymbol{x}_0]. In fact, \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) is learned by a model, and it may not be exactly equal to the true mean \mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}[\boldsymbol{x}_0]. This is the meaning of the "Imperfect Mean" mentioned in the title of the Extended-Analytic-DPM paper. Improving the estimation results under an Imperfect Mean is more practically significant.

Maximum Likelihood

Assuming the mean model \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) has been pre-trained, the only remaining parameter for the distribution \mathcal{N}(\boldsymbol{x}_0;\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t),\bar{\sigma}_t^2\boldsymbol{I}) is \bar{\sigma}_t^2. The corresponding negative log-likelihood is: \begin{aligned} &\, \mathbb{E}_{\boldsymbol{x}_t\sim p(\boldsymbol{x}_t)}\mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[-\log \mathcal{N}(\boldsymbol{x}_0;\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t),\bar{\sigma}_t^2\boldsymbol{I})\right] \\ =&\, \frac{\mathbb{E}_{\boldsymbol{x}_t,\boldsymbol{x}_0\sim p(\boldsymbol{x}_t|\boldsymbol{x}_0)\tilde{p}(\boldsymbol{x}_0)}\left[\Vert\boldsymbol{x}_0 - \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)\Vert^2\right]}{2\bar{\sigma}_t^2} + \frac{d}{2}\log \bar{\sigma}_t^2 + \frac{d}{2}\log 2\pi \\ \end{aligned} \label{eq:neg-log} The value that minimizes this is exactly: \bar{\sigma}_t^2 = \frac{1}{d}\mathbb{E}_{\boldsymbol{x}_t,\boldsymbol{x}_0\sim p(\boldsymbol{x}_t|\boldsymbol{x}_0)\tilde{p}(\boldsymbol{x}_0)}\left[\Vert\boldsymbol{x}_0 - \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)\Vert^2\right] The characteristic here is that \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) is not necessarily the accurate mean result; therefore, the second equality in equation [eq:full-cov] does not hold, only the first equality holds. Substituting equation [eq:bar-mu], we get: \bar{\sigma}_t^2 = \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2 d}\mathbb{E}_{\boldsymbol{x}_0\sim \tilde{p}(\boldsymbol{x}_0),\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0},\boldsymbol{I})}\left[\left\Vert\boldsymbol{\varepsilon} - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\bar{\alpha}_t\boldsymbol{x}_0 + \bar{\beta}_t\boldsymbol{\varepsilon}, t)\right\Vert^2\right] Of course, this only analyzes the simple case where the covariance matrix is \bar{\sigma}_t^2\boldsymbol{I}. We can also consider a more general diagonal covariance, i.e., \mathcal{N}(\boldsymbol{x}_0;\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t),\text{diag}(\bar{\boldsymbol{\sigma}}_t^2)), which results in: \bar{\boldsymbol{\sigma}}_t^2 = \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2 }\mathbb{E}_{\boldsymbol{x}_0\sim \tilde{p}(\boldsymbol{x}_0),\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0},\boldsymbol{I})}\left[\left(\boldsymbol{\varepsilon} - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\bar{\alpha}_t\boldsymbol{x}_0 + \bar{\beta}_t\boldsymbol{\varepsilon}, t)\right)^2\right]

Conditional Variance

If we want to obtain the covariance \text{diag}(\bar{\boldsymbol{\sigma}}_t^2(\boldsymbol{x}_t)) conditioned on \boldsymbol{x}_t, it is equivalent to calculating each component independently. The result is obtained by removing the \mathbb{E}_{\boldsymbol{x}_t\sim p(\boldsymbol{x}_t)} averaging step: \bar{\boldsymbol{\sigma}}_t^2(\boldsymbol{x}_t) = \mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[(\boldsymbol{x}_0 - \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t))^2\right] = \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left(\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right)^2\right] where \boldsymbol{\epsilon}_t = \frac{\boldsymbol{x}_t - \bar{\alpha}_t \boldsymbol{x}_0}{\bar{\beta}_t}. As in the previous article, using \mathbb{E}_{\boldsymbol{x}}[\boldsymbol{x}] = \mathop{\text{argmin}}_{\boldsymbol{\mu}}\mathbb{E}_{\boldsymbol{x}}\left[\Vert \boldsymbol{x} - \boldsymbol{\mu}\Vert^2\right] \label{eq:mean-opt} we get \begin{aligned} \bar{\boldsymbol{\sigma}}_t^2(\boldsymbol{x}_t) =&\, \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left(\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right)^2\right] \\ =&\, \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\mathop{\text{argmin}}_{\boldsymbol{g}}\mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left\Vert\left(\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right)^2-\boldsymbol{g}\right\Vert^2\right] \\ =&\, \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\mathop{\text{argmin}}_{\boldsymbol{g}(\boldsymbol{x}_t)}\mathbb{E}_{\boldsymbol{x}_t\sim p(\boldsymbol{x}_t)}\mathbb{E}_{\boldsymbol{x}_0\sim p(\boldsymbol{x}_0|\boldsymbol{x}_t)}\left[\left\Vert\left(\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right)^2-\boldsymbol{g}(\boldsymbol{x}_t)\right\Vert^2\right] \\ =&\, \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\mathop{\text{argmin}}_{\boldsymbol{g}(\boldsymbol{x}_t)}\mathbb{E}_{\boldsymbol{x}_t,\boldsymbol{x}_0\sim p(\boldsymbol{x}_t|\boldsymbol{x}_0)\tilde{p}(\boldsymbol{x}_0)}\left[\left\Vert\left(\boldsymbol{\epsilon}_t - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right)^2-\boldsymbol{g}(\boldsymbol{x}_t)\right\Vert^2\right] \\ \end{aligned} This is the "NPR-DPM" scheme for learning conditional variance in Extended-Analytic-DPM. Additionally, the original paper proposed an "SN-DPM" scheme, which is based on the Perfect Mean assumption rather than the Imperfect Mean. However, the experimental results in the paper show that SN-DPM outperforms NPR-DPM. In other words, while the paper claims to solve the Imperfect Mean problem, the experiments show that the scheme assuming a Perfect Mean is better. This suggests that the Perfect Mean assumption is actually quite close to practice; or to put it another way, the Imperfect Mean problem can be considered non-existent.

Two Stages

Readers might wonder: didn’t we say earlier that the learnable variance in "Improved Denoising Diffusion Probabilistic Models" increased training difficulty? Why does Extended-Analytic-DPM go back to making a trainable variance model?

We know that DDPM provides two schemes for variance: \sigma_t = \frac{\bar{\beta}_{t-1}}{\bar{\beta}_t}\beta_t and \sigma_t = \beta_t. These two simple schemes actually perform quite well. This indirectly shows that fine-tuning the variance has a small impact on the generation results (at least for the full T-step diffusion); the main factor is the learning of \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t), while the variance is just "the icing on the cake." If variance is treated as a learnable parameter or model and learned together with the mean model \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t), the variance changing during the training process will seriously interfere with the learning of the mean model \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t), violating the principle of "mean model as primary, variance as secondary."

The cleverness of Extended-Analytic-DPM lies in its proposal of a two-stage training scheme: first, train the mean model \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) using the original fixed variance, then fix this model and reuse most of its parameters to learn a variance model. This approach achieves "three goals with one action":

1. Reduces the number of parameters and training costs;

2. Allows the reuse of already trained mean models;

3. Makes the training process more stable.

Personal Thoughts

At this point, the introduction to Extended-Analytic-DPM is basically complete. Attentive readers might feel that if the results of the previous Analytic-DPM were "stunning," then Extended-Analytic-DPM seems quite ordinary, with nothing particularly soul-stirring. It can be said that Extended-Analytic-DPM is a trivial generalization of Analytic-DPM. Although experimental results show it still brings decent improvements, the overall feeling is quite flat. This is largely because Analytic-DPM was such a "gem" that this one seems a bit dim by comparison, though it is a solid piece of work in its own right.

Furthermore, as mentioned earlier, experimental results show that SN-DPM, based on the Perfect Mean assumption, performs better than NPR-DPM, which is based on the Imperfect Mean assumption. This result makes the original paper’s title somewhat "mismatched" with its findings—since the experiments show the Perfect Mean scheme is better, it implies the Imperfect Mean problem can be ignored. The original paper did not provide further analysis or evaluation of this result. I wonder if it relates to the bias of variance estimation? As we know, using the "divide by n" formula to estimate variance is biased, and NPR-DPM is based on this operation. In contrast, SN-DPM directly estimates the second moment, and the estimation of the second moment is unbiased. This seems to make some sense, but it doesn’t fully explain everything. It’s a bit mysterious.

Finally, I wonder if readers have the same question as I do: given \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t), why not directly use the negative log-likelihood like equation [eq:neg-log] as the loss function to learn the variance, instead of redesigning the NPR-DPM or SN-DPM losses in MSE form? Is there any special benefit to the MSE form loss? I haven’t thought of an answer yet.

Summary

This article introduced the optimal diffusion variance estimation results from "Extended-Analytic-DPM," the upgraded version of Analytic-DPM. it primarily focuses on derivations for the imperfect mean case and proposes a learning scheme for conditional variance.

Reprinting: Please include the original address of this article: https://kexue.fm/archives/9246

Further details on reprinting: Please refer to "Scientific Space FAQ".