Categorical Reparameterization with Gumbel-Softmax

Eric Jang, Shixiang Gu, Ben Poole

Introduction

Stochastic neural networks with discrete random variables are a powerful technique for representing distributions encountered in unsupervised learning, language modeling, attention mechanisms, and reinforcement learning domains. For example, discrete variables have been used to learn probabilistic latent representations that correspond to distinct semantic classes (Kingma et al., 2014), image regions (Xu et al., 2015), and memory locations (Graves et al., 2014; 2016). Discrete representations are often more interpretable (Chen et al., 2016) and more computationally efficient (Rae et al., 2016) than their continuous analogues.

However, stochastic networks with discrete variables are difficult to train because the backpropagation algorithm — while permitting efficient computation of parameter gradients — cannot be applied to non-differentiable layers. Prior work on stochastic gradient estimation has traditionally focused on either score function estimators augmented with Monte Carlo variance reduction techniques (Paisley et al., 2012; Mnih & Gregor, 2014; Gu et al., 2016; Gregor et al., 2013), or biased path derivative estimators for Bernoulli variables (Bengio et al., 2013). However, no existing gradient estimator has been formulated specifically for categorical variables. The contributions of this work are threefold:

We introduce Gumbel-Softmax, a continuous distribution on the simplex that can approximate categorical samples, and whose parameter gradients can be easily computed via the reparameterization trick.

We show experimentally that Gumbel-Softmax outperforms all single-sample gradient estimators on both Bernoulli variables and categorical variables.

We show that this estimator can be used to efficiently train semi-supervised models (e.g. Kingma et al. (2014)) without costly marginalization over unobserved categorical latent variables.

The practical outcome of this paper is a simple, differentiable approximate sampling mechanism for categorical variables that can be integrated into neural networks and trained using standard backpropagation.

The Gumbel-Softmax distribution

The Gumbel-Max trick (Gumbel, 1954; Maddison et al., 2014) provides a simple and efficient way to draw samples zz from a categorical distribution with class probabilities π\pi:

where g1...gkg_{1}...g_{k} are i.i.d samples drawn from Gumbel(0,1)\text{Gumbel}(0,1)The Gumbel(0,1)\text{Gumbel}(0,1) distribution can be sampled using inverse transform sampling by drawing uUniform(0,1)u\sim\text{Uniform}(0,1) and computing g=log(log(u))g=-\log(-\log(\text{u})). . We use the softmax function as a continuous, differentiable approximation to arg max\operatorname*{arg\,max}, and generate kk-dimensional sample vectors yΔk1y\in\Delta^{k-1} where

The density of the Gumbel-Softmax distribution (derived in Appendix B) is:

This distribution was independently discovered by Maddison et al. (2016), where it is referred to as the concrete distribution. As the softmax temperature τ\tau approaches , samples from the Gumbel-Softmax distribution become one-hot and the Gumbel-Softmax distribution becomes identical to the categorical distribution p(z)p(z).

The Gumbel-Softmax distribution is smooth for τ>0\tau>0, and therefore has a well-defined gradient \nicefracyπ\nicefrac{{\partial y}}{{\partial\pi}} with respect to the parameters π\pi. Thus, by replacing categorical samples with Gumbel-Softmax samples we can use backpropagation to compute gradients (see Section 3.1). We denote this procedure of replacing non-differentiable categorical samples with a differentiable approximation during training as the Gumbel-Softmax estimator.

While Gumbel-Softmax samples are differentiable, they are not identical to samples from the corresponding categorical distribution for non-zero temperature. For learning, there is a tradeoff between small temperatures, where samples are close to one-hot but the variance of the gradients is large, and large temperatures, where samples are smooth but the variance of the gradients is small (Figure 1). In practice, we start at a high temperature and anneal to a small but non-zero temperature.

In our experiments, we find that the softmax temperature τ\tau can be annealed according to a variety of schedules and still perform well. If τ\tau is a learned parameter (rather than annealed via a fixed schedule), this scheme can be interpreted as entropy regularization (Szegedy et al., 2015; Pereyra et al., 2016), where the Gumbel-Softmax distribution can adaptively adjust the “confidence” of proposed samples during the training process.

2 Straight-Through Gumbel-Softmax Estimator

Continuous relaxations of one-hot vectors are suitable for problems such as learning hidden representations and sequence modeling. For scenarios in which we are constrained to sampling discrete values (e.g. from a discrete action space for reinforcement learning, or quantized compression), we discretize yy using arg max\operatorname*{arg\,max} but use our continuous approximation in the backward pass by approximating θzθy\nabla_{\theta}z\approx\nabla_{\theta}y. We call this the Straight-Through (ST) Gumbel Estimator, as it is reminiscent of the biased path derivative estimator described in Bengio et al. (2013). ST Gumbel-Softmax allows samples to be sparse even when the temperature τ\tau is high.

