Gradient Estimation Using Stochastic Computation Graphs

John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel

Introduction

The great success of neural networks is due in part to the simplicity of the backpropagation algorithm, which allows one to efficiently compute the gradient of any loss function defined as a composition of differentiable functions. This simplicity has allowed researchers to search in the space of architectures for those that are both highly expressive and conducive to optimization; yielding, for example, convolutional neural networks in vision and LSTMs for sequence data . However, the backpropagation algorithm is only sufficient when the loss function is a deterministic, differentiable function of the parameter vector.

A rich class of problems arising throughout machine learning requires optimizing loss functions that involve an expectation over random variables. Two broad categories of these problems are (1) likelihood maximization in probabilistic models with latent variables , and (2) policy gradients in reinforcement learning . Combining ideas from from those two perennial topics, recent models of attention and memory have used networks that involve a combination of stochastic and deterministic operations.

In most of these problems, from probabilistic modeling to reinforcement learning, the loss functions and their gradients are intractable, as they involve either a sum over an exponential number of latent variable configurations, or high-dimensional integrals that have no analytic solution. Prior work (see Section 6) has provided problem-specific derivations of Monte-Carlo gradient estimators, however, to our knowledge, no previous work addresses the general case.

Appendix C recalls several classic and recent techniques in variational inference and reinforcement learning , where the loss functions can be straightforwardly described using the formalism of stochastic computation graphs that we introduce. For these examples, the variance-reduced gradient estimators derived in prior work are special cases of the results in Sections 3 and 4.

The contributions of this work are as follows:

We introduce a formalism of stochastic computation graphs, and in this general setting, we derive unbiased estimators for the gradient of the expected loss.

We show how this estimator can be computed as the gradient of a certain differentiable function (which we call the surrogate loss), hence, it can be computed efficiently using the backpropagation algorithm. This observation enables a practitioner to write an efficient implementation using automatic differentiation software.

We describe variance reduction techniques that can be applied to the setting of stochastic computation graphs, generalizing prior work from reinforcement learning and variational inference.

We briefly describe how to generalize some other optimization techniques to this setting: majorization-minimization algorithms, by constructing an expression that bounds the loss function; and quasi-Newton / Hessian-free methods , by computing estimates of Hessian-vector products.

The main practical result of this article is that to compute the gradient estimator, one just needs to make a simple modification to the backpropagation algorithm, where extra gradient signals are introduced at the stochastic nodes. Equivalently, the resulting algorithm is just the backpropagation algorithm, applied to the surrogate loss function, which has extra terms introduced at the stochastic nodes. The modified backpropagation algorithm is presented in Section 5.

Preliminaries

We might be given a parameterized probability distribution xp(; θ)x\sim p(\cdot;\ \theta). In this case, we can use the score function (SF) estimator :

This classic equation is derived as follows:

This equation is valid if and only if p(x; θ)p(x;\ \theta) is a continuous function of θ\theta; however, it does not need to be a continuous function of xx .

xx may be a deterministic, differentiable function of θ\theta and another random variable zz, i.e., we can write x(z,θ)x(z,\theta). Then, we can use the pathwise derivative (PD) estimator, defined as follows.

This equation, which merely swaps the derivative and expectation, is valid if and only if f(x(z,θ))f(x(z,\theta)) is a continuous function of θ\theta for all zz . Note that for the pathwise derivative estimator, f(x(z,θ))f(x(z,\theta)) merely needs to be a continuous function of θ\theta—it is sufficient that this function is almost-everywhere differentiable. A similar statement can be made about p(x;θ)p(x;\theta) and the score function estimator. See Glasserman for a detailed discussion of the technical requirements for these gradient estimators to be valid. That is not true if, for example, ff is a step function.

This formula can be derived by writing the expectation as an integral and differentiating, as in Equation 2.

In some cases, it is possible to reparameterize a probabilistic model—moving θ\theta from the distribution to inside the expectation or vice versa. See for a general discussion, and see for a recent application of this idea to variational inference.

The SF and PD estimators are applicable in different scenarios and have different properties.

SF is valid under more permissive mathematical conditions than PD. SF can be used if ff is discontinuous, or if xx is a discrete random variable.

