MMD GAN: Towards Deeper Understanding of Moment Matching Network
Chun-Liang Li, Wei-Cheng Chang, Yu Cheng, Yiming Yang, Barnabás Póczos
Introduction
In this work, we try to improve GMMN and consider using MMD with adversarially learned kernels instead of fixed Gaussian kernels to have better hypothesis testing power. The main contributions of this work are:
In Section 3, we propose a practical realization called MMD GAN that learns generator with the adversarially trained kernel. We further propose a feasible set reduction to speed up and stabilize the training of MMD GAN.
In Section 5, we show that MMD GAN is computationally more efficient than GMMN, which can be trained with much smaller batch size. We also demonstrate that MMD GAN has promising results on challenging datasets, including CIFAR-10, CelebA and LSUN, where GMMN fails. To our best knowledge, we are the first MMD based work to achieve comparable results with other GAN works on these datasets.
Finally, we also study the connection to existing works in Section 4. Interestingly, we show Wasserstein GAN is the special case of the proposed MMD GAN under certain conditions. The unified view shows more connections between moment matching and GAN, which can potentially inspire new algorithms based on well-developed tools in statistics . Our experiment code is available at https://github.com/OctoberChang/MMD-GAN.
GAN, Two-Sample Test and GMMN
GMMN: One example of characteristic kernel is Gaussian kernel . Based on Theorem 1, propose generative moment-matching network (GMMN), which trains by
with a fixed Gaussian kernel rather than training an additional discriminator as GAN.
2 Properties of MMD with Kernel Learning
If is parameterized by a feed-forward neural network, it satisfies Assumption 2 and can be trained via gradient descent as well as propagation, since the objective is continuous and differentiable followed by Theorem 3. More technical discussions are shown in Appendix B.
MMD GAN
To approximate (3), we use neural networks to parameterized and with expressive power. For , the assumption is locally Lipschitz, where commonly used feed-forward neural networks satisfy this constraint. Also, the gradient has to be bounded, which can be done by clipping or gradient penalty . The non-trivial part is has to be injective. For an injective function , there exists an function such that and Note that injective is not necessary invertible., which can be approximated by an autoencoder. In the following, we denote to be the parameter of discriminator networks, which consists of an encoder , and train the corresponding decoder to regularize . The objective (3) is relaxed to be
Note that we ignore the autoencoder objective when we train , but we use (4) for a concise presentation. We note that the empirical study suggests autoencoder objective is not necessary to lead the successful GAN training as we will show in Section 5, even though the injective property is required in Theorem 1.
The proposed algorithm is similar to GAN , which aims to optimize two neural networks and in a minmax formulation, while the meaning of the objective is different. In , is a discriminator (binary) classifier to distinguish two distributions. In the proposed algorithm, distinguishing two distribution is still done by two-sample test via MMD, but with an adversarially learned kernel parametrized by . is then trained to pass the hypothesis test. More connection and difference with related works is discussed in Section 4. Because of the similarity of GAN, we call the proposed algorithm MMD GAN. We present an implementation with the weight clipping in Algorithm 1, but one can easily extend to other Lipschitz approximations, such as gradient penalty .
Encoding Perspective of MMD GAN: Besides from using kernel selection to explain MMD GAN, the other way to see the proposed MMD GAN is viewing as a feature transformation function, and the kernel two-sample test is performed on this transformed feature space (i.e., the code space of the autoencoder). The optimization is finding a manifold with stronger signals for MMD two-sample test. From this perspective, is the special case of MMD GAN if is the identity mapping function. In such circumstance, the kernel two-sample test is conducted in the original data space.
With Theorem 5, we could reduce the feasible set of during the optimization by solving
which the optimal solution is still equivalent to solving (2).
However, it is hard to solve the constrained optimization problem with backpropagation. We relax the constraint by ordinal regression to be
which only penalizes the objective when the constraint is violated. In practice, we observe that reducing the feasible set makes the training faster and stabler.
Related Works
There has been a recent surge on improving GAN . We review some related works here.
Connection with WGAN: If we composite with linear kernel instead of Gaussian kernel, and restricting the output dimension to be , we then have the objective
Parameterizing and with neural networks and assuming such , recovers Wasserstein GAN (WGAN) Theoretically, they are not equivalent but the practical neural network approximation results in the same algorithm.. If we treat as the data transform function, WGAN can be interpreted as first-order moment matching (linear kernel) while MMD GAN aims to match infinite order of moments with Gaussian kernel form Taylor expansion . Theoretically, Wasserstein distance has similar theoretically guarantee as Theorem 1, 3 and 4. In practice, show neural networks does not have enough capacity to approximate Wasserstein distance. In Section 5, we demonstrate matching high-order moments benefits the results. also propose McGAN that matches second order moment from the primal-dual norm perspective. However, the proposed algorithm requires matrix (tensor) decompositions because of exact moment matching , which is hard to scale to higher order moment matching. On the other hand, by giving up exact moment matching, MMD GAN can match high-order moments with kernel tricks. More detailed discussions are in Appendix B.3.
Difference from Other Works with Autoencoders: Energy-based GANs also utilizes the autoencoder (AE) in its discriminator from the energy model perspective, which minimizes the reconstruction error of real samples while maximize the reconstruction error of generated samples . In contrast, MMD GAN uses AE to approximate invertible functions by minimizing the reconstruction errors of both real samples and generated samples . Also, show EBGAN approximates total variation, with the drawback of discontinuity, while MMD GAN optimizes MMD distance. The other line of works aims to match the AE codespace , and utilize the decoder . match the distribution of and via different distribution distances and generate data (e.g. image) by . use MMD to match and , and generate data via . The proposed MMD GAN matches the and , and generates data via directly as GAN. is similar to MMD GAN but it considers KL-divergence without showing continuity and weak∗ topology guarantee as we prove in Section 2.
Other GAN Works: In addition to the discussed works, there are several extended works of GAN. proposes using the linear kernel to match first moment of its discriminator’s latent features. considers the variance of empirical MMD score during the training. Also, only improves the latent feature matching in by using kernel MMD, instead of proposing an adversarial training framework as we studied in Section 2. uses Wasserstein distance to match the distribution of autoencoder loss instead of data. One can consider to extend to higher order matching based on the proposed MMD GAN. A parallel work use energy distance, which can be treated as MMD GAN with different kernel. However, there are some potential problems of its critic. More discussion can be referred to .
Experiment
We train MMD GAN for image generation on the MNIST , CIFAR-10 , CelebA , and LSUN bedrooms datasets, where the size of training instances are 50K, 50K, 160K, 3M respectively. All the samples images are generated from a fixed noise random vectors and are not cherry-picked.
Network architecture: In our experiments, we follow the architecture of DCGAN to design by its generator and by its discriminator except for expanding the output layer of to be dimensions.
Kernel designs: The loss function of MMD GAN is implicitly associated with a family of characteristic kernels. Similar to the prior MMD seminal papers , we consider a mixture of RBF kernels where is a Gaussian kernel with bandwidth parameter . Tuning kernel bandwidth optimally still remains an open problem. In this works, we fixed and to be and left the to learn the kernel (feature representation) under these .
Hyper-parameters: We use RMSProp with learning rate of for a fair comparison with WGAN as suggested in its original paper . We ensure the boundedness of model parameters of discriminator by clipping the weights point-wisely to the range as required by Assumption 2. The dimensionality of the latent space is manually set according to the complexity of the dataset. We thus use for MNIST, for CelebA, and for CIFAR-10 and LSUN bedrooms. The batch size is set to be for all datasets.
We start with comparing MMD GAN with GMMN on two standard benchmarks, MNIST and CIFAR-10. We consider two variants for GMMN. The first one is original GMMN, which trains the generator by minimizing the MMD distance on the original data space. We call it as GMMN-D. To compare with MMD GAN, we also pretrain an autoencoder for projecting data to a manifold, then fix the autoencoder as a feature transformation, and train the generator by minimizing the MMD distance in the code space. We call it as GMMN-C.
The results are pictured in Figure 1. Both GMMN-D and GMMN-C are able to generate meaningful digits on MNIST because of the simple data structure. By a closer look, nonetheless, the boundary and shape of the digits in Figure 1(a) and 1(b) are often irregular and non-smooth. In contrast, the sample digits in Figure 1(c) are more natural with smooth outline and sharper strike. For CIFAR-10 dataset, both GMMN variants fail to generate meaningful images, but resulting some low level visual features. We observe similar cases in other complex large-scale datasets such as CelebA and LSUN bedrooms, thus results are omitted. On the other hand, the proposed MMD GAN successfully outputs natural images with sharp boundary and high diversity. The results in Figure 1 confirm the success of the proposed adversarial learned kernels to enrich statistical testing power, which is the key difference between GMMN and MMD GAN.
If we increase the batch size of GMMN to , the image quality is improved, however, it is still not competitive to MMD GAN with . The images are put in Appendix C. This demonstrates that the proposed MMD GAN can be trained more efficiently than GMMN with smaller batch size.
Comparisons with GANs: There are several representative extensions of GANs. We consider recent state-of-art WGAN based on DCGAN structure , because of the connection with MMD GAN discussed in Section 4. The results are shown in Figure 2. For MNIST, the digits generated from WGAN in Figure 2(a) are more unnatural with peculiar strikes. In Contrary, the digits from MMD GAN in Figure 2(d) enjoy smoother contour. Furthermore, both WGAN and MMD GAN generate diversified digits, avoiding the mode collapse problems appeared in the literature of training GANs. For CelebA, we can see the difference of generated samples from WGAN and MMD GAN. Specifically, we observe varied poses, expressions, genders, skin colors and light exposure in Figure 2(b) and 2(e). By a closer look (view on-screen with zooming in), we observe that faces from WGAN have higher chances to be blurry and twisted while faces from MMD GAN are more spontaneous with sharp and acute outline of faces. As for LSUN dataset, we could not distinguish salient differences between the samples generated from MMD GAN and WGAN.
2 Quantitative Analysis
To quantitatively measure the quality and diversity of generated samples, we compute the inception score on CIFAR-10 images. The inception score is used for GANs to measure samples quality and diversity on the pretrained inception model . Models that generate collapsed samples have a relatively low score. Table 1 lists the results for samples generated by various unsupervised generative models trained on CIFAR-10 dataset. The inception scores of are directly derived from the corresponding references.
Although both WGAN and MMD GAN can generate sharp images as we show in Section 5.1, our score is better than other GAN techniques except for DFM . This seems to confirm empirically that higher order of moment matching between the real data and fake sample distribution benefits generating more diversified sample images. Also note DFM appears compatible with our method and combing training techniques in DFM is a possible avenue for future work.
3 Stability of MMD GAN
4 Computation Issue
We conduct time complexity analysis with respect to the batch size . The time complexity of each iteration is for WGAN and for our proposed MMD GAN with a mixture of RBF kernels. The quadratic complexity of MMD GAN is introduced by computing kernel matrix, which is sometimes criticized for being inapplicable with large batch size in practice. However, we point that there are several recent works, such as EBGAN , also matching pairwise relation between samples of batch size, leading to complexity as well.
Empirically, we find that under GPU environment, the highly parallelized matrix operation tremendously alleviated the quadratic time to almost linear time with modest . Figure 3 compares the computational time per generator iterations versus different on Titan X. When , which is adapted for training MMD GAN in our experiments setting, the time per iteration of WGAN and MMD GAN is 0.268 and 0.676 seconds, respectively. When , which is used for training GMMN in its references , the time per iteration becomes 4.431 and 8.565 seconds, respectively. This result coheres our argument that the empirical computational time for MMD GAN is not quadratically expensive compared to WGAN with powerful GPU parallel computation.
5 Better Lipschitz Approximation and Necessity of Auto-Encoder
We used weight-clipping for Lipschitz constraint in Assumption 2. Another approach for obtaining a discriminator with the similar constraints that approximates a Wasserstein distance is , where the gradient of the discriminator is constrained to be 1 between the generated and data points. Inspired by , an alternative approach is to apply the gradient constraint as a regularizer to the witness function for a different IPM, such as the MMD. This idea was first proposed in for the Energy Distance (in parallel to our submission), which was shown in to correspond to a gradient penalty on the witness function for any RKHS-based MMD. Here we undertake a preliminary investigation of this approach, where we also drop the requirement in Algorithm 1 for to be injective, which we observe that it is not necessary in practice. We show some preliminary results of training MMD GAN with gradient penalty and without the auto-encoder in Figure 5. The preliminary study indicates that MMD GAN can generate satisfactory results with other Lipschitz constraint approximation. One potential future work is conducting more thorough empirical comparison studies between different approximations.
Discussion
We introduce a new deep generative model trained via MMD with adversarially learned kernels. We further study its theoretical properties and propose a practical realization MMD GAN, which can be trained with much smaller batch size than GMMN and has competitive performances with state-of-the-art GANs. We can view MMD GAN as the first practical step forward connecting moment matching network and GAN. One important direction is applying developed tools in moment matching on general GAN works based the connections shown by MMD GAN. Also, in Section 4, we connect WGAN and MMD GAN by first-order and infinite-order moment matching. shows finite-order moment matching () achieves the best performance on domain adaption. One could extend MMD GAN to this by using polynomial kernels. Last, in theory, an injective mapping is necessary for the theoretical guarantees. However, we observe that it is not mandatory in practice as we show in Section 5.5. One conjecture is it usually learns the injective mapping with high probability by parameterizing with neural networks, which worth more study as a future work.
Acknowledgments
We thank the reviewers for their helpful comments. We also thank Dougal Sutherland and Arthur Gretton for their valuable feedbacks and pointing out an error in the previous draft. This work is supported in part by the National Science Foundation (NSF) under grants IIS-1546329 and IIS-1563887.
References
Appendix A Technical Proof
Since MMD is a probabilistic metric , we have the triangular inequality for every . Therefore,
In this, we consider Gaussian kernel , therefore
for all and . Similarly, . Combining the above claim with (7) and bounded convergence theorem, we have
By Mean Value Theorem, . Incorporating it with (8) and triangular inequality, we have
Now let be locally Lipschitz. For a given pair there is a constant and an open set such that for every we have
A.2 Proof of Theorem 4
The proof utilizes parts of results from .
(⇐)⇐(\Leftarrow)
A.3 Proof of Theorem 5
Appendix B Property of MMD with Fixed and Learned Kernels
One can simplify Theorem 3 and its proof for standard MMD distance to show MMD is also continuous and differentiable almost everywhere. In , they propose a counterexample to show the discontinuity of MMD by assuming . However, it is known that is not in RKHS, so the discussed counterexample is not appropriate.
B.2 IPM Framework
From integral probability metrics (IPM), the probabilistic distance can be defined as
By changing the function class , we can recover several distances, such as total variation, Wasserstein distance and MMD distance. From , the discriminator in different existing works of GAN can be explained to be used to solve different probabilistic metrics based on (10). For MMD, the function class is , where is RKHS associated with kernel . Different form many distances, such as total variation and Wasserstein distance, there is an analytical representation as we show in Section 2, which is
Because of the analytical representation of (10), GMMN does not need an additional network for estimating the distance.
Here we also provide an explanation of the proposed MMD with adversarially learned kernel under IPM framework. The MMD distance with adversarially learned kernel is represented as
where . From this perspective, the proposed MMD distance with adversarially learned kernel is still defined by IPM but with a larger function class.
B.3 MMD is an Efficient Moment Matching
In MMD, with polynomial kernel , the MMD distance is
which is inexact moment matching because the second term contains the quadratic of the first moment. It is difficult to match high-order moments, because we have to deal with high order tensors directly. On the other hand, MMD can easily match high-order moments (even infinite order moments by using Gaussian kernel) with kernel tricks, and enjoys strong theoretical guarantee.