I’ve spent most of 2018 training neural networks that tackle the limits of my GPUs. Whether it was a 150 millions parameters language model like OpenAI’s huge Generative Pre-trained Transformer (or the recent and similar BERT model) or a meta-learning neural net fed with 30 million element inputs like the one of our ICLR ‘18 paper, I could barely fit more than a few training samples on a GPU.

But most of the time stochastic gradient descent algorithms require larger batches than just a handful of examples to get decent results.

How can you train your model on large batches when your GPU can’t hold more than a few samples?

There are several tools, tips and tricks you can use to do that and I thought it would be nice to gather all the things I use and learned in a post.

In this post I will mainly talk about the PyTorch framework. Some of these tools are not in PyTorch yet (as of 1.0) so I include some custom code as well.

In particular, we’ll talk about:

How you can train a model on a single or multi GPU server with batches larger than the GPUs memory or when even a single training sample won’t fit (!),

How you can make the most efficient use of a multi-GPU machine, and

The simplest way to train a model using several machines in a distributed setup.

Let’s start by the simplest trick: gradient accumulation.

⌛️Large batches on one or several GPU(s)

So, you’ve build a nice model that might be the new SOTA on this neat task but every time you try to stack more than a few samples in a batch you get a CUDA RuntimeError: out of memory.

Adam confirms your predicament! 😱Oh no!

But you’re pretty sure that doubling the batch size will improve the results.

How can you do that?

There is an easy solution to this problem: accumulating gradients. Here is a quick reminder on how stochastic gradient descent works from my earlier post on meta-learning:

The 5-steps of a gradient descent optimization algorithm

The PyTorch code equivalent of these 5 steps can also be written in 5 lines:

During the loss.backward() operation, gradients are computed for each parameter (in green on our animation) and stored in a tensor associated to each parameter: parameter.grad (the middle graph on our animation).

Accumulating gradients just means that, before calling optimizer.step() to perform a step of gradient descent, we will sum the gradients of several backward operations in the parameter.grad tensors. This is straightforward to do in PyTorch as the gradient tensors are not reset unless we call model.zero_grad() or optimizer.zero_grad() . We’ll also need to divide by the number of accumulation steps if our loss is averaged over the training samples.

Here is a simple gist for training a model using gradient accumulation. In this example we can train with a batch size that is accumulation_steps -larger than the maximum size that fits on our GPU(s):

Grzegorz Chlebus made a nice post describing how to do gradient accumulation in TensorFlow, check it out here.

😱 Pushing that to the extreme

Can you train a model for which not even a single sample can fit on a GPU?

Well if your architecture doesn’t have too-much skip connections, yes, it’s possible! The solution is to trade compute for memory using gradient-checkpointing.

Basically, the idea is to back-propagate the gradients in small chunks along the model, trading the memory needed to store a full back propagation graph with the additional compute of a partial forward pass associated to each chunk. This is a rather slow method as we add additional compute to reduce the memory requirements but it can be interesting in some settings, e.g. to train RNN models over very long sequences (see for example my previous introduction to meta-learning).

I won’t go into more details here and will just refer you to the relevant links:

A “Memory-poor” strategy that needs O(1) memory (but requires O(n²) computation steps) — From Yaroslav Bulatov’s nice post: https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9

🕰 Making the best of a multi-GPU machine

Now let’s talk more specifically about training model on multi-GPUs.

The go-to strategy to train a PyTorch model on a multi-GPU server is to use torch.nn.DataParallel. It’s a container which parallelizes the application of a module by splitting the input across the specified devices, chunking along the batch dimension.

DataParallel is very easy to use, we just add one line to encapsulate the model:

However one issue can arise with DataParallel: unbalanced GPU usage.

Under some settings GPU-1 will be used a lot more than the other GPUs.

Where does this come from? I made an illustration to better explain what DataParallel does under the hood: