Neural networks can learn complicated representations fairly easily. However, there are some tasks where new data (or categories of data) is constantly changing. For example, you may train a network to recognize pictures of 8 different types of cats. But in the future, you may want to change that to 12 breeds. If the network has to keep learning new data over time, it is called a continual learning problem. This article talks about a very recent technique that attempts to constantly adapt to new data at a fraction of the cost of retraining entire models.

This article is based on Lifelong Learning with Dynamically Expandable Networks

How this article is structured

This article follows the original paper. I add my own bits and explain the material to simply it.

This article requires some knowledge about neural networks, including weights, regularization, neurons, etc.

What is Continual Learning

Continual Learning just means being able to learn continuously over time. The data arrives in sequences over time, and the algorithm has to learn to be able predict that new data. Usually, techniques like transfer learning are used, where the model is trained on previous data, and some features are used from that model to learn new data. This is usually done to reduce the time required to train models from scratch. It is also used when the new data is sparse.

The Problem

The simplest way to perform such learning is by constantly fine-tuning the model based on newer data. However, if the new task is very different from the old tasks, the model will not be able to perform well on that new task, as features from the old task are not useful, e.g. if a model that is trained on a million images of animals, it will probably not work very well if it is fine-tuned on images of cars. The features learned from animals won’t be very useful when trying to detect cars.

Another problem is that after fine-tuning, the model may begin to perform the original task poorly (in this example, predicting animals). For example, the stripes on a zebra has a vastly different meaning than a striped T-shirt or a fence. Fine-tuning such a model will degrade its performance recognizing zebras.

Introducing Dynamically Expandable Networks

At a very high level, the idea of Expanding Networks is very logical. Train a model, and if it cannot predict very well, increase its capacity to learn. If a new task arrives that is vastly different from an existing task, extract whatever useful information you can from the old model and train a new model.The authors used these logical ideas and developed techniques to make such a construct possible.

The authors introduce 3 different techniques to make such a framework possible. Each method will be discussed in detail , but at a very high level, they are:

Selective retraining — Find the neurons that are relevant to the new task and retain them. Dynamic Network Expansion — If the model is unable to learn from step 1 (i.e. the loss is above a threshold value), increase the capacity of the model by adding more neurons. Network Split/Duplication — If some new models’ units have begun to change drastically, duplicate those weights, and retrain those duplicates, while keeping the old weights fixed.

Figure 1.0 Left: Selective training. Center: Dynamic Expansion. Right: Network Split.

In the above, figure t denotes task number. Thus, t-1 denotes the previous task, and t denotes the current task.

Selective retraining

The simplest way to train a new model would be to train the entire model every time a new task arrives. However, because deep neural networks can get very large, this method will become very expensive.

To prevent such an issue, the authors present a novel technique. At the first task, the model is trained with L1 regularization. This ensures sparsity in the network, i.e. only some neurons are connected to other neurons. We will see why this is useful in a moment.

Equation 1.0 Loss function for initial task.

The W^t denotes the weights of the model at time t. In this case t =1. D_t denotes training data at time t. The right half of the equation, starting from μ, is simply the L1 regularization term, and μ is the regularization strength. L denotes the layers of the network from the first layer to the last. This regulation tries to make the weights of the model close (or equal) to zero. You can read about l1 and l2 regularization here.

When the next task needs to be learned, a sparse linear classifier is fit on the last layer of the model, then the network is trained using:

Equation 1.1 Training the network for the next task

The notation:

Equation 1.2

This means the weights of all the layers except the last layer. All these layers (from the first to the last) are fixed, while just the newly added layer is optimized with the same l1 regularization to promote sparse connections.

Building this sparse connection helps identify those neurons that are affected in the rest of the model! The finding is done using Breadth First Search, which is a very popular search algorithm. Then, only those weights can be updated, saving a lot of computation time, and weights that aren’t connected won’t be touched. This also helps prevent negative learning, where the performance on old tasks degrades.

Dynamic Network Expansion

Selective retraining works for tasks that are highly relevant from older tasks. But when newer tasks have fairly different distributions, it will begin to fail. The authors use another technique to ensure that newer data can be represented by increasing the capacity of the network. They do so by adding additional neurons. Their method will be discussed in detail here.

Suppose that you wish to expand the Lth layer of a network by k neurons. The new weight matrix for that layer (and the previous layer ) will look have dimensions:

Equation 3.0

𝒩 is the total number of neurons after adding the k neurons.

Usually, we don’t want to add k neurons. Instead, we would like the network to figure out the right number of neurons to add. Fortunately, there is already an existing technique that uses Lasso to regularize a network to have sparse weights (which can then be removed). This technique is described in detail in the paper Group Sparse Regularization for Deep Neural Networks.

I won’t go into detail here, but using that on a layer gives such results (Group Lasso is the technique that was used):

Figure 2.0 Comparision between regularisation methods. Grayed cells represent removed connections.

The authors used a layer basis (only on the newly added k neurons) instead of the entire network. The technique was used to nullify as many connections as possible, and keeping only the most relevant ones. Those neurons were then removed, making the model compact.

Network Split/Duplication

There is a common problem in transfer learning called semantic drift, or catastrophic forgetting, where the model slowly shifts its weights so much that it forgets about the original tasks.

Although it is possible to add L2 regularization, which ensures that the weights don’t shift dramatically, it won’t help if the new tasks are very different (the model will just fail to learn after a certain point).

Instead, it is better to duplicate the neurons if they shift beyond a certain range. If the value of a neuron changes beyond a certain value, a copy of the neuron is made, and a split occurs, and that duplicate unit is added as a copy to that same layer.

Specifically, for a hidden unit i, if the l2 distance between the new weight and the old weight ( ρ_i) is > 𝜎, then the split is made. 𝜎 is a hyperparameter. After the split, the entire network will need to be trained again, but the convergence is fast because the initialization is not random, but has a reasonably optimal value.

Training and Evaluation

Datasets

Three datasets were used. They are, namely:

MNIST-Variation. This dataset consists of 62, 000 images of handwritten digits from 0 to 9. The digits are rotated and have noise in the background(unlike MNIST). CIFAR-10. This dataset is consists of 60, 000 images of generic objects, including vehicles, and animals. AWA-Class. This dataset consists of 30, 475 images of 50 animals.

Models

To compare performance, a variety of models were used. They are: