Composing graphical models with neural networks for structured representations and fast inference

Matthew J. Johnson, David Duvenaud, Alexander B. Wiltschko, Sandeep R. Datta, Ryan P. Adams

Introduction

Modeling often has two goals: first, to learn a flexible representation of complex high-dimensional data, such as images or speech recordings, and second, to find structure that is interpretable and generalizes to new tasks. Probabilistic graphical models provide many tools to build structured representations, but often make rigid assumptions and may require significant feature engineering. Alternatively, deep learning methods allow flexible data representations to be learned automatically, but may not directly encode interpretable or tractable probabilistic structure. Here we develop a general modeling and inference framework that combines these complementary strengths.

Consider learning a generative model for video of a mouse. Learning interpretable representations for such data, and comparing them as the animal’s genes are edited or its brain chemistry altered, gives useful behavioral phenotyping tools for neuroscience and for high-throughput drug discovery . Even though each image is encoded by hundreds of pixels, the data lie near a low-dimensional nonlinear manifold. A useful generative model must not only learn this manifold but also provide an interpretable representation of the mouse’s behavioral dynamics. A natural representation from ethology is that the mouse’s behavior is divided into brief, reused actions, such as darts, rears, and grooming bouts. Therefore an appropriate model might switch between discrete states, with each state representing the dynamics of a particular action. These two learning tasks — identifying an image manifold and a structured dynamics model — are complementary: we want to learn the image manifold in terms of coordinates in which the structured dynamics fit well. A similar challenge arises in speech , where high-dimensional spectrographic data lie near a low-dimensional manifold because they are generated by a physical system with relatively few degrees of freedom but also include the discrete latent dynamical structure of phonemes, words, and grammar .

To address these challenges, we propose a new framework to design and learn models that couple nonlinear likelihoods with structured latent variable representations. Our approach uses graphical models for representing structured probability distributions while enabling fast exact inference subroutines, and uses ideas from variational autoencoders for learning not only the nonlinear feature manifold but also bottom-up recognition networks to improve inference. Thus our method enables the combination of flexible deep learning feature models with structured Bayesian (and even nonparametric ) priors. Our approach yields a single variational inference objective in which all components of the model are learned simultaneously. Furthermore, we develop a scalable fitting algorithm that combines several advances in efficient inference, including stochastic variational inference , graphical model message passing , and backpropagation with the reparameterization trick . Thus our algorithm can leverage conjugate exponential family structure where it exists to efficiently compute natural gradients with respect to some variational parameters, enabling effective second-order optimization , while using backpropagation to compute gradients with respect to all other parameters. We refer to our general approach as the structured variational autoencoder (SVAE).

Latent graphical models with neural net observations

In this paper we propose a broad family of models. Here we develop three specific examples.

One particularly natural structure used frequently in graphical models is the discrete mixture model. By fitting a discrete mixture model to data, we can discover natural clusters or units. These discrete structures are difficult to represent directly in neural network models.

Consider the problem of modeling the data y={yn}n=1Ny=\{y_{n}\}_{n=1}^{N} shown in Fig. 1(a). A standard approach to finding the clusters in data is to fit a Gaussian mixture model (GMM) with a conjugate prior:

However, the fit GMM does not represent the natural clustering of the data (Fig. 1(b)). Its inflexible Gaussian observation model limits its ability to parsimoniously fit the data and their natural semantics.

Instead of using a GMM, a more flexible alternative would be a neural network density model:

where μ(xn;γ)\mu(x_{n};\gamma) and Σ(xn;γ)\Sigma(x_{n};\gamma) depend on xnx_{n} through some smooth parametric function, such as multilayer perceptron (MLP), and where p(γ)p(\gamma) is a Gaussian prior . This model fits the data density well (Fig. 1(c)) but does not explicitly represent discrete mixture components, which might provide insights into the data or natural units for generalization. See Fig. 2(a) for a graphical model.

By composing a latent GMM with nonlinear observations, we can combine the modeling strengths of both , learning both discrete clusters along with non-Gaussian cluster shapes:

This combination of flexibility and structure is shown in Fig. 1(d). See Fig. 2(b) for a graphical model.

2 Latent linear dynamical systems for modeling video

Now we consider a harder problem: generatively modeling video. Since a video is a sequence of image frames, a natural place to start is with a model for images. shows that the density network of Eq. (1) can accurately represent a dataset of high-dimensional images {yn}n=1N\{y_{n}\}_{n=1}^{N} in terms of the low-dimensional latent variables {xn}n=1N\{x_{n}\}_{n=1}^{N}, each with independent Gaussian distributions.

To extend this image model into a model for videos, we can introduce dependence through time between the latent Gaussian samples {xn}n=1N\{x_{n}\}_{n=1}^{N}. For instance, we can make each latent variable xnx_{n} depend on the previous latent variable xn1x_{n-1} through a Gaussian linear dynamical system, writing

where the matrices AA and BB have a conjugate prior. This model has low-dimensional latent states and dynamics as well as a rich nonlinear generative model of images. In addition, the timescales of the dynamics are represented directly in the eigenvalue spectrum of AA, providing both interpretability and a natural way to encode prior information. See Fig. 2(c) for a graphical model.

3 Latent switching linear dynamical systems for parsing behavior from video

As a final example that combines both time series structure and discrete latent units, consider again the behavioral phenotyping problem described in Section 1. Drawing on graphical modeling tools, we can construct a latent switching linear dynamical system (SLDS) to represent the data in terms of continuous latent states that evolve according to a discrete library of linear dynamics, and drawing on deep learning methods we can generate video frames with a neural network image model.

Structured mean field inference and recognition networks

Why aren’t such rich hybrid models used more frequently? The main difficulty with combining rich latent variable structure and flexible likelihoods is inference. The most efficient inference algorithms used in graphical models, like structured mean field and message passing, depend on conjugate exponential family likelihoods to preserve tractable structure. When the observations are more general, like neural network models, inference must either fall back to general algorithms that do not exploit the model structure or else rely on bespoke algorithms developed for one model at a time.

In this section, we review inference ideas from conjugate exponential family probabilistic graphical models and variational autoencoders, which we combine and generalize in the next section.

Graphical models and exponential families provide many algorithmic tools for efficient inference . Given an exponential family latent variable model, when the observation model is a conjugate exponential family, the conditional distributions stay in the same exponential families as in the prior and hence allow for the same efficient inference algorithms.

