Inductive Moment Matching

Linqi Zhou, Stefano Ermon, Jiaming Song

Introduction

Generative models for continuous domains have enabled numerous applications in images (Rombach et al., 2022; Saharia et al., 2022; Esser et al., 2024), videos (Ho et al., 2022a; Blattmann et al., 2023; OpenAI, 2024), and audio (Chen et al., 2020; Kong et al., 2020; Liu et al., 2023), yet achieving high-fidelity outputs, efficient inference, and stable training remains a core challenge — a trilemma that continues to motivate research in this domain. Diffusion models (Sohl-Dickstein et al., 2015; Ho et al., 2020; Song et al., 2020b), one of the leading techniques, require many inference steps for high-quality results, while step-reduction methods, such as diffusion distillation (Yin et al., 2024; Sauer et al., 2025; Zhou et al., 2024; Luo et al., 2024a) and Consistency Models (Song et al., 2023; Geng et al., 2024; Lu & Song, 2024; Kim et al., 2023), often risk training collapse without careful tuning and regularization (such as pre-generating data-noise pair and early stopping).

To address the aforementioned trilemma, we introduce Inductive Moment Matching (IMM), a stable, single-stage training procedure that learns generative models from scratch for single- or multi-step inference. IMM operates on the time-dependent marginal distributions of stochastic interpolants (Albergo et al., 2023) — continuous-time stochastic processes that connect two arbitrary probability density functions (data at t=0t=0 and prior at t=1t=1). By learning a (stochastic or deterministic) mapping from any marginal at time tt to any marginal at time s<ts<t, it can naturally support one- or multi-step generation (Figure 2).

IMM models can be trained efficiently from mathematical induction. For time s<r<ts<r<t, we form two distributions at ss by running a one-step IMM from samples at rr and tt. We then minimize their divergence, enforcing that the distributions at ss are independent of the starting time-steps. This construction by induction guarantees convergence to the data distribution. To help with training stability, we model IMM based on certain stochastic interpolants and optimize the objective with stable sample-based divergence estimators such as moment matching (Gretton et al., 2012). Notably, we prove that Consistency Models (CMs) are a single-particle, first-moment matching special case of IMM, which partially explains the training instability of CMs.

On ImageNet-256 ⁣× ⁣256256{\mkern-1.0mu\times\mkern-1.0mu}256, IMM surpasses diffusion models and achieves 1.99 FID with only 8 inference steps using standard transformer architectures. On CIFAR-10, IMM similarly achieves state-of-the-art of 1.98 FID with 2-step generation for a model trained from scratch.

Preliminaries

When γt0\gamma_{t}\equiv 0 and It(x,ϵ)=αtx+σtϵ{\bm{I}}_{t}({\mathbf{x}},\bm{\epsilon})=\alpha_{t}{\mathbf{x}}+\sigma_{t}\bm{\epsilon} for αt,σt\alpha_{t},\sigma_{t} defined in FM, the intermediate variable xt=αtx+σtϵ{\mathbf{x}}_{t}=\alpha_{t}{\mathbf{x}}+\sigma_{t}\bm{\epsilon} becomes a deterministic interpolation and its interpolant velocity vt=αtx+σtϵ{\mathbf{v}}_{t}=\alpha_{t}^{\prime}{\mathbf{x}}+\sigma_{t}^{\prime}\bm{\epsilon} reduces to FM velocity. Thus, its training and inference both reduce to that of FM. When ϵN(0,I)\bm{\epsilon}\sim{\mathcal{N}}(0,{I}), stochastic interpolants reduce to vv-prediction diffusion.

2 Maximum Mean Discrepancy

Inductive Moment Matching

We introduce Inductive Moment Matching (IMM), a method that trains a model of both high quality and sampling efficiency in a single stage. To do so, we assume a time-augmented interpolation between data (distribution at t=0t=0) and prior (distribution at t=1t=1) and propose learning an implicit one-step model (i.e. a one-step sampler) that transforms the distribution at time tt to the distribution at time ss for any s<ts<t (Section 3.1). The model enables direct one-step sampling from t=1t=1 to s=0s=0 and few-step sampling via recursive application from any tt to any r<tr<t and then to any s<rs<r until s=0s=0; this allows us to learn the model from its own samples via bootstrapping (Section 3.2).

Given data xq(x){\mathbf{x}}\sim q({\mathbf{x}}) and prior ϵp(ϵ)\bm{\epsilon}\sim p(\bm{\epsilon}), the time-augmented interpolation xt{\mathbf{x}}_{t} defined in Albergo et al. (2023) follows xtqt(xtx,ϵ){\mathbf{x}}_{t}\sim q_{t}({\mathbf{x}}_{t}|{\mathbf{x}},\bm{\epsilon}). This implies a marginal interpolating distribution

We learn a model distribution implicitly defined by a one-step sampler that transforms qt(xt)q_{t}({\mathbf{x}}_{t}) into qs(xs)q_{s}({\mathbf{x}}_{s}) for some sts\leq t. This can be done via a special class of interpolants, which preserves the marginal distribution qs(xs)q_{s}({\mathbf{x}}_{s}) while interpolating between x{\mathbf{x}} and xt{\mathbf{x}}_{t}. We term these marginal-preserving interpolants among a class of generalized interpolants. Formally, we define xs{\mathbf{x}}_{s} as a generalized interpolant between x{\mathbf{x}} and xt{\mathbf{x}}_{t} if, for all s[0,t]s\in[0,t], its distribution follows

and satisfies constraints Itt(x,xt)=xt{\bm{I}}_{t|t}({\mathbf{x}},{\mathbf{x}}_{t})={\mathbf{x}}_{t}, I0t(x,xt)=x{\bm{I}}_{0|t}({\mathbf{x}},{\mathbf{x}}_{t})={\mathbf{x}}, γtt=γ0t=0\gamma_{t|t}=\gamma_{0|t}=0, and qt1(xtx,ϵ)qt(xtx,ϵ)q_{t|1}({\mathbf{x}}_{t}|{\mathbf{x}},\bm{\epsilon})\equiv q_{t}({\mathbf{x}}_{t}|{\mathbf{x}},\bm{\epsilon}). When t=1t=1, it reduces to regular stochastic interpolants. Next, we define marginal-preserving interpolants.

A generalized interpolant xs{\mathbf{x}}_{s} is marginal-preserving if for all tt\in and for all s[0,t]s\in[0,t], the following equality holds:

That is, this class of interpolants has the same marginal at ss regardless of tt. For all tt\in, we define our noisy model distribution at s[0,t]s\in[0,t] as

where the interpolant is marginal preserving and pstθ(xxt)p_{s|t}^{\theta}({\mathbf{x}}|{\mathbf{x}}_{t}) is our clean model distribution implicitly parameterized as a one-step sampler. This definition also enables multistep sampling. To produce a clean sample x{\mathbf{x}} given xtqt(xt){\mathbf{x}}_{t}\sim q_{t}({\mathbf{x}}_{t}) in two steps via an intermediate ss: (1) we sample x^pstθ(xxt)\hat{{\mathbf{x}}}\sim p_{s|t}^{\theta}({\mathbf{x}}|{\mathbf{x}}_{t}) followed by x^sqst(xsx^,xt)\hat{{\mathbf{x}}}_{s}\sim q_{s|t}({\mathbf{x}}_{s}|\hat{{\mathbf{x}}},{\mathbf{x}}_{t}) and (2) if the marginal of x^s\hat{{\mathbf{x}}}_{s} matches qs(xs)q_{s}({\mathbf{x}}_{s}), we can obtain x{\mathbf{x}} by xp0sθ(xx^s){\mathbf{x}}\sim p_{0|s}^{\theta}({\mathbf{x}}|\hat{{\mathbf{x}}}_{s}). We are therefore motivated to minimize divergence between Eq. (4) and (6) using the objective below.

Naïve objective. As one can easily draw samples from the model, it can be naïvely learned by directly minimizing

with time distribution p(s,t)p(s,t) and a sample-based divergence metric D(,)D(\cdot,\cdot) such as MMD or GAN (Goodfellow et al., 2020). If an interpolant xs{\mathbf{x}}_{s} is marginal-preserving, then the minimum loss is 0 (see Lemma 3). One might also notice the similarity between right-hand sides of Eq. (4) and (6). However, qs(xs)=pstθ(xs)q_{s}({\mathbf{x}}_{s})=p_{s|t}^{\theta}({\mathbf{x}}_{s}) does not necessarily imply pstθ(xxt)=qt(xxt)p_{s|t}^{\theta}({\mathbf{x}}|{\mathbf{x}}_{t})=q_{t}({\mathbf{x}}|{\mathbf{x}}_{t}). In fact, the minimizer pstθ(xxt)p_{s|t}^{\theta}({\mathbf{x}}|{\mathbf{x}}_{t}) is not unique and, under mild assumptions, a deterministic minimizer exists (see Section 4).

2 Learning via Inductive Bootstrapping

While sound, the naïve objective in Eq. (7) is difficult to optimize in practice because when tt is far from ss, the input distribution qt(xt)q_{t}({\mathbf{x}}_{t}) can be far from the target qs(xs)q_{s}({\mathbf{x}}_{s}). Fortunately, our interpolant construction implies that the model definition in Eq. (6) satisfies boundary condition qs(xs)=pssθ(xs)q_{s}({\mathbf{x}}_{s})=p_{s|s}^{\theta}({\mathbf{x}}_{s}) regardless of θ\theta (see Lemma 4), which indicates that pstθ(xs)qs(xs)p_{s|t}^{\theta}({\mathbf{x}}_{s})\approx q_{s}({\mathbf{x}}_{s}) when tt is close to ss. Furthermore, the interpolant enforces pstθ(xs)psrθ(xs)p_{s|t}^{\theta}({\mathbf{x}}_{s})\approx p_{s|r}^{\theta}({\mathbf{x}}_{s}) for any r<tr<t close to tt as long as the model is continuous around tt. Therefore, we can construct an inductive learning algorithm for pstθ(xs)p_{s|t}^{\theta}({\mathbf{x}}_{s}) by using samples from psrθ(xs)p_{s|r}^{\theta}({\mathbf{x}}_{s}).

For better analysis, we define a sequence number nn for parameter θn\theta_{n} and function r(s,t)r(s,t) where sr(s,t)<ts\leq r(s,t)<t such that pstθn(xs)p_{s|t}^{\theta_{n}}({\mathbf{x}}_{s}) learns to match psrθn1(xs)p_{s|r}^{\theta_{n-1}}({\mathbf{x}}_{s}).Note that nn is different from optimization steps. Advancing from n1n-1 to nn can take arbitrary number of optimization steps. We omit rr’s arguments when context is clear and let r(s,t)r(s,t) be a finite decrement from tt but truncated at sts\leq t (see Appendix B.3 for well-conditioned r(s,t)r(s,t)).

General objective. With marginal-preserving interpolants and mapping r(s,t)r(s,t), we learn θn\theta_{n} in the following objective:

where w(s,t)w(s,t) is a weighting function. We choose MMD as our objective due to its superior optimization stability and show that this objective learns the correct data distribution.

Assuming r(s,t)r(s,t) is well-conditioned, the interpolant is marginal-preserving, and θn\theta_{n}^{*} is a minimizer of Eq. (8) for each nn with infinite data and network capacity, for all tt\in, s[0,t]s\in[0,t],

In other words, θn\theta_{n} eventually learns the target distribution qs(xs)q_{s}({\mathbf{x}}_{s}) by parameterizing a one-step sampler pstθn(xxt)p_{s|t}^{\theta_{n}}({\mathbf{x}}|{\mathbf{x}}_{t}).

Simplified Formulation and Practice

We present algorithmic and practical decisions below.

Despite theoretical soundness, it remains unclear how to empirically choose a marginal-preserving interpolant. First, we present a sufficient condition for marginal preservation.

Given s,ts,t\in, sts\leq t, an interpolant xsqst(xsx,xt){\mathbf{x}}_{s}\sim q_{s|t}({\mathbf{x}}_{s}|{\mathbf{x}},{\mathbf{x}}_{t}) is self-consistent if for all r[s,t]r\in[s,t], the following holds:

In other words, xs{\mathbf{x}}_{s} has the same distribution if one (1) directly samples it by interpolating x{\mathbf{x}} and xt{\mathbf{x}}_{t} and (2) first samples any xr{\mathbf{x}}_{r} (given x{\mathbf{x}} and xt{\mathbf{x}}_{t}) and then samples xs{\mathbf{x}}_{s} (given x{\mathbf{x}} and xr{\mathbf{x}}_{r}). Furthermore, self-consistency implies marginal preservation (Lemma 5).

DDIM interpolant. Denoising Diffusion Implicit Models (Song et al., 2020a) was introduced as a fast ODE sampler for diffusion models, defined as

(Informal) If γst0\gamma_{s|t}\equiv 0 and Ist(x,xt){\bm{I}}_{s|t}({\mathbf{x}},{\mathbf{x}}_{t}) satisfies mild assumptions, there exists a deterministic pstθ(xxt)p_{s|t}^{\theta}({\mathbf{x}}|{\mathbf{x}}_{t}) that attains 0 loss for Eq. (7).

See Appendix B.6 for formal statement and proof. This allows us to define pstθ(xxt)=δ(xgθ(xt,s,t))p_{s|t}^{\theta}({\mathbf{x}}|{\mathbf{x}}_{t})=\delta({\mathbf{x}}-{\bm{g}}_{\theta}({\mathbf{x}}_{t},s,t)) for a neural network gθ(xt,s,t){\bm{g}}_{\theta}({\mathbf{x}}_{t},s,t) with parameter θ\theta by default.

Eliminating stochasticity. We use DDIM interpolant, deterministic model, and prior p(ϵ)=N(0,σd2I)p(\bm{\epsilon})={\mathcal{N}}(0,\sigma_{d}^{2}{I}) where σd\sigma_{d} is the data standard deviation (Lu & Song, 2024). As a result, one can draw xs{\mathbf{x}}_{s} from model via xs=fs,tθ(xt):=DDIM(xt,gθ(xt,s,t),s,t){\mathbf{x}}_{s}={\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t})\vcentcolon=\operatorname{DDIM}({\mathbf{x}}_{t},{\bm{g}}_{\theta}({\mathbf{x}}_{t},s,t),s,t) where xtqt(xt){\mathbf{x}}_{t}\sim q_{t}({\mathbf{x}}_{t}).

Re-using xt{\mathbf{x}}_{t} for xr{\mathbf{x}}_{r}. Inspecting Eq. (8) and (6), one requires xrqr(xr){\mathbf{x}}_{r}\sim q_{r}({\mathbf{x}}_{r}) to generate samples from the target distribution. Instead of sampling xr{\mathbf{x}}_{r} given a new (x,ϵ)({\mathbf{x}},\bm{\epsilon}) pair, we can reduce variance by reusing xt{\mathbf{x}}_{t} and x{\mathbf{x}} such that xr=DDIM(xt,x,r,t){\mathbf{x}}_{r}=\operatorname{DDIM}({\mathbf{x}}_{t},{\mathbf{x}},r,t). This is justified because xr{\mathbf{x}}_{r} derived from xt{\mathbf{x}}_{t} preserves the marginal distribution qr(xr)q_{r}({\mathbf{x}}_{r}) (see Appendix C.2).

Stop gradient. We set nn to optimization step number, i.e. advancing from n1n-1 to nn is a single optimizer step where θn\theta_{n} is initialized from θn1\theta_{n-1}. Equivalently, we can omit nn from θn\theta_{n} and write θn1\theta_{n-1} as the stop-gradient parameter θ\theta^{-}.

Simplified objective. Let xt,xt{\mathbf{x}}_{t},{\mathbf{x}}_{t}^{\prime} be i.i.d. random variables from qt(xt)q_{t}({\mathbf{x}}_{t}) and xr,xr{\mathbf{x}}_{r},{\mathbf{x}}_{r}^{\prime} are variables obtained by reusing xt,xt{\mathbf{x}}_{t},{\mathbf{x}}_{t}^{\prime} respectively, the training objective can be derived from the MMD definition in Eq. (1) (see Appendix C.3) as

where ys,t=fs,tθ(xt){\mathbf{y}}_{s,t}={\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}), ys,t=fs,tθ(xt){\mathbf{y}}_{s,t}^{\prime}={\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}^{\prime}), ys,r=fs,rθ(xr){{\mathbf{y}}}_{s,r}={\bm{f}}_{s,r}^{\theta^{-}}({\mathbf{x}}_{r}), ys,r=fs,rθ(xr){{\mathbf{y}}}_{s,r}^{\prime}={\bm{f}}_{s,r}^{\theta^{-}}({\mathbf{x}}_{r}^{\prime}), k(,)k(\cdot,\cdot) is a kernel function, and w(s,t)w(s,t) is a prior weighting function.

An empirical estimate of the above objective uses MM particle samples to approximate each distribution indexed by tt. In practice, we divide a batch of model output with size BB into B/MB/M groups within which share the same (s,t)(s,t) sample, and the objective is approximated by instantiating B/MB/M number of M×MM\times M matrices. Note that the number of model passes does not change with respect to MM (see Appendix C.4). A M=2M=2 version is visualized in Figure 3 and a simplified training algorithm is shown in Algorithm 1. A full training algorithm is shown in Appendix D.

2 Other Implementation Choices

We defer detailed analysis of each decision to Appendix C.

Flow trajectories. We investigate the two most used flow trajectories (Nichol & Dhariwal, 2021; Lipman et al., 2022),

Cosine. αt=cos(12πt)\alpha_{t}=\cos(\frac{1}{2}\pi t), σt=sin(12πt)\sigma_{t}=\sin(\frac{1}{2}\pi t).

Network gθ(xt,s,t){\bm{g}}_{\theta}({\mathbf{x}}_{t},s,t). We set gθ(xt,s,t)=cskip(t)xt+cout(t)Gθ(cin(t)xt,cnoise(s),cnoise(t)){\bm{g}}_{\theta}({\mathbf{x}}_{t},s,t)=c_{\text{skip}}(t){\mathbf{x}}_{t}+c_{\text{out}}(t){\bm{G}}_{\theta}(c_{\text{in}}(t){\mathbf{x}}_{t},c_{\text{noise}}(s),c_{\text{noise}}(t)) with a neural network Gθ{\bm{G}}_{\theta}, following EDM (Karras et al., 2022). For all choices we let cin(t)=1/αt2+σt2/σdc_{\text{in}}(t)=1/\sqrt{\alpha_{t}^{2}+\sigma_{t}^{2}}/\sigma_{d} (Lu & Song, 2024). Listed below are valid choices for other coefficients.

Identity. cskip(t)=0c_{\text{skip}}(t)=0, cout(t)=1c_{\text{out}}(t)=1.

Simple-EDM (Lu & Song, 2024). cskip(t)=αt/(αt2+σt2)c_{\text{skip}}(t)=\alpha_{t}/(\alpha_{t}^{2}+\sigma_{t}^{2}), cout(t)=σdσt/αt2+σt2c_{\text{out}}(t)=-\sigma_{d}\sigma_{t}/\sqrt{\alpha_{t}^{2}+\sigma_{t}^{2}}.

Euler-FM. cskip(t)=1c_{\text{skip}}(t)=1, cout(t)=tσdc_{\text{out}}(t)=-t\sigma_{d}. This is specific to OT-FM schedule.

We show in Appendix C.5 that fs,tθ(xt){\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}) similarly follows the EDM parameterization of the form fs,tθ(xt)=cskip(s,t)xt+cout(s,t)Gθ(cin(t)xt,cnoise(s),cnoise(t)){\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t})=c_{\text{skip}}(s,t){\mathbf{x}}_{t}+c_{\text{out}}(s,t){\bm{G}}_{\theta}(c_{\text{in}}(t){\mathbf{x}}_{t},c_{\text{noise}}(s),c_{\text{noise}}(t)).

Noise conditioning cnoise()c_{\text{noise}}(\cdot). We choose cnoise(t)=ctc_{\text{noise}}(t)=ct for some constant c1c\geq 1. We find our model convergence relatively insensitive to cc but recommend using larger cc, e.g. 10001000 (Song et al., 2020b; Peebles & Xie, 2023), because it enables sufficient distinction between nearby rr and tt.

Mapping function r(s,t)r(s,t). We find that r(s,t)r(s,t) via constant decrement in ηt=σt/αt\eta_{t}=\sigma_{t}/\alpha_{t} works well where the decrement is chosen in the form of (ηmaxηmin)/2k(\eta_{\text{max}}-\eta_{\text{min}})/2^{k} for some appropriate kk (details in Appendix C.7).

Weighting w(s,t)w(s,t) and distribution p(s,t)p(s,t). We follow VDM (Kingma et al., 2021; Kingma & Gao, 2024) and define p(t)=U(ϵ,T)p(t)={\mathcal{U}}(\epsilon,T) and p(st)=U(ϵ,t)p(s|t)={\mathcal{U}}(\epsilon,t) for constants ϵ,T\epsilon,T\in. Similarly, weighting is defined as

3 Sampling

Restart sampling. Similar to Xu et al. (2023); Song et al. (2023), one can introduce stochasticity during sampling by re-noising a sample to a higher noise-level before sampling again. For example, a two-step restart sampler from xt{\mathbf{x}}_{t} requires s(0,t)s\in(0,t) for drawing sample x^=f0,sθ(xs)\hat{{\mathbf{x}}}={\bm{f}}_{0,s}^{\theta}({\mathbf{x}}_{s}) where xsqs(xsf0,tθ(xt)).{\mathbf{x}}_{s}\sim q_{s}({\mathbf{x}}_{s}|{\bm{f}}_{0,t}^{\theta}({\mathbf{x}}_{t})).

