Improved Techniques for Training Consistency Models

Yang Song, Prafulla Dhariwal

Introduction

Consistency models (song2023consistency) are an emerging family of generative models that produce high-quality samples using a single network evaluation. Unlike GANs (goodfellow2014generative), consistency models are not trained with adversarial optimization and thus sidestep the associated training difficulty. Compared to score-based diffusion models (sohl2015deep; song2019generative; song2020improved; ho2020denoising; song2021scorebased), consistency models do not require numerous sampling steps to generate high-quality samples. They are trained to generate samples in a single step, but still retain important advantages of diffusion models, such as the flexibility to exchange compute for sample quality through multistep sampling, and the ability to perform zero-shot data editing.

We can train consistency models using either consistency distillation (CD) or consistency training (CT). The former requires pre-training a diffusion model and distilling the knowledge therein into a consistency model. The latter allows us to train consistency models directly from data, establishing them as an independent family of generative models. Previous work (song2023consistency) demonstrates that CD significantly outperforms CT. However, CD adds computational overhead to the training process since it requires learning a separate diffusion model. Additionally, distillation limits the sample quality of the consistency model to that of the diffusion model. To avoid the downsides of CD and to position consistency models as an independent family of generative models, we aim to improve CT to either match or exceed the performance of CD.

For optimal sample quality, both CD and CT rely on learned metrics like the Learned Perceptual Image Patch Similarity (LPIPS) (zhang2018unreasonable) in previous work (song2023consistency). However, depending on LPIPS has two primary downsides. Firstly, there could be potential bias in evaluation since the same ImageNet dataset (deng2009imagenet) trains both LPIPS and the Inception network in Fréchet Inception Distance (FID) (heusel2017gans), which is the predominant metric for image quality. As analyzed in kynkanniemi2023the, improvements of FIDs can come from accidental leakage of ImageNet features from LPIPS, causing inflated FID scores. Secondly, learned metrics require pre-training auxiliary networks for feature extraction. Training with these metrics requires backpropagating through extra neural networks, which increases the demand for compute.

To tackle these challenges, we introduce improved techniques for CT that not only surpass CD in sample quality but also eliminate the dependence on learned metrics like LPIPS. Our techniques are motivated from both theoretical analysis, and comprehensive experiments on the CIFAR-10 dataset (krizhevsky2014cifar). Specifically, we perform an in-depth study on the empirical impact of weighting functions, noise embeddings, and dropout in CT. Additionally, we identify an overlooked flaw in prior theoretical analysis for CT and propose a simple fix by removing the Exponential Moving Average (EMA) from the teacher network. We adopt Pseudo-Huber losses from robust statistics to replace LPIPS. Furthermore, we study how sample quality improves as the number of discretization steps increases, and utilize the insights to propose a simple but effective curriculum for total discretization steps. Finally, we propose a new schedule for sampling noise levels in the CT objective based on lognormal distributions.

Taken together, these techniques allow CT to attain FID scores of 2.51 and 3.25 for CIFAR-10 and ImageNet 64×6464\times 64 in one sampling step, respectively. These scores not only surpass CD but also represent improvements of 3.5×\times and 4×\times over previous CT methods. Furthermore, they significantly outperform the best few-step diffusion distillation techniques for diffusion models even without the need for distillation. By two-step generation, we achieve improved FID scores of 2.24 and 2.77 on CIFAR-10 and ImageNet 64×6464\times 64, surpassing the scores from CD in both one-step and two-step settings. These results rival many top-tier diffusion models and GANs, showcasing the strong promise of consistency models as a new independent family of generative models.

Consistency models

