Recently I came across this idea of center loss described in this paper. You define the outputs from the second last layers of the neural network as embeddings. For this loss, you define a per class center which serves as the centroid of embeddings corresponding to that class. The center loss term is defined as:

Equation for center loss

As the network gets updated with gradient descent, the per class center term needs to be updated. Ideally the update would involve going through the entire training data, but that is not feasible in practice. Thus the update is done over the mini-batch and a hyperparameter ‘alpha’ controls the learning rates of the centers. The update is given by:

Update equation for centers

Another scalar ‘lambda’ is used to balance the two loss functions. The total loss used for training the neural network is given by:

Total loss used for training the neural network

To see how the distribution of learned feature changes with the addition of this loss term, authors trained a neural network having embedding of size 2 with different values of lambda on the mnist dataset. This is what the features looked like when plotted:

plot of embeddings trained using different values of lambda

As it can be seen, as lambda increases features are more spread apart from each other.

Results

The authors of the paper used various face recognition datasets to test their results. Following were the results on different datasets:

model A was trained using standard softmax loss, model B using softmax loss with contrastive loss, model C using softmax loss with center loss.

Read this paper to understand contrastive loss.

results on Labelled faces in the wild and Youtube faces dataset

. Identification rates on MegaFace with 1M distractors

. Verification TAR at 10–6 FAR on MegaFace with 1M distractors

As it can be seen, on Megaface (which is a very challenging benchmark) center loss smashed all the previously published results.

My experiments on different datasets

Face recognition typically involves training on a dataset involving a large number of classes (~10k).

I wanted to check whether this idea will work with datasets involving smaller number of classes. I conducted experiments on 4 datasets in that directions. The datasets were cluttered-mnist, fashion-mnist, cifar-10 and cifar-100.

Results on cluttered-mnist