Recasting Gradient-Based Meta-Learning as Hierarchical Bayes

Erin Grant, Chelsea Finn, Sergey Levine, Trevor Darrell, Thomas Griffiths

Introduction

A remarkable aspect of human intelligence is the ability to quickly solve a novel problem and to be able to do so even in the face of limited experience in a novel domain. Such fast adaptation is made possible by leveraging prior learning experience in order to improve the efficiency of later learning. This capacity for also has the potential to enable an artificially intelligent agent to learn more efficiently in situations with little available data or limited computational resources (schmidhuber1987evolutionary; bengio1991learning; naik1992meta).

In machine learning, is formulated as the extraction of domain-general information that can act as an inductive bias to improve learning efficiency in novel tasks (caruana1998multitask; thrun1998learning). This inductive bias has been implemented in various ways: as learned hyperparameters in a that regularize task-specific parameters (heskes1998solving), as a learned metric space in which to group neighbors (bottou1992local), as a trained that allows encoding and retrieval of episodic information (santoro2016meta), or as an optimization algorithm with learned parameters (schmidhuber1987evolutionary; bengio1992optimization).

The of finn2017model is an instance of a learned optimization procedure that directly optimizes the standard gradient descent rule. The algorithm estimates an initial parameter set to be shared among the task-specific models; the intuition is that gradient descent from the learned initialization provides a favorable inductive bias for . However, this inductive bias has been evaluated only empirically in prior work (finn2017model).

In this work, we present a novel derivation of and a novel extension to , illustrating that this algorithm can be understood as inference for the parameters of a prior distribution in a . The learned prior allows for quick adaptation to unseen tasks on the basis of an implicit predictive density over task-specific parameters. The reinterpretation as gives a principled statistical motivation for as a algorithm, and sheds light on the reasons for its favorable performance even among methods with significantly more parameters. More importantly, by casting gradient-based within a Bayesian framework, we are able to improve by taking insights from Bayesian posterior estimation as novel augmentations to the gradient-based procedure. We experimentally demonstrate that this enables better performance on a few-shot learning benchmark.

Meta-Learning Formulation

The goal of a meta-learner is to extract task-general knowledge through the experience of solving a number of related tasks. By using this learned prior knowledge, the learner has the potential to quickly adapt to novel tasks even in the face of limited data or limited computation time.

Formally, we consider a dataset D\mathscr{D} that defines a distribution over a family of tasks T\mathcal{T}. These tasks share some common structure such that learning to solve a single task has the potential to aid in solving another. Each task T\mathcal{T} defines a distribution over data points x\mathbf{x}, which we assume in this work to consist of inputs and either regression targets or classification labels y\mathbf{y} in a supervised learning problem (although this assumption can be relaxed to include reinforcement learning problems; e.g., see finn2017model). The objective of the meta-learner is to be able to minimize a task-specific performance metric associated with any given unseen task from the dataset given even only a small amount of data from the task; i.e., to be capable of to a novel task.

In the following subsections, we discuss two ways of formulating a solution to the problem: and probabilistic inference in a . These approaches were developed orthogonally, but, in Section 3.1, we draw a novel connection between the two.

A parametric meta-learner aims to find some shared parameters θ\bm{\theta} that make it easier to find the right task-specific parameters ϕ\bm{\phi} when faced with a novel task. A variety of meta-learners that employ gradient methods for task-specific have been proposed (e.g., andrychowicz2016learning; li2017alearning; li2017blearning; wichrowska2017learned). (finn2017model) is distinct in that it provides a procedure that employs a single additional parameter (the rate) and operates on the same parameter space for both and . These are necessary features for the equivalence we show in Section 3.1.

To address the problem, estimates the parameters θ\bm{\theta} of a set of models so that when one or a few batch gradient descent steps are taken from the initialization at θ\bm{\theta} given a small sample of task data xj1,,xjNpTj(x)\smash{\mathbf{x}_{j_{1}},\dots,\mathbf{x}_{j_{N}}\sim p_{\mathcal{T}_{j}}(\mathbf{x})} each model has good generalization performance on another sample xjN+1,,xjN+MpTj(x)\smash{\mathbf{x}_{j_{N+1}},\dots,\mathbf{x}_{j_{N+M}}\sim p_{\mathcal{T}_{j}}(\mathbf{x})} from the same task. The objective in a maximum likelihood setting is

where we use ϕj\bm{\phi}_{j} to denote the updated parameters after taking a single batch gradient descent step from the initialization at θ\bm{\theta} with step size α\alpha on the negative log-likelihood associated with the task Tj\mathcal{T}_{j}. Note that since ϕj\bm{\phi}_{j} is an iterate of a gradient descent procedure that starts from θ\bm{\theta}, each ϕj\bm{\phi}_{j} is of the same dimensionality as θ\bm{\theta}. We refer to the inner gradient descent procedure that computes ϕj\bm{\phi}_{j} as . The computational graph of is given in Figure 1 (left).

2 as ian Inference

An alternative way to formulate is as a problem of probabilistic inference in the hierarchical model depicted in Figure 1 (right). In particular, in the case of , each task-specific parameter ϕj\smash{\bm{\phi}_{j}} is distinct from but should influence the estimation of the parameters {ϕjjj}\smash{\{\bm{\phi}_{j^{\prime}}\mid j^{\prime}\neq j\}} from other tasks. We can capture this intuition by introducing a meta-level parameter θ\bm{\theta} on which each task-specific parameter is statistically dependent. With this formulation, the mutual dependence of the task-specific parameters ϕj\bm{\phi}_{j} is realized only through their individual dependence on the meta-level parameters θ\bm{\theta} As such, estimating θ\bm{\theta} provides a way to constrain the estimation of each of the ϕj\bm{\phi}_{j}.

Given some data in a multi-task setting, we may estimate θ\bm{\theta} by integrating out the task-specific parameters to form the marginal likelihood of the data. Formally, grouping all of the data from each of the tasks as X\mathbf{X} and again denoting by xj1,,xjN\mathbf{x}_{j_{1}},\dots,\mathbf{x}_{j_{N}} a sample from task Tj\mathcal{T}_{j}, the marginal likelihood of the observed data is given by

Maximizing (2) as a function of θ\bm{\theta} gives a point estimate for θ\bm{\theta}, an instance of a method known as (bernardo2006bayesian; gelman2014bayesian) due to its use of the data to estimate the parameters of the prior distribution.

ian models have a long history of use in both and (e.g., lawrence2004learning; yu2005learning; gao2008knowledge; daume2009bayesian; wan2012sparse). However, the formulation of as does not automatically provide an inference procedure, and furthermore, there is no guarantee that inference is tractable for expressive models with many parameters such as deep neural networks.

Linking &

In this section, we connect the two independent approaches of Section 2.1 and Section 2.2 by showing that can be understood as in a hierarchical probabilistic model. Furthermore, we build on this understanding by showing that a choice of update rule for the task-specific parameters ϕj\bm{\phi}_{j} (i.e., a choice of inner-loop optimizer) corresponds to a choice of prior over task-specific parameters, p(ϕj    θ)\smash{p\mathchoice{\left({\,}\bm{\phi}_{j}\;|\;\bm{\theta}{\,}\right)}{({\,}\bm{\phi}_{j}\;|\;\bm{\theta}{\,})}{({\,}\bm{\phi}_{j}\;|\;\bm{\theta}{\,})}{({\,}\bm{\phi}_{j}\;|\;\bm{\theta}{\,})}}.

In general, when performing , the marginalization over task-specific parameters ϕj\bm{\phi}_{j} in (2) is not tractable to compute exactly. To avoid this issue, we can consider an approximation that makes use of a point estimate ϕ^j\smash{\hat{\bm{\phi}}_{j}} instead of performing the integration over ϕ\bm{\phi} in (2). Using ϕ^j\smash{\hat{\bm{\phi}}_{j}} as an estimator for each ϕj\bm{\phi}_{j}, we may write the negative logarithm of the marginal likelihood as

Setting ϕ^j=θ+αθlogp(xj1,,xjN    θ)\hat{\bm{\phi}}_{j}=\bm{\theta}+\alpha\operatorname{\nabla}_{\bm{\theta}}\log p\mathchoice{\left({\,}\mathbf{x}_{j_{1}},\dots,\mathbf{x}_{j_{N}}\;|\;\bm{\theta}{\,}\right)}{({\,}\mathbf{x}_{j_{1}},\dots,\mathbf{x}_{j_{N}}\;|\;\bm{\theta}{\,})}{({\,}\mathbf{x}_{j_{1}},\dots,\mathbf{x}_{j_{N}}\;|\;\bm{\theta}{\,})}{({\,}\mathbf{x}_{j_{1}},\dots,\mathbf{x}_{j_{N}}\;|\;\bm{\theta}{\,})} for each jj in (3) recovers the unscaled form of the one-step MAML objective in (1). This tells us that the MAML objective is equivalent to a maximization with respect to the meta-level parameters θ\bm{\theta} of the marginal likelihood p(X    θ)\smash{p\mathchoice{\left({\,}\mathbf{X}\;|\;\bm{\theta}{\,}\right)}{({\,}\mathbf{X}\;|\;\bm{\theta}{\,})}{({\,}\mathbf{X}\;|\;\bm{\theta}{\,})}{({\,}\mathbf{X}\;|\;\bm{\theta}{\,})}}, where a point estimate for each task-specific parameter ϕj\smash{\bm{\phi}_{j}} is computed via one or a few steps of gradient descent. By taking only a few steps from the initialization at θ\bm{\theta}, the point estimate ϕj^\smash{\hat{\bm{\phi}_{j}}} trades off minimizing the objective logp(xj1,,xjN    θ)\smash{-\log p\mathchoice{\left({\,}\mathbf{x}_{j_{1}},\dots,\mathbf{x}_{j_{N}}\;|\;\bm{\theta}{\,}\right)}{({\,}\mathbf{x}_{j_{1}},\dots,\mathbf{x}_{j_{N}}\;|\;\bm{\theta}{\,})}{({\,}\mathbf{x}_{j_{1}},\dots,\mathbf{x}_{j_{N}}\;|\;\bm{\theta}{\,})}{({\,}\mathbf{x}_{j_{1}},\dots,\mathbf{x}_{j_{N}}\;|\;\bm{\theta}{\,})}} with staying close in value to the parameter initialization θ\bm{\theta}.

