Training our model

It’s time to leave Python behind and switch over to Swift, where we’ll use Apple’s Create ML to train a mobile-ready subreddit suggester with our data. Create ML provides high-level APIs that make it easy to train machine learning models for common tasks like text classification. While there isn’t much flexibility (yet), in about 30 lines of code, we can train a decent model.

To get started, create a new Swift Playground in Xcode. Make sure you select the macOS option for the Playground, as that’s the platform where Create ML is installed.

Note that macOS is select in the sub-menu above.

Here’s a gist with all the code you’ll need to train and export your model. I’ve included some instructions and sample output in the comments. You can learn more about MLTextClassifier and other Create ML models here.

One quick aside for those just getting started with machine learning. Notice that we break out dataset into two parts so that 80% of our data belongs to a training set and 20% to a testing set. It’s extremely important that the model never sees the testing data during training. This way, we can guard against overfitting and make sure that our model is going to generalize to new data in the future.

We can see clear evidence of overfitting by comparing accuracy on the training data (97%) to the testing data (63%). To improve things in the future, we should probably gather more data, but for now, it’s a decent start.

A single accuracy number alone, though, shouldn’t make us feel comfortable enough to release the model. More testing is needed. To get a better understanding of what the experience is going to be like for users, let’s go back to Python and dig into this accuracy a bit more.

Testing it out

Imagine we rolled out this autosuggestion feature today, but in a parallel universe. In this parallel universe, the exact same set of users submits the exact same set of posts, but they submit them to the subreddit suggested to them by our model. How many of the posts in this parallel universe end up in the same subreddit as their counterparts in the real world? This is a good measure of how many users we would have saved a click.

Using the same API scraper as before, grab another 100 posts from each subreddit, but this time, they are the 100 most recent submissions. Unless a top 1000 post was submitted within the last 24 hours (unlikely), there shouldn’t be any overlap with the training data. We’ll predict the subreddit each should be submitted to and see how accurate our model is.

It looks like we correctly suggested the subreddit for 55% of posts! Because subreddits receive posts at different rates, it’s tough to say exactly what fraction of all posts in a day we’d save users a click on, but it’s probably pretty high. Digging even further, let’s take a look at a confusion matrix to figure out where our model is doing well and where it’s going wrong.

In a perfect world, we’d see bright yellow on the diagonal indicating the prediction from our model were matching the actual data.

Each row denotes the actual subreddit a post ended up in, while columns are subreddits suggested by our model. The color of the square tells us what fraction of posts in row X were predicted to end up in column Y. Based on our chart, the model does really well with subreddits that have more uniform syntactic structures (e.g. AskReddit titles have question words) as well as subreddits with clue words (e.g. TIL or IAmA). It gets confused with posts going to generic places like r/funny or r/videos and has a tough time distinguishing between subreddits with similar content like r/machinelearning and r/machineslearn .

Finally, I came up with some titles of my own to do some anecdotal testing. I tried to write these with a specific subreddit in mind based on my knowledge of the site. Overall, the model gave me a great suggestion most of the time. The only mistake was the last example, where I think it saw the word Swift and assumed I was talking about the programming language instead of the pop star.

Based on these tests, we can be confident that our model is going to save a large fraction of users a click when submitting posts. We also have some good ideas on where we can make improvements in future versions, like looking for ways to improve performance in generic subreddits like r/funny .

It’d also be nice if the model could output a confidence score to tell us how sure it was of its suggestion. We could use that to decide if the suggestion was worth showing to a user or not. That’s a limitation of Create ML, and it’s something we could include with a little more custom work.

Adding the model to your app

Adding the model to an app requires that you drag the .mlmodel into the navigator and use the Swift classes that Xcode generates to incorporate it into our UX. The following code can be used to call the model.

Final thoughts

Saving users a click here and there can seem small, but these UX improvements can have big impacts on conversion and engagement. Thanks to tools like Apple’s Create ML, you don’t need to be an expert in machine learning save your users some time. I hope more mobile developers feel empowered to make these features a part of their apps.

All of the code can be found in this GitHub repository.

Finally, as a machine learning engineer, it’s a fun exercise to try out tools made for a different developer audience. Swift is a promising language that I plan on learning more of. That said, it’s not quite ready for data science primetime. Here are a few items that are now on my wishlist:

More flexible JSON tools. I initially tried to write my data scraper in Swift, not Python, but the Swift JSONDecorder requires building a complete data model of the payload, which for the Reddit API would have been extremely complicated. There is no option to just create a dictionary. In the end, I found myself back in Python land. Access to lower-level metrics. Create ML is a great high-level abstraction, but it goes a little bit too far. There aren’t enough tools to troubleshoot a model or test it to the point that I’d feel comfortable putting it into production. For example, we measured training, validation, and testing accuracy directly in Swift, but there’s no easy way to compute something like a confusion matrix without writing a lot of code yourself. More descriptive output. I wish the model offered by Create ML also provided a prediction confidence. If we knew when the model was unsure, we could build some better fallbacks into the UX.

Discuss this post on Hacker News and Reddit.