Neural networks have shown to provide significant improvement in prediction accuracy by stacking multiple hidden layers, hence we see the trend towards more and more usage of deep learning. The ImageNet challenge exemplifies this trend towards deep neural networks, from Alex-Net ( 8 layers ), VGG-Net ( 19 layers ) to ResNet ( 152-layers ) within a period of 4 years. Such improvement in accuracy by increasing the depth of neural network comes with the cost of additional execution time and power consumption.

Even if we disregard the additional execution time during training of the neural network, overlooking this addition in execution time during inference could result in negative user-experience. Improving inference can be envisioned as how accurately we can infer and how quickly can we infer.

Deeper layered networks help to improve the former argument for inference, while adding negative impact for the later. This article elaborates the idea of BranchyNet for fast inference through early-exits.

Let’s say we participate in an experimental math competition. We are given 3 mathematical problems to solve, each with an increasing level of difficulty.

Problem 1: Finding a square root of 4. (difficulty x)

Problem 2: Finding a square root of 1000. (difficulty y)

Problem 3: Finding a square root of 13371337. (difficulty z)

For simplicity we assume that solving a problem of difficulty x takes t_x seconds, of difficulty y takes t_y seconds and of difficulty z takes t_z seconds and t_x<t_y<t_z. The total time taken for human (t_h) to solve the entire set of mathematical problems would be the summation of the difficulty of the problems in terms of time (t_x+t_y+t_z).

Now, we train a deep neural network to participate in this experimental math competition as well. The neural network has to be designed such that it can solve the hardest problem of difficulty z (assuming it also take t_z seconds). The time taken for the neural network to solve each problem is equal to the time required to complete execution through all the layers of the neural network. The total time taken for the neural network (t_n)would be equal to the number of mathematical problems multiplied by the time taken to solve each problem (3t_z).

Since the time required by the neural network (3t_z) is greater than the time taken by human (t_x+t_y+t_z) to solve the same set of mathematical problems, human will be the winner in this case.

The above example is purely based on assumptions, it may not necessarily be true in the real world given how fast machines can compute compare to human brain. What we are trying to convey with the help of the math competition example is that the deep neural network performs the same amount of work to predict any result regardless of the input and the total time is directly proportional to the number of layers. Each layer extracts different high-level features to aid in prediction of the input label.

Confidence of predicting an input label is unevenly distributed between layers depending upon the features extracted. For instance, in digit recognition, certain features such as no curved lines would provide high confidence that the digit we are trying to predict is 1,4 or 7 and not 0, 2, 3, 5, 6, 8, or 9. Even though the neural network is highly confident about the prediction label of the input at earlier layers of the network, it still has to process through all the layers of the network before predicting the label due to its current structure. BranchyNet relaxes the rigid structure of the neural network by allowing early exits for high confident predictions.

The BranchyNet architecture comprises of a single entry point and multiple exit points. An entry point can be viewed as an input layer and exit points can be viewed as an output layer. A branch is a subset of contiguous layers, which do not overlap with other branches, followed by an exit point. The original neural network can be viewed as a main branch and layers branching out from the main branch can be viewed as side branches. Each branch must resolve to an exit and can be numbered increasingly based on its distance from the entry point.

Figure 1. BranchyNet with 2 branches added to AlexNet.

The above figure depicts BranchyNet with 2 side branches added to the AlexNet main branch. The boxes indicate the exit points, we have 2 early exits and one exit of the main AlexNet branch. BranchyNet can be used for fast inference when the input label is predicted at earlier layers of the neural network with high confidence. If the classifier at the exit point has sufficient confidence about the label of an input sample, the sample can exit and no further computation on the input sample through deeper layers is needed. An entropy is used to determine the confidence of a classifier at an exit point for the sample input.

where y is a vector containing predicted probabilities for all possible class labels.

