4 Steps to Finding the Right Deep Learning Model

Escape Beginner Mistakes When First Applying Deep Learning

If you’ve been looking to make your application machine learning-powered, you‘ll notice that there’s a flood of ML models and model implementations out there that may or may not fit your task. Especially if you’re not deeply intimate with a specific genre of models, it can be overwhelming to pick what model implementation to adopt for your project.

I’ve compiled these 4 steps you should understand when picking your next machine learning mode after talking to hundreds of engineers and their ML projects as part of ModelDepot!

1. Understanding the Problem Domain

Originally from PublicDomainPictures.net

While you might be building a hot dog locator, the model you’re looking for might not be called a “hot dog locator”. The difference in how we might think of a model from a user vs ML researcher can make it frustrating to find the right model.

For the hot dog location problem, it’s an “object detection” problem under the “computer vision” category. In fact there’s also a dataset that exists that has bounding boxes around hotdogs called COCO!

When you think about your problem at hand, the easiest way to translate it into the right ML terms would be to think about your inputs. Is it text or images? Those would usually correspond to natural language processing (NLP) or computer vision (CV) respectively. From there you’d want to dive deeper into that field to find out what kind of sub-problems exist, such as sentiment classification in NLP. Furthermore, you can explore datasets that might already contain items of interest (ex. hotdogs) to scope down models trained on that specific dataset. Sometimes getting the jargon right can be tricky, so using a user-friendly ML model search tool, such as ModelDepot, can help you quickly find and understand models that can help your use case.

2. Finding the “Right” Accuracy

“Machine Learning” from xkcd

It might be obvious that accuracy is something you should care a lot about, but simply trusting any accuracy number can end badly. There’s several things to keep in mind while thinking about accuracy.

Accuracy Metric

There’s a plethora of different metrics depending on what problem you’re solving. Each specific problem domain in ML has a set of standard metrics that are relevant. It’s extremely important to figure out which metrics are the most important for you!

For example, if we were building a credit card fraud detection system, and only considered correct_predictions/all_predictions (aka. accuracy) we can simply develop a model that always returns “not fraud” and get > 99% accuracy since most transactions are not fraud! Therefore, it’s important to pick the right metric for your task!

Reported Accuracy

The reported accuracy is a good start to figuring out if a model meets your applications requirements. Almost always, the original paper of the model will report accuracy metrics for the model. Make sure that you understand how the metric they used relates to the metric you’re using if the two are different. But also understand that their dataset might be different from the task at hand you have and a 2% improvement on their problem might not matter too much to you in the end.

Your Own Accuracy

If you find something that seems to have reasonable reported accuracy metrics, you’ll want to test out the model yourself to see how well the model will do for you. Ideally you have a test set of inputs that your model would expect to receive (ex. emails, reviews, etc.) and the corresponding expected output. Testing the model on your own data is the best way to ensure that it’ll perform well for your use case, though is also the most laborious way.

In-Browser Live Demos on ModelDepot

There are certain ways to quickly demo models, such as using ModelDepot’s online demo feature. You can quickly feed models with example input and see the outcome of the model in less than a minute. You can also try out the model in online environments such as Google Colab to skip setting up a local dev environment.

3. Knowing Your Data

Depending on how much data you have or are willing to collect, your approach to finding models will vary widely! Building from scratch isn't the only approach, and can actually be the worst approach depending on your data! Let’s dive into some cases.

I Have a lot of Data

If you have a lot of training data, you’ll want to look for models that have easily accessible training scripts to train your model from scratch. Getting DL models to converge can be very difficult; to make your life easier, you should look for projects on Github that look active. Having a supportive community around a model can go a long way to help you out.

I Have Some Data

If you only have some data, you might be able to get away with using a training technique called “transfer learning”. Transfer learning allows you to take a pre-trained model on a similar domain to tune the model to work well for your specific problem using your small amount of training data. You’ll want to look for pre-trained models that are easy to “dissect” and re-train. You can find some at Tensorflow Hub or Keras Applications.

I Only Have a Handful of Examples

No worries! Having a handful of examples is a great start. Look for models that are exclusively pre-trained and use your examples to be the “test set” to evaluate how those models perform on your data. Luckily there’s several places you can look for to check out pre-trained models, such as various model zoos for each framework: Tensorflow, Caffe, ONNX, PyTorch. ModelDepot also offers a more general searching interface for pre-trained models to help pick the right ML model.

4. Picking the Architecture

Accuracy vs Speed Tradeoff (Figure 2 from https://arxiv.org/pdf/1611.10012.pdf)

We can now look at the architectures behind the models if 1) the models have acceptable accuracy on your own data and 2) are easily retrainable or come with pre-trained models.

Accuracy, Speed and Size

One of the biggest practical considerations is the speed vs accuracy tradeoff. Researchers have developed a wide variety of architectures to match the different use cases that applications might encounter in the real world. For example, perhaps your model is supposed to run on a compute-limited mobile phone, so you might be looking for a MobileNet architecture that’s lightweight and fast. Otherwise, if you’re not compute constrained but want to have the best accuracy, you can go with whatever the state of the art promises you the best accuracy, regardless of how slow or big the model is.

Some models might have lightweight variants such as PSPNet50 vs full PSPNet that cuts down on the number of layers to make it faster and slimmer. Other times you can look to techniques such as pruning or quantization to make a model even smaller and faster.

Done!

With those four steps, you can navigate from knowing what problem you want to solve, to selecting a few models that could best solve your problem as quickly as possible.

There are other considerations such as ML framework, code quality, or the reputation of the model author but those considerations are usually a luxury that you’ll afford when moving beyond the PoC/MVP stage of integrating ML in your product. Let me know your thoughts on how you decide on finding ML models in the comments!