Recasting Gradient-Based Meta-Learning as Hierarchical Bayes
Erin Grant, Chelsea Finn, Sergey Levine, Trevor Darrell, Thomas Griffiths
Introduction
A remarkable aspect of human intelligence is the ability to quickly solve a novel problem and to be able to do so even in the face of limited experience in a novel domain. Such fast adaptation is made possible by leveraging prior learning experience in order to improve the efficiency of later learning. This capacity for also has the potential to enable an artificially intelligent agent to learn more efficiently in situations with little available data or limited computational resources (schmidhuber1987evolutionary; bengio1991learning; naik1992meta).
In machine learning, is formulated as the extraction of domain-general information that can act as an inductive bias to improve learning efficiency in novel tasks (caruana1998multitask; thrun1998learning). This inductive bias has been implemented in various ways: as learned hyperparameters in a that regularize task-specific parameters (heskes1998solving), as a learned metric space in which to group neighbors (bottou1992local), as a trained that allows encoding and retrieval of episodic information (santoro2016meta), or as an optimization algorithm with learned parameters (schmidhuber1987evolutionary; bengio1992optimization).
The of finn2017model is an instance of a learned optimization procedure that directly optimizes the standard gradient descent rule. The algorithm estimates an initial parameter set to be shared among the task-specific models; the intuition is that gradient descent from the learned initialization provides a favorable inductive bias for . However, this inductive bias has been evaluated only empirically in prior work (finn2017model).
In this work, we present a novel derivation of and a novel extension to , illustrating that this algorithm can be understood as inference for the parameters of a prior distribution in a . The learned prior allows for quick adaptation to unseen tasks on the basis of an implicit predictive density over task-specific parameters. The reinterpretation as gives a principled statistical motivation for as a algorithm, and sheds light on the reasons for its favorable performance even among methods with significantly more parameters. More importantly, by casting gradient-based within a Bayesian framework, we are able to improve by taking insights from Bayesian posterior estimation as novel augmentations to the gradient-based procedure. We experimentally demonstrate that this enables better performance on a few-shot learning benchmark.
Meta-Learning Formulation
The goal of a meta-learner is to extract task-general knowledge through the experience of solving a number of related tasks. By using this learned prior knowledge, the learner has the potential to quickly adapt to novel tasks even in the face of limited data or limited computation time.
Formally, we consider a dataset that defines a distribution over a family of tasks . These tasks share some common structure such that learning to solve a single task has the potential to aid in solving another. Each task defines a distribution over data points , which we assume in this work to consist of inputs and either regression targets or classification labels in a supervised learning problem (although this assumption can be relaxed to include reinforcement learning problems; e.g., see finn2017model). The objective of the meta-learner is to be able to minimize a task-specific performance metric associated with any given unseen task from the dataset given even only a small amount of data from the task; i.e., to be capable of to a novel task.
In the following subsections, we discuss two ways of formulating a solution to the problem: and probabilistic inference in a . These approaches were developed orthogonally, but, in Section 3.1, we draw a novel connection between the two.
A parametric meta-learner aims to find some shared parameters that make it easier to find the right task-specific parameters when faced with a novel task. A variety of meta-learners that employ gradient methods for task-specific have been proposed (e.g., andrychowicz2016learning; li2017alearning; li2017blearning; wichrowska2017learned). (finn2017model) is distinct in that it provides a procedure that employs a single additional parameter (the rate) and operates on the same parameter space for both and . These are necessary features for the equivalence we show in Section 3.1.
To address the problem, estimates the parameters of a set of models so that when one or a few batch gradient descent steps are taken from the initialization at given a small sample of task data each model has good generalization performance on another sample from the same task. The objective in a maximum likelihood setting is
where we use to denote the updated parameters after taking a single batch gradient descent step from the initialization at with step size on the negative log-likelihood associated with the task . Note that since is an iterate of a gradient descent procedure that starts from , each is of the same dimensionality as . We refer to the inner gradient descent procedure that computes as . The computational graph of is given in Figure 1 (left).
2 as ian Inference
An alternative way to formulate is as a problem of probabilistic inference in the hierarchical model depicted in Figure 1 (right). In particular, in the case of , each task-specific parameter is distinct from but should influence the estimation of the parameters from other tasks. We can capture this intuition by introducing a meta-level parameter on which each task-specific parameter is statistically dependent. With this formulation, the mutual dependence of the task-specific parameters is realized only through their individual dependence on the meta-level parameters As such, estimating provides a way to constrain the estimation of each of the .
Given some data in a multi-task setting, we may estimate by integrating out the task-specific parameters to form the marginal likelihood of the data. Formally, grouping all of the data from each of the tasks as and again denoting by a sample from task , the marginal likelihood of the observed data is given by
Maximizing (2) as a function of gives a point estimate for , an instance of a method known as (bernardo2006bayesian; gelman2014bayesian) due to its use of the data to estimate the parameters of the prior distribution.
ian models have a long history of use in both and (e.g., lawrence2004learning; yu2005learning; gao2008knowledge; daume2009bayesian; wan2012sparse). However, the formulation of as does not automatically provide an inference procedure, and furthermore, there is no guarantee that inference is tractable for expressive models with many parameters such as deep neural networks.
Linking &
In this section, we connect the two independent approaches of Section 2.1 and Section 2.2 by showing that can be understood as in a hierarchical probabilistic model. Furthermore, we build on this understanding by showing that a choice of update rule for the task-specific parameters (i.e., a choice of inner-loop optimizer) corresponds to a choice of prior over task-specific parameters, .
In general, when performing , the marginalization over task-specific parameters in (2) is not tractable to compute exactly. To avoid this issue, we can consider an approximation that makes use of a point estimate instead of performing the integration over in (2). Using as an estimator for each , we may write the negative logarithm of the marginal likelihood as
Setting for each in (3) recovers the unscaled form of the one-step MAML objective in (1). This tells us that the MAML objective is equivalent to a maximization with respect to the meta-level parameters of the marginal likelihood , where a point estimate for each task-specific parameter is computed via one or a few steps of gradient descent. By taking only a few steps from the initialization at , the point estimate trades off minimizing the objective with staying close in value to the parameter initialization .
We can formalize this trade-off by considering the linear regression case. Recall that the estimate of corresponds to the global mode of the posterior . In the case of a linear model, early stopping of an iterative gradient descent procedure to estimate is exactly equivalent to estimation of under the assumption of a prior that depends on the number of descent steps as well as the direction in which each step is taken. In particular, write the input examples as and the vector of regression targets as , omit the task index from , and consider the gradient descent update
for iteration index and learning rate . santos1996equivalence shows that, starting from , in (4) solves the regularized linear least squares problem
Since in (4) maximizes (6), we may conclude that iterations of gradient descent in a linear regression model with squared error exactly computes the estimate of , given a Gaussian-noised observation model and a Gaussian prior over with parameters and . Therefore, in the case of linear regression with squared error, is exactly using the estimate as the point estimate of .
In the nonlinear case, is again equivalent to an procedure to maximize the marginal likelihood that uses a point estimate for computed by one or a few steps of gradient descent. However, this point estimate is not necessarily the global mode of a posterior. We can instead understand the point estimate given by truncated gradient descent as the value of the mode of an implicit posterior over resulting from an empirical loss interpreted as a negative log-likelihood, and regularization penalties and the early stopping procedure jointly acting as priors (for similar interpretations, see sjoberg1995overtraining; bishop1995regularization; duvenaud2016early).
The exact equivalence between early stopping and a Gaussian prior on the weights in the linear case, as well as the implicit regularization to the parameter initialization the nonlinear case, tells us that every iterate of truncated gradient descent is a mode of an implicit posterior. In particular, we are not required to take the gradient descent procedure of that computes to convergence in order to establish a connection between and . can therefore be understood to approximate an expectation of the marginal for each task as
using the point estimate for single-step .
The algorithm for as probabilistic inference is given in Algorithm 2; Subroutine 3 computes each marginal using the point estimate of as just described. Formulating in this way, as probabilistic inference in a , motivates the interpretation in Section 3.2 of using various meta-optimization algorithms to induce a prior over task-specific parameters.
2 The Prior Over Task-Specific Parameters
If is diagonal, we can identify (8) as a Newton method with a diagonal approximation to the inverse Hessian; using the inverse Hessian evaluated at the point recovers Newton’s method itself. On the other hand, meta-learning the matrix matrix via gradient descent provides a method to incorporate task-general information into the covariance of the prior, . For instance, the meta-learned matrix may encode correlations between parameters that dictates how such parameters are updated relative to each other.
Formally, taking steps of gradient descent from using the update rule in (8) gives a that solves
Improving Model-Agnostic Meta-Learning
Identifying as a method for probabilistic inference in a hierarchical model allows us to develop novel improvements to the algorithm. In Section 4.1, we consider an approach from Bayesian parameter estimation to improve the algorithm, and in Section LABEL:sec:novel_curvature, we discuss how to make this procedure computationally tractable for high-dimensional models.
We have shown that the algorithm is an procedure that employs a point estimate for the mid-level, task-specific parameters in a ian model. However, the use of this point estimate may lead to an inaccurate point approximation of the integral in (2) if the posterior over the task-specific parameters, , is not sharply peaked at the value of the point estimate. (laplace1986memoir; mackay1992evidence; mackay1992practical) is applicable in this case as it replaces a point estimate of an integral with the volume of a Gaussian centered at a mode of the integrand, thereby forming a local quadratic approximation.
We can make use of this approximation to incorporate uncertainty about the task-specific parameters into the algorithm at time. In particular, suppose that each integrand in (2) has a mode at which it is locally well-approximated by a quadratic function. uses a second-order Taylor expansion of the negative log posterior in order to approximate each integral in the product in (2) as
where is the Hessian matrix of second derivatives of the negative log posterior.