A quick introduction to machine learning in R with caret

If you’ve been using R for a while, and you’ve been working with basic data visualization and data exploration techniques, the next logical step is to start learning some machine learning.

To help you begin learning about machine learning in R, I’m going to introduce you to an R package: the caret package. We’ll build a very simple machine learning model as a way to learn some of caret’s basic syntax and functionality.

But before diving into caret, let’s quickly discuss what machine learning is and why we use it.

What is machine learning?

Machine learning is the study of data-driven, computational methods for making inferences and predictions.

Without going into extreme depth here, let’s unpack that by looking at an example.

A simple example

Imagine that you want to understand the relationship between car weight and car fuel efficiency (I.e., miles per gallon); how is fuel efficiency effected by a car’s weight?

To answer this question, you could obtain a dataset with several different car models, and attempt to identify a relationship between weight (which we’ll call wt ) and miles per gallon (which we’ll call mpg ).

A good starting point would be to simply plot the data, so first, we’ll create a scatterplot using R’s ggplot:

require(ggplot2) ggplot(data = mtcars, aes(x = wt, y = mpg)) + geom_point()





Just examining the data visually, it’s pretty clear that there’s some relationship. But if we want to make more precise claims about the relationship between wt and mpg , we need to specify this relationship mathematically. So, as we press forward in our analysis, we’ll make the assumption that this relationship can be described mathematically; more precisely, we’ll assume that this relationship can be described by some mathematical function, . In the case of the above example, we’ll be making the assumption that “miles per gallon” can be described as a function of “car weight”.

Assuming this type of mathematical relationship, machine learning provides a set of methods for identifying that relationship. Said differently, machine learning provides a set of computational methods that accept data observations as inputs, and subsequently estimate that mathematical function, ; machine learning methods learn the relationship by being trained with an input dataset.

Ultimately, once we have this mathematical function (a model), we can use that model to make predictions and inferences.

How much math you really need to know

What I just wrote in the last few paragraphs about “estimating functions” and “mathematical relationships” might cause you to ask a question: “how much math do I need to know to do machine learning?”

Ok, here is some good news: to implement basic machine learning techniques, you don’t need to know much math.

To be clear, there is quite a bit of math involved in machine learning, but most of that math is taken care of for you. What I mean, is that for the most part, R libraries and functions perform the mathematical calculations for you. You just need to know which functions to use, and when to use them.

Here’s an analogy: if you were a carpenter, you wouldn’t need to build your own power tools. You wouldn’t need to build your own drill and power saw. Therefore, you wouldn’t need to understand the mathematics, physics, and electrical engineering principles that would be required to construct those tools from scratch. You could just go and buy them “off the shelf.” To be clear, you’d still need to learn how to use those tools, but you wouldn’t need a deep understanding of math and electrical engineering to operate them.

When you’re first getting started with machine learning, the situation is very similar: you can learn to use some of the tools, without knowing the deep mathematics that makes those tools work.

Having said that, the above analogy is somewhat imperfect. At some point, as you progress to more advanced topics, it will be very beneficial to know the underlying mathematics. My argument, however, is that in the beginning, you can still be productive without a deep understanding of calculus, linear algebra, etc. In any case, I’ll be writing more about “how much math you need” in another blog post.

Ok, so you don’t need to know that much math to get stared, but you’re not entirely off the hook. As I noted above, you still need to know how to use the tools properly.

In some sense, this is one of the challenges of using machine learning tools in R: many of them be difficult to use.

R has many packages for implementing various machine learning methods, but unfortunately many of these tools were designed separately, and they are not always consistent in how they work. The syntax for some of the machine learning tools is very awkward, and syntax from one tool to the next is not always the same. If you don’t know where to start, machine learning in R can become very confusing.

This is why I recommend using the caret package to do machine learning in R.

A quick introduction to caret

For starters, let’s discuss what caret is.

The caret package is a set of tools for building machine learning models in R. The name “caret” stands for Classification And REgression Training. As the name implies, the caret package gives you a toolkit for building classification models and regression models. Moreover, caret provides you with essential tools for:

– Data preparation, including: imputation, centering/scaling data, removing correlated predictors, reducing skewness

– Data splitting

– Model evaluation

– Variable selection

Caret simplifies machine learning in R

While caret has broad functionality, the real reason to use caret is that it’s simple and easy to use.

As noted above, one of the major problems with machine learning in R is that most of R’s different machine learning tools have different interfaces. They almost all “work” a little differently from one another: the syntax is slightly different from one modeling tool to the next; tools for different parts of the machine learning workflow don’t always “work well” together; tools for fine tuning models or performing critical functions may be awkward or difficult to work with. Said succinctly, R has many machine learning tools, but they can be extremely clumsy to work with.

Caret solves this problem. To simplify the process, caret provides tools for almost every part of the model building process, and moreover, provides a common interface to these different machine learning methods.

For example, caret provides a simple, common interface to almost every machine learning algorithm in R. When using caret, different learning methods like linear regression, neural networks, and support vector machines, all share a common syntax (the syntax is basically identical, except for a few minor changes).

Moreover, additional parts of the machine learning workflow – like cross validation and parameter tuning – are built directly into this common interface.

To say that more simply, caret provides you with an easy-to-use toolkit for building many different model types and executing critical parts of the ML workflow. This simple interface enables rapid, iterative modeling. In turn, this iterative workflow will allow you to develop good models faster, with less effort, and with less frustration.

Caret’s syntax

Now that you’ve been introduced to caret, let’s return to the example above (of mpg vs wt ) and see how caret works.

Again, imagine you want to learn the relationship between mpg and wt . As noted above, in mathematical terms, this means identifying a function, , that describes the relationship between wt and mpg .

Here in this example, we’re going to make an additional assumption that will simplify the process somewhat: we’re going to assume that the relationship is linear; we’ll assume that that it can be described by a straight line of the form .

In terms of our modeling effort, this means that we’ll be using linear regression to build our machine learning model.

Without going into the details of linear regression (I’ll save that for another blog post), let’s look at how we implement linear regression with caret.

The train() function

The core of caret’s functionality is the train() function.

train() is the function that we use to “train” the model. That is, train is the function that will “learn” the relationship between mpg and wt .

Let’s take a look at this syntactically.

Here is the syntax for a linear regression model, regressing mpg on wt .