For example, consider learning a Gaussian linear dynamical system model with linear Gaussian observations. The generative model for latent states x={xn}n=1Nx=\{x_{n}\}_{n=1}^{N} and observations y={yn}n=1Ny=\{y_{n}\}_{n=1}^{N} is

given parameters θ=(A,B,C,D){\theta=(A,B,C,D)} with a conjugate prior p(θ)p(\theta). To approximate the posterior p(θ,xy)p(\theta,x\,|\,y), consider the mean field family q(θ)q(x)q(\theta)q(x) and the variational inference objective

where we can optimize the variational family q(θ)q(x)q(\theta)q(x) to approximate the posterior p(θ,xy)p(\theta,x\,|\,y) by maximizing Eq. (7). Because the observation model p(yx,θ)p(y\,|\,x,\theta) is conjugate to the latent variable model p(xθ)p(x\,|\,\theta), for any fixed q(θ)q(\theta) the optimal factor q ⁣(x)arg maxq(x)L[q(θ)q(x)]{q}^{*}\!(x)\triangleq\operatorname*{arg\,max}_{q(x)}\mathcal{L}[\,q(\theta)q(x)\,] is itself a Gaussian linear dynamical system with parameters that are simple functions of the expected statistics of q(θ)q(\theta) and the data yy. As a result, for fixed q(θ)q(\theta) we can easily compute q ⁣(x){q}^{*}\!(x) and use message passing algorithms to perform exact inference in it. However, when the observation model is not conjugate to the latent variable model, these algorithmically exploitable structures break down.

2 Recognition networks in variational autoencoders

The variational autoencoder (VAE) handles general non-conjugate observation models by introducing recognition networks. For example, when a Gaussian latent variable model p(x)p(x) is paired with a general nonlinear observation model p(yx,γ)p(y\,|\,x,\gamma), the posterior p(xy,γ)p(x\,|\,y,\gamma) is non-Gaussian, and it is difficult to compute an optimal Gaussian approximation. The VAE instead learns to directly output a suboptimal Gaussian factor q(xy)q(x\,|\,y) by fitting a parametric map from data yy to a mean and covariance, μ(y;ϕ)\mu(y;\phi) and Σ(y;ϕ)\Sigma(y;\phi), such as an MLP with parameters ϕ\phi. By optimizing over ϕ\phi, the VAE effectively learns how to condition on non-conjugate observations yy and produce a good approximating factor.

Structured variational autoencoders

We can combine the tractability of conjugate graphical model inference with the flexibility of variational autoencoders. The main idea is to use a conditional random field (CRF) variational family. We learn recognition networks that output conjugate graphical model potentials instead of outputting the complete variational distribution’s parameters directly. These potentials are then used in graphical model inference algorithms in place of the non-conjugate observation likelihoods.

The SVAE algorithm computes stochastic gradients of a mean field variational inference objective. It can be viewed as a generalization both of the natural gradient SVI algorithm for conditionally conjugate models and of the AEVB algorithm for variational autoencoders . Intuitively, it proceeds by sampling a data minibatch, applying the recognition model to compute graphical model potentials, and using graphical model inference algorithms to compute the variational factor, combining the evidence from the potentials with the prior structure in the model. This variational factor is then used to compute gradients of the mean field objective. See Fig. 3 for graphical models of the variational families with recognition networks for the models developed in Section 2.

In this section, we outline the SVAE model class more formally, write the mean field variational inference objective, and show how to efficiently compute unbiased stochastic estimates of its gradients. The resulting algorithm for computing gradients of the mean field objective, shown in Algorithm 1, is simple and efficient and can be readily applied to a variety of learning problems and graphical model structures. See the supplementals for details and proofs.

To set up notation for a general SVAE, we first define a conjugate pair of exponential family densities on global latent variables θ\theta and local latent variables x={xn}n=1Nx=\{x_{n}\}_{n=1}^{N}. Let p(xθ)p(x\,|\,\theta) be an exponential family and let p(θ)p(\theta) be its corresponding natural exponential family conjugate prior, writing

where we used exponential family conjugacy to write tθ(θ)=(ηx0(θ),logZx(ηx0(θ)))t_{\theta}(\theta)=\left(\eta^{0}_{x}(\theta),-\log Z_{x}(\eta^{0}_{x}(\theta))\right). The local latent variables xx could have additional structure, like including both discrete and continuous latent variables or tractable graph structure, but here we keep the notation simple.

Next, we define a general likelihood function. Let p(yx,γ)p(y\,|\,x,\gamma) be a general family of densities and let p(γ)p(\gamma) be an exponential family prior on its parameters. For example, each observation yny_{n} may depend on the latent value xnx_{n} through an MLP, as in the density network model of Section 2. This generic non-conjugate observation model provides modeling flexibility, yet the SVAE can still leverage conjugate exponential family structure in inference, as we show next.

2 Stochastic variational inference algorithm

Though the general observation model p(yx,γ)p(y\,|\,x,\gamma) means that conjugate updates and natural gradient SVI cannot be directly applied, we show that by generalizing the recognition network idea we can still approximately optimize out the local variational factors leveraging conjugacy structure.

For fixed yy, consider the mean field family q(θ)q(γ)q(x)q(\theta)q(\gamma)q(x) and the variational inference objective

Without loss of generality we can take the global factor q(θ)q(\theta) to be in the same exponential family as the prior p(θ)p(\theta), and we denote its natural parameters by ηθ\eta_{\theta}. We restrict q(γ)q(\gamma) to be in the same exponential family as p(γ)p(\gamma) with natural parameters ηγ\eta_{\gamma}. Finally, we restrict q(x)q(x) to be in the same exponential family as p(xθ)p(x\,|\,\theta), writing its natural parameter as ηx\eta_{x}. Using these explicit variational parameters, we write the mean field variational inference objective in Eq. (10) as L(ηθ,ηγ,ηx)\mathcal{L}(\eta_{\theta},\eta_{\gamma},\eta_{x}).

To perform efficient optimization of the objective L(ηθ,ηγ,ηx)\mathcal{L}(\eta_{\theta},\eta_{\gamma},\eta_{x}), we consider choosing the variational parameter ηx\eta_{x} as a function of the other parameters ηθ\eta_{\theta} and ηγ\eta_{\gamma}. One natural choice is to set ηx\eta_{x} to be a local partial optimizer of L\mathcal{L}. However, without conjugacy structure finding a local partial optimizer may be computationally expensive for general densities p(yx,γ)p(y\,|\,x,\gamma), and in the large data setting this expensive optimization would have to be performed for each stochastic gradient update. Instead, we choose ηx\eta_{x} by optimizing over a surrogate objective L^\widehat{\mathcal{L}} with conjugacy structure, given by

As with the variational autoencoder of Section 3.2, the resulting variational factor q ⁣(x){q}^{*}\!(x) is suboptimal for the variational objective L\mathcal{L}. However, because the surrogate objective has the same form as a variational inference objective for a conjugate observation model, the factor q ⁣(x){q}^{*}\!(x) not only is easy to compute but also inherits exponential family and graphical model structure for tractable inference.

Choosing ηx(ηθ,ϕ)\eta^{*}_{x}(\eta_{\theta},\phi) to be a local partial optimizer of L^\widehat{\mathcal{L}} provides two computational advantages. First, it allows ηx(ηθ,ϕ)\eta^{*}_{x}(\eta_{\theta},\phi) and expectations with respect to q ⁣(x){q}^{*}\!(x) to be computed efficiently by exploiting exponential family graphical model structure. Second, it provides computationally efficient ways to estimate the natural gradient with respect to the latent model parameters, as we summarize next.

where F(ηθ)=L(ηθ,ηγ,ηx(ηθ,ϕ))F(\eta_{\theta}^{\prime})=\mathcal{L}(\eta_{\theta},\eta_{\gamma},\eta_{x}^{*}(\eta_{\theta}^{\prime},\phi)). When there is only one local variational factor q(x)q(x), then we can simplify the estimator to

Note that the first term in Eq. (13) is the same as the expression for the natural gradient in SVI for conjugate models , while a stochastic estimate of F(ηθ)\nabla F(\eta_{\theta}) in the first expression or, alternatively, a stochastic estimate of ηθL(ηθ,ηγ,ηx(ηθ,ϕ))\nabla_{\eta_{\theta}}\mathcal{L}(\eta_{\theta},\eta_{\gamma},\eta_{x}^{*}(\eta_{\theta},\phi)) in the second expression is computed automatically as part of the backward pass for computing the gradients with respect to the other parameters, as described next. Thus we have an expression for the natural gradient with respect to the latent model’s parameters that is almost as simple as the one for conjugate models, differing only by a term involving the neural network likelihood function. Natural gradients are invariant to smooth invertible reparameterizations of the variational family and provide effective second-order optimization updates .

The KL divergence terms are between members of the same tractable exponential families. An unbiased estimate of the first term can be computed by sampling x^q ⁣(x)\hat{x}\sim{q}^{*}\!(x) and γ^q(γ)\hat{\gamma}\sim q(\gamma) and computing ηγ,ϕlogp(yx^,γ^)\nabla_{\eta_{\gamma},\phi}\log p(y\,|\,\hat{x},\hat{\gamma}) with automatic differentiation.

Related work

In addition to the papers already referenced, there are several recent papers to which this work is related.

The two papers closest to this work are and . In the authors consider combining variational autoencoders with continuous state-space models, emphasizing the relationship to linear dynamical systems (also called Kalman filter models). They primarily focus on nonlinear dynamics and an RNN-based variational family, as well as allowing control inputs. However, the approach does not extend to general graphical models or discrete latent variables. It also does not leverage natural gradients or exact inference subroutines.

In the authors also consider the problem of variational inference in general continuous state space models but focus on using a structured Gaussian variational family without considering parameter learning. As with , this approach does not include discrete latent variables (or any latent variables other than the continuous states). However, the method they develop could be used with an SVAE to handle inference with nonlinear dynamics.

In addition, both and extend the variational autoencoder framework to sequential models, though they focus on RNNs rather than probabilistic graphical models.

Finally, there is much related work on handling nonconjugate model terms in mean field variational inference. In and the authors present a general scheme that is able to exploit conjugate exponential family structure while also handling arbitrary nonconjugate model factors, including the nonconjugate observation models we consider here. In particular, they propose using a proximal gradient framework and splitting the variational inference objective into a difficult term to be linearized (with respect to mean parameters) and a tractable concave term, so that the resulting proximal gradient update is easy to compute, just like in a fully conjugate model. In , the authors propose performing natural gradient descent with respect to natural parameters on each of the variational factors in turn, and they focus on approximating expectations of nonconjugate energy terms in the objective with model-specific lower-bounds (rather than estimating them with generic Monte Carlo). As in conjugate SVI , they observe that, on conjugate factors and with an undamped update (i.e. a unit step size), the natural gradient update reduces to the standard conjugate mean field update.

In contrast to the approaches of , , and , rather than linearizing intractable terms around the current iterate, in this work we handle intractable terms via recognition networks and amoritized inference (and the remaining tractable objective terms are multi-concave in general, analogous to SVI ). That is, we use parametric function approximators to learn to condition on evidence in a conjugate form. We expect these approaches to handling nonconjugate objective terms may be complementary, and the best choice may be situation-dependent. For models with local latent variables and datasets where minibatch-based updating is important, using inference networks to compute local variational parameters in a fixed-depth circuit (as in the VAE ) or optimizing out the local variational factors using fast conjugate updates (as in conjugate SVI ) can be advantageous because in both cases local variational parameters for the entire dataset need not be maintained across updates. The SVAE we propose here is a way to combine the inference network and conjugate SVI approaches.

Experiments

We apply the SVAE to both synthetic and real data and demonstrate its ability to learn feature representations and latent structure. Code is available at github.com/mattjj/svae.

Consider a sequence of 1D images representing a dot bouncing from one side of the image to the other, as shown at the top of Fig. 4. We use an LDS SVAE to find a low-dimensional latent state space representation along with a nonlinear image model. The model is able to represent the image accurately and to make long-term predictions with uncertainty. See supplementals for details.

This experiment also demonstrates the optimization advantages that can be provided by the natural gradient updates. In Fig. 5(a) we compare natural gradient updates with standard gradient updates at three different learning rates. The natural gradient algorithm not only learns much faster but also is less dependent on parameterization details: while the natural gradient update used an untuned stepsize of 0.1, the standard gradient dynamics at step sizes of both 0.1 and 0.05 resulted in some matrix parameters to be updated to indefinite values.

2 LDS SVAE for modeling video

We also apply an LDS SVAE to model depth video recordings of mouse behavior. We use the dataset from in which a mouse is recorded from above using a Microsoft Kinect. We used a subset consisting of 8 recordings, each of a distinct mouse, 20 minutes long at 30 frames per second, for a total of 288000 video fames downsampled to 30×3030\times 30 pixels.

We use MLP observation and recognition models with two hidden layers of 200 units each and a 10D latent space. Fig. 5(b) shows images corresponding to a regular grid on a random 2D subspace of the latent space, illustrating that the learned image manifold accurately captures smooth variation in the mouse’s body pose. Fig. 6 shows predictions from the model paired with real data.

3 SLDS SVAE for parsing behavior

Finally, because the LDS SVAE can accurately represent the depth video over short timescales, we apply the latent switching linear dynamical system (SLDS) model to discover the natural units of behavior. Fig. 7 and Fig. 8 in the appendix show some of the discrete states that arise from fitting an SLDS SVAE with 30 discrete states to the depth video data. The discrete states that emerge show a natural clustering of short-timescale patterns into behavioral units. See the supplementals for more.

Conclusion

Structured variational autoencoders provide a general framework that combines some of the strengths of probabilistic graphical models and deep learning methods. In particular, they use graphical models both to give models rich latent representations and to enable fast variational inference with CRF-like structured approximating distributions. To complement these structured representations, SVAEs use neural networks to produce not only flexible nonlinear observation models but also fast recognition networks that map observations to conjugate graphical model potentials.

pages15 rangepages16 rangepages15 rangepages19 rangepages9 rangepages16 rangepages9 rangepages26 rangepages9 rangepages9 rangepages9

References

Appendix A Optimization

In this section we fix our notation for gradients and establish some basic definitions and results that we use in the sequel.

The transpose of f\nabla f is the Jacobian matrix of ff, in which the ijijth entry is the function fi/xj\partial f_{i}/\partial x_{j}.

A.2 Local and partial optimizers

In this section we state the definitions of local partial optimizer and necessary conditions for optimality that we use in the sequel.

and we call y ⁣{y}^{*}\! an unconstrained local partial optimizer of ff given xx if there exists an ϵ>0\epsilon>0 such that

where \|\,\cdot\,\| is any vector norm.

and hence the cost gradient yf(x,y ⁣)\nabla_{y}f(x,{y}^{*}\!) is orthogonal to the first-order feasible variations in yy given by the null space of yh(x,y ⁣)T\nabla_{y}h(x,{y}^{*}\!)^{\mathsf{T}}.

Note that the regularity condition on the constraints is not needed if the constraints are linear [26, Prop. 3.3.7].

A.3 Partial optimization and the Implicit Function Theorem

and using the chain rule write its gradient as

If y ⁣{y}^{*}\! is an unconstrained local partial optimizer of ff given xx then it satisfies yf(x,y ⁣)=0\nabla_{y}f(x,{y}^{*}\!)=0, and if y ⁣{y}^{*}\! is a regularly-constrained local partial optimizer then the feasible variation y ⁣(x)\nabla{y}^{*}\!(x) is orthogonal to the cost gradient yf(x,y ⁣)\nabla_{y}f(x,{y}^{*}\!). In both cases the second term in the expression for g(x)\nabla g(x) in Eq. (25) is zero. ∎

hh is continuous and has a continuous nonsingular gradient matrix yh(x,y)\nabla_{y}h(x,y) in an open set containing (xˉ,yˉ)(\bar{x},\bar{y}).

Appendix B Exponential families

In this section we set up notation for exponential families and outline some basic results. Throughout this section we take all densities to be absolutely continuous with respect to the appropriate Lebesgue measure (when the underlying set X\mathcal{X} is Euclidean space) or counting measure (when X\mathcal{X} is discrete), and denote the Borel σ\sigma-algebra of a set X\mathcal{X} as B(X)\mathcal{B}(\mathcal{X}) (generated by Euclidean and discrete topologies, respectively). We assume measurability of all functions as necessary.

We can write the normalized probability density as

and take Θ=ηx1(H)\Theta=\eta_{x}^{-1}(H) to be the open set of parameters that correspond to normalizable densities. We summarize this notation in the following definition.

The next proposition shows that the log partition function of an exponential family generates cumulants of the statistic.

The gradient of the log partition function of an exponential family gives the expected sufficient statistic,

where the expectation is over the random variable xx with density p(xη)p(x\,|\,\eta). More generally, the moment generating function of t(x)t(x) can be written

and so derivatives of logZ\log Z give cumulants of t(x)t(x), where the first cumulant is the mean and the second and third cumulants are the second and third central moments, respectively.

Given an exponential family of densities on X\mathcal{X} as in Definition B.1, we can define a related exponential family of densities on Θ\Theta by defining a statistic function tθ(θ)t_{\theta}(\theta) in terms of the functions ηx(θ)\eta_{x}(\theta) and logZx(ηx(θ))\log Z_{x}(\eta_{x}(\theta)).

where the first nn coordinates of tθ(θ)t_{\theta}(\theta) are given by ηx(θ)\eta_{x}(\theta) and the last coordinate is given by logZx(ηx(θ))-\log Z_{x}(\eta_{x}(\theta)). We call the exponential family with statistic tθ(θ)t_{\theta}(\theta) the natural exponential family conjugate prior to the density p(xθ)p(x\,|\,\theta) and write

Notice that using tθ(θ)t_{\theta}(\theta) we can rewrite the original density p(xθ)p(x\,|\,\theta) as

This relationship is useful in Bayesian inference: when the exponential family p(xθ)p(x\,|\,\theta) is a likelihood function and the family p(θ)p(\theta) is used as a prior, the pair enjoy a convenient conjugacy property, as summarized in the next proposition.

Let the densities p(xθ)p(x\,|\,\theta) and p(θ)p(\theta) be defined as in Definitions B.1 and B.3, respectively. We have the relations

and hence in particular the posterior p(θx)p(\theta\,|\,x) is in the same exponential family as p(θ)p(\theta) with the natural parameter ηθ+(tx(x),1)\eta_{\theta}+(t_{x}(x),1). Similarly, with multiple likelihood terms p(xiθ)p(x_{i}\,|\,\theta) for i=1,2,,Ni=1,2,\ldots,N we have

Finally, we give a few more exponential family properties that are useful for gradient-based optimization algorithms and variational inference. In particular, we note that the Fisher information matrix of an exponential family can be computed as the Hessian matrix of its log partition function, and that the KL divergence between two members of the same exponential family has a simple expression.

Given a family of densities p(xθ)p(x\,|\,\theta) indexed by a parameter θ\theta, the score vector v(x,θ)v(x,\theta) is the gradient of the log density with respect to the parameter,

and the Fisher information matrix for the parameter θ\theta is the covariance of the score,

Given an exponential family of densities p(xη)p(x\,|\,\eta) indexed by the natural parameter η\eta, as in Eq. (32), the score with respect to the natural parameter is given by

and the Fisher information matrix is given by

Given an exponential family of densities p(xη)p(x\,|\,\eta) indexed by the natural parameter η\eta, as in Eq. (32), and two particular members with natural parameters η1\eta_{1} and η2\eta_{2}, respectively, the KL divergence from one to the other is

Appendix C Natural gradient SVI for exponential families

In this section we give a derivation of the natural gradient stochastic variational inference (SVI) method of using our notation. We extend the algorithm in Section D.

Let p(x,yθ)p(x,y\,|\,\theta) be an exponential family and p(θ)p(\theta) be its corresponding natural exponential family prior as in Definitions B.1 and B.3, writing

where we have used tθ(θ)=(ηxy0(θ),logZxy(ηxy0(θ)))t_{\theta}(\theta)=\left(\eta^{0}_{xy}(\theta),-\log Z_{xy}(\eta^{0}_{xy}(\theta))\right) in Eq. (52).

Given a fixed observation yy, for any density q(θ,x)=q(θ)q(x)q(\theta,x)=q(\theta)q(x) we have

where we have used the fact that the KL divergence is always nonnegative. Therefore to choose q(θ)q(x)q(\theta)q(x) to minimize the KL divergence to the posterior p(θ,xy)p(\theta,x\,|\,y) we define the mean field variational inference objective as

and the mean field variational inference problem as

The following proposition shows that because of the exponential family conjugacy structure, we can fix the parameterization of q(θ)q(\theta) and still optimize over all possible densities without loss of generality.

Given the mean field optimization problem Eq. (56), for any fixed q(x)q(x) the optimal factor q(θ)q(\theta) is detetermined (νΘ\nu_{\Theta}-a.e.) by

In particular, the optimal q(θ)q(\theta) is in the same exponential family as the prior p(θ)p(\theta).

This proposition follows immediately from a more general lemma, which we reuse in the sequel.

Let p(a,b,c)p(a,b,c) be a joint density and let q(a)q(a), q(b)q(b), and q(c)q(c) be mean field factors. Consider the mean field variational inference objective

For fixed q(a)q(a) and q(c)q(c), the partially optimal factor q ⁣(b){q}^{*}\!(b) over all possible densities,

In particular, if p(cb,a)p(c\,|\,b,a) is an exponential family with p(ba)p(b\,|\,a) its natural exponential family conjugate prior, and logp(b,ca)\log p(b,c\,|\,a) is a multilinear polynomial in the statistics tb(b)t_{b}(b) and tc(c)t_{c}(c), written

for some matrix ηc0(a)\eta^{0}_{c}(a), then the optimal factor can be written

As a special case, when cc is conditionally independent of bb given aa, so that p(cb,a)=p(cb)p(c\,|\,b,a)=p(c\,|\,b), then

Rewrite the objective in Eq. (59), dropping terms that are constant with respect to q(b)q(b), as

Proposition C.1 justifies parameterizing the density q(θ)q(\theta) with variational natural parameters ηθ\eta_{\theta} as

where the statistic function tθt_{\theta} and the log partition function logZθ\log Z_{\theta} are the same as in the prior family p(θ)p(\theta). Using this parameterization, we can define the mean field objective as a function of the parameters ηθ\eta_{\theta}, partially optimizing over q(x)q(x),

The partial optimization over q(x)q(x) in Eq. (71) should be read as choosing q(x)q(x) to be a local partial optimizer of Eq. (55); in general, it may be intractable to find a global partial optimizer, and the results that follow use only first-order stationary conditions on q(x)q(x). We refer to this objective function, where we locally partially optimize the mean field objective Eq. (55) over q(x)q(x), as the SVI objective.

C.2 Easy natural gradients of the SVI objective

By again leveraging the conjugate exponential family structure, we can write a simple expression for the gradient of the SVI objective, and even for its natural gradient.

Let the SVI objective L(ηθ)\mathcal{L}(\eta_{\theta}) be defined as in Eq. (71). Then the gradient L(ηθ)\nabla\mathcal{L}(\eta_{\theta}) is

where q ⁣(x){q}^{*}\!(x) is a local partial optimizer of the mean field objective Eq. (55) for fixed global variational parameters ηθ\eta_{\theta}.

First, note that because q ⁣(x){q}^{*}\!(x) is a local partial optimizer for Eq. (55) by Proposition A.3, we have

Next, we use the conjugate exponential family structure and Proposition B.4, Eq. (42), to expand

As an immediate result of Proposition C.3, the natural gradient defined by

The natural gradient of the SVI objective Eq. (71) is

The natural gradient corrects for a kind of curvature in the variational family and is invariant to reparameterization of the family . As a result, natural gradient ascent is effectively a second-order quasi-Newton optimization algorithm, and using natural gradients can greatly accelerate the convergence of gradient-based optimization algorithms . It is a remarkable consequence of the exponential family structure that natural gradients of the partially optimized mean field objective with respect to the global variational parameters can be computed efficiently (without any backward pass as would be required in generic reverse-mode differentiation). Indeed, the exponential family conjugacy structure makes the natural gradient of the SVI objective even easier to compute than the flat gradient.

C.3 Stochastic natural gradients for large datasets

The real utility of natural gradient SVI is in its application to large datasets. Consider the model composed of global latent variables θ\theta, local latent variables x={xn}n=1Nx=\{x_{n}\}_{n=1}^{N}, and data y={yn}n=1Ny=\{y_{n}\}_{n=1}^{N},

where each p(xn,ynθ)p(x_{n},y_{n}\,|\,\theta) is a copy of the same likelihood function with conjugate prior p(θ)p(\theta). For fixed observations y={yn}n=1Ny=\{y_{n}\}_{n=1}^{N}, let

be a variational family to approximate the posterior p(θ,xy)p(\theta,x\,|\,y) and consider the SVI objective given by Eq. (71). Using Eq. (44) of Proposition B.4, it is straightforward to extend the natural gradient expression in Corollary C.4 to an unbiased Monte Carlo estimate which samples terms in the sum over data points.

where p(θ)p(\theta) and p(xn,ynθ)p(x_{n},y_{n}\,|\,\theta) are a conjugate pair of exponential families, define L(ηθ)\mathcal{L}(\eta_{\theta}) as in Eq. (71). Let the random index n^\hat{n} be sampled from the set {1,2,,N}\{1,2,\ldots,N\} and let pn>0p_{n}>0 be the probability it takes value nn. Then

where q ⁣(xn^){q}^{*}\!(x_{\hat{n}}) is a local partial optimizer of L\mathcal{L} given q(θ)q(\theta).

Taking expectation over the index n^\hat{n}, we have

The remainder of the proof follows from Proposition B.4 and the same argument as in Proposition C.3. ∎

The unbiased stochastic gradient developed in Corollary C.5 can be used in a scalable stochastic gradient ascent algorithm. To simplify notation, in the following sections we drop the notation for multiple likelihood terms p(xn,ynθ)p(x_{n},y_{n}\,|\,\theta) for n=1,2,,Nn=1,2,\ldots,N and return to working with a single likelihood term p(x,yθ)p(x,y\,|\,\theta). The extension to multiple likelihood terms is immediate.

C.4 Conditinally conjugate models and block updating

The model classes often considered for natural gradient SVI, and the main model classes we consider here, have additional conjugacy structure in the local latent variables. In this section we introduce notation for this extra structure in terms of the additional local latent variables zz and discuss the local block coordinate optimization that is often performed to compute the factor q ⁣(z)q ⁣(x){q}^{*}\!(z){q}^{*}\!(x) for use in the natural gradient expression.

Let p(z,x,yθ)p(z,x,y\,|\,\theta) be an exponential family and p(θ)p(\theta) be its corresponding natural exponential family conjugate prior, writing

where we have used tθ(θ)=(ηzxy0(θ),logZzxy(ηzxy0(θ)))t_{\theta}(\theta)=\left(\eta^{0}_{zxy}(\theta),-\log Z_{zxy}(\eta^{0}_{zxy}(\theta))\right) in Eq. (87). Additionally, let tzxy(z,x,y)t_{zxy}(z,x,y) be a multilinear polynomial in the statistics functions tx(x)t_{x}(x), ty(y)t_{y}(y), and tz(z)t_{z}(z), let p(zθ)p(z\,|\,\theta), p(xz,θ)p(x\,|\,z,\theta), and p(yx,z,θ)=p(yx,θ)p(y\,|\,x,z,\theta)=p(y\,|\,x,\theta) be exponential families, and let p(zθ)p(z\,|\,\theta) be a conjugate prior to p(xz,θ)p(x\,|\,z,\theta) and p(xz,θ)p(x\,|\,z,\theta) be a conjugate prior to p(yx,θ)p(y\,|\,x,\theta), so that

for some matrices ηx0(θ)\eta^{0}_{x}(\theta) and ηy0(θ)\eta^{0}_{y}(\theta).

This model class includes many common models, including the latent Dirichlet allocation, switching linear dynamical systems with linear-Gaussian emissions, and mixture models and hidden Markov models with exponential family emissions. The conditionally conjugate structure is both powerful and restrictive: while it potentially limits the expressiveness of the model class, it enables block coordinate optimization with very simple and fast updates, as we show next. When conditionally conjugate structure is not present, these local optimizations can instead be performed with generic gradient-based methods and automatic differentiation .

Let p(θ,z,x,y)p(\theta,z,x,y) be a model as in Eqs. (85)-(92), and for fixed data yy let q(θ)q(z)q(x)q(\theta)q(z)q(x) be a corresponding mean field variational family for approximating the posterior p(θ,z,xy)p(\theta,z,x\,|\,y), with

and with the mean field variational inference objective

Fixing the other factors, the partial optimizers q ⁣(z){q}^{*}\!(z) and q ⁣(x){q}^{*}\!(x) for L\mathcal{L} over all possible densities are given by

This proposition is a consequence of Lemma C.2 and the conjugacy structure. ∎

Proposition C.6 gives an efficient block coordinate ascent algorithm: for fixed ηθ\eta_{\theta}, by alternatively updating ηz\eta_{z} and ηx\eta_{x} according to Eqs. (99)-(100) we are guaranteed to converge to a stationary point that is partially optimal in the parameters of each factor. In addition, performing each update requires only computing expected sufficient statistics in the variational factors, which means evaluating logZθ(ηθ)\nabla\log Z_{\theta}(\eta_{\theta}), logZz(ηz)\nabla\log Z_{z}(\eta_{z}), and logZx(ηx)\nabla\log Z_{x}(\eta_{x}), quantities that be computed anyway in a gradient-based optimization routine. The block coordinate ascent procedure leveraging this conditional conjugacy structure is thus not only efficient but also does not require a choice of step size.

Note in particular that this procedure produces parameters ηz(ηθ)\eta_{z}^{*}(\eta_{\theta}) and ηx(ηθ)\eta_{x}^{*}(\eta_{\theta}) that are partially optimal (and hence stationary) for the objective. That is, defining the parameterized mean field variational inference objective as L(ηθ,ηz,ηx)=L[q(θ)q(z)q(x)]L(\eta_{\theta},\eta_{z},\eta_{x})=\mathcal{L}[\,q(\theta)q(z)q(x)\,], for fixed ηθ\eta_{\theta} the block coordinate ascent procedure has limit points ηz\eta_{z}^{*} and ηx\eta_{x}^{*} that satisfy

Appendix D The SVAE objective and its gradients

In this section we define the SVAE variational lower bound and show how to efficiently compute unbiased stochastic estimates of its gradients, including an unbiased estimate of the natural gradient with respect to the variational parameters with conjugacy structure. The setup here parallels the setup for natural gradient SVI in Section C, but while SVI is restricted to complete-data conjugate models, here we consider more general likelihood models.

Let p(xθ)p(x\,|\,\theta) be an exponential family and let p(θ)p(\theta) be its corresponding natural exponential family conjugate prior, as in Definitions B.1 and B.3, writing

where we have used tθ(θ)=(ηx0(θ),logZx(ηx0(θ)))t_{\theta}(\theta)=\left(\eta^{0}_{x}(\theta),-\log Z_{x}(\eta^{0}_{x}(\theta))\right) in Eq. (104). Let p(yx,γ)p(y\,|\,x,\gamma) be a general family of densities (not necessarily an exponential family) and let p(γ)p(\gamma) be an exponential family prior on its parameters of the form

For fixed yy, consider the mean field family of densities q(θ,γ,x)=q(θ)q(γ)q(x)q(\theta,\gamma,x)=q(\theta)q(\gamma)q(x) and the mean field variational inference objective

By the same argument as in Proposition C.1, without loss of generality we can take the global factor q(θ)q(\theta) to be in the same exponential family as the prior p(θ)p(\theta), and we denote its natural parameters by ηθ\eta_{\theta}, writing

We restrict q(γ)q(\gamma) to be in the same exponential family as p(γ)p(\gamma) with natural parameters ηγ\eta_{\gamma}, writing

Finally, we restrictThe parametric form for q(x)q(x) need not be restricted a priori, but rather without loss of generality given the surrogate objective Eq. (111) and the form of ψ\psi used in Eq. (112), the optimal factor q(x)q(x) is in the same family as p(xθ)p(x\,|\,\theta). We treat it as a restriction here so that we can proceed with more concrete notation. q(x)q(x) to be in the same exponential family as p(xθ)p(x\,|\,\theta), writing its natural parameter as ηx\eta_{x}. Using these explicit variational natural parameters, we rewrite the mean field variational inference objective in Eq. (106) as

To perform efficient optimization in the objective L\mathcal{L} defined in Eq. (109), we consider choosing the variational parameter ηx\eta_{x} as a function of the other parameters ηθ\eta_{\theta} and ηγ\eta_{\gamma}. One natural choice is to set ηx\eta_{x} to be a local partial optimizer of L\mathcal{L}, as in Section C. However, finding a local partial optimizer may be computationally expensive for general densities p(yx,γ)p(y\,|\,x,\gamma), and in the large data setting this expensive optimization would have to be performed for each stochastic gradient update. Instead, we choose ηx\eta_{x} by optimizing over a surrogate objective L^\widehat{\mathcal{L}}, which we design using exponential family structure to be both easy to optimize and to share curvature properties with the mean field objective L\mathcal{L}. The surrogate objective L^\widehat{\mathcal{L}} is

where the constant does not depend on ηx\eta_{x}. We define the function ψ(x;y,ϕ)\psi(x;y,\phi) to have a form related to the exponential family p(xθ)p(x\,|\,\theta),

where the notation above should be interpreted as choosing ηx(ηθ,ϕ)\eta^{*}_{x}(\eta_{\theta},\phi) to be a local argument of maximum. The results to follow rely only on necessary first-order conditions for unconstrained local optimality.

Given this choice of function ηx(ηθ,ϕ)\eta^{*}_{x}(\eta_{\theta},\phi), we define the SVAE objective to be

where L\mathcal{L} is the mean field variational inference defined in Eq. (109), and we define the SVAE optimization problem to be

We summarize these definitions in the following.

Let L\mathcal{L} denote the mean field variational inference objective

where the densities p(θ)p(\theta), p(γ)p(\gamma), and p(xθ)p(x\,|\,\theta) are exponential families and p(θ)p(\theta) is the natural exponential family conjugate prior to p(xθ)p(x\,|\,\theta), as in Eqs. (102)-(104). Given a parameterization of the variational factors as

let L(ηθ,ηγ,ηx)\mathcal{L}(\eta_{\theta},\eta_{\gamma},\eta_{x}) denote the mean field variational inference objective Eq. (116) as a function of these variational parameters. We define the SVAE objective as

where ηx(ηθ,ϕ)\eta_{x}^{*}(\eta_{\theta},\phi) is defined as a local partial optimizer of the surrogate objective L^\widehat{\mathcal{L}},

where the surrogate objective L^\widehat{\mathcal{L}} is defined as

then the bound can be made tight in the sense that

When there is only one local latent variational factor q(x)q(x) (and no further factorization structure), the natural gradient of the SVAE objective Eq. (114) with respect to the conjugate global variational parameters ηθ\eta_{\theta} is

where the first term is the SVI natural gradient from Corollary C.4, using

and where a stochastic estimate of the second term is computed as part of the backward pass for the gradient ϕL(ηθ,ηγ,ηx(ηθ,ϕ))\nabla_{\phi}\mathcal{L}(\eta_{\theta},\eta_{\gamma},\eta^{*}_{x}(\eta_{\theta},\phi)).

First we use the chain rule, analogously to Eq. (25), to write the gradient as

where the first term is the same as the SVI gradient derived in Proposition C.3. In the case of SVI, the second term is zero because ηx\eta^{*}_{x} is chosen as a partial optimizer of L\mathcal{L}, but for the SVAE objective the second term is nonzero in general, and the remainder of this proof amounts to deriving a simple expression for it.

We compute the term ηθηx(ηθ,ϕ)\nabla_{\eta_{\theta}}\eta^{*}_{x}(\eta_{\theta},\phi) in Eq. (127) in terms of the gradients of the surrogate objective L^\widehat{\mathcal{L}} using the Implicit Function Theorem given in Corollary A.5, which yields

First, we compute the gradient of L^\widehat{\mathcal{L}} with respect to ηx\eta_{x}, writing

When there is only one local latent variational factor q(x)q(x) (and no further factorization structure), as a consequence of the first-order stationary condition ηxL^(ηθ,ηx(ηθ,ϕ),ϕ)=0\nabla_{\eta_{x}}\widehat{\mathcal{L}}(\eta_{\theta},\eta_{x}^{*}(\eta_{\theta},\phi),\phi)=0 and the fact that 2logZx(ηx)\nabla^{2}\log Z_{x}(\eta_{x}) is always positive definite for minimal exponential families, we have

which is useful in simplifying the expressions to follow.

Continuing with the calculation of the terms in Eq. (128), we compute ηxηx2L^\nabla^{2}_{\eta_{x}\eta_{x}}\widehat{\mathcal{L}} by differentiating the expression in Eq. (130) again, writing

where the last line follows from using the first-order stationary condition Eq. (131). Next, we compute the other term ηθηx2L^\nabla^{2}_{\eta_{\theta}\eta_{x}}\widehat{\mathcal{L}} by differentiating Eq. (130) with respect to ηθ\eta_{\theta} to yield

where the latter matrix is 2logZx(ηx(ηθ,ϕ))\nabla^{2}\log Z_{x}(\eta_{x}^{*}(\eta_{\theta},\phi)) padded by a row of zeros.

Plugging these expressions back into Eq. (128) and cancelling, we arrive at

and so we have an expression for the gradient of the SVAE objective as

When we compute the natural gradient, the Fisher information matrix factors on the left of each term cancel, yielding the result in the proposition. ∎