Related Work

For distributions that are reparameterizable, we can compute the sample zz as a deterministic function gg of the parameters θ\theta and an independent random variable ϵ\epsilon, so that z=g(θ,ϵ)z=g(\theta,\epsilon). The path-wise gradients from ff to θ\theta can then be computed without encountering any stochastic nodes:

For example, the normal distribution zN(μ,σ)z\sim\mathcal{N}(\mu,\sigma) can be re-written as μ+σN(0,1)\mu+\sigma\cdot\mathcal{N}(0,1), making it trivial to compute \nicefraczμ\nicefrac{{\partial z}}{{\partial\mu}} and \nicefraczσ\nicefrac{{\partial z}}{{\partial\sigma}}. This reparameterization trick is commonly applied to training variational autooencoders with continuous latent variables using backpropagation (Kingma & Welling, 2013; Rezende et al., 2014b). As shown in Figure 2, we exploit such a trick in the construction of the Gumbel-Softmax estimator.

Biased path derivative estimators can be utilized even when zz is not reparameterizable. In general, we can approximate θzθm(θ)\nabla_{\theta}z\approx\nabla_{\theta}m(\theta), where mm is a differentiable proxy for the stochastic sample. For Bernoulli variables with mean parameter θ\theta, the Straight-Through (ST) estimator (Bengio et al., 2013) approximates m=μθ(z)m=\mu_{\theta}(z), implying θm=1\nabla_{\theta}m=1. For k=2k=2 (Bernoulli), ST Gumbel-Softmax is similar to the slope-annealed Straight-Through estimator proposed by Chung et al. (2016), but uses a softmax instead of a hard sigmoid to determine the slope. Rolfe (2016) considers an alternative approach where each binary latent variable parameterizes a continuous mixture model. Reparameterization gradients are obtained by backpropagating through the continuous variables and marginalizing out the binary variables.

One limitation of the ST estimator is that backpropagating with respect to the sample-independent mean may cause discrepancies between the forward and backward pass, leading to higher variance. Gumbel-Softmax avoids this problem because each sample yy is a differentiable proxy of the corresponding discrete sample zz.

2 Score Function-Based Gradient Estimators

The score function estimator (SF, also referred to as REINFORCE (Williams, 1992) and likelihood ratio estimator (Glynn, 1990)) uses the identity θpθ(z)=pθ(z)θlogpθ(z)\nabla_{\theta}p_{\theta}(z)=p_{\theta}(z)\nabla_{\theta}\log p_{\theta}(z) to derive the following unbiased estimator:

SF only requires that pθ(z)p_{\theta}(z) is continuous in θ\theta, and does not require backpropagating through ff or the sample zz. However, SF suffers from high variance and is consequently slow to converge. In particular, the variance of SF scales linearly with the number of dimensions of the sample vector (Rezende et al., 2014a), making it especially challenging to use for categorical distributions.

We briefly summarize recent stochastic gradient estimators that utilize control variates. We direct the reader to Gu et al. (2016) for further detail on these techniques.

NVIL (Mnih & Gregor, 2014) uses two baselines: (1) a moving average fˉ\bar{f} of ff to center the learning signal, and (2) an input-dependent baseline computed by a 1-layer neural network fitted to ffˉf-\bar{f} (a control variate for the centered learning signal itself). Finally, variance normalization divides the learning signal by max(1,σf)\max(1,\sigma_{f}), where σf2\sigma_{f}^{2} is a moving average of Var[f]\text{Var}[f].

DARN (Gregor et al., 2013) uses b=f(zˉ)+f(zˉ)(zzˉ)b=f(\bar{z})+f^{\prime}(\bar{z})(z-\bar{z}), where the baseline corresponds to the first-order Taylor approximation of f(z)f(z) from f(zˉ)f(\bar{z}). zz is chosen to be \nicefrac12\nicefrac{{1}}{{2}} for Bernoulli variables, which makes the estimator biased for non-quadratic ff, since it ignores the correction term μb\mu_{b} in the estimator expression.

VIMCO (Mnih & Rezende, 2016) is a gradient estimator for multi-sample objectives that uses the mean of other samples b=\nicefrac1mjif(zj)b=\nicefrac{{1}}{{m}}\sum_{j\neq i}f(z_{j}) to construct a baseline for each sample ziz1:mz_{i}\in z_{1:m}. We exclude VIMCO from our experiments because we are comparing estimators for single-sample objectives, although Gumbel-Softmax can be easily extended to multi-sample objectives.

3 Semi-Supervised Generative Models

