Scalable Diffusion Models with Transformers
William Peebles, Saining Xie
Introduction
Machine learning is experiencing a renaissance powered by transformers. Over the past five years, neural architectures for natural language processing , vision and several other domains have largely been subsumed by transformers . Many classes of image-level generative models remain holdouts to the trend, though—while transformers see widespread use in autoregressive models , they have seen less adoption in other generative modeling frameworks. For example, diffusion models have been at the forefront of recent advances in image-level generative models ; yet, they all adopt a convolutional U-Net architecture as the de-facto choice of backbone.
The seminal work of Ho et al. first introduced the U-Net backbone for diffusion models. Having initially seen success within pixel-level autoregressive models and conditional GANs , the U-Net was inherited from PixelCNN++ with a few changes. The model is convolutional, comprised primarily of ResNet blocks. In contrast to the standard U-Net , additional spatial self-attention blocks, which are essential components in transformers, are interspersed at lower resolutions. Dhariwal and Nichol ablated several architecture choices for the U-Net, such as the use of adaptive normalization layers to inject conditional information and channel counts for convolutional layers. However, the high-level design of the U-Net from Ho et al. has largely remained intact.
With this work, we aim to demystify the significance of architectural choices in diffusion models and offer empirical baselines for future generative modeling research. We show that the U-Net inductive bias is not crucial to the performance of diffusion models, and they can be readily replaced with standard designs such as transformers. As a result, diffusion models are well-poised to benefit from the recent trend of architecture unification—e.g., by inheriting best practices and training recipes from other domains, as well as retaining favorable properties like scalability, robustness and efficiency. A standardized architecture would also open up new possibilities for cross-domain research.
In this paper, we focus on a new class of diffusion models based on transformers. We call them Diffusion Transformers, or DiTs for short. DiTs adhere to the best practices of Vision Transformers (ViTs) , which have been shown to scale more effectively for visual recognition than traditional convolutional networks (e.g., ResNet ).
More specifically, we study the scaling behavior of transformers with respect to network complexity vs. sample quality. We show that by constructing and benchmarking the DiT design space under the Latent Diffusion Models (LDMs) framework, where diffusion models are trained within a VAE’s latent space, we can successfully replace the U-Net backbone with a transformer. We further show that DiTs are scalable architectures for diffusion models: there is a strong correlation between the network complexity (measured by Gflops) vs. sample quality (measured by FID). By simply scaling-up DiT and training an LDM with a high-capacity backbone (118.6 Gflops), we are able to achieve a state-of-the-art result of 2.27 FID on the class-conditional ImageNet generation benchmark.
Related Work
Transformers have replaced domain-specific architectures across language, vision , reinforcement learning and meta-learning . They have shown remarkable scaling properties under increasing model size, training compute and data in the language domain , as generic autoregressive models and as ViTs . Beyond language, transformers have been trained to autoregressively predict pixels . They have also been trained on discrete codebooks as both autoregressive models and masked generative models ; the former has shown excellent scaling behavior up to 20B parameters . Finally, transformers have been explored in DDPMs to synthesize non-spatial data; e.g., to generate CLIP image embeddings in DALLE 2 . In this paper, we study the scaling properties of transformers when used as the backbone of diffusion models of images.
Denoising diffusion probabilistic models (DDPMs).
Diffusion and score-based generative models have been particularly successful as generative models of images , in many cases outperforming generative adversarial networks (GANs) which had previously been state-of-the-art. Improvements in DDPMs over the past two years have largely been driven by improved sampling techniques , most notably classifier-free guidance , reformulating diffusion models to predict noise instead of pixels and using cascaded DDPM pipelines where low-resolution base diffusion models are trained in parallel with upsamplers . For all the diffusion models listed above, convolutional U-Nets are the de-facto choice of backbone architecture. Concurrent work introduced a novel, efficient architecture based on attention for DDPMs; we explore pure transformers.
Architecture complexity.
When evaluating architecture complexity in the image generation literature, it is fairly common practice to use parameter counts. In general, parameter counts can be poor proxies for the complexity of image models since they do not account for, e.g., image resolution which significantly impacts performance . Instead, much of the model complexity analysis in this paper is through the lens of theoretical Gflops. This brings us in-line with the architecture design literature where Gflops are widely-used to gauge complexity. In practice, the golden complexity metric is still up for debate as it frequently depends on particular application scenarios. Nichol and Dhariwal’s seminal work improving diffusion models is most related to us—there, they analyzed the scalability and Gflop properties of the U-Net architecture class. In this paper, we focus on the transformer class.
Diffusion Transformers
Before introducing our architecture, we briefly review some basic concepts needed to understand diffusion models (DDPMs) . Gaussian diffusion models assume a forward noising process which gradually applies noise to real data : , where constants are hyperparameters. By applying the reparameterization trick, we can sample , where .
Diffusion models are trained to learn the reverse process that inverts forward process corruptions: , where neural networks are used to predict the statistics of . The reverse process model is trained with the variational lower bound of the log-likelihood of , which reduces to , excluding an additional term irrelevant for training. Since both and are Gaussian, can be evaluated with the mean and covariance of the two distributions. By reparameterizing as a noise prediction network , the model can be trained using simple mean-squared error between the predicted noise and the ground truth sampled Gaussian noise : . But, in order to train diffusion models with a learned reverse process covariance , the full term needs to be optimized. We follow Nichol and Dhariwal’s approach : train with , and train with the full . Once is trained, new images can be sampled by initializing and sampling via the reparameterization trick.
Classifier-free guidance.
Conditional diffusion models take extra information as input, such as a class label . In this case, the reverse process becomes , where and are conditioned on . In this setting, classifier-free guidance can be used to encourage the sampling procedure to find such that is high . By Bayes Rule, , and hence . By interpreting the output of diffusion models as the score function, the DDPM sampling procedure can be guided to sample with high by: , where indicates the scale of the guidance (note that recovers standard sampling). Evaluating the diffusion model with is done by randomly dropping out during training and replacing it with a learned “null” embedding . Classifier-free guidance is widely-known to yield significantly improved samples over generic sampling techniques , and the trend holds for our DiT models.
Latent diffusion models.
Training diffusion models directly in high-resolution pixel space can be computationally prohibitive. Latent diffusion models (LDMs) tackle this issue with a two-stage approach: (1) learn an autoencoder that compresses images into smaller spatial representations with a learned encoder ; (2) train a diffusion model of representations instead of a diffusion model of images ( is frozen). New images can then be generated by sampling a representation from the diffusion model and subsequently decoding it to an image with the learned decoder .
As shown in Figure 2, LDMs achieve good performance while using a fraction of the Gflops of pixel space diffusion models like ADM. Since we are concerned with compute efficiency, this makes them an appealing starting point for architecture exploration. In this paper, we apply DiTs to latent space, although they could be applied to pixel space without modification as well. This makes our image generation pipeline a hybrid-based approach; we use off-the-shelf convolutional VAEs and transformer-based DDPMs.
2 Diffusion Transformer Design Space
We introduce Diffusion Transformers (DiTs), a new architecture for diffusion models. We aim to be as faithful to the standard transformer architecture as possible to retain its scaling properties. Since our focus is training DDPMs of images (specifically, spatial representations of images), DiT is based on the Vision Transformer (ViT) architecture which operates on sequences of patches . DiT retains many of the best practices of ViTs. Figure 3 shows an overview of the complete DiT architecture. In this section, we describe the forward pass of DiT, as well as the components of the design space of the DiT class.
The input to DiT is a spatial representation (for images, has shape ). The first layer of DiT is “patchify,” which converts the spatial input into a sequence of tokens, each of dimension , by linearly embedding each patch in the input. Following patchify, we apply standard ViT frequency-based positional embeddings (the sine-cosine version) to all input tokens. The number of tokens created by patchify is determined by the patch size hyperparameter . As shown in Figure 4, halving will quadruple , and thus at least quadruple total transformer Gflops. Although it has a significant impact on Gflops, note that changing has no meaningful impact on downstream parameter counts.
We add to the DiT design space.
DiT block design.
Following patchify, the input tokens are processed by a sequence of transformer blocks. In addition to noised image inputs, diffusion models sometimes process additional conditional information such as noise timesteps , class labels , natural language, etc. We explore four variants of transformer blocks that process conditional inputs differently. The designs introduce small, but important, modifications to the standard ViT block design. The designs of all blocks are shown in Figure 3.
In-context conditioning. We simply append the vector embeddings of and as two additional tokens in the input sequence, treating them no differently from the image tokens. This is similar to cls tokens in ViTs, and it allows us to use standard ViT blocks without modification. After the final block, we remove the conditioning tokens from the sequence. This approach introduces negligible new Gflops to the model.
Cross-attention block. We concatenate the embeddings of and into a length-two sequence, separate from the image token sequence. The transformer block is modified to include an additional multi-head cross-attention layer following the multi-head self-attention block, similar to the original design from Vaswani et al. , and also similar to the one used by LDM for conditioning on class labels. Cross-attention adds the most Gflops to the model, roughly a 15% overhead.
Adaptive layer norm (adaLN) block. Following the widespread usage of adaptive normalization layers in GANs and diffusion models with U-Net backbones , we explore replacing standard layer norm layers in transformer blocks with adaptive layer norm (adaLN). Rather than directly learn dimension-wise scale and shift parameters and , we regress them from the sum of the embedding vectors of and . Of the three block designs we explore, adaLN adds the least Gflops and is thus the most compute-efficient. It is also the only conditioning mechanism that is restricted to apply the same function to all tokens.
adaLN-Zero block. Prior work on ResNets has found that initializing each residual block as the identity function is beneficial. For example, Goyal et al. found that zero-initializing the final batch norm scale factor in each block accelerates large-scale training in the supervised learning setting . Diffusion U-Net models use a similar initialization strategy, zero-initializing the final convolutional layer in each block prior to any residual connections. We explore a modification of the adaLN DiT block which does the same. In addition to regressing and , we also regress dimension-wise scaling parameters that are applied immediately prior to any residual connections within the DiT block. We initialize the MLP to output the zero-vector for all ; this initializes the full DiT block as the identity function. As with the vanilla adaLN block, adaLN-Zero adds negligible Gflops to the model.
We include the in-context, cross-attention, adaptive layer norm and adaLN-Zero blocks in the DiT design space.
Model size.
We apply a sequence of DiT blocks, each operating at the hidden dimension size . Following ViT, we use standard transformer configs that jointly scale , and attention heads . Specifically, we use four configs: DiT-S, DiT-B, DiT-L and DiT-XL. They cover a wide range of model sizes and flop allocations, from 0.3 to 118.6 Gflops, allowing us to gauge scaling performance. Table 1 gives details of the configs.
We add B, S, L and XL configs to the DiT design space.
Transformer decoder.
After the final DiT block, we need to decode our sequence of image tokens into an output noise prediction and an output diagonal covariance prediction. Both of these outputs have shape equal to the original spatial input. We use a standard linear decoder to do this; we apply the final layer norm (adaptive if using adaLN) and linearly decode each token into a tensor, where is the number of channels in the spatial input to DiT. Finally, we rearrange the decoded tokens into their original spatial layout to get the predicted noise and covariance.
The complete DiT design space we explore is patch size, transformer block architecture and model size.
Experimental Setup
We explore the DiT design space and study the scaling properties of our model class. Our models are named according to their configs and latent patch sizes ; for example, DiT-XL/2 refers to the XLarge config and .
We train class-conditional latent DiT models at and image resolution on the ImageNet dataset , a highly-competitive generative modeling benchmark. We initialize the final linear layer with zeros and otherwise use standard weight initialization techniques from ViT. We train all models with AdamW . We use a constant learning rate of , no weight decay and a batch size of 256. The only data augmentation we use is horizontal flips. Unlike much prior work with ViTs , we did not find learning rate warmup nor regularization necessary to train DiTs to high performance. Even without these techniques, training was highly stable across all model configs and we did not observe any loss spikes commonly seen when training transformers. Following common practice in the generative modeling literature, we maintain an exponential moving average (EMA) of DiT weights over training with a decay of 0.9999. All results reported use the EMA model. We use identical training hyperparameters across all DiT model sizes and patch sizes. Our training hyperparameters are almost entirely retained from ADM. We did not tune learning rates, decay/warm-up schedules, Adam / or weight decays.
Diffusion.
We use an off-the-shelf pre-trained variational autoencoder (VAE) model from Stable Diffusion . The VAE encoder has a downsample factor of 8—given an RGB image with shape , has shape . Across all experiments in this section, our diffusion models operate in this -space. After sampling a new latent from our diffusion model, we decode it to pixels using the VAE decoder . We retain diffusion hyperparameters from ADM ; specifically, we use a linear variance schedule ranging from to , ADM’s parameterization of the covariance and their method for embedding input timesteps and labels.
Evaluation metrics.
We measure scaling performance with Fréchet Inception Distance (FID) , the standard metric for evaluating generative models of images.
We follow convention when comparing against prior works and report FID-50K using 250 DDPM sampling steps. FID is known to be sensitive to small implementation details ; to ensure accurate comparisons, all values reported in this paper are obtained by exporting samples and using ADM’s TensorFlow evaluation suite . FID numbers reported in this section do not use classifier-free guidance except where otherwise stated. We additionally report Inception Score , sFID and Precision/Recall as secondary metrics.
Compute.
We implement all models in JAX and train them using TPU-v3 pods. DiT-XL/2, our most compute-intensive model, trains at roughly 5.7 iterations/second on a TPU v3-256 pod with a global batch size of 256.
Experiments
We train four of our highest Gflop DiT-XL/2 models, each using a different block design—in-context (119.4 Gflops), cross-attention (137.6 Gflops), adaptive layer norm (adaLN, 118.6 Gflops) or adaLN-zero (118.6 Gflops). We measure FID over the course of training. Figure 5 shows the results. The adaLN-Zero block yields lower FID than both cross-attention and in-context conditioning while being the most compute-efficient. At 400K training iterations, the FID achieved with the adaLN-Zero model is nearly half that of the in-context model, demonstrating that the conditioning mechanism critically affects model quality. Initialization is also important—adaLN-Zero, which initializes each DiT block as the identity function, significantly outperforms vanilla adaLN. For the rest of the paper, all models will use adaLN-Zero DiT blocks.
Scaling model size and patch size.
We train 12 DiT models, sweeping over model configs (S, B, L, XL) and patch sizes (8, 4, 2). Note that DiT-L and DiT-XL are significantly closer to each other in terms of relative Gflops than other configs. Figure 2 (left) gives an overview of the Gflops of each model and their FID at 400K training iterations. In all cases, we find that increasing model size and decreasing patch size yields considerably improved diffusion models.
Figure 6 (top) demonstrates how FID changes as model size is increased and patch size is held constant. Across all four configs, significant improvements in FID are obtained over all stages of training by making the transformer deeper and wider. Similarly, Figure 6 (bottom) shows FID as patch size is decreased and model size is held constant. We again observe considerable FID improvements throughout training by simply scaling the number of tokens processed by DiT, holding parameters approximately fixed.
DiT Gflops are critical to improving performance.
The results of Figure 6 suggest that parameter counts do not uniquely determine the quality of a DiT model. As model size is held constant and patch size is decreased, the transformer’s total parameters are effectively unchanged (actually, total parameters slightly decrease), and only Gflops are increased. These results indicate that scaling model Gflops is actually the key to improved performance. To investigate this further, we plot the FID-50K at 400K training steps against model Gflops in Figure 8. The results demonstrate that different DiT configs obtain similar FID values when their total Gflops are similar (e.g., DiT-S/2 and DiT-B/4). We find a strong negative correlation between model Gflops and FID-50K, suggesting that additional model compute is the critical ingredient for improved DiT models. In Figure 12 (appendix), we find that this trend holds for other metrics such as Inception Score.
Larger DiT models are more compute-efficient. In Figure 9, we plot FID as a function of total training compute for all DiT models. We estimate training compute as model Gflops batch size training steps 3, where the factor of 3 roughly approximates the backwards pass as being twice as compute-heavy as the forward pass. We find that small DiT models, even when trained longer, eventually become compute-inefficient relative to larger DiT models trained for fewer steps. Similarly, we find that models that are identical except for patch size have different performance profiles even when controlling for training Gflops. For example, XL/4 is outperformed by XL/2 after roughly Gflops.
Visualizing scaling.
We visualize the effect of scaling on sample quality in Figure 7. At 400K training steps, we sample an image from each of our 12 DiT models using identical starting noise , sampling noise and class labels. This lets us visually interpret how scaling affects DiT sample quality. Indeed, scaling both model size and the number of tokens yields notable improvements in visual quality.
1 State-of-the-Art Diffusion Models
Following our scaling analysis, we continue training our highest Gflop model, DiT-XL/2, for 7M steps. We show samples from the model in Figures 1, and we compare against state-of-the-art class-conditional generative models. We report results in Table 3. When using classifier-free guidance, DiT-XL/2 outperforms all prior diffusion models, decreasing the previous best FID-50K of 3.60 achieved by LDM to 2.27. Figure 2 (right) shows that DiT-XL/2 (118.6 Gflops) is compute-efficient relative to latent space U-Net models like LDM-4 (103.6 Gflops) and substantially more efficient than pixel space U-Net models such as ADM (1120 Gflops) or ADM-U (742 Gflops). Our method achieves the lowest FID of all prior generative models, including the previous state-of-the-art StyleGAN-XL . Finally, we also observe that DiT-XL/2 achieves higher recall values at all tested classifier-free guidance scales compared to LDM-4 and LDM-8. When trained for only 2.35M steps (similar to ADM), XL/2 still outperforms all prior diffusion models with an FID of 2.55.
×\times512 ImageNet.
We train a new DiT-XL/2 model on ImageNet at resolution for 3M iterations with identical hyperparameters as the model. With a patch size of 2, this XL/2 model processes a total of 1024 tokens after patchifying the input latent (524.6 Gflops). Table 3 shows comparisons against state-of-the-art methods. XL/2 again outperforms all prior diffusion models at this resolution, improving the previous best FID of 3.85 achieved by ADM to 3.04. Even with the increased number of tokens, XL/2 remains compute-efficient. For example, ADM uses 1983 Gflops and ADM-U uses 2813 Gflops; XL/2 uses 524.6 Gflops. We show samples from the high-resolution XL/2 model in Figure 1 and the appendix.
2 Scaling Model vs. Sampling Compute
Diffusion models are unique in that they can use additional compute after training by increasing the number of sampling steps when generating an image. Given the impact of model Gflops on sample quality, in this section we study if smaller-model compute DiTs can outperform larger ones by using more sampling compute. We compute FID for all 12 of our DiT models after 400K training steps, using sampling steps per-image. The main results are in Figure 10. Consider DiT-L/2 using 1000 sampling steps versus DiT-XL/2 using 128 steps. In this case, L/2 uses Tflops to sample each image; XL/2 uses less compute— Tflops—to sample each image. Nonetheless, XL/2 has the better FID-10K (23.7 vs 25.9). In general, scaling-up sampling compute cannot compensate for a lack of model compute.
Conclusion
We introduce Diffusion Transformers (DiTs), a simple transformer-based backbone for diffusion models that outperforms prior U-Net models and inherits the excellent scaling properties of the transformer model class. Given the promising scaling results in this paper, future work should continue to scale DiTs to larger models and token counts. DiT could also be explored as a drop-in backbone for text-to-image models like DALLE 2 and Stable Diffusion.
We thank Kaiming He, Ronghang Hu, Alexander Berg, Shoubhik Debnath, Tim Brooks, Ilija Radosavovic and Tete Xiao for helpful discussions. William Peebles is supported by the NSF GRFP.
References
Appendix A Additional Implementation Details
We include detailed information about all of our DiT models in Table 4, including both and models. In Figure 13, we report DiT training loss curves. Finally, we also include Gflop counts for DDPM U-Net models from ADM and LDM in Table 6.
To embed input timesteps, we use a 256-dimensional frequency embedding followed by a two-layer MLP with dimensionality equal to the transformer’s hidden size and SiLU activations. Each adaLN layer feeds the sum of the timestep and class embeddings into a SiLU nonlinearity and a linear layer with output neurons equal to either (adaLN) or (adaLN-Zero) the transformer’s hidden size. We use GELU nonlinearities (approximated with tanh) in the core transformer .
Classifier-free guidance on a subset of channels.
In our experiments using classifier-free guidance, we applied guidance only to the first three channels of the latents instead of all four channels. Upon investigating, we found that three-channel guidance and four-channel guidance give similar results (in terms of FID) when simply adjusting the scale factor. Specifically, three-channel guidance with a scale of appears reasonably well-approximated by four-channel guidance with a scale of (e.g., three-channel guidance with a scale of gives an FID-50K of 2.27, and four-channel guidance with a scale of gives an FID-50K of 2.20). It is somewhat interesting that applying guidance to a subset of elements can still yield good performance, and we leave it to future work to explore this phenomenon further.
Appendix B Model Samples
We show samples from our two DiT-XL/2 models at and resolution trained for 3M and 7M steps, respectively. Figures 1 and 11 show selected samples from both models. Figures 14 through 33 show uncurated samples from the two models across a range of classifier-free guidance scales and input class labels (generated with 250 DDPM sampling steps and the ft-EMA VAE decoder). As with prior work using guidance, we observe that larger scales increase visual fidelity and decrease sample diversity.
Appendix C Additional Scaling Results
In Figure 12, we show the effects of DiT scale on a suite of evaluation metrics—FID, sFID, Inception Score, Precision and Recall. We find that our FID-driven analysis in the main paper generalizes to the other metrics—across every metric, scaled-up DiT models are more compute-efficient and model Gflops are highly-correlated with performance. In particular, Inception Score and Precision benefit heavily from increased model scale.
Impact of scaling on training loss.
We also examine the impact of scale on training loss in Figure 13. Increasing DiT model Gflops (via transformer size or number of input tokens) causes the training loss to decrease more rapidly and saturate at a lower value. This phenomenon is consistent with trends observed with language models, where scaled-up transformers demonstrate both improved loss curves as well as improved performance on downstream evaluation suites .
Appendix D VAE Decoder Ablations
We used off-the-shelf, pre-trained VAEs across our experiments. The VAE models (ft-MSE and ft-EMA) are fine-tuned versions of the original LDM “f8” model (only the decoder weights are fine-tuned). We monitored metrics for our scaling analysis in Section 5 using the ft-MSE decoder, and we used the ft-EMA decoder for our final metrics reported in Tables 3 and 3. In this section, we ablate three different choices of the VAE decoder; the original one used by LDM and the two fine-tuned decoders used by Stable Diffusion. Because the encoders are identical across models, the decoders can be swapped-in without retraining the diffusion model. Table 5 shows results; XL/2 continues to outperform all prior diffusion models when using the LDM decoder.