The proof of Proposition D.3 uses the necessary condition for unconstrained local optimality to simplify the expression in Eq. (132). This simplification does not necessarily hold if ηx\eta_{x} is constrained; for example, if the factor q(x)q(x) has additional factorization structure, then there are additional (linear) coordinate subspace constraints on ηx\eta_{x}. Note also that when q(x)q(x) is a Gaussian family with fixed covariance (that is, with sufficient statistics tx(x)=xt_{x}(x)=x) the same simplification always applies because third and higher-order cumulants are zero for such families and hence 3logZx(ηx)=0\nabla^{3}\log Z_{x}(\eta_{x})=0.

More generally, when the local latent variables have additional factorization structure, as in the Gaussian mixture model (GMM) and switching linear dynamical system (SLDS) examples, the natural gradient with respect to ηθ\eta_{\theta} can be estimated efficiently by writing Eq. (127) as

where we can recover the second term in Eq. (127) by using the chain rule. We can estimate this second term directly using the reparameterization trick. Note that to compute the natural gradient estimate in this case, we need to apply (2logZθ(ηθ))1{(\nabla^{2}\log Z_{\theta}(\eta_{\theta}))}^{-1} to this term because the convenient cancellation from Proposition D.3 does not apply. When ηθ\eta_{\theta} is of small dimension compared to ηγ\eta_{\gamma}, ϕ\phi, and even ηx\eta_{x}, this additional computational cost is not large.

and so the dependence of the expression in Eq. (138) on ϕ\phi is through ηx(ηθ,ϕ)\eta^{*}_{x}(\eta_{\theta},\phi). Only the first term in Eq. (138) needs to be estimated with the reparameterization trick.

We summarize this procedure in the following proposition.

with respect to ηγ\eta_{\gamma} and ϕ\phi, respectively.

D.4 Partially optimizing ℒ^^ℒ\widehat{\mathcal{L}} using conjugacy structure

In Section D.1 we defined the SVAE objective in terms of a function ηx(ηθ,ϕ)\eta_{x}^{*}(\eta_{\theta},\phi), which was itself implicitly defined in terms of first-order stationary conditions for an auxiliary objective L^(ηθ,ηx,ϕ)\widehat{\mathcal{L}}(\eta_{\theta},\eta_{x},\phi). Here we show how L^\widehat{\mathcal{L}} admits efficient local partial optimization in the same way as the conditionally conjugate model of Section C.4.

In this section we consider additional structure in the local latent variables. Specifically, as in Section C.4, we introduce to the notation another set of local latent variables zz in addition to the local latent variables xx. However, unlike Section C.4, we still consider general likelihood families p(yx,γ)p(y\,|\,x,\gamma).

Let p(z,xθ)p(z,x\,|\,\theta) be an exponential family and p(θ)p(\theta) be its corresponding natural exponential family conjugate prior, writing

where we have used tθ(θ)=(ηzx0(θ),logZzx(ηzx0(θ)))t_{\theta}(\theta)=\left(\eta^{0}_{zx}(\theta),-\log Z_{zx}(\eta^{0}_{zx}(\theta))\right) in Eq. (87). Additionally, let tzx(z,x)t_{zx}(z,x) be a multilinear polynomial in the statistics tz(z)t_{z}(z) and tx(x)t_{x}(x), and let p(zθ)p(z\,|\,\theta) and p(xz,θ)p(x\,|\,z,\theta) be a conjugate pair of exponential families, writing

Let p(yx,γ)p(y\,|\,x,\gamma) be a general family of densities (not necessarily an exponential family) and let p(γ)p(\gamma) be an exponential family prior on its parameters of the form

The corresponding variational factors are

As in Section D.1, we construct the surrogate objective L^\widehat{\mathcal{L}} to allow us to exploit exponential family and conjugacy structure. In particular, we construct L^\widehat{\mathcal{L}} to resemble the mean field objective, namely

but in L^\widehat{\mathcal{L}} we replace the logp(yx,γ)\log p(y\,|\,x,\gamma) likelihood term, which may be a general family of densities without much structure, with a more tractable approximation,

where ψ(x;y,ϕ)\psi(x;y,\phi) is a function on xx that resembles a conjugate likelihood for p(xz,θ)p(x\,|\,z,\theta),

We then define ηz(ηθ,ϕ)\eta^{*}_{z}(\eta_{\theta},\phi) and ηx(ηθ,ϕ)\eta^{*}_{x}(\eta_{\theta},\phi) to be local partial optimizers of L^\widehat{\mathcal{L}} given fixed values of the other parameters ηθ\eta_{\theta} and ϕ\phi, and in particular they satisfy the first-order necessary optimality conditions

The structure of the surrogate objective L^\widehat{\mathcal{L}} is chosen so that it resembles the mean field variational inference objective for the conditionally conjugate model of Section C.4, and as a result we can use the same block coordinate ascent algorithm to efficiently find partial optimzers ηz(ηθ,ϕ)\eta_{z}^{*}(\eta_{\theta},\phi) and ηx(ηθ,ϕ)\eta_{x}^{*}(\eta_{\theta},\phi).

with the other arguments fixed, are are given by

and by alternating the expressions in Eq. (155) as updates we can compute ηz(ηθ,ϕ)\eta_{z}^{*}(\eta_{\theta},\phi) and ηx(ηθ,ϕ)\eta_{x}^{*}(\eta_{\theta},\phi) as local partial optimizers of L^\widehat{\mathcal{L}}.

These updates follow immediately from Lemma C.2. Note in particular that the stationary conditions ηzL^=0\nabla_{\eta_{z}}\widehat{\mathcal{L}}=0 and ηxL^=0\nabla_{\eta_{x}}\widehat{\mathcal{L}}=0 yield the each expression in Eq. (155), respectively. ∎

The other properties developed in Propositions D.2, D.3, and D.4 also hold true for this model because it is a special case in which we have separated out the local variables, denoted xx in earlier sections, into two groups, denoted zz and xx here, to match the exponential family structure in p(zθ)p(z\,|\,\theta) and p(xz,θ)p(x\,|\,z,\theta), and performed unconstrained optimization in each of the variational parameters. However, the expression for the natural gradient is slightly simpler for this model than the corresponding version of Proposition D.3.

Appendix E Experiment details and expanded figures

For the synthetic 1D dot video data, we trained an LDS SVAE on 80 random image sequences each of length 50, using one sequence per update, and show the model’s future predictions given a prefix of a longer sequence. We used MLP image and recognition models each with one hidden layer of 50 units and a latent state space of dimension 8.