SD-DiT: Unleashing the Power of Self-supervised Discrimination in Diffusion Transformer
Rui Zhu, Yingwei Pan, Yehao Li, Ting Yao, Zhenglong Sun, Tao Mei, Chang Wen Chen
Introduction
Recent computer vision field has witnessed the rise of diffusion models in powerful and scalable generative architectures for image generation. Such practical generative model pushes the limits of a series of CV applications, including text-to-image synthesis , video generation , and 3D generation .
A recent pioneering practice is the Diffusion Transformer (DiT) , which inherits the impressive scaling properties of Transformers and significantly improves the capacity & scalability of diffusion models. Unfortunately, similar to Vision Transformers , the training of DiT usually suffers from slow convergence and heavy computation burden issues. The recent works then turn their focus on investigating the way to accelerate the training convergence of DiT. Many consider combining the Transformer-based diffusion process with additional mask reconstruction objective via the popular mask strategy . In particular, MDT simultaneously encodes both the complete and masked image input, in order to enhance the intra-image contextual learning among the associated patches. MaskDiT integrates generative diffusion process with mask reconstruction auxiliary task to optimize the whole DiT encoder and decoder (see Fig. 1(b)).
Although significantly improved training efficiency is attained, these DiT architectures with mask strategy still struggle with extremely high-fidelity image synthesis and suffer from several inherent limitations. (1) Training-inference discrepancy: Mask strategy inevitably introduces learnable mask tokens for triggering mask reconstruction during DiT training, but no artificial mask token is involved for generative diffusion process at inference. This training-inference discrepancy severely limits the generative capacity of learned DiT. Note that to alleviate such discrepancy, MDT introduces additional dual-path interaction between complete and masked inputs during training, while sacrificing much higher computational and memory cost. (2) Fuzzy relations between mask reconstruction & generative diffusion process: Most mask-based DiT structures process both the visible and learnable mask tokens via the same DiT decoder to jointly enable mask reconstruction and generative diffusion process, leaving the inherent different peculiarity of each objective not fully exploited. It is noteworthy that such mask modeling can be regarded as intra-image reconstruction derived from the same data distribution (e.g., from to for noised data in MaskDiT). Instead, the generative diffusion process aims to model the translations between the real data distribution and a different noised data distribution . This issue is also observed in MaskDiT, where mask reconstruction objective will gradually overwhelm generative objective at the late training stage. Accordingly, the joint training of the two distinct objectives with fuzzy relations results in sub-optimal training of DiT when applied to generative task.
To address these limitations, our work paves a new way to frame mask modeling of DiT training on the basis of discrimination knowledge distilling in a self-supervised fashion. We propose a novel Diffusion Transformer model with Self-supervised Discrimination, namely SD-DiT, that pursues highly-efficient learning of DiT with higher generative capacity. Technically, SD-DiT shapes the discrimination knowledge distilling in a teacher-student scheme. As shown in Fig. 1(a), the input discriminative pairs of teacher and student DiT encoders are derived from different diffusion noises (i.e., and along the same Probability Flow Ordinary Differential Equation (PF-ODE) of EDM ). More importantly, different from typical mask strategy that triggers mask reconstruction objective over both DiT encoder and decoder, SD-DiT decouples DiT encoder and decoder to separately perform discrimination knowledge distilling and generative diffusion process. Our launching point is to fully exploit the mutual but also fuzzy relations between self-supervised discrimination distillation and generative diffusion process through such decoupled DiT design. Eventually, we devise a new discriminative loss to enforce the inter-image alignment of encoded visible tokens between teacher and student DiT encoders in the joint embedding space. Next, SD-DiT only feeds student samples into student DiT decoder for performing the conventional generative diffusion objective. Note that here our discriminative loss can be interpreted as inter-image translation between teacher sample (approximately real data distribution ) and student sample (noised data distribution ), which better aligns with generative diffusion objective than conventional intra-image mask reconstruction objective. As such, the joint optimization of discriminative and generative diffusion objectives strengthens DiT training both effectively and efficiently.
In the meantime, the student branch (student DiT encoder plus decoder) in our decoupled DiT design completely retains the same regular noise in EDM and modules as in the generative modeling at inference. The additional teacher DiT encoder is simply updated as the Exponential Moving Average (EMA) of student DiT encoder in a light-weight fashion, without incurring a heavy computational burden for self-supervised discrimination. In this way, our SD-DiT not only preserves the training efficiency of mask modeling, but also elegantly circumvents the training-inference discrepancy issue.
The main contribution of this work is the proposal of Diffusion Transformer structure that fully unleashes the power of self-supervised discrimination to facilitate DiT training. This also leads to the elegant view of how a training-efficient DiT architecture should be designed for fully exploiting the mutual but also fuzzy relations between mask modeling and generative diffusion process, and how to bridge the training-inference discrepancy tailored to generative task. Through extensive experiments on ImageNet-256256, we demonstrate that our SD-DiT consistently seeks a better training speed-performance trade-off when compared to state-of-the-art DiT models.
Related Work
Diffusion Models. Denoising diffusion probabilistic models (DDPMs) greatly accelerate the development of generative models, especially the tasks of text conditioned image synthesis , image editing and personalized image generation . As a score-based model , DDPMs introduce a forward process to gradually add Gaussian noise to the data according to Stochastic Differential Equation , and the iterative denoising procedures are employed to generate high-quality samples. To tackle such a time-consuming iterative nature of DDPM, fast sampling strategies and training diffusion in the latent space are proposed. Besides, several innovations for improving the network architecture of diffusion models are attained to handle various challenging generation tasks. Convolutional UNet is the de-facto configuration from recent diffusion models and ADM further boosts the generation quality of UNet with scalable model size, including the adaptive group normalization , the attention blocks and the residual blocks from BigGAN .
Diffusion Transformers. Transformers provide a new paradigm to connect various domains across language , vision , and multi-modalities , with remarkable scaling properties in terms of model size and pre-training efficiency . Recently, some Transformer-based diffusion models are proposed to exploit the advantages of Transformer architecture in diffusion models. For example, GenViT first presents that Vision Transformer (ViT) has the potential for image generation. Based on ViT with long skip connections, U-ViT is specifically designed for the diffusion model which is characterized by integrating the time, the specific condition, and the noisy image patches as tokens. DiT systematically studies the scaling behaviors of Transformers under the Latent Diffusion Models (LDMs) framework, and achieves better generation quality than the U-Net counterparts with a scaling-up high-capacity backbone. In this work, we take the conventional DiT blocks as backbone network and the generative diffusion task is implemented as LDMs.
Self-supervised Learning with Diffusion Models. With the dominant status of Transformers in vision and language, the mask strategy from self-supervised learning has greatly propelled the development of generative models. Following the paradigm of bidirectional generative modeling , MaskGiT and MUSE aim at predicting randomly masked visual tokens which were first tokenized from images by a discrete VQ-VAE . Iterative decoding is further utilized to rapidly generate an image. Moreover, MAGE employs such masked token modeling to unify representation pre-training and image generation. On the other hand, diffusion models built upon Transformers could be well integrated with the mask image modeling . For example, inspired by MAE , MDT and Mask-DiT take advantage of the asymmetrical encoder-decoder of MAE and add the learning objective loss of reconstructing masked tokens (without discrete tokenizers) to the original generative diffusion loss. Such combination with mask modeling remarkably improves the training efficiency and the contextual reasoning ability of the Diffusion Transformer (i.e., DiT). It is noteworthy that mask modeling is built upon the intra-view reconstruction while the typical self-supervised methods with discriminative joint embedding pretraining focus on the inter-view alignment (invariance). Different from existing mask strategy with intra-image contextual learning, our SD-DiT paves a new way to endow mask modeling in DiT with self-supervised discrimination ability via inter-image alignment.
Approach
In this paper, we devise a Diffusion Transformer with Self-supervised Discrimination (SD-DiT) to frame mask modeling in efficient DiT training as self-supervised discrimination knowledge distilling. This section starts with a brief review of the preliminaries of diffusion models. Then, the overall decoupled architecture for discriminative and generative objectives is elaborated. After that, two different kinds of objectives for generative diffusion process and mask modeling, i.e., generative loss and discriminative loss, are introduced. Finally, the overall objective of SD-DiT at the training stage is provided.
Diffusion models introduce a forward process to progressively add Gaussian noise to the data distribution by a Stochastic Differential Equation (SDE) over time:
where and are the drift and diffusion coefficients, and is the standard Brownian motion. With the time flowing from 0 to , we denote the marginal distribution of as . Based on such an SDE, Song et al. define the probability flow ordinary differential equation (PF-ODE) in the reverse-time sample generation process:
Recent EDM proposes to add Gaussian noise with mean zero and standard deviation into the data distribution. Specifically, EDM utilizes instead of and configures and in Eq. 2. In this case, the resulting perturbed distribution is given by p_{\sigma}({{\boldsymbol{x}}})=p_{\text{data}}({\boldsymbol{x}})\ast\mathcal{N}\big{(}\mathbf{0},\sigma^{2}\mathbf{I}), where denotes the convolution operation. In other words, the real data can be directly diffused as:
And the corresponding PF-ODE in EDM is presented as:
where is the score function . As such, diffusion models are basically regarded as score-based generative models . To avoid numerical instability in ODE solving, is a small positive value and thus , while is large enough so that is close to a tractable Gaussian distribution. The training objective of EDM is to minimize the expected denoising loss for separately for each , by parameterizing a denoiser network as :
The estimated score function is thus measured as:
Based on the formulation of EDM, Consistency Models propose to learn a consistency function whose outputs of arbitrary pairs on the PF-ODE trajectory (Eq. 4) are consistent with . Formally, the consistency function is defined as:
and reflects an important property of self-consistency:
The diffusion noising schedule in our SD-DiT follows the basic formulation of EDM. And the discrimination objective in our SD-DiT is framed on the basis of the theory of the consistency function (Eq. 7).
2 Overall Architecture
The motivation of our SD-DiT is to exploit self-supervised discrimination to facilitate the efficient training of Diffusion Transformer. Fig. 2 illustrates the overall architecture of SD-DiT, which triggers mask modeling as discrimination knowledge distilling in a teacher-student scheme. Decoupled Encoder-Decoder Structure. Technically, our SD-DiT consists of teacher/student DiT encoders and one DiT decoder, and the core generative objective is framed on the basis of latent space as LDM . The additional discriminative objective is shaped as inter-image alignment among teacher and student DiT encoders in self-supervised joint-embedding space . Considering the fuzzy relations between mask modeling and generative diffusion process, here we leverage a decoupled encoder-decoder structure to perform the joint training of generative and discriminative objectives, rather than optimizing the whole encoder-decoder with mask reconstruction objective as in existing methods . Specifically, SD-DiT feeds the discriminative pairs into teacher and student DiT encoders to conduct discrimination knowledge distilling. After that, only student samples are fed into student DiT decoder to perform generative diffusion process. In this decoupled design, the discriminative objective only updates DiT encoder by empowering it with inter-image discriminative capacity. Meanwhile, DiT decoder is solely optimized with generative objective by retaining the same regular noise to nicely mimic the generative diffusion process at inference.
Discriminative Pairs. In an effort to trigger discriminative objective, we construct the input discriminative pairs based on the EDM formulation (Eq. 3). Since the student branch (including student DiT encoder and decoder) will perform both the generative and discriminative objectives, here the student view should be diffused regularly within a large range, similar to MaskDiT : . For the teacher view, we take inspiration from the InfoMin principle in self-supervised learning, and choose the fixed minimum noise of the consistency function to construct input samples: . As such, the noised distribution of teacher view can be the closest one to the original data distribution () and far away from the noised student view. Note that we empirically evaluate various teacher noise across in Sec. 4.4, and attain the similar observations as in InfoMin principle : The noised teacher view too close to the noised student view could be harmful to self-supervised discriminative learning. Accordingly, we use the fixed minimum noise for teacher view in practice.
3 Generative Objective
Inspired by the training efficiency and location contextual awareness brought by mask strategy, we follow the typical mask modeling techniques (e.g., ) to frame the generative objective via asymmetric encoder-decoder structure along the student branch.
Mask Strategy. The image will be divided into non-overlapping patches through the patch embedding layer of DiT. Let denote the binary random mask with the same size of non-overlapping patches. It is worth noting that MAE and MaskDiT additionally leverage the mask to learn additional mask tokens in mask reconstruction auxiliary task. Instead, our SD-DiT solely utilizes the mask to separate the noised student view into visible patches and invisible patches , where indicates element-wise multiplication on patches.
Student Branch. Given the visible and invisible patches via mask strategy, the student branch applies the typical asymmetric encoder-decoder architecture to improve the training efficiency. The student DiT encoder can be built with various DiT-Small/Base/XL backbones, while the lightweight student DiT decoder consists of a fixed number of blocks (i.e., 8 DiT blocks, similar to the configurations of MAE .). The student DiT encoder only operates over the visible patch and obtains the visible tokens . Then the student decoder is fed with the complete token set . Such an asymmetric paradigm with a high mask ratio proposed by MAE greatly reduces the training cost because the main computation burden is carried on the large-scale encoder.
Generative Loss. Recall that in existing mask modeling techniques (e.g., MAE and MaskDiT), the input token set of decoder commonly augments the visible tokens with learnable mask tokens, according to the positions of the mask . The mask reconstruction auxiliary task is included to recover the learnable mask tokens from the invisible patches . It is noteworthy that such mask reconstruction objective can benefit the representation learning, but leaves the inherent different peculiarity of mask modeling and generative objectives under-exploited. MaskDiT also points out the fuzzy relations between these two objectives, where mask reconstruction loss will gradually overwhelm the generative objective at the late training stage.
To alleviate this limitation, we discard the mask reconstruction loss and optimize the DiT decoder with only generative loss. Formally, for the complete token set , we remove the learnable mask tokens and directly insert the invisible patches onto the visible tokens , according to the positions of mask . Next, the generative loss operates over the compete tokens, which is measured in the form of EDM (Eq. 5):
where denotes student branch including the student DiT encoder and DiT decoder .
4 Discriminative Objective
Unlike typical mask modeling with mask reconstruction loss, our SD-DiT paves a new way to frame mask modeling of DiT training on the basis of discrimination knowledge distilling in a self-supervised manner. Inspired by self-distilling loss in ViT-based self-supervised methods (i.e., DINO and iBOT ), we design discriminative loss to enforce the inter-image alignment of encoded visible tokens between teacher and student DiT encoders.
Specifically, the teacher sample is fed into teacher DiT encoder , yielding the output tokens . Next, SD-DiT performs discriminative loss over the visible tokens between teacher and student in the joint encoding space. A three-layer projection head operates on and and outputs the softmax probability distribution over dimensions. By denoting the distribution on each student and teacher token as and ( indicates the index of visible tokens.), the softmax probability distribution of student is measured as:
where the student temperature controls the sharpness of the softmax distribution. A similar formulation also holds for teacher: with teacher temperature . For each visible token , the discrimination loss targets aligning the distribution between teacher and student by minimizing the cross-entropy loss:
The final discrimination loss is calculated over all visible patch tokens and the [CLS] token:
Besides, we adopt the centering technique in DINO to avoid feature collapse, where the batch mean statistic is used to whiten the features before softmax during each training iteration. For simplicity, here we leave the details of centering and the complete pseudo-codes to supplementary materials.
In summary, the overall training loss is the combination of discrimination loss and generative loss: . The parameters of student branch (student DiT encoder and decoder) are optimized by this overall loss. And teacher DiT encoder (parameterized as ) is updated as the exponential moving average (EMA) of student DiT encoder: Here is a momentum coefficient. During training, the teacher is updated by EMA without SGD back-propagation, thereby only requiring extremely lightweight computational cost. At inference, the teacher is completely removed and no burden is introduced.
Experiments
In this section, we provide the settings of model architecture, training setup, and evaluation details. We list the detailed configurations in supplementary material.
Model Architecture. The basic Transformer blocks in our backbone network fully adopt the DiT block which fuses conditional time and class embedding with adaptive layer normalization . We follow the paradigm of LDM and DiT to perform diffusion generation in the latent space of the frozen pre-trained VAE model , which downsamples a image into a latent variable. Inspired by , we adopt the asymmetric encoder-decoder for generative diffusion process. The student DiT encoder employs DiT-Small/Base/XL-2 (patch size: 2) and the small-scale DiT decoder contains 8 DiT blocks, similar to the configurations of MAE. For the discriminative objective, we mainly follow the settings of iBOT and DINO . The teacher DiT encoder is the EMA of student encoder, and the momentum coefficient increases from 0.996 to 0.999 at the end of training. The three-layer projection head outputs the [CLS] and patch tokens with dimension for softmax probability distribution in discriminative loss Eq. 11.
Training Setup. Following previous Transformer-based diffusion models , we conduct all the experiments on ImageNet-1K with 256256 resolution and a batch size of 256. We adopt the most common settings of DiT, e.g., AdamW optimizer with a constant learning rate and no weight decay. Without specified stating, the mask ratio is set to 0.2 on the student view, and no mask is applied on the teacher view. No data augmentation is employed for both student and teacher inputs since our model will learn the discrimination among various noised views. Notice that the mixed precision might lead to nanloss during training, so we only apply mixed precision for evaluation on small scale backbone (DiT-S) and transfer to full precision for large scale backbone (DiT-B and DiT-XL). All experiments are conducted on 8 80GB-A100 GPUs.
Evaluations. To evaluate both the diversity and quality of our generative model, we utilize the most commonly adopted Fréchet Inception Distance (FID) as evaluation metric. For fair comparison with previous works , we report FID-50K from ADM’s TensorFlow evaluation suite with the reference batch. We report the FID scores of the class-conditional sampling. Besides, we provide more supporting metrics including Inception Score (IS) , sFID and Precision/Recall .
2 Training Speed vs. Performance
Here we evaluate our SD-DiT with regard to both training speed and generative performance. Fig. 3 shows the training speed (i.e., training steps per second) and FID-50K score of SD-DiT in comparison to state-of-the-art DiT models (DiT , MDT , and MaskDiT ) on 8 A100 GPUs. For fair comparison, the backbone network of each run is built on the same scale of DiT-S/2, same batch size (256) and training iterations (400k). For SD-DiT and MaskDiT, we follow MDT and implement them with the same Float32 precision. For a comprehensive analysis, we also label each run with the number of input patches. As shown in Fig. 3, our SD-DiT (FID: 48.39; speed: 9.2 steps/sec or 0.11 sec/step) obtains better generative performance with faster training speed than MDT (FID: 53.46; speed: 2.4 steps/sec or 0.42 sec/step) and DiT (FID: 68.40, 5.03 steps/sec or 0.20 sec/step). This is due to that MDT simultaneously forwards and backwards both the complete (100%) and visible patches (70%), and DiT operates over the complete (100%) patches, thereby resulting in slower training speed. In contrast, our SD-DiT and MaskDiT only forward and backward partial patches (80%/50%), leading to faster training speed. Furthermore, unlike MaskDiT that optimizes the whole encoder-decoder with mask reconstruction objective, our SD-DiT adopts a decoupled encoder-decoder structure to better exploit the mutual but also fuzzy relations between generative and discriminative objectives, leading to the best FID-50K score. The results basically demonstrate the effectiveness of our SD-DiT which seeks a competitive training speed-performance trade-off.
3 Performance Comparison
Comparison among Backbones in Different Scales. Tab. 1 provides comprehensive comparisons between our SD-DiT and several DiT-based state-of-the-arts under three different model sizes (DiT-S/B/XL). Notice that Mask-DiT only conducts experiments on DiT-XL backbone so we do not report its results on DiT-S and DiT-B backbones. The batch size of all models is set as 256 for fair comparison. Specifically, under the same small-scale backbone (DiT-S), our SD-DiT-S (48.39) exhibits better performance than DiT-S (68.40) and MDT-S (53.46) by a large margin. This significant performance improvement of FID score is consistently observed when transferring to the larger scale backbones (DiT-B, DiT-XL). The results clearly validate the advantage of self-supervised discrimination knowledge distilling for mask modeling in Diffusion Transformer.
Comparison on Convergence Speed in Large Scale Backbone. Here we evaluate the convergence speed of our SD-DiT-XL/2 based on large-scale backbone. Fig. 4 illustrates the comparison of convergence speed by showing the FID scores in different training steps for our SD-DiT and various baselines. The batch size of each run is set as 256 for fair comparisons, and the maximum training step is 2400k. Note that the results of DiT and MaskDiT in different steps are directly copied from the reported results in MaskDiT . As shown in Fig. 4, SD-DiT persistently reflects better training convergence than DiT and MaskDiT across the whole training steps. The detailed performance comparisons against MDT are listed in Tab. 1, where our SD-DiT (FID: 9.01) brings higher results than MDT (FID: 9.60) and MaskDiT (FID: 12.15) with 1300k training steps. It is worthy noting that SD-DiT trained with 1300k steps outperforms typical DiT with 7000k steps (FID: 9.01 vs. 9.62), achieving about 5 faster training progress. In addition, SD-DiT (1100k steps) achieves a comparable FID performance with MDT (1300k steps) (9.66 vs. 9.60). Such fast convergence again confirms the power of self-supervised discrimination for facilitating DiT training.
Comparison with State-of-the-Art Generative Methods. Tab. 2 summarizes the performance comparison against state-of-the-art generative methods. We strictly follow MDT to list the cost comparison column as “IterBatchsize”. We follow the most DiT-based approaches and report the results in DiT-XL backbone with larger training iterations (2400k). Generally, under the same batch size of 256, our SD-DiT-XL/2 achieves a better FID score than DiT-XL/2 and MDT-XL/2. Although MaskDiT-XL/2 obtains the best FID score among all DiT-based methods, it benefits from the extremely large batch size of 1024. A more fair comparison between our SD-DiT and MaskDiT can be referred to Fig. 4, where each run is trained with the same batch size (256). In that figure, SD-DiT-XL/2 leads to consistent performance boost against MaskDiT-XL/2, which clearly validates our proposal.
4 Ablation Study
We conduct ablation study to examine each component in SD-DiT. Considering that DiT training is computationally expensive, we adopt a lightweight setting for efficient evaluation: using small scale backbone (DiT-S) with 400k training steps, bs 256 and mask ratio unless specified.
Effect of Discriminative Objective. Tab. 3 details the performances of ablated runs of our SD-DiT. Specifically, the first row shows the FID score (53.7) of our complete SD-DiT-S/2 with mask ratio. Next, by removing discriminative objective () and the corresponding teacher branch from SD-DiT (2nd row), the generative performance drops by a large margin. This demonstrates the merit of our self-supervised discrimination tailored to Diffusion Transformer. In addition, when removing mask strategy of SD-DiT (3rd row), a clear performance drop is attained, which highlights the effectiveness of mask strategy that triggers the learning of intra-image contextual awareness .
Effect of Mask Ratio. To further seek the sweet point of the balance between generative and discriminative task, we vary mask ratio from 0 to 1 and show the corresponding FID scores in Fig. 5. As shown in Fig. 5, the best performance of our SD-DiT is attained when the mask ratio is 20%, and thus we adopt this ratio practically in all experiments of Sec. 4.3. We additionally show the performance of MDT-S/2 trained with 600K steps under its optimal 30% mask ratio (the fixed green dashed line in Fig. 5, 50.3), which is inferior to our SD-DiT-S/2 with 400k steps (20% mask ratio, 48.4). Moreover, MaskDiT points out one interesting observation with regard to mask ratio: MaskDiT with 75% mask ratio achieves an extremely degraded FID score (121.16 of MaskDiT-XL/2). In other words, when 75% patches participate in the mask reconstruction task and only 25% local patches focus on the generative task, the generative ability of MaskDiT will be significantly weakened. This reveals the fuzzy relations between mask reconstruction and the generative task. Instead, in our SD-DiT-S/2, even when the mask ratio is increased to 90%, the corresponding FID score (61.0) is still higher than that of DiT-S/2 with 400k steps (68.40 in Tab. 1). These findings clearly verify that our design could alleviate the negative effect of fuzzy relations between mask modeling and generative task.
Effect of Noise of Teacher View. Recall that in our SD-DiT, the noise of student view is set as based on EDM formulation (Eq. 3). Following EDM and Consistency Model , we set and . Here we further test the effect of noise of teacher view. Specifically, we first set the noise of teacher view from the same distribution as student view, i.e., . As depicted in the yellow dashed line in Fig. 6, the corresponding FID (63.4) is somewhat unsatisfying. This result shows that teacher noise derived from the same distribution of student noise can not make the discriminative loss practical for generative task. Such observation aligns with InfoMin principle in self-supervised learning: reducing the mutual information between two variant views can bring a good pre-train model learning with sufficient view-invariance. That’s why we choose the fixed minimum noise as in Consistency Models for teacher view, i.e., . In this way, the noise distribution of teacher view can be the closest one to the original data distribution () and far away from student view. We empirically evaluate various teacher noise within (see the red curve in Fig. 6), and the fixed minimum noise (scale: 0.002) can get the best performance (53.7). Furthermore, we draw the approximate log-normal probability density distribution (PDF) of based on EDM (see the black dashed line in Fig. 6). When the fixed is set within the scale with high density (e.g., 0.3 and 0.5 close to the mean of ), the corresponding FID of SD-DiT drops drastically (e.g., 66.2 when ) and is even worse than the case of the same distribution (yellow dashed line). This again reveals that the noise scale of teacher view should be far away from the distribution of .
Conclusions
In this work, we propose a Diffusion Transformer architecture, namely SD-DiT, to facilitate the training process by unleashing the power of self-supervised discrimination. SD-DiT novelly frames mask modeling in a teacher-student manner to jointly execute discriminative and generative diffusion processes in a decoupled encoder-decoder structure. Such design nicely explores the mutual but also fuzzy relations between mask modeling and generative objective, leading to both effective and efficient DiT training. Experiments conducted on ImageNet validate the competitiveness of SD-DiT when compared to SOTA DiT-based approaches.
References
Appendix A Implementation Details
In contrast to DiT and MDT whose settings are derived from the ADM formulation , our SD-DiT employs the formulation of EDM in order to construct the discriminative pairs according to the theory of the consistency function (Eq.(7) in main paper) based on the PF-ODE (Eq. (4) in main paper) of EDM. Specifically, we adopt the EDM preconditioning parameterization by using a -dependent skip connectionPlease refer to EDM for more comprehensive details.:
This preconditioning parameterization is a common practice to avoid large variation in gradient magnitudes brought by various noise levels. As shown in Eq. 13, the denoiser is not directly employed as a neural network. Instead, a different network is trained to learn . In our SD-DiT, the student branch is wrapped as in Eq. 13 with skip connection preconditioning. For simplicity, we did not introduce this parameterization in the main paper. We follow the default hyper-parameters of EDM for the skip connection , the noise level and the input and output magnitudes . Besides, the student noise distribution follows the in EDM’s setting:
where and . Note that we draw the approximate log-normal probability density distribution (i.e., the black dashed line in Fig. 5 in main paper) of the corresponding according to this Eq. 14. During the sampling stage, we use the default time steps schedule of EDM:
where sampling steps , , and . Following EDM, we utilize the second-order Heun ODE solver for sampling. We follow the paradigm of LDM to perform diffusion generation in the latent space of the frozen pre-trained VAE model , which downsamples a image into a latent variable. More implementation details can be referred in Tab. 4.
Network parameters. The teacher-student design will double the parameters of a typical DiT. But at inference, the teacher network will be removed, and thus no parameter burden is introduced. In this sense, the model size of learned SD-DiT-XL/2 is 740.6M, which is comparable to MaskDiT-XL/2 (730.1M). During training, the additional teacher network is directly updated by EMA without SGD backward propagation, thereby only requiring extremely lightweight computational cost compared to standard backward propagation. At inference, the teacher network is completely removed and no burden is introduced.
Appendix B Additional Experimental Results
MaskDiT-XL/2 attains the best FID score with fewer training steps, attributed to a large batch size of 1024. For a more comprehensive comparison, we experiment by training SD-DiT-XL/2 with 1024 batch size, and the FID is 16.78 (150k steps), which is better than MaskDiT-XL/2 (FID: 17.22 at 150k steps) .
Comparison at higher iterations.
We experiment by training SD-DiT-XL/2 with higher iterations (3500k), and the FID is 6.74, which is comparable to MDT-XL/2 (FID: 6.65, 3500k) . It is worth noting that, compared to , our SD-DiT-XL/2 only uses 45GB memory per GPU with faster training speed (much lower than the memory requirement of MDT-XL/2 ), leading to a better computational cost-performance trade-off.
Classifier-free guidance (CFG) results.
We also experiment by upgrading our SD-DiT with CFG, and the FID of SD-DiT-XL/2 (+CFG) is 3.23, which is better than MaskDiT-XL/2 (without the unmask tuning stage) with CFG (FID: 4.54) .