How to tune a BigQuery ML classification model to achieve a desired precision or recall

Select the probability threshold based on the ROC curve

BigQuery provides an incredibly convenient way to train machine learning models on large, structured datasets. In an earlier article, I showed you how to train a classification model to predict flight delays. Here’s the SQL query that will predict whether a flight is going to be late by 15 minutes or more:

CREATE OR REPLACE MODEL flights.delayed OPTIONS (model_type='logistic_reg', input_label_cols=['late'],

data_split_method='custom', data_split_col='is_train_row') AS SELECT

IF(arr_delay < 15, 0, 1) AS late,

carrier,

origin,

dest,

dep_delay,

taxi_out,

distance,

is_train_day = 'True' AS is_train_row

FROM

`cloud-training-demos.flights.tzcorr` as f

JOIN `cloud-training-demos.flights.trainday` as t

USING(FL_DATE)

WHERE

arr_delay IS NOT NULL

This will create a model called delayed in a dataset called flights. The model will use the columns carrier, origin, etc. as inputs to the model and predict whether or not the flight will be late. The ground truth comes from historical data of US airline arrival delays. Note that I have pre-split the data by day into whether a row is training data (is_train_day=True) or should be used for independent evaluation (is_train_day=False). This is important because flight delays on the same day tend to be highly correlated.

When the training finishes, BigQuery reports training statistics, but we should look at evaluation statistics on the withheld dataset (is_train_day=False):

SELECT * from ML.EVALUATE(MODEL flights.delayed,

(

SELECT

IF(arr_delay < 15, 0, 1) AS late,

carrier,

origin,

dest,

dep_delay,

taxi_out,

distance

FROM

`cloud-training-demos.flights.tzcorr` as f

JOIN `cloud-training-demos.flights.trainday` as t

USING(FL_DATE)

WHERE

arr_delay IS NOT NULL

AND is_train_day = 'False'

))

This gives us:

Evaluation statistics for flight delay model

The flight delays data is an unbalanced dataset — only 18% of flights are late (which I found by doing):

SELECT

SUM(IF(arr_delay < 15, 0, 1))/COUNT(arr_delay) AS fraction_late

FROM

`cloud-training-demos.flights.tzcorr` as f

JOIN `cloud-training-demos.flights.trainday` as t

USING(FL_DATE)

WHERE

arr_delay IS NOT NULL

AND is_train_day = 'True'

In an unbalanced dataset like this, accuracy is not as useful and it is common to have a business goal to meet a desired precision (if 1s are more common) or recall (if 0s are more common).

ROC curve

A classification model actually returns the probability that the flight will be late; the evaluation statistics in the table above are computed by thresholding that probability at 0.5. We can vary this threshold and get the precision and recall at different thresholds. This is called a ROC curve (it’s an acronym that dates back from radar days) and we can get BigQuery to generate this curve using:

SELECT * from ML.ROC_CURVE(MODEL flights.delayed,

(

SELECT

IF(arr_delay < 15, 0, 1) AS late,

carrier,

origin,

dest,

dep_delay,

taxi_out,

distance

FROM

`cloud-training-demos.flights.tzcorr` as f

JOIN `cloud-training-demos.flights.trainday` as t

USING(FL_DATE)

WHERE

arr_delay IS NOT NULL

AND is_train_day = 'True'

))

Essentially, it’s the same as the ML.EVALUATE query, except that you use ML.ROC_CURVE instead — because we are going to tune the threshold, we should do this on the training data (is_train_day=True). The table comes back with 101 rows, with a false positive rate (1-precision) and recall for each threshold that has been tried. By default, the thresholds are based on computing the percentiles on the training dataset:

Recall, false positive rate, etc. as the probability threshold of the classification model is changed

We can click on the “Explore in Data Studio” link in the BigQuery UI to get a graph (set false_positive_rate as the dimension and recall as the metric):

ROC curve in Data Studio

A more useful view is to have threshold as the dimension and the two other statistics as dimensions:

Graph the variation of recall and false_positive_rate by threshold in Data Studio and choose the threshold that gives you a recall close to 70%.

We can tune the threshold to achieve a certain recall (then you will live with whatever precision you get). Let’s say that we want to make sure to identify at least 70% of late flights, i.e. we want a recall of 0.7. From the graph above, you see that the probability threshold needs to be 0.335. This means that we will falsely identify 1% of on-time flights as being late.

Tuning probability threshold in SQL

Here’s a query that will return the threshold without having to draw a graph and mouse over it:

WITH roc AS ( SELECT * from ML.ROC_CURVE(MODEL flights.delayed,

(

SELECT

IF(arr_delay < 15, 0, 1) AS late,

carrier,

origin,

dest,

dep_delay,

taxi_out,

distance

FROM

`cloud-training-demos.flights.tzcorr` as f

JOIN `cloud-training-demos.flights.trainday` as t

USING(FL_DATE)

WHERE

arr_delay IS NOT NULL

AND is_train_day = 'True'

))) SELECT

threshold, recall, false_positive_rate,

ABS(recall - 0.7) AS from_desired_recall

FROM roc

ORDER BY from_desired_recall ASC

LIMIT 1

This gives us the threshold we need to use to get a recall of 0.7:

Choosing the threshold that gives us a recall of 0.7.

Nicely tuned! Photo by Ali Morshedlou on Unsplash

Evaluation and prediction with new probability threshold

We can evaluate (on the withheld dataset) to see if we do hit the desired recall:

SELECT * from ML.EVALUATE(MODEL flights.delayed,

(

SELECT

IF(arr_delay < 15, 0, 1) AS late,

carrier,

origin,

dest,

dep_delay,

taxi_out,

distance

FROM

`cloud-training-demos.flights.tzcorr` as f

JOIN `cloud-training-demos.flights.trainday` as t

USING(FL_DATE)

WHERE

arr_delay IS NOT NULL

AND is_train_day = 'False'

), STRUCT(0.3348 AS threshold))

This gives us:

We get the hoped-for recall of 70% on the independent evaluation dataset!

Hurray! We got recall=70% on the independent evaluation dataset also, so it appears that we have not overfit.

As with the ML.EVALUATE, when we do ML predictions, we can specify the desired threshold for the probability:

SELECT * from ML.PREDICT(MODEL flights.delayed,

(

SELECT

IF(arr_delay < 15, 0, 1) AS late,

carrier,

origin,

dest,

dep_delay,

taxi_out,

distance

FROM

`cloud-training-demos.flights.tzcorr` as f

JOIN `cloud-training-demos.flights.trainday` as t

USING(FL_DATE)

WHERE

arr_delay IS NOT NULL

AND is_train_day = 'False'

LIMIT 10

), STRUCT(0.3348 AS threshold))

Enjoy!