Don't use deep learning your data isn't that big

The other day Brian was at a National Academies meeting and he gave one of his usual classic quotes:

Best quote from NAS DS Round Table: "I mean, do we need deep learning to analyze 30 subjects?" - B Caffo @simplystats #datascienceinreallife — CMU Statistics & Data Science (@CMU_Stats) May 1, 2017

When I saw that quote I was reminded of the blog post Don’t use hadoop - your data isn’t that big. Just as with Hadoop at the time that post was written - deep learning has achieved a level of mania and hype that means people are trying it for every problem.

The issue is that only a very few places actually have the data to do deep learning. Sure if you are Google and have everyone’s emails over the last decade or if you are Facebook and have billions of tagged images, then deep learning makes sense. But I’ve always thought that the major advantage of using deep learning over simpler models is that if you have a massive amount of data you can fit a massive number of parameters.

When your dataset isn’t that big, doing something simpler is often both more interpretable and it works just as well due to potential overfitting. To test this idea I’m going to do an experiment on the digits data. I’m going to build a model just to predict one versus zero. I’m going to do that using simple linear regression with the top ten predictors and I’m going to use a deep neural network following this post. First lets load the packages we need:

library(readr) library(h2o) library(caret) library(dplyr) library(genefilter) library(RSkittleBrewer)

Then load the data:

dat = read_csv("/home/jtleek/train.csv") dat = dat %>% filter(label < 2)

Now what I’m going to do is break the data into a training set and a testing set, leaving 20% for testing.

library(caret) set.seed(12345) inTrain = createDataPartition(dat$label, p=0.8, list=FALSE) training = dat[inTrain,] testing = dat[-inTrain,]

Using these data we can now try our experiment. I’m going to compare two methods:

The Leekasso an approach which involves picking the top 10 best predictive pixels and using them in a linear model.

A deep learning architecture with 5 layers and 160 nodes per layer with a Tanh rectifier and 20 epochs as described here

I’m going to create training sets of size 10 to 80, increasing by 5 each time. I’m going to do this 5 times so I can try to average out some of the noise.

local.h2o <- h2o.init(ip = "localhost", port = 54321, startH2O = TRUE, nthreads=-1) ntrain = dim(training)[1] ss = seq(10,80,by=5) B = 5 leek = deep = matrix(NA,ncol=length(ss),nrow=B) tsData = as.h2o(testing) testing_labels = testing$label for(i in seq_along(ss)){ for(b in 1:B){ samp = createDataPartition(training$label,p=ss[i]/length(training$label))$Resample1 training0 = training[samp,] tmp = colFtests(as.matrix(training0[,-1]),as.factor(training0$label)) index = which(rank(tmp$p.value) <= 10) leekasso0 = lm(training0$label ~ ., data=training0[,(index + 1)]) leek[b,i] = mean((predict(leekasso0,testing) > 0.5) == testing$label) training0$label = as.factor(training0$label) trData = as.h2o(training0) res.dl <- h2o.deeplearning(x = 2:785, y = 1, trData, activation = "Tanh", hidden=rep(160,5),epochs = 20) #use model to predict testing dataset pred.dl<-h2o.predict(object=res.dl, newdata=tsData[,-1]) pred.dl.df<-as.data.frame(pred.dl) deep[b,i] = sum(diag(table(testing_labels,pred.dl.df$predict)))/length(testing_labels) } print(i) }

Now we plot the accuracy of each of these methods versus sample size with vertical bars showing the 10th and 90th percentiles for accuracy.

trop = RSkittleBrewer("tropical") plot(ss,colMeans(leek),col=trop[1],type="l", ylim=c(0.5,1), xlab="Training Set Sample Size", ylab="Accuracy",lwd=3) upp = apply(leek,2,quantile,0.9,na.rm=T) low = apply(leek,2,quantile,0.1,na.rm=T) segments(ss,low,ss,upp,col=trop[1],lwd=3) lines(ss,colMeans(deep),col=trop[2],lwd=3) upp = apply(deep,2,quantile,0.9,na.rm=T) low = apply(deep,2,quantile,0.1,na.rm=T) segments(ss,low,ss,upp,col=trop[2],lwd=3) legend(50,0.7, legend=c("Top 10 (Leekasso)", "Deep Learning"), col=trop[1:2],lwd=3,lty=1)

For low training set sample sizes it looks like the simpler method (just picking the top 10 and using a linear model) slightly outperforms the more complicated methods. As the sample size increases, we see the more complicated method catch up and have comparable test set accuracy.

This is an extremely simple example but illustrates the larger point that Brian was making above. The sample size matters. If you are Google, Amazon, or Facebook and have near infinite data it makes sense to deep learn. But if you have a more modest sample size you may not be buying any accuracy - just losing interpretability. Although I guess you still get to keep the hype :).

Please enable JavaScript to view the comments powered by Disqus.

Disqus