Semi-supervised learning considers the problem of learning from both labeled data (x,y)DL(x,y)\sim\mathcal{D}_{L} and unlabeled data xDUx\sim\mathcal{D}_{U}, where xx are observations (i.e. images) and yy are corresponding labels (e.g. semantic class). For semi-supervised classification, Kingma et al. (2014) propose a variational autoencoder (VAE) whose latent state is the joint distribution over a Gaussian “style” variable zz and a categorical “semantic class” variable yy (Figure 6, Appendix). The VAE objective trains a discriminative network qϕ(yx)q_{\phi}(y|x), inference network qϕ(zx,y)q_{\phi}(z|x,y), and generative network pθ(xy,z)p_{\theta}(x|y,z) end-to-end by maximizing a variational lower bound on the log-likelihood of the observation under the generative model. For labeled data, the class yy is observed, so inference is only done on zq(zx,y)z\sim q(z|x,y). The variational lower bound on labeled data is given by:

For unlabeled data, difficulties arise because the categorical distribution is not reparameterizable. Kingma et al. (2014) approach this by marginalizing out yy over all classes, so that for unlabeled data, inference is still on qϕ(zx,y)q_{\phi}(z|x,y) for each yy. The lower bound on unlabeled data is:

where α\alpha is the scalar trade-off between the generative and discriminative objectives.

One limitation of this approach is that marginalization over all kk class values becomes prohibitively expensive for models with a large number of classes. If D,I,GD,I,G are the computational cost of sampling from qϕ(yx)q_{\phi}(y|x), qϕ(zx,y)q_{\phi}(z|x,y), and pθ(xy,z)p_{\theta}(x|y,z) respectively, then training the unsupervised objective requires O(D+k(I+G))\mathcal{O}(D+k(I+G)) for each forward/backward step. In contrast, Gumbel-Softmax allows us to backpropagate through yqϕ(yx)y\sim q_{\phi}(y|x) for single sample gradient estimation, and achieves a cost of O(D+I+G)\mathcal{O}(D+I+G) per training step. Experimental comparisons in training speed are shown in Figure 5.

Experimental Results

In our first set of experiments, we compare Gumbel-Softmax and ST Gumbel-Softmax to other stochastic gradient estimators: Score-Function (SF), DARN, MuProp, Straight-Through (ST), and Slope-Annealed ST. Each estimator is evaluated on two tasks: (1) structured output prediction and (2) variational training of generative models. We use the MNIST dataset with fixed binarization for training and evaluation, which is common practice for evaluating stochastic gradient estimators (Salakhutdinov & Murray, 2008; Larochelle & Murray, 2011).

We trained a SBN with two hidden layers of 200 units each. This corresponds to either 200 Bernoulli variables (denoted as 392392-200200-200200-392392) or 20 categorical variables (each with 10 classes) with binarized activations (denoted as 392392-(20×10)(20\times 10)-(20×10)(20\times 10)-392392).

As shown in Figure 3, ST Gumbel-Softmax is on par with the other estimators for Bernoulli variables and outperforms on categorical variables. Meanwhile, Gumbel-Softmax outperforms other estimators on both Bernoulli and Categorical variables. We found that it was not necessary to anneal the softmax temperature for this task, and used a fixed τ=1\tau=1.

2 Generative Modeling with Variational Autoencoders

We train variational autoencoders (Kingma & Welling, 2013), where the objective is to learn a generative model of binary MNIST images. In our experiments, we modeled the latent variable as a single hidden layer with 200 Bernoulli variables or 20 categorical variables (20×1020\times 10). We use a learned categorical prior rather than a Gumbel-Softmax prior in the training objective. Thus, the minimization objective during training is no longer a variational bound if the samples are not discrete. In practice, we find that optimizing this objective in combination with temperature annealing still minimizes actual variational bounds on validation and test sets. Like the structured output prediction task, we use a multi-sample bound for evaluation with m=1000m=1000.

As shown in Figure 4, ST Gumbel-Softmax outperforms other estimators for Categorical variables, and Gumbel-Softmax drastically outperforms other estimators in both Bernoulli and Categorical variables.

3 Generative Semi-Supervised Classification

We apply the Gumbel-Softmax estimator to semi-supervised classification on the binary MNIST dataset. We compare the original marginalization-based inference approach (Kingma et al., 2014) to single-sample inference with Gumbel-Softmax and ST Gumbel-Softmax.

We trained on a dataset consisting of 100 labeled examples (distributed evenly among each of the 10 classes) and 50,000 unlabeled examples, with dynamic binarization of the unlabeled examples for each minibatch. The discriminative model qϕ(yx)q_{\phi}(y|x) and inference model qϕ(zx,y)q_{\phi}(z|x,y) are each implemented as 3-layer convolutional neural networks with ReLU activation functions. The generative model pθ(xy,z)p_{\theta}(x|y,z) is a 4-layer convolutional-transpose network with ReLU activations. Experimental details are provided in Appendix A.

In Kingma et al. (2014), inference over the latent state is done by marginalizing out yy and using the reparameterization trick for sampling from qϕ(zx,y)q_{\phi}(z|x,y). However, this approach has a computational cost that scales linearly with the number of classes. Gumbel-Softmax allows us to backpropagate directly through single samples from the joint qϕ(y,zx)q_{\phi}(y,z|x), achieving drastic speedups in training without compromising generative or classification performance. (Table 2, Figure 5).

In Figure 5, we show how Gumbel-Softmax versus marginalization scales with the number of categorical classes. For these experiments, we use MNIST images with randomly generated labels. Training the model with the Gumbel-Softmax estimator is 2×2\times as fast for 1010 classes and 9.9×9.9\times as fast for 100100 classes.

Discussion

The primary contribution of this work is the reparameterizable Gumbel-Softmax distribution, whose corresponding estimator affords low-variance path derivative gradients for the categorical distribution. We show that Gumbel-Softmax and Straight-Through Gumbel-Softmax are effective on structured output prediction and variational autoencoder tasks, outperforming existing stochastic gradient estimators for both Bernoulli and categorical latent variables. Finally, Gumbel-Softmax enables dramatic speedups in inference over discrete latent variables.

We sincerely thank Luke Vilnis, Vincent Vanhoucke, Luke Metz, David Ha, Laurent Dinh, George Tucker, and Subhaneil Lahiri for helpful discussions and feedback.

References

Appendix A Semi-Supervised Classification Model

Figures 6 and 7 describe the architecture used in our experiments for semi-supervised classification (Section 4.3).

Appendix B Deriving the density of the Gumbel-Softmax distribution

Here we derive the probability density function of the Gumbel-Softmax distribution with probabilities π1,...,πk\pi_{1},...,\pi_{k} and temperature τ\tau. We first define the logits xi=logπix_{i}=\log\pi_{i}, and Gumbel samples g1,...,gkg_{1},...,g_{k}, where giGumbel(0,1)g_{i}\sim\text{Gumbel}(0,1). A sample from the Gumbel-Softmax can then be computed as:

The mapping from the Gumbel samples gg to the Gumbel-Softmax sample yy is not invertible as the normalization of the softmax operation removes one degree of freedom. To compensate for this, we define an equivalent sampling process that subtracts off the last element, (xk+gk)/τ(x_{k}+g_{k})/\tau before the softmax:

To derive the density of this equivalent sampling process, we first derive the density for the ”centered” multivariate Gumbel density corresponding to:

where giGumbel(0,1)g_{i}\sim\text{Gumbel}(0,1). Note the probability density of a Gumbel distribution with scale parameter β=1\beta=1 and mean μ\mu at zz is: f(z,μ)=eμzeμzf(z,\mu)=e^{\mu-z-e^{\mu-z}}. We can now compute the density of this distribution by marginalizing out the last Gumbel sample, gkg_{k}:

We perform a change of variables with v=egkv=e^{-g_{k}}, so dv=egkdgkdv=-e^{-g_{k}}dg_{k} and dgk=dvegk=dv/vdg_{k}=-dv\,e^{g_{k}}=dv/v, and define uk=0u_{k}=0 to simplify notation:

B.2 Transforming to a Gumbel-Softmax

Given samples u1,...,uk,1u_{1},...,u_{k,-1} from the centered Gumbel distribution, we can apply a deterministic transformation hh to yield the first k1k-1 coordinates of the sample from the Gumbel-Softmax:

Note that the final coordinate probability yky_{k} is fixed given the first k1k-1, as i=1kyi=1\sum_{i=1}^{k}y_{i}=1:

We can thus compute the probability of a sample from the Gumbel-Softmax using the change of variables formula on only the first k1k-1 variables:

Thus we need to compute two more pieces: the inverse of hh and its Jacobian determinant. The inverse of hh is:

Next, we compute the determinant of the Jacobian:

where ee is a k1k-1 dimensional vector of ones, and we’ve used the identities: det(AB)=det(A)det(B)\text{det}(AB)=\text{det}(A)\text{det}(B), det(diag(x))=ixi\text{det}(\text{diag}(x))=\prod_{i}x_{i}, and det(I+uvT)=1+uTv\text{det}(I+uv^{T})=1+u^{T}v.

We can then plug into the change of variables formula (Eq. 21) using the density of the centered Gumbel (Eq.15), the inverse of hh (Eq. 22) and its Jacobian determinant (Eq. 26):