Multistep Distillation of Diffusion Models via Moment Matching

Tim Salimans, Thomas Mensink, Jonathan Heek, Emiel Hoogeboom

Introduction

Diffusion models (Ho et al., 2020; Song & Ermon, 2019; Sohl-Dickstein et al., 2015) have recently become the state-of-the-art model class for generating images, video, audio, and other modalities. By casting the generation of high dimensional outputs as an iterative denoising process, these models have made the problem of learning to synthesize complex outputs tractable. Although this decomposition simplifies the training objective compared to alternatives like GANs, it shifts the computational burden to inference: Sampling from diffusion models usually requires hundreds of neural network evaluations, making these models expensive to use in applications.

To reduce the cost of inference, recent work has moved towards distilling diffusion models into generators that are faster to sample. The methods proposed so far can be subdivided into 2 classes: deterministic methods that aim to directly approximate the output of the iterative denoising process in fewer steps, and distributional methods that try to generate output with the same approximate distribution as learned by the diffusion model. Here we propose a new method for distilling diffusion models of the second type: We cast the problem of distribution matching in terms of matching conditional expectations of the clean data given the noisy data along the sampling trajectory of the diffusion process. The proposed method is closely related to previous approaches applying score matching with an auxiliary model to distilled one-step generators, but the moment matching perspective allows us to generalize these methods to the few-step setting where we obtain large improvements in output quality, even outperforming the many-step base models our distilled generators are learned from. Finally, the moment matching perspective allows us to also propose a second variant of our algorithm that eliminates the need for the auxiliary model in exchange for processing two independent minibatches per parameter update.

Background

2 Generalized method of moments

Moment Matching Distillation

In words: The conditional expectation of clean data should be identical between the data distribution qq and the sampling distribution gg of the distilled model.

Equation 3 gives us a set of moment conditions that uniquely identifies the target distribution, similar to how the regular diffusion training loss identifies the data distribution (Song et al., 2021b). These moment conditions can be used as the basis of a distillation method to finetune gη(zt,t)g_{\eta}({\mathbf{z}}_{t},t) from the denoising model gθg_{\theta}. In particular, we can fit gηg_{\eta} to qq by minimizing the L2-distance between these moments:

The resulting algorithm resembles the alternating optimization of a GAN (Goodfellow et al., 2020), and like a GAN is generally not guaranteed to converge. In practice, we find that Algorithm 2 is stable for the right choice of hyperparameters, especially when taking k8k\geq 8 sampling steps. The algorithm also closely resembles Variational Score Distillation as previously used for distilling 1-step generators gηg_{\eta} in Diff-Instruct. We discuss this relationship in Section 4.

2 Parameter-space moment matching

Alternating optimization of the moment matching objective (Algorithm 2) is difficult to analyze theoretically, and the requirement to keep track of two different models adds engineering complexity. We therefore also experiment with an instantaneous version of the auxiliary denoising model gϕg_{\phi^{*}}, where ϕ\phi^{*} is determined using a single infinitesimal gradient descent step on L(ϕ)L(\phi) (defined in Algorithm 2), evaluated on a single minibatch. Starting from teacher parameters θ\theta, and preconditioning the loss gradient with a pre-determined scaling matrix Λ\Lambda, we can define:

Now we use ϕ(λ)\phi(\lambda) in calculating L(η)L(\eta) from Algorithm 2, take the first-order Taylor expansion for gϕ(λ)(zs)gθ(zs)λgθ(zs)θ(ϕ(λ)θ)=λgθ(zs)θΛϕL(ϕ)ϕ=θg_{\phi(\lambda)}({\mathbf{z}}_{s})-g_{\theta}({\mathbf{z}}_{s})\approx\lambda\frac{\partial g_{\theta}({\mathbf{z}}_{s})}{\partial\theta}(\phi(\lambda)-\theta)=\lambda\frac{\partial g_{\theta}({\mathbf{z}}_{s})}{\partial\theta}\Lambda\nabla_{\phi}L(\phi)|_{\phi=\theta}, and scale the loss with the inverse of λ\lambda to get:

The instantaneous version of our moment matching loss can thus be interpreted as trying to match teacher gradients between the training data and generated data. This makes it a special case of the Efficient Method of Moments (Gallant & Tauchen, 1996), a classic method in statistics where a teacher model pθp_{\theta} is first estimated using maximum likelihood, after which its gradient is used to define a moment matching loss for learning a second model gηg_{\eta}. Under certain conditions, the second model then attains the statistical efficiency of the maximum likelihood teacher model. The difference between our version of this method and that proposed by Gallant & Tauchen (1996) is that in our case the loss of the teacher model is a weighted denoising loss, rather than the log-likelihood of the data.

3 Hyperparameter choices

