In this post I reproduce two recent papers in the field of metalearning: MAML and the similar Reptile. The full notebook for this reproduction can be found here.

The goal of both of these papers is to solve the K-shot learning problem. In K-shot learning, we need to train a neural network to generalize based on a very small number of examples (often on the order of 10 or so) instead of the often thousands of examples we see in datasets like ImageNet.

However, in preparation for K-shot learning, you are allowed to train on many similar K-shot problems to learn the best way to generalize based on only K examples.

This is learning to learn or metalearning. We have already seen metalearning in my post on “Learning to Learn by Gradient Descent by Gradient Descent”, which you can find here:

The metalearning approach of both Reptile and MAML is to come up with an initialization for neural networks that is easily generalizable to similar tasks. This is different to “Learning to Learn by Gradient Descent by Gradient Descent” in which we weren’t learning an initialization but rather an optimizer.

Transfer Learning

This approach is very similar to transfer learning, in which we train a network on, say, ImageNet, and it later turns out that fine-tuning this network makes it easy to learn another image dataset, with much less data. Indeed, transfer learning can be seen as a form of metalearning. Indeed, it can be used to learn from very small datasets as you can see here.

The difference here is that the initial network was trained with the explicit purpose of being easily generalizable, whereas transfer learning just “accidentally” happens to work, and thus might not work optimally.

Indeed, it is fairly easy to find a in which transfer learnings fails to learn a good initialization. For this we need to look at the 1D sine wave regression problem.

In this K-shot problem, each task consists in learning a modified sine function. Specifically, for each task, the underlying function will be of the form y = a sin(x + b), with both a and b chosen randomly, and the goal of our neural network is to learn to find y given x based on only 10 (x, y) pairs.

Let’s plot a couple of example sine wave tasks:

3 random tasks

To understand why this is going to be a problem for transfer learning, let’s plot 1,000 of them:

1,000 random tasks

Looks like there is a lot of overlap at each x value, to say the least…

Since there are multiple possible values for each x across multiple tasks, if we train a single neural net to deal with multiple tasks at the same time, its best bet will simply be to return the average y value across all tasks for each x. So what is the average y value for each x?

Average value for each x, with random task shown for scale

The average is basically 0, which means a neural network trained on a lot of tasks would simply return 0 everywhere! It is unclear that this will actually help very much, and yet this is the transfer learning approach in this case…

Let’s see how well it does by actually implementing a simple model to solve these sine wave tasks and training it using a transfer learning. First, the model itself:

You’ll notice that it is implemented in a weird way (what’s a “ModifiableModule”? What’s a “GradLinear”?). This is because we will later train it using MAML. For details on what some of these classes are, check out the notebook, but for now you can assume they are similar to nn.Module and nn.Linear.

Now, let’s train it for a while on a bunch of different random tasks in sequence:

And here’s what happen when we try to fine-tune this transfer model to a specific random task:

Transfer learning on a specific random task

Basically it looks like our transfer model learns a constant function and that it is really hard to fine tune it to something better. It’s not even clear that our transfer learning is any better than random initialization… And indeed it isn’t! A random initialization ends up getting a better loss over time than fine tuning our transfer model.

Learning curve for transfer learning vs random initialization

MAML

We now come to MAML, the first of the two algorithms we will look at today.

As mentioned before, we are trying to find a set of weights such that running gradient descent on similar tasks makes progress as quickly as possible. MAML takes this extremely literally by running one iteration of gradient descent and then updating the initial weights based on how much progress that one iteration made towards the true task. More concretely it:

Creates a copy of the initialization weights

Runs an iteration of gradient descent for a random task on the copy

Backpropagates the loss on a test set through the iteration of gradient descent and back to the initial weights, so that we can update the initial weights in a direction in which they would have been easier to update.

We thus need to take a gradient of a gradient, aka a second degree derivative in this process. Fortunately this is something that PyTorch supports now, unfortunately PyTorch makes it a bit awkward to update the parameters of a model in a way that we can still run gradient descent through them (we already saw this is “Learning to Learn by Gradient Descent by Gradient Descent”), which explains the weird way in which the model is written.

Because we are going to use second derivatives, we need to make sure that the computational graph that allowed us to compute the original gradients stays around, which is why we pass create_graph=True to .backward() .

So how does it work on a specific random task?

Training on a random function using MAML

Wow that’s much better, even after a single step of gradient descent the sine shape starts being visible, and after 10 steps the center of the wave is almost fully correct. Is this reflected in the learning curve? Yes!

Unfortunately, it is a bit annoying that we have to use second order derivatives for this… it forces the code to be complicated and it also makes things a fair bit slower (around 33% according to the paper, which matches what we shall see here).

Is there an approximation of MAML that doesn’t use the second order derivatives? Of course! We can simply pretend that the gradients that we used for the inner gradient descent just came out of nowhere, and thus just improve the initial parameters without taking into account these second order derivatives. Let’s add a first_order parameter to our MAML training function to deal with this:

So how good is this first order approximation? Almost as good as the original MAML, as it turns out, and it is indeed about 33% faster.

Learning curve of MAML vs MAML first order

Reptile

The first order approximation for MAML tells us that something interesting is going on: after all, it seems like how the gradients were generated should be relevant for a good initialization, and yet it apparently isn’t so much.

Reptile takes this idea even further by telling us to do the following: run SGD for a few iterations on a given task, and then move your initialization weights a little bit in the direction of the weights you obtained after your k iterations of SGD. An algorithm so simple, it takes only a couple lines of pseudocode:

When I first read this, I was quite consternated: isn’t this the same as training your weights alternatively on each task, just like in transfer learning? How would this ever work?

Indeed, the Reptile paper anticipates this very reaction:

You might be thinking “isn’t this the same as training on the expected loss Eτ [Lτ]?” and then checking if the date is April 1st.

As it happens, I am writing this on April 2nd, so this is all serious. So what’s going on?

Well, indeed if we had run SGD for a single iteration, we would have something equivalent to the transfer learning described above, but we aren’t, we are using a few iterations, and so indeed the weights we update towards each time actually depend indirectly on the second derivatives of the loss, similar to MAML.

Ok, but still, why would this work? Well Reptile provides a compelling intuition for this: for each task, there are weights that are optimal. Indeed, there are probably many sets of weights that are optimal. This means that if you take several tasks, there should be a set of weights for which the distance to at least one optimal set of weights for each task is minimal. This set of weights is where we want to initialize our networks, since it is likely to be the one for which the least work is necessary to reach the optimum for any task. This is the set of weights that Reptile finds.

We can see this expressed visually in the following image: the two black lines represent the sets of optimal weights for two different tasks, while the gray line represents the initialization weights. Reptile tries to get the initialization weights closer and closer to the point where the optimal weights are nearest to each other.

Let’s now implement Reptile and compare it to MAML:

How does it look on a random problem? Beautiful:

Reptile performance on a random problem

What about the learning curve?

Learning curve of Reptile, MAML and MAML first-order

It looks like Reptile does indeed achieve similar or even slightly better performance to MAML with a much simpler and slightly faster algorithm!

All this applies to many more problems than just this toy example of sine waves. For more details, I really do recommend you read the papers. At this point, you should have enough background to understand them quite easily.

What would be really interesting in the future would be to apply these approaches not just to the K-shot learning problem but to bigger problems: transfer learning has been extremely successful in the image classification world for training models based on medium sized datasets (a few hundreds or thousands, as opposed to the around 10-ish that is common in K-shot learning). Would training a resnet network using Reptile produce something that is even more amenable to transfer learning than the models we have now?