VEEGAN: Reducing Mode Collapse in GANs using Implicit Variational Learning
Akash Srivastava, Lazar Valkov, Chris Russell, Michael U. Gutmann, Charles Sutton
Introduction
Deep generative models are a topic of enormous recent interest, providing a powerful class of tools for the unsupervised learning of probability distributions over difficult manifolds such as natural images . Deep generative models are usually implicit statistical models , also called implicit probability distributions, meaning that they do not induce a density function that can be tractably computed, but rather provide a simulation procedure to generate new data points. Generative adversarial networks (GANs) are an attractive such method, which have seen promising recent successes . GANs train two deep networks in concert: a generator network that maps random noise, usually drawn from a multi-variate Gaussian, to data items; and a discriminator network that estimates the likelihood ratio of the generator network to the data distribution, and is trained using an adversarial principle. Despite an enormous amount of recent work, GANs are notoriously fickle to train, and it has been observed that they often suffer from mode collapse, in which the generator network learns how to generate samples from a few modes of the data distribution but misses many other modes, even though samples from the missing modes occur throughout the training data.
To address this problem, we introduce \acronym,\acronymis a Variational Encoder Enhancement to Generative Adversarial Networks. https://akashgit.github.io/VEEGAN/ a variational principle for estimating implicit probability distributions that avoids mode collapse. While the generator network maps Gaussian random noise to data items, \acronymintroduces an additional reconstructor network that maps the true data distribution to Gaussian random noise. We train the generator and reconstructor networks jointly by introducing an implicit variational principle, which encourages the reconstructor network not only to map the data distribution to a Gaussian, but also to approximately reverse the action of the generator. Intuitively, if the reconstructor learns both to map all of the true data to the noise distribution and is an approximate inverse of the generator network, this will encourage the generator network to map from the noise distribution to the entirety of the true data distribution, thus resolving mode collapse.
Background
where indicates an expectation over the standard normal , indicates an expectation over the data distribution , and denotes the sigmoid function. At the optimum, in the limit of infinite data and arbitrarily powerful networks, we will have , where is the density that is induced by running the network on normally distributed input, and hence that .
Unfortunately, GANs can be difficult and unstable to train . One common pathology that arises in GAN training is mode collapse, which is when samples from capture only a few of the modes of . An intuition behind why mode collapse occurs is that the only information that the objective function provides about is mediated by the discriminator network . For example, if is a constant, then is constant with respect to , and so learning the generator is impossible. When this situation occurs in a localized region of input space, for example, when there is a specific type of image that the generator cannot replicate, this can cause mode collapse.
Method
The main idea of \acronymis to introduce a second network that we call the reconstructor network which is learned both to map the true data distribution to a Gaussian and to approximately invert the generator network.
To understand why this might prevent mode collapse, consider the example in Figure 1. In both columns of the figure, the middle vertical panel represents the data space, where in this example the true distribution is a mixture of two Gaussians. The bottom panel depicts the input to the generator, which is drawn from a standard normal distribution , and the top panel depicts the result of applying the reconstructor network to the generated and the true data. The arrows labeled show the action of the generator. The purple arrows labelled show the action of the reconstructor on the true data, whereas the green arrows show the action of the reconstructor on data from the generator. In this example, the generator has captured only one of the two modes of . The difference between Figure 1(a) and 1(b) is that the reconstructor networks are different.
First, let us suppose (Figure 1(a)) that we have successfully trained so that it is approximately the inverse of . As we have assumed mode collapse however, the training data for the reconstructor network does not include data items from the “forgotten" mode of therefore the action of on data from that mode is ill-specified. This means that is unlikely to be Gaussian and we can use this mismatch as an indicator of mode collapse.
Conversely, let us suppose (Figure 1(b)) that is successful at mapping the true data distribution to a Gaussian. In that case, if mode collapses, then will not map all back to the original and the resulting penalty provides us with a strong learning signal for both and .
While this objective captures the main idea of our paper, it cannot be easily computed and minimised. We next transform it into a computable version and derive theoretical guarantees.
Let us denote the distribution of the outputs of the reconstructor network when applied to a fixed data item by and when applied to all by . The conditional distribution is Gaussian with unit variance and, with a slight abuse of notation, (deterministic) mean function . The entropy term can thus be written as
This cross entropy is minimized with respect to when . Unfortunately, the integral on the right-hand side of (2) cannot usually be computed in closed form. We thus introduce a variational distribution and by Jensen’s inequality, we have
which we use to bound the cross-entropy in (2). In variational inference, strong parametric assumptions are typically made on . Importantly, we here relax that assumption, instead representing implicitly as a deep generative model, enabling us to learn very complex distributions. The variational distribution plays exactly the same role as the generator in a GAN, and for that reason, we will parameterize as the output of a stochastic neural network .
In practice minimizing this bound is difficult if is specified implicitly. For instance, it is challenging to train a discriminator network that accurately estimates the unknown likelihood ratio , because as a conditional distribution, is much more peaked than the joint distribution , making it too easy for a discriminator to tell the two distributions apart. Intuitively, the discriminator in a GAN works well when it is presented a difficult pair of distributions to distinguish. To circumvent this problem, we write (see supplementary material)
Here all expectations are taken with respect to the joint distribution
Combining these two ideas, we obtain the final objective function
Suppose that there exist parameters such that where denotes Shannon entropy. Then minimizes , and further
Because neural networks are universal approximators, the conditions in the proposition can be achieved when the networks and are sufficiently powerful.
2 Learning with Implicit Probability Distributions
This subsection describes how to approximate when we have implicit representations for and rather than explicit densities. In this case, we cannot optimize directly, because the KL divergence in (5) depends on a density ratio which is unknown, both because is implicit and also because is unknown. Following , we estimate this ratio using a discriminator network which we will train to encourage
This will allow us to estimate as
where . In this equation, note that is a function of ; although we suppress this in the notation, we do take this dependency into account in the algorithm. We use an auxiliary objective function to estimate . As mentioned earlier, we omit the entropy term from as it is constant with respect to all parameters. In principle, any method for density ratio estimation could be used here, for example, see . In this work, we will use the logistic regression loss, much as in other methods for deep adversarial training, such as GANs , or for noise contrastive estimation . We will train to distinguish samples from the joint distribution from . The objective function for this is
where denotes expectation with respect to the joint distribution and with respect to . We write to indicate the Monte Carlo estimate of . Our learning algorithm optimizes this pair of equations with respect to using stochastic gradient descent. In particular, the algorithms aim to find a simultaneous solution to and . This training procedure is described in Algorithm 1. When this procedure converges, we will have that , which means that has converged to the likelihood ratio (6). Therefore have also converged to a minimum of .
We also found that pre-training the reconstructor network on samples from helps in some cases.
Relationships to Other Methods
An enormous amount of attention has been devoted recently to improved methods for GAN training, and we compare ourselves to the most closely related work in detail.
BiGAN and Adversarially Learning Inference (ALI) are two essentially identical recent adversarial methods for learning both a deep generative network and a reconstructor network . Likelihood-free variational inference (LFVI) extends this idea to a hierarchical Bayesian setting. Like \acronym, all of these methods also use a discriminator on the joint space. However, the \acronymobjective function provides significant benefits over the logistic regression loss over and that is used in ALI/BiGAN, or the KL-divergence used in LFVI.
In all of these methods, just as in vanilla GANs, the objective function depends on and only via the output of the discriminator; therefore, if there is a mode of data space in which is insensitive to changes in and , there will be mode collapse. In \acronym, by contrast, the reconstruction term does not depend on the discriminator, and so can provide learning signal to or even when the discriminator is constant. We will show in Section 5 that indeed \acronymis dramatically less prone to mode collapse than ALI.
InfoGAN
While differently motivated to obtain disentangled representation of the data, InfoGAN also uses a latent-code reconstruction based penalty in its cost function. But unlike \acronym, only a part of the latent code is reconstructed in InfoGAN. Thus, InfoGAN is similar to VEEGAN in that it also includes an autoencoder over the latent codes, but the key difference is that InfoGAN does not also train the reconstructor network on the true data distribution. We suggest that this may be the reason that InfoGAN was observed to require some of the same stabilization tricks as vanilla GANs, which are not required for VEEGAN.
Adversarial Methods for Autoencoders
Experiments
Quantitative evaluation of GANs is problematic because implicit distributions do not have a tractable likelihood term to quantify generative accuracy. Quantifying mode collapsing is also not straightforward, except in the case of synthetic data with known modes. For this reason, several indirect metrics have recently been proposed to evaluate GANs specifically for their mode collapsing behavior . However, none of these metrics are reliable on their own and therefore we need to compare across a number of different methods. Therefore in this section we evaluate \acronymon several synthetic and real datasets and compare its performance against vanilla GANs , Unrolled GAN and ALI on five different metrics. Our results strongly suggest that \acronymdoes indeed resolve mode collapse in GANs to a large extent. Generally, we found that \acronymperformed well with default hyperparameter values, so we did not tune these. Full details are provided in the supplementary material.
Mode collapse can be accurately measured on synthetic datasets, since the true distribution and its modes are known. In this section we compare all four competing methods on three synthetic datasets of increasing difficulty: a mixture of eight 2D Gaussian distributions arranged in a ring, a mixture of twenty-five 2D Gaussian distributions arranged in a grid Experiment follows . Please note that for certain settings of parameters, vanilla GAN can also recover all 25 modes, as was pointed out to us by Paulina Grnarova. and a mixture of ten 700 dimensional Gaussian distributions embedded in a 1200 dimensional space. This mixture arrangement was chosen to mimic the higher dimensional manifolds of natural images. All of the mixture components were isotropic Gaussians. For a fair comparison of the different learning methods for GANs, we use the same network architectures for the reconstructors and the generators for all methods, namely, fully-connected MLPs with two hidden layers. For the discriminator we use a two layer MLP without dropout or normalization layers. \acronymmethod works for both deterministic and stochastic generator networks. To allow for the generator to be a stochastic map we add an extra dimension of noise to the generator input that is not reconstructed.
To quantify the mode collapsing behavior we report two metrics: We sample points from the generator network, and count a sample as high quality, if it is within three standard deviations of the nearest mode, for the 2D dataset, or within 10 standard deviations of the nearest mode, for the 1200D dataset. Then, we report the number of modes captured as the number of mixture components whose mean is nearest to at least one high quality sample. We also report the percentage of high quality samples as a measure of sample quality. We generate samples from each trained model and average the numbers over five runs. For the unrolled GAN, we set the number of unrolling steps to five as suggested in the authors’ reference implementation.
As shown in Table 1, \acronymcaptures the greatest number of modes on all the synthetic datasets, while consistently generating higher quality samples. This is visually apparent in Figure 2, which plot the generator distributions for each method; the generators learned by \acronymare sharper and closer to the true distribution. This figure also shows why it is important to measure sample quality and mode collapse simultaneously, as either alone can be misleading. For instance, the GAN on the 2D ring has sample quality, but this is simply because the GAN collapses all of its samples onto one mode (Figure 2(b)). On the other extreme, the unrolled GAN on the 2D grid captures almost all the modes in the true distribution, but this is simply because that it is generating highly dispersed samples (Figure 2(i)) that do not accurately represent the true distribution, hence the low sample quality. All methods had approximately the same running time, except for unrolled GAN, which is a few orders of magnitude slower due to the unrolling overhead.
2 Stacked MNIST
Following , we evaluate our methods on the stacked MNIST dataset, a variant of the MNIST data specifically designed to increase the number of discrete modes. The data is synthesized by stacking three randomly sampled MNIST digits along the color channel resulting in a 28x28x3 image. We now expect modes in this data set, corresponding to the number of possible triples of digits.
Again, to focus the evaluation on the difference in the learning algorithms, we use the same generator architecture for all methods. In particular, the generator architecture is an off-the-shelf standard implementationhttps://github.com/carpedm20/DCGAN-tensorflow of DCGAN .
For Unrolled GAN, we used a standard implementation of the DCGAN discriminator network. For ALI and \acronym, the discriminator architecture is described in the supplementary material. For the reconstructor in ALI and \acronym, we use a simple two-layer MLP for the reconstructor without any regularization layers.
Finally, for \acronymwe pretrain the reconstructor by taking a few stochastic gradient steps with respect to before running Algorithm 1. For all methods other than \acronym, we use the enhanced generator loss function suggested in , since we were not able to get sufficient learning signals for the generator without it. \acronymdid not require this adjustment for successful training.
As the true locations of the modes in this data are unknown, the number of modes are estimated using a trained classifier as described originally in . We used a total of samples for all the models and the results are averaged over five runs. As a measure of quality, following again, we also report the KL divergence between the generator distribution and the data distribution. As reported in Table 2, \acronymnot only captures the most modes, it consistently matches the data distribution more closely than any other method. Generated samples from each of the models are shown in the supplementary material.
3 CIFAR
Finally, we evaluate the learning methods on the CIFAR-10 dataset, a well-studied and diverse dataset of natural images. We use the same discriminator, generator, and reconstructor architectures as in the previous section. However, the previous mode collapsing metric is inappropriate here, owing to CIFAR’s greater diversity. Even within one of the 10 classes of CIFAR, the intra-group diversity is very high compared to any of the 10 classes of MNIST. Therefore, for CIFAR it is inappropriate to assume, as the metrics of the previous subsection do, that each labelled class corresponds to a single mode of the data distribution.
As shown in Table 2, ALI and \acronymachieve the best IvOM. Qualitatively, however, generated samples from \acronymseem better than other methods. In particular, the samples from \acronym+DAE are meaningless. Generated samples from \acronymare shown in 3(b); samples from other methods are shown in the supplementary material. As another illustration of this, Figure 3 illustrates the IvOM metric, by showing the nearest neighbors to real images that each of the GANs were able to generate; in general, the nearest neighbors will be more semantically meaningful than randomly generated images. We omit \acronym+DAE from this table because it did not produce plausible images. Across the methods, we see in Figure 3 that \acronymcaptures small details, such as the face of the poodle, that other methods miss.
Conclusion
We have presented \acronym, a new training principle for GANs that combines a KL divergence in the joint space of representation and data points with an autoencoder over the representation space, motivated by a variational argument. Experimental results on synthetic data and real images show that our approach is much more effective than several state-of-the art GAN methods at avoiding mode collapse while still generating good quality samples.
Acknowledgement
We thank Martin Arjovsky, Nicolas Collignon, Luke Metz, Casper Kaae Sønderby, Lucas Theis, Soumith Chintala, Stanisław Jastrzębski, Harrison Edwards, Amos Storkey and Paulina Grnarova for their helpful comments. We would like to specially thank Ferenc Huszár for insightful discussions and feedback.