Predicted probability for a class label is the measure of the likelihood that the input can be classified as the class label. A simple example would be classifying a digit number for the class labels ( 1, 2, 3, 4, 5, 6, 7, 8 , 9, 0 ). The predicted probability vector [ 0.05, 0.05, 0.1, 0.05, 0.05, 0.1, 0.45, 0.05, 0.05, 0.05 ] corresponds to the predicted probability for each class label ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 0 ).

Entropy measures the amount of information conveyed in the predicted probability vector of each exit branch. The higher the entropy the more information the neural network has about a particular input. The more information the neural network has the more confidence it is about that particular input.

During inference time, an entropy at an exit point is compared to the corresponding threshold set for that particular exit point. If the entropy is lesser than the threshold value, then the sample input is allowed to exit at the particular exit point. A high-level execution of the inference phase of the BranchyNet can be viewed as follows. We have a vector T denoting threshold for each exit, i.e., T[1] is the threshold for exit 1. The inference phase begins with the lowest exit points. The input sample is fed through the branch of the lowest exit point and the entropy is compared to the threshold value. If the entropy is lower than the threshold value, then the sample input can exit the network at that exit point. If not the next lowest exit point is chosen and the procedure is continued similarly until the highest or the final exit point is reached.

An important hyperparameter apart from the threshold in BranchyNet, is the weights in joint optimization during the training of the BranchyNet. BranchyNet involves usage of a loss function as an optimization objective, i.e., for classification specific task, we can use softmax cross entropy loss function. Each exit branch is designed to minimize this loss function. To train the entire BranchyNet, we form a joint optimization problem with the weighted sum of the loss functions on each exit branch. The earlier branches helps increasing the accuracy of the later branches due to the added regularization which prevents over-fitting. Giving more weight to the earlier branches allows the earlier layers to learn more discriminative features and allows more samples to exit the network early; thus, reducing the average time taken by a network to infer a label for an input sample.

The BranchyNet architecture synergizes well with the distributed execution of artificial intelligence models for inference. Since each branch (subset of layers) can be executed independently, the branches can be distributed on different nodes to execute in parallel. In case of an early exit, the remaining execution of the branches could be deterred by communication between nodes. This distributed nature of BranchyNet allows data scientists to deploy deeper models with very little trade off in time. Deeper models would enable more sophisticated feature extraction to solve more complex problems using Artificial Intelligence.

Let’s look at an example of how a 10-layered deep BranchyNet would be executed. The 10-layered deep BranchyNet has 4 different branches, 3 side branches and 1 main branch. We consider 4 nodes for executing these 4 different branches for simplicity. Node 1 executes the main branch till layer 3 and then communicates the output of layer 3 to Node 2. Node 1 in parallel continues executing the branch with exit 1. Node 2 executes the main branch from layer 4. Similarly Node 2 would communicate the output of layer 6 to Node 3, and in parallel continue to execute the branch with the exit 2. If a node can exit early, i.e., high confidence is obtained by a node at an exit, the result would immediately be communicated back to the client, thereby deterring the further execution of the BranchyNet and resulting in a faster response time. Early exit for higher confidence inputs and parallel execution of branches contribute to heavy reduction of the inference time. The execution time of the layers could further be reduced by parallelizing the execution of each neural network layer itself.

Figure 2. Execution of 10-layered BranchyNet.

Heavy reduction in inference time of Artificial Intelligence models increases their adoption for more real-time prediction tasks. Inference using BranchyNet could utilize paired execution by local nodes (Nodes on which an application is running, can be a Computer, Mobile or IoT device) and remote network nodes. Local nodes would execute the lowest branches and would predict an output locally in case of early exit, the remote nodes would begin execution of the latter branches simultaneously as the local node, thereby reducing time of inference in case of no local early exit. The time taken in case of no early exit locally would be the difference between execution time required for the layers of the exit branch and execution time of layers on the local node.

If you are interested to learn more about BranchyNet, the research paper is available at https://arxiv.org/abs/1709.01686.

Please let me know if you have any feedback or suggestion on the topic. I am happy to discuss.

I am working on a tool for developers to easily augment a trained deep neural network model with branches for fast inference and will share about the progress in a later post.