Fast Sampling of Diffusion Models via Operator Learning
Hongkai Zheng, Weili Nie, Arash Vahdat, Kamyar Azizzadenesheli, Anima Anandkumar
Introduction
Diffusion models (Sohl-Dickstein et al., 2015; Ho et al., 2020), also known as score-based generative models (Song et al., 2020b), have emerged as a powerful generative modeling framework in various areas. They have achieved state-of-the-art (SOTA) performance in many applications including image generation (Dhariwal & Nichol, 2021), molecule generation (Xu et al., 2022), audio synthesis (Kong et al., 2021) and model robustness (Nie et al., 2022). However, sampling from diffusion models requires hundreds of neural network evaluations, making them slower by orders of magnitude compared to other generative models such as generative adversarial networks (GANs) (Goodfellow et al., 2020). Accelerating sampling in diffusion models remains a challenging but important problem, especially when applying them to time-sensitive downstream applications such as AI for art and design (Ramesh et al., 2022) or generative models for decision making (Ajay et al., 2022).
Existing methods for fast sampling of diffusion models can be summarized into two main categories: 1) training-free sampling methods (Song et al., 2020a; Lu et al., 2022) and 2) training-based sampling methods (Luhman & Luhman, 2021; Salimans & Ho, 2021; Xiao et al., 2021). Specifically, the training-free methods focus on reducing the number of discretization steps from a numerical perspective while solving the stochastic differential equations (SDE) or probability flow ordinary differential equations (ODE). However, even the best well-designed numerical solvers (Lu et al., 2022; Karras et al., 2022) still need 1030 model evaluations such that the approximation error is small enough for an acceptable sampling quality. On the other hand, training-based methods train a surrogate network to replace some parts of the numerical solver or even the whole solver. Particularly, progressive distillation (Salimans & Ho, 2021) has made a big step towards real-time sampling (e.g., decent results with 4 steps) but it still has a sequential nature like conventional numerical solvers.
The goal of this work is to develop a fast and parallel sampling method for diffusion models with only one model evaluation. By parallel, we mean that our method can decode images at different time locations in the trajectory in parallel and hence, generate the final solution using only one model evaluation. The major challenge here arises from the difficulty of solving a complicated and large-scale differential equation, which typically requires many discrete time steps to emulate accurately from a numerical approximation perspective.
In this paper, we employ the recent advances in neural operators for solving differential equations to overcome this challenge. Neural operators (Li et al., 2020b; Kovachki et al., 2021b), especially the Fourier neural operator (FNO) (Li et al., 2020a) have shown several orders of magnitude speedup over conventional solvers. This class of models enables learning maps between spaces of functions and is shown to be discretization invariant, allowing them to work with different resolutions of data without changing the model parameters, and can approximate any given nonlinear continuous operator (Kovachki et al., 2021a).
The FNO allows for parallel decoding: i.e. the outputs at all locations of the trajectory can be simultaneously evaluated. This is a property that none of the previous sampling methods for diffusion models enjoy. In this work, we propose a neural operator for diffusion model sampling (DSNO) that maps the initial conditions (i.e. Gaussian distribution) to the solution trajectories and we show its effectiveness in both unconditional and class-conditional image generation.
We propose a neural operator for the fast sampling of diffusion models (DSNO) that can sample high-quality images with one model evaluation.
We introduce temporal convolution blocks parameterized in Fourier space, which can be easily combined with any existing neural architectures of diffusion models to build a neural operator backbone for DSNO. Furthermore, our proposed temporal convolution blocks are lightweight and only increase the model size by 10%.
For the first time, we propose a parallel decoding method to generate the trajectories of images using continuous function representation, which enables generation of the final solution in one model evaluation.
Our proposed DSNO achieves new state-of-the-art FID scores of 3.78 for CIFAR-10 and 7.83 for ImageNet-64 in the setting of single-step-generation of diffusion models.
Finally, we note that DSNO leverages parallel decoding temporally to generate the solution trajectory by evaluating the output function at different time steps in parallel. This is in contrast to the prior training-based methods that have a sequential nature and predict the trajectory step by step. We believe that DSNO with parallel decoding is a key step for the real-time sampling of diffusion models, potentially benefiting many interactive applications.
Background
Fourier neural operator.
where the lifting operator , projection operator , and residual connections are pointwise operators parameterized with neural networks, and is a fixed nonlinear activation function. is an integral kernel operator parameterized in Fourier space such that for a given , an input function to the ’th layer, we have,
where and are the Fourier transform and inverse Fourier transform on , is a trainable parameter that parameterizes a kernel function in Fourier space. Given an input function , we first apply the lifting point-wise operator that expands the co-dimension of the input function , followed by layers of global integral operators accompanied with pointwise non-linearity operation . The result of the global integration layers is passed to the local and pointwise projection layer to compute the output function. This architecture is shown to possess the crucial discretization invariance and universal approximation properties of universal operators (Kovachki et al., 2021a, b).
Learning the trajectory with neural operator
The class of neural operators defined in equation 4 approximates the solution map of the diffusion ODE, i.e., a mapping from to the probability flow trajectory , arbitrarily well.
This implies that the proposed architecture has the required capacity to learn to output the continuous time probability flow trajectory in one model call.
Temporal convolution block in Fourier space.
where is a point-wise nonlinear function, and is a Fourier convolution operator defined in equation 5 parameterized by . Note that our proposed temporal convolution layer differs slightly from the FNO layer given in equation 4. Specifically, we move the nonlinear activation function right after the Fourier convolution operator and replace the linear pointwise operator with an identity shortcut, which preserves the high-frequency information without extra cost and also leads to a better optimization landscape (He et al., 2016). We have not observed the advantages of using a more general linear layer. The identity map is shown to be sufficient and more attractive because it is computationally efficient. Furthermore, we note that, by convolution theorem, we have
Notably, the integral form in equation 8 inherently possesses a structural similarity to the core diffusion process in equation 3, meaning that the temporal convolution layer implicitly parameterize the ODE solution trajectory.
for all . Accordingly, and are realized by the fast Fourier transform algorithm. Figure 1 demonstrates the implementation details of the temporal convolution layers. Note that the temporal convolution layer only operates over the temporal dimension and hidden feature channel dimension and thus treats the pixel dimension as the same as the batch dimension. In other words, in the above example corresponds to the number of channel dimensions in practice.
Architecture of DSNO.
As demonstrated in Figure 1, the architecture of DSNO is built on top of any existing architecture of diffusion models, by adding our proposed temporal convolution layers to each level of the U-Net structure. The dark blue blocks are the modules in the existing diffusion model backbone, which treat the temporal dimension the same as the batch dimension and only work on the pixel and channel dimension. The yellow blocks are the Fourier temporal convolution blocks, which only perform on the temporal and channel dimension. Therefore, our model is highly parallelizable and adds minimal computation complexity to the original backbone. Again, suppose the temporal domain is discretized into . The DSNO takes as input the time embeddings at these times and the initial condition. The feature map of the first convolution layer is repeated times over the temporal dimension as the initial feature at different times. Each feature representation is combined with the corresponding time embedding in the following ResNet blocks.
Training of DSNO
Training DSNO is a standard operator learning setting. The training objective is a weighted integral of the error:
where is the parameter of DSNO, is the weighting function, is the initial condition, and is a norm. In practice, we optimize over to minimize the empirical-risk similar to Kovachki et al. (2021b):
where are discrete points in the temporal domain, and can be generated from any existing solver or sampling method.
Parallel decoding
As shown in the top two yellow blocks in Figure 1, the proposed Fourier temporal convolution block can predict images at different times in parallel. Given any input function , we can compute the Fourier coefficient and then call the inverse Fourier transform at all in parallel to generate output for different times at once. Plus, the other modules of DSNO treat temporal dimension like batch dimension and can perform in parallel for different s. Therefore, DSNO is capable of efficient parallel decoding. Note that the effectiveness of our parallel decoding is based on the fact that the solutions of the diffusion ODE at different times are conditionally independent given the initial condition. Parallel decoding has shown its efficiency in transformers-based models (Chang et al., 2023) and language models (Ghazvininejad et al., 2019) for discrete tokens generation in the spatial domain. DSNO is the first parallel decoding method for continuous diffusion ODE trajectory, which is in temporal domain.
Compact power spectrum.
We examine the spectrum of the probability flow ODE trajectories generated from several publicly available pre-trained diffusion models in the literature, and observe that the ODE trajectories always have a compact energy spectrum over the temporal dimension. See more details in Appendix A.1. The smoothness of the diffusion ODE trajectory means the high-frequency modes do not contribute much to the learning objective. Therefore, DSNO built upon the stacks of Fourier temporal convolution layers can model the underlying solution operator of diffusion ODEs more efficiently with a relatively small number of discretization steps .
Experiments
In our experiments, we examine the proposed method on both unconditional and conditional image generation tasks. We show that our method dramatically accelerates the sampling process of diffusion models, compared to existing fast sampling methods including both training-free and training-based approaches. Our code is available at https://github.com/devzhk/DSNO-pytorch.
We use the Frechet inception distance (FID) (Heusel et al., 2017) to evaluate the quality of generated images. FID score is computed by comparing 50,000 generated images against the corresponding reference statistics of the dataset. We use the ADM’s TensorFlow evaluation suite (Dhariwal & Nichol, 2021) and EDM’s evaluation code (Karras et al., 2022) to compute FID-50K with the same reference statistics. We also report Recall (Kynkäänniemi et al., 2019) as the secondary metric of mode coverage for the experiments on ImageNet-64.
2 Unconditional generation: CIFAR-10
We first generate 1 million trajectories with 512-step DDIM (Song et al., 2020a) using the pre-trained diffusion model proposed by Salimans & Ho (2021), and use it to train DSNO. The FID score of the training set is 2.51, computed over the first 50k data points in the training set.
Sampling quality and speed.
Table 1 compares the proposed DSNO trained with a temporal resolution of 4 against both training-based and training-free sampling methods in terms of FID and the corresponding number of model evaluations. DSNO clearly outperforms all the baselines with only one model evaluation and even achieves a better FID score than 2-step progressive distillation models. Furthermore, we compare the cost of one single forward pass of both DSNO and the original backboneThe progressive distillation only has JAX implementation. We implement its backbone in Pytorch and port the pre-trained weights from the official JAX checkpoint so that we can make a fair speed comparison within the same framework. on a V100 in a standard AWS p3.2xlarge instance. For the speed test, we do 20 warm-up runs to avoid the potential inconsistency arising from the built-in cudnn autotuner. Since the time cost of progressive distillation grows linearly with the number of sampling steps, we can easily calculate the speedup of DSNO over the progressive distillation from Table 3. DSNO is 2.6 times faster than the 4-step progressive distillation model and 1.3 times faster than 2-step progressive distillation model. Compared to hybrid models that combine GAN and diffusion models, DSNO achieves comparable performance with at most one-fourth number of model evaluations.
3 Conditional generation: ImageNet-64
We generate 2.3 million trajectories with 16-step progressive distillation (Salimans & Ho, 2021) using the pre-trained diffusion model from its official code base. The FID score of the generated training set is 2.70, computed over the first 50k training data points.
Sampling quality and speed.
Table 2 compares DSNO trained with a temporal resolution of 4 against the recent advanced fast sampling methods for diffusion models. DSNO clearly outperforms 1-step progressive distillation model and archives comparable FID 2-step models of progressive distillation with only one model evaluation. From Table 3, DSNO has 1.7 times speedup over progressive distillation. The recall of DSNO is comparable to ADM’s, showing that DSNO inherits the original diffusion model’s diversity/mode coverage as it learns to solve the probability flow ODE.
Trajectory prediction and reconstruction.
Figure 2 compares the trajectories predicted by DSNO and the original ODE solver, respectively, for the fixed random seed with a temporal resolution 4. We see that the DSNO predicted trajectory highly matches the groundth-truth ODE trajectory, which demonstrates the effectiveness of DSNO with parallel decoding. Besides, Figure 3 shows the random samples from DSNO and the original pre-trained diffusion model with the same random seed. It is clear that the mapping from Gaussian noise to the output image is well-preserved.
4 Ablation study
We first investigate the impact of temporal convolution by comparing the performance of architectures with and without temporal convolution blocks. All the other settings are kept the same such as temporal resolution 4, quadratic time discretization scheme, the square root of the SNR weighting function, and batch size 256. As reported in Table 4, the temporal convolution design is crucial to DSNO’s performance as its kernel integration operator nature is a better model inductive bias to model the trajectory in time.
Loss weighting.
The loss weighting function used in the training objective of Diffusion models(Ho et al., 2020; Song et al., 2020b) typically distributes more weights to the small times, which is important to training diffusion models. We also adopt such a weighting function since it is generally harder to control the error at small times. We observe that such loss weighting function benefits the training of DSNO. As reported in Table 5, using the square root of the SNR weighting function slightly improves the FID by 0.35.
Time discretization scheme.
How to discretize the temporal domain is important to the performance of the numerical solvers. Some small changes to the time discretization scheme could lead to very different sample qualities as shown in (Karras et al., 2022; Zhang & Chen, 2022). DSNO also needs to choose a way to discretize the temporal domain. Here we consider the two most common choices of time discretization schemes in the literature: uniform time step and quadratic time step. As shown in Table 5, the quadratic time step is slightly better than the uniform time step by 0.12, showing that DSNO is not sensitive to the different time discretization schemes used in the existing solvers and can work nicely with different solvers.
Temporal resolution.
We study the effect of temporal resolution (i.e., the discretization steps ), given the square root of SNR weighting function and the quadratic time discretization scheme. As reported in Table 6, the FID improves as we increase the temporal resolution. Since the higher temporal resolution introduces more supervision into the training, it is reasonable to expect better FID scores. However, higher resolution also results in higher computation costs. Since increasing the resolution from 4 to 8 only provides a marginal benefit (due to the compact spectrum we observed in Appendix A.1), one may use temporal resolution 4 for better efficiency.
Loss function.
Related work
ODE-based samplers are much more widely used in practice (Rombach et al., 2022) than SDE-based methods because they can take large time steps by leveraging some useful structures of the underlying ODE such as semi-linear structure and the form of exponentially weighted integral (Lu et al., 2022; Zhang & Chen, 2022). Existing works (Song et al., 2021; Bao et al., 2021; Zhang & Chen, 2022; Dockhorn et al., 2022) have greatly reduced the number of discretization steps to 10-50 in time while keeping the approximation error small to generate high-quality samples. The exponentially weighted integral structure of the solution trajectory revealed by prior works also inspired our design of the temporal convolution block.
Operator learning for solving PDEs.
Neural operators are deep learning models that are designed for mappings between function spaces, i.e., continuous functions (Li et al., 2020b; Kovachki et al., 2021a). They are widely deployed as the de facto deep learning models in scientific computing when dealing with partial differential equations (PDE). Among these methods, Fourier neural operator stands out and is one of the most efficient machine learning methods for scientific computing problems involving PDE (Yang et al., 2021; Wen et al., 2022). It is shown to possess the crucial discretization invariance and universal approximation properties of universal operators (Kovachki et al., 2021a, b), which motivates our design of the temporal convolution block in our method.
Training-based sampling.
Training-based methods typically train a neural network surrogate to replace some parts of the numerical solver or even the whole solver. This category includes various methods from diverse perspectives such as knowledge distillation (Luhman & Luhman, 2021; Salimans & Ho, 2021), learning the noise schedule (Lam et al., 2021; Watson et al., 2021), learning the reverse covariance (Bao et al., 2022), which require extra training. Training-based methods usually work in the few-step regime with less than 10 steps. Direct Luhman & Luhman (2021) is the first work to get descent sample quality on CIFAR10 with one model evaluation but it suffers from overfitting and its sampling quality drops dramatically compared to the original sampling methods of diffusion models. The current SOTA progressive distillation (Salimans & Ho, 2021) reduces the number of steps down to 4-8 without losing much sample quality. However, it has the same issue as knowledge distillation in the limit of one function evaluation. Some other methods (Xiao et al., 2021; Vahdat et al., 2021; Zheng et al., 2022) combine diffusion models with other generative models such as GAN and VAE to enable fast sampling.
Conclusion and discussion
In this paper, we propose diffusion model sampling with neural operator (DSNO) that maps the initial condition, i.e., Gaussian distribution, to the continuous-time solution trajectory of the reverse diffusion process. To better model the temporal correlations along the trajectory, we introduce temporal convolution layers into the given diffusion model backbone. Experiments show that our method achieves the SOTA FID score of 3.78 for CIFAR-10 and 7.83 for ImageNet-64 with only one model evaluation. Our method is a big step toward real-time sampling of diffusion models, which can potentially benefit many time-sensitive applications of diffusion models.
Acknowledgements
We would like to thank the reviewers and the area chair for their constructive comments. Anima Anandkumar is supported in part by Bren professorship. This work was done partly during Hongkai Zheng’s internship at NVIDIA.
References
Appendix A Appendix
The discrete-time Fourier transform of the signal with period is given by
where . is the frequency. is called the frequency mode. Let be the time step. The spectrum is defined as the product of the Fourier transform of with its conjugate:
where is the complex conjugate. In practice, the statistics are computed over all pixel locations and channels of randomly generated trajectories. and the sampling frequency is 1000 Hz to avoid aliasing. Figure 4 visualizes the energy spectrum of ODE trajectories sampled from the diffusion model ”DDPM++ cont. (VP)” trained by (Song et al., 2020b) on CIFAR10. We observe that most power concentrates in the regime where the frequency mode is less than 5.
A.2 Background: neural operators
Let and be two Banach spaces and be a non-linear map. Suppose we have a finite collection of data where are i.i.d. samples from the distribution supported on and . Neural operators aim to learn parameterized by to approximate from the observed data by minimizing the empirical risk given by
The architecture of neural operators is constructed as a stack of kernel integration layers where the kernel function is parameterized by learnable weights. This architecture utilizes the convolution theorem on abelian groups. Among different neural operator architectures, Fourier neural operator (Li et al., 2020a) stands out and is one of the most efficient machine learning methods for scientific computing problems involving PDE (Yang et al., 2021; Wen et al., 2022). It is shown to possess the crucial discretization invariance and universal approximation properties of universal operators (Kovachki et al., 2021a, b).
A.3 Extended set of generated samples
We provide an extended set of randomly generated samples from our ImageNet-64 model in Figure 5.
A.4 Generalization to different resolution
Figure 6 visualizes the predicted trajectory of DSNO in temporal resolution 8 on ImageNet-64 while it is trained on temporal resolution 4. Although the resulting trajectories do not look perfectly smooth, it still demonstrates the generalization ability of DSNO to unseen time resolutions.
A.5 Further discussion
There are several directions we leave as future work. First, guided sampling of diffusion models is widely used in various applications but accelerating guided sampling is also more challenging (Meng et al., 2022). How to adapt DSNO for sampling Guided diffusion model will be an interesting next step. Second, the temporally continuous output of DSNO provides another level of flexibility compared to distillation-based methods and is readily available for applications such as DiffPure (Nie et al., 2022) that require fast forward/backward sampling from diffusion models at various temporal locations. DSNO could potentially reduce the inference time in those applications. We leave the exploration of those applications to future work. Last but not least, transformer-based architectures have shown their promising capacity for diffusion models (Peebles & Xie, 2022; Bao et al., 2023) in high-resolution image generation. It is natural to integrate our temporal convolution layers into these diffusion transformers as the temporal blocks operate solely on the temporal dimension regardless of how the pixel space is modeled. The resulting new architecture could also potentially serve as a new architecture design for other problems where the dynamics are continuous in time.
Reducing data collection cost with advanced solvers.
While we primarily use DDIM solver to collect data for fair comparison in this paper, it is worth noting that advanced numerical solvers like DPM solver(Lu et al., 2022) can approximate the solution operator with less computation cost, which will greatly speed up our training data generation process. Our final implementation includes examples of using DPM solvers in our GitHub repository https://github.com/devzhk/DSNO-pytorch.