We can formalize this trade-off by considering the linear regression case. Recall that the estimate of ϕj\bm{\phi}_{j} corresponds to the global mode of the posterior p(ϕj    xj1,xjN,θ)p(xj1,xjN    ϕj)p(ϕj    θ)\smash{p\mathchoice{\left({\,}\bm{\phi}_{j}\;|\;\mathbf{x}_{j_{1}},\dots\mathbf{x}_{j_{N}},\bm{\theta}{\,}\right)}{({\,}\bm{\phi}_{j}\;|\;\mathbf{x}_{j_{1}},\dots\mathbf{x}_{j_{N}},\bm{\theta}{\,})}{({\,}\bm{\phi}_{j}\;|\;\mathbf{x}_{j_{1}},\dots\mathbf{x}_{j_{N}},\bm{\theta}{\,})}{({\,}\bm{\phi}_{j}\;|\;\mathbf{x}_{j_{1}},\dots\mathbf{x}_{j_{N}},\bm{\theta}{\,})}\propto p\mathchoice{\left({\,}\mathbf{x}_{j_{1}},\dots\mathbf{x}_{j_{N}}\;|\;\bm{\phi}_{j}{\,}\right)}{({\,}\mathbf{x}_{j_{1}},\dots\mathbf{x}_{j_{N}}\;|\;\bm{\phi}_{j}{\,})}{({\,}\mathbf{x}_{j_{1}},\dots\mathbf{x}_{j_{N}}\;|\;\bm{\phi}_{j}{\,})}{({\,}\mathbf{x}_{j_{1}},\dots\mathbf{x}_{j_{N}}\;|\;\bm{\phi}_{j}{\,})}p\mathchoice{\left({\,}\bm{\phi}_{j}\;|\;\bm{\theta}{\,}\right)}{({\,}\bm{\phi}_{j}\;|\;\bm{\theta}{\,})}{({\,}\bm{\phi}_{j}\;|\;\bm{\theta}{\,})}{({\,}\bm{\phi}_{j}\;|\;\bm{\theta}{\,})}}. In the case of a linear model, early stopping of an iterative gradient descent procedure to estimate ϕj\bm{\phi}_{j} is exactly equivalent to estimation of ϕj\bm{\phi}_{j} under the assumption of a prior that depends on the number of descent steps as well as the direction in which each step is taken. In particular, write the input examples as X\mathbf{X} and the vector of regression targets as y\mathbf{y}, omit the task index from ϕ\bm{\phi}, and consider the gradient descent update

for iteration index kk and learning rate α+\alpha\in^{+}. santos1996equivalence shows that, starting from ϕ(0)=θ\bm{\phi}_{(0)}=\bm{\theta}, ϕ(k)\bm{\phi}_{(k)} in (4) solves the regularized linear least squares problem

Since ϕ(k)\bm{\phi}_{(k)} in (4) maximizes (6), we may conclude that kk iterations of gradient descent in a linear regression model with squared error exactly computes the estimate of ϕ\bm{\phi}, given a Gaussian-noised observation model and a Gaussian prior over ϕ\bm{\phi} with parameters μ0=θ\smash{{\bm{\mu}}_{0}=\bm{\theta}} and Σ0=Q\smash{{\bm{\Sigma}}_{0}=\mathbf{Q}}. Therefore, in the case of linear regression with squared error, is exactly using the estimate as the point estimate of ϕ\bm{\phi}.

In the nonlinear case, is again equivalent to an procedure to maximize the marginal likelihood that uses a point estimate for ϕ\bm{\phi} computed by one or a few steps of gradient descent. However, this point estimate is not necessarily the global mode of a posterior. We can instead understand the point estimate given by truncated gradient descent as the value of the mode of an implicit posterior over ϕ\bm{\phi} resulting from an empirical loss interpreted as a negative log-likelihood, and regularization penalties and the early stopping procedure jointly acting as priors (for similar interpretations, see sjoberg1995overtraining; bishop1995regularization; duvenaud2016early).

The exact equivalence between early stopping and a Gaussian prior on the weights in the linear case, as well as the implicit regularization to the parameter initialization the nonlinear case, tells us that every iterate of truncated gradient descent is a mode of an implicit posterior. In particular, we are not required to take the gradient descent procedure of that computes ϕ^\smash{\hat{\bm{\phi}}} to convergence in order to establish a connection between and . can therefore be understood to approximate an expectation of the marginal for each task Tj\mathcal{T}_{j} as

using the point estimate ϕ^j=θ+αθlogp(xjn    θ)\hat{\bm{\phi}}_{j}=\bm{\theta}+\alpha\operatorname{\nabla}_{\bm{\theta}}\log p\mathchoice{\left({\,}\mathbf{x}_{j_{n}}\;|\;\bm{\theta}{\,}\right)}{({\,}\mathbf{x}_{j_{n}}\;|\;\bm{\theta}{\,})}{({\,}\mathbf{x}_{j_{n}}\;|\;\bm{\theta}{\,})}{({\,}\mathbf{x}_{j_{n}}\;|\;\bm{\theta}{\,})} for single-step .

The algorithm for as probabilistic inference is given in Algorithm 2; Subroutine 3 computes each marginal using the point estimate of ϕ^\smash{\hat{\bm{\phi}}} as just described. Formulating in this way, as probabilistic inference in a , motivates the interpretation in Section 3.2 of using various meta-optimization algorithms to induce a prior over task-specific parameters.

2 The Prior Over Task-Specific Parameters

If B\mathcal{B} is diagonal, we can identify (8) as a Newton method with a diagonal approximation to the inverse Hessian; using the inverse Hessian evaluated at the point ϕ(k1)\bm{\phi}_{(k-1)} recovers Newton’s method itself. On the other hand, meta-learning the matrix B\mathcal{B} matrix via gradient descent provides a method to incorporate task-general information into the covariance of the prior, p(ϕ    θ)\smash{p\mathchoice{\left({\,}\bm{\phi}\;|\;\theta{\,}\right)}{({\,}\bm{\phi}\;|\;\theta{\,})}{({\,}\bm{\phi}\;|\;\theta{\,})}{({\,}\bm{\phi}\;|\;\theta{\,})}}. For instance, the meta-learned matrix B\mathcal{B} may encode correlations between parameters that dictates how such parameters are updated relative to each other.

Formally, taking kk steps of gradient descent from ϕ(0)=θ\bm{\phi}_{(0)}=\bm{\theta} using the update rule in (8) gives a ϕ(k)\bm{\phi}_{(k)} that solves

Improving Model-Agnostic Meta-Learning

Identifying as a method for probabilistic inference in a hierarchical model allows us to develop novel improvements to the algorithm. In Section 4.1, we consider an approach from Bayesian parameter estimation to improve the algorithm, and in Section LABEL:sec:novel_curvature, we discuss how to make this procedure computationally tractable for high-dimensional models.

We have shown that the algorithm is an procedure that employs a point estimate for the mid-level, task-specific parameters in a ian model. However, the use of this point estimate may lead to an inaccurate point approximation of the integral in (2) if the posterior over the task-specific parameters, p(ϕj    xjN+1,,xjN+M,θ)\smash{p\mathchoice{\left({\,}\bm{\phi}_{j}\;|\;\mathbf{x}_{j_{N+1}},\dots,\mathbf{x}_{j_{N+M}},\bm{\theta}{\,}\right)}{({\,}\bm{\phi}_{j}\;|\;\mathbf{x}_{j_{N+1}},\dots,\mathbf{x}_{j_{N+M}},\bm{\theta}{\,})}{({\,}\bm{\phi}_{j}\;|\;\mathbf{x}_{j_{N+1}},\dots,\mathbf{x}_{j_{N+M}},\bm{\theta}{\,})}{({\,}\bm{\phi}_{j}\;|\;\mathbf{x}_{j_{N+1}},\dots,\mathbf{x}_{j_{N+M}},\bm{\theta}{\,})}}, is not sharply peaked at the value of the point estimate. (laplace1986memoir; mackay1992evidence; mackay1992practical) is applicable in this case as it replaces a point estimate of an integral with the volume of a Gaussian centered at a mode of the integrand, thereby forming a local quadratic approximation.

We can make use of this approximation to incorporate uncertainty about the task-specific parameters into the algorithm at time. In particular, suppose that each integrand in (2) has a mode ϕj\bm{\phi}^{*}_{j} at which it is locally well-approximated by a quadratic function. uses a second-order Taylor expansion of the negative log posterior in order to approximate each integral in the product in (2) as

where Hj\smash{\mathbf{H}_{j}} is the Hessian matrix of second derivatives of the negative log posterior.