Classifier-free guidance. Given a data-label pair (x,c)({\mathbf{x}},{\mathbf{c}}), during inference time, classifier-free guidance (Ho & Salimans, 2022) with weight ww replaces conditional model output Gθ(xt,s,t,c){\bm{G}}_{\theta}({\mathbf{x}}_{t},s,t,{\mathbf{c}}) by a reweighted model output via

where \varnothing denotes the null-token indicating unconditional output. Similarly, we define our guided model as fs,t,wθ(xt)=cskip(s,t)xt+cout(s,t)Gθw(xt,s,t,c){\bm{f}}_{s,t,w}^{\theta}({\mathbf{x}}_{t})=c_{\text{skip}}(s,t){\mathbf{x}}_{t}+c_{\text{out}}(s,t){\bm{G}}_{\theta}^{w}({\mathbf{x}}_{t},s,t,{\mathbf{c}}) where Gθw(xt,s,t,c){\bm{G}}_{\theta}^{w}({\mathbf{x}}_{t},s,t,{\mathbf{c}}) is as defined in Eq. (14) and we drop cin()c_{\text{in}}(\cdot) and cnoise()c_{\text{noise}}(\cdot) for notational simplicity. We justify this decision in Appendix E. Similar to diffusion models, c{\mathbf{c}} is randomly dropped with probability pp during training without special practices.

We present pushforward sampling in Algorithm 2 and detail both samplers in Appendix F.

Connection with Prior Works

Our work is closely connected with many prior works. Detailed analysis is found in Appendix G.

We show in the following Lemma that CM objective with L2{\mathcal{L}}_{2} distance is a single-particle estimate of IMM objective with energy kernel.

This single-particle estimate ignores the repulsion force imposed by k(,)k(\cdot,\cdot). Energy kernel also only matches the first moment, ignoring all higher moments. These decisions can be significant contributors to training instability and performance degradation of CMs.

Improved CMs (Song & Dhariwal, 2023) propose pseudo-huber loss as d(,)d(\cdot,\cdot) which we justify in the Lemma below.

Negative pseudo-huber loss kc(x,y)=c\normxy2+c2k_{c}(x,y)=c-\sqrt{\norm{x-y}^{2}+c^{2}} for c>0c>0 is a conditionally positive definite kernel that matches all moments of xx and yy where weights on higher moments depend on cc.

From a moment-matching perspective, the improved performance is explained by the loss matching all moments of the distributions. In addition to pseudo-huber loss, many other kernels (Laplace, RBF, etc.) are all valid choices in the design space.

We also extend IMM loss to the differential limit by taking r(s,t)tr(s,t)\rightarrow t, the result of which subsumes the continuous-time CM (Lu & Song, 2024) as a single-particle estimate (Appendix H). We leave experiments for this to future work.

Generative Moment Matching Network. GMMN (Li et al., 2015) directly applies MMD to train a generator Gθ(z){\bm{G}}_{\theta}({\mathbf{z}}) where zN(0,I){\mathbf{z}}\sim{\mathcal{N}}(0,{I}) to match the data distribution. It is a special case of IMM in that when t=1t=1 and r(s,t)s=0r(s,t)\equiv s=0 our loss reduces to naïve GMMN objective.

Related Works

Diffusion, Flow Matching, and stochastic interpolants. Diffusion models (Sohl-Dickstein et al., 2015; Song et al., 2020b; Ho et al., 2020; Kingma et al., 2021) and Flow Matching (Lipman et al., 2022; Liu et al., 2022) are widely used generative frameworks that learn a score or velocity field of a noising process from data into a simple prior. They have been scaled successfully for text-to-image (Rombach et al., 2022; Saharia et al., 2022; Podell et al., 2023; Chen et al., 2023; Esser et al., 2024) and text-to-video (Ho et al., 2022a; Blattmann et al., 2023; OpenAI, 2024) tasks. Stochastic interpolants (Albergo et al., 2023; Albergo & Vanden-Eijnden, 2022) extend these ideas by explicitly defining a stochastic path between data and prior, then matching its velocity to facilitate distribution transfer. While IMM builds on top of the interpolant construction, it directly learns one-step mappings between any intermediate marginal distributions.

Diffusion distillation. To resolve diffusion models’ sampling inefficiency, recent methods (Salimans & Ho, 2022; Meng et al., 2023; Yin et al., 2024; Zhou et al., 2024; Luo et al., 2024a; Heek et al., 2024) focus on distilling one-step or few-step models from pre-trained diffusion models. Some approaches (Yin et al., 2024; Zhou et al., 2024) propose jointly optimizing two networks but the training requires careful tuning in practice and can lead to mode collapse (Yin et al., 2024). Another recent work (Salimans et al., 2024) explicitly matches the first moment of the data distribution available from pre-trained diffusion models. In contrast, our method implicitly matches all moments using MMD and can be trained from scratch with a single model.

Few-step generative models from scratch. Early one-step generative models primarily relied on GANs (Goodfellow et al., 2020; Karras et al., 2020; Brock, 2018) and MMD (Li et al., 2015, 2017) (or their combination) but scaling adversarial training remains challenging. Recent independent classes of few-step models, e.g. Consistency Models (CMs) (Song et al., 2023; Song & Dhariwal, 2023; Lu & Song, 2024), Consistency Trajectory Models (CTMs) (Kim et al., 2023; Heek et al., 2024) and Shortcut Models (SMs) (Frans et al., 2024) still face training instability and require specialized components (Lu & Song, 2024) (e.g., JVP for flash attention) or other special practices (e.g., high weight decay for SMs, combined LPIPS (Zhang et al., 2018) and GAN losses for CTMs, and special training schedules (Geng et al., 2024)) to remain stable. In contrast, our method trains stably with a single loss and achieves strong performance without special training practices.

Family Method FID()(\downarrow) Steps ()(\downarrow) #Params GAN BigGAN (Brock, 2018) 6.95 1 112M GigaGAN (Kang et al., 2023) 3.45 1 569M StyleGAN-XL (Karras et al., 2020) 2.30 1 166M Masked & AR VQGAN (Esser et al., 2021) 26.52 1024 227M MaskGIT (Chang et al., 2022) 6.18 8 227M MAR (Li et al., 2024) 1.98 100 400M VAR-d20d20 (Tian et al., 2024a) 2.57 10 600M VAR-d30d30 (Tian et al., 2024a) 1.92 10 2B Diffusion & Flow ADM (Dhariwal & Nichol, 2021) 10.94 250 554M CDM (Ho et al., 2022b) 4.88 8100 - SimDiff (Hoogeboom et al., 2023) 2.77 512 2B LDM-4-G (Rombach et al., 2022) 3.60 250 400M U-DiT-L (Tian et al., 2024b) 3.37 250 916M U-ViT-H (Bao et al., 2023) 2.29 50 501M DiT-XL/2 (w=1.0w=1.0) (Peebles & Xie, 2023) 9.62 250 675M DiT-XL/2 (w=1.25w=1.25) (Peebles & Xie, 2023) 3.22 250 675M DiT-XL/2 (w=1.5w=1.5) (Peebles & Xie, 2023) 2.27 250 675M SiT-XL/2 (w=1.0w=1.0) (Ma et al., 2024) 9.35 250 675M SiT-XL/2 (w=1.5w=1.5) (Ma et al., 2024) 2.15 250 675M Few-Step from Scratch iCT (Song et al., 2023) 34.24 1 675M 20.3 2 675M Shortcut (Frans et al., 2024) 10.60 1 675M 7.80 4 675M 3.80 128 675M IMM (ours) (XL/2, w=1.25w=1.25) 7.77 1 675M 5.33 2 675M 3.66 4 675M 2.77 8 675M IMM (ours) (XL/2, w=1.5w=1.5) 8.05 1 675M 3.99 2 675M 2.51 4 675M 1.99 8 675M

Experiments

We evaluate IMM’s empirical performance (Section 7.1), training stability (Section 7.2), sampling choices (Section 7.3), scaling behavior (Section 7.4) and ablate our practical decisions (Section 7.5).

We present FID (Heusel et al., 2017) results for unconditional CIFAR-10 and class-conditional ImageNet-256 ⁣× ⁣256256{\mkern-1.0mu\times\mkern-1.0mu}256 in Table 1 and 2. For CIFAR-10, we separate baselines into diffusion and flow models, distillation models, and few-step models from scratch. IMM belongs to the last category in which it achieves state-of-the-art performance of 1.98 using pushforward sampler. For ImageNet-256 ⁣× ⁣256256{\mkern-1.0mu\times\mkern-1.0mu}256, we use the popular DiT (Peebles & Xie, 2023) architecture because of its scalability, and compare it with GANs, masked and autoregressive models, diffusion and flow models, and few-step models trained from scratch.

We observe decreasing FID with more steps and IMM achieves 1.99 FID with 8 steps (with w=1.5w=1.5), surpassing DiT and SiT (Ma et al., 2024) using the same architecture except for trivially injecting time ss (see Appendix I). Notably, we also achieve better 8-step FID than the 10-step VAR (Tian et al., 2024a) of comparable size. At 16 steps, IMM also achieves 1.90 FID outperforming VAR’s 2B variant (see Appendix I.4). However, different from VAR, IMM grants flexibility of variable number of inference steps and the large improvement in FID from 1 to 8 steps additionally demonstrates IMM’s efficient inferece-time scaling capability. Lastly, we similarly surpass Shortcut models’ (Frans et al., 2024) best performance with only 8 steps. We defer inference details to Section 7.3 and Appendix I.2.

2 IMM Training is Stable

We show that IMM is stable and achieves reasonable performance across a range of parameterization choices.

Positional vs. Fourier embedding. A known issue for CMs (Song et al., 2023) is its training instability when using Fourier embedding with scale 1616, which forces reliance on positional embeddings for stability. We find that IMM does not face this problem (see Figure 4). For Fourier embedding we use the standard NCSN++ (Song et al., 2020b) architecture and set embedding scale to 1616; for positional embeddings, we adopt DDPM++ (Song et al., 2020b). Both embedding types converge reliably, and we include samples from the Fourier embedding model in Figure 4.

Particle number. Particle number MM for estimating MMD is an important parameter for empirical success (Gretton et al., 2012; Li et al., 2015), where the estimate is more accurate with larger MM. In our case, naïvely increasing MM can slow down convergence because we have a fixed batch size BB in which the samples are grouped into B/MB/M groups of MM where each group shares the same tt. The larger MM means that fewer tt’s are sampled. On the other hand, using extremely small numbers of particles, e.g. M=2M=2, leads to training instability and performance degradation, especially on a large scale with DiT architectures. We find that there exists a sweet spot where a few particles effectively help with training stability while further increasing MM slows down convergence (see Figure 5). We see that in ImageNet-256 ⁣× ⁣256256{\mkern-1.0mu\times\mkern-1.0mu}256, training collapses when M=1M=1 (which is CM) and M=2M=2, and achieves lowest FID under the same computation budget with M=4M=4. We hypothesize M<4M<4 does not allow sufficient mixing between particles and larger MM means fewer tt’s are sampled for each step, thus slowing convergence. A general rule of thumb is to use a large enough MM for stability, but not too large for slowed convergence.

Noise embedding cnoise()c_{\text{noise}}(\cdot). We plot in Figure 9 the log absolute mean difference of tt and r(s,t)r(s,t) in the positional embedding space. Increasing cc increases distinguishability of nearby distributions. We also observe similar convergence on ImageNet-256 ⁣× ⁣256256{\mkern-1.0mu\times\mkern-1.0mu}256 across different cc, demonstrating the insensitivity of our framework w.r.t. noise function.

3 Sampling

We investigate different sampling settings for best performance. One-step sampling is performed by simple pushforward from TT to ϵ\epsilon (concrete values in Appendix I.2). On CIFAR-10 we use 2 steps and set intermediate time t1t_{1} such that ηt1=1.4\eta_{t_{1}}=1.4, a choice we find to work well empirically. On ImageNet-256 ⁣× ⁣256256{\mkern-1.0mu\times\mkern-1.0mu}256 we go beyond 2 steps and, for simplicity, investigate (1) uniform decrement in tt and (2) EDM (Karras et al., 2024) schedule (detailed in Appendix I.2). We plot FID of all sampler settings in Figure 6 with guidance weight w=1.5w=1.5. We find pushforward samplers with uniform schedule to work the best on ImageNet-256 ⁣× ⁣256256{\mkern-1.0mu\times\mkern-1.0mu}256 and use this as our default setting for multi-step generation. Additionally, we concede that pushforward combined with restart samplers can achieve superior results. We leave such experiments to future works.

4 Scaling Behavior

Similar to diffusion models, IMM scale with training and inference compute as well as model size on ImageNet-256 ⁣× ⁣256256{\mkern-1.0mu\times\mkern-1.0mu}256. We plot in Figure 7 FID vs. training and inference compute in GFLOPs and we find strong correlation between compute used and performance. We further visualize samples in Figure 8 with increasing model size, i.e. DiT-S, DiT-B, DiT-L, DiT-XL, and increasing inference steps, i.e. 1, 2, 4, 8 steps. The sample quality increases along both axes as larger transformers with more inference steps capture more complex distributions. This also explains that more compute can sometimes yield different visual content from the same initial noise as shown in the visual results.

5 Ablation Studies

All ablation studies are done with DDPM++ architecture for CIFAR-10 and DiT-B for ImageNet-256 ⁣× ⁣256256{\mkern-1.0mu\times\mkern-1.0mu}256. FID comparisons use 2-step samplers by default.

Flow schedules and parameterization. We investigate all combinations of network parameterization and flow schedules: Simple-EDM + cosine (sEDM/cos), Simple-EDM + OT-FM (sEDM/FM), Euler-FM + OT-FM (eFM), Identity + cosine (id/cos), Identity + OT-FM (id/FM). Identity parameterization consistently fall behind other types of parameterization, which all show similar performance across datasets (see Table 3). We see that on smaller scale (CIFAR-10), sEDM/FM works the best but on larger scale (ImageNet-256 ⁣× ⁣256256{\mkern-1.0mu\times\mkern-1.0mu}256), eFM works the best, indicating that OT-FM schedule and Euler paramaterization may be more scalable than other choices.

Mapping function r(s,t)r(s,t). Our choices for ablation are (1) constant decrement in ηt\eta_{t}, (2) constant decrement in tt, (3) constant decrement in λt=log(αt2/σt2)\lambda_{t}=\log(\alpha_{t}^{2}/\sigma_{t}^{2}), (4) constant increment in 1/ηt1/\eta_{t} (see Appendix C.6). For fair comparison, we choose the decrement gap so that the minimum tr(s,t)t-r(s,t) is 103\approx 10^{-3} and use the same network parameterization. FID progression in Figure 11 show that (1) consistently outperforms other choices. We additionally ablate the mapping gap using M=4M=4 in (1). The constant decrement is in the form of (ηmaxηmin)/2k(\eta_{\text{max}}-\eta_{\text{min}})/2^{k} for an appropriately chosen kk. We show in Figure 10 that the performance is relatively stable across k{11,12,13}k\in\{11,12,13\} but experiences instability for k=14k=14. This suggests that, for a given particle number, there exists a largest kk for stable optimization.

Conclusion

We present Inductive Moment Matching, a framework that learns a few-step generative model from scratch. It trains by first leveraging self-consistent interpolants to interpolate between data and prior and then matching all moments of its own distribution interpolated to be closer to that of data. Our method guarantees convergence in distribution and generalizes many prior works. Our method also achieves state-of-the-art performance across benchmarks while achieving orders of magnitude faster inference. We hope it provides a new perspective on training few-step models from scratch and inspire a new generation of generative models.

Impact Statement

This paper advances research in diffusion models and generative AI, which enable new creative possibilities and democratize content creation but also raise important considerations. Potential benefits include expanding artistic expression, assisting content creators, and generating synthetic data for research. However, we acknowledge challenges around potential misuse for deepfakes, copyright concerns, and impacts on creative industries. While our work aims to advance technical capabilities, we encourage ongoing discussion about responsible development and deployment of these technologies.

Acknowledgement

We thank Wanqiao Xu, Bokui Shen, Connor Lin, and Samrath Sinha for additional technical discussions and helpful suggestions.

References

Appendix A Background: Properties of Stochastic Interpolants

We note some relevant properties of stochastic interpolants for our exposition.

Boundary satisfaction. For an interpolant distribution qt(xtx,ϵ)q_{t}({\mathbf{x}}_{t}|{\mathbf{x}},\bm{\epsilon}) defined in Albergo et al. (2023), and the marginal qt(xt)q_{t}({\mathbf{x}}_{t}) as defined in Eq. (2), we can check that q1(x1)=p(x1)q_{1}({\mathbf{x}}_{1})=p({\mathbf{x}}_{1}) and q0(x0)=q(x0)q_{0}({\mathbf{x}}_{0})=q({\mathbf{x}}_{0}) so that x1=ϵ{\mathbf{x}}_{1}=\bm{\epsilon} and x0=x{\mathbf{x}}_{0}={\mathbf{x}}.

Joint distribution. The joint distribution of x{\mathbf{x}} and xt{\mathbf{x}}_{t} is written as

in which case x1=ϵ{\mathbf{x}}_{1}=\bm{\epsilon}.

Appendix B Theorems and Derivations

Assuming marginal-preserving interpolant and metric D(,)D(\cdot,\cdot), a minimizer θ\theta^{*} of Eq. (7) exists, i.e. pstθ(xxt)=qt(xxt)p_{s|t}^{\theta^{*}}({\mathbf{x}}|{\mathbf{x}}_{t})=q_{t}({\mathbf{x}}|{\mathbf{x}}_{t}), and the minimum is 0.

We directly substitute qt(xxt)q_{t}({\mathbf{x}}|{\mathbf{x}}_{t}) into the objective to check. First,

where (a)(a) is due to definition of marginal preservation. So the objective becomes

In general, the minimizer qt(xxt)q_{t}({\mathbf{x}}|{\mathbf{x}}_{t}) exists. However, this does not show that the minimizer is unique. In fact, the minimizer is not unique in general because a deterministic minimizer can also exist under certain assumptions on the interpolant (see Appendix B.6). ∎

Failure Case without Marginal Preservation. We additionally show that marginal-preservation property of the interpolant qst(xsx,xt)q_{s|t}({\mathbf{x}}_{s}|{\mathbf{x}},{\mathbf{x}}_{t}) is important for the naïve objective in Eq. (7) to attain 0 loss (Lemma 3). Consider the failure case below where the constructed interpolant is a generalized interpolant but not necessarily marginal-preserving. Then we show that there exists a tt such that pstθ(xs)p_{s|t}^{\theta}({\mathbf{x}}_{s}) can never reach qs(xs)q_{s}({\mathbf{x}}_{s}) regardless of θ\theta.

Let q(x)=δ(x)q({\mathbf{x}})=\delta({\mathbf{x}}), p(ϵ)=δ(ϵ1)p(\bm{\epsilon})=\delta(\bm{\epsilon}-1), and suppose an interpolant Ist(x,xt)=(1st)x+stxt{\bm{I}}_{s|t}({\mathbf{x}},{\mathbf{x}}_{t})=(1-\frac{s}{t}){\mathbf{x}}+\frac{s}{t}{\mathbf{x}}_{t} and γst=st1st\mathmybb1(t<1)\gamma_{s|t}=\frac{s}{t}\sqrt{1-\frac{s}{t}}\mathmybb{1}(t<1), then D(qs(xs),pstθ(xs))>0D(q_{s}({\mathbf{x}}_{s}),p_{s|t}^{\theta}({\mathbf{x}}_{s}))>0 for all 0<s<t<10<s<t<1 regardless of the learned distribution pstθ(xxt)p_{s|t}^{\theta}({\mathbf{x}}|{\mathbf{x}}_{t}) given any metric D(,)D(\cdot,\cdot).

This example first implies the learning target

is a delta distribution. However, we show that if we select any t<1t<1 and 0<s<t0<s<t, pstθ(xs)p_{s|t}^{\theta}({\mathbf{x}}_{s}) can never be a delta distribution.

Now, we show the model distribution has non-zero variance under these choices of tt and ss. Expectations are over pstθ(xs)p_{s|t}^{\theta}({\mathbf{x}}_{s}) or conditional interpolant qst(xsx,xt)q_{s|t}({\mathbf{x}}_{s}|{\mathbf{x}},{\mathbf{x}}_{t}) for all equations below.

B.2 Boundary Satisfaction of Model Distribution

The operator output pstθ(xs)p_{s|t}^{\theta}({\mathbf{x}}_{s}) satisfies boundary condition.

For all ss\in and all θ\theta, the following boundary condition holds.

B.3 Definition of Well-Conditioned r​(s,t)r(s,t)

For simplicity, the mapping function r(s,t)r(s,t) is well-conditioned if

where Δ(t)ϵ>0\Delta(t)\geq\epsilon>0 is a positive function such that r(s,t)r(s,t) is increasing for ts+c0(s)t\geq s+c_{0}(s) where c0(s)c_{0}(s) is the largest tt that is mapped to ss. Formally, c0(s)=sup{t:r(s,t)=s}c_{0}(s)=\sup\{t:r(s,t)=s\}. For ts+c0(s)t\geq s+c_{0}(s), the inverse w.r.t. tt exists, i.e. r1(s,)r^{-1}(s,\cdot) and r1(s,r(s,t))=tr^{-1}(s,r(s,t))=t. All practical implementations follow this general form, and are detailed in Appendix C.6.

B.4 Main Theorem

We prove by induction on sequence number nn. First, r(s,t)r(s,t) is well-conditioned by following the definition in Eq. (46). Furthermore, for notational convenience, we let rn1(s,):=r1(s,r1(s,r1(s,)))r_{n}^{-1}(s,\cdot)\vcentcolon=r^{-1}(s,r^{-1}(s,r^{-1}(s,\dots))) be nn nested application of r1(s,)r^{-1}(s,\cdot) on the second argument. Additionally, r01(s,t)=tr_{0}^{-1}(s,t)=t.

Base case: n=1n=1. Given any s0s\geq 0, r(s,u)=sr(s,u)=s for all s<uc0(s)s<u\leq c_{0}(s), implying

for uc0(s)u\leq c_{0}(s) where (a)(a) is implied by Lemma 4 and (b)(b) is implied by Lemma 3.

Inductive assumption: n1n-1. We assume psuθn1(xs)=qs(xs)p_{s|u}^{\theta_{n-1}^{*}}({\mathbf{x}}_{s})=q_{s}({\mathbf{x}}_{s}) for all surn21(s,c0(s))s\leq u\leq r_{n-2}^{-1}(s,c_{0}(s)).

We inspect the target distribution psr(s,u)θn1(xs)p_{s|r(s,u)}^{\theta_{n-1}^{*}}({\mathbf{x}}_{s}) in Eq. (8) if optimized on surn11(s,c0(s))s\leq u\leq r_{n-1}^{-1}(s,c_{0}(s)). On this interval, we can apply r(s,)r(s,\cdot) to the inequality and get s=r(s,s)r(s,u)r(s,rn11(s,c0(s)))=rn21(s,c0(s))s=r(s,s)\leq r(s,u)\leq r(s,r_{n-1}^{-1}(s,c_{0}(s)))=r_{n-2}^{-1}(s,c_{0}(s)) since r(s,)r(s,\cdot) is increasing. And by inductive assumption psr(s,u)θn1(xs)=qs(xs)p_{s|r(s,u)}^{\theta_{n-1}^{*}}({\mathbf{x}}_{s})=q_{s}({\mathbf{x}}_{s}) for sr(s,u)rn21(s,c0(s))s\leq r(s,u)\leq r_{n-2}^{-1}(s,c_{0}(s)), this implies minimizing

on surn11(s,c0(s))s\leq u\leq r_{n-1}^{-1}(s,c_{0}(s)) is equivalent to minimizing

for surn11(s,c0(s))s\leq u\leq r_{n-1}^{-1}(s,c_{0}(s)). Lemma 3 implies that its minimum achieves psuθn(xs)=qs(xs)p_{s|u}^{\theta_{n}^{*}}({\mathbf{x}}_{s})=q_{s}({\mathbf{x}}_{s}).

B.5 Self-Consistency Implies Marginal Preservation

Without assuming marginal preservation, it is important to define the marginal distribution of xs{\mathbf{x}}_{s} under generalized interpolants qst(xsx,xt)q_{s|t}({\mathbf{x}}_{s}|{\mathbf{x}},{\mathbf{x}}_{t}) as

and we show that with self-consistent interpolants, this distribution is invariant of tt, i.e. qst(xs)=qs(xs)q_{s|t}({\mathbf{x}}_{s})=q_{s}({\mathbf{x}}_{s}).

If the interpolant qst(xsx,xt)q_{s|t}({\mathbf{x}}_{s}|{\mathbf{x}},{\mathbf{x}}_{t}) is self-consistent, the marginal distribution qst(xs)q_{s|t}({\mathbf{x}}_{s}) as defined in Eq. (49) satisfies qs(xs)=qst(xs)q_{s}({\mathbf{x}}_{s})=q_{s|t}({\mathbf{x}}_{s}) for all t[s,1]t\in[s,1].

where (a)(a) uses definition of self-consistent interpolants and (b)(b) uses definition of our generalized interpolant. ∎

We show in Appendix C.1 that DDIM is an example self-consistent interpolant. Furthermore, DDPM posteior (Ho et al., 2020; Kingma et al., 2021) is also self-consistent (see Lemma 6).

B.6 Existence of Deterministic Minimizer

We present the formal statement for the deterministic minimizer.

which pushes forward the measure qt(xt)q_{t}({\mathbf{x}}_{t}) to qs(xs)q_{s}({\mathbf{x}}_{s}). We define:

Then, since γst0\gamma_{s|t}\equiv 0, qst(xsx,xt)=δ(xsIst(x,xt))q_{s|t}({\mathbf{x}}_{s}|{\mathbf{x}},{\mathbf{x}}_{t})=\delta({\mathbf{x}}_{s}-{\bm{I}}_{s|t}({\mathbf{x}},{\mathbf{x}}_{t})) where xδ(xhst(xt)){\mathbf{x}}\sim\delta({\mathbf{x}}-h_{s|t}({\mathbf{x}}_{t})). Therefore,

whose marginal follows qs(xs)q_{s}({\mathbf{x}}_{s}) due to it being the result of PF-ODE trajectories starting from qt(xt)q_{t}({\mathbf{x}}_{t}). ∎

Concretely, DDIM interpolant satisfies all of the deterministic assumption, the regularity condition, and the invertibility assumption because it is a linear function of x{\mathbf{x}} and xt{\mathbf{x}}_{t}. Therefore, any diffusion or FM schedule with DDIM interpolant will enjoy a deterministic minimizer pstθ(xxt)p_{s|t}^{\theta}({\mathbf{x}}|{\mathbf{x}}_{t}).

Appendix C Analysis of Simplified Parameterization

We check that DDIM interpolant is self-consistent. By definition, qst(xsx,xt)=δ(xsDDIM(xt,x,t,s))q_{s|t}({\mathbf{x}}_{s}|{\mathbf{x}},{\mathbf{x}}_{t})=\delta({\mathbf{x}}_{s}-\operatorname{DDIM}({\mathbf{x}}_{t},{\mathbf{x}},t,s)). We check that for all srts\leq r\leq t,

Therefore, δ(xsDDIM(DDIM(xt,x,t,r),x,r,s))=δ(xsDDIM(xt,x,t,s))\delta({\mathbf{x}}_{s}-\operatorname{DDIM}(\operatorname{DDIM}({\mathbf{x}}_{t},{\mathbf{x}},t,r),{\mathbf{x}},r,s))=\delta({\mathbf{x}}_{s}-\operatorname{DDIM}({\mathbf{x}}_{t},{\mathbf{x}},t,s)). So DDIM is self-consistent.

It also implies a Gaussian forward process qt(xtx)=N(αtx,σt2σd2I)q_{t}({\mathbf{x}}_{t}|{\mathbf{x}})={\mathcal{N}}(\alpha_{t}{\mathbf{x}},\sigma_{t}^{2}\sigma_{d}^{2}{I}) as in diffusion models. By definition,

so that xt{\mathbf{x}}_{t} is a deterministic transform given x{\mathbf{x}} and ϵ\bm{\epsilon}, i.e., xt=DDIM(ϵ,x,t,1)=αtx+σtϵ{\mathbf{x}}_{t}=\operatorname{DDIM}(\bm{\epsilon},{\mathbf{x}},t,1)=\alpha_{t}{\mathbf{x}}+\sigma_{t}\bm{\epsilon}, which implies qt(xtx)=N(αtx,σt2σd2I)q_{t}({\mathbf{x}}_{t}|{\mathbf{x}})={\mathcal{N}}(\alpha_{t}{\mathbf{x}},\sigma_{t}^{2}\sigma_{d}^{2}{I}).

C.2 Reusing 𝐱t{\mathbf{x}}_{t} for 𝐱r{\mathbf{x}}_{r}

We propose that instead of sampling xr{\mathbf{x}}_{r} via forward flow αrx+σrϵ\alpha_{r}{\mathbf{x}}+\sigma_{r}\bm{\epsilon} we reuse xt{\mathbf{x}}_{t} such that xr=DDIM(xt,x,r,t){\mathbf{x}}_{r}=\operatorname{DDIM}({\mathbf{x}}_{t},{\mathbf{x}},r,t) to reduce variance. In fact, for any self-consistent interpolant, one can reuse xt{\mathbf{x}}_{t} via xrqrt(xrx,xt){\mathbf{x}}_{r}\sim q_{r|t}({\mathbf{x}}_{r}|{\mathbf{x}},{\mathbf{x}}_{t}) and xr{\mathbf{x}}_{r} will follow qr(xr)q_{r}({\mathbf{x}}_{r}) marginally. We check

where (a)(a) is due to Lemma 5. We can see that sampling x,xt{\mathbf{x}},{\mathbf{x}}_{t} first then xrqrt(xrx,xt){\mathbf{x}}_{r}\sim q_{r|t}({\mathbf{x}}_{r}|{\mathbf{x}},{\mathbf{x}}_{t}) respects the marginal distribution qr(xr)q_{r}({\mathbf{x}}_{r}).

C.3 Simplified Objective

We derive our simplified objective. Given MMD defined in Eq. (1), we write our objective as

where ,\langle\cdot,\cdot\rangle is in RKHS, (a)(a) is due to the correlation between xr{\mathbf{x}}_{r} and xt{\mathbf{x}}_{t} by re-using xt{\mathbf{x}}_{t}.

C.4 Empirical Estimation

As proposed in Gretton et al. (2012), MMD is typically estimated with V-statistics by instantiating a matrix of size M×MM\times M such that a batch of BB x{\mathbf{x}} samples, {x(i)}i=1B\{x^{(i)}\}_{i=1}^{B}, is separated into groups of MM (assume BB is divisible by MM) particles {x(i,j)}i=1,j=1B/M,M\{x^{(i,j)}\}_{i=1,j=1}^{B/M,M} where each group share a (si,ri,ti)(s^{i},r^{i},t^{i}) sample. The Monte Carlo estimate becomes

Computational efficiency. First we note that regardless of MM, we require only 2 model forward passes - one with and one without stop gradient, since the model takes in all BB instances together within the batch and produce outputs for the entire batch. For the calculation of our loss, although the need for MM particles may imply inefficient computation, the cost of this matrix computation is negligible in practice compared to the complexity of model forward pass. Suppose a forward pass for a single instance is O(K){\mathcal{O}}(K), then the total computation for computation loss for a batch of BB instances is O(BK)+O(BM){\mathcal{O}}(BK)+{\mathcal{O}}(BM). Deep neural networks often has KMK\gg M, so O(BK){\mathcal{O}}(BK) dominates the computation.

C.5 Simplified Parameterization

We derive fs,tθ(xt){\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}) for each parameterization, which now generally follows the form

Identity. This is simply DDIM with xx-prediction network.

When noise schedule is cosine, fs,tθ(xt)=cos(12π(ts))xtsin(12π(ts))σdGθ(xtσdαt2+σt2,cnoise(s),cnoise(t)){\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t})=\cos(\frac{1}{2}\pi(t-s)){\mathbf{x}}_{t}-\sin(\frac{1}{2}\pi(t-s))\sigma_{d}{\bm{G}}_{\theta}\left(\frac{{\mathbf{x}}_{t}}{\sigma_{d}\sqrt{\alpha_{t}^{2}+\sigma_{t}^{2}}},c_{\text{noise}}(s),c_{\text{noise}}(t)\right). And similar to Lu & Song (2024), we can show that predicting xs=DDIM(xt,x,s,t){\mathbf{x}}_{s}=\operatorname{DDIM}({\mathbf{x}}_{t},{\mathbf{x}},s,t) with L2{\mathcal{L}}_{2} loss is equivalent to vv-prediction with cosine schedule.

