All are Worth Words: A ViT Backbone for Diffusion Models
Fan Bao, Shen Nie, Kaiwen Xue, Yue Cao, Chongxuan Li, Hang Su, Jun Zhu
Introduction
Diffusion models are powerful deep generative models that emerge recently for high quality image generation . They grow rapidly and find applications in text-to-image generation , image-to-image generation , video generation , speech synthesis , and 3D synthesis .
Along with the development of algorithms , the revolution of backbones plays a central role in diffusion models. A representative example is U-Net based on a convolutional neural network (CNN) employed in prior work . The CNN-based U-Net is characterized by a group of down-sampling blocks, a group of up-sampling blocks, and long skip connections between the two groups, which dominates diffusion models for image generation tasks . On the other hand, vision transformers (ViT) have shown promise in various vision tasks, where ViT is comparable or even superior to CNN based approaches . Therefore, a very natural question arises: whether the reliance of the CNN-based U-Net is necessary in diffusion models?
In this paper, we design a simple and general ViT-based architecture called U-ViT (Figure 1). Following the design methodology of transformers, U-ViT treats all inputs including the time, condition and noisy image patches as tokens. Crucially, U-ViT employs long skip connections between shallow and deep layers inspired by U-Net. Intuitively, low-level features are important to the pixel-level prediction objective in diffusion models and such connections can ease the training of the corresponding prediction network. Besides, U-ViT optionally adds an extra 33 convolutional block before output for better visual quality. See a systematical ablation study for all elements in Figure 2.
We evaluate U-ViT in three popular tasks: unconditional image generation, class-conditional image generation and text-to-image generation. In all settings, U-ViT is comparable if not superior to a CNN-based U-Net of a similar size. In particular, latent diffusion models with U-ViT achieve record-breaking FID scores of 2.29 in class-conditional image generation on ImageNet 256256, and 5.48 in text-to-image generation on MS-COCO, among methods without accessing large external datasets during the training of generative models.
Our results suggest that the long skip connection is crucial while the down/up-sampling operators in CNN-based U-Net are not always necessary for image diffusion models. We believe that U-ViT can provide insights for future research on diffusion model backbones and benefit generative modeling on large scale cross-modality datasets.
Background
Diffusion models gradually inject noise to data, and then reverse this process to generate data from noise. The noise-injection process, also called the forward process, is formalized as a Markov chain:
Here is the data, , and and represent the noise schedule such that . To reverse this process, a Gaussian model is adopted to approximate the ground truth reverse transition , and the optimal mean is
where is the condition or its continuous embedding. In prior work on image modeling, the success of diffusion models heavily rely on CNN-based U-Net , which is a convolutional backbone characterized by a group of down-sampling blocks, a group of up-sampling blocks and long skip connections between the two groups, and is fed into U-Net by mechanisms such as adaptive group normalization and cross attention .
Vision Transformer (ViT) is a pure transformer architecture that treats an image as a sequence of tokens (words). ViT rearranges an image into a sequence of flattened patches. Then, ViT adds learnable 1D position embeddings to linear embeddings of these patches before feeding them into a transformer encoder . ViT has shown promise in various vision tasks but it is not clear whether it is suitable for diffusion-based image modeling yet.
Method
U-ViT is a simple and general backbone for diffusion models in image generation (Figure 1). In particular, U-ViT parameterizes the noise prediction networkU-ViT can also parameterize other types of prediction, e.g., -prediction . in Eq. (1). It takes the time , the condition and the noisy image as inputs and predicts the noise injected into . Following the design methodology of ViT, the image is split into patches, and U-ViT treats all inputs including the time, condition and image patches as tokens (words).
Inspired by the success of the CNN-based U-Net in diffusion models , U-ViT also employs similar long skip connections between shallow and deep layers. Intuitively, the objective in Eq. (1) is a pixel-level prediction task and is sensitive to low-level features. The long skip connections provide shortcuts for the low-level features and therefore ease the training of the noise prediction network.
Additionally, U-ViT optionally adds a 33 convolutional block before output. This is intended to prevent the potential artifacts in images produced by transformers . The block improves the visual quality of the samples generated by U-ViT according to our experiments.
In Section 3.1, we present the implementation details of U-ViT. In Section 3.2, we present the scaling properties of U-ViT by studying the effect of depth, width and patch size.
Although U-ViT is conceptually simple, we carefully design its implementation. To this end, we perform a systematical empirical study on key elements in U-ViT. In particular, we ablate on CIFAR10 , evaluate the FID score every 50K training iterations on 10K generated samples (instead of 50K samples for efficiency), and determine default implementation details.
Variants of the patch embedding. We consider two variants of the patch embedding. (1) The original patch embedding adopts a linear projection that maps a patch to a token embedding, as illustrated in Figure 1. (2) Alternatively, use a stack of 33 convolutional blocks followed by a 11 convolutional block to map an image to token embeddings. We compare them in Figure 2 (d), and the original patch embedding performs better.
Variants of the position embedding. We consider two variants of the position embedding. (1) The first one is the 1-dimensional learnable position embedding proposed in the original ViT , which is the default setting in this paper. (2) The second one is the 2-dimensional sinusoidal position embedding, which is obtained by concatenating the sinusoidal embeddings of and for a patch at position . As shown in Figure 2 (e), the 1-dimensional learnable position embedding performs better. We also try not use any position embedding, and find the model fails to generate meaningful images, which implies the position information is critical in image generation.
2 Effect of Depth, Width and Patch Size
We present scaling properties of U-ViT by studying the effect of the depth (i.e., the number of layers), width (i.e., the hidden size ) and patch size on CIFAR10. As shown in Figure 3, the performance improves as the depth (i.e., the number of layers) increases from 9 to 13. Nevertheless, U-ViT does not gain from a larger depth like 17 in 50K training iterations. Similarly, increasing the width (i.e., the hidden size) from 256 to 512 improves the performance, and further increase to 768 brings no gain; decreasing the patch size from 8 to 2 improves the performance, and further decrease to 1 brings no gain. Note that a small patch size like 2 is required for a good performance. We hypothesize it is because that the noise prediction task in diffusion models is low-level and requires small patches, differing from high-level tasks (e.g., classification). Since using a small patch size is costly for high resolution images, we firstly convert them to low-dimensional latent representations and model these latent representations using U-ViT.
Related Work
Transformers in diffusion models. A related work is GenViT . GenViT employs a smaller ViT that does not employ long skip connections and the 33 convolutional block, and incorporates time before normalization layers for image diffusion models. Empirically, our U-ViT performs much better than GenViT (see Table 1) by carefully designing implementation details. Another related work is VQ-Diffusion and its variants . VQ-Diffusion firstly obtains a sequence of discrete image tokens via a VQ-GAN , and then models these tokens using a discrete diffusion model with a transformer as its backbone. Time and condition are fed into the transformer through cross attention or adaptive layer normalization. In contrast, our U-ViT simply treats all inputs as tokens, and employs long skip connections between shallow and deep layers, which achieves a better FID (see Table 1 and Table 4). In addition to images, transformers in diffusion models are also employed to encode texts , decode texts and generate CLIP embeddings .
U-Net in diffusion models. initially introduce CNN-based U-Net to model the gradient of log-likelihood function for continuous image data. After that, improvements on the CNN-based U-Net for (continuous) image diffusion models are made, including using group normalization , multi-head attention , improved residual block and cross attention . In contrast, our U-ViT is a ViT-based backbone with conceptually simple design, and meanwhile has a comparable performance if not superior to a CNN-based U-Net of a similar size (see Table 1 and Table 4).
Improvements of diffusion models. In addition to the backbone, there are also improvements on other aspects, such as fast sampling , improved training methodology and controllable generation .
Experiments
We evaluate the proposed U-ViT in unconditional and class-conditional image generation (Section 5.2), as well as text-to-image generation (Section 5.3). Before presenting these results, we list main experimental setup below, and more details such as the sampling hyperparameters are provided in Appendix A.
Datasets. For unconditional learning, we consider CIFAR10 , which contain 50K training images, and CelebA 6464 , which contain 162,770 training images of human faces. For class-conditional learning, we consider ImageNet at 6464, 256256 and 512512 resolutions, which contains 1,281,167 training images from 1K different classes. For text-to-image learning, we consider MS-COCO at 256256 resolution, which contains 82,783 training images and 40,504 validation images. Each image is annotated with 5 captions.
High resolution image generation. We follow latent diffusion models (LDM) for images at 256256 and 512512 resolutions. We firstly convert them to latent representations at 3232 and 6464 resolutions respectively, using a pretrained image autoencoder provided by Stable Diffusionhttps://github.com/CompVis/stable-diffusion . Then we model these latent representations using the proposed U-ViT.
Text-to-image learning. On MS-COCO, we convert discrete texts to a sequence of embeddings using a CLIP text encoder following Stable Diffusion. Then these embeddings are fed into U-ViT as a sequence of tokens.
U-ViT configurations. We identify several configurations of U-ViT in Table 2. In the rest of the paper, we use brief notation to represent the U-ViT configuration and the input patch size (for instance, U-ViT-H/2 means the U-ViT-Huge configuration with an input patch size of 22).
Training. We use the AdamW optimizer with a weight decay of 0.3 for all datasets. We use a learning rate of 2e-4 for most datasets, except ImageNet 6464 where we use 3e-4. We train 500K iterations on CIFAR10 and CelebA 6464 with a batch size of 128. We train 300K iterations on ImageNet 6464 and ImageNet 256256, and 500K iterations on ImageNet 512512, with a batch size of 1024. We train 1M iterations on MS-COCO with a batch size of 256. On ImageNet 256256, ImageNet 512512 and MS-COCO, we adopt classifier-free guidance following . We provide more details, such as the training time and how we choose hyperparameters in Appendix A.
2 Unconditional and Class-Conditional Image Generation
We compare U-ViT with prior diffusion models based on U-Net. We also compare with GenViT , a smaller ViT which does not employ long skip connections, and incorporates time before normalization layers. Consistent with previous literature, we report the FID score on 50K generated samples to measure the image quality.
As shown in Table 1, U-ViT is comparable to U-Net on unconditional CIFAR10 and CelebA 6464, and meanwhile performs much better than GenViT.
On class-conditional ImageNet 6464, we initially try the U-ViT-M configuration with 131M parameters. As shown in Table 1, it gets a FID of 5.85, which is better than 6.92 of IDDPM that employs a U-Net with 100M parameters. To further improve the performance, we employ the U-ViT-L configuration with 287M parameters, and the FID improves from 5.85 to 4.26.
Meanwhile, we find that our U-ViT performs especially well in the latent space , where images are firstly converted to their latent representations before applying diffusion models. On class-conditional ImageNet 256256, our U-ViT obtains a state-of-the-art FID of 2.29, which outperforms all prior diffusion models. Table 3 further demonstrates that our U-ViT outperforms LDM under different number of sampling steps using the same sampler. Note that our U-ViT also outperforms VQ-Diffusion, which is a discrete diffusion model that employs a transformer as its backbone. We also try replace our U-ViT with a U-Net with a similar amount of parameters and computational cost, where our U-ViT still outperforms U-Net (see details in Appendix E). On class-conditional ImageNet 512512, our U-ViT outperforms ADM-G that directly models the pixels of images. In Figure 4, we provide selected samples on ImageNet 256256 and ImageNet 512512, and random samples on other datasets, which have good quality and clear semantics. We provide more generated samples including class-conditional and random ones in Appendix F.
In Section 3.1 we have demonstrated the importance of long skip connection on small-scale dataset (i.e., CIFAR10). Figure 5 further shows it is also critical for large-scale dataset such as ImageNet.
In Appendix C, we present results of other metrics (e.g., sFID, inception score, precision and recall) as well as the computational cost (GFLOPs) with more U-ViT configurations on ImageNet. Our U-ViT is still comparable to state-of-the-art diffusion models on other metrics, and meanwhile has comparable if not smaller GFLOPs.
3 Text-to-Image Generation on MS-COCO
We evaluate U-ViT for text-to-image generation on the standard benchmark dataset MS-COCO. We train our U-ViT in the latent space of images as detailed in Section 5.1. We also train another latent diffusion model that employs a U-Net of comparable model size to U-ViT-S, and leave other parts unchanged. Its hyperparameters and training details are provided in Appendix B. We report the FID score to measure the image quality. Consistent with previous literature, we randomly draw 30K prompts from the MS-COCO validation set, and generate samples on these prompts to compute FID.
As shown in Table 4, our U-ViT-S already achieves a state-of-the-art FID among methods without accessing large external datasets during the training of generative models. By further increasing the number of layers from 13 to 17, our U-ViT-S (Deep) can even achieve a better FID of 5.48. Figure 6 shows generated samples of U-Net and U-ViT using the same random seed for a fair comparison. We find U-ViT generates more high quality samples, and meanwhile the semantics matches the text better. For example, given the text ‘‘a baseball player swinging a bat at a ball’’, U-Net generates neither the bat nor the ball. In contrast, our U-ViT-S generates the ball with even a smaller number of parameters, and our U-ViT-S (Deep) further generates the bat. We hypothesize this is because texts and images interact at every layer in our U-ViT, which is more frequent than U-Net that only interact at cross attention layer. We provide more samples in Appendix F.
Conclusion
This work presents U-ViT, a simple and general ViT-based architecture for image generation with diffusion models. U-ViT treats all inputs including the time, condition and noisy image patches as tokens and employs long skip connections between shallow and deep layers. We evaluate U-ViT in tasks including unconditional and class-conditional image generation, as well as text-to-image generation. Experiments demonstrate U-ViT is comparable if not superior to a CNN-based U-Net of a similar size. These results suggest that, for diffusion-based image modeling, the long skip connection is crucial while the down/up-sampling operators in CNN-based U-Net are not always necessary. We believe that U-ViT can provide insights for future research on backbones in diffusion models and benefit generative modeling on large scale cross-modality datasets.
Acknowledgments
This work was supported by NSF of China Projects (Nos. 62061136001, 61620106010, 62076145, U19B2034, U1811461, U19A2081, 6197222); Beijing Outstanding Young Scientist Program NO. BJJWZYJH012019100020098; a grant from Tsinghua Institute for Guo Qiang; the High Performance Computing Center, Tsinghua University; the Fundamental Research Funds for the Central Universities, and the Research Funds of Renmin University of China (22XNKJ13). C. Li was also sponsored by Beijing Nova Program. J.Z was also supported by the XPlorer Prize.
References
Appendix A Experimental Setup
We list the experimental setup for U-ViT presented in the main paper in Table 5.
In our early experiments, we try learning rates between 1e-4 and 5e-4, and find that a learning rate of 2e-4 performs well for all datasets. On ImageNet 6464, a learning rate of 3e-4 could further improve the performance. We try weight decay between 0.01 and 0.05, and find that a weight decay of 0.03 performs well for all datasets. We try the running coefficients , of AdamW among , and find that performs well for all datasets. On CIFAR10, could further improve the performance. On MS-COCO, could further improve the performance. We train with mixed precision for efficiency, and the training time and devices are listed in Table 6. Besides, the training memory of U-ViT can be greatly reduced with the gradient checkpointing trick. For example, the memory for forward and backward on a single A100 can be reduced from 53GB to 10GB when training U-ViT-L/2 with a batch size of 128 on ImageNet 256256.
During inference, with 1 A100, generating 500 samples with DPM-Solver takes around 19 seconds, 34 seconds, 59 seconds, 89 seconds, with U-ViT-S, U-ViT-M, U-ViT-L, U-ViT-H respectively. The time would double if classifier-free guidance is used.
Appendix B Details of the U-Net Baseline on MS-COCO
We employ the U-Net with cross attention provided by LDM for the baseline. The U-Net is performed on the 3232 resolution latent representation, and down-samples it to 1616, 88 and 44 resolution. The number of channels is 128 at 3232 resolution, and 256 at other resolutions. Each resolution has 2 residual blocks. The U-Net performs self attention and cross attention at 1616 and 88 resolution. Such a configuration leads to a total of 53M parameters, which is comparable to 45M of U-ViT-Small for a fair comparison. We use the AdamW optimizer with weight decay set to 0.01 and running coefficients , set to (0.9, 0.999), which are the setting used across LDM . We tune the learning rate of U-Net and find 2e-4 performs the best. The training iterations and the batch size of U-Net are the same to U-ViT for a fair comparison.
Appendix C Results of Other Metrics and Configurations on ImageNet
We present results of FID , sFID , inception score (IS) , precision and recall on ImageNet in Table 7. Our U-ViT is still comparable to state-of-the-art diffusion models based on U-Net on these metrics, and meanwhile has comparable if not smaller GFLOPs.
Appendix D CKA Analysis
We find that the “addition” and “no long skip connection” settings share a similar phenomenon that neighboring blocks in the network have similar representations, e.g., blocks 0-3, 6-11 in Figure 7 (b), and blocks 0-5, 6-11 in Figure 7 (c). In contrast, the representations of neighboring blocks under the “concatenation” setting have low similarity, as shown in Figure 7 (a). Thus, the “concatenation” setting significantly changes the representations in the transformer, while the “addition” setting does not.
Appendix E Compare with U-Net Under Similar Amount of Parameters and Computational Cost
On ImageNet 256256, we also try replace our U-ViT with a U-Net with a similar amount of parameters and computational cost. The U-Net employs implementation from ADM . We set the model channels as 320, the channel multiplier as (2, 2, 4), the number of residual blocks as 3, and employs attention at 2 and 4 down-sampling. This leads to a U-Net of 646M parameters and 135 GFLOPs, and our U-ViT has 501M parameters and 133 GFLOPs. We use the same optimizer configuration as ADM. As shown in Figure 8, our U-ViT consistently outperforms U-Net at different training iterations without classifier-free guidance. We also evaluate FID with 50K samples at 500K training iterations. With no classifier-free guidance, U-ViT obtains a FID of 6.58 and U-Net obtains a FID of 10.69. With a classifier-free guidance scale of 0.4, U-ViT obtains a FID of 2.29 and U-Net obtains a FID of 2.66. Under both settings, our U-ViT outperforms U-Net.