MDTv2: Masked Diffusion Transformer is a Strong Image Synthesizer
Shanghua Gao, Pan Zhou, Ming-Ming Cheng, Shuicheng Yan
Introduction
Diffusion probabilistic models (DPMs) have been at the forefront of recent advances in image-level generative models, often surpassing the previously state-of-the-art (SoTA) generative adversarial networks (GANs) . Additionally, DPMs have demonstrated their success in numerous other applications, including text-to-image generation and speech generation . DPMs adopt a time-inverted Stochastic Differential Equation (SDE) to gradually map a Gaussian noise into a sample by multiple time steps, with each step corresponding to a network evaluation. In practice, generating a sample is time-consuming due to the thousands of time steps required for the SDE to converge. To address this issue, various generation sampling strategies have been proposed to accelerate the inference speed. Nevertheless, improving the training speed of DPMs is less explored but highly desired. Training of DPMs also unavoidably requires a large number of time steps to ensure the convergence of SDEs, making it very computationally expensive, especially in this era where large-scale models and data are often used to improve generation performance.
In this work, we first observe that DPMs often struggle to learn the associated relations among object parts in an image. This leads to its slow learning process during training. Specifically, as illustrated by \figreffig:converge_comp, the classical DPM, DDPM with DiT as backbone, has learned the shape of a dog at the 50k-th training step, then learns its one eye and mouth until at the 200k-th step while missing another eye. Also, the relative position of two ears is not very accurate even at the 300k-th step. This learning process reveals that DPMs fail to learn the associated relations among semantic parts and indeed independently learn each semantic. The reason behind this phenomenon is that DPMs maximize the log probability of real data by minimizing the per-pixel prediction loss, which ignores the associated relations among object parts in an image, thus resulting in their slow learning progress.
Inspired by the above observation, we propose an effective Masked Diffusion Transformer (MDT) to improve the training efficiency of DPMs. MDT proposes a mask latent modeling scheme designed for transformer-based DPMs to explicitly enhance contextual learning ability and improve the associated relation learning among semantics in an image. Specifically, following , MDT operates the diffusion process in the latent space to save computational cost. It masks certain image tokens, and designs an asymmetric masking diffusion transformer (AMDT) to predict masked tokens from unmasked ones in a diffusion generation manner. To this end, AMDT contains an encoder, a side-interpolater and a decoder. The encoder and decoder modify the transformer block in DiT via inserting global and local token position information to help predict masked tokens. The encoder only processes unmasked tokens during training while handling all tokens during inference as there are no masks. So to ensure decoder always processes all tokens for training prediction or inference generation, a side-interpolater implemented by a small network aims to predict masked tokens from encoder output during training, and it is removed during inference.
By this masking latent modeling scheme, our MDT can reconstruct full information of an image from its contextual incomplete input, learning the associated relations among semantics in an image. As shown in \figreffig:converge_comp, MDT typically generates two eyes (and two ears) of the dog at almost the same training steps, indicating that it correctly learns the associated semantics of an image by utilizing the mask latent modeling scheme. In contrast, DiT cannot easily synthesize a dog with the correct semantic relations among its parts. This comparison shows the superior relation modeling and faster learning ability of MDT over DiT. Experimental results demonstrate that MDT achieves superior performance on the image synthesis task, and set the new SoTA on class-conditional image synthesis on the ImageNet dataset, as shown in \figreffig:sample and \tabreftab:sota_comp. MDT also enjoys about 3 faster learning progress during training than the SoTA DPMs, namely DIT, as demonstrated by \figreffig:converge_comp and \tabreftab:ditvsmdt. We hope our work can inspire more works on speeding up the diffusion training process with unified representation learning.
The main contributions are summarised as follows:
Our proposed masked diffusion transformer introduces an effective mask latent modeling scheme into DPMs and also accordingly designs an asymmetric masking diffusion transformer. Our method is the first one that aims to explicitly enhance contextual learning ability and improve the relation learning among image semantics for DPMs.
Experiments show that our masked diffusion transformer enjoys higher performance on image synthesis and greatly improves the learning progress during training. It achieves the new SoTA for image synthesis.
Related works
Diffusion probabilistic model (DPM) , also known as score-based model , is a competitive approach for image synthesis. DPMs begin by using an evolving Stochastic Differential Equation (SDE) to gradually add Gaussian noise into real data, transforming a complex data distribution into a Gaussian distribution. Then, it adopts a time-inverted SDE to gradually map a Gaussian noise into a sample by multiple steps. At each sampling time step, a network is utilized to generate the sample along the gradient of the log probability, also known as the score function . The iterative nature of diffusion models can result in high training and inference costs. Efficient sampling strategies , latent space diffusion , and multi-resolution cascaded generation have been proposed to reduce the inference cost. Additionally, some training schemes are introduced to improve the diffusion model training, \egapproximate maximum likelihood training , training loss weighting . In contrast to them optimizing the diffusion training process, we identify the lack of contextual modeling ability in diffusion models. To address this, we propose the mask latent modeling scheme as a complementary approach to enhancing the contextual representation of diffusion models, which is orthogonal to existing diffusion training schemes.
2 Networks for Diffusion Models
The UNet-like network, enhanced by spatial self-attention and group normalization is firstly used for diffusion models . Several design improvements, \egadding more attention heads, BigGAN residual block, and adaptive group normalization, are proposed in to further enhance the generation ability of the UNet. Recently, due to the broad applicability of transformer networks, several works have attempted to utilize the vision transformer (ViT) structure for diffusion models . GenViT demonstrates that ViT is capable of image generation but has inferior performance compared to UNet. U-ViT improves ViT by adding long-skip connections and convolutional layers, achieving competitive performance with that of UNet. DiT verifies the scaling ability of ViT on large model sizes and feature resolutions. Our MDT is orthogonal to these diffusion networks as it focuses on contextual representation learning. Moreover, the position-aware designs in MDT reveal that the mask latent modeling scheme benefits from a stronger diffusion network. We will explore further to release the potential of these networks with MDT.
3 Mask Modeling
Mask modeling has been proven to be effective in both recognition learning and generative modeling . In the natural language processing (NLP) field, mask modeling was first introduced to enable representation pretraining and language generation . Subsequently, it also proved feasible for vision recognition and generation tasks. In vision recognition, pretraining schemes that utilize mask modeling enable good representation quality , scalability and faster convergence . In generative modeling, following the bi-directional generative modeling in NLP, MaskGIT and MUSE use the masked generative transformer to predict randomly masked image tokens for image generation. Similarly, VQ-Diffusion presents a mask-replace diffusion strategy to generate images. In contrast, our MDT aims to enhance the contextual representation of the denoising diffusion transformer with mask latent modeling. This preserves the detail refinement ability of denoising diffusion models by maintaining the diffusion process during inference. To ensure that the mask latent modeling in MDT focuses on representation learning instead of reconstruction, we propose an asymmetrical structure in mask modeling training. As an extra benefit, it enables lower training costs than masked generative models because it skips the masked patches in training instead of replacing masked input patches with a mask token.
Revisitation of Diffusion Probabilistic Model
For diffusion probabilistic models , such as DDPM and DDIM , training involves a forward noising process and a reverse denoising process. In the forward noising process, Gaussian noise is gradually added to the real sample via a discrete SDE of formulation , where denotes the noise magnitude. If the time step is large, would be a Gaussian noise. Similarly, the reverse denoising process is a discrete SDE that gradually maps a Gaussian noise into a sample. At each time step, given , it predicts the next reverse step via a network. The network is trained by optimizing the variational lower-bound of , where . Following , we obtain by optimizing , and we reparameterize as a noise prediction network and train it with a simple mean-squared error loss, \ie, where is the ground truth Gaussian noise. During inference, one can sample a Gaussian noise and then gradually reverses to a sample .
Same as , we train the diffusion model conditioned with class label , \ie. By default, we use class-conditioned image generation in our experiments.
Masked Diffusion Transformer
As shown in \figreffig:converge_comp, DPMs with DiT backbone exhibit slow training convergence due to the slowly learning of the associated relations among semantics in an image. To relieve this issue, we propose Masked Diffusion Transformer (MDT), which introduces a mask latent modeling scheme to explicitly enhance contextual learning ability and to improve the capability of establishing associated relations among different semantics in an image. To this end, as depicted in in \figreffig:overall, MDT consists of 1) a latent masking operation to mask the input image in the latent space, and 2) an asymmetric masking diffusion transformer that performs vanilla diffusion process as DPMs, but with masked input. To reduce computational costs, MDT follows LatentDiffusion to perform generative learning in the latent space instead of raw pixel space.
In the training phase, MDT first encodes an image into a latent space with a pre-trained VAE encoder . The latent masking operation in MDT then adds Gaussian noise into the image latent embedding, patchifies the resulting noisy latent embedding into a sequence of tokens, and masks certain tokens. The remaining unmasked tokens are fed into the asymmetric masking diffusion transformer which contains an encoder, a side-interpolater, and a decoder to predict the masked tokens from the unmasked ones. During inference, MDT replaces the side-interpolater with a position embedding adding operation. MDT takes the latent embedding of a Gaussian noise as input to generate the denoised latent embedding, which is then passed to a pre-trained VAE decoder for image generation.
The above masking latent modeling scheme in the training phase forces the diffusion model to reconstruct full information of an image from its contextual incomplete input. Thereby, the model is encouraged to learn the relations among image latent tokens, particularly the associated relations among semantics in an image. For example, as illustrated in \figreffig:overall, the model should first well understand the correct associated relations among small image parts (tokens) of the dog image. Then, it should generate the masked “eye” tokens by using other unmasked tokens as contextual information. Furthermore, \figreffig:converge_comp shows that MDT often learns to generate the associated semantics of an image at nearly the same pace, such as the generation of the two eyes (two ears) of the dog at the almost same training step. While DiT (DDPM with transformer backbone) learns to generate one eye (one ear) initially and then learns to generate another eye (ear) after roughly 100k training steps. This demonstrates the superior learning ability of MDT over DiT in terms of the associated relation learning of image semantics.
In the following parts, we will introduce the two key components of MDT, 1) a latent masking operation, and 2) an asymmetric masking diffusion transformer.
2 Latent Masking
Following the Latent diffusion model (LDM) , MDT performs generation learning in the latent space instead of raw pixel space to reduce computational costs. In the following, we briefly recall LDM and then introduce our latent masking operation on the latent input.
3 Asymmetric Masking Diffusion Transformer
We introduce our asymmetric masking diffusion transformer for performing joint training of mask latent modeling and diffusion process. As shown in \figreffig:block, it consists of three components: an encoder, a side-interpolater and a decoder, each of which is described in detail at below.
Position-aware encoder and decoder. In MDT, predicting the masked latent tokens from the unmasked tokens requires the position relations of all tokens. To enhance the position information in the model, we propose positional-aware encoder and decoder that facilitate the learning of the masked latent tokens. Specifically, the encoder and decoder tailor the standard DiT block via adding two kinds of token position information, and respectively contain and tailored blocks.
Firstly, as illustrated in \figreffig:block, the encoder adds the conventional learnable global position embedding into the noisy latent embedding input. Similarly, the decoder also introduces the learnable position embedding into its input but with different approaches in the training and inference phases. During training, the side-interpolater already uses the learnable global position embedding as introduced below, which can deliver the global position information to the decoder. During inference, since the side interpolater is discarded (see below), the decoder explicitly adds the position embedding into its input to enhance positional information.
Secondly, as shown in \figreffig:block, the encoder and decoder add a local relative positional bias to each head in each block when computing the attention score of the self-attention :
The encoder takes the unmasked noisy latent embedding provided by our latent masking operation, and feeds its output into the side-interpolater/decoder during training/inference. For decoder, its input is the output of side-interpolater for training or the combination of the encoder output and the learnable position embedding for inference. Since during training, the encoder and decoder respectively handle unmasked tokens and full tokens, we call our model as the “asymmetric” model.
Side-interpolater. As shown in Fig. 3, during training, for efficiency and better performance, the encoder only processes the unmasked tokens . While in the inference phase, the encoder handles all tokens due to the lack of masks. This means that there is a big difference in encoder output (\iedecoder input) during training and inference, at least in terms of token number. To ensure decoder always processes all tokens for training prediction or inference generation, side-interpolater implemented by a small network aims to predict masked tokens from encoder output during training and would be removed during inference.
Since there are no masks during inference, the side-interpolater is replaced by a position embedding operation which adds the learnable position embeddings of side-interpolater which is learned during training. This ensures the decoder always processes all tokens and uses the same learnable position embeddings for training prediction or inference generation, thus having better image generation performance.
4 Training
During training, we feed both full latent embedding and the masked latent embedding to the diffusion model, since we observe that only using masked latent embedding makes the model focus too much on masked region reconstruction while ignoring the diffusion training. The training objectives for full/masked latent inputs both follow the description in \secrefsec:ddpm. Due to the asymmetrical masking structure, the extra costs for using masked latent embedding is small. This is also demonstrated by \figreffig:converge_comp which shows that MDT still achieves about 3 faster learning progress than previous SoTA DiT in terms of total training hours.
Experiments
We give the implementation details of MDT, including model architecture, training details, and evaluation metrics.
Model architecture. We follow DiT to set the total block number (\ie ), token number, and channel numbers of the diffusion transformer of MDT. As DiT reveals stronger synthesis performance when using a smaller patch size, we also use a patch size =2 by default, denoted by MDT-/2. Moreover, We also follow DiT’s parameters to design MDT for getting its small-, base-, and xlarge-sized model, denoted by MDT-S/B/XL. Same as LatentDiffusion and DiT, MDT adopts the fixed VAEThe model is downloaded in https://huggingface.co/stabilityai/sd-vae-ft-mse provided by the Stable Diffusion to encode/decode the image/latent tokens by default. The VAE encoder has a downsampling ratio of , and a feature channel dimension of 4, \eg an image of size 2562563 is encoded into a latent embedding of size 32324.
Training details. Following , all models are trained by AdamW optimizer of 3e-4 learning rate, 256 batch size, and without weight decay on ImageNet with an image resolution of 256256. We set the mask ratio as 0.3 and . Following the training settings in DiT, we set the maximum step in training to 1000 and use the linear variance schedule with a range from to . Other settings are also aligned with DiT.
Evaluation. We evaluate models with commonly used metrics, \ieFre’chet Inception Distance (FID) , sFID , Inception Score (IS) , Precision and Recall . The FID is used as the major metric as it measures both diversity and fidelity. sFID improves upon FID by evaluating at the spatial level. As a complement, IS and Precision are used for measuring fidelity, and Recall is used to measure diversity. For fair comparisons, we follow to use the TensorFlow evaluation suite from ADM and report FID-50K with 250 DDPM sampling steps. Unless specified otherwise, we report the FID scores without the classifier-free guidance .
2 Comparison Results
Performance comparison. \tabreftab:ditvsmdt compares our MDT with the SoTA DiT under different model sizes. It is evident that MDT achieves higher FID scores for all model scales with fewer training costs. The parameters and inference cost of MDTs are similar to DiT, since the extra modules in MDT are negligible as introduced in \secrefsec:mdt_method. For small models, MDT-S/2 trained with 300k steps outperforms the DiT-S/2 trained with 400k steps by a large margin on FID (57.01 vs. 68.40). More importantly, MDT-S/2 trained with 2000k steps achieves similar performance with a larger model DiT-B/2 trained with a similar computational budget. For the largest model, MDT-XL/2 trained with 1300k steps outperforms DiT-XL/2 trained with 7000k steps on FID (9.60 vs. 9.62), achieving about 5 faster training progress.
We also compare the class-conditional image generation performance of MDT with existing methods in \tabreftab:sota_comp. To make fair comparisons with DiT, we also use the EMA weights of VAE decoder in this table. Under class-conditional settings, MDT with half training iterations outperforms DiT by a large margin, \eg6.83 vs 9.62 in FID. Following previous works , we utilize an improved classifier-free guidance with a power-cosine weight scaling to trade off between precision and recall during class-conditional sampling. MDT achieves superior performance over previous SoTA DiT and other methods with the FID score of 1.81, setting a new SoTA for class-conditional image generation. Similar to DiT, we never observe the model has saturated FID scores when continuing training.
Convergence speed. \figreffig:converge_comp compares the performance of the DiT/S-2 baseline and MDT/S-2 under different training steps and training time on 8A100 GPUs. Because of the stronger contextual learning ability, MDT achieves better performance with faster generation learning speed. MDT enjoys about 3 faster learning speed in terms of both training steps and training time. For example, MDT-S/2 trained with about 33 hours (400k steps) achieves superior performance than DiT-S/2 trained with about 100 hours (1500k steps). This reveals that contextual learning is vital for faster generation learning of diffusion models.
3 Ablation
In this part, we conduct ablation to verify the designs in MDT. We report the results of MDT-S/2 model and use FID-50k as the evaluation metric unless otherwise stated.
Masking ratio. The masking ratio determines the number of input patches that can be processed during training. We give the comparison of using different masking ratios in \tabreftab:maskratio. The best masking ratio for MDT-S/2 is 30%, which is quite different from the masking ratio used for recognition models, e.g. 75% masking ratio in MAE . We assume that the image generation requires learning more details from more patches for high-quality synthesis, while recognition models only need the most essential patches to infer semantics.
Asymmetric vs. Symmetric architecture in masking. Unlike the masked generation works , \egMaskGIT, that utilize the masking scheme to generate images, MDT focuses on improving diffusion models with contextual learning ability via the masking latent modeling. Therefore, we use an asymmetric architecture to only process the unmasked tokens in the diffusion model encoder. We compare the asymmetric architecture in MDT and the symmetrical architecture that processes full input with masked tokens replaced by a learnable mask token. As shown in \tabreftab:asymmetric_mask, the asymmetric architecture in MDT has an FID of 50.26, outperforming the FID of 51.56 achieved by the symmetric architecture. The asymmetric architecture further reduces the training cost and allows the diffusion model to focus on learning contextual information instead of reconstructing masked tokens.
Full and masked latent tokens. In MDT, both the full and masked latent embeddings are fed into the diffusion model during training. In comparison, we give the results trained by only using full/masked latent embeddings as shown in \tabreftab:used_latent, where the computational cost is aligned for fair comparisons. Trained with both full and masked latent leads to clear gain over two competitors. While using only the masked latent embeddings results in slow convergence, which we attribute to the training/inference inconsistency as the inference in MDT is a diffusion process instead of the masked reconstruction process.
Loss on all tokens. By default, we calculate the loss on both masked and unmasked latent embeddings. In comparison, mask modeling for recognition models commonly calculates loss on masked tokens . \tabreftab:mask_sup shows that calculating the loss on all tokens is much better than on masked tokens. We assume that this is because generative models require stronger consistency among patches than recognition models do, since details are vital for high-quality image synthesis.
Effect of side-interpolater. The side-interpolater in MDT predicts the masked tokens, allowing the diffusion model to learn more semantics and maintain consistency in decoder inputs during training and inference. We compare the performance with/without the side-interpolater in \tabreftab:side_interpolater, and observe a gain of 1.34 in FID when using the side-interpolater, proving its effectiveness.
Masked shortcut in side-interpolater. The masked shortcut ensures that the side-interpolater only predicts the masked tokens from unmasked ones. \tabreftab:masking_shortcut shows that using the masked shortcut enhances the FID from 50.91 to 50.26, indicating that restricting side-interpolater to only predict masked tokens helps the diffusion model achieve stronger performance.
Side-interpolater position. To meet the high-quality image generation requirements of the diffusion model, the side-interpolater is placed in the middle of the network instead of the end of the network in recognition models . \tabreftab:decoder_pos, presents the comparison of placing the side-interpolater at different positions of the MDT-S model with 12 blocks. The results show that placing the side-interpolater before the last two blocks achieves the best FID score, whereas placing it at the end of the network like recognition models impairs the performance. Placing the side-interpolater at the early stages of the network also harm the performance, indicating the mask latent modeling is beneficial to most stages in the diffusion models.
Block number in side-interpolater. We compare the performance of different numbers of blocks in the side-interpolater in \tabreftab:num_block_si. The default setting of 1 block achieves the best performance, and the FID worsens with an increase in block number. This result is consistent with our motivation that side-interpolater should not learn too much information other than interpolating the masked representations.
Positional-aware enhancement. To further release the potential of mask latent modeling, we enhance the DiT baseline with stronger positional awareness ability, \ielearnable positional embeddings and the relative positional bias in basic blocks. \tabreftab:pos_embed_si shows the positional embeddings in side-interpolater improves the FID from 51.58 to 50.26, indicating the positional embedding is vital for the side-interpolater. Also, enables the training of positional embeddings brings the gain in FID as revealed in \tabreftab:learn_pos. In \tabreftab:rel_bias, the relative positional bias in the basic blocks significantly improves the FID from 53.56 to 50.26, showing the relative positional modeling ability is essential for diffusion models to obtain the contextual representation ability and generate high-quality images. Therefore, the positional awareness ability in diffusion model structure is required to accompany the masked latent modeling, playing a key role in improving performance.
Conclusion
This work proposes a masked diffusion transformer to enhance the contextual representation and improve the relation learning among image semantics for DPMs. We introduce an effective mask latent modeling scheme into DPMs and also accordingly designs an asymmetric masking diffusion transformer. Experiments show that our masked diffusion transformer enjoys higher performance on image synthesis and largely improves the learning progress during training, achieving the new SoTA for image synthesis on the ImageNet dataset.
References
Appendix A Model Details
Network configurations. We follow the network configurations described in DiT to set the total block number (\ie ), token number, and channel numbers for the masked diffusion transformer of MDT. The configurations of MDT models are given in \tabreftab:model_config. Following DiT, the MDT has models with different sizes, denoted by S/B/XL.
Network parameters and costs. The network parameters and training costs for MDT under different model scales are listed in \tabreftab:model_config. In comparison to DiT baselines, MDT introduces a negligible extra inference parameters and costs.
Appendix B Comparison of VAE decoders
To ensure fair comparisons with DiT , we use both the MSE and EMA versions of pretrained VAE decodersMSE and EMA versions of VAE models are downloaded in https://huggingface.co/stabilityai/sd-vae-ft-mse and https://huggingface.co/stabilityai/sd-vae-ft-ema. for image sampling. \tabreftab:emavsmse shows that the EMA version has slightly better performance than the MSE version. Except for the results in Table 1 of the manuscript that uses the EMA VAE decoder, we use the MSE VAE decoder by default.
Appendix C Inpainting with MDT
By default, MDT uses the mask latent modeling during training and becomes the standard diffusion model during inference. When the side-interpolater is kept during inference, MDT naturally enables the image inpainting ability. As shown in \figreffig:inpainting, we utilize different mask ratios on the image and inpaint the masked parts with MDT. Although the MDT model is trained with the mask ratio of 30%, it can easily handle much larger masking ratios, such as 70% mask ratio. We attribute this ability to the combination of our proposed mask latent modeling and the diffusion model.
Appendix D Improved Classifier-free Guidance
The classifier-free guidance sampling enables the trade-off between sample quality and diversity. It achieves this by combining the class-conditional and unconditional estimation:
where is the class-conditional estimation, is the unconditional estimation, and is the guidance scale. Generally, a larger results in high sample quality by decreasing the diversity. MUSE changes the fixed guidance scale with a linear increasing schedule during sampling, which makes the model samples with more diversity at early steps while samples with higher fidelity at late steps. Inspired by this, we present a power-cosine schedule for the guidance scale during the sampling procedure:
where is the time step during sampling, is the maximum sampling step, is the maximum guidance scale, and is a factor that controls the increasing speed of the guidance scale. As revealed in \figreffig:powercos, the power-cosine schedule enables a low guidance scale at early steps while quickly increasing the guidance scale at late steps. By increasing , the guidance scale has a slow increase at early steps and a fast increase at late steps. The improved classifier-free guidance sampling equipped with the power-cosine guidance scale schedule enables the model samples with high diversity at early steps and high quality at late steps. In this work, is set to 4, and the corresponding is set to 3.8 to ensure the model generates images with high fidelity at late steps.
Appendix E Visualization
We provide more visualized examples of MDT-XL/2 generated images in \tabreffig:samplemore. In \tabreffig:inc, we show more visualized examples of MDT-S/2 along with training steps.