where the term xlogpσ(x)\nabla_{\mathbf{x}}\log p_{\sigma}({\mathbf{x}}) is known as the score function of pσ(x)p_{\sigma}({\mathbf{x}}) (song2019sliced; song2019generative; song2020improved; song2021scorebased). Here σmin\sigma_{\text{min}} is a small positive value such that pσmin(x)pdata(x)p_{\sigma_{\text{min}}}({\mathbf{x}})\approx p_{\text{data}}({\mathbf{x}}), introduced to avoid numerical issues in ODE solving. Meanwhile, σmax\sigma_{\text{max}} is sufficiently large so that pσ(x)N(0,σmax2I)p_{\sigma}({\mathbf{x}})\approx\mathcal{N}(\bm{0},\sigma_{\text{max}}^{2}\mathbf{I}). Following Karras2022edm; song2023consistency, we adopt σmin=0.002\sigma_{\text{min}}=0.002, and σmax=80\sigma_{\text{max}}=80 throughout the paper. Crucially, solving the probability flow ODE from noise level σ1\sigma_{1} to σ2\sigma_{2} allows us to transform a sample xσ1pσ1(x){\mathbf{x}}_{\sigma_{1}}\sim p_{\sigma_{1}}({\mathbf{x}}) into xσ2pσ2(x){\mathbf{x}}_{\sigma_{2}}\sim p_{\sigma_{2}}({\mathbf{x}}).

The ODE in Eq. 1 establishes a bijective mapping between a noisy data sample xσpσ(x){\mathbf{x}}_{\sigma}\sim p_{\sigma}({\mathbf{x}}) and xσminpσmin(x)pdata(x){\mathbf{x}}_{\sigma_{\text{min}}}\sim p_{\sigma_{\text{min}}}({\mathbf{x}})\approx p_{\text{data}}({\mathbf{x}}). This mapping, denoted as f:(xσ,σ)xσmin{\bm{f}}^{*}:({\mathbf{x}}_{\sigma},\sigma)\mapsto{\mathbf{x}}_{\sigma_{\text{min}}}, is termed the consistency function. By its very definition, the consistency function satisfies the boundary condition f(x,σmin)=x{\bm{f}}^{*}({\mathbf{x}},\sigma_{\text{min}})={\mathbf{x}}. A consistency model, which we denote by fθ(x,σ){\bm{f}}_{\bm{\theta}}({\mathbf{x}},\sigma), is a neural network trained to approximate the consistency function f(x,σ){\bm{f}}^{*}({\mathbf{x}},\sigma). To meet the boundary condition, we follow song2023consistency to parameterize the consistency model as

where Fθ(x,σ)\bm{F}_{\bm{\theta}}({\mathbf{x}},\sigma) is a free-form neural network, while cskip(σ)c_{\text{skip}}(\sigma) and cout(σ)c_{\text{out}}(\sigma) are differentiable functions such that cskip(σmin)=1c_{\text{skip}}(\sigma_{\text{min}})=1 and cout(σmin)=0c_{\text{out}}(\sigma_{\text{min}})=0.

To train the consistency model, we discretize the probability flow ODE using a sequence of noise levels σmin=σ1<σ2<<σN=σmax\sigma_{\text{min}}=\sigma_{1}<\sigma_{2}<\cdots<\sigma_{N}=\sigma_{\text{max}}, where we follow Karras2022edm; song2023consistency in setting σi=(σmin1/ρ+i1N1(σmax1/ρσmin1/ρ))ρ\sigma_{i}=(\sigma_{\text{min}}^{1/\rho}+\frac{i-1}{N-1}(\sigma_{\text{max}}^{1/\rho}-\sigma_{\text{min}}^{1/\rho}))^{\rho} for i1,Ni\in\llbracket 1,N\rrbracket, and ρ=7\rho=7, where a,b\llbracket a,b\rrbracket denotes the set of integers {a,a+1,,b}\{a,a+1,\cdots,b\}. The model is trained by minimizing the following consistency matching (CM) loss over θ{\bm{\theta}}:

Given that x˘σi\breve{{\mathbf{x}}}_{\sigma_{i}} relies on the unknown score function xlogpσi+1(x)\nabla_{\mathbf{x}}\log p_{\sigma_{i+1}}({\mathbf{x}}), directly optimizing the consistency matching objective in Eq. 3 is infeasible. To circumvent this challenge, song2023consistency propose two training algorithms: consistency distillation (CD) and consistency training (CT). For consistency distillation, we first train a diffusion model sϕ(x,σ){\bm{s}}_{\bm{\phi}}({\mathbf{x}},\sigma) to estimate xlogpσ(x)\nabla_{\mathbf{x}}\log p_{\sigma}({\mathbf{x}}) via score matching (hyvarinen-EstimationNonNormalizedStatistical-2005; vincent2011connection; song2019sliced; song2019generative), then approximate x˘σi\breve{{\mathbf{x}}}_{\sigma_{i}} with x^σi=xσi+1(σiσi+1)σi+1sϕ(xσi+1,σi+1)\hat{{\mathbf{x}}}_{\sigma_{i}}={\mathbf{x}}_{\sigma_{i+1}}-(\sigma_{i}-\sigma_{i+1})\sigma_{i+1}{\bm{s}}_{\bm{\phi}}({\mathbf{x}}_{\sigma_{i+1}},\sigma_{i+1}). On the other hand, consistency training employs a different approximation method. Recall that xσi+1=x+σi+1z{\mathbf{x}}_{\sigma_{i+1}}={\mathbf{x}}+\sigma_{i+1}{\mathbf{z}} with xpdata(x){\mathbf{x}}\sim p_{\text{data}}({\mathbf{x}}) and zN(0,I){\mathbf{z}}\sim\mathcal{N}(\bm{0},{\bm{I}}). Using the same x{\mathbf{x}} and z{\mathbf{z}}, song2023consistency define xˇσi=x+σiz\check{{\mathbf{x}}}_{\sigma_{i}}={\mathbf{x}}+\sigma_{i}{\mathbf{z}} as an approximation to x˘σi\breve{{\mathbf{x}}}_{\sigma_{i}}, which leads to the consistency training objective below:

As analyzed in song2023consistency, this objective is asymptotically equivalent to consistency matching in the limit of NN\to\infty. We will revisit this analysis in LABEL:sec:noema.

After training a consistency model fθ(x,σ){\bm{f}}_{\bm{\theta}}({\mathbf{x}},\sigma) through CD or CT, we can directly generate a sample x{\mathbf{x}} by starting with zN(0,σmax2I){\mathbf{z}}\sim\mathcal{N}(\bm{0},\sigma_{\text{max}}^{2}{\bm{I}}) and computing x=fθ(z,σmax){\mathbf{x}}={\bm{f}}_{\bm{\theta}}({\mathbf{z}},\sigma_{\text{max}}). Notably, these models also enable multistep generation. For a sequence of indices 1=i1<i2<<iK=N1=i_{1}<i_{2}<\cdots<i_{K}=N, we start by sampling xKN(0,σmax2I){\mathbf{x}}_{K}\sim\mathcal{N}(\bm{0},\sigma_{\text{max}}^{2}{\bm{I}}) and then iteratively compute xkfθ(xk+1,σik+1)+σik2σmin2zk{\mathbf{x}}_{k}\leftarrow{\bm{f}}_{\bm{\theta}}({\mathbf{x}}_{k+1},\sigma_{i_{k+1}})+\sqrt{\sigma_{i_{k}}^{2}-\sigma_{\text{min}}^{2}}{\mathbf{z}}_{k} for k=K1,K2,,1k=K-1,K-2,\cdots,1, where zkN(0,I){\mathbf{z}}_{k}\sim\mathcal{N}(\bm{0},{\bm{I}}). The resulting sample x1{\mathbf{x}}_{1} approximates the distribution pdata(x)p_{\text{data}}({\mathbf{x}}). In our experiments, setting K=3K=3 (two-step generation) often enhances the quality of one-step generation considerably, though increasing the number of sampling steps further provides diminishing benefits.

Improved techniques for consistency training

Below we re-examine the design choices of CT in song2023consistency and pinpoint modifications that improve its performance, which we summarize in Section 3. We focus on CT without learned metric functions. For our experiments, we employ the Score SDE architecture in song2021scorebased and train the consistency models for 400,000 iterations on the CIFAR-10 dataset (krizhevsky2014cifar) without class labels. While our primary focus remains on CIFAR-10 in this section, we observe similar improvements on other datasets, including ImageNet 64×6464\times 64 (deng2009imagenet). We measure sample quality using Fréchet Inception Distance (FID) (heusel2017gans).