SF only requires sample values f(x)f(x), whereas PD requires the derivatives f(x)f^{\prime}(x). In the context of control (reinforcement learning), SF can be used to obtain unbiased policy gradient estimators in the “model-free” setting where we have no model of the dynamics, we only have access to sample trajectories.

SF tends to have higher variance than PD, when both estimators are applicable (see for instance ). The variance of SF increases (often linearly) with the dimensionality of the sampled variables. Hence, PD is usually preferable when xx is high-dimensional. On the other hand, PD has high variance if the function ff is rough, which occurs in many time-series problems due to an “exploding gradient problem” / “butterfly effect”.

PD allows for a deterministic limit, SF does not. This idea is exploited by the deterministic policy gradient algorithm .

The methods of estimating gradients of expectations have been independently proposed in several different fields, which use differing terminology. What we call the score function estimator (via ) is alternatively called the likelihood ratio estimator and REINFORCE . We chose this term because the score function is a well-known object in statistics. What we call the pathwise derivative estimator (from the mathematical finance literature and reinforcement learning ) is alternatively called infinitesimal perturbation analysis and stochastic backpropagation . We chose this term because pathwise derivative is evocative of propagating a derivative through a sample path.

2 Stochastic Computation Graphs

The results of this article will apply to stochastic computation graphs, which are defined as follows: {defn}[Stochastic Computation Graph] A directed, acyclic graph, with three types of nodes:

Input nodes, which are set externally, including the parameters we differentiate with respect to.

Deterministic nodes, which are functions of their parents.

Stochastic nodes, which are distributed conditionally on their parents.

Each parent vv of a non-input node ww is connected to it by a directed edge (v,w)(v,w). In the subsequent diagrams of this article, we will use circles to denote stochastic nodes and squares to denote deterministic nodes, as illustrated below. The structure of the graph fully specifies what estimator we will use: SF, PD, or a combination thereof. This graphical notation is shown below, along with the single-variable estimators from Section 2.1.

3 Simple Examples

Several simple examples that illustrate the stochastic computation graph formalism are shown below. The gradient estimators can be described by writing the expectations as integrals and differentiating, as with the simpler estimators from Section 2.1. However, they are also implied by the general results that we will present in Section 3.

These simple examples illustrate several important motifs, where stochastic and deterministic nodes are arranged in series or in parallel. For example, note that in (2) the derivative of yy does not appear in the estimator, since the path from θ\theta to ff is “blocked” by xx. Similarly, in (3), p(yx)p(y\>|\>x) does not appear (this type of behavior is particularly useful if we only have access to a simulator of a system, but not access to the actual likelihood function). On the other hand, (4) has a direct path from θ\theta to ff, which contributes a term to the gradient estimator. (5) resembles a parameterized Markov reward process, and it illustrates that we’ll obtain score function terms of the form grad log-probability ×\times future costs.

The examples above all have one input θ\theta, but the formalism accommodates models with multiple inputs, for example a stochastic neural network with multiple layers of weights and biases, which may influence different subsets of the stochastic and cost nodes. See Appendix C for nontrivial examples with stochastic nodes and multiple inputs. The figure on the right shows a deterministic computation graph representing classification loss for a two-layer neural network, which has four parameters (W1,b1,W2,b2)(W_{1},b_{1},W_{2},b_{2}) (weights and biases). Of course, this deterministic computation graph is a special type of stochastic computation graph.

Main Results on Stochastic Computation Graphs

This section will consider a general stochastic computation graph, in which a certain set of nodes are designated as costs, and we would like to compute the gradient of the sum of costs with respect to some input node θ\theta.

In brief, the main results of this section are as follows:

We derive a gradient estimator for an expected sum of costs in a stochastic computation graph. This estimator contains two parts (1) a score function part, which is a sum of terms grad log-prob of variable ×\times sum of costs influenced by variable; and (2) a pathwise derivative term, that propagates the dependence through differentiable functions.

This gradient estimator can be computed efficiently by differentiating an appropriate “surrogate” objective function.

Let Θ\Theta denote the set of input nodes, D\mathcal{D} the set of deterministic nodes, and S\mathcal{S} the set of stochastic nodes. Further, we will designate a set of cost nodes C\mathcal{C}, which are scalar-valued and deterministic. (Note that there is no loss of generality in assuming that the costs are deterministic—if a cost is stochastic, we can simply append a deterministic node that applies the identity function to it.) We will use θ\theta to denote an input node (θΘ\theta\in\Theta) that we differentiate with respect to. In the context of machine learning, we will usually be most concerned with differentiating with respect to a parameter vector (or tensor), however, the theory we present does not make any assumptions about what θ\theta represents.

For the results that follow, we need to define the notion of “influence”, for which we will introduce two relations \prec and D\prec^{\scriptscriptstyle D}. The relation vwv\prec w (“v influences w”) means that there exists a sequence of nodes a1,a2,,aKa_{1},a_{2},\dots,a_{K}, with K0K\geq 0, such that (v,a1),(a1,a2),,(aK1,aK),(aK,w)(v,a_{1}),(a_{1},a_{2}),\dots,(a_{K-1},a_{K}),(a_{K},w) are edges in the graph. The relation vDwv\prec^{\scriptscriptstyle D}w (“v deterministically influences w”) is defined similarly, except that now we require that each aka_{k} is a deterministic node. For example, in Figure 1, diagram (5) above, θ\theta influences {x1,x2,f1,f2}\{x_{1},x_{2},f_{1},f_{2}\}, but it only deterministically influences {x1,x2}\{x_{1},x_{2}\}.

Next, we will establish a condition that is sufficient for the existence of the gradient. Namely, we will stipulate that every edge (v,w)(v,w) with ww lying in the “influenced” set of θ\theta corresponds to a differentiable dependency: if ww is deterministic, then the Jacobian wv\frac{\partial w}{\partial v} must exist; if ww is stochastic, then the probability mass function p(wv,)p(w\>|\>v,\dots) must be differentiable with respect to vv.

More formally: {cond}[Differentiability Requirements] Given input node θΘ\theta\in\Theta, for all edges (v,w)(v,w) which satisfy θDv\theta\prec^{\scriptscriptstyle D}v and θDw\theta\prec^{\scriptscriptstyle D}w, then the following condition holds: if ww is deterministic, Jacobian wv\frac{\partial w}{\partial v} exists, and if ww is stochastic, then the derivative of the probability mass function vp(w\textscparentsw)\frac{\partial}{\partial v}p(w\>|\>\textsc{parents}_{w}) exists. Note that Section 3.1 does not require that all the functions in the graph are differentiable. If the path from an input θ\theta to deterministic node vv is blocked by stochastic nodes, then vv may be a nondifferentiable function of its parents. If a path from input θ\theta to stochastic node vv is blocked by other stochastic nodes, the likelihood of vv given its parents need not be differentiable; in fact, it does not need to be knownThis fact is particularly important for reinforcement learning, allowing us to compute policy gradient estimates despite having a discontinuous dynamics function or reward function..

We need a few more definitions to state the main theorems. Let \textscdepsv:={wΘSwDv}\textsc{deps}_{v}\vcentcolon=\{w\in\Theta\cup\mathcal{S}\>|\>w\prec^{\scriptscriptstyle D}v\}, the “dependencies” of node vv, i.e., the set of nodes that deterministically influence it. Note the following:

If vSv\in\mathcal{S}, the probability mass function of vv is a function of \textscdepsv\textsc{deps}_{v}, i.e., we can write p(v\textscdepsv)p(v\>|\>\textsc{deps}_{v}).

If vDv\in\mathcal{D}, vv is a deterministic function of \textscdepsv\textsc{deps}_{v}, so we can write v(\textscdepsv)v(\textsc{deps}_{v}).

Let Q^v:=cv,cCc^\hat{Q}_{v}\vcentcolon=\sum_{\begin{subarray}{c}c\succ v,\\ c\in\mathcal{C}\end{subarray}}{\hat{c}}, i.e., the sum of costs downstream of node vv. These costs will be treated as constant, fixed to the values obtained during sampling. In general, we will use the hat symbol v^\hat{v} to denote a sample value of variable vv, which will be treated as constant in the gradient formulae.

Now we can write down a general expression for the gradient of the expected sum of costs in a stochastic computation graph:

Suppose that θΘ\theta\in\Theta satisfies Section 3.1. Then the following two equivalent equations hold:

