On First-Order Meta-Learning Algorithms

Alex Nichol, Joshua Achiam, John Schulman

Introduction

While machine learning systems have surpassed humans at many tasks, they generally need far more data to reach the same level of performance. For example, Schmidt et al. showed that human subjects can recognize new object categories based on a few example images. Lake et al. noted that on the Atari game of Frostbite, human novices were able to make significant progress on the game after 15 minutes, but double-dueling-DQN required more than 1000 times more experience to attain the same score.

It is not completely fair to compare humans to algorithms learning from scratch, since humans enter the task with a large amount of prior knowledge, encoded in their brains and DNA. Rather than learning from scratch, they are fine-tuning and recombining a set of pre-existing skills. The work cited above, by Tenenbaum and collaborators, argues that humans’ fast-learning abilities can be explained as Bayesian inference, and that the key to developing algorithms with human-level learning speed is to make our algorithms more Bayesian. However, in practice, it is challenging to develop (from first principles) Bayesian machine learning algorithms that make use of deep neural networks and are computationally feasible.

Meta-learning has emerged recently as an approach for learning from small amounts of data. Rather than trying to emulate Bayesian inference (which may be computationally intractable), meta-learning seeks to directly optimize a fast-learning algorithm, using a dataset of tasks. Specifically, we assume access to a distribution over tasks, where each task is, for example, a classification problem. From this distribution, we sample a training set and a test set of tasks. Our algorithm is fed the training set, and it must produce an agent that has good average performance on the test set. Since each task corresponds to a learning problem, performing well on a task corresponds to learning quickly.

A variety of different approaches to meta-learning have been proposed, each with its own pros and cons. In one approach, the learning algorithm is encoded in the weights of a recurrent network, but gradient descent is not performed at test time. This approach was proposed by Hochreiter et al. who used LSTMs for next-step prediction and has been followed up by a burst of recent work, for example, Santoro et al. on few-shot classification, and Duan et al. for the POMDP setting.

A second approach is to learn the initialization of a network, which is then fine-tuned at test time on the new task. A classic example of this approach is pretraining using a large dataset (such as ImageNet ) and fine-tuning on a smaller dataset (such as a dataset of different species of bird ). However, this classic pre-training approach has no guarantee of learning an initialization that is good for fine-tuning, and ad-hoc tricks are required for good performance. More recently, Finn et al. proposed an algorithm called MAML, which directly optimizes performance with respect to this initialization—differentiating through the fine-tuning process. In this approach, the learner falls back on a sensible gradient-based learning algorithm even when it receives out-of-sample data, thus allowing it to generalize better than the RNN-based approaches . On the other hand, since MAML needs to differentiate through the optimization process, it’s not a good match for problems where we need to perform a large number of gradient steps at test time. The authors also proposed a variant called first-order MAML (FOMAML), which is defined by ignoring the second derivative terms, avoiding this problem but at the expense of losing some gradient information. Surprisingly, though, they found that FOMAML worked nearly as well as MAML on the Mini-ImageNet dataset . (This result was foreshadowed by prior work in meta-learning that ignored second derivatives when differentiating through gradient descent, without ill effect.) In this work, we expand on that insight and explore the potential of meta-learning algorithms based on first-order gradient information, motivated by the potential applicability to problems where it’s too cumbersome to apply techniques that rely on higher-order gradients (like full MAML).

We point out that first-order MAML is simpler to implement than was widely recognized prior to this article.

We introduce Reptile, an algorithm closely related to FOMAML, which is equally simple to implement. Reptile is so similar to joint training (i.e., training to minimize loss on the expecation over training tasks) that it is especially surprising that it works as a meta-learning algorithm. Unlike FOMAML, Reptile doesn’t need a training-test split for each task, which may make it a more natural choice in certain settings. It is also related to the older idea of fast weights / slow weights .

We provide a theoretical analysis that applies to both first-order MAML and Reptile, showing that they both optimize for within-task generalization.

On the basis of empirical evaluation on the Mini-ImageNet and Omniglot datasets, we provide some insights for best practices in implementation.

Meta-Learning an Initialization

We consider the optimization problem of MAML : find an initial set of parameters, ϕ\phi, such that for a randomly sampled task τ\tau with corresponding loss LτL_{\tau}, the learner will have low loss after kk updates. That is:

where UτkU^{k}_{\tau} is the operator that updates ϕ\phi kk times using data sampled from τ\tau. In few-shot learning, UU corresponds to performing gradient descent or Adam on batches of data sampled from τ\tau.

MAML works by optimizing this loss through stochastic gradient descent, i.e., computing

Reptile

In this section, we describe a new first-order gradient-based meta-learning algorithm called Reptile. Like MAML, Reptile learns an initialization for the parameters of a neural network model, such that when we optimize these parameters at test time, learning is fast—i.e., the model generalizes from a small number of examples from the test task. The Reptile algorithm is as follows:

In the last step, instead of simply updating ϕ\phi in the direction \mathchoice{\hbox{\displaystyle\widetilde{\phi}}}{\hbox{\textstyle\widetilde{\phi}}}{\hbox{\scriptstyle\widetilde{\phi}}}{\hbox{\scriptscriptstyle\widetilde{\phi}}}-\phi, we can treat (\phi-\mathchoice{\hbox{\displaystyle\widetilde{\phi}}}{\hbox{\textstyle\widetilde{\phi}}}{\hbox{\scriptstyle\widetilde{\phi}}}{\hbox{\scriptscriptstyle\widetilde{\phi}}}) as a gradient and plug it into an adaptive algorithm such as Adam . (Actually, as we will discuss in Section 5.1, it is most natural to define the Reptile gradient as (\phi-\mathchoice{\hbox{\displaystyle\widetilde{\phi}}}{\hbox{\textstyle\widetilde{\phi}}}{\hbox{\scriptstyle\widetilde{\phi}}}{\hbox{\scriptscriptstyle\widetilde{\phi}}})/\alpha, where α\alpha is the stepsize used by the SGD operation.) We can also define a parallel or batch version of the algorithm that evaluates on nn tasks each iteration and updates the initialization to

Other than the stepsize parameter ϵ\epsilon and task sampling, the batched version of Reptile is the same as the SimuParallelSGD algorithm . SimuParallelSGD is a method for communication-efficient distributed optimization, where workers perform gradient updates locally and infrequently average their parameters, rather than the standard approach of averaging gradients.

Case Study: One-Dimensional Sine Wave Regression

As a simple case study, let’s consider the 1D sine wave regression problem, which is slightly modified from Finn et al. . This problem is instructive since by design, joint training can’t learn a very useful initialization; however, meta-learning methods can.

The task τ=(a,b)\tau=(a,b) is defined by the amplitude aa and phase ϕ\phi of a sine wave function fτ(x)=asin(x+b)f_{\tau}(x)=a\sin(x+b). The task distribution by sampling aU([0.1,5.0])a\sim U([0.1,5.0]) and bU([0,2π])b\sim U([0,2\pi]).

Sample pp points x1,x2,,xpU()x_{1},x_{2},\dots,x_{p}\sim U()

Learner sees (x1,y1),(x2,y2),,(xp,yp)(x_{1},y_{1}),(x_{2},y_{2}),\dots,(x_{p},y_{p}) and predicts the whole function f(x)f(x)

We calculate this integral using 5050 equally-spaced points xx.

On the other hand, MAML and Reptile give us an initialization that outputs approximately f(x)=0f(x)=0 before training on a task τ\tau, but the internal feature representations of the network are such that after training on the sampled datapoints (x1,y1),(x2,y2),,(xp,yp)(x_{1},y_{1}),(x_{2},y_{2}),\dots,(x_{p},y_{p}), it closely approximates the target function fτf_{\tau}. This learning progress is shown in the figures below. Figure 1 shows that after Reptile training, the network can quickly converge to a sampled sine wave and infer the values away from the sampled points. As points of comparison, we also show the behaviors of MAML and a randomly-initialized network on the same task.

Analysis

In this section, we provide two alternative explanations of why Reptile works.

Here, we will use a Taylor series expansion to approximate the update performed by Reptile and MAML. We will show that both algorithms contain the same leading-order terms: the first term minimizes the expected loss (joint training), the second and more interesting term maximizes within-task generalization. Specifically, it maximizes the inner product between the gradients on different minibatches from the same task. If gradients from different batches have positive inner product, then taking a gradient step on one batch improves performance on the other batch.

Unlike in the discussion and analysis of MAML, we won’t consider a training set and test set from each task; instead, we’ll just assume that each task gives us a sequence of kk loss functions L1,L2,,LkL_{1},L_{2},\dots,L_{k}; for example, classification loss on different minibatches. We will use the following definitions:

For each of these definitions, i[1,k]i\in[1,k].

First, let’s calculate the SGD gradients to O(α2)O(\alpha^{2}) as follows.

Next, we will approximate the MAML gradient. Define UiU_{i} as the operator that updates the parameter vector on minibatch ii: Ui(ϕ)=ϕαLi(ϕ)U_{i}(\phi)=\phi-\alpha L_{i}^{\prime}(\phi).

For simplicity of exposition, let’s consider the k=2k=2 case, and later we’ll provide the general formulas.

As we will show in the next paragraph, the terms like H2g1\overline{H}_{2}\overline{g}_{1} serve to maximize the inner products between the gradients computed on different minibatches, while lone gradient terms like g1\overline{g}_{1} take us to the minimum of the joint training problem.

Recalling our gradient expressions, we get the following expressions for the meta-gradients, for SGD with k=2k=2:

Finally, we can extend these calculations to the general k2k\geq 2 case:

2 Finding a Point Near All Solution Manifolds

Here, we argue that Reptile converges towards a solution ϕ\phi that is close (in Euclidean distance) to each task τ\tau’s manifold of optimal solutions. This is a informal argument and should be taken much less seriously than the preceding Taylor series analysis.

Let ϕ\phi denote the network initialization, and let Wτ\mathcal{W}_{\tau} denote the set of optimal parameters for task τ\tau. We want to find ϕ\phi such that the distance D(ϕ,Wτ)D(\phi,\mathcal{W}_{\tau}) is small for all tasks.

We will show that Reptile corresponds to performing SGD on that objective.

Each iteration of Reptile corresponds to sampling a task τ\tau and performing a stochastic gradient update

In practice, we can’t exactly compute PWτ(ϕ)P_{\mathcal{W}_{\tau}}(\phi), which is defined as a minimizer of LτL_{\tau}. However, we can partially minimize this loss using gradient descent. Hence, in Reptile we replace Wτ(ϕ)W^{\ast}_{\tau}(\phi) by the result of running kk steps of gradient descent on LτL_{\tau} starting with initialization ϕ\phi.

Experiments

We evaluate our method on two popular few-shot classification tasks: Omniglot and Mini-ImageNet . These datasets make it easy to compare our method to other few-shot learning approaches like MAML.

In few-shot classification tasks, we have a meta-dataset DD containing many classes CC, where each class is itself a set of example instances {c1,c2,...,cn}\{c_{1},c_{2},...,c_{n}\}. If we are doing KK-shot, NN-way classification, then we sample tasks by selecting NN classes from CC and then selecting K+1K+1 examples for each class. We split these examples into a training set and a test set, where the test set contains a single example for each class. The model gets to see the entire training set, and then it must classify a randomly chosen sample from the test set. For example, if you trained a model for 5-shot, 5-way classification, then you would show it 25 examples (5 per class) and ask it to classify a 26th example.

In addition to the above setup, we also experimented with the transductive setting, where the model classifies the entire test set at once. In our transductive experiments, information was shared between the test samples via batch normalization . In our non-transductive experiments, batch normalization statistics were computed using all of the training samples and a single test sample. We note that Finn et al. use transduction for evaluating MAML.

For our experiments, we used the same CNN architectures and data preprocessing as Finn et al. . We used the Adam optimizer in the inner loop, and vanilla SGD in the outer loop, throughout our experiments. For Adam we set β1=0\beta_{1}=0 because we found that momentum reduced performance across the board.This finding also matches our analysis from Section 5.1, which suggests that Reptile works because sequential steps come from different mini-batches. With momentum, a mini-batch has influence over the next few steps, reducing this effect. During training, we never reset or interpolated Adam’s rolling moment data; instead, we let it update automatically at every inner-loop training step. However, we did backup and reset the Adam statistics when evaluating on the test set to avoid information leakage.

The results on Omniglot and Mini-ImageNet are shown in Tables 1 and 2. While MAML, FOMAML, and Reptile have very similar performance on all of these tasks, Reptile does slightly better than the alternatives on Mini-ImageNet and slightly worse on Omniglot. It also seems that transduction gives a performance boost in all cases, suggesting that further research should pay close attention to its use of batch normalization during testing.

2 Comparing Different Inner-Loop Gradient Combinations

For this experiment, we used four non-overlapping mini-batches in each inner-loop, yielding gradients g1g_{1}, g2g_{2}, g3g_{3}, and g4g_{4}. We then compared learning performance when using different linear combinations of the gig_{i}’s for the outer loop update. Note that two-step Reptile corresponds to g1+g2g_{1}+g_{2}, and two-step FOMAML corresponds to g2g_{2}.

To make it easier to get an apples-to-apples comparison between different linear combinations, we simplified our experimental setup in several ways. First, we used vanilla SGD in the inner- and outer-loops. Second, we did not use meta-batches. Third, we restricted our experiments to 5-shot, 5-way Omniglot. With these simplifications, we did not have to worry as much about the effects of hyper-parameters or optimizers.

Figure 3 shows the learning curves for various inner-loop gradient combinations. For gradient combinations with more than one term, we ran both a sum and an average of the inner gradients to correct for the effective step size increase.

3 Overlap Between Inner-Loop Mini-Batches

Both Reptile and FOMAML use stochastic optimization in their inner-loops. Small changes to this optimization procedure can lead to large changes in final performance. This section explores the sensitivity of Reptile and FOMAML to the inner loop hyperparameters, and also shows that FOMAML’s performance significantly drops if mini-batches are selected the wrong way.

Figure 4(b) shows a similar phenomenon, but here we fixed the inner-loop to four iterations and instead varied the batch size. For batch sizes greater than 25, the final inner-loop batch for shared-tail FOMAML necessarily contains samples from the previous batches. Similar to Figure 4(a), here we observe that shared-tail FOMAML with random sampling degrades more gradually than shared-tail FOMAML with cycling.

In both of these parameter sweeps, separate-tail FOMAML and Reptile do not degrade in performance as the number of inner-loop iterations or batch size changes.

There are several possible explanations for above findings. For example, one might hypothesize that shared-tail FOMAML is only worse in these experiments because its effective step size is much lower than that of separate-tail FOMAML. However, Figure 4(c) suggests that this is not the case: performance was equally poor for every choice of step size in a thorough sweep. A different hypothesis is that shared-tail FOMAML performs poorly because, after a few inner-loop steps on a sample, the gradient of the loss for that sample does not contain very much useful information about the sample. In other words, the first few SGD steps might bring the model close to a local optimum, and then further SGD steps might simply bounce around this local optimum.

Discussion

Meta-learning algorithms that perform gradient descent at test time are appealing because of their simplicity and generalization properties . The effectiveness of fine-tuning (e.g. from models trained on ImageNet ) gives us additional faith in these approaches. This paper proposed a new algorithm called Reptile, whose training process is only subtlely different from joint training and only uses first-order gradient information (like first-order MAML).

We gave two theoretical explanations for why Reptile works. First, by approximating the update with a Taylor series, we showed that SGD automatically gives us the same kind of second-order term that MAML computes. This term adjusts the initial weights to maximize the dot product between the gradients of different minibatches on the same task—i.e., it encourages the gradients to generalize between minibatches of the same task. We also provided a second informal argument, which is that Reptile finds a point that is close (in Euclidean distance) to all of the optimal solution manifolds of the training tasks.

While this paper studies the meta-learning setting, the Taylor series analysis in Section 5.1 may have some bearing on stochastic gradient descent in general. It suggests that when doing stochastic gradient descent, we are automatically performing a MAML-like update that maximizes the generalization between different minibatches. This observation partly explains why fine tuning (e.g., from ImageNet to a smaller dataset ) works well. This hypothesis would suggest that joint training plus fine tuning will continue to be a strong baseline for meta-learning in various machine learning problems.

Future Work

We see several promising directions for future work:

Understanding to what extent SGD automatically optimizes for generalization, and whether this effect can be amplified in the non-meta-learning setting.

Applying Reptile in the reinforcement learning setting. So far, we have obtained negative results, since joint training is a strong baseline, so some modifications to Reptile might be necessary.

Exploring whether Reptile’s few-shot learning performance can be improved by deeper architectures for the classifier.

Exploring whether regularization can improve few-shot learning performance, as currently there is a large gap between training and testing error.

Evaluating Reptile on the task of few-shot density modeling .

References

Appendix A Hyper-parameters

For all experiments, we linearly annealed the outer step size to 0. We ran each experiment with three different random seeds, and computed the confidence intervals using the standard deviation across the runs.

Initially, we tried optimizing the Reptile hyper-parameters using CMA-ES . However, we found that most hyper-parameters had little effect on the resulting performance. After seeing this result, we simplified all of the hyper-parameters and shared hyper-parameters between experiments when it made sense.