Structured Variational Learning of Bayesian Neural Networks with Horseshoe Priors
Soumya Ghosh, Jiayu Yao, Finale Doshi-Velez
Introduction
Bayesian Neural Networks (BNNs) are increasingly the de-facto approach for modeling stochastic functions. By treating the weights in a neural network as random variables, and performing posterior inference on these weights, BNNs can avoid overfitting in the regime of small data, provide well-calibrated posterior uncertainty estimates, and model a large class of stochastic functions with heteroskedastic and multi-modal noise. These properties have resulted in BNNs being adopted in applications ranging from active learning (Hernández-Lobato & Adams, 2015; Gal et al., 2016a) and reinforcement learning (Blundell et al., 2015; Depeweg et al., 2017).
While there have been many recent advances in training BNNs (Hernández-Lobato & Adams, 2015; Blundell et al., 2015; Rezende et al., 2014; Louizos & Welling, 2016; Hernandez-Lobato et al., 2016), model-selection in BNNs has received relatively less attention. Unfortunately, the consequences for a poor choice of architecture are severe: too few nodes, and the BNN will not be flexible enough to model the function of interest; too many nodes, and the BNN predictions will have large variance. We note that these Bayesian model selection concerns are subtlely different from overfitting and underfitting concerns that arise from maximum likelihood training: here, more expressive models (e.g. those with more nodes) require more data to concentrate the posterior. When there is insufficent data, the posterior uncertainty over the BNN weights will remain large, resulting in large variances in the BNN’s predictions. We illustrate this issue in Figure 1, where we see a BNN trained with too many parameters has higher variance around its predictions than one with fewer. Thus, the core concern of Bayesian model selection is to identify a model class expressive enough that it can explain the observed data set, but not so expressive that it can explain everything (Rasmussen & Ghahramani, 2001; Murray & Ghahramani, 2005).
Model selection in BNNs is challenging because the number of nodes in a layer is a discrete quantity. Recently, (Ghosh & Doshi-Velez, 2017; Louizos et al., 2017) independently proposed performing model selection in Bayesian neural networks by placing Horseshoe priors (Carvalho et al., 2009) over the weights incident to each node in the network. This prior can be interpreted as a continuous relaxation of a spike-and-slab approach that would assign a discrete on-off variable to each node, allowing for computationally-efficient optimization via variational inference.
In this work, we expand upon this idea with several innovations and careful experiments. Via a combination of using regularized horseshoe priors for the node-specific weights and variational approximations that retain critical posterior structure, we both improve upon the statistical properties of the earlier works and provide improved generalization, especially for smaller data sets and in sample-limited settings such as reinforcement learning. We also present a new thresholding rule for pruning away nodes. Unlike previous work our rule does not require computing a point summary of the inferred posteriors. We compare the various model and inference combinations on a diverse set of regression and reinforcement learning tasks. We find that the proposed innovations consistently improve upon the compactness of the models learned without sacrificing predictive performance.
Bayesian Neural Networks
Given observation response pairs and , we are interested in the posterior distribution and in using it for predicting responses to unseen data , . The prior allows one to encode problem-specific beliefs as well as general properties about weights.
Bayesian Neural Networks with Regularized Horseshoe Priors
While the horseshoe prior has some good properties, when the amount of training data is limited, units with essentially no shrinkage can produce large weights can adversely affect generalization performance of HS-BNNs, with minor perturbations of the data leading to vastly different predictions. To deal with this issue, here we consider the regularized horseshoe prior (Piironen & Vehtari, 2017). Under this prior is drawn from,
Half-Cauchy re-parameterization for variational learning. Instead of directly parameterizing the Half-Cauchy random variables in Equations 1 and 2, we use a convenient auxiliary variable parameterization (Wand et al., 2011) of the distribution, where is the Inverse Gamma distribution with density for . This avoids the challenges posed by the direct approximation during variational learning — standard exponential family variational approximations struggle to capture the thick Cauchy tails, while a Cauchy approximating family leads to high variance gradients.
where is the likelihood function and , with , .
Non-Centered Parameterization The regularized horseshoe (and the horseshoe) prior both exhibit strong correlations between the weights and the scales . While their favorable sparsity inducing properties stem from this coupling, it also gives rise to coupled posteriors that exhibit pathological funnel shaped geometries (Betancourt & Girolami, 2015; Ingraham & Marks, 2016) that are difficult to reliably sample or approximate.
Adopting non-centered parameterizations (Ingraham & Marks, 2016), helps alleviate the issue. Consider a reformulation of Equation 2,
where the distribution on the scales are left unchanged. Since the scales and weights are sampled from independent prior distributions and are marginally uncorrelated, such a parameterization is referred to as non-centered. The likelihood is now responsible for introducing the coupling between the two, when conditioning on observed data. Non-centered parameterizations are known to lead to simpler posterior geometries (Betancourt & Girolami, 2015). Empirically (Ghosh & Doshi-Velez, 2017) have shown that adopting a non-centered parameterization significantly improves the quality of the posterior approximation for BNNs with Horseshoe priors. Thus, we also adopt non-centered parameterizations for the regularized Horseshoe BNNs.
Structured Variational Learning of Regularized Horseshoe BNNs
The more flexible the approximating family the better it approximates the true posterior. Below, we first describe a straight-forward fully-factored approximation and then a more sophisticated structured approximation that we demonstrate has better statistical properties.
The simplest possibility is to use a fully factorized variational family,
Restricting the variational distribution for the non-centered weight between units in layer and in layer , to the Gaussian family , and the non-negative scale parameters and and the variance of the output layer weights to the log-Normal family, , , and , allows for the development of straightforward inference algorithms (Ghosh & Doshi-Velez, 2017; Louizos et al., 2017). It is not necessary to impose distributional constraints on the variational approximations of the auxiliary variables , , or . Conditioned on the other variables the optimal variational family for these latent variables follow inverse Gamma distributions. We refer to this approximation as the factorized approximation.
Structured Variational Approximations
Although computationally convenient, the factorized approximations fail to capture posterior correlations among the network weights, and more pertinently, between weights and scales.
Table 1 summarizes the variational approximations introduced in this section.
2 Black Box Variational Inference
Irrespective of the variational family choice, the resulting evidence lower bound (ELBO),
is challenging to evaluate. Here we have used to denote the set of all non-centered weights in the network. The non-linearities introduced by the neural network and the potential lack of conjugacy between the neural network parameterized likelihoods and the Horseshoe priors render the first expectation in Equation 6 intractable.
Recent progress in black box variational inference (Kingma & Welling, 2014; Rezende et al., 2014; Ranganath et al., 2014; Titsias & Lázaro-gredilla, 2014) subverts this difficulty. These techniques compute noisy unbiased estimates of the gradient , by approximating the offending expectations with unbiased Monte-Carlo estimates and relying on either score function estimators (Williams, 1992; Ranganath et al., 2014) or reparameterization gradients (Kingma & Welling, 2014; Rezende et al., 2014; Titsias & Lázaro-gredilla, 2014) to differentiate through the sampling process. With the unbiased gradients in hand, stochastic gradient ascent can be used to optimize the ELBO. In practice, reparameterization gradients exhibit significantly lower variances than their score function counterparts and are typically favored for differentiable models. The reparameterization gradients rely on the existence of a parameterization that separates the source of randomness from the parameters with respect to which the gradients are sought. For our Gaussian variational approximations, the well known non-centered parameterization, , allows us to compute Monte-Carlo gradients,
for any differentiable function and . Furthermore, all practical implementations of variational Bayesian neural networks use a further re-parameterization to lower variance of the gradient estimator. They sample from the implied variational distribution over a layer’s pre-activations instead of directly sampling the much higher dimensional weights (Kingma et al., 2015).
Computational Considerations The primary computational bottleneck for the structured approximation arises in computing the pre-activations in equation 8. While computing in the factorized approximation involves a single inner product, in the structured case it requires the computation of the quadratic form and a point wise multiplication with the elements of . Owing to the diagonal plus rank-one structure of , we only need two inner products, followed by a scalar squaring and addition to compute the quadratic form and scalar multiplications for the point-wise multiplication with . Thus the structured approximation is only marginally more expensive. Further, it uses only weight variance parameters per layer, instead of parameters used by the factorized approximation. Not having to compute gradients and update these additional parameters further mitigates the performance difference.
3 Pruning Rule
The Horseshoe and its regularized variant provide strong shrinkage towards zero for small . However, the shrunk weights, although tiny, are never actually zero. A user-defined thresholding rule is required to prune away the shrunk weights. One could first summarize the inferred posterior distributions using a point estimate and then use the summary to define a thresholding rule (Louizos et al., 2017). We propose an alternate thresholding rule that obviates the need for a point summary. We prune away a unit, if , where and are user defined parameters, with and . Since, both and are constrained to the log-Normal variational family, their product follows another log-Normal distribution, and implementing the thresholding rule simply amounts to computing the cumulative distribution function of the log-Normal distribution. To see why this rule is sensible, recall that for units which experience strong shrinkage the regularized Horseshoe tends to the Horseshoe. Under the Horseshoe prior, governs the (non-negative) scale of the weight node vector . Therefore, under our thresholding rule, we prune away nodes whose posterior scales, place probability greater than below a sufficiently small threshold . In our experiments, we set and to either or .
Related Work
Bayesian neural networks have a long history. Early work can be traced back to (Buntine & Weigend, 1991; MacKay, 1992; Neal, 1993). These early approaches do not scale well to modern architectures or the large datasets required to learn them. Recent advances in stochastic MCMC methods (Li et al., 2016; Welling & Teh, 2011) and stochastic variational methods (Blundell et al., 2015; Rezende et al., 2014), black-box variational and alpha-divergence minimization (Hernandez-Lobato et al., 2016; Ranganath et al., 2014), and probabilistic backpropagation (Hernández-Lobato & Adams, 2015) have reinvigorated interest in BNNs by allowing scalable inference. Work on learning structure in BNNs has received less attention. (Blundell et al., 2015) introduce a mixture-of-Gaussians prior on the weights, with one mixture tightly concentrated around zero, thus approximating a spike and slab prior over weights. Others (Kingma et al., 2015; Gal & Ghahramani, 2016) have noticed connections between Dropout (Srivastava et al., 2014) and approximate variational inference. In particular, (Molchanov et al., 2017) show that the interpretation of Gaussian dropout as performing variational inference in a network with log uniform priors over weights leads to sparsity in weights. The goal of turning off edges is very different than the approach considered here, which performs model selection over the appropriate number of nodes. More closely related to us, are the recent works of (Ghosh & Doshi-Velez, 2017) and (Louizos et al., 2017). The authors consider group Horseshoe priors for unit pruning. We improve upon these works by using regularized Horseshoe priors that improve generalization, structured variational approximations that provide more accurate inferences, and by proposing a new thresholding rule to prune away units with small scales. Yet others (Neklyudov et al., 2017) have proposed pruning units via truncated log-normal priors over unit scales. However, they do not place priors over network weights and are unable to infer posterior weight uncertainty. In related but orthogonal research (Adams et al., 2010; Song et al., 2017) focused on the problem of structure learning in deep belief networks. There is also a body of work on learning structure in non-Bayesian neural networks. Early work (LeCun et al., 1990; Hassibi et al., 1993) pruned networks by analyzing second-order derivatives of the objectives. More recently, (Wen et al., 2016) describe applications of structured sparsity not only for optimizing filters and layers but also computation time. Closer to our work in spirit, (Ochiai et al., 2016; Scardapane et al., 2017; Alvarez & Salzmann, 2016) and (Murray & Chiang, 2015) who use group sparsity to prune groups of weights—e.g. weights incident to a node. However, these approaches don’t model weight uncertainty and provide uniform shrinkage to all weights.
Experiments
In this section, we present experiments that evaluate various aspects of the proposed regularized Horseshoe Bayesian neural network (reg-HS) and the structured variational approximation. In all experiments, we use a learning rate of , the global horseshoe scale , a batch size of , , and . For the structured approximation, we also found that constraining , , and to unit-norms resulted in better predictive performance. Additional experimental details are in the supplement.
We begin by comparing reg-HS against BNNs using the standard Horseshoe (HS) prior on a collection of diverse datasets from the UCI repository. We follow the protocol of (Hernández-Lobato & Adams, 2015) to compare the two models. To provide a controlled comparison, and to tease apart the effects of model versus inference enhancements we employ factorized variational approximations for either model. In figure 2, the UCI datasets are sorted from left to right, with the smallest on the left. We find that the regularized Horseshoe leads to consistent improvements in predictive performance. As expected, the gains are more prominent for the smaller datasets for which the regularization afforded by the regularized Horseshoe is crucial for avoiding over-fitting. In the remainder, all reported experimental results use the reg-HS prior.
Structured variational approximations provide greater shrinkage.
Next, we evaluate the effect of utilizing structured variational approximations. In preliminary experiments, we found that of the approximations described in Section 4.1, the structured approximation outperformed the semi-structured variant while the factorized approximation provided better predictive performance than the tied approximation. In this section we only report results comparing models employing these two variational families.
Toy Data First, we explore the effects of structured and factorized variational approximations on predictive uncertainties. Following (Ghosh & Doshi-Velez, 2017) we consider a noisy regression problem: , , and explore the relationship between predictive uncertainty and model capacity. We compare a single layer unit BNN using a standard normal prior against BNNs with the regularized horseshoe prior utilizing factorized and structured variational approximations. Figures 1 and 3 show that while a BNN severely over-estimates the predictive uncertainty, models using the reg-HS priors by pruning away excess capacity, significantly improve the estimated uncertainty. Furthermore, we observe that the structured approximation best alleviates the under-fitting issues.
Controlled comparisons on UCI benchmarks We return to the UCI benchmark to carefully vet the different variational approximations. We deviate from prior work, by using networks with significantly more capacity than previously considered for this benchmark. In particular, we use single layer networks with an order of magnitude more hidden units () than considered in previous work (). This additional capacity is more than that needed to explain the UCI benchmark datasets well. With this experimental setup, we are able to evaluate how well the proposed methods perform at pruning away extra modeling capacity. For all but the ‘year‘ dataset, we report results from five trials each trained on a random split of the data. For the large year dataset, we ran a single trial (details in the supplement). Figure 2 shows consistently stronger shrinkage.
Comparison against competing methods. We compare the reg-HS model with structured variational approximation against the variational matrix Gaussian (VMG) approach of (Louizos & Welling, 2016), which has previously been shown to outperform other variational approaches to learning BNNs. We used the pruning rule with for all but the ‘year‘ dataset, for which we set . Figure 2 demonstrates that structured reg-HS is competitive with VMG in terms of predictive performance. We either perform similarly or better than VMG on the majority of the datasets. More interestingly, structured reg-HS achieves competitive performance while pruning away excess capacity and achieving significant compression. We also fine-tuned the pruned model by updating the weight means while holding others fixed. However, this didn’t significantly affect predictive performance. Finally, we evaluate how reg-HS compares against VMG in the low data regime. For the three smallest UCI datasets we use ten percent of the data for training. In such limited data regimes (Figure 2) the shrinkage afforded by reg-HS leads to clear improvements in predictive performance over VMG.
HS-BNNs improve reinforcement learning performance.
So far, we have focused on using BNNs simply for prediction. One application area in which having good predictive uncertainty estimates is crucial is in model-based reinforcement learning scenarios (e.g. (Depeweg et al., 2017; Gal et al., 2016b; Killian et al., 2017)): here, it is essential not only to have an estimate of what state an agent may be in after taking a particular action, but also an accurate sense of all the states the agent may end up in. In the following, we apply our regularized HS-BNN with structured approximations to two domains: the 2D map of Killian et al. (2017) and acrobot Sutton & Barto (1998). For each domain, we focused on one instance dynamic setting. In each domain, we collected training samples by training a DDQN (van Hasselt et al., 2016) online (updated every episode). The DDQN was trained with an epsilon-greedy policy that started at one and decayed to with decay rate , for 500 episodes. This procedure ensured that we had a wide variety of samples that were still biased in coverage toward the optimal policy. To simulate resource constrained scenarios, we limited ourselves to 10 of DDQN training batches ( samples for the 2D map and training samples for acrobot). We considered two architectures, a single hidden layer network with units, and a two layer network with units per layer as the transition function for each domain. Then we simulated from each BNN to learn a DDQN policy (two layers of width , ; learning rate ) and tested this policy on the original simulator.
As in our prediction results, training a moderately-sized BNN with so few data results in severe underfitting, which in turn, adversely affects the quality of the policy that is learned. We see in table 2 that the better fitting of the structured reg-HS-BNN results in higher task performance, across domains and model architectures.
Discussion and Conclusion
We demonstrated that the regularized horseshoe prior, combined with a structured variational distribution, is a computationally efficient tool for model selection in Bayesian neural networks. By retaining crucial posterior dependencies, the structured approximation provided, to our knowledge, state of the art shrinkage for BNNs while being competitive in predictive performance to existing approaches. We found, model re-parameterizations — decomposition of the Half-Cauchy priors into inverse gamma distributions and non-centered representations essential for avoiding poor local optima. There remain several interesting follow-on directions, including, modeling enhancements that use layer, node, or even weight specific weight decay , or layer specific global shrinkage parameter to provide different levels of shrinkage to different parts of the BNN.
References
Appendix A Conditional variational pre-activations
Recall from Section 4.2, that the variational pre-activation distribution is given by , where , and is diagonal. To equation requires and . The expressions for these follow directly from the properties of partitioned Gaussians.
Rearranging, we can see that, is made up of the columns and .
Appendix B Algorithmic details
The ELBO corresponding to the non-centered regularized HS model is,
The entropy of is given by . We can exploit the structure of and to compute this efficiently. We note that . Since is diagonal . Using the matrix determinant lemma we can efficiently compute . Owing to the diagonal structure of , computing it’s determinant and inverse is particularly efficient.
Fixed point updates
The auxiliary variables , and all follow inverse Gamma distributions. Here we derive for , the others follow analogously. Consider,
Appendix C Algorithm
Algorithm 1 provides pseudocode summarizing the overall algorithm for training regularized HSBNN (with strictured variational approximations).
Appendix D Experimental details
For regression problems we use Gaussian likelihoods with an unknown precision , . We place a vague prior on the precision, and approximate the posterior over using another variational distribution . The corresponding variational parameters are learned via a gradient update during learning.
For comparing the reg-HS and HS models we followed the protocol of (Hernandez-Lobato & Adams, 2015) and trained a single hidden layer network with rectified linear units for all but the larger “Protein” and “Year” datasets for which we train a unit network. For the smaller datasets we train on a randomly subsampled subset and evaluate on the remainder and repeat this process times. For “Protein” we perform 5 replications and for “Year” we evaluate on a single split. For, VMG we used pseudo-inputs, a learning rate of and a batch size of 128.
Reinforcement learning Experiments
We used a learning rate of . For the map domain we trained for epochs and for acrobot we trained for epochs.
Appendix E Additional Experimental results
In Figure 4 we provide further shrinkage results from the experiments described in the main text comparing regularized Horseshoe models utilizing factorized and structured approximations.
Figure 5 illustrates the shrinkage afforded by 50 unit HS-BNNs using fully factorized approximations. Similar to factorized regularized Horseshoe BNNs limited compression is achieved. Figures On some datasets, we do not achieve much compression and all 50 units are used. A consequence of the fully factorized approximations providing weaker shrinkage as well as units not being large enough to model the complexity of the dataset.
Appendix F Prior samples from networks with HS and regularized Horseshoe priors
To provide further intuition into the behavior of networks with Horseshoe and regularized Horseshoe priors we provide functions drawn from networks endowed with these priors. Figure 6 plots five random functions sampled from one layer networks with varying widths. Observe that the regularized horseshoe distribution leads to smoother functions, thus affording stronger regularization. As demonstrated in the main paper, this stronger regularization leads to improved predictive performance when the amount of training data is limited.