In our choice of hyperparameters we choose to stick as closely as possible to the values recommended in EDM (Karras et al., 2022), some of which were also used in Diff-Instruct (Luo et al., 2024) and DMD (Yin et al., 2023). We use the EDM test time noise schedule for p(s)p(s), as well as their training loss weighting for w(s)w(s), but we shift all log-signal-to-noise ratios with the resolution of the data following Hoogeboom et al. (2023). For our gradient preconditioner Λ\Lambda, as used in Section 3.2, we use the preconditioner defined in Adam (Kingma & Ba, 2014), which can be loaded from the teacher checkpoint or calculated fresh by running a few training steps before starting distillation. During distillation, Λ\Lambda is not updated.

To get stable results for small numbers of sampling steps (k=1,2k=1,2) we find that we need to use a weighting function w(s)w(s) with less emphasis on high-signal (low ss) data than in the EDM weighting. Using a flat weight w(s)=1w(s)=1 or the adaptive weight from DMD (Yin et al., 2023) works well.

As with previous methods, it’s possible to enable classifier-free guidance (Ho & Salimans, 2022) when evaluating the teacher model gθg_{\theta}. We find that guidance is typically not necessary if output quality is measured by FID, though it does increase Inception Score and CLIP score. To enable classifier-free guidance and prediction clipping for the teacher model in Algorithm 3, we need to define how to take gradients through these modifications: Here we find that a simple straight-through approximation works well, using the backward pass of the unmodified teacher model.

Related Work

In the case of one-step sampling, our method in Algorithm 2 is a special case of Variational Score Distillation, Diff-Instruct, and related methods (Wang et al., 2024; Luo et al., 2024; Yin et al., 2023; Nguyen & Tran, 2023) which distill a diffusion model by approximately minimizing the KL divergence between the distilled generator and the teacher model:

The methods differ in their exact formulation of the adversarial divergence DadvD_{\textrm{adv}}, in the sampling of time steps, and in the use of additional losses. For example Xu et al. (2023a) train unconditional discriminators Dϕ(,t)D_{\phi}(\cdot,t) and decompose the adversarial objective in a marginal (used in the discriminator) and a conditional distribution approximated with an additional regression model. Xiao et al. (2021) instead use a conditional discriminator of the form Dϕ(,zt,t)D_{\phi}(\cdot,{\mathbf{z}}_{t},t).

Experiments

We evaluate our proposed methods in the class-conditional generation setting on the ImageNet dataset (Deng et al., 2009), which is the most well-established benchmark for comparing image quality. On this dataset we also run several ablations to show the effect of classifier-free guidance and other hyperparameter choices on our method. Finally, we present an experiment with a large text-to-image model to show our approach can also be scaled to this setting.

We begin by evaluating on class-conditional ImageNet generation, at the 64×6464\times 64 and 128×128128\times 128 resolutions (Tables 1 and 2). Our results here are for a relatively small model with 400 million parameters based on Simple Diffusion (Hoogeboom et al., 2023). We distill our models for a maximum of 200,000 steps at batch size 2048, calculating FID every 5,000 steps. We report the optimal FID seen during the distillation process, keeping evaluation data and random seeds fixed across evaluations to minimize bias.

For our base models we report results with slight classifier-free guidance of w=0.1w=0.1, which gives the optimal FID. We also use an optimized amount of sampling noise, following Salimans & Ho (2022), which is slightly higher compared to equation 2. For our distilled models we obtained better results without classifier-free guidance, and we use standard ancestral sampling without tuning the sampling noise. We compare against various distillation methods from the literature, including both distillation methods that produce deterministic samplers (progressive distillation, consistency distillation) and stochastic samplers (Diff-Instruct, adversarial methods).

Ranking the different methods by FID, we find that our moment matching distillation method is especially competitive when using 8+8+ sampling steps, where it sets new state-of-the-art results, beating out even the best undistilled models using more than 1000 sampling steps, as well as its teacher model. For 1 sampling step some of the other methods show better results: improving our results in this setting we leave for future work. For 8+8+ sampling steps we get similar results for our alternating optimization version (Section 3.1) and the instant 2-batch version (Section 3.2) of our method. For fewer sampling steps, the alternating version performs better.

We find that our distilled models also perform very well in terms of Inception Score (Salimans et al., 2016) even though we did not optimize for this. By using classifier-free guidance the Inception Score can be improved further, as we show in Section 5.3.

How can a distilled model improve upon its teacher? On Imagenet our distilled diffusion model with 8 sampling steps and no classifier-free guidance outperforms its 512-step teacher with optimized guidance level, for both the 64×6464\times 64 and 128×128128\times 128 resolution. This result might be surprising since the many-step teacher model is often seen as the gold standard for sampling quality. However, even the teacher model has prediction error that makes it possible to improve upon it. In theory, predictions of the clean data at different diffusion times are all linked and should be mutually consistent, but since the diffusion model is implemented with an unconstrained neural network this generally will not be the case in practice. Prediction errors will thus be different across timesteps which opens up the possibility of improving the results by averaging over these predictions in the right way.

Similarly, prediction error will not be constant over the model inputs zt{\mathbf{z}}_{t}, and biasing generation away from areas of large error could also yield sampling improvements. Although many-step ancestral sampling typically gives good results, and is often better than deterministic samplers like DDIM, it’s not necessarily optimal. In future work we hope to study the improvement of moment matching over our base sampler in further detail, and test our hypotheses about its causes.

2 Ablating conditional sampling

3 Effect of classifier-free guidance

Our distillation method can be used with or without guidance. For the alternating optimization version of our method we only apply guidance in the teacher model, but not in the generator or auxiliary denoising model. For the instant 2-batch version we apply guidance and clipping to the teacher model and then calculate its gradient with a straight through approximation. Experimenting with different levels of guidance, we find that increasing guidance typically increases Inception Score and CLIP Score, while reducing FID, as shown in the adjacent figure.

4 Distillation loss is informative for moment matching

A unique advantage of the instant 2-batch version of our moment matching approach is that, unlike most other distillation methods, it has a simple loss function (equation 9) that is minimized without adversarial techniques, bootstrapping, or other tricks. This means that the value of the loss is useful for monitoring the progress of the distillation algorithm. We show this for Imagenet 128×128128\times 128 in the adjacent figure: The typical behavior we see is that the loss tends to go up slightly for the first few optimization steps, after which it exponentially falls to zero with increasing number of parameter updates.

5 Text to image

To investigate our proposed method’s potential to scale to large text-to-image models we train a pixel-space model (no encoder/decoder) on a licensed dataset of text-image pairs at a resolution of 512×512512\times 512, using the UViT model and shifted noise schedule from Simple Diffusion (Hoogeboom et al., 2023) and using a T5 XXL text encoder following Imagen (Saharia et al., 2022). We compare the performance of our base model against an 8-step distilled model obtained with our moment matching method. In Table 3 we report zero-shot FID (Heusel et al., 2017) and CLIP Score (Radford et al., 2021) on MS-COCO (Lin et al., 2014): Also in this setting we find that our distilled model with alternating optimization exceeds the metrics for our base model. The instant 2-batch version of our algorithm performs somewhat less well at 8 sampling steps. Samples from our distilled text-to-image model are shown in Figure 1 and in Figure 7 in the appendix.

Conclusion

We presented Moment Matching Distillation, a method for making diffusion models faster to sample. The method distills many-step diffusion models into few-step models by matching conditional expectations of the clean data given noisy data along the sampling trajectory. The moment matching framework provides a new perspective on related recently proposed distillation methods and allows us to extend these methods to the multi-step setting. Using multiple sampling steps, our distilled models consistently outperform their one-step versions, and often even exceed their many-step teachers, setting new state-of-the-art results on the Imagenet dataset. However, automated metrics of image quality are highly imperfect, and in future work we plan to run a full set of human evaluations on the outputs of our distilled models to complement the metrics reported here.

We presented two different versions of our algorithm: One based on alternating updates of a distilled generator and an auxiliary denoising model, and another using two minibatches to allow only updating the generator. In future work we intend to further explore the space of algorithms spanned by these choices, and gain additional insight into the costs and benefits of both approaches.

References

Appendix A Instant moment matching = matching expected teacher gradients

In section 3.2 we propose an instantaneous version of our moment matching loss that does not require alternating optimization of an auxiliary denoising model gϕg_{\phi}. This alternative version of our algorithm uses the loss in equation 8, which we reproduce here for easy readibility:

Appendix B Experimental details

All experiments were run on TPUv5e , using 256 chips per experiment. For ImageNet we used a global batch size of 2048, while for text-to-image we used a global batch size of 512. The base models were trained for 1M steps, requiring between 2 days (Imagenet 64) to 2 weeks (text-to-image). We use the UViT architecture from Hoogeboom et al. (2023). Configurations largely correspond to those in the appendix of Hoogeboom et al. (2023), where we used their small model variant for our Imagenet experiments.

For Imagenet we distill the trained base models for a maximum of 200,000 steps, and for text-to-image we use a maximum of 50,000 steps. We report the best FID obtained during distillation, evaluating every 5,000 steps. We fix the random seed and data used in each evaluation to minimize biasing our results. We use the Adam optimizer (Kingma & Ba, 2014) with β1=0,β2=0.99,ϵ=1e12\beta_{1}=0,\beta_{2}=0.99,\epsilon=1e^{-12}. We use learning rate warmup for the first 1,000 steps and then linearly anneal the learning rate to zero over the remainder of the optimization steps. We use gradient clipping with a maximum norm of 1. We don’t use an EMA, weight decay, or dropout.

Appendix C More model samples

Appendix D Relationship to score matching

However, moment matching can still be seen to match the proper score expression (equation 19) approximately, if we assume that the forward processes match, meaning pη(ztzs)q(ztzs)p_{\eta}({\mathbf{z}}_{t}|{\mathbf{z}}_{s})\approx q({\mathbf{z}}_{t}|{\mathbf{z}}_{s}). This then gives: