Knowledge Distillation

Intuition

The red plot is of the large model(Teacher), the blue plot is of the student network trained without distillation, and the purple plot is of the student network trained using distillation.

Our main objective is to train a model that could generate good results on the test dataset. To do this, we train a cumbersome model that could be an ensemble of many different trained model or simply a very large model trained with a regularizer such as a dropout. Once the cumbersome model(teacher) is ready we use a technique called distillation to transfer the complex knowledge learned by the cumbersome model to a small model(student) that is more suitable for deployment in devices with memory and heavy computation constraints.

By complex knowledge what I mean is that the cumbersome model can discriminate between a large number of classes. While doing so, it assigns probabilities to all the incorrect classes. However, these probabilities are very small, they tell a lot about how the cumbersome model generalizes. For-eg — An image of a dog is very less likely to be mistaken as a cat, but that mistake is still many times more than mistaking it for a pigeon.

The cumbersome model needs to generalize well, so when we are distilling knowledge to the smaller model, it should be able to generalize in the same way as the large model. For-eg — if a large model is an ensemble of many different models, the small model learned using distilling knowledge will be able to generalize better on test datasets compared to a small model that is trained in the normal way on the same training set as was used to train the ensemble.

An obvious way to transfer the generalization ability of the cumbersome model to a small model is to use the class probabilities(generated with a technique called distillation) produced by the cumbersome model as “soft targets” for training the small model. For the transfer stage, we can use the same training set or a separate set(consist entirely of unlabelled data).

Scatters dots are the data points in the dataset. The blue plot is the cumbersome model(teacher) trained on a complex dataset and the red plot is the small model(student) generalizing on the soft targets generated from the teacher model.

For the soft target, we can take the geometric mean of all the individual predictive distributions. When the soft target has high entropy(you can read more about entropy here), they provide much more information per training case than hard targets and much less variance in the gradient between training cases, so the small model can often be trained on much less data than the original cumbersome model and using a much higher learning rate.

Distillation

Distillation is a general solution for producing soft targets probability distributions. In this, we produce class probabilities using a softmax activation function that converts the logits z_i, computed for each class into a probability, q_i:

Equation-1

Here the T is the temperature which is normally set to 1. Using higher probability distribution results in a softer probability distribution. z_i are the logits computed for each class for every data point in the dataset.

In the simplest variant of distillation, knowledge is transferred to a small model by training the cumbersome model on a transfer set and using a soft target distribution for each case in the transfer set, produced by high temperature in its softmax function. The same high temperature is used while training the small or distilled model, but after it has been trained temperature is set to 1.

This technique can be significantly improved if we have the correct labels for the transfer set. One way to do this is to take the weighted average of the two objective functions. The first objective function is the cross-entropy with the soft targets and this cross-entropy is computed using the same high temperature in the softmax of the distilled model as was used for generating the soft targets from the cumbersome model. The second objective function is the cross-entropy with the correct labels. This is computed using exactly the same logits in softmax of the distilled model but at a temperature of 1. Keeping the weight of the second objective function considerably small generated the best results.

Since the magnitudes of the gradients produced by the soft targets scale as 1/T² it is important to multiply them by T² when using both hard and soft targets. This ensures that the relative contributions of the hard and soft targets remain roughly unchanged if the temperature used for distillation is changed while experimenting with meta-parameters.

Running Through An Example

Process of deploying deep learning models in embedded devices.

In this example, we will be considering that we are working on a very complex dataset(X) and getting good accuracy on it using a shallow network(could be deployed in mobile devices) is next to impossible, with the traditional model building techniques. So to get good accuracy we will use the knowledge distillation technique:

Firstly, we will train dataset X in a datacenter using a large model(teacher), that is able to break down the complexity of the dataset X to produce good accuracy.

Now we will use the original dataset X or a transfer set(a smaller version of dataset X) to generate soft targets from the large model, and then use them to get the probability distributions for each case in the dataset using the distillation technique.

The next step is to build a shallow network(student) based on the memory and computation constraint of the mobile device.

This shallow network will train on the probability distributions, produced in the second step and generalize on it.

Once this model is trained, it is ready to be deployed in mobile devices.

To get an estimate of how good are the results of the distilled model, we will look at the numbers that the writers of the paper Knowledge Distillation in a Neural network were able to achieve on the MNIST dataset using knowledge distillation.

Results

To see how well distillation works, the MNIST dataset was trained using a very large neural network(teacher) consisting of 2 hidden layers, each having 1200 rectified linear hidden units. The following model has trained over 60,000 training cases. The network achieved 67 test error, whereas a smaller network(student) with two hidden units, each having 800 rectified hidden units 146 test errors.

But if the smaller net was regularized solely by adding the additional task of matching the soft targets produced by the large net at a temperature of 20, it achieved 74 test errors. This shows that soft targets can transfer a great deal of knowledge to the distilled model, including the knowledge about how to generalize that is learned from translated training data even though the transfer set does not contain any translations.