Note: this article originally appeared in Towards Data Science.

Evaluating Keras neural network performance using Yellowbrick visualizations

If you have ever used Keras to build a machine learning model, you’ve probably made a plot like this one before:

{training, validation} {loss, accuracy} plots from a Keras model training run

This is a matrix of training loss, validation loss, training accuracy, and validation accuracy plots, and it’s an essential first step for evaluating the accuracy and level of fit (or overfit) for our model. But there are many nuances to model performance that escape these simple line charts. What if you want to dig deeper?

Yellowbrick is a dataviz toolkit that provides many advanced data and model evaluation plots that let you do just that:

Since I first tried it out a year ago, Yellowbrick definitely grown to be one of the favorite tools in my data science toolbox.

However, the library is designed to work with Scikit-Learn, and is not (yet 😉) compatible with Keras. Luckily we can fix this by building a simple model wrapper that fixes this problem. That’s great news because it unlocks a few advanced model plots that for use in your neural network model evaluation.

This post will show you how!

Note: this post assumes light familiarity with the Keras library.

Building a scikit-learn keras wrapper

yellowbrick is designed to work with machine learning algorithms from the venerable scikit-learn library. Every model in scikit-learn has the same basic API:

The keras API heavily draws from scikit-learn , but adapts it to additional needs (like model compilation) that emerge when training neural networks:

Luckily keras ships with scikit-learn wrappers already built into the library: keras.wrappers.scikit_learn.KerasClassifier for classifiers, and keras.wrappers.scikit_learn.KerasRegressor for regressors:

This is a great feature because it allows you to use your Keras neural networks with Scikit-Learn tools, like cross-validation and grid search. Jason Brownlee has a great post on this subject: “How to Grid Search Hyperparameters for Deep Learning Models in Python With Keras”.

Unfortunately, if our goal is to use Yellowbrick visualizations on Keras models, KerasClassifier / KerasRegressor doesn’t quite go far enough. 😞 There are a couple of problems:

Yellowbrick relies on some Scikit-Learn model semantics that KerasClassifier doesn’t provide. KerasClassifier doesn’t provide alternative training methods, like fit_generator or fit_dataframe , which are hard to give up.

Luckily we can fix this by writing our own subclass, KerasBatchClassifier :

KerasBatchClassifier fixes issue #1 by setting _estimator_type and classes_ properties and adding a diamond dependency on BaseEstimator , and it fixes issue #2 by making fit use fit_generator internally.

The inability to use fit_generator with KerasClassifier is a well-known pain point. This code is an adaptation of (and refinement on) existing solutions from other users — particularly this one.

With this shim in place, we can move on to the fun part: applying yellowbrick visualizations to our neural network models!

Evaluating classification

We’ll start off by checking out yellowbrick classification evaluation plots. For the purposes of this demo, I trained a very basic CNN trained on a subset of images of fruits from the Google Open Images dataset. You can get that dataset here, and you can follow along with the code here.

The simplest of the classification evaluation plots is ClassPredictionError , which provides a stacked bar chart of per-class model predictions:

With this chart in hand we can quickly assess which classes are popular classification targets and which ones are not, and what the most common misclassifications are within a single class.

However, I personally much prefer the ConfusionMatrix :

This visualization lets us quickly zero in on important properties of the model:

Which classes are most accurately predicted

Which classes are least accurately predicted

Which misclassifications are most common

Finally there is ClassificationReport . This provides four essential classification model metrics — precision, recall, f1 score, and support — in an easily digestible visual format:

Evaluating regression

Yellowbrick also packs tools for evaluating regression models. For this demo I trained a simple feedforward neural network that attempts to predict price-per-day for various homes from the Boston AirBnBs dataset on Kaggle. You can see the code for yourself here.

The basic regression analysis plot is PredictionError , which charts predicted values from the model against ground truth values from the dataset:

This chart is useful for identifying patterns in the data (and seeing how well the model adapts to them). For example, by examining the y values in this plot, we can see that users have a strong preference for rentals at multiples of 100.

Then there is ResidualsPlot :

A model residual is the distance between the actual and predicted value of a single record. By putting all of our residuals on a single plot we can assess whether or not our model performs better on some sections of the data then on others. In this case we see that our residuals are larger in magnitude when the predicted value is larger as well — a sign that the model is performing better on smaller values in the dataset than on larger ones.

Conclusion

When working with neural networks, having a library of advanced visualizations you can use to dig into specific properties of your model is essential to your ability to iterate on your model builds quickly and effectively. In this post we saw how we can leverage yellowbrick with keras to build some of these kinds of graphs. Hopefully having read this post, you’re now ready to replace one or two hacky matplotlib code gists lying around with well-maintained, well-architected visualization recipes from this nifty new library.

Interested in learning more about the Python data visualization ecosystem? I recommend watching Jake Vanderplas’s highly entertaining PyCon 2017 talk “The Python Visualization Landscape”.

Interested in learning more about Yellowbrick? In additional to the model evaluation plots showcased here, the library also provides plots for evaluating unsupervised clustering, modeling text, and modeling data features. Take a look.