The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables

Chris J. Maddison, Andriy Mnih, Yee Whye Teh

Introduction

Software libraries for automatic differentiation (AD) (Abadi et al., 2015; Theano Development Team, 2016) are enjoying broad use, spurred on by the success of neural networks on some of the most challenging problems of machine learning. The dominant mode of development in these libraries is to define a forward parametric computation, in the form of a directed acyclic graph, that computes the desired objective. If the components of the graph are differentiable, then a backwards computation for the gradient of the objective can be derived automatically with the chain rule. The ease of use and unreasonable effectiveness of gradient descent has led to an explosion in the diversity of architectures and objective functions. Thus, expanding the range of useful continuous operations can have an outsized impact on the development of new models. For example, a topic of recent attention has been the optimization of stochastic computation graphs from samples of their states. Here, the observation that AD “just works” when stochastic nodesFor our purposes a stochastic node of a computation graph is just a random variable whose distribution depends in some deterministic way on the values of the parent nodes. can be reparameterized into deterministic functions of their parameters and a fixed noise distribution (Kingma & Welling, 2013; Rezende et al., 2014), has liberated researchers in the development of large complex stochastic architectures (e.g. Gregor et al., 2015).

Computing with discrete stochastic nodes still poses a significant challenge for AD libraries. Deterministic discreteness can be relaxed and approximated reasonably well with sigmoidal functions or the softmax (see e.g., Grefenstette et al., 2015; Graves et al., 2016), but, if a distribution over discrete states is needed, there is no clear solution. There are well known unbiased estimators for the gradients of the parameters of a discrete stochastic node from samples. While these can be made to work with AD, they involve special casing and defining surrogate objectives (Schulman et al., 2015), and even then they can have high variance. Still, reasoning about discrete computation comes naturally to humans, and so, despite the difficulty associated, many modern architectures incorporate discrete stochasticity (Mnih et al., 2014; Xu et al., 2015; Kočiský et al., 2016).

This work is inspired by the observation that many architectures treat discrete nodes continuously, and gradients rich with counterfactual information are available for each of their possible states. We introduce a continuous relaxation of discrete random variables, Concrete for short, which allow gradients to flow through their states. The Concrete distribution is a new parametric family of continuous distributions on the simplex with closed form densities. Sampling from the Concrete distribution is as simple as taking the softmax of logits perturbed by fixed additive noise. This reparameterization means that Concrete stochastic nodes are quick to implement in a way that “just works” with AD. Crucially, every discrete random variable corresponds to the zero temperature limit of a Concrete one. In this view optimizing an objective over an architecture with discrete stochastic nodes can be accomplished by gradient descent on the samples of the corresponding Concrete relaxation. When the objective depends, as in variational inference, on the log-probability of discrete nodes, the Concrete density is used during training in place of the discrete mass. At test time, the graph with discrete nodes is evaluated.

The paper is organized as follows. We provide a background on stochastic computation graphs and their optimization in Section 2. Section 3 reviews a reparameterization for discrete random variables, introduces the Concrete distribution, and discusses its application as a relaxation. Section 4 reviews related work. In Section 5 we present results on a density estimation task and a structured prediction task on the MNIST and Omniglot datasets. In Appendices C and F we provide details on the practical implementation and use of Concrete random variables. When comparing the effectiveness of gradients obtained via Concrete relaxations to a state-of-the-art-method (VIMCO, Mnih & Rezende, 2016), we find that they are competitive—occasionally outperforming and occasionally underperforming—all the while being implemented in an AD library without special casing.

Background

Stochastic computation graphs (SCGs) provide a formalism for specifying input-output mappings, potentially stochastic, with learnable parameters using directed acyclic graphs (see Schulman et al. (2015) for a review). The state of each non-input node in such a graph is obtained from the states of its parent nodes by either evaluating a deterministic function or sampling from a conditional distribution. Many training objectives in supervised, unsupervised, and reinforcement learning can be expressed in terms of SCGs.

In general, both the objective and its gradients are intractable. We will side-step this issue by estimating them with samples from pϕ(x)p_{\phi}(x). The gradient w.r.t. to the parameters θ\theta has the form

and can be easily estimated using Monte Carlo sampling:

where Xspϕ(x)X^{s}\sim p_{\phi}(x) i.i.d. The more challenging task is to compute the gradient w.r.t. the parameters ϕ\phi of pϕ(x)p_{\phi}(x). The expression obtained by differentiating the expected objective,

does not have the form of an expectation w.r.t. xx and thus does not directly lead to a Monte Carlo gradient estimator. However, there are two ways of getting around this difficulty which lead to the two classes of estimators we will now discuss.

2 Score Function Estimators

The score function estimator (SFE, Fu, 2006), also known as the REINFORCE (Williams, 1992) or likelihood-ratio estimator (Glynn, 1990), is based on the identity ϕpϕ(x)=pϕ(x)ϕlogpϕ(x)\nabla_{\phi}p_{\phi}(x)=p_{\phi}(x)\nabla_{\phi}\log p_{\phi}(x), which allows the gradient in Eq. 3 to be written as an expectation:

Estimating this expectation using naive Monte Carlo gives the estimator

where Xspϕ(x)X^{s}\sim p_{\phi}(x) i.i.d. This is a very general estimator that is applicable whenever logpϕ(x)\log p_{\phi}(x) is differentiable w.r.t. ϕ\phi. As it does not require fθ(x)f_{\theta}(x) to be differentiable or even continuous as a function of xx, the SFE can be used with both discrete and continuous random variables.

Though the basic version of the estimator can suffer from high variance, various variance reduction techniques can be used to make the estimator much more effective (Greensmith et al., 2004). Baselines are the most important and widely used of these techniques (Williams, 1992). A number of score function estimators have been developed in machine learning (Paisley et al., 2012; Gregor et al., 2013; Ranganath et al., 2014; Mnih & Gregor, 2014; Titsias & Lázaro-Gredilla, 2015; Gu et al., 2016), which differ primarily in the variance reduction techniques used.

3 Reparameterization Trick

Having reparameterized pϕ(x)p_{\phi}(x), we can now express the objective as an expectation w.r.t. q(z)q(z):

As q(z)q(z) does not depend on ϕ\phi, we can estimate the gradient w.r.t. ϕ\phi in exactly the same way we estimated the gradient w.r.t. θ\theta in Eq. 1. Assuming differentiability of fθ(x)f_{\theta}(x) w.r.t. xx and of gϕ(z)g_{\phi}(z) w.r.t. ϕ\phi and using the chain rule gives

The reparameterization trick, introduced in the context of variational inference independently by Kingma & Welling (2014), Rezende et al. (2014), and Titsias & Lázaro-Gredilla (2014), is usually the estimator of choice when it is applicable. For continuous latent variables which are not directly reparameterizable, new hybrid estimators have also been developed, by combining partial reparameterizations with score function estimators (Ruiz et al., 2016; Naesseth et al., 2016).

4 Application: Variational Training of Latent Variable Models

provides a convenient alternative which has precisely the form we considered in Section 2.1. This approach relies on introducing an auxiliary distribution qϕ(z  x)q_{\phi}(z\ |\ x) with its own parameters, which serves as approximation to the intractable posterior pθ(z  x)p_{\theta}(z\ |\ x). The model is trained by jointly maximizing the objective w.r.t. to the parameters of pp and qq. The number of samples used inside the objective mm allows trading off the computational cost against the tightness of the bound. For m=1m=1, Lm(θ,ϕ)\mathcal{L}_{m}(\theta,\phi) becomes is the widely used evidence lower bound (ELBO, Hoffman et al., 2013) on logpθ(x)\log p_{\theta}(x), while for m>1m>1, it is known as the importance weighted bound (Burda et al., 2016).

The Concrete Distribution

In other words, the sampling of a discrete random variable can be refactored into a deterministic function—componentwise addition followed by argmax\operatorname*{argmax}—of the parameters logαk\log\alpha_{k} and fixed distribution log(logUk)-\log(-\log U_{k}). See Figure 1(a) for a visualization.

The apparently arbitrary choice of noise gives the trick its name, as log(logU)-\log(-\log U) has a Gumbel distribution. This distribution features in extreme value theory (Gumbel, 1954) where it plays a central role similar to the Normal distribution: the Gumbel distribution is stable under max\max operations, and for some distributions, the order statistics (suitably normalized) of i.i.d. draws approach the Gumbel in distribution. The Gumbel can also be recognized as a log-\log-transformed exponential random variable. So, the correctness of (9) also reduces to a well known result regarding the argmin\operatorname*{argmin} of exponential random variables. See (Hazan et al., 2016) for a collection of related work, and particularly the chapter (Maddison, 2016) for a proof and generalization of this trick.

2 Concrete Random Variables

The derivative of the argmax\operatorname*{argmax} is 0 everywhere except at the boundary of state changes, where it is undefined. For this reason the Gumbel-Max trick is not a suitable reparameterization for use in SCGs with AD. Here we introduce the Concrete distribution motivated by considering a graph, which is the same as Figure 1(a) up to a continuous relaxation of the argmax\operatorname*{argmax} computation, see Figure 1(b). This will ultimately allow the optimization of parameters αk\alpha_{k} via gradients.

The softmax computation of (10) smoothly approaches the discrete argmax\operatorname*{argmax} computation as λ0\lambda\to 0 while preserving the relative order of the Gumbels logαk+Gk\log\alpha_{k}+G_{k}. So, imagine making a series of forward passes on the graphs of Figure 1. Both graphs return a stochastic value for each forward pass, but for smaller temperatures the outputs of Figure 1(b) become more discrete and eventually indistinguishable from a typical forward pass of Figure 1(a).

The distribution of XX sampled via (10) has a closed form density on the simplex. Because there may be other ways to sample a Concrete random variable, we take the density to be its definition.

Proposition 1 lists a few properties of the Concrete distribution. (a) is confirmation that our definition corresponds to the sampling routine (10). (b) confirms that rounding a Concrete random variable results in the discrete random variable whose distribution is described by the logits logαk\log\alpha_{k}, (c) confirms that taking the zero temperature limit of a Concrete random variable is the same as rounding. Finally, (d) is a convexity result on the density. We prove these results in Appendix A.

(Convex eventually) If λ(n1)1\lambda\leq(n-1)^{-1}, then pα,λ(x)p_{\alpha,\lambda}(x) is log-convex in xx.

The binary case of the Gumbel-Max trick simplifies to passing additive noise through a step function. The corresponding Concrete relaxation is implemented by passing additive noise through a sigmoid—see Figure 3. We cover this more thoroughly in Appendix B, along with a cheat sheet (Appendix F) on the density and implementation of all the random variables discussed in this work.

3 Concrete Relaxations

Concrete random variables may have some intrinsic value, but we investigate them simply as surrogates for optimizing a SCG with discrete nodes. When it is computationally feasible to integrate over the discreteness, that will always be a better choice. Thus, we consider the use case of optimizing a large graph with discrete stochastic nodes from samples.

First, we outline our proposal for how to use Concrete relaxations by considering a variational autoencoder with a single discrete latent variable. Let Pa(d)P_{a}(d) be the mass function of some nn-dimensional one-hot discrete random variable with unnormalized probabilities a(0,)na\in(0,\infty)^{n} and pθ(xd)p_{\theta}(x|d) some distribution over a data point xx given d(0,1)nd\in(0,1)^{n} one-hot. The generative model is then pθ,a(x,d)=pθ(xd)Pa(d)p_{\theta,a}(x,d)=p_{\theta}(x|d)P_{a}(d). Let Qα(dx)Q_{\alpha}(d|x) be an approximating posterior over d(0,1)nd\in(0,1)^{n} one-hot whose unnormalized probabilities α(x)(0,)n\alpha(x)\in(0,\infty)^{n} depend on xx. All together the variational lowerbound we care about stochastically optimizing is

where pa,λ2(z)p_{a,\lambda_{2}}(z) is a Concrete density with location aa and temperature λ2\lambda_{2}. At test time we evaluate the discrete lowerbound L1(θ,a,α)\mathcal{L}_{1}(\theta,a,\alpha). Naively implementing Eq. 13 will result in numerical issues. We discuss this and other details in Appendix C.

Thus, the basic paradigm we propose is the following: during training replace every discrete node with a Concrete node at some fixed temperature (or with an annealing schedule). The graphs are identical up to the softmax\operatorname*{softmax} / argmax\operatorname*{argmax} computations, so the parameters of the relaxed graph and discrete graph are the same. When an objective depends on the log-probability of discrete variables in the SCG, as the variational lowerbound does, we propose that the log-probability terms are also “relaxed” to represent the true distribution of the relaxed node. At test time the original discrete loss is evaluated. This is possible, because the discretization of any Concrete distribution has a closed form mass function, and the relaxation of any discrete distribution into a Concrete distribution has a closed form density. This is not always possible. For example, the multinomial probit model—the Gumbel-Max trick with Gaussians replacing Gumbels—does not have a closed form mass.

The success of Concrete relaxations will depend on the choice of temperature during training. It is important that the relaxed nodes are not able to represent a precise real valued mode in the interior of the simplex as in Figure 2(d). If this is the case, it is possible for the relaxed random variable to communicate much more than log2(n)\log_{2}(n) bits of information about its α\alpha parameters. This might lead the relaxation to prefer the interior of the simplex to the vertices, and as a result there will be a large integrality gap in the overall performance of the discrete graph. Therefore Proposition 1 (d) is a conservative guideline for generic nn-ary Concrete relaxations; at temperatures lower than (n1)1(n-1)^{-1} we are guaranteed not to have any modes in the interior for any α(0,)n\alpha\in(0,\infty)^{n}. We discuss the subtleties of choosing the temperatures in more detail in Appendix C. Ultimately the best choice of λ\lambda and the performance of the relaxation for any specific nn will be an empirical question.

Related Work

Perhaps the most common distribution over the simplex is the Dirichlet with density pα(x)k=1nxkαk1p_{\alpha}(x)\propto\prod_{k=1}^{n}x_{k}^{\alpha_{k}-1} on xΔn1x\in\Delta^{n-1}. The Dirichlet can be characterized by strong independence properties, and a great deal of work has been done to generalize it (Connor & Mosimann, 1969; Aitchison, 1985; Rayens & Srinivasan, 1994; Favaro et al., 2011). Of note is the Logistic Normal distribution (Atchison & Shen, 1980), which can be simulated by taking the softmax of n1n-1 normal random variables and an nnth logit that is deterministically zero. The Logistic Normal is an important distribution, because it can effectively model correlations within the simplex (Blei & Lafferty, 2006). To our knowledge the Concrete distribution does not fall completely into any family of distributions previously described. For λ1\lambda\leq 1 the Concrete is in a class of normalized infinitely divisible distributions (S. Favaro, personal communication), and the results of Favaro et al. (2011) apply.

The idea of using a softmax of Gumbels as a relaxation for a discrete random variable was concurrently considered by (Jang et al., 2016), where it was called the Gumbel-Softmax. They do not use the density in the relaxed objective, opting instead to compute all aspects of the graph, including discrete log-probability computations, with the relaxed stochastic state of the graph. In the case of variational inference, this relaxed objective is not a lower bound on the marginal likelihood of the observations, and care needs to be taken when optimizing it. The idea of using sigmoidal functions with additive input noise to approximate discreteness is also not a new idea. (Frey, 1997) introduced nonlinear Gaussian units which computed their activation by passing Gaussian noise with the mean and variance specified by the input to the unit through a nonlinearity, such as the logistic function. Salakhutdinov & Hinton (2009) binarized real-valued codes of an autoencoder by adding (Gaussian) noise to the logits before passing them through the logistic function. Most recently, to avoid the difficulty associated with likelihood-ratio methods (Kočiský et al., 2016) relaxed the discrete sampling operation by sampling a vector of Gaussians instead and passing those through a softmax.

There is another family of gradient estimators that have been studied in the context of training neural networks with discrete units. These are usually collected under the umbrella of straight-through estimators (Bengio et al., 2013; Raiko et al., 2014). The basic idea they use is passing forward discrete values, but taking gradients through the expected value. They have good empirical performance, but have not been shown to be the estimators of any loss function. This is in contrast to gradients from Concrete relaxations, which are biased with respect to the discrete graph, but unbiased with respect to the continuous one.

Experiments

The aim of our experiments was to evaluate the effectiveness of the gradients of Concrete relaxations for optimizing SCGs with discrete nodes. We considered the tasks in (Mnih & Rezende, 2016): structured output prediction and density estimation. Both tasks are difficult optimization problems involving fitting probability distributions with hundreds of latent discrete nodes. We compared the performance of Concrete reparameterizations to two state-of-the-art score function estimators: VIMCO (Mnih & Rezende, 2016) for optimizing the multisample variational objective (m>1m>1) and NVIL (Mnih & Gregor, 2014) for optimizing the single-sample one (m=1m=1). We performed the experiments using the MNIST and Omniglot datasets. These are datasets of 28×2828\times 28 images of handwritten digits (MNIST) or letters (Omniglot). For MNIST we used the fixed binarization of Salakhutdinov & Murray (2008) and the standard 50,000/10,000/10,000 split into training/validation/testing sets. For Omniglot we sampled a fixed binarization and used the standard 24,345/8,070 split into training/testing sets. We report the negative log-likelihood (NLL) of the discrete graph on the test data as the performance metric.

2 Density Estimation

Density estimation, or generative modelling, is the problem of fitting the distribution of data. We took the latent variable approach described in Section 2.4 and trained the models by optimizing the variational objective Lm(θ,ϕ)\mathcal{L}_{m}(\theta,\phi) given by Eq. 8 averaged uniformly over minibatches of data points xx. Both our generative models pθ(z, x)p_{\theta}(z,\ x) and variational distributions qϕ(z  x)q_{\phi}(z\ |\ x) were parameterized with neural networks as described above. We trained models with Lm(θ,ϕ)\mathcal{L}_{m}(\theta,\phi) for m{1,5,50}m\in\{1,5,50\} and approximated the NLL with L50,000(θ,ϕ)\mathcal{L}_{50,000}(\theta,\phi) averaged uniformly over the whole dataset.

The results are shown in Table 1. In general, VIMCO outperformed Concrete relaxations for linear models and Concrete relaxations outperformed VIMCO for non-linear models. We also tested the effectiveness of Concrete relaxations on generative models with nn-ary layers on the L5(θ,ϕ)\mathcal{L}_{5}(\theta,\phi) objective. The best 44-ary model achieved test/train NLL 86.7/83.3, the best 8-ary achieved 87.4/84.6 with Concrete relaxations, more complete results in Appendix E. The relatively poor performance of the 8-ary model may be because moving from 4 to 8 results in a more difficult objective without much added capacity. As a control we trained nn-ary models using logistic normals as relaxations of discrete distributions (with retuned temperature hyperparameters). Because the discrete zero temperature limit of logistic Normals is a multinomial probit whose mass function is not known, we evaluated the discrete model by sampling from the discrete distribution parameterized by the logits learned during training. The best 4-ary model achieved test/train NLL of 88.7/85.0, the best 8-ary model achieved 89.1/85.1.

3 Structured Output Prediction

Structured output prediction is concerned with modelling the high-dimensional distribution of the observation given a context and can be seen as conditional density estimation. We considered the task of predicting the bottom half x1x_{1} of an image of an MNIST digit given its top half x2x_{2}, as introduced by Raiko et al. (2014). We followed Raiko et al. (2014) in using a model with layers of discrete stochastic units between the context and the observation. Conditioned on the top half x2x_{2} the network samples from a distribution pϕ(z  x2)p_{\phi}(z\ |\ x_{2}) over layers of stochastic units zz then predicts x1x_{1} by sampling from a distribution pθ(x1  z)p_{\theta}(x_{1}\ |\ z). The training objective for a single pair (x1,x2)(x_{1},x_{2}) is

This objective is a special case of Lm(θ,ϕ)\mathcal{L}_{m}(\theta,\phi) (Eq. 8) where we use the prior pϕ(zx2)p_{\phi}(z|x_{2}) as the variational distribution. Thus, the objective is a lower bound on logpθ,ϕ(x1  x2)\log p_{\theta,\phi}(x_{1}\ |\ x_{2}).

We trained the models by optimizing LmSP(θ,ϕ)\mathcal{L}^{SP}_{m}(\theta,\phi) for m{1,5,50}m\in\{1,5,50\} averaged uniformly over minibatches and evaluated them by computing L100SP(θ,ϕ)\mathcal{L}^{SP}_{100}(\theta,\phi) averaged uniformly over the entire dataset. The results are shown in Figure 4. Concrete relaxations more uniformly outperformed VIMCO in this instance. We also trained nn-ary (392V–240H–240H–240H–392V) models on the L1SP(θ,ϕ)\mathcal{L}^{SP}_{1}(\theta,\phi) objective using the best temperature hyperparameters from density estimation. 4-ary achieved a test/train NLL of 55.4/46.0 and 8-ary achieved 54.7/44.8. As opposed to density estimation, increasing arity uniformly improved the models. We also investigated the hypothesis that for higher temperatures Concrete relaxations might prefer the interior of the interval to the boundary points {1,1}\{-1,1\}. Figure 4 was generated with binary (392V–240H–240H–240H–392V) model trained on L1SP(θ,ϕ)\mathcal{L}^{SP}_{1}(\theta,\phi).

Conclusion

We introduced the Concrete distribution, a continuous relaxation of discrete random variables. The Concrete distribution is a new distribution on the simplex with a closed form density parameterized by a vector of positive location parameters and a positive temperature. Crucially, the zero temperature limit of every Concrete distribution corresponds to a discrete distribution, and any discrete distribution can be seen as the discretization of a Concrete one. The application we considered was training stochastic computation graphs with discrete stochastic nodes. The gradients of Concrete relaxations are biased with respect to the original discrete objective, but they are low variance unbiased estimators of a continuous surrogate objective. We showed in a series of experiments that stochastic nodes with Concrete distributions can be used effectively to optimize the parameters of a stochastic computation graph with discrete stochastic nodes. We did not find that annealing or automatically tuning the temperature was important for these experiments, but it remains interesting and possibly valuable future work.

We thank Jimmy Ba for the excitement and ideas in the early days, Stefano Favarro for some analysis of the distribution. We also thank Gabriel Barth-Maron and Roger Grosse.

References

Appendix A Proof of Proposition 1

Let Zk=logαk+GkZ_{k}=\log\alpha_{k}+G_{k}, which has density

We will consider the invertible transformation

where yn=1i=1n1yiy_{n}=1-\sum_{i=1}^{n-1}y_{i}. This has Jacobian

by adding yi/yny_{i}/y_{n} times each of the top n1n-1 rows to the bottom row we see that this Jacobian has the same determinant as

with r=logcr=\log c change of variables we have density

Follows directly from (a) and the Gumbel-Max trick (Maddison, 2016).

Follows directly from (a) and the Gumbel-Max trick (Maddison, 2016).

Let λ(n1)1\lambda\leq(n-1)^{-1}. The density of XX can be rewritten as

Thus, the log density is up to an additive constant CC

If λ(n1)1\lambda\leq(n-1)^{-1}, then the first nn terms are convex, because log-\log is convex. For the last term, log-\log is convex and non-increasing and jkyjλ\prod_{j\neq k}y_{j}^{\lambda} is concave for λ(n1)1\lambda\leq(n-1)^{-1}. Thus, their composition is convex. The sum of convex terms is convex, finishing the proof.

Appendix B The Binary Special Case

Bernoulli random variables are an important special case of discrete distributions taking states in {0,1}\{0,1\}. Here we consider the binary special case of the Gumbel-Max trick from Figure 1(a) along with the corresponding Concrete relaxation.

Thus, D1=dH(logα+logUlog(1U))D_{1}\overset{d}{=}H(\log\alpha+\log U-\log(1-U)), where HH is the unit step function.

We define the Binary Concrete random variable XX by its density on the unit interval.

We state without proof the special case of Proposition 1 for Binary Concrete distributions

(Convex eventually) If λ1\lambda\leq 1, then pα,λ(x)p_{\alpha,\lambda}(x) is log-convex in xx.

If we want this to have a Bernoulli distribution with probability α/(1+α)\alpha/(1+\alpha), then we should solve the equation

This gives Φ(0)=1/(1+α)\Phi(0)=1/(1+\alpha), which can be accomplished by relocating the random variable YY with CDF Φ\Phi to be X=YΦ1(1/(1+α))X=Y-\Phi^{-1}(1/(1+\alpha)).

Appendix C Using Concrete Relaxations

In this section we include some tips for implementing and using the Concrete distribution as a relaxation. We use the following notation

Both sigmoid and log-sum-exp are common operations in libraries like TensorFlow or theano.

For the sake of exposition, we consider a simple variational autoencoder with a single discrete random variable and objective L1(θ,a,α)\mathcal{L}_{1}(\theta,a,\alpha) given by Eq. 8 for a single data point xx. This scenario will allow us to discuss all of the decisions one might make when using Concrete relaxations.

with respect to θ\theta, aa, and any parameters in α\alpha from samples of the SCG required to simulate an estimator of L1(θ,a,α)\mathcal{L}_{1}(\theta,a,\alpha).

C.2 What you might relax and why

This choice allows us to take derivatives through the stochastic computaitons of the graph.

The second consideration is which objective to put in place of [  ][\ \cdot\ ] in Eq. 19. We will consider the ideal scenario irrespective of numerical issues. In Subsection C.3 we address those numerical issues. The central question is how to treat the expectation of the ratio Pa(D)/Qα(Dx)P_{a}(D)/Q_{\alpha}(D|x) (which is the KL component of the loss) when ZZ replaces DD.

There are at least three options for how to modify the objective. They are, (20) replace the discrete mass with Concrete densities, (21) relax the computation of the discrete log mass, (22) replace it with the analytic discrete KL.

where d(i)d^{(i)} is a one-hot binary vector with di(i)=1d_{i}^{(i)}=1 and pa,λ2(z)p_{a,\lambda_{2}}(z) is the density of some Concrete random variable with temperature λ2\lambda_{2} with location parameters aa. Although (22) or (21) is tempting, we emphasize that these are NOT necessarily lower bounds on logp(x)\log p(x) in the relaxed model. (20) is the only objective guaranteed to be a lower bound:

For this reason we consider objectives of the form (20). Choosing (22) or (21) is possible, but the value of these objectives is not interpretable and one should early stop otherwise it will overfit to the spurious “KL” component of the loss. We now consider practical issues with (20) and how to address them. All together we can interpret qα,λ1(zx)q_{\alpha,\lambda_{1}}(z|x) as the Concrete relaxation of the variational posterior and pa,λ2(z)p_{a,\lambda_{2}}(z) the relaxation of the prior.

C.3 Which random variable to treat as the stochastic node

When implementing a SCG like the variational autoencoder example, we need to compute log-probabilities of Concrete random variables. This computation can suffer from underflow, so where possible it’s better to take a different node on the relaxed graph as the stochastic node on which log-likelihood terms are computed. For example, it’s tempting in the case of Concrete random variables to treat the Gumbels as the stochastic node on which the log-likelihood terms are evaluated and the softmax as downstream computation. This will be a looser bound in the context of variational inference than the corresponding bound when treating the Concrete relaxed states as the node.

and the objective on the RHS is fully reparameterizable and what we chose to optimize.

C.3.2 Binary Concrete

In the binary case, the logistic function is invertible, so it makes most sense to treat the logit plus noise as the stochastic node. In particular, the binary random node was sample from:

All together the relaxation in the binary special case would be

where fa,λ2(y)f_{a,\lambda_{2}}(y) is the density of a Logistic random variable sampled via Eq. 26 with location aa and temperature λ2\lambda_{2}.

This section had a dense array of densities, so we summarize the relevant ones, along with how to sample from them, in Appendix F.

C.4 Choosing the temperature

For n=1n=1 temperatures λ(n1)1\lambda\leq(n-1)^{-1} is a good guideline. For n>1n>1 taking λ(n1)1\lambda\leq(n-1)^{-1} is not necessarily a good guideline, although it will depend on nn and the specific application. As nn\to\infty the Concrete distribution becomes peakier, because the random normalizing constant k=1nexp((logαk+Gk)/λ)\sum_{k=1}^{n}\exp((\log\alpha_{k}+G_{k})/\lambda) grows. This means that practically speaking the optimization can tolerate much higher temperatures than (n1)1(n-1)^{-1}. We found in the cases n=4n=4 that λ=1\lambda=1 was the best temperature and in n=8n=8, λ=2/3\lambda=2/3 was the best. Yet λ=2/3\lambda=2/3 was the best single performing temperature across the n{2,4,8}n\in\{2,4,8\} cases that we considered. We recommend starting in that ball-park and exploring for any specific application.

When the loss depends on a KL divergence between two Concrete nodes, it’s possible to give the nodes distinct temperatures. We found this to improve results quite dramatically. In the context of our original problem and it’s relaxation:

Both λ1\lambda_{1} for the posterior temperature and λ2\lambda_{2} for the prior temperature are tunable hyperparameters.

Appendix D Experimental Details

The basic model architectures we considered are exactly analogous to those in Burda et al. (2016) with Concrete/discrete random variables replacing Gaussians.

The conditioning functions we used were either linear or non-linear. Non-linear consisted of two tanh\tanh layers of the same size as the preceding stochastic layer in the computation graph.

D.2 n𝑛n-ary layers

D.3 Bias Initialization

All biases were initialized to 0 with the exception of the biases in the prior decoder distribution over the 784 or 392 observed units. These were initialized to the logit of the base rate averaged over the respective dataset (MNIST or Omniglot).

D.4 Centering

We also found it beneficial to center the layers of the inference network during training. The activity in (1,1)d(-1,1)^{d} of each stochastic layer was centered during training by maintaining a exponentially decaying average with rate 0.9 over minibatches. This running average was subtracted from the activity of the layer before it was updated. Gradients did not flow throw this computation, so it simply amounted to a dynamic offset. The averages were not updated during the evaluation.

D.5 Hyperparameter Selection

All models were initialized with the heuristic of Glorot & Bengio (2010) and optimized using Adam (Kingma & Ba, 2014) with parameters β1=0.9,β2=0.999\beta_{1}=0.9,\beta_{2}=0.999 for 10710^{7} steps on minibatches of size 64. Hyperparameters were selected on the MNIST dataset by grid search taking the values that performed best on the validation set. Learning rates were chosen from {104,3104,103}\{10^{-4},3\cdot 10^{-4},10^{-3}\} and weight decay from {0,102,101,1}\{0,10^{-2},10^{-1},1\}. Two sets of hyperparameters were selected, one for linear models and one for non-linear models. The linear models’ hyperparameters were selected with the 200H–200H–784V density model on the L5(θ,ϕ)\mathcal{L}_{5}(\theta,\phi) objective. The non-linear models’ hyperparameters were selected with the 200H\sim200H\sim784V density model on the L5(θ,ϕ)\mathcal{L}_{5}(\theta,\phi) objective. For density estimation, the Concrete relaxation hyperparameters were (weight decay = 0, learning rate = 31043\cdot 10^{-4}) for linear and (weight decay = 0, learning rate = 10410^{-4}) for non-linear. For structured prediction Concrete relaxations used (weight decay = 10310^{-3}, learning rate = 31043\cdot 10^{-4}).

In addition to tuning learning rate and weight decay, we tuned temperatures for the Concrete relaxations on the density estimation task. We found it valuable to have different values for the prior and posterior distributions, see Eq. 28. In particular, for binary we found that (prior λ2=1/2\lambda_{2}=1/2, posterior λ1=2/3\lambda_{1}=2/3) was best, for 4-ary we found (prior λ2=2/3\lambda_{2}=2/3, posterior λ1=1\lambda_{1}=1) was best, and (prior λ2=2/5\lambda_{2}=2/5, posterior λ1=2/3\lambda_{1}=2/3) for 8-ary. No temperature annealing was used. For structured prediction we used just the corresponding posterior λ1\lambda_{1} as the temperature for the whole graph, as there was no variational posterior.

We performed early stopping when training with the score function estimators (VIMCO/NVIL) as they were much more prone to overfitting.

Appendix E Extra Results

Appendix F Cheat Sheet