Training generative neural networks via Maximum Mean Discrepancy optimization
Gintare Karolina Dziugaite, Daniel M. Roy, Zoubin Ghahramani
Introduction
In this paper, we consider the problem of learning generative models from i.i.d. data with unknown distribution . We formulate the learning problem as one of finding a function , called the generator, such that, given an input drawn from some fixed noise distribution , the distribution of the output is close to the data’s distribution . Note that, given and , we can easily generate new samples despite not having an explicit representation for the underlying density.
We are particularly interested in the case where the generator is a deep neural network whose parameters we must learn. Rather than being used to classify or predict, these networks transport input randomness to output randomness, thus inducing a distribution. The first direct instantiation of this idea is due to , although MacKay draws connections even further back to the work of and others on autoencoders, suggesting that generators can be understood as decoders. MacKay’s proposal, called density networks, uses multi-layer perceptrons (MLP) as generators and learns the parameters by approximating Bayesian inference.
Since MacKay’s proposal, there has been a great deal of progress on learning generative models, especially over high-dimensional spaces like images. Some of the most successful approaches have been based on restricted Boltzmann machines and deep Boltzmann networks . A recent example is the Neural Autoregressive Density Estimator due to . An indepth survey, however, is beyond the scope of this article.
This work builds on a proposal due to . Their adversarial nets framework takes an indirect approach to learning deep generative neural networks: a discriminator network is trained to recognize the difference between training data and generated samples, while the generator is trained to confuse the discriminator. The resulting two-player game is cast as a minimax optimization of a differentiable objective and solved greedily by iteratively performing gradient descent steps to improve the generator and then the discriminator.
Given the greedy nature of the algorithm, give a careful prescription for balancing the training of the generator and the discriminator. In particular, two gradient steps on the discriminator’s parameters are taken for every iteration of the generator’s parameters. It is not clear at this point how sensitive this balance is as the data set and network vary. In this paper, we describe an approximation to adversarial learning that replaces the adversary with a closed-form nonparametric two-sample test statistic based on the Maximum Mean Discrepancy (MMD), which we adopted from the kernel two sample test . We call our proposal MMD nets.In independent work reported in a recent preprint, Li, Swersky, and Zemel also propose to use MMD as a training objective for generative neural networks. We leave a comparison to future work. We give bounds on the estimation error incurred by optimizing an empirical estimator rather than the true population MMD and give some illustrations on synthetic and real data.
Learning to sample as optimization
where is some measure of discrepancy and is the distribution of when . In practice, we only have i.i.d. samples from , and so we optimize an empirical estimate of .
where and . In this case, Eq. 1 becomes
for and . The output of the discriminator can be interpreted as the probability it assigns to its input being drawn from , and so is the expected log loss incurred when classifying the origin of a point equally likely to have been drawn from or . Therefore, optimizing maximizes the probability of distinguishing samples from and . Assuming that the optimal discriminator exists for every , the optimal generator is that whose output distribution is closest to , as measured by the Jensen–Shannon divergence, which is minimized when .
In , the generators and discriminators are chosen to be multilayer perceptrons (MLP). In order to find a minimax solution, they propose taking alternating gradient steps along and . Note that the composition that appears in the value function is yet another (larger) MLP. This fact permits the use of the back-propagation algorithm to take gradient steps.
2 MMD as an adversary
In their paper introducing adversarial nets, remark that a balance must be struck between optimizing the generator and optimizing the discriminator. In particular, the authors suggest maximization steps for every one minimization step to ensure that is well synchronized with during training. A large value for , however, can lead to overfitting. In their experiments, for every step taken along the gradient with respect to , they take two gradient steps with respect to to bring closer to the desired optimum (Goodfellow, pers. comm.).
It is unclear how sensitive this balance is. Regardless, while adversarial networks deliver impressive sampling performance, the optimization takes approximately 7.5 hours to train on the MNIST dataset running on a nVidia GeForce GTX TITAN GPU with 6GB RAM. Can we potentially speed up the process with a more tractable choice of adversary?
where and . See Fig. 1 for a comparison of the architectures of adversarial and MMD nets.
While Eq. 2 involves a maximization over a family of functions, show that it can be solved in closed form when is a reproducing kernel Hilbert space (RKHS).
More carefully, let be a reproducing kernel Hilbert space (RKHS) of real-valued functions on and let denote its inner product. By the reproducing property it follows that there exists a reproducing kernel such that every can be expressed as
The functions induced by a kernel are those functions in the closure of the span of the set , which is necessarily an RKHS. Note, that for every positive definite kernel there is a unique RKHS such that every function in satisfies Eq. 3.
If is chosen to be an RKHS , then
where is the mean embedding of , given by
and satisfying, for all ,
In practice, we often do not have access to or . Instead, we are given independent i.i.d. data and fom and , respectively, and would like to estimate the MMD. showed that
MMD Nets
Note that is comprised of only those parts of the unbiased estimator that depend on .
In practice, the minimization is solved by gradient descent, possibly on subsets of the data. More carefully, the chain rule gives us
Each derivative is easily computed for standard kernels like the RBF kernel. Our gradient depends on the partial derivatives of the generator with respect to its parameters, which we can compute using back propagation.
Generalization bounds for MMD
MMD nets operate by minimizing an empirical estimate of the MMD. This estimate is subject to Monte Carlo error and so the network weights (parameters) that are found to minimize the empirical MMD may do a poor job at minimizing the exact population MMD. We show that, for sufficiently large data sets, this estimation error is bounded, despite the space of parameters being continuous and high dimensional.
Let denote the space of possible parameters for the generator , let be the distribution on for the noisy inputs, and let be the distribution of when for . Let be the value optimizing the unbiased empirical MMD estimate, i.e.,
and let be the value optimizing the population MMD, i.e.,
We are interested in bounding the difference
To that end, for a measured space , write for the space of essentially bounded functions on and write for the unit ball under the sup norm, i.e.,
The bounds we obtain will depend on a notion of complexity captured by the fat-shattering dimension:
For every , the fat-shattering dimension of , written , is defined as
We then have the following bound on the estimation error:
Assume the kernel is bounded by one. Define
for constants and depending on and alone.
Empirical evaluation
In this section, we demonstrate the approach on an illustrative synthetic example as well as the standard MNIST digits and Toronto Face Dataset (TFD) benchmarks. We show that MMD-based optimization of the generator rapidly delivers a generator that performs well in maximizing the density of a held-out test set under a kernel-density estimator.
Under an RBF kernel and Gaussian generator with parameters , it is straightforward to find the gradient of by applying the chain rule. Using fixed random standard normal numbers , we have for . The result of these illustrative synthetic experiments can be found in Fig. 1. The dataset consisted of samples from a standard normal and noise input samples were generated from a standard normal with a fixed random seed. The algorithm was initialized at values . We fixed the learning rate to and ran gradient descent steps for iterations.
2 MNIST digits
We trained our generative network on the MNIST digits dataset . The generator was chosen to be a fully connected, 3 hidden layers neural network with sigmoidal activation functions. Following , we used a radial basis function (RBF) kernel, but also evaluated the rational quadratic (RQ) kernel and Laplacian kernel, but found that the RBF performed best in the parameter ranges we evaluated. We used Bayesian optimization (WHETLab) to set the bandwidth of the RBF and the number of neurons in each layer on initial test runs of 50,000 iterations. We used the median heuristic suggested by for the kernel two-sample test to choose the kernel bandwidth. The learning rate was adjusting during optimization by RMSPROP .
Fig. 2 presents the digits learned after 1,000,000 iterations. We performed minibatch stochastic gradient descent, resampling the generated digits every 300 iterations, using minibatches of size 500, with equal numbers of training and generated points. It is clear that the digits produced have many artifacts not appearing in the MNIST data set. Despite this, the mean log density of the held-out test data is , as compared with the reported mean log density achieved by adversarial nets.
There are several possible explanations for this. First, kernel density estimation is known to perform poorly in high dimensions. Second, the MMD objective can itself be understood as the squared difference of two kernel density estimates, and so, in a sense, the objective being optimized is directly related to the subsequent mean test log density evaluation. There is no clear connection for adversarial networks, which might explain why it suffers under this test. Our experience suggests that the RBF kernel delivers base line performance but that an image-specific kernel, capturing, e.g., shift invariance, might lead to better images.
3 Toronto face dataset
We have also trained the generative MMD network on Toronto face dataset (TFD) . The parameters were adapted from the MNIST experiment: we also used a 3-hidden-layer sigmoidal MLP with similar architecture (1000, 600, and 1000 units) and RBF kernel for the cost function with the same hyper parameter. The training dataset batch sizes were equal to the number of generated points (500). The generated points were resampled every 500 iterations. The network was optimized for 500,000 iterations.
The samples from the resulting network are plotted in Fig. 3. The mean log density of the held-out test set is 2283 39. Although this figure is higher than the mean log density of 2057 26 reported in , the samples from the MMD network are again clearly distinguishable from the training dataset. Thus the high test score suggests that kernel density estimation does not perform well at evaluating the performance for these high dimensional datasets.
Conclusion
MMD offers a closed form surrogate for the discriminator in adversarial nets framework. After using Bayesian optimization for the parameters, we found that the network outperformed the adversarial network in terms of the density of the held-out test set under kernel density estimation. On the other hand, there is a clear discrepancy between the digits produced by MMD Nets and the MNIST digits, which might suggest that KDE is not up to the task of evaluating these models. Given how quickly MMD Nets achieves this level of performance, it is worth considering its use as an initialization for more costly procedures.
Acknowledgments
The authors would like to thank Bharath Sriperumbudur for technical discussions.
Appendix A Proofs
We begin with some preliminaries and known results:
A random variable is said to be a Rademacher random variable if it takes values in , each with probability .
Let be a probability measure on , and let be a class of uniformly bounded functions on . Then the Rademacher complexity of (with respect to ) is
where is a sequence of independent Rademacher random variables, and are independent, -distributed random variables, independent also from .
Then, for all and independent random variables in ,
Assume , . Then
The case where is a finite set is elementary:
Let be the distribution of , with taking values in some finite set . Then, with probability at least , where is defined as in Theorem 4, we have
Note, that the upper bound stated in Theorem 4 holds for the parameter value , i.e.,
Because depends on the training data and generator data , we use a uniform bound that holds over all . Specifically,
This yields that with probability at least ,
Since was chosen to minimize , we know that
In order to prove the general result, we begin with some technical lemmas. The development here owes much to .
Then there exists a constant that depends on , such that
Let us introduce , where and have the same distribution and are independent for all . Then the following is true:
Using Jensen’s inequality and the independence of and , we have
Introducing conditional expectations allows us to rewrite the equation with the sum over outside the expectations. I.e., Eq. 36 equals to
The second equality follows by symmetry of random variables . Note that we also added Rademacher random variables before each term in the sum since has the same distribution as for all and therefore the ’s do not affect the expectation of the sum.
Note that and are identically distributed. Thus the triangle inequality implies that Eq. 39 is less than or equal to
where is the Rademacher’s complexity of . Then by Theorem 3, we have
Then there exists that depends on , such that
The proof is very similar to that of Lemma 1. ∎
The proof follows the same steps as the proof of Theorem 5 apart from a stronger uniform bound stated in Appendix A. I.e., we need to show:
For all , does not depend on and therefore the first two terms of the equation above can be taken out of the supremum. Also, note that since , we have
and is an unbiased estimate of . Then from McDiarmid’s inequality on , we have
Therefore LABEL:eq:supremuminequality is bounded by the sum of the bound on Eq. 45 and the following:
Thus the next step is to find the bound for the supremum above.
We will first find the upper bound on , i.e., for every , we will show that there exists , such that
since the kernel is bounded by , and therefore is bounded by for all . The conditions of Theorem 2 are satisfied and thus we can use McDiarmid’s Inequality on :
To show Eq. 48, we need to bound the expectation of . We can apply Lemma 1 on the function classes and . The resulting bound is
where and are parameters associated to fat shattering dimension of as stated in the assumptions of the theorem, and is a constant depending on .
Similarly, has bounded differences:
for some constant that depends on . The final bound on is then
Summing up the bounds from Eq. 54 and Eq. 57, it follows that
Using the bound in Eq. 45, we have obtain the uniform bound we were looking for:
Since it was assumed that and , we get
To finish, we proceed as in the proof of Theorem 5. We can rearrange some of the terms to get a different form of Eq. 28:
All of the above implies that for any , there exists , such that
The rate is given by Eq. 53 and Eq. 59:
where the constants and depend on and alone. ∎
We close by noting that the approximation error is zero in the nonparametric limit.