Ask a birdwatcher how they recognize an Arctic tern and they’ll probably tell you it’s the pointy tips on the long angular wings or the small round head and short legs. An AI model meanwhile might be excellent at classifying bird species, but could not tell you how it does that. Until now.

Exciting new research from Duke University introduces ProtoPNet, a deep learning network that can explain how it distinguishes a pigeon from a partridge in real time. The paper will be presented at next month’s NeurIPS 2019 conference in Vancouver.

ProtoPNet architecture

The researchers trained and evaluated ProtoPNet on the CUB-200–2011 dataset, which comprises 11,788 photos of 200 bird species.

ProtoPNet consists of layers from convolutional neural network models such as VGG-16, a prototype layer, and a fully connected layer. Rectified Linear Unit was used as the activation function for the first two types of layers, and the sigmoid activation function was used for the last layer.

Different features of the source image are extracted by convolutional layers as prototypical parts. These are then compared to training images to produce an activation map showing the similarity between them.

The reasoning process ProtoPNet uses to identify a test image.

Researchers identified innovations in their deep learning network:

Unlike post-hoc methods which create interpretation after training the network, ProtoPNet has a built-in case-based reasoning process that generates explanations during classification;

While attention-based models only point out determinants of the test image, ProtoPNet can not only do the same thing but also visualize which prototypical parts were being compared to the test image;

The feature extraction of Bayesian Case Models is performed by Scale Invariant Feature Transform (SIFT), whereas with an end-to-end training method, ProtoPNet uses a specialized neural network architecture for feature extraction and prototype learning;

ProtoPNet does not require a decoder for prototype visualization, and its accuracy can be improved by adding up ProtoPNet models, which creates more prototypes for each class

Experiment result shows the test accuracy of a combined ProtoPNet (84.8%) is compatible with other state-of-the-art deep learning models. Researchers also did a second experiment on car models with a combined network of a VGG19-, ResNet34-, and DenseNet121-based ProtoPNet, which scored 91.4% accuracy.

A key feature of the method is the “heat maps” it generates to show which parts of a bird (belly, wings, beak etc.) were particularly useful for classification — essentially showing us how it reasons while distinguishing bird species. This adds information and insight lacking in common black box classification models.

The case-based reasoning feature offers the opportunity for ProtoPNet to be applied to medical images to help doctor make diagnosis, which is the researchers’ next project. Other possible applications are in training self-driving vehicles to better identify pedestrians, traffic signs, etc., or to spot suspects in surveillance camera streams.

The paper This Looks Like That: Deep Learning for Interpretable Image Recognition was first released in June 2018 and selected by ICML 2018 while work was still in progress. An update was released in September 2019 and modified yesterday on OpenReview. The first author of this paper Chaofan Chen has also uploaded the correlated code package on GitHub.