GAN and VAE from an Optimal Transport Point of View
Aude Genevay, Gabriel Peyré, Marco Cuturi
Minimum Kantorovitch Estimators
Given some empirical distribution where , and a parametric family of probability distributions , , a Minimum Kantorovitch Estimator (MKE) for is defined as any solution of the problem
where is the Wasserstein cost on for some ground cost function , defined as
where and , and and are marginalization operators that return for a given coupling its first and second marginal, respectively.
The notations and above agree with the more general notion of pushforward measures: Given a measurable map , which can be interpreted as a function “moving” points from a measurable space to another, one can naturally extend to become a more general map that can now “move” an entire probability measure on towards a new probability measure on . The operator “pushes forward” each elementary mass of a measure in by applying the map to obtain then a mass in , to build on aggregate a new measure in written . More rigorously, the pushforward measure of a measure by a map is the measure denoted as in such that for any set , .
MKE-GM.
The MKE approach can be used directly in the case where is a statistical model, namely a parameterized family of probability distributions with a given density with respect to a dominant base measure, as considered for instance with exponential families on discrete spaces in . However, the MKE approach can also be used in a generative model setting, where is defined instead as the push forward of a fixed distribution supported on a low dimensional space , , where the parameterization lies now in choosing a map , i.e. , resulting in the following special case of the original (MKE) problem:
The map should be therefore thought as a “decoding” map from a low dimensional space to a high dimensional space. In such a setting, the maximum likelihood estimator is in general undefined or difficult to compute (because the support of the measures are singular) while MKEs are attractive because they are always well defined.
Dual Formulation and GAN
Because (1) is a linear program, it has a dual formulation, known as the Kantorovich problem [13, Thm. 5.9]:
subscript𝒵ℎsubscript𝑔𝜃𝑧differential-d𝜁𝑧subscript𝒳~ℎ𝑦differential-d𝜈𝑦ℎ𝑥~ℎ𝑦𝑐𝑥𝑦E(\theta)=\underset{h,\tilde{h}}{\max}\;\left\{\int_{\mathcal{Z}}h(g_{\theta}(z))\mathrm{d}\zeta(z)+\int_{\mathcal{X}}\tilde{h}(y)\mathrm{d}\nu(y)\;;\;h(x)+\tilde{h}(y)\leqslant c(x,y)\right\}. (2) where are continuous functions on often called Kantorovich potentials in the literature.
In the dual formulation (2), does not appear anymore in the constraints. Therefore, the gradient of can be computed as
𝜃subscript𝑔𝜃𝑧top∇superscriptℎ⋆subscript𝑔𝜃𝑧differential-d𝜁𝑧\nabla E(\theta)=\int_{\mathcal{Z}}[\partial_{\theta}g_{\theta}(z)]^{\top}\nabla h^{\star}(g_{\theta}(z))\mathrm{d}\zeta(z), (3) where is an optimal dual function solving (2). Here is the adjoint of the Jacobian of , where is the dimension of the parameter space .
A key remark in Kantorovich’s formulation is to notice that the cost of any pair can always be improved by replacing in (2) by the -transform of defined as
which is, indeed, given a candidate potential for the first variable, the best possible potential that can be paired with that satisfies the constraints of (2) (see [13, Thm. 5.9]). For this reason, one can parameterize problem (2) as depending on one potential function only.
A first approach to solve (2) is to remark that since is discrete, one can replace the continuous potential by the discrete vector and impose . As shown in , the optimization over can then be achieved using stochastic gradient descent.
Similarly to , another approach is to approximate (2) by restricting the dual potential to have a parametric form where is a discriminative deep network (see Figure 1, center). This map is often referred to as being an “adversarial” map. Plugging this ansatz in (2) leads to the Wasserstein-GAN problem
𝜃𝜉subscript𝒵subscriptℎ𝜉subscript𝑔𝜉𝑧differential-d𝜁𝑧subscript𝑗superscriptsubscriptℎ𝜉𝑐subscript𝑦𝑗\underset{\theta}{\min}\;\underset{\xi}{\max}\;\int_{\mathcal{Z}}h_{\xi}\circ g_{\xi}(z)\mathrm{d}\zeta(z)+\sum_{j}h_{\xi}^{c}(y_{j}). (WGAN) In the special case where , one can prove that the mechanics of -transforms result in the additional constraint that , subject to being a -Lipschitz function, see [13, Particular case 5.4]. This is used in to replace by in (WGAN) and use a deep network made of ReLu units whose Lipschitz constant is upper-bounded by .
As a side-note, and as previously commented in the literature, there is at this point no empirical evidence that supports the idea that using discriminative deep networks that way can result in accurate approximations of Wasserstein distances. These alternative formulations provide instead a very useful proxy for a quantity directly related to the Wasserstein distance.
Primal Formulation and VAE
Following , in the special case of a generative model , formula (1) can be conveniently re-written as
This is advantageous because now is defined over , which is lower-dimensional than , and also because, as in Equation (2), does not appear in the constraints either. This provides an alternative formula for the gradient of :
𝜃subscript𝑔𝜃𝑧topsubscript∇1𝑐subscript𝑔𝜃𝑧𝑦differential-dsuperscript𝜋⋆𝑧𝑦\nabla E(\theta)=\int_{\mathcal{Z}\times\mathcal{X}}[\partial_{\theta}g_{\theta}(z)]^{\top}\nabla_{1}c(g_{\theta}(z),y)\mathrm{d}\pi^{\star}(z,y), (5) where is an optimal coupling solving (4). Here denotes the gradient of with respect to the first variable.
suggests to look for couplings with a parametric form. A simple way to achieve this is to restrict couplings to those of the form
where is a parametric “encoding” map (typically a deep network), see Figure 1, right. This satisfies by construction the marginal constraint , but in general it cannot satisfy the other constraint (because is discrete while is not). So following , it makes sense to consider a relaxed “unbalanced” formulation (in the sense of ) of the form
subscript𝒵𝒳𝑐subscript𝑔𝜃𝑧𝑦differential-d𝜋𝑧𝑦𝜆𝐷conditionalsubscript𝑃1♯𝜋𝜁subscript𝑃2♯𝜋𝜈E_{\lambda}(\theta)=\underset{\pi}{\min}\;\left\{\int_{\mathcal{Z}\times\mathcal{X}}c(g_{\theta}(z),y)\mathrm{d}\pi(z,y)+\lambda D(P_{1\sharp}\pi|\zeta)\;;\;P_{2\sharp}\pi=\nu\right\}, (6) where is some distance or divergence between positive measures on and a relaxation parameter.
Plugging the ansatz in (6), one obtains the Wasserstein-VAE formulation
𝜃𝜉subscriptΔ𝜈subscript𝑔𝜃subscript𝑓𝜉subscriptId𝒳𝜆𝐷conditionalsubscript𝑓𝜉♯𝜈𝜁\underset{(\theta,\xi)}{\min}\;\Delta_{\nu}(g_{\theta}\circ f_{\xi},\mathrm{Id}_{\mathcal{X}})+\lambda D(f_{\xi\sharp}\nu|\zeta), (WVAE) where is the cost measuring the deviation of a map to identity
Such a cost is usually associated with the Monge formulation of optimal transport , whose original motivation was to find an optimal map under that cost that would be able to push forward a given measure onto another[12, §1.1].
Conclusions
The WGAN and WVAE formulations are very different, and are in some sense dual one of each other. For GAN, the couple should be thought as a (primal, dual) pair (often referred to as adversarial pair, which is reminiscent of game theory saddle points). For VAE, the couple is rather an (encoding, decoding) pair, and both have the flavour of transportation maps.
In sharp contrast to the primal gradient formula (5) which only requires integrating against an optimal coupling , the dual gradient formula (3) involves the integration of the gradient of an optimal potential . The latter tends to be more unstable and thus necessitates accurate optimization sub-iterations to obtain an optimal dual potential or an approximation within a restricted parametric class . This is somehow inline with the empirical observation that training VAE is more stable than training GAN. One should however bear in mind that, although both formulations can be motivated by the same minimum Kantorovitch estimation problem (MKE-GM), they define quite different estimators. In particular, GAN is often credited for producing less blurry outputs when used for image generation.
Denoting and the solutions of (MKE-GM), (WGAN) and (WVAE), one has in the limit (to cancel the bias due to the marginal constraint relaxation),
furthermore mentions that in the “non-parametric limit” (i.e. when the number of parameters appearing in tends to , and also letting ), the gap between the estimators should vanish. Indeed, and should capture the desired optimal map in the limit and one thus recovers the true solution to (MKE-GM). While it would be interesting from a theoretical perspective to prove and quantify such a claim, it is unclear wether it would be useful for the practitioner. Indeed, the convergence rate might be slow, so that in practice one can be quite far from this non-parametric limit. One could even argue that this limit may give poor estimators for complicated datasets, so that parameterizing the maps and using non-convex optimization solvers lead instead to a beneficial and implicit regularization of these estimators.