Improved Techniques for Training Consistency Models
Yang Song, Prafulla Dhariwal
Introduction
Consistency models (song2023consistency) are an emerging family of generative models that produce high-quality samples using a single network evaluation. Unlike GANs (goodfellow2014generative), consistency models are not trained with adversarial optimization and thus sidestep the associated training difficulty. Compared to score-based diffusion models (sohl2015deep; song2019generative; song2020improved; ho2020denoising; song2021scorebased), consistency models do not require numerous sampling steps to generate high-quality samples. They are trained to generate samples in a single step, but still retain important advantages of diffusion models, such as the flexibility to exchange compute for sample quality through multistep sampling, and the ability to perform zero-shot data editing.
We can train consistency models using either consistency distillation (CD) or consistency training (CT). The former requires pre-training a diffusion model and distilling the knowledge therein into a consistency model. The latter allows us to train consistency models directly from data, establishing them as an independent family of generative models. Previous work (song2023consistency) demonstrates that CD significantly outperforms CT. However, CD adds computational overhead to the training process since it requires learning a separate diffusion model. Additionally, distillation limits the sample quality of the consistency model to that of the diffusion model. To avoid the downsides of CD and to position consistency models as an independent family of generative models, we aim to improve CT to either match or exceed the performance of CD.
For optimal sample quality, both CD and CT rely on learned metrics like the Learned Perceptual Image Patch Similarity (LPIPS) (zhang2018unreasonable) in previous work (song2023consistency). However, depending on LPIPS has two primary downsides. Firstly, there could be potential bias in evaluation since the same ImageNet dataset (deng2009imagenet) trains both LPIPS and the Inception network in Fréchet Inception Distance (FID) (heusel2017gans), which is the predominant metric for image quality. As analyzed in kynkanniemi2023the, improvements of FIDs can come from accidental leakage of ImageNet features from LPIPS, causing inflated FID scores. Secondly, learned metrics require pre-training auxiliary networks for feature extraction. Training with these metrics requires backpropagating through extra neural networks, which increases the demand for compute.
To tackle these challenges, we introduce improved techniques for CT that not only surpass CD in sample quality but also eliminate the dependence on learned metrics like LPIPS. Our techniques are motivated from both theoretical analysis, and comprehensive experiments on the CIFAR-10 dataset (krizhevsky2014cifar). Specifically, we perform an in-depth study on the empirical impact of weighting functions, noise embeddings, and dropout in CT. Additionally, we identify an overlooked flaw in prior theoretical analysis for CT and propose a simple fix by removing the Exponential Moving Average (EMA) from the teacher network. We adopt Pseudo-Huber losses from robust statistics to replace LPIPS. Furthermore, we study how sample quality improves as the number of discretization steps increases, and utilize the insights to propose a simple but effective curriculum for total discretization steps. Finally, we propose a new schedule for sampling noise levels in the CT objective based on lognormal distributions.
Taken together, these techniques allow CT to attain FID scores of 2.51 and 3.25 for CIFAR-10 and ImageNet in one sampling step, respectively. These scores not only surpass CD but also represent improvements of 3.5 and 4 over previous CT methods. Furthermore, they significantly outperform the best few-step diffusion distillation techniques for diffusion models even without the need for distillation. By two-step generation, we achieve improved FID scores of 2.24 and 2.77 on CIFAR-10 and ImageNet , surpassing the scores from CD in both one-step and two-step settings. These results rival many top-tier diffusion models and GANs, showcasing the strong promise of consistency models as a new independent family of generative models.
Consistency models
where the term is known as the score function of (song2019sliced; song2019generative; song2020improved; song2021scorebased). Here is a small positive value such that , introduced to avoid numerical issues in ODE solving. Meanwhile, is sufficiently large so that . Following Karras2022edm; song2023consistency, we adopt , and throughout the paper. Crucially, solving the probability flow ODE from noise level to allows us to transform a sample into .
The ODE in Eq. 1 establishes a bijective mapping between a noisy data sample and . This mapping, denoted as , is termed the consistency function. By its very definition, the consistency function satisfies the boundary condition . A consistency model, which we denote by , is a neural network trained to approximate the consistency function . To meet the boundary condition, we follow song2023consistency to parameterize the consistency model as
where is a free-form neural network, while and are differentiable functions such that and .
To train the consistency model, we discretize the probability flow ODE using a sequence of noise levels , where we follow Karras2022edm; song2023consistency in setting for , and , where denotes the set of integers . The model is trained by minimizing the following consistency matching (CM) loss over :
Given that relies on the unknown score function , directly optimizing the consistency matching objective in Eq. 3 is infeasible. To circumvent this challenge, song2023consistency propose two training algorithms: consistency distillation (CD) and consistency training (CT). For consistency distillation, we first train a diffusion model to estimate via score matching (hyvarinen-EstimationNonNormalizedStatistical-2005; vincent2011connection; song2019sliced; song2019generative), then approximate with . On the other hand, consistency training employs a different approximation method. Recall that with and . Using the same and , song2023consistency define as an approximation to , which leads to the consistency training objective below:
As analyzed in song2023consistency, this objective is asymptotically equivalent to consistency matching in the limit of . We will revisit this analysis in LABEL:sec:noema.
After training a consistency model through CD or CT, we can directly generate a sample by starting with and computing . Notably, these models also enable multistep generation. For a sequence of indices , we start by sampling and then iteratively compute for , where . The resulting sample approximates the distribution . In our experiments, setting (two-step generation) often enhances the quality of one-step generation considerably, though increasing the number of sampling steps further provides diminishing benefits.
Improved techniques for consistency training
Below we re-examine the design choices of CT in song2023consistency and pinpoint modifications that improve its performance, which we summarize in Section 3. We focus on CT without learned metric functions. For our experiments, we employ the Score SDE architecture in song2021scorebased and train the consistency models for 400,000 iterations on the CIFAR-10 dataset (krizhevsky2014cifar) without class labels. While our primary focus remains on CIFAR-10 in this section, we observe similar improvements on other datasets, including ImageNet (deng2009imagenet). We measure sample quality using Fréchet Inception Distance (FID) (heusel2017gans).