I recently came across an ICML’17 paper “Attentive Recurrent Comparators” which proposes a simple yet powerful model for data efficient learning. The paper presents the first super-human One-shot Classification performance on the Omniglot dataset using only raw pixel information!

In this blog post I am going to present my understanding of the main ideas of the paper. The authors of the paper also released an implementation using Theano. But I found the implementation a bit difficult to follow. So as an exercise I actually went ahead and re-implemented it in PyTorch.

ARC comparing two similar characters from the Omniglot dataset. The ARC uses an attention mechanism to look back and forth between the two images and judge their similarity. Source: https://github.com/sanyam5/arc-pytorch

Motivation

Data-efficient machine learning is the buzz phrase right now. The idea is to design ML algorithms that perform well but don’t require 100’s of thousands of annotated data points. An ideal ML algorithm would require just one data point to learn an entire concept. Such a concept learning system is said to be a One-shot Learning system. The challenge with doing One-shot Learning is due to the difficulty in having Dynamic Representations.

Imagine a system trained on images of buildings, cars and animals. If the new concepts of fruits and faces are introduced, the current feature set consisting of wheels, fur, windshield, trunk etc., are virtually useless to identify them. The authors call the representations that are formed by observing a fixed set of features Static Representations. It is clear that the current feature set must evolve to include features that can recognize the new entities. Ideally, the model should be in control of what features it observes and how lower-level features combine to form higher-level features. The authors call the representations formed by observing a continually evolving set of features Dynamic Representations.

Dynamic Representations

Dynamic Representation is basically coming up with a feature set on the fly, lazily. When you see an apple — you start thinking of its color, shape. When you see a face — you start thinking about color of eyes, shape of the nose, etc.

One clever way of having Dynamic Representations is by encoding a given sample in the context of other samples. For example:

When asked to differentiate between fruits like Apple, Orange, Guava, you try to form a representation of one fruit in terms of all others: How is Apple different from Orange and Guava? Similarly, when asked to differentiate between faces, you try to understand how is Face A different from other faces in the dataset.

ARCs

The paper first presents a simple model — “Attentive Recurrent Comparators” (ARCs) — for learning to differentiate between two given images. The model derives its motivation by observing how humans find points of difference between two given images.

“Spot the difference”. Source https://commons.wikimedia.org/wiki/File:Spot_the_difference.png

A human trying to differentiate between pair of images will not try to understand everything about the first image before taking a look at the second image. It’s just too much data to process all at once, most of which will be irrelevant to the task. The human instead takes alternating looks to understand what to look at.

You might see something in the first image simply because it is not present in the other image or vice versa. ARCs incorporate this aspect into them by using attention. The Attention mechanism opens a pathway for the Neural Networks to “ask” for a portion of the data.

The ARC architecture makes a clever use of the attention mechanism. At every time step it takes “glimpses” alternating between the first image (Image A) and the second image (Image B) similar to the way a human doing ‘Spot The Difference’.

Source: Attentive Recurrent Comparators, ICML’17

At the heart of ARC is a “controller” which is basically a Recurrent Neural Network (RNN) which at every time step t, in its current hidden state h(t-1) takes as input a glimpse G(t) and moves to a new hidden state h(t). The way this glimpse is generated is interesting. A small Neural Network (not shown in the diagram) converts h(t-1) to omega(t). The omega(t) are what the paper calls the “glimpse parameters”. The omega(t) can be thought of as a tuple (x, y, delta) from which a glimpse centered at (x, y) and with a zoom factor of delta.

But, there is more to it. For a Neural Network to learn through gradient descent all functions used by it must be smooth (differentiable). It does not suffice to just crop a portion of the image and feed it as the glimpse. Soft or differentiable attention is used in modern Deep Learning for generating pixels of the glimpse by taking a weighted sum of ALL pixels in the image. The weights smoothly decay as one moves away from the pixels that the glimpse is trying to encode. The paper proposes using Cauchy decay kernels instead of the traditional Gaussian kernels. The reason for this choice is that Cauchy kernel is smoother than the Gaussian kernel (which decays too fast) and Neural Network learn smoother functions faster.

Though the RNN controller now has full control over where to focus it comes at the cost of “pixelating” if it tries to see larger regions. This is because the number of pixels in the glimpse are fixed and less than the number of pixels in the given image. The RNN controller must carefully choose what to see.

ARC Binary Classifier

So we have defined the architecture of ARCs. Let’s test it out on the simple binary classification task as a sanity check.

How about we take the final hidden state H(T) as these encodings and feed them to a simple Neural Network whose task would be to tell if the two images belong to the same class or not? Virtually any dataset with independant classes can be used for this type of training.This paper specifically uses the Omniglot and the CASIA Webface datasets.

Feeding the encoding H(T) from ARC ( in red dotted box) to Linear Layer Classifier

The Omniglot data-set has alphabets from 50 languages and the challenge is to use only 30 of them for training and validation and test on the 20. This means the network must learn to differentiate between characters it has never even seen. At the risk of slight exaggeration, it’s like letting a Neural Network train on images of various types of Animals, various types of Cars and various types of Buildings and then giving it two different photos of two different fruits (which may or may not belong to the same species) expecting it to correctly predict whether they belong to the same species or not!

When I put to train my PyTorch implementation of the ARCs , it worked liked a charm. It was so much fun to visualize the attention mechanism at work.