Functional Variational Bayesian Neural Networks

Shengyang Sun, Guodong Zhang, Jiaxin Shi, Roger Grosse

Introduction

Bayesian neural networks (BNNs) (Hinton & Van Camp, 1993; Neal, 1995) have the potential to combine the scalability, flexibility, and predictive performance of neural networks with principled Bayesian uncertainty modelling. However, the practical effectiveness of BNNs is limited by our ability to specify meaningful prior distributions and by the intractability of posterior inference. Choosing a meaningful prior distribution over network weights is difficult because the weights have a complicated relationship to the function computed by the network. Stochastic variational inference is appealing because the update rules resemble ordinary backprop (Graves, 2011; Blundell et al., 2015), but fitting accurate posterior distributions is difficult due to strong and complicated posterior dependencies (Louizos & Welling, 2016; Zhang et al., 2018; Shi et al., 2018a).

In a classic result, Neal (1995) showed that under certain assumptions, as the width of a shallow BNN was increased, the limiting distribution is a Gaussian process (GP). Lee et al. (2018) recently extended this result to deep BNNs. Deep Gaussian Processes (DGP) (Cutajar et al., 2017; Salimbeni & Deisenroth, 2017) have close connections to BNNs due to similar deep structures. However, the relationship of finite BNNs to GPs is unclear, and practical variational BNN approximations fail to match the predictions of the corresponding GP. Furthermore, because the previous analyses related specific BNN architectures to specific GP kernels, it’s not clear how to design BNN architectures for a given kernel. Given the rich variety of structural assumptions that GP kernels can represent (Rasmussen & Williams, 2006; Lloyd et al., 2014; Sun et al., 2018), there remains a significant gap in expressive power between BNNs and GPs (not to mention stochastic processes more broadly).

In this paper, we perform variational inference directly on the distribution of functions. Specifically, we introduce functional variational BNNs (fBNNs), where a BNN is trained to produce a distribution of functions with small KL divergence to the true posterior over functions. We prove that the KL divergence between stochastic processes can be expressed as the supremum of marginal KL divergences at finite sets of points. Based on this, we present functional ELBO (fELBO) training objective. Then we introduce a GAN-like minimax formulation and a sampling-based approximation for functional variational inference. To approximate the marginal KL divergence gradients, we adopt the recently proposed spectral Stein gradient estimator (SSGE) (Shi et al., 2018b).

Our fBNNs make it possible to specify stochastic process priors which encode richly structured dependencies between function values. This includes stochastic processes with explicit densities, such as GPs which can model various structures like smoothness and periodicity (Lloyd et al., 2014; Sun et al., 2018). We can also use stochastic processes with implicit densities, such as distributions over piecewise linear or piecewise constant functions. Furthermore, in contrast with GPs, fBNNs efficiently yield explicit posterior samples of the function. This enables fBNNs to be used in settings that require explicit minimization of sampled functions, such as Thompson sampling (Thompson, 1933; Russo & Van Roy, 2016) or predictive entropy search (Hernández-Lobato et al., 2014; Wang & Jegelka, 2017).

One desideratum of Bayesian models is that they behave gracefully as their capacity is increased (Rasmussen & Ghahramani, 2001). Unfortunately, ordinary BNNs don’t meet this basic requirement: unless the asymptotic regime is chosen very carefully (e.g. Neal (1995)), BNN priors may have undesirable behaviors as more units or layers are added. Furthermore, larger BNNs entail more difficult posterior inference and larger description length for the posterior, causing degeneracy for large networks, as shown in Figure 1. In contrast, the prior of fBNNs is defined directly over the space of functions, thus the BNN can be made arbitrarily large without changing the functional variational inference problem. Hence, the predictions behave well as the capacity increases.

Empirically, we demonstrate that fBNNs generate sensible extrapolations for both explicit periodic priors and implicit piecewise priors. We show fBNNs outperform competing approaches on both small scale and large scale regression datasets. fBNNs’ reliable uncertainty estimates enable state-of-art performance on the contextual bandits benchmark of Riquelme et al. (2018).

Background

Given a dataset D={(xi,yi)}i=1n\mathcal{D}=\{(\mathbf{x}_{i},y_{i})\}_{i=1}^{n}, a Bayesian neural network (BNN) is defined in terms of a prior p(w)p(\mathbf{w}) on the weights, as well as the likelihood p(Dw)p(\mathcal{D}|\mathbf{w}). Variational Bayesian methods (Hinton & Van Camp, 1993; Graves, 2011; Blundell et al., 2015) attempt to fit an approximate posterior q(w)q(\mathbf{w}) to maximize the evidence lower bound (ELBO):

The most commonly used variational BNN training method is Bayes By Backprop (BBB) (Blundell et al., 2015), which uses a fully factorized Gaussian approximation to the posterior, i.e. q(w)=N(w;μ,diag(σ2))q(\mathbf{w})=\mathcal{N}(\mathbf{w};{\bm{\mu}},\operatorname*{diag}(\bm{\sigma}^{2})). Using the reparameterization trick (Kingma & Welling, 2013), the gradients of ELBO towards μ,σ\mu,\sigma can be computed by backpropagation, and then be used for updates.

Most commonly, the prior p(w)p(\mathbf{w}) is chosen for computational convenience; for instance, independent Gaussian or Gaussian mixture distributions. Other priors, including log-uniform priors (Kingma et al., 2015; Louizos et al., 2017) and horseshoe priors (Ghosh et al., 2018; Louizos et al., 2017), were proposed for specific purposes such as model compression and model selection. But the relationships of weight-space priors to the functions computed by networks are difficult to characterize.

2 Stochastic Processes

A stochastic process (Lamperti, 2012) FF is typically defined as a collection of random variables, on a probability space (Ω,F,P)\left(\Omega,\mathcal{F},P\right). The random variables, indexed by some set X\mathcal{X}, all take values in the same mathematical space Y\mathcal{Y} . In other words, given a probability space (Ω,Σ,P)\left(\Omega,\Sigma,P\right), a stochastic process can be simply written as {F(x):xX}\{F(\mathbf{x}):\mathbf{x}\in\mathcal{X}\}. For any point ωΩ\omega\in\Omega, F(,ω)F(\cdot,\omega) is a sample function mapping index space X\mathcal{X} to space Y\mathcal{Y}, which we denote as ff for notational simplicity.

For any finite index set x1:n={x1,...,xn}\mathbf{x}_{1:n}=\{\mathbf{x}_{1},...,\mathbf{x}_{n}\}, we can define the finite-dimensional marginal joint distribution over function values {F(x1),,F(xn)}\{F(\mathbf{x}_{1}),\cdots,F(\mathbf{x}_{n})\}. For example, Gaussian Processes have marginal distributions as multivariate Gaussians.

The Kolmogorov Extension Theorem (Øksendal, 2003) shows that a stochastic process can be characterized by marginals over all finite index sets. Specifically, for a collection of joint distributions ρx1:n\rho_{\mathbf{x}_{1:n}}, we can define a stochastic process FF such that for all x1:n\mathbf{x}_{1:n}, ρx1:n\rho_{\mathbf{x}_{1:n}} is the marginal joint distribution of FF at x1:n\mathbf{x}_{1:n}, as long as ρ\rho satisfies the following two conditions:

Exchangeability. For any permutation π\pi of {1,,n}\{1,\cdots,n\}, ρπ(x1:n)(π(y1:n))=ρx1:n(y1:n).\rho_{\pi(\mathbf{x}_{1:n})}(\pi(y_{1:n}))=\rho_{\mathbf{x}_{1:n}}(y_{1:n}).

Consistency. For any 1mn1\leq m\leq n, ρx1:m(y1:m)=ρx1:n(y1:n)dym+1:n.\rho_{\mathbf{x}_{1:m}}(y_{1:m})=\int\rho_{\mathbf{x}_{1:n}}(y_{1:n})dy_{m+1:n}.

3 Spectral Stein Gradient Estimator (SSGE)

When applying Bayesian methods to modern probabilistic models, especially those with neural networks as components (e.g., BNNs and deep generative models), it is often the case that we have to deal with intractable densities. Examples include the marginal distribution of a non-conjugate model (e.g., the output distribution of a BNN), and neural samplers such as GANs (Goodfellow et al., 2014). A shared property of these distributions is that they are defined through a tractable sampling process, despite the intractable density. Such distributions are called implicit distributions (Huszár, 2017).

The Spectral Stein Gradient Estimator (SSGE) (Shi et al., 2018b) is a recently proposed method for estimating the log density derivative function of an implicit distribution, only requiring samples from the distribution. Specifically, given a continuous differentiable density q(x)q(\mathbf{x}), and a positive definite kernel k(x,x)k(\mathbf{x},\mathbf{x}^{\prime}) in the Stein class (Liu et al., 2016) of qq, they show

where {ψj}j1\{\psi_{j}\}_{j\geq 1} is a series of eigenfunctions of kk given by Mercer’s theorem: k(x,x)=jμjψj(x)ψj(x)k(\mathbf{x},\mathbf{x}^{\prime})=\sum_{j}\mu_{j}\psi_{j}(\mathbf{x})\psi_{j}(\mathbf{x}^{\prime}). The Nyström method (Baker, 1997; Williams & Seeger, 2001) is used to approximate the eigenfunctions ψj(x)\psi_{j}(\mathbf{x}) and their derivatives. The final estimator is given by truncating the sum in Equation 2 and replacing the expectation by Monte Carlo estimates.

Functional Variational Bayesian Neural Networks

We introduce function space variational inference analogously to weight space variational inference (see Section 2.1), except that the distributions are over functions rather than weights. We assume a stochastic process prior pp over functions f:XYf:\mathcal{X}\to\mathcal{Y}. This could be a GP, but we also allow stochastic processes without closed-form marginal densities, such as distributions over piecewise linear functions. For the variational posterior qϕQq_{\phi}\in\mathcal{Q}, we consider a neural network architecture with stochastic weights and/or stochastic inputs. Specifically, we sample a function from qq by sampling a random noise vector ξ\xi and defining f(x)=gϕ(x,ξ)f(\mathbf{x})=g_{\phi}(\mathbf{x},\mathbf{\xi}) for some function gϕg_{\phi}. For example, standard weight space BNNs with factorial Gaussian posteriors can be viewed this way using the reparameterization trick (Kingma & Welling, 2013; Blundell et al., 2015). (In this case, ϕ\phi corresponds to the means and variances of all the weights.) Note that because a single vector ξ\mathbf{\xi} is shared among all input locations, it corresponds to randomness in the function, rather than observation noise; hence, the sampling of ξ\mathbf{\xi} corresponds to epistemic, rather than aleatoric, uncertainty (Depeweg et al., 2017).

Functional variational inference maximizes the functional ELBO (fELBO), akin to the weight space ELBO in Equation 1, except that the distributions are over functions rather than weights.

For two stochastic processes PP and QQ,

Roughly speaking, this result follows because the σ\sigma-algebra constructed with the Kolmogorov Extension Theorem (Section 2.2) is generated by cylinder sets which depend only on finite sets of points. A full proof is given in Appendix A.

Using this characterization of the functional KL divergence, we rewrite the fELBO:

We also denote Ln(q):=infXXnLX(q)\mathcal{L}_{n}(q):=\underset{\mathbf{X}\in\mathcal{X}^{n}}{\inf}\mathcal{L}_{\mathbf{X}}(q) for the restriction to sets of nn points. This casts maximizing the fELBO as a two-player zero-sum game analogous to a generative adversarial network (GAN) (Goodfellow et al., 2014): one player chooses the stochastic network, and the adversary chooses the measurement set. Note that the infimum may not be attainable, because the size of the measurement sets is unbounded. In fact, the function space KL divergence may be infinite, for instance if the prior assigns measure zero to the set of functions representable by a neural network (Arjovsky & Bottou, 2017). Observe that GANs face the same issue: because a generator network is typically limited to a submanifold of the input domain, an ideal discriminator could discriminate real and fake images perfectly. However, by limiting the capacity of the discriminator, one obtains a useful training objective. By analogy, we obtain a well-defined and practically useful training objective by restricting the measurement sets to a fixed finite size. This is discussed further in the next section.

2 Choosing the Measurement Set

As discussed above, we approximate the fELBO using finite measurement sets to have a well-defined and practical optimization objective. We now discuss how to choose the measurement sets.

The minimax formulation of the fELBO naturally suggests a two-player zero-sum game, analogous to GANs, whereby one player chooses the stochastic network representing the posterior, and the adversary chooses the measurement set.

We adopt concurrent optimization akin to GANs (Goodfellow et al., 2014). In the inner loop, we minimize LX(q)\mathcal{L}_{\mathbf{X}}(q) with respect to X\mathbf{X}; in the outer loop, we maximize LX(q)\mathcal{L}_{\mathbf{X}}(q) with respect to qq.

Unfortunately, this approach did not perform well in terms of generalization. The measurement set which maximizes the KL term is likely to be close to the training data, since these are the points where one has the most information about the function. But the KL term is the only part of the fELBO encouraging the network to match the prior structure. Hence, if the measurement set is close to the training data, then nothing will encourage the network to exploit the structured prior for extrapolation.

Instead, we adopt a sampling-based approach. In order to use a structured prior for extrapolation, the network needs to match the prior structure both near the training data and in the regions where it must make predictions. Therefore, we sample measurement sets which include both (a) random training inputs, and (b) random points from the domain where one is interested in making predictions. We replace the minimization in Equation 6 with a sampling distribution cc, and then maximize the expected LX(q)\mathcal{L}_{{\mathbf{X}}}(q):

where XM\mathbf{X}^{M} are MM points independently drawn from cc.

With the restriction to finite measurement sets, one only has an upper bound on the true fELBO. Unfortunately, this means the approximation is not a lower bound on the log marginal likelihood (log-ML) logp(D)\log p(\mathcal{D}). Interestingly, if the measurement set is chosen to include all of the training inputs, then L(q)\mathcal{L}(q) is in fact a log-ML lower bound:

If the measurement set X\mathbf{X} contains all the training inputs XD\mathbf{X}^{D}, then

To better understand the relationship between adversarial and sampling-based inference, we consider the idealized scenario where the measurement points in both approaches include all training locations, i.e., X={XD,XM}\mathbf{X}=\{\mathbf{X}^{D},\mathbf{X}^{M}\}. Let fM,fD\mathbf{f}^{M},\mathbf{f}^{D} be the function values at XM,XD\mathbf{X}^{M},\mathbf{X}^{D}, respectively. By 2,

So maximizing LXM,XD(q)\mathcal{L}_{{\mathbf{X}^{M},\mathbf{X}^{D}}}(q) is equivalent to minimizing the KL divergence from the true posterior on points XM,XD\mathbf{X}^{M},\mathbf{X}^{D}. Based on this, we have the following consistency theorem that helps justify the use of adversarial and sampling-based objectives with finite measurement points.

The proof is given in Section B.2. While it is usually impractical for the measurement set to contain all the training inputs, it is still reassuring that a proper lower bound can be obtained with a finite measurement set.

3 KL Divergence Gradients

While the likelihood term of the fELBO is tractable, the KL divergence term remains intractable because we don’t have an explicit formula for the variational posterior density qϕ(fX)q_{\phi}(\mathbf{f}^{\mathbf{X}}). (Even if qϕq_{\phi} is chosen to have a tractable density in weight space (Louizos & Welling, 2017), the marginal distribution over fX\mathbf{f}^{\mathbf{X}} is likely intractable.) To derive an approximation, we first observe that

The first term (expected score function) in Equation 11 is zero, so we discard it.But note that it may be useful as a control variate (Roeder et al., 2017). The Jacobian ϕfX\nabla_{\phi}\mathbf{f}^{\mathbf{X}} can be exactly multiplied by a vector using backpropagation. Therefore, it remains to estimate the log-density derivatives {\color[rgb]{0,0,1}\nabla_{\mathbf{f}}\log q({\mathbf{f}^{\mathbf{X}}})} and {\color[rgb]{1,0,0}\nabla_{\mathbf{f}}\log p({\mathbf{f}^{\mathbf{X}}})}.

The entropy derivative {\color[rgb]{0,0,1}\nabla_{\mathbf{f}}\log q({\mathbf{f}^{\mathbf{X}}})} is generally intractable. For priors with tractable marginal densities such as GPs (Rasmussen & Williams, 2006)Section D.1 introduces an additional fix to deal with the GP kernel matrix stability issue. and Student-t Processes (Shah et al., 2014), {\color[rgb]{1,0,0}\nabla_{\mathbf{f}}\log p({\mathbf{f}^{\mathbf{X}}})} is tractable. However, we are also interested in implicit stochastic process priors, i.e. {\color[rgb]{1,0,0}\nabla_{\mathbf{f}}\log p({\mathbf{f}^{\mathbf{X}}})} is also intractable. Because the SSGE (see Section 2.3) can estimate score functions for both in-distribution and out-of-distribution samples, we use it to estimate both derivative terms in all our experiments. (We compute {\color[rgb]{1,0,0}\nabla_{\mathbf{f}}\log p({\mathbf{f}^{\mathbf{X}}})} exactly whenever it is tractable.)

4 The Algorithm

Now we present the whole algorithm for fBNNs in Algorithm 1. In each iteration, our measurement points include a mini-batch Ds\mathcal{D}_{s} from the training data and random points XM\mathbf{X}^{M} from a distribution cc. We forward XDs\mathbf{X}^{D_{s}} and XM\mathbf{X}^{M} together through the network g(;ϕ)g(\cdot;\phi) which defines the variational posterior qϕq_{\phi}. Then we try to maximize the following objective corresponding to fELBO:

Related work

Variational inference was first applied to neural networks by Peterson (1987) and Hinton & Van Camp (1993). More recently, Graves (2011) proposed a practical method for variational inference with fully factorized Gaussian posteriors which used a simple (but biased) gradient estimator. Improving on that work, Blundell et al. (2015) proposed an unbiased gradient estimator using the reparameterization trick of Kingma & Welling (2013). There has also been much work (Louizos & Welling, 2016; Sun et al., 2017; Zhang et al., 2018; Bae et al., 2018) on modelling the correlations between weights using more complex Gaussian variational posteriors. Some non-Gaussian variational posteriors have been proposed, such as multiplicative normalizing flows (Louizos & Welling, 2017) and implicit distributions (Shi et al., 2018a). Neural networks with dropout were also interpreted as BNNs (Gal & Ghahramani, 2016; Gal et al., 2017). Local reparameterization trick (Kingma et al., 2015) and Flipout (Wen et al., 2018) try to decorrelate the gradients within a mini-batch for reducing variances during training. However, all these methods place priors over the network parameters. Often, spherical Gaussian priors are placed over the weights for convenience. Other priors, including log-uniform priors (Kingma et al., 2015; Louizos et al., 2017) and horseshoe priors (Ghosh et al., 2018; Louizos et al., 2017), were proposed for specific purposes such as model compression and model selection. But the relationships of weight-space priors to the functions computed by networks are difficult to characterize.

There have been other recent attempts to train BNNs in the spirit of functional priors. Flam-Shepherd et al. (2017) trained a BNN prior to mimic a GP prior, but they still required variational inference in weight space. Noise Contrastive Priors (Hafner et al., 2018) are somewhat similar in spirit to our work in that they use a random noise prior in the function space. The prior is incorporated by adding a regularization term to the weight-space ELBO, and is not rich enough to encourage extrapolation and pattern discovery. Neural Processes (NP) (Garnelo et al., 2018) try to model any conditional distribution given arbitrary data points, whose prior is specified implicitly by prior samples. However, in high dimensional spaces, conditional distributions become increasingly complicated to model. Variational Implicit Processes (VIP) (Ma et al., 2018) are, in a sense, the reverse of fBNNs: they specify BNN priors and use GPs to approximate the posterior. But the use of BNN priors means they can’t exploit richly structured GP priors or other stochastic processes.

Gaussian processes are difficult to apply exactly to large datasets since the computational requirements scale as O(N3)O(N^{3}) time, and as O(N2)O(N^{2}) memory, where NN is the number of training cases. Multiple approaches have been proposed to reduce the computational complexity. However, sparse GP methods (Lázaro-Gredilla et al., 2010; Snelson & Ghahramani, 2006; Titsias, 2009; Hensman et al., 2013; 2015; Krauth et al., 2016) still suffer for very large dataset, while random feature methods (Rahimi & Recht, 2008; Le et al., 2013) and KISS-GP (Wilson & Nickisch, 2015; Izmailov et al., 2017) must be hand-tailored to a given kernel.

Experiments

Our experiments had two main aims: (1) to test the ability of fBNNs to extrapolate using various structural motifs, including both implicit and explicit priors, and (2) to test if they perform competitively with other BNNs on standard benchmark tasks such as regression and contextual bandits.

In all of our experiments, the variational posterior is represented as a stochastic neural network with independent Gaussian distributions over the weights, i.e. q(w)=N(w;μ,diag(σ2))q(\mathbf{w})=\mathcal{N}(\mathbf{w};{\bm{\mu}},\operatorname*{diag}(\bm{\sigma}^{2})). One could also use stochastic activations, but we did not find that this gave any improvement. We always used the ReLU activation function unless otherwise specified. Measurement points were sampled uniformly from a rectangle containing the training inputs. More precisely, each coordinate was sampled from the interval [xmind2,xmax+d2][x_{\rm min}-\tfrac{d}{2},x_{\rm max}+\tfrac{d}{2}], where xminx_{\rm min} and xmaxx_{\rm max} are the minimum and maximum input values along that coordinate, and d=xmaxxmind=x_{\rm max}-x_{\rm min}. For experiments where we used GP priors, we first fit the GP hyperparameters to maximize the marginal likelihood on subsets of the training examples, and then fixed those hyperparameters to obtain the prior for the fBNNs.

Making sensible predictions outside the range of the observed data requires exploiting the underlying structure. In this section, we consider some illustrative examples where fBNNs are able to use structured priors to make sensible extrapolations. Section C.2 also shows the extrapolation of fBNNs for a time-series problem.

PERRBF\text{PER}+\text{RBF} (which does). In each case, the fBNN makes similar predictions to the exact GP. In contrast, the standard BBB (BBB-11) cannot even fit the training data, while BBB with scaling down KL by 0.0010.001 (BBB-0.0010.001) manages to fit training data, but fails to provide sensible extrapolations. Gaussian processes can model periodic structure using a periodic kernel plus a RBF kernel:

where pp is the period. In this experiment, we consider 20 inputs randomly sampled from the interval [2,0.5][0.5,2][-2,-0.5]\cup[0.5,2], and targets yy which are noisy observations of a periodic function: y=2sin(4x)+ϵy=2*\sin(4x)+\epsilon with ϵN(0,0.04)\epsilon\sim\mathcal{N}(0,0.04). We compared our method with Bayes By Backprop (BBB) (Blundell et al., 2015) (with a spherical Gaussian prior on w\mathbf{w}) and Gaussian Processes. For fBNNs and GPs, we considered both a single RBF kernel (which does not capture the periodic structure) and PER+RBF\text{PER}+\text{RBF} as in eq. 13 (which does).Details: we used a BNN with five hidden layers, each with 500 units. The inputs and targets were normalized to have zero mean and unit variance. For all methods, the observation noise variance was set to the true value. We used the trained GP as the prior of our fBNNs. In each iteration, measurement points included all training examples, plus 40 points randomly sampled from $$. We used a training budget of 80,000 iterations, and annealed the weighting factor of the KL term linearly from 0 to 1 for the first 50,000 iterations.

As shown in Fig. 2, BBB failed to fit the training data, let alone recover the periodic pattern (since its prior does not encode any periodic structure). For this example, we view the GP with PER+RBF\text{PER}+\text{RBF} as the gold standard, since its kernel structure is designed to model periodic functions. Reassuringly, the fBNNs made very similar predictions to the GPs with the corresponding kernels, though they predicted slightly smaller uncertainty. We emphasize that the extrapolation results from the functional prior, rather than the network architecture, which does not encode periodicity, and which is not well suited to model smooth functions due to the ReLU activation function.

1.2 Implicit Priors

Because the KL term in the fELBO is estimated using the SSGE, an implicit variational inference algorithm (as discussed in Section 2.3), the functional prior need not have a tractable marginal density. In this section, we examine approximate posterior samples and marginals for two implicit priors: a distribution over piecewise constant functions, and a distribution over piecewise linear functions. Prior samples are shown in Figure 3; see Section D.2 for the precise definitions. In each run of the experiment, we first sampled a random function from the prior, and then sampled 2020 points from [0,0.2][0,0.2] and another 2020 points from [0.8,1][0.8,1], giving a training set of 40 data points. To make the task more difficult for the fBNN, we used the tanh activation function, which is not well suited for piecewise constant or piecewise linear functions.Details: the standard deviation of observation noise was chosen to be 0.02. In each iteration, we took all training examples, together with 40 points randomly sampled from $$]. We used a fully connected network with 2 hidden layers of 100 units, and tanh activations. The network was trained for 20,000 iterations.

Posterior predictive samples and marginals are shown for three different runs in Figure 3. We observe that fBNNs made predictions with roughly piecewise constant or piecewise linear structure, although their posterior samples did not seem to capture the full diversity of possible explanations of the data. Even though the tanh activation function encourages smoothness, the network learned to generate functions with sharp transitions.

2 Predictive Performance

Following previous work (Hernández-Lobato & Adams, 2015), we then experimented with standard regression benchmark datasets from the UCI collection (Asuncion & Newman, 2007). In particular, we only used the datasets with less than 2000 data points so that we could fit GP hyperparameters by maximizing marginal likelihood exactly. Each dataset was randomly split into training and test sets, comprising 90% and 10% of the data respectively. This splitting process was repeated 10 times to reduce variability.Details: For all datasets, we used networks with one hidden layer of 50 hidden units. We first fit GP hyper-parameters using marginal likelihood with a budget of 10,000 iterations. We then trained the observation variance and kept it lower bounded by GP observation variance. FBNNs were trained for 2,000 epochs. And in each iteration, measurement points included 20 training examples, plus 5 points randomly sampled.

We compared our fBNNs with Bayes By Backprop (BBB) (Blundell et al., 2015) and Noisy K-FAC (Zhang et al., 2018). In accordance with Zhang et al. (2018), we report root mean square error (RMSE) and test log-likelihood. The results are shown in Table 1. On most datasets, our fBNNs outperformed both BBB and NNG, sometimes by a significant margin.

2.2 Large Scale Datasets

Observe that fBNNs are naturally scalable to large datasets because they access the data only through the expected log-likelihood term, which can be estimated stochastically. In this section, we verify this experimentally. We compared fBNNs and BBB with large scale UCI datasets, including Naval, Protein Structures, Video Transcoding (Memory, Time) and GPU kernel performance. We randomly split the datasets into 80% training, 10% validation, and 10% test. We used the validating set to select the hyperparameters and performed early stopping.

Both methods were trained for 80,000 iterations.We tune the learning rate from [0.001,0.01][0.001,0.01]. We tuned between not annealing the learning rate or annealing it by 0.10.1 at 40000 iterations. We evaluated the validating set in each epoch, and selected the epoch for testing based on the validation performance. To control overfitting, we used Gamma(6.,6.)\text{Gamma}(6.,6.) prior following (Hernández-Lobato & Adams, 2015) for modelling observation precision and perform inference. We used 1 hidden layer with 100 hidden units for all datasets. For the prior of fBNNs, we used a GP with Neural Kernel Network (NKN) kernels as used in Sun et al. (2018). We note that GP hyperparameters were fit using mini-batches of size 1000 with 10000 iterations. In each iteration, measurement sets consist of 500 training samples and 5 or 50 points from the sampling distribution cc, tuned by validation performance. We ran each experiment 5 times, and report the mean and standard deviation in Table 2. More large scale regression results with bigger networks can be found at Section C.4 and Section C.5.

3 Contextual Bandits

One of the most important applications of uncertainty modelling is to guide exploration in settings such as bandits, Bayesian optimization (BO), and reinforcement learning. In this section, we evaluate fBNNs on a recently introduced contextual bandits benchmark (Riquelme et al., 2018). In contextual bandits problems, the agent tries to select the action with highest reward given some input context. Because the agent learns about the model gradually, it should balance between exploration and exploitation to maximize the cumulative reward. Thompson sampling (Thompson, 1933) is one promising approach which repeatedly samples from the posterior distribution over parameters, choosing the optimal action according to the posterior sample.

We compared our fBNNs with the algorithms benchmarked in (Riquelme et al., 2018). We ran the experiments for all algorithms and tasks using the default settings open sourced by Riquelme et al. (2018). For fBNNs, we kept the same settings, including batchsize (512), training epochs (100) and training frequency (50). For the prior, we use the multi-task GP of Riquelme et al. (2018). Measurement sets consisted of training batches, combined with 10 points sampled from data regions. We ran each experiment 10 times; the mean and standard derivation are reported in Table 3 (Section C.1 has the full results for all experiments.). Similarly to Riquelme et al. (2018), we also report the mean rank and mean regret.

As shown in Table 3, fBNNs outperformed other methods by a wide margin. Additionally, fBNNs maintained consistent performance even with deeper and wider networks. By comparison, BBB suffered significant performance degradation when the hidden size was increased from 50 to 500. This is consistent with our hypothesis that functional variational inference can gracefully handle networks with high capacity.

4 Bayesian Optimization

Another domain where efficient exploration requires accurate uncertainty modeling is Bayesian optimization. Our experiments with Bayesian optimization are described in App C.3. We compared BBB, RBF Random Feature (Rahimi & Recht, 2008) and our fBNNs in the context of Max-value Entropy Search (MES) (Wang & Jegelka, 2017), which requires explicit function samples for Bayesian Optimization. We performed BO over functions sampled from Gaussian Processes corresponding to RBF, Matern12 and ArcCosine kernels, and found our fBNNs achieved comparable or better performance than RBF Random Feature.

Conclusions

In this paper we investigated variational inference between stochastic processes. We proved that the KL divergence between stochastic processes equals the supremum of KL divergence for marginal distributions over all finite measurement sets. Then we presented two practical functional variational inference approaches: adversarial and sampling-based. Adopting BNNs as the variational posterior yields our functional variational Bayesian neural networks. Empirically, we demonstrated that fBNNs extrapolate well over various structures, estimate reliable uncertainties, and scale to large datasets.

Acknowledgements

We thank Ricky Chen, Kevin Luk and Xuechen Li for their helpful comments on this project. SS was supported by a Connaught New Researcher Award and a Connaught Fellowship. GZ was supported by an MRIS Early Researcher Award. RG acknowledges funding from the CIFAR Canadian AI Chairs program.

References

Appendix A Functional KL Divergence

We begin with some basic terminology and classical results. See Gray (2011) and Folland (2013) for more details.

Given a probability measure space (Ω,F,P)(\Omega,\mathcal{F},P) and another probability measure MM on the smae space, the KL divergence of PP with respect to MM is defined as

where the supremum is taken over all finite measurable partitions Q={Qi}i=1n\mathcal{Q}=\{Q_{i}\}_{i=1}^{n} of Ω\Omega, and PQ,MQP_{\mathcal{Q}},M_{\mathcal{Q}} represent the discrete measures over the partition Q\mathcal{Q}, respectively.

Given probability spaces (X,FX,μ)(X,\mathcal{F}_{X},\mu) and (Y,FY,ν)(Y,\mathcal{F}_{Y},\nu), we say that measure ν\nu is a pushforward of μ\mu if ν(A)=μ(f1(A))\nu(A)=\mu(f^{-1}(A)) for a measurable f:XYf:X\to Y and any AFYA\in\mathcal{F}_{Y}. This relationship is denoted by ν=μf1\nu=\mu\circ f^{-1}.

Let TT be an arbitrary index set, and {(Ωt,Ft)}tT\{(\Omega_{t},\mathcal{F}_{t})\}_{t\in T} be some collection of measurable spaces. For each subset JITJ\subset I\subset T, define ΩJ=tJΩt.\Omega^{J}=\prod_{t\in J}\Omega_{t}. We call πIJ\pi_{I\to J} the canonical projection map from II to JJ if

Where wJw|_{J} is defined as, if w=(wi)iIw=(w_{i})_{i\in I}, then wJ=(wi)iJw|_{J}=(w_{i})_{i\in J}.

Let TT be an arbitrary index set, (Ω,F)(\Omega,\mathcal{F}) be a measurable space. Suppose

is the set of Ω\Omega-valued functions. A cylinder subset is a finitely restricted set defined as

We call the σ\sigma-algebra FT:=σ(GΩT)\mathcal{F}^{T}:=\sigma(\mathcal{G}_{\Omega^{T}}) as the cylindrical σ\sigma-algebra of ΩT\Omega^{T}, and (ΩT,FT)(\Omega^{T},\mathcal{F}^{T}) the cylindrical measurable space.

The Kolmogorov Extension Theorem is the foundational result used to construct many stochastic processes, such as Gaussian processes. A particularly relevant fact for our purposes is that this theorem defines a measure on a cylindrical measurable space, using only canonical projection measures on finite sets of points.

Let TT be an arbitrary index set. (Ω,F)(\Omega,\mathcal{F}) is a standard measurable space, whose cylindrical measurable space on TT is (ΩT,FT)(\Omega^{T},\mathcal{F}^{T}). Suppose that for each finite subset ITI\subset T, we have a probability measure μI\mu_{I} on ΩI\Omega^{I}, and these measures satisfy the following compatibility relationship: for each subset JIJ\subset I, we have

Then there exists a unique probability measure μ\mu on ΩT\Omega^{T} such that for all finite subsets ITI\subset T,

In the context of Gaussian processes, μ\mu is a Gaussian measure on a separable Banach space, and the μI\mu_{I} are marginal Gaussian measures at finite sets of input positions (Mallasto & Feragen, 2017).

Suppose that MM and PP are measures on the sequence space corresponding to outcomes of a sequence of random variables X0,X1,X_{0},X_{1},\cdots with alphabet AA. Let Fn=σ(X0,,Xn1)\mathcal{F}_{n}=\sigma(X_{0},\cdots,X_{n-1}), which asymptotically generates the σ\sigma-algebra σ(X0,X1,)\sigma(X_{0},X_{1},\cdots). Then

Where PFn,MFnP_{\mathcal{F}_{n}},M_{\mathcal{F}_{n}} denote the pushforward measures with f:f(X0,X1,)=f(X0,,Xn1)f:f(X_{0},X_{1},\cdots)=f(X_{0},\cdots,X_{n-1}), respectively.

A.2 Functional KL divergence

Here we clarify what it means for the measurable sets to only depend on the values at a countable set of points. To begin with, we firstly introduce some definitions.

For fΩT,t0T,vΩf\in\Omega^{T},t_{0}\in T,v\in\Omega, a replacing function ft0,vrf^{r}_{t_{0},v} is,

For HFTH\in\mathcal{F}^{T}, we define the free indices τc(H){\tau}^{c}(H):

Restricted index sets satisfy the following properties:

for any HFT\textrm{for any }H\in\mathcal{F}^{T}, τ(H)=τ(Hc)\tau(H)=\tau(H^{c}).

for any indices set I and measureable sets {Hi;HiFT}iI\textrm{for any indices set }I\textrm{ and measureable sets }\{H_{i};H_{i}\in\mathcal{F}^{T}\}_{i\in I}, τ(iIHi)iIτ(Hi)\tau(\underset{i\in I}{\cup}{H_{i}})\subseteq\underset{i\in I}{\cup}\tau(H_{i}).

Having defined restricted indices, a key step in our proof is to show that, for any measureable set HH in a cylindrical measureable space (ΩT,FT)(\Omega^{T},\mathcal{F}^{T}), its set of restricted indices τ(H)\tau(H) is countable.

Given a cylindrical measureable space (ΩT,FT)(\Omega^{T},\mathcal{F}^{T}), for any HFTH\in\mathcal{F}^{T}, τ(H)\tau(H) is countable.

Define H={HHFT,τ(H) is countable},HFT\mathcal{H}=\{H|H\in\mathcal{F}^{T},\tau(H)\textrm{ is countable}\},\mathcal{H}\subseteq\mathcal{F}^{T}. By the two properties of restricted indices in 6, H\mathcal{H} is a σ\sigma-algebra on ΩT\Omega^{T}.

On the other hand, FT=σ(GΩT)\mathcal{F}^{T}=\sigma(\mathcal{G}_{\Omega^{T}}). Because any set in GΩT\mathcal{G}_{\Omega^{T}} has finite restricted indices, GΩTH\mathcal{G}_{\Omega^{T}}\subseteq\mathcal{H}. Therefore H\mathcal{H} is a σ\sigma-algebra containing GΩT\mathcal{G}_{\Omega^{T}}. Thus Hσ(GΩT)=FT\mathcal{H}\supseteq\sigma(\mathcal{G}_{\Omega^{T}})=\mathcal{F}^{T}.

Overall, we conclude H=FT\mathcal{H}=\mathcal{F}^{T}. For any HFTH\in\mathcal{F}^{T}, τ(H)\tau(H) is countable. ∎

For two stochastic processes P,MP,M on a cylindrical measurable space (ΩT,FT)(\Omega^{{T}},\mathcal{F}^{{T}}), the KL divergence of PP with respect to MM satisfies,

where the supremum is over all finite indices subsets TdTT_{d}\subseteq T, and PTd,MTdP_{T_{d}},M_{T_{d}} represent the canonical projection maps πTTd\pi_{T\to T_{d}} of P,MP,M, respectively.

Recall that stochastic processes are defined over a cylindrical σ\sigma-algebra FT\mathcal{F}^{T}. By 6, for every set HFTH\in\mathcal{F}^{T}, the restricted index set τ(H)\tau(H) is countable. Our proof proceeds in two steps:

Any finite measurable partition of ΩT\Omega^{T} corresponds to a finite measurable partition over some ΩTc\Omega^{T_{c}}, where TcT_{c} is a countable index set.

Correspondence between partitions implies correspondence between KL divergences.

KL divergences over a countable indices set can be represented as supremum of KL divergences over finite indices sets.

where the sup is over all finite measurable partitions of the function space ΩT\Omega^{T}, denoted by QΩT\mathcal{Q}_{\Omega^{T}}:

By 6, each τ(QΩT(i))\tau(Q_{\Omega^{T}}^{(i)}) is countable. So the combined restricted index set Tc:=i=1kτ(QΩT(i))T_{c}:=\displaystyle\bigcup_{i=1}^{k}\tau(Q_{\Omega^{T}}^{(i)}) is countable.

Consider the canonical projection mapping πTTc\pi_{T\to T_{c}}, which induces a partition on ΩTc\Omega^{T_{c}}, denoted by QΩTc\mathcal{Q}_{\Omega^{T_{c}}}:

The pushforward measure defined by this mapping is

Step 3. Denote D(Tc)\mathcal{D}({T}_{c}) as the collection of all finite subsets of Tc{T}_{c}. For any finite set TdD(Tc){T}_{d}\in\mathcal{D}({T}_{c}), we denote PTdP_{{T}_{d}} as the pushforward measure of PTcP_{{T}_{c}} on ΩTd\Omega^{{T}_{d}}. From the Kolmogorov Extension Theorem (4), we know that PTdP_{{T}_{d}} corresponds to the finite marginals of PP at ΩTd\Omega^{{T}_{d}}. Because Tc{T}_{c} is countable, based on 5, we have,

We are left with the last question: whether each Td{T}_{d} is contained in some D(Tc)\mathcal{D}({T}_{c}) ?

For any finite indices set Td{T}_{d}, we build a finite measureable partition QQ. Let Ω=Ω0Ω1,Ω0Ω1=\Omega=\Omega_{0}\cup\Omega_{1},\Omega_{0}\cap\Omega_{1}=\emptyset. Assume Td=K,Td={Td(k)}k=1:K|{T}_{d}|=K,{T}_{d}=\{{T}_{d}(k)\}_{k=1:K}, let I={IiIi=(I1i,I2i,IKi)}i=1:2KI=\{I^{i}|I^{i}=(I^{i}_{1},I^{i}_{2},\cdots I^{i}_{K})\}_{i=1:2^{K}} to be all KK-length binary vectors. We define the partition,

Through this settting, Q\mathcal{Q} is a finite parition of ΩT\Omega^{{T}}, and Tc(Q)=Td{T}_{c}(\mathcal{Q})={T}_{d}. Therefore Td{T}_{d} in Section A.2 can range over all finite index sets, and we have proven the theorem.

A.3 KL Divergence between Conditional Stochastic Processes

In this section, we give an example of computing the KL divergence between two conditional stochastic processes. Consider two datasets D1,D2\mathcal{D}_{1},\mathcal{D}_{2}, the KL divergence between two conditional stochastic processes is

Therefore, the KL divergence between these two stochastic processes equals to the marginal KL divergence on the observed locations. When D2=\mathcal{D}_{2}=\emptyset, p(fD2)=p(f)p(f|\mathcal{D}_{2})=p(f), this shows the KL divergence between posterior process and prior process are the marginal KL divergence on observed locations.

This also justifies our usage of MM measurement points in the adversarial functional VI and sampling-based functional VI of Section 3.

Appendix B Additional Proofs

Let XM=X\XD\mathbf{X}^{M}=\mathbf{X}\backslash\mathbf{X}^{D} be measurement points which aren’t in the training data.

B.2 Consistency for Gaussian Processes

This section provides proof for consistency in 3.

By the assumption that both q(D)q(\mathcal{D}) and p(fD)p(f|\mathcal{D}) are Gaussian processes:

where mm and kk denote the mean and covariance functions, respectively.

In this theorem, we also assume the measurement points cover all training locations as in Equation 9, where we have (based on 2):

Here m(X)=[m(x1),,m(xM)]m(\mathbf{X})=[m(\mathbf{x}_{1}),\dots,m(\mathbf{x}_{M})]^{\top}, and [k(XM,XM)]ij=k(xi,xj)\left[k(\mathbf{X}^{M},\mathbf{X}^{M})\right]_{ij}=k(\mathbf{x}_{i},\mathbf{x}_{j}).

So we have that Equation 37 holds for any XMXM\mathbf{X}^{M}\in\mathcal{X}^{M}. Given M>1M>1, then for 1i<jM\forall 1\leq i<j\leq{M}, we have mp(xi)=mq(xi)m_{p}(\mathbf{x}_{i})=m_{q}(\mathbf{x}_{i}), and kp(xi,xj)=kq(xi,xj)k_{p}(\mathbf{x}_{i},\mathbf{x}_{j})=k_{q}(\mathbf{x}_{i},\mathbf{x}_{j}), which implies

Because GPs are uniquely determined by their mean and covariance functions, we arrive at the conclusion. ∎

Appendix C Additional Experiments

Here we present the full table for the contextual bandits experiment.

C.2 Time-series Extrapolation

In Figure 4 we could see that the performance of fBNN closely matches the exact prediction by GP. Both of them give visually good extrapolation results that successfully model the long-term trend, local variations, and periodic structures. In contrast, weight-space prior and inference (BBB) neither captures the right periodic structure, nor does it give meaningful uncertainty estimates.

C.3 Bayesian Optimization

In this section, we adopt Bayesian Optimization to explore the advantage of coherent posteriors. Specifically, we use Max Value Entropy Search (MES) (Wang & Jegelka, 2017), which tries to maximize the information gain about the minimum value yy^{\star},

Where ϕ\phi and Ψ\Psi are probability density function and cumulative density function of a standard normal distribution, respectively. The yy^{\star} is the minimum of a random function from the posterior, and γy(x)=μt(x)yσt(x)\gamma_{y^{\star}}(\mathbf{x})=\frac{\mu_{t}(\mathbf{x})-y^{\star}}{\sigma_{t}(\mathbf{x})}.

With a probabilistic model, we can compute or estimate the mean μt(x)\mu_{t}(\mathbf{x}) and the standard deviation σt(x)\sigma_{t}(\mathbf{x}). However, to compute the MES acquisition function, samples yy^{\star} of function minima are required as well, which leads to difficulties. Typically when we model the data with a GP, we can get the posterior on a specific set of points but we don’t have access to the extremes of the underlying function. In comparison, if the function posterior is represented in a parametric form, we can perform gradient decent easily and search for the minima.

We use 3-dim functions sampled from some Gaussian process prior for Bayesian optimization. Concretely, we experiment with samples from RBF, Order-1 ArcCosine and Matern12 kernels. We compare three parametric approaches: fBNN, BBB and Random Feature (Rahimi & Recht, 2008). For fBNN, we use the true kernel as functional priors. In contrast, ArcCosine and Matern12 kernels do not have simple explicit random feature expressions, therefore we use RBF random features for all three kernels. When looking for minima, we sample 10 yy^{\star}. For each yy^{\star}, we perform gradient descent along the sampled parametric function posterior with 30 different starting points. We use 500 dimensions for random feature. We use network with 5×1005\times 100 for fBNN. For BBB, we select the network within 1×100,3×1001\times 100,3\times 100. Because of the similar issue in Figure 1, using larger networks won’t help for BBB. We use batch size 30 for both fBNN and BBB. The measurement points contain 30 training points and 30 points uniformly sampled from the known input domain of functions. For training, we rescale the inputs to $$, and we normalize outputs to have zero mean and unit variance. We train fBNN and BBB for 20000 iterations and anneal the coefficient of log likelihood term linearly from 0 to 1 for the first 10000 iterations. The results with 10 runs are shown in Figure 5.

As seen from Figure 5, fBNN and Random feature outperform BBB by a large margin on all three functions. We also observe fBNN performs slightly worse than random feature in terms of RBF priors. Because random feature method is exactly a GP with RBF kernel asymptotically, it sets a high standard for the parametric approaches. In contrast, fBNN outperforms random feature for both ArcCosine and Matern12 functions. This is because of the big discrepancy between such kernels and RBF random features. Because fBNN use true kernels, it models the function structures better. This experiment highlights a key advantage of fBNN, that fBNN can learn parametric function posteriors for various priors.

C.4 Varying depth

To compare with Variational Free Energy (VFE) (Titsias, 2009), we experimented with two medium-size datasets so that we can afford to use VFE with full batch. For VFE, we used 1000 inducing points initialized by k-means of training point. For BBB and FBNNs, we used batch size 500 with a budget of 2000 epochs. As shown in Table 5, FBNNs performed slightly worse than VFE, but the gap became smaller as we used larger networks. By contrast, BBB totally failed with large networks (5 hidden layers with 500 hidden units each layer). Finally, we note that the gap between FBNNs and VFE diminishes if we use fewer inducing points (e.g., 300 inducing points).

C.5 Large scale regression with deeper networks

In this section we experimented on large scale regression datasets with deeper networks. For BBB and fBNNs, we used a network with 5 hidden layers of 100 units, and kept all other settings the same as Section 5.2.2. We also compared with the stochastic variational Gaussian processes (SVGP) (Hensman et al., 2013), which provides a principled mini-batch training for sparse GP methods, thus enabling GP to scale up to large scale datasets. For SVGP, we used 1000 inducing points initialized by k-means of training points (Note we cannot afford larger size of inducing points because of the cubic computational cost). We used batch size 2000 and iterations 60000 to match the training time with fBNNs. Likewise for BNNs, we used validation set to tune the learning rate from {0.01,0.001}\{0.01,0.001\}. We also tuned between not annealing the learning rate or annealing it by 0.1 at 30000 iterations. We evaluated the validating set in each epoch, and selected the epoch for testing based on the validation performance. The averaged results over 5 runs are shown in Table 6.

As shown in Table 6, SVGP performs better than BBB and fBNNs in terms of the smallest naval dataset. However, with dataset size increasing, SVGP performs worse than BBB and fBNNs by a large margin. This stems from the limited capacity of 1000 inducing points, which fails to act as sufficient statistics for large datasets. In contrast, BNNs including BBB and fBNNs can use larger networks freely without the intractable computational cost.

Appendix D Implementation Details

For Gaussian process priors, p(fX)p(\mathbf{f}^{\mathbf{X}}) is a multivariate Gaussian distribution, which has an explicit density. Therefore, we can compute the gradients {\color[rgb]{1,0,0}\nabla_{\mathbf{f}}\log p_{\phi}({\mathbf{f}^{\mathbf{X}}})} analytically.

D.2 Implicit Priors