This reduces to vv-target if cosine schedule is used, and it deviates from vv-target if FM schedule is used instead.

This results in Euler ODE from xt{\mathbf{x}}_{t} to xs{\mathbf{x}}_{s}. We also show that the network output reduces to vv-prediction if matched with xs=DDIM(xt,x,s,t){\mathbf{x}}_{s}=\operatorname{DDIM}({\mathbf{x}}_{t},{\mathbf{x}},s,t). To see this,

which is vv-target under OT-FM schedule. This parameterization naturally allows zero-SNR sampling and satisfies boundary condition at s=0s=0, similar to Simple-EDM above. This is not true for Identity parametrization using Gθ{\bm{G}}_{\theta} as it satisfies boundary condition only at s>0s>0.

C.6 Mapping Function r​(s,t)r(s,t)

We discuss below the concrete choices for r(s,t)r(s,t). We use a constant decrement ϵ>0\epsilon>0 in different spaces.

Constant decrement in η(t):=ηt=σt/αt\eta(t)\vcentcolon=\eta_{t}=\sigma_{t}/\alpha_{t}. This is the choice that we find to work better than other choices in practice. First, let its inverse be η1()\eta^{-1}(\cdot),

We choose ϵ=(ηmaxηmin)/2k\epsilon=(\eta_{\text{max}}-\eta_{\text{min}})/2^{k} for some kk. We generally choose ηmax160\eta_{\text{max}}\approx 160 and ηmin0\eta_{\text{min}}\approx 0. We find k={10,,15}k=\{10,\dots,15\} works well enough depending on datasets.

Constant decrement in λ(t):=log-SNRt=2log(αt/σt)\lambda(t)\vcentcolon=\text{log-SNR}_{t}=2\log(\alpha_{t}/\sigma_{t}).

Let its inverse be λ1()\lambda^{-1}(\cdot), then

We choose ϵ=(λmaxλmin)/2k\epsilon=(\lambda_{\text{max}}-\lambda_{\text{min}})/2^{k}. This choice comes close to the first choice, but we refrain from this because r(s,t)r(s,t) becomes close to tt both when t0t\approx 0 and t1t\approx 1 instead of just t1t\approx 1. This gives more chances for training instability than the first choice.

We choose ϵ=(1/η(t)min1/η(t)max)/2k\epsilon=(1/\eta(t)_{\text{min}}-1/\eta(t)_{\text{max}})/2^{k}.

C.7 Time Distribution p​(s,t)p(s,t)

In all cases we choose p(t)=U(ϵ,T)p(t)={\mathcal{U}}(\epsilon,T) and p(st)=U(ϵ,t)p(s|t)={\mathcal{U}}(\epsilon,t) for some ϵ0\epsilon\geq 0 and T1T\leq 1. The decision for time distribution is coupled with r(s,t)r(s,t). We list the constraints on p(s,t)p(s,t) for each r(s,t)r(s,t) choice below.

Constant decrement in η(t)\eta(t). We need to choose T<1T<1 because, for example, assuming OT-FM schedule, ηt=t/(1t)\eta_{t}=t/(1-t), one can observe that constant decrement in ηt\eta_{t} when t1t\approx 1 results in r(s,t)r(s,t) that is too close to tt due to ηt\eta_{t}’s exploding gradient around 11. We need to define T<1T<1 such that r(s,T)r(s,T) is not too close to TT for ss reasonably far away. With ηmax160\eta_{\text{max}}\approx 160, we can choose T=0.994T=0.994 for OT-FM and T=0.996T=0.996 for VP-diffusion.

Constant decrement in tt. No constraints needed. T=1T=1, ϵ=0\epsilon=0.

Constant decrement in λt\lambda_{t}. One can similarly observe exploding gradient causing r(s,t)r(s,t) to be too close to tt at both t0t\approx 0 and t1t\approx 1, so we can choose ϵ>0\epsilon>0, e.g. 0.0010.001, in addition to choosing T=0.994T=0.994 for OT-FM and T=0.996T=0.996 for VP-diffusion.

Constant increment in 1/ηt1/\eta_{t}. This experience exploding gradient for t0t\approx 0, so we require ϵ>0\epsilon>0, e.g. 0.0050.005. And T=1T=1.

C.8 Kernel Function

whose magnitude can vary a lot depending on how far xx is from yy.

C.9 Weighting Function w​(s,t)w(s,t)

To review VDM (Kingma et al., 2021), the negative ELBO loss for diffusion model is

where ϵθ\bm{\epsilon}_{\theta} is the noise-prediction network and λt=log-SNRt\lambda_{t}=\text{log-SNR}_{t}. The weighted-ELBO loss proposed in Kingma & Gao (2024) introduces an additional weighting function w(t)w(t) monotonically increasing in tt (monotonically decreasing in log-SNRt\text{log-SNR}_{t}) understood as a form of data augmentation. Specifically, they use sigmoid as the function such that the weighted ELBO is written as

where σ()\sigma(\cdot) is sigmoid function.

The αt\alpha_{t} is tailored towards the Simple-EDM and Euler-FM parameterization as we have shown in Appendix C.5 that the networks σdGθ\sigma_{d}{\bm{G}}_{\theta} amounts to vv-prediction in cosine and OT-FM schedules. Notice that ELBO diffusion loss matches ϵ\bm{\epsilon} instead of v{\mathbf{v}}. Inspecting the gradient of Laplace kernel, we have (again, for simplicity we let G^θ(xt,s,t)=Gθ(xtσdαt2+σt2,cnoise(s),cnoise(t))\hat{{\bm{G}}}_{\theta}({\mathbf{x}}_{t},s,t)={\bm{G}}_{\theta}(\frac{{\mathbf{x}}_{t}}{\sigma_{d}\sqrt{\alpha_{t}^{2}+\sigma_{t}^{2}}},c_{\text{noise}}(s),c_{\text{noise}}(t)))

for some constant G^target\hat{{\bm{G}}}_{\text{target}}. We can see that gradient \partialderivativeθfs,tθ(xt)\partialderivative{\theta}{\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}) is guided by vector G^θ(xt,s,t)G^target\hat{{\bm{G}}}_{\theta}({\mathbf{x}}_{t},s,t)-\hat{{\bm{G}}}_{\text{target}}. Assuming G^θ(xt,s,t)\hat{{\bm{G}}}_{\theta}({\mathbf{x}}_{t},s,t) is vv-prediction, as is the case for Simple-EDM parameterization with cosine schedule and Euler-FM parameterization with OT-FM schedule, we can reparameterize vv- to ϵ\epsilon-prediction with ϵθ\bm{\epsilon}_{\theta} as the new parameterization. We omit arguments to network for simplicity.

We show below that for both cases ϵθϵtarget=αt(G^θG^target)\bm{\epsilon}_{\theta}-\bm{\epsilon}_{\text{target}}=\alpha_{t}(\hat{{\bm{G}}}_{\theta}-\hat{{\bm{G}}}_{\text{target}}) for some constants ϵtarget\bm{\epsilon}_{\text{target}} and G^target\hat{{\bm{G}}}_{\text{target}}. For Simple-EDM, we know xx-prediction from vv-prediction parameterization (Salimans & Ho, 2022), xθ=αtxtσtG^θ{\mathbf{x}}_{\theta}=\alpha_{t}{\mathbf{x}}_{t}-\sigma_{t}\hat{{\bm{G}}}_{\theta}, and we also know xx-prediction from ϵ\epsilon-prediction, xθ=(xtσtϵθ)/αt{\mathbf{x}}_{\theta}=({\mathbf{x}}_{t}-\sigma_{t}\bm{\epsilon}_{\theta})/\alpha_{t}. We have

For Euler-FM, we know xx-prediction from vv-prediction parameterization, xθ=xttG^θ{\mathbf{x}}_{\theta}={\mathbf{x}}_{t}-t\hat{{\bm{G}}}_{\theta} and we also know xx-prediction from ϵ\epsilon-prediction, xθ=(xttϵθ)/(1t){\mathbf{x}}_{\theta}=({\mathbf{x}}_{t}-t\bm{\epsilon}_{\theta})/(1-t). We have

In both cases, (G^θ(xt,s,t)G^target)(\hat{{\bm{G}}}_{\theta}({\mathbf{x}}_{t},s,t)-\hat{{\bm{G}}}_{\text{target}}) can be rewritten to (ϵθ(xt,s,t)ϵtarget)(\bm{\epsilon}_{\theta}({\mathbf{x}}_{t},s,t)-\bm{\epsilon}_{\text{target}}) by multiplying a factor αt\alpha_{t}, and the guidance vector now matches that of the ELBO-diffusion loss. Therefore, we are motivated to incorporate αt\alpha_{t} into w(s,t)w(s,t) as proposed.

The exponent aa for αta\alpha_{t}^{a} takes a value of either 1 or 2. We explain the reason for each decision here. When a=1a=1, we guide the gradient \partialderivativeθfs,tθ(xt)\partialderivative{\theta}{\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}), with score difference (ϵθ(xt,s,t)ϵtarget)(\bm{\epsilon}_{\theta}({\mathbf{x}}_{t},s,t)-\bm{\epsilon}_{\text{target}}). To motivate a=2a=2, we first note that the weighted gradient

and as shown above that ϵ\epsilon-prediction is parameterized as

Multiplying an additional αt\alpha_{t} to \partialderivativeθG^θ(xt,s,t)\partialderivative{\theta}\hat{{\bm{G}}}_{\theta}({\mathbf{x}}_{t},s,t) therefore implicitly reparameterizes our model into an ϵ\epsilon-prediction model. The case of a=2a=2 therefore implicitly reparameterizes our model into an ϵ\epsilon-prediction model guided by the score difference (ϵθ(xt,s,t)ϵtarget)(\bm{\epsilon}_{\theta}({\mathbf{x}}_{t},s,t)-\bm{\epsilon}_{\text{target}}). Empirically, αt2\alpha_{t}^{2} additionally downweights loss for larger tt compared to αt\alpha_{t}, allowing the model to train on smaller time-steps more effectively.

Lastly, the division of αt2+σt2\alpha_{t}^{2}+\sigma_{t}^{2} is inspired by the increased weighting for middle time-steps (Esser et al., 2024) for Flow Matching training. This is purely an empirical decision.

Appendix D Training Algorithm

We present the training algorithm in Algorithm 3.

Appendix E Classifier-Free Guidance

We refer readers to Appendix C.5 for analysis of each parameterization. Most notably, the network Gθ{\bm{G}}_{\theta} in both (1) Simple-EDM with cosine diffusion schedule and (2) Euler-FM with OT-FM schedule are equivalent to vv-prediction parameterization in diffusion (Salimans & Ho, 2022) and FM (Lipman et al., 2022). When conditioned on label c{\mathbf{c}} during sampling, it is customary to use classifier-free guidance to reweight this vv-prediction network via

with guidance weight ww so that the classifier-free guided fs,t,wθ(xt){\bm{f}}_{s,t,w}^{\theta}({\mathbf{x}}_{t}) is

Appendix F Sampling Algorithms

Pushforward sampling. See Algorithm 4. We assume a series of NN time steps {ti}i=0N\{t_{i}\}_{i=0}^{N} with T=tN>tN1>>t2>t1>t0=ϵT=t_{N}>t_{N-1}>\dots>t_{2}>t_{1}>t_{0}=\epsilon for the maximum time TT and minimum time ϵ\epsilon. Denote σd\sigma_{d} as data standard deviation.

Restart sampling. See Algorithm 5. Different from pushforward sampling, NN time steps {ti}i=0N\{t_{i}\}_{i=0}^{N} do not need to be strictly decreasing for all time steps, e.g. T=tNtN1t2t1t0=ϵT=t_{N}\geq t_{N-1}\geq\dots\geq t_{2}\geq t_{1}\geq t_{0}=\epsilon (assuming T>ϵT>\bm{\epsilon}). Different from pushforward sampling, restart sampling first denoise a clean sample before resampling a noise to be added to this clean sample. Then a clean sample is predicted again. The process is iterated for NN steps.

Appendix G Connection with Prior Works

We show that CM loss is a special case of our simplified IMM objective.

Since xt=xt{\mathbf{x}}_{t}={\mathbf{x}}_{t}^{\prime}, xr=xr{\mathbf{x}}_{r}={\mathbf{x}}_{r}^{\prime}, we have fs,tθ(xt)=fs,tθ(xt){\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t})={\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}^{\prime}) and fs,rθ(xr)=fs,rθ(xr){\bm{f}}_{s,r}^{\theta}({\mathbf{x}}_{r})={\bm{f}}_{s,r}^{\theta}({\mathbf{x}}_{r}^{\prime}). So k(fs,tθ(xt),fs,tθ(xt))=k(fs,rθ(xr),fs,rθ(xr))=0k({\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}),{\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}^{\prime}))=k({\bm{f}}_{s,r}^{\theta}({\mathbf{x}}_{r}),{\bm{f}}_{s,r}^{\theta}({\mathbf{x}}_{r}^{\prime}))=0 by definition. Since k(x,y)=\normxy2k(x,y)=-\norm{x-y}^{2}, it is easy to see Eq. (12) reduces to

where w(s,t)w(s,t) is a weighting function. If ss is a small positive constant, we further have fs,tθ(xt)gθ(xt,t){\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t})\approx{\bm{g}}_{\theta}({\mathbf{x}}_{t},t) where we drop ss as input. If gθ(xt,t){\bm{g}}_{\theta}({\mathbf{x}}_{t},t) itself satisfies boundary condition at s=0s=0, we can directly take s=0s=0 in which case f0,tθ(xt)=gθ(xt,t){\bm{f}}_{0,t}^{\theta}({\mathbf{x}}_{t})={\bm{g}}_{\theta}({\mathbf{x}}_{t},t). And under these assumptions, our loss becomes

However, one can notice that from a moment-matching perspective, this loss significantly deviates from a proper divergence between distributions, and is problematic in two aspects. First, it assumes single-particle estimate, which now ignores the entropy repulsion term in MMD that arises only during multi-particle estimation. This can contribute to mode collapse and training instability of CM. Second, the choice of energy kernel is not a proper positive definite kernel required by MMD. At best, it only matches the first moment (its Taylor expansion cannot cover all moments as in RBF kernels), which is insufficient for matching two complex distributions! We should use kernels that match higher moments in practice. In fact, we show in the following Lemma that the pseudo-huber loss proposed in Song & Dhariwal (2023) matches higher moments as a kernel.

We know that negative L2L_{2} distance \normxy-\norm{x-y} is conditionally positive definite. We prove this below for completion. Due to triangle inequality, \normxy\normx\normy-\norm{x-y}\geq-\norm{x}-\norm{y}. Then

where (a)(a) is due to i=1nci=0\sum_{i=1}^{n}c_{i}=0. Now since c\normz2+c2\normzc-\sqrt{\norm{z}^{2}+c^{2}}\geq-\norm{z} for all c>0c>0, we have

So negative pseudo-huber loss is a valid conditionally positive definite kernel.

Next, we analyze pseudo-huber loss’s effect on higher-order moments by directly Taylor expanding \normz2+c2c\sqrt{\norm{z}^{2}+c^{2}}-c at z=0z=0

where we substitute z=xyz=x-y. Each higher order \normxyk\norm{x-y}^{k} for k>2k>2 expands to a polynomial containing up to kk-th moments, i.e., {x,x2,xk},{y,y2,yk}\{x,x^{2},\dots x^{k}\},\{y,y^{2},\dots y^{k}\}, thus the implicit feature map contains all higher moments where cc contributes to the weightings in front of each term. ∎

Furthermore, we extend our finite difference (between r(s,t)r(s,t) and tt) IMM objective to the differential limit by taking r(s,t)tr(s,t)\rightarrow t in Appendix H. This results in a new objective that similarly subsumes continuous-time CM (Song et al., 2023; Lu & Song, 2024) as a single-particle special case.

G.2 Diffusion GAN and Adversarial Consistency Distillation

Diffusion GAN (Xiao et al., 2021) parameterizes its generative distribution as

where Gθ{\bm{G}}_{\theta} is a neural network, p(z)p({\mathbf{z}}) is standard Gaussian distribution, and qst(xsx,xt)q_{s|t}({\mathbf{x}}_{s}|{\mathbf{x}},{\mathbf{x}}_{t}) is the DDPM posterior

Note that DDPM posterior is a stochastic interpolant, and more importantly, it is self-consistent, which we show in the Lemma below.

For all 0s<t10\leq s<t\leq 1, DDPM posterior distribution from tt to ss as defined in Eq. (97) is a self-consistent Gaussian interpolant between x{\mathbf{x}} and xt{\mathbf{x}}_{t}.

Let xrqrt(xrx,xt){\mathbf{x}}_{r}\sim q_{r|t}({\mathbf{x}}_{r}|{\mathbf{x}},{\mathbf{x}}_{t}) and xsqsr(xsx,xr){\mathbf{x}}_{s}\sim q_{s|r}({\mathbf{x}}_{s}|{\mathbf{x}},{\mathbf{x}}_{r}), we show that xs{\mathbf{x}}_{s} follows qst(xsx,xt)q_{s|t}({\mathbf{x}}_{s}|{\mathbf{x}},{\mathbf{x}}_{t}).

where ϵ1,ϵ2N(0,I)\bm{\epsilon}_{1},\bm{\epsilon}_{2}\sim{\mathcal{N}}(0,{I}) are i.i.d. Gaussian noise. Directly expanding

where (a)(a) is due to the fact that sum of two independent Gaussian variables with variance a2a^{2} and b2b^{2} is also Gaussian with variance a2+b2a^{2}+b^{2}, and ϵ3N(0,I)\bm{\epsilon}_{3}\sim{\mathcal{N}}(0,{I}) is another independent Gaussian noise. We show the calculation of the variance:

This shows xs{\mathbf{x}}_{s} follows qst(xsx,xt)q_{s|t}({\mathbf{x}}_{s}|{\mathbf{x}},{\mathbf{x}}_{t}) and completes the proof. ∎

This shows another possible design of the interpolant that can be used, and diffusion GAN’s formulation generally complies with our design of the generative distribution, except that it learns this conditional distribution of x{\mathbf{x}} given xt{\mathbf{x}}_{t} directly while we learn a marginal distribution. When they directly learn the conditional distribution by matching pstθ(xxt)p_{s|t}^{\theta}({\mathbf{x}}|{\mathbf{x}}_{t}) with qt(xxt)q_{t}({\mathbf{x}}|{\mathbf{x}}_{t}), the model is forced to learn qt(xxt)q_{t}({\mathbf{x}}|{\mathbf{x}}_{t}) and there only exists one minimizer. However, in our case, the model can learn multiple different solutions because we match the marginals instead.

GAN loss and MMD loss. We also want to draw attention to similarity between GAN loss used in Xiao et al. (2021); Sauer et al. (2025) and MMD loss. MMD is an integral probability metric over a set of functions F{\mathcal{F}} in the following form

where a supremum is taken on this set of functions. This naturally gives rise to an adversarial optimization algorithm if F{\mathcal{F}} is defined as the set of neural networks. However, MMD bypasses this by selecting F{\mathcal{F}} as the RKHS where the optimal ff can be analytically found. This eliminates the adversarial objective and gives a stable minimization objective in practice. However, this is not to say that RKHS is the best function set. With the right optimizers and training scheme, the adversarial objective may achieve better empirical performance, but this also makes the algorithm difficult to scale to large datasets.

G.3 Generative Moment Matching Network

It is trivial to check that GMMN is a special parameterization. We fix t=1t=1, and due to boundary condition, r(s,t)s=0r(s,t)\equiv s=0 implies training target psr(s,t)θ(xs)=qs(xs)p_{s|r(s,t)}^{\theta^{-}}({\mathbf{x}}_{s})=q_{s}({\mathbf{x}}_{s}) is the data distribution. Additionally, pstθ(xs)p_{s|t}^{\theta}({\mathbf{x}}_{s}) is a simple pushforward of prior p(ϵ)p(\bm{\epsilon}) through network gθ(ϵ){\bm{g}}_{\theta}(\bm{\epsilon}) where drop dependency on tt and ss since they are constant.

Appendix H Differential Inductive Moment Matching

Similar to the continuous-time CMs presented in (Lu & Song, 2024), our MMD objective can be taken to the differential limit. Consider the simplifed loss and parameterization in Eq. (12), we use the RBF kernel as our kernel of choice for simplicity.

Let fs,tθ(xt){\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}) be a twice continuously differentiable function with bounded first and second derivatives, let k(,)k(\cdot,\cdot) be RBF kernel with unit bandwidth, x,xq(x){\mathbf{x}},{\mathbf{x}}^{\prime}\sim q({\mathbf{x}}), xtqt(xtx){\mathbf{x}}_{t}\sim q_{t}({\mathbf{x}}_{t}|{\mathbf{x}}), xtqs(xtx){\mathbf{x}}_{t}^{\prime}\sim q_{s}({\mathbf{x}}_{t}^{\prime}|{\mathbf{x}}^{\prime}), xr=DDIM(xt,x,t,r){\mathbf{x}}_{r}=\operatorname{DDIM}({\mathbf{x}}_{t},{\mathbf{x}},t,r) and xr=DDIM(xt,x,t,r){\mathbf{x}}_{r}^{\prime}=\operatorname{DDIM}({\mathbf{x}}_{t}^{\prime},{\mathbf{x}}^{\prime},t,r), the following objective

Putting it together, the above results imply

since it is easy to check that the remaining terms cancel.

Substituting x=fs,tθ(xt)x={\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}), a=fs,rθ(xr)a={\bm{f}}_{s,r}^{\theta}({\mathbf{x}}_{r}), y=fs,tθ(xt)y={\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}^{\prime}), b=fs,rθ(xr)b={\bm{f}}_{s,r}^{\theta}({\mathbf{x}}_{r}^{\prime}), we furthermore have

Therefore, LIMM-(θ,t){\mathcal{L}}_{\text{IMM-}\infty}(\theta,t) can be derived as

Due to the stop-gradient operation, we can similarly find a pseudo-objective whose gradient matches the gradient of LIMM-(θ,t){\mathcal{L}}_{\text{IMM-}\infty}(\theta,t) in the limit of rtr\rightarrow t.

Let fs,tθ(xt){\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}) be a twice continuously differentiable function with bounded first and second derivatives, k(,)k(\cdot,\cdot) be RBF kernel with unit bandwidth, x,xq(x){\mathbf{x}},{\mathbf{x}}^{\prime}\sim q({\mathbf{x}}), xtqt(xtx){\mathbf{x}}_{t}\sim q_{t}({\mathbf{x}}_{t}|{\mathbf{x}}), xtqt(xtx){\mathbf{x}}_{t}^{\prime}\sim q_{t}({\mathbf{x}}_{t}^{\prime}|{\mathbf{x}}^{\prime}), xr=DDIM(xt,x,t,r){\mathbf{x}}_{r}=\operatorname{DDIM}({\mathbf{x}}_{t},{\mathbf{x}},t,r) and xr=DDIM(xt,x,t,r){\mathbf{x}}_{r}^{\prime}=\operatorname{DDIM}({\mathbf{x}}_{t}^{\prime},{\mathbf{x}}^{\prime},t,r), the gradient of the following pseudo-objective

can be used to optimize θ\theta and can be analytically derived as

Similar to the derivation of LIMM-(θ,t){\mathcal{L}}_{\text{IMM-}\infty}(\theta,t), let x=fs,tθ(xt)x={\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}), a=fs,rθ(xr)a={\bm{f}}_{s,r}^{\theta^{-}}({\mathbf{x}}_{r}), y=fs,tθ(xt)y={\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}^{\prime}), b=fs,rθ(xr)b={\bm{f}}_{s,r}^{\theta^{-}}({\mathbf{x}}_{r}^{\prime}), we have

Note that \odvfs,tθ(xt)t\odv{{\bm{f}}_{s,t}^{\theta^{-}}({\mathbf{x}}_{t})}{t} is now parameterized by θ\theta^{-} instead of θ\theta because the gradient is already taken w.r.t. fs,tθ(xt){\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}) outside of the brackets, so (xa)(x-a) and (yb)(y-b) merely require evaluation at current θ\theta with no gradient information, which θ\theta^{-} satisfies. The objective can be derived as

H.2 Connection with Continuous-Time CMs

Observing Eq. (105) and Eq. (110), we can see that when xt=xt{\mathbf{x}}_{t}={\mathbf{x}}_{t}^{\prime} and xr=xr{\mathbf{x}}_{r}={\mathbf{x}}_{r}^{\prime}, ss being a small positive constant, then fs,tθ(xt)=fs,tθ(xt){\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}^{\prime})={\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}), and exp(12\normfs,tθ(xt)fs,tθ(xt)2)=1\exp(-\frac{1}{2}\norm{{\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}^{\prime})-{\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t})}^{2})=1, and fs,tθ(xt)gθ(xt,t){\bm{f}}_{s,t}^{\theta}({\mathbf{x}}_{t}^{\prime})\approx{\bm{g}}_{\theta}({\mathbf{x}}_{t},t) where since ss is fixed we discard the dependency on ss as input. Then, Eq. (105) reduces to

which is the same as differential consistency loss (Song et al., 2023; Geng et al., 2024). And Eq. (110) reduces to

which is the pseudo-objective for continuous-time CMs (Song et al., 2023; Lu & Song, 2024) (minus a weighting function of choice).

Appendix I Experiment Settings

We summarize our best runs in Table 5. Specifically, for ImageNet-256 ⁣× ⁣256256{\mkern-1.0mu\times\mkern-1.0mu}256, we adopt a latent space paradigm for computational efficiency. For its autoencoder, we follow EDM2 (Karras et al., 2024) and pre-encode all images from ImageNet into latents without flipping, and calculate the channel-wise mean and std for normalization. We use Stable Diffusion VAEhttps://huggingface.co/stabilityai/sd-vae-ft-mse and rescale the latents by channel mean [0.86488,0.27787343,0.21616915,0.3738409][0.86488,-0.27787343,0.21616915,0.3738409] and channel std [4.85503674,5.31922414,3.93725398,3.9870003][4.85503674,5.31922414,3.93725398,3.9870003]. After this normalization transformation, we further multiply the latents by 0.50.5 so that the latents roughly have std 0.50.5. For DiT architecture of different sizes, we use the same hyperparameters for all experiments.

Choices for TT and ϵ\epsilon. By default assuming we are using mapping function r(s,t)r(s,t) by constant decrement in ηt\eta_{t}, we keep ηmax160\eta_{\text{max}}\approx 160. This implies that for time distribution of the form U(ϵ,T){\mathcal{U}}(\epsilon,T), we set T=0.996T=0.996 for cosine diffusion and T=0.994T=0.994 for OT-FM. For ϵ\epsilon, we set it differently for pixel-space and latent-space model. For pixel-space on CIFAR-10, we follow Nichol & Dhariwal (2021) and set ϵ\epsilon to a small positive constant because pixel quantization makes smaller noise imperceptible. We find 0.006 to work well. For latent-space on ImageNet-256 ⁣× ⁣256256{\mkern-1.0mu\times\mkern-1.0mu}256, we have no such intuition as in pixel-space. We simply set ϵ=0\epsilon=0 in this case.

Exceptions occur when we ablate other choices of r(s,t)r(s,t), e.g. constant decrement in λt\lambda_{t} in which case we set ϵ=0.001\bm{\epsilon}=0.001 to prevent r(s,t)r(s,t) for being too close to tt when tt is small.

Injecting time ss. The design for additionally injecting ss can be categorized into 2 types – injecting ss directly and injecting stride size (ts)(t-s). In both cases, architectural designs exactly follow the time injection of tt. We simply extract positional time embedding of ss (or tst-s) fed through 2-layer MLP (same as for tt) before adding this new embedding to the embedding of tt after MLP. The summed embedding is then fed through all the Transformer blocks as in standard DiT architecture.

Improved CT baseline. For ImageNet-256 ⁣× ⁣256256{\mkern-1.0mu\times\mkern-1.0mu}256, we implement iCT baseline by using our improved parameterization with Simple-EDM and OT-FM schedule. We use the proposed pseudo-huber loss for training but find training often collapses using the same r(s,t)r(s,t) schedule as ours. We carefully tune the gap to achieve reasonable performance without collapse and present our results in Table 2.

I.2 Inference Settings

Inference schedules. For all one-step inference, we directly start from ϵN(0,σd2I)\bm{\epsilon}\sim{\mathcal{N}}(0,\sigma_{d}^{2}{I}) at time TT to time ϵ\epsilon through pushforward sampling. For all 2-step methods, we set the intermediate timestep t1t_{1} such that ηt1=1.4\eta_{t_{1}}=1.4; this choice is purely empirical which we find to work well. For N4N\geq 4 steps we explore two types of time schedules: (1) uniform decrement in tt with η0<η1<ηN\eta_{0}<\eta_{1}\dots<\eta_{N} where

and (2) EDM (Karras et al., 2022) time schedule. EDM schedule specifies η0<η1<ηN\eta_{0}<\eta_{1}\dots<\eta_{N} where

We slightly modify the schedule so that η0=ηmin\eta_{0}=\eta_{\text{min}} is the endpoint instead of η1=ηmin\eta_{1}=\eta_{\text{min}} and η0=0\eta_{0}=0 as originally proposed, since our η0\eta_{0} can be set to 0 without numerical issue.

We also specify the time schedule type used for our best runs in Table 5 and their results.

I.3 Scaling Settings

Model GFLOPs. We reuse numbers from DiT (Peebles & Xie, 2023) for each model architecture.

Training compute. Following Peebles & Xie (2023), we use the formula model GFLOPsbatch sizetraining steps4\text{model GFLOPs}\cdot\text{batch size}\cdot\text{training steps}\cdot 4 for training compute where, different from DiT, we have constant 4 because for each iteration we have 2 forward pass and 1 backward pass, which is estimated as twice the forward compute.

Inference compute. We calculate inference compute via model GFLOPsnumber of steps\text{model GFLOPs}\cdot\text{number of steps}.

I.4 Scaling Beyond 8 Steps

We investigate when the performance saturates for ImageNet-256 ⁣× ⁣{\mkern-1.0mu\times\mkern-1.0mu}256 in Table 6. We see continued improvement beyond 8 steps and at 16 steps our method already outperforms VAR with 2B parameters (1.92 FID).

I.5 Ablation on exponent aa

We compare the performance between a=1a=1 and a=2a=2 on full DiT-XL architecture in Table 7, which shows how aa affects results of different sampling steps. We observe that a=2a=2 causes slightly higher 11-step sampling FID but outperforms a=1a=1 in the multi-step regime.

11-step 22-step 44-step 88-step a=1a=1 7.97 4.01 2.61 2.13 a=2a=2 8.28 4.08 2.60 2.01

11-step 22-step 44-step 88-step TF32 w/ a=1a=1 7.97 4.01 2.61 2.13 FP16 w/ a=1a=1 8.73 4.54 3.03 2.38 FP16 w/ a=2a=2 8.05 3.99 2.51 1.99

I.6 Caveats for Lower-Precision Training

For all experiments, we follow the original works (Song et al., 2020b; Peebles & Xie, 2023) and use the default TF32 precision for training and evaluation. When switching to lower precision such as FP16, we find that our mapping function, i.e. constant decrement in ηt\eta_{t}, can cause indistinguishable time embedding after some MLP layers when tt is large. To mitigate this issue, we simply impose a minimum gap Δ\Delta between any tt and rr, for example, Δ=104\Delta=10^{-4}. Our resulting mapping function becomes

Optionally, we can also increase distinguishability between nearby time-steps inside the network by injecting (rs)(r-s) instead of ss as our second time condition. We use this as default for FP16 training. With these simple changes, we observe minimal impact on generation performance.

Lastly, if training from scratch with lower precision, we recommend FP16 instead of BF16 because of higher precision that is needed to distinguish between nearby tt and rr.

We show results in Table 8. For FP16, a=1a=1 causes slight performance degradation because of the small gap issue at large tt. This is effectively resolved by a=2a=2 which downweights losses at large tt by focusing on smaller tt instead. At lower precision, while not necessary, a=2a=2 is an effective solution to achieve good performance that matches or even surpasses that of TF32.

Appendix J Additional Visualization

We present additional visualization results in the following page.