Facial Similarity with Siamese Networks in PyTorch

37,207 reads

This is Part 2 of a two part article. You should read part 1 before continuing here.

In the last article discussed the class of problems that one shot learning aims to solve, and how siamese networks are a good candidate for such problems. We went over a special loss function that calculates similarity of two images in a pair. We will now implement all that we discussed previously in PyTorch.

You can find the full code as a Jupyter Notebook at the end of this article.

The Architecture

We will use a standard convolutional neural network architecture. We use batch normalisation after each convolution layer, followed by dropout.

There is nothing special about this network. It accepts an input of 100px*100px and has 3 full connected layers after the convolution layers.

But where is the other Siamese ?

In the previous post, I showed how a pair of networks process each image in a pair. But in this post, there is just one network. Because the weights are constrained to be identical for both networks, we use one model and feed it two images in succession. After that we calculate the loss value using both the images, and then back propagate. This saves a lot of memory at absolute no hit on other metrics(like accuracy).

Contrastive Loss

We defined contrastive loss to be

Equation 1.0

And we defined Dw(which is just the euclidean distance)as :

Equation 1.1

Gw is the output of our network for one image.

The contrastive loss in PyTorch looks like this:

The Dataset

In the previous post I wanted to use MNIST, but some readers suggested I instead use the facial similarity example I discussed in the same post. Therefore I switched from MNIST/OmniGlot to the AT&T faces dataset.

The dataset contains images of 40 subjects from various angles. I put aside the last 3 subjects from training to test our model.

Figure 1.0. Left: Samples from different classes. Right: All Samples of one subject

Data Loading

Our architecture requires an input pair , along with the label (similar/dissimilar). Therefore I created my own custom data loader to do the job. It uses the image folder to read images from folders. This means that you can use this on any dataset that you wish.

The Siamese Network dataset generates a pair of images , along with their similarity label (0 if genuine, 1 if imposter). To prevent imbalances, I ensure that nearly half of the images are from same class, while the other half is not.

Training the Siamese Network

The training process of a siamese network is as follows:

Pass the first image of the image pair through the network. Pass the 2nd image of the image pair through the network. Calculate the loss using the ouputs from 1 and 2. Back propagate the loss to calculate the gradients. Update the weights using an optimiser. We will use Adam for this example.

The network was trained for 100 epochs, using Adam and a learning rate of 0.0005. The graph of the loss over time is shown below:

Figure 2.0 Loss value over time. The x axis is number of iterations

Testing the Network

We had held out 3 subjects for the test set, which will be used to evaluate the performance our model.

To calculate the similarity, we just calculate the Dw(Equation 1.1). The distance directly corresponds to the dissimilarity between the image pair. A high value of Dw indicates higher dissimilarity.

Figure 3.0 Some outputs of the model. Lower values indicate more similarity, and higher values indicate less similarity.

The results are quite good. The network is able to distinguish between the same person even when they are from different angles. It also does a good job at discriminating dissimilar images.

Conclusion

We discussed and implemented a siamese network to discriminate between pairs of faces for facial recognition. This is useful when there are few (or just one) training examples of a particular face. We used a discriminative loss function to be able to train a neural network.

You can find the entire code in my repo:

P.S.

If you liked this article, please ❤ it down below to share it. Suggestions are welcome, and if you did not understand something, feel free to ask me.

Tags