TRACT: Denoising Diffusion Models with Transitive Closure Time-Distillation
David Berthelot, Arnaud Autef, Jierui Lin, Dian Ang Yap, Shuangfei Zhai, Siyuan Hu, Daniel Zheng, Walter Talbott, Eric Gu
Introduction
Diffusion models represent state-of-the-art generative models for many domains and applications. They work by learning to estimate the score of a given data distribution, which in practice can be implemented with a denoising autoencoder following a noise schedule. Training a diffusion model is arguably much simpler compared to many alternative generative modeling approaches, e.g., GANs , normalizing flows and auto-regressive models . The loss is well-defined and stable; there is a large degree of flexibility to design the architecture; and it directly works with continuous inputs without the need for discretization. These properties make diffusion models demonstrate excellent scalability to large models and datasets, as shown in recent works in diverse domains such as image generation , image or audio super-resolution , audio and music synthesis , language models , and cross-domain applications such as text-to-image and text-to-speech
Despite the empirical success, inference efficiency remains a major challenge for diffusion models. As shown in , the inference process of diffusion models can be cast as solving a neural ODE , where the sampling quality improves as the discretization error decreases. As a result, up to thousands of denoising steps are used in practice in order to achieve high sampling quality. This dependency on a large number of inference steps makes diffusion models less favorable compared to one-shot sampling methods, e.g., GANs, especially in resource-constrained deployment settings.
Existing efforts for speeding up inference of diffusion models can be categorized into three classes: (1) reducing the dimensionality of inputs ; (2) improving the ODE solver ; and (3) progressively distilling the output of a teacher diffusion model to a student model with fewer steps . Among these, the progressive distillation approach is of special interest to us. It uses the fact that with a Denoising Diffusion Implicit Model (DDIM) inference schedule , there is a deterministic mapping between the initial noise and the final generated result. This allows one to learn an efficient student model that approximates a given teacher model. A naive implementation of such distillation would be prohibitive, as for each student update, the teacher network needs to be called times (where is typically large) for each student network update. Salimans and Ho bypass this issue by performing progressive binary time distillation (BTD). In BTD, the distillation is divided into phases, and in each phase, the student model learns the inference result of two consecutive teacher model inference steps. Experimentally, BTD can reduce the inference steps to four with minor performance loss on CIFAR10 and 64x64 ImageNet.
In this paper, we aim to push the inference efficiency of diffusion models to the extreme: one-step inference with high quality samples. We first identify critical drawbacks of BTD that prevent it from achieving this goal: 1) objective degeneracy, where the approximation error accumulates from one distillation phase to the next, and 2) the prevention of using aggressive stochastic weights averaging (SWA) to achieve good generalization, due to the fact that the training course is divided into distinct phases.
Motivated by these observations, we propose a novel diffusion model distillation scheme named TRAnsitive Closure Time-Distillation (TRACT). In a nutshell, TRACT trains a student model to distill the output of a teacher model’s inference output from step to with . The training target is computed by performing one step inference update of the teacher model to get , followed by calling the student model to get , in a bootstrapping fashion. At the end of distillation, one can perform one-step inference with the student model by setting and . We show that TRACT can be trained with only one or two phases, which avoids BTD’s objective degeneracy and incompatibility with SWA.
Experimentally, we verify that TRACT drastically improves upon the state-of-the-art results with one and two steps of inference. Notably, it achieves single-step FID scores of 7.4 and 3.8 for 64x64 ImageNet and CIFAR10 respectively.
Related Work
DDIMs are a subclass of Denoising Diffusion Probabilistic Models (DDPM) where the original noise is reused at every step of the inference process. Typically DDIMs use a -steps noise schedule for . By convention, denotes the noise-free step and therefore . In the variance preserving (VP) noisification setting, a noisy sample is produced from the original sample and some Gaussian noise as follows:
A neural network is trained to predict either the signal, the noise or both. The estimations of and at step are denoted as and . For the sake of conciseness, we only detail the signal prediction case. During the denoisification phase, the predicted is used to estimate by substitution in Equation (1):
These estimates allow inference, by substitution in Equation (1), of for any :
Here we introduced the step function to denote DDIM inference from to .
A common framework in the denoisification process is to use stochastic differential equations (SDEs) that maintain the desired distribution as the sample evolves over time . Song et. al. presented a corresponding probability flow ordinary differential equation (ODE) with the initial generated noise as the only source of stochasticity. Compared to SDEs, ODEs can be solved with larger step sizes as there is no randomness between steps.
Another advantage of solving probability flow ODEs is that we can use existing numerical ODE solvers to accelerate sampling in the denoisification phase. However, solving ODEs numerically approximates the true solution trajectory due to the truncation error from the solver. Popular numerical ODE solvers include first-order Euler’s method and higher-order methods such as Runge-Kutta (RK) . Karras et. al. apply Heun’s order method in the family of explicit second-order RK to maintain a tradeoff between truncation error and number of function evaluations (NFEs) .
However, existing ODE solvers are unable to generate high-quality samples in the few-step sampling regime (we loosely define few-steps regime in steps). RK methods may suffer from numerical issues with large step sizes . Our work provides an orthogonal direction to these ODE solvers, and TRACT outputs can be further refined with higher-order methods.
The idea of distilling a pretrained diffusion model to a single step student is first introduced in . Despite encouraging results, it suffers from high training costs and sampling quality degradation. This idea is later extended in , where one progressively distills a teacher model to a student by reducing its total steps by a factor of two.
Specifically, in Binary Time-Distillation (BTD) , a student network is trained to replace two denoising steps of the teacher . Using the step function notation, is modeled to hold this equality:
From this definition, we can determine the target that makes the equality hold (see Appendix A.1):
The signal loss is inferred by rewriting the noise prediction error (see Appendix A.2), yielding:
Once a student has been trained to completion, it becomes the teacher and the process is repeated until the final model has the desired number of steps. training phases are required to distill a -steps teacher to a single-step model and each trained student requires half the sampling steps of its teacher to generate high-quality samples.
Method
We propose TRAnsitive Closure Time-Distillation (TRACT), an extension of BTD, that reduces the number of distillation phases from to a small constant, typically or . We focus on the VP setting used in BTD first, but the method itself is independent of it and we illustrate it in the Variance Exploding (VE) setting at the end of the section. While TRACT also works for noise-predicting objectives, we demonstrate it on signal-prediction where the neural network predicts an estimate of .
We conjecture that the final quality of samples from a distilled model is influenced by the number of distillation phases and the length of each phase. As later validated in the experiments section, we consider two potential explanations as to why it is the case.
In BTD, the student in the previous distillation phase becomes the teacher for the next phase. The student from the previous phase has a positive loss which yields an imperfect teacher for the next phase. These imperfections accumulate over successive generations of students which leads to objective degeneracy.
Generalization
Stochastic Weight Averaging (SWA) has been used to improve the performance of neural networks trained for DDPMs . With Exponential Moving Average (EMA), the momentum parameter is limited by the training length: high momentum yields high-quality results but leads to over-regularized models if the training length is too short. This ties in with the time-distillation problem since the total training length is directly proportional to the number of training phases.
2 TRACT
TRACT is a multi-phase method where each phase distills -steps schedule to steps, and is repeated until the desired number of steps is reached. In a phase, the -steps schedule is partitioned into contiguous groups. The partitioning strategy is left open; for example, in our experiments we used equally-sized groups as demonstrated in Algorithm (1).
Our method can be seen as an extension of BTD which is not constrained by . However, computational implications arise from the relaxation of this constraint, such as the estimation of from for .
For a contiguous segment , we model the student to jump to step from any step as illustrated in Figure (1):
The student is specified to encompass denoising steps of . However, this formulation could require multiple calls of during training, leading to prohibitive computational costs.
The transitive closure operator can now be modeled with self-teaching by rewriting the closure in Equation (6) as a recurrence:
From this definition, we can determine the target that makes the equality hold using the same method as for Equation (4), see Appendix A.1 for details:
For the special case , we have .
The loss is the standard signal-predicting DDIM distillation training loss, e.g. for a target value :
3 Adapting TRACT to a Runge-Kutta teacher and Variance Exploding noise schedule
To illustrate its generality, we apply TRACT to teachers from Elucidating the Design space of diffusion Models (EDM) that use a VE noise schedule and an RK sampler.
A VE noisification process is parameterized by a sequence of noise standard deviations for with , and denotes the noise-free step . A noisy sample is produced from an original sample and Gaussian noise as follows:
Following on the EDM approach, we use an RK sampler for the teacher and distill it to a DDIM sampler for the student. The corresponding step functions are and , respectively. The step function to estimate from , , is defined as:
where .
The step function to estimate from , , is defined as:
Then, learning the transitive closure operator via self-teaching requires:
From this definition, we can again determine the target that makes the equality hold:
The loss is then a weighted loss between the student network prediction and the target. We follow the weighting and network preconditioning strategies introduced in the EDM paper :
The resulting distillation algorithm and details on the derivation of , as well as the training target can be found in Appendix A.8.
Experiments
We present results with TRACT on two image generation benchmarks: CIFAR-10 and class-conditional 64x64 ImageNet. On each dataset, we measure the performance of our distilled models using the Frechet Inception Distance (FID), computed from 50,000 generated samples. We run each experiment with three seeds to compute the mean and standard deviation. 1-step TRACT models improve FID from 9.1 to 4.5 on CIFAR-10 and from 17.5 to 7.4 on 64x64 ImageNet compared to their BTD counterparts, using the exact same architecture and teacher models. We also present results with TRACT when distilling EDM teacher models using a RK sampler and VE noise schedule: they further improve our FID results to 3.8 on CIFAR-10, see Table (1).
We follow up with ablations of the key components of our method: momentums for self-teaching and inference EMAs, and distillation schedules.
The teacher model in each TRACT distillation experiment is initialized from teacher checkpoints of the BTD paper https://github.com/google-research/google-research/tree/master/diffusion_distillation so as to be directly comparable to them.
We use a two-phase distillation schedule. At the start of each phase, the student’s weights are initialized from the current teacher being distilled. In the first phase, the teacher model uses a 1024-step sampling schedule and the student learns to generate samples in 32 steps. In the second phase, the teacher is initialized as the student from the previous phase, and the student learns to generate images in a single step.
We experimented with two training lengths: 96M samples to match the BTD paper, and 256M samples to showcase the benefits of longer training with TRACT. Our 1-step TRACT-96M model obtains an FID of 5.02 that cuts in almost half the previous state-of-the-art of 9.12 with the same architecture and training budget. TRACT-256M further improves our 1-step FID results to 4.45. For both training budgets, we also run distillation experiments ending with a larger number of steps: with and obtain state-of-the-art models at all steps. 1 and 2 step results are presented on Table 1 while 4 and 8 step results are presented on Table 7. More experimental details can be found in Appendix A.3.
On class-conditional 64x64 ImageNet, our single-step TRACT-96M student achieves a FID of 7.43, which improves our BTD counterpart by 2.4x. Due to resource constraints, we did not distill a TRACT model with as many training samples (1.2B) as BTD . Therefore, the new state-of-the-art that we set on the same model architecture is obtained with a tenth of the training budget. 1 and 2 step results are presented in Table 2 while 4 and 8 step results are presented on Table 8. More experimental details can be found in Appendix A.3.
2 Image generation results with EDM teachers
EDM models are initialized from checkpoints released with the paperhttps://nvlabs-fi-cdn.nvidia.com/edm/pretrained/, which are based off NCSN++ architecture for CIFAR-10, and ADM architecture for 64x64 ImageNet. Results for TRACT-EDM models are presented on Table 1 and 7 for CIFAR-10 as well as Table 2 and 8 for 64x64 ImageNet. Experimental details can be found in Appendix A.4.
3 Stochastic Weight Averaging ablations
TRACT uses two different EMAs: one for the self-teacher and one for the student model used at inference time. The self-teacher uses a fast-moving (low momentum) EMA with momentum and the inference model uses a slow-moving (high momentum) EMA with momentum . We study both momentums across ablations on CIFAR-10.
We use the same implementation for the inference model weigths
The momentum parameter for the self-teaching EMA strikes a balance between convergence speed and training stability. With low , the self-teacher weights adapt rapidly to training updates but incorporate noise from the optimization process, leading to unstable self-teaching. On the other hand, higher values yield stable self-teaching targets but introduce latency between the student model state and that of its self-teacher. This, in turn, results in outdated self-teacher targets yielding slower convergence.
For the ablation study of , we fixed the distillation schedule to , the training length to 48M samples per phase and to 0.99995. Results are presented in Table 4The best result in the table does not match our best: throughout ablations, for simplicity and at the cost of performance, we did not allocate a larger share of the training budget to the distillation phase. Performance decreases monotonically as the self-teaching EMA grows above a certain threshold (about in this setting), which supports the slower convergence hypothesis for high values of this parameter. Results are equally worse for values at or below 0.01 and present a high variance. Similarly to observations made in BYOL , we found that a wide range of momentum parameter values gives good performance. In light of this, we set for all other experiments.
We use a slow-moving EMA of student weights at inference time, which has been shown empirically to yield better test time performance . For the ablation study of , we fix the distillation schedule to , training length per phase to 48M samples and . Results are presented in Table 4, we observe that values of strongly affect performance. In A.7 we share a heuristic to compute values yielding high quality results across experiments and for varying training lengths.
4 Influence of the number of distillation phases
In the VP setting, we find that TRACT performs best when using a 2-phase distillation schedule. Confirming our original conjecture, we observe that schedules with more phases suffer more from objective degeneracy. However, we observe the worst results were obtained with a single-phase distillation . In that case, we suspect that due to the long chain of time steps, a phenomenon similar to gradient vanishing is happening. We present ablation results on CIFAR-10 with distillation schedules of increasing number of phases from 1 to 5: .
We set , and the overall training length to 96M samples. Single-step FID results are presented in Table 5. Results clearly get worse with more distillation phases, providing support to the objective degeneracy hypothesis.
TRACT with 3, 4 and 5 phase distillation schedules is trained again with an increased training budget, now set to 48M samples per phase. 1-step FID results are presented in Table 6. Many-phase schedules improve their performance but FID scores are still worse than with the 2-phase schedule, despite leveraging the same training budget per distillation phase. This suggests that the objective degeneracy problem cannot be fully solved at the cost of a reasonably higher training budget. Meanwhile, as seen in previous experiments (see Table (1)), 2-phase results with 256M samples improved markedly over 96M samples. Therefore, with a fixed training budget, distilling a 2-phase TRACT for longer might be the best choice.
To further confirm that objective degeneracy is the reason why TRACT outperforms BTD , we compare BTD to TRACT on the same BTD-compatible schedule: the 10 phases . We set and 48M training samples per distillation phase for both experiments. In this setting, BTD outperforms TRACT with an FID of 5.95 versus 6.8. This is additional confirmation that BTD’s inferior overall performance may come from its inability to leverage 2-phase distillation schedules. Besides the schedule, the other difference between the BTD and TRACT is the use of self-teaching by TRACT. This experiment also suggests that self-teaching may result in less efficient objectives than supervised training.
5 Beyond time distillation
In addition to reducing quality degradation with fewer sampling steps, TRACT can be used for knowledge distillation to other architectures, in particular smaller ones. Compared to TRACT-96M, we show a degradation from 5.02 to 6.47 FID at 1 sampling step on CIFAR-10 by distilling a model from 60.0M parameters to 19.4M. For more details, refer to A.9.
Conclusion
Generating samples in a single step can greatly improve the tractability of diffusion models. We introduce TRAnsitive Closure Time-distillation (TRACT), a new method that significantly improves the quality of generated samples from a diffusion model in a few steps. This result is achieved by distilling a model in fewer phases and with stronger stochastic weight averaging than prior methods. Experimentally, we show that without architecture changes to prior work, TRACT improves single-step FID by up to 2.4. Further experiments demonstrate that TRACT can also effectively distill to other architectures, in particular to smaller student architectures. While demonstrated on images datasets, our method is general and makes no particular assumption about the type of data. It is left to future work to apply it to other types of data.
An interesting extension of TRACT could further improve the quality-efficiency trade-off: tpically, distillation steps in DDIMs/DDPMs have maxed out at 8192 due to computational costs of sampling. Since TRACT allows arbitrary reductions in steps between training phases, we could feasibly distill from much higher step counts teachers, where prior methods could not. This unexplored avenue could open new research into difficult tasks where diffusion models could not previously be applied.
We would like to thank Josh Susskind, Xiaoying Pang, Miguel Angel Bautista Martin and Russ Webb for their feedback and suggestions.
Contributions
Here are the authors contributions to the work: David Berthelot led the research and came up with the transitive closure method and working code prototypes. Arnaud Autef obtained CIFAR-10 results, designed and ran ablation experiments, set up multi-gpu and multi-node training via DDP. Walter Talbott helped with ablation experiments and with writing. Daniel Zheng worked on cloud compute infrastructure, set up multi-gpu and multi-node training via DDP, and ran experiments. Siyuan Hu implemented the FID, integrated the BTD paper’s model into transitive closure framework and conducted the experiments of distillation to smaller architectures. Jierui Lin finalized data, training and evaluation pipeline, obtained 64x64 ImageNet results, integrated BTD’s teacher models and noise schedule to our pipeline, reproduced binary distillation and its variants for ablation. Dian Ang Yap implemented EDM variants, and designed experiments for TRACT (VE-EDM) on CIFAR-10 and ImageNet. Shuangfei Zhai contributed to the discussions, writing and ablation studies. Eric Gu contributed to writing and conducted experiments for distillation to smaller architectures.
References
Appendix A Appendix
We want our student network to match the closure of the teacher steps via self-teaching. If the student network is perfect we have:
A.2 Deriving the distillation loss
The signal training loss is derived from the noise training loss between the expected noise and predicted noise as follows:
In code implementation, it is common for numerical reasons to use:
A.3 Experimental details for TRACT
To obtain our best performing TRACT-96M and TRACT-256M, we use a global batch size of 256 split across 8 GPUs and the Adam optimizer with a constant learning rate of , no weight decay, no dropout, and gradient clipping to a norm of 1.0.
We use a global batch size of 256 split across 8 GPUs for our best performing TRACT-96M. We use Adam optimizer with a constant learning rate of , no weight decay, no dropout, and gradient clipping to a norm of 1.0.
Unlike for CIFAR10 experiments, we predict both signal and noise during distillation, following the setup of BTD .
A.4 Experimental details for TRACT with EDM teachers
We use NCSN++ by Song et. al. and pretrained weights from Karras et. al. . Without augmentation, our model contains 56 million trainable parameters. We use a global batch size of 512 split across 8 GPUs for our best performing TRACT-96M and TRACT-256M. We use Adam optimizer with a constant learning rate of for one-step distillation, and learning rate of for multi-step distillation. For all settings, we disable weight decay, learning rate warmup, dropout, augmentation, and gradient clipping.
As EDM samplers require lower NFEs compared to typical DDIM samplers, we tune the hyperparameter of timesteps to distill from. We find that distilling from 40 steps (79 NFEs with Runge-Kutta) to 1 step gives a good balance between generating high quality targets from teachers, and ease of learning for the student.
For class-conditional 64x64 ImageNet, we use ADM model by Dhariwal and Nichol with 296 million parameters and weights from Karras et. al. with no changes. We train the model with a global batch size of 512 split across 8 A100 GPUs, with Adam optimizer of a constant learning rate of with a linear learning rate warm-up of 4M samples. We disable weight decay, dropout, augmentation, and gradient clipping as we do not observe impacts on the final FID scores. As 64x64 ImageNet uses a different model targeted for higher resolutions, we apply a one-phase distillation from 128 timesteps (255 NFEs with Runge-Kutta) to one-step or multi-step for improved targets from the teacher.
A.5 More experimental results
In Table 7 and 8, we show the results of distilling to 4-step and 8-step student models on CIFAR10 and 64x64 ImageNet. Performance is slightly worse than BTD , possibly due to self-teaching being less efficient than supervised training as discussed in Section 4.4.
A.6 Generated samples
We present random samples from our distilled models with varying sampling steps. As shown in Figure 2, 3 and 4, the deterministic mapping between noises and samples are mostly preserved in these distilled models. We can see that there is a slight degrade of image quality for the 1-step student compared with students distilled to more sampling steps.
A.7 A heuristic to pick the EMA momentum parameter
In our experiments, we evidenced that , the EMA momentum parameter for our evaluation model, is key to good performance and must be tuned carefully. This motivated us to come up with a heuristic to pick it efficiently.
We hypothesize that this momentum parameter should be as high as possible while keeping the EMA unbiased to model parameter values at initialization: at the end of training, the weight of initial student parameters should be small in the EMA of model parameters. Formally, training a model over steps we obtain a sequence of model weights , where represents the model weights at initialization. At the end of training, the resulting EMA of model weights used for inference is obtained from:
In that context, we want to be small at the end of training and we call this small value:
We parameterize our heuristic this way: when varying our training length, changes but we keep fixed to derive an appropriate value used by the inference time EMA. We picked in our experiments and obtained good results with it.
We compare this heuristic to a direct grid search for a fixed parameter. We carry out experiments on CIFAR-10 with 5 different values, . For each value, we train 1-step TRACT models over varying training lengths. We vary the number of samples with a fixed batch size of , the number of training steps therefore varies as . We compare results to a grid search for with 5 different values, . Those values are also fixed across training lengths. Results are summarized in Figure 5. First, we note that training for longer always improves performance. Then, we notice that the heuristic seems more effective at finding a good performing EMA momentum parameter than a direct grid search. All values reach a reasonable performance and the best performance for a given training length is often for one value of , rather than a found from the direct grid search.
A.8 TRACT with EDM teachers
In this section, we expand derivations of the TRACT distillation algorithm with EDM teacher models. We keep their VE noise schedule, an RK sampler with the teacher model, and distill them into a DDIM student network.
With a VE noise schedule, starting from a noisy image with a signal predicting network we obtain estimates for the signal and the noise as:
The corresponding DDIM step from to any is therefore:
The ODE defined by the VE noise process 10 can be solved by numerical integration from step to step with a RK solver following the below steps, which define .
Where the last two steps are skipped when .
We want our student network to match the closure of the teacher steps via self-teaching. If the student network is perfect we have:
We now have all the ingredients to write the TRACT distillation algorithm with EDM teachers, which is presented in Algorithm 2.
A.9 Knowledge Distillation
TRACT can also be used for knowledge distillation from architecture A to another architecture B using only one additional training phase. Of particular interest is the case when the target architecture B has lower computational complexity than A. The additional training phase is a standard distillation (i.e. not time-based), where the number of steps is held constant while the teacher and student have different architectures. For instance, a distillation schedule leads to two possibilities and . In our experiments, we created the architecture B with 67% fewer parameters by keeping the architecture A (from BTD paper), and reducing the number of channels for each layer from 256 to 128. We show comparable performance in Table 9 on CIFAR-10 with 1 sampling step.
It could be interesting to validate in future work whether TRACT works with heterogeneous types of model architectures for students and teachers, such as Transformers .