The estimator expressions above have two terms. The first term is due to the influence of θ\theta on probability distributions. The second term is due to the influence of θ\theta on the cost variables through a chain of differentiable functions. The distribution term involves a sum of gradients times “downstream” costs. The first term in Equation 5 involves a sum of gradients times “downstream” costs, whereas the first term in Equation 6 has a sum of costs times “upstream” gradients.

2 Surrogate Loss Functions

The next corollary lets us write down a “surrogate” objective LL, which is a function of the inputs that we can differentiate to obtain an unbiased gradient estimator.

One practical consequence of this result is that we can apply a standard automatic differentiation procedure to LL to obtain an unbiased gradient estimator. In other words, we convert the stochastic computation graph into a deterministic computation graph, to which we can apply the backpropagation algorithm.

There are several alternative ways to define the surrogate objective function that give the same gradient as LL from Corollary 1. We could also write L(Θ,S):=wp(w^\textscdepsw)P^vQ^w+cCc(\textscdepsc)L(\Theta,\mathcal{S})\vcentcolon=\sum_{w}\frac{p(\hat{w}\>|\>\textsc{deps}_{w})}{{\hat{P}}_{v}}\hat{Q}_{w}+\sum_{c\in\mathcal{C}}c(\textsc{deps}_{c}), where P^w{\hat{P}}_{w} is the probability p(w\textscdepsw)p(w\>|\>\textsc{deps}_{w}) obtained during sampling, which is viewed as a constant.

The surrogate objective from Corollary 1 is actually an upper bound on the true objective in the case that (1) all costs cCc\in\mathcal{C} are negative, (2) the the costs are not deterministically influenced by the parameters Θ\Theta. This construction allows from majorization-minimization algorithms (similar to EM) to be applied to general stochastic computation graphs. See Appendix B for details.

3 Higher-Order Derivatives.

The gradient estimator for a stochastic computation graph is itself a stochastic computation graph. Hence, it is possible to compute the gradient yet again (for each component of the gradient vector), and get an estimator of the Hessian. For most problems of interest, it is not efficient to compute this dense Hessian. On the other hand, one can also differentiate the gradient-vector product to get a Hessian-vector product—this computation is usually not much more expensive than the gradient computation itself. The Hessian-vector product can be used to implement a quasi-Newton algorithm via the conjugate gradient algorithm . A variant of this technique, called Hessian-free optimization , has been used to train large neural networks.

Variance Reduction

We can make a general statement for the case of stochastic computation graphs—that we can add a baseline to every stochastic node, which depends all of the nodes it doesn’t influence. Let \textscNonInfluenced(v):={wvw}\textsc{NonInfluenced}(v)\vcentcolon=\{w\>|\>v\nprec w\}.

Algorithms

As shown in Section 3, the gradient estimator can be obtained by differentiating a surrogate objective function LL. Hence, this derivative can be computed by performing the backpropagation algorithm on LL. That is likely to be the most practical and efficient method, and can be facilitated by automatic differentiation software.

Related Work

As discussed in Section 2, the score function and pathwise derivative estimators have been used in a variety of different fields, under different names. See for a review of gradient estimation, mostly from the simulation optimization literature. Glasserman’s textbook provides an extensive treatment of various gradient estimators and Monte Carlo estimators in general. Griewank and Walther’s textbook is a comprehensive reference on computation graphs and automatic differentiation (of deterministic programs.) The notation and nomenclature we use is inspired by Bayes nets and influence diagrams . (In fact, a stochastic computation graph is a type of Bayes network; where the deterministic nodes correspond to degenerate probability distributions.)

The topic of gradient estimation has drawn significant recent interest in machine learning. Gradients for networks with stochastic units was investigated in Bengio et al. , though they are concerned with differentiating through individual units and layers; not how to deal with arbitrarily structured models and loss functions. Kingma and Welling consider a similar framework, although only with continuous latent variables, and point out that reparameterization can be used to to convert hierarchical Bayesian models into neural networks, which can then be trained by backpropagation.

The score function method is used to perform variational inference in general models (in the context of probabilistic programming) in Wingate and Weber , and similarly in Ranganath et al. ; both papers mostly focus on mean-field approximations without amortized inference. It is used to train generative models using neural networks with discrete stochastic units in Mnih and Gregor and Gregor et al. in ; both amortize inference by using an inference network.

Generative models with continuous valued latent variables networks are trained (again using an inference network) with the reparametrization method by Rezende, Mohamed, and Wierstra and by Kingma and Welling . Rezende et al. also provide a detailed discussion of reparameterization, including a discussion comparing the variance of the SF and PD estimators.

Bengio, Leonard, and Courville have recently written a paper about gradient estimation in neural networks with stochastic units or non-differentiable activation functions—including Monte Carlo estimators and heuristic approximations. The notion that policy gradients can be computed in multiple ways was pointed out in early work on policy gradients by Williams . However, all of this prior work deals with specific structures of the stochastic computation graph and does not address the general case.

Conclusion

We have developed a framework for describing a computation with stochastic and deterministic operations, called a stochastic computation graph. Given a stochastic computation graph, we can automatically obtain a gradient estimator, given that the graph satisfies the appropriate conditions on differentiability of the functions at its nodes. The gradient can be computed efficiently in a backwards traversal through the graph: one approach is to apply the standard backpropagation algorithm to one of the surrogate loss functions from Section 3; another approach (which is roughly equivalent) is to apply a modified backpropagation procedure shown in Algorithm 1. The results we have presented are sufficiently general to automatically reproduce a variety of gradient estimators that have been derived in prior work in reinforcement learning and probabilistic modeling, as we show in Appendix C. We hope that this work will facilitate further development of interesting and expressive models.

Acknowledgements

We would like to thank Shakir Mohamed, Dave Silver, Yuval Tassa, Andriy Mnih, and Paul Horsfall for insightful comments.

References

Appendix A Proofs

We will consider the case that all of the random variables are continuous-valued, thus the expectations can be written as integrals. For discrete random variables, the integrals should be changed to sums.

Equation 9 requires that the integrand is differentiable, which is satisfied if all of the PDFs and c(\textscdepsc)c(\textsc{deps}_{c}) are differentiable. Equation 6 follows by summing over all costs cCc\in\mathcal{C}. Equation 5 follows from rearrangement of the terms in this equation.

It suffices to show that for a particular node vSv\in\mathcal{S}, the following expectation (taken over all variables) vanishes

Analogously to \textscNonInfluenced(v)\textsc{NonInfluenced}(v), define \textscInfluenced(v):={wwv}\textsc{Influenced}(v)\vcentcolon=\left\{w\>|\>w\succ v\right\}. Note that the nodes can be ordered so that \textscNonInfluenced(v)\textsc{NonInfluenced}(v) all come before vv in the ordering. Thus, we can write

Appendix B Surrogate as an Upper Bound, and MM Algorithms

LL has additional significance besides allowing us to estimate the gradient of the expected sum of costs. Under certain conditions, LL is a upper bound on on the true objective (plus a constant).

where the second line used the inequality xlogx+1x\geq\log x+1, and the sign is reversed since c^\hat{c} is negative. Summing over cCc\in\mathcal{C} and rearranging we get

Equation 20 allows for majorization-minimization algorithms (like the EM algorithm) to be used to optimize with respect to θ\theta. In fact, similar equations have been derived by interpreting rewards (negative costs) as probabilities, and then taking the variational lower bound on log-probability (e.g., ).

Appendix C Examples

This section considers two settings where the formalism of stochastic computation graphs can be applied. First, we consider the generalized EM algorithm for maximum likelihood estimation in probabilistic models with latent variables. Second, we consider reinforcement learning in Markov Decision Processes. In both cases, the objective function is given by an expectation; writing it out as a composition of stochastic and deterministic steps yields a stochastic computation graph.

The generalized EM algorithm maximizes likelihood in a probabilistic model with latent variables . We start with a parameterized probability density p(x,z;θ)p(x,z;\theta) where xx is observed, zz is a latent variable, and θ\theta is a parameter of the distribution. The generalized EM algorithm maximizes the variational lower bound, which is defined by an expectation over zz for each sample xx:

As parameters will appear both in the probability density and inside the expectation, stochastic computation graphs provide a convenient route for deriving the gradient estimators.

propose a generalized EM algorithm for multi-layered latent variable models that employs an inference network, an explicit parameterization of the posterior qϕ(zx)p(zx)q_{\phi}(z\>|\>x)\approx p(z\>|\>x), to allow for fast approximate inference. The generative model and inference network take the form

The inference model qϕq_{\phi} is used for sampling, i.e., we sample h1qϕ1(x),h2qϕ2(h1),h3qϕ3(h2)h_{1}\sim q_{\phi_{1}}(\cdot\>|\>x),h_{2}\sim q_{\phi_{2}}(\cdot\>|\>h_{1}),h_{3}\sim q_{\phi_{3}}(\cdot\>|\>h_{2}). The stochastic computation graph is shown above.

Given a sample hqϕh\sim q_{\phi} an unbiased estimate of the gradient is given by Theorem 2 as

where Q^1=r1+r2+r3\hat{Q}_{1}=r_{1}+r_{2}+r_{3}; Q^2=r2+r3\hat{Q}_{2}=r_{2}+r_{3}; and Q^3=r3\hat{Q}_{3}=r_{3}, and b1,b2,b3b_{1},b_{2},b_{3} are baseline functions.

Variational Autoencoder, Deep Latent Gaussian Models and Reparameterization.

Here we’ll note out that in some cases, the stochastic computation graph can be transformed to give the same probability distribution for the observed variables, but one obtains a different gradient estimator. Kingma and Welling and Rezende et al. consider a model that is similar to the one proposed by Mnih et al. but with continuous latent variables, and they re-parameterize their inference network to enable the use of the PD estimator. The original objective, the variational lower bound, is

The second term, the entropy of qϕq_{\phi}, can be computed analytically for the parametric forms of qq considered in the paper (Gaussians). For qϕq_{\phi} being conditionally Gaussian, i.e. qϕ(hx)=N(hμϕ(x),σϕ(x))q_{\phi}(h|x)=N(h|\mu_{\phi}(x),\sigma_{\phi}(x)) re-parameterizing leads to h=hϕ(ϵ;x)=μϕ(x)+ϵσϕ(x)h=h_{\phi}(\epsilon;x)=\mu_{\phi}(x)+\epsilon\sigma_{\phi}(x), giving

The stochastic computation graph before and after reparameterization is shown above. Given ϵρ\epsilon\sim\rho an estimate of the gradient is obtained as

C.2 Policy Gradients in Reinforcement Learning.

In reinforcement learning, an agent interacts with an environment according to its policy π\pi, and the goal is to maximize the expected sum of rewards, called the return. Policy gradient methods seek to directly estimate the gradient of expected return with respect to the policy parameters . In reinforcement learning, we typically assume that the environment dynamics are not available analytically and can only be sampled. Below we distinguish two important cases: the Markov decision process (MDP) and the partially observable Markov decision process (POMDP).

MDPs: In the MDP case, the expectation is taken with respect to the distribution over state (ss) and action (aa) sequences

where τ=(s1,a1,s2,a2,)\tau=(s_{1},a_{1},s_{2},a_{2},\dots) are trajectories and the distribution over trajectories is defined in terms of the environment dynamics pE(st+1st,at)p_{E}(s_{t+1}\>|\>s_{t},a_{t}) and the policy πθ\pi_{\theta}: pθ(τ)=pE(s1)tπθ(atst)pE(st+1st,at)p_{\theta}(\tau)=p_{E}(s_{1})\prod_{t}\pi_{\theta}(a_{t}\>|\>s_{t})p_{E}(s_{t+1}\>|\>s_{t},a_{t}). rr are rewards (negative costs in the terminology of the rest of the paper). The classic REINFORCE estimate of the gradient is given by

POMDPs differ from MDPs in that the state sts_{t} of the environment is not observed directly but, as in latent-variable time series models, only through stochastic observations oto_{t}, which depend on the latent states sts_{t} via pE(otst)p_{E}(o_{t}\>|\>s_{t}). The policy therefore has to be a function of the history of past observations πθ(ato1ot)\pi_{\theta}(a_{t}\>|\>o_{1}\dots o_{t}). Applying Theorem 2, we obtain a gradient estimator:

Here, the baseline btb_{t} and the policy πθ\pi_{\theta} can depend on the observation history through time tt, and these functions can be parameterized as recurrent neural networks .