This post assumes you have a CS231n-ish level of understanding of neural networks (aka you have taken a university level introduction course to deep learning). If you are completely new to neural nets, I highly recommend exactly said course as a perfect resource to quickly get up to speed. Honestly, it was when following those lectures that I developed most of my intuitions about CNNs that (the intuitions) I still rely on every day — Thanks Andrej! Thanks Justin!)

The first image that comes up if you google for batch renormalization. Source: http://bit.ly/2mXnNCq

If you’re like me, you enjoy throwing CNNs at every pictorial problem that comes your way. You feel confident explaining to your MBA friends how¹ neural nets work and you like to complain with your CS² friends about the price tag of Nvidia’s new GPUs. If you’re like me, then you have heard of BatchNorm. You have probably used it³, googled for an explanation of what internal covariate shift means and somewhat satisfied returned to your daily business of tuning some more hyperparameters.

Now, chances are you haven’t heard of batch renormalization⁴, Ioffe’s follow-up paper.

The tl;dr for most people reads something like this:

BatchRenorm is superior to BatchNorm, implemented in TensorFlow, but comes at the cost of a few extra hyperparameters to tune. If your batch size is very small (say 2 or 4 — most likely due to constrained GPU memory) you should probably use it.

Below, I’ll try to give you a refresher on how BatchNorm is related to transfer learning, why you should be just a little paranoid when using BatchNorm (i.e. how it can break down in unexpected ways) and how BatchRenorm will help you go back to sleeping like a baby, but only — to stretch the simile a bit — if you are willing to do a few extra push-ups before going to bed.

Transfer Learning

The transfer learning scenario you are most familiar with, is probably this: you have a CNN pre-trained on ImageNet that you now want to use to distinguish between your left foot socks and your right foot socks (or something of the sort). Depending on how much data you have, you unfreeze the last few layers or re-train the whole CNN on your very own sock dataset — so far, so good. But there are other scenarios of transfer learning, that are, if not as omnipresent, just as important⁵.

A train and test set, that would benefit from some domain adaptation magic.

When — instead of generalizing from one task to the next (like classifying socks instead of dogs and cats) — we want our model to generalize from a source domain to a target domain with a different data distribution, this is called domain adaptation. To get an intuitive understanding for why this might be a difficult problem, consider two datasets with only cats and dogs. Both include your normal variety of cats, but the first only has dogs that are brown. If we are unlucky, our model trained on the first dataset, will not learn differences such as pointy vs floppy ears, but only that if it sees a brown thing it must be a dog. If so, we shouldn’t be surprised if our classification accuracy suffers a bolt from the blue when testing on the second dataset⁶.

The takeaway is a universal truth of machine learning: if your data distribution changes under your nose, you are probably in trouble.

BatchNorm is not having any of it

Back to BatchNorm and googling what “reducing internal covariate shift” means. “Internal covariate shift” is just a fancy term for the fact that the input (“data”) distributions of intermediate layers of neural networks change during training. This is not surprising since the input of an intermediate layer is simply the output of the layer before it and as the parameters of this “pre”-layer get updated over time, its output will change too.

Instead of trying to find a clever “internal domain adaptation” technique, the ingenious solution of Ioffe-Szegedy to this problem of changing input distributions is to simply sidestep it. They use BatchNorm to force every layer input to be normalized and voilà: no more mess of shape-shifting distributions.

For quick reference here is the algorithm.

BatchNorm algorithm — during training the inputs are normalized over each mini-batch. The scale and shift at the end is meant to give the model some flexibility to unlearn the normalization if necessary. During inference the inputs are normalized using a moving average of the mini-batch means and variances seen during training. Source: Original BatchNorm paper by I-S.

It turns out that using BatchNorm also makes your model more robust to less careful weight initialization and larger learning rates⁷. And another goodie:

I-S report that the noise introduced by computing the mean and variance over each mini-batch instead of over the entire training set⁸ isn’t just bad news, but acts as regularization and can remove the need to add extra dropout layers.

Why BatchNorm should make you paranoid

You know what I hate: if my code compiles, my model trains, but for some well-hidden reason the model performance is much worse than expected. Unfortunately, under certain circumstances BatchNorm can be that well-hidden reason. To understand when this happens I highly recommend reading Alex Irpan’s post on the perils of BatchNorm. In any case, here is my executive summary:

When the mini-batch mean (µB) and mini-batch standard deviation (σB) diverge from the mean and standard deviation over the entire training set too often, BatchNorm breaks. Remember that at inference time we use the moving averages of µB and σB (as an estimate of the statistics of the entire training set) to do the normalization step. Naturally, if your means and standard deviations during training and testing are different, so are your activations and you can’t be surprised if your results are different (read worse), too. This can happen when your mini-batch samples are non-i.i.d. (or in plain language: when your sampling procedure is biased — think first sampling only brown dogs and then sampling only black dogs) or, more commonly, when you have a very small batch size⁹. In both cases: Welcome back to “shape-shifting distributions”-land.

Enter BatchRenorm

BatchRenorm tackles this issue of differing statistics at train and inference time head-on. The key insight to bridge the difference is this:

Source: Ioffe’s BatchRenorm paper.

The normalization step at inference time (using estimates of the training set statistics µ and σ) can actually be rewritten as an affine transformation of the normalization step at training time (using mini-batch statistics µB and σB)! And that’s basically all there is to it. Using mini-batch + affine transformation at train time and moving averages at inference time ensures that the output of BatchRenorm is the same during both phases, even when σB != σ and µB != µ.

Here is the algorithm in its non-japanese entirety.

The stop_gradient operation is there to ensure that r and d are treated as constants during the gradient computations¹⁰. r_max and d_max are two hyperparameters introduced to control the transition between BatchNorm and BatchRenorm. In the paper they set r_max = 1 and d_max = 0 (which equals using BatchNorm) for the first 5k steps and then slowly increase both values. Think of it as waiting until the moving averages σ and µ are warmed up, before using batch renormalization. Source: Ibid.

An interesting thing to note is that batch renormalization is really just a generalization of BatchNorm and reverts to its predecessor when σB == σ and µB == µ (or r = 1 and d = 0). This leads us to the question of when to use BatchRenorm and when BatchNorm is enough?

BatchRenorm ≥ BatchNorm?

The good news is that in terms of model performance you can count on BatchRenorm to always be better or equal to BatchNorm.

A head-to-head comparison of training an Inception-v3 model first with BatchNorm and then BatchRenorm using (a) batch size 32 and (b) batch size 4. Source: Ibid.

However, using BatchRenorm comes at the added cost of two hyperparameters (discussed in the caption beneath the BatchRenorm algorithm) for which you have to find the right schedule to get the best performance. So there you have it. As most things in life it is a trade-off between your time and your model’s performance. If you have outsourced your hyperparameter tuning to things like Bayesian optimization¹¹, it’s at least still a trade-off between computing resources and performance.

Personally, I will be using BatchRenorm with the fixed schedule mentioned in the paper¹² from now on. If I have very small batch sizes (or some weird mini-batch sampling curiosity as in Irpan’s post) I might bring myself to do some hyperparameter tuning myself.

Let me know how it goes for you.