Crazy paths often lead to the right destination!

Optimizing a cost function is one of the most important concepts in Machine Learning. Gradient Descent is the most common optimization algorithm and the foundation of how we train an ML model. But it can be really slow for large datasets. That’s why we use a variant of this algorithm known as Stochastic Gradient Descent to make our model learn a lot faster. But what makes it faster? Does it come at a cost?

Well…Before diving into SGD, here’s a quick reminder of Vanilla Gradient Descent…

We first randomly initialize the weights of our model. Using these weights we calculate the cost over all the data points in the training set. Then we compute the gradient of cost w.r.t the weights and finally, we update weights. And this process continues until we reach the minimum.

The update step is something like this…

J is the cost over all the training data points

Now, what happens if the number of data points in our training set becomes large? say m = 10,000,000. In this case, we have to sum the cost of all m examples just to perform one update step!

Here comes the SGD to rescue us…

Instead of calculating the cost of all data points we calculate the cost of one single data point and the corresponding gradient. Then we update the weights.

The update step is as follows…

J_i is the cost of ith training example

We can easily see that in this case update steps are performed very quickly and that is why we can reach the minimum in a very small amount of time.

But…Why SGD works?

The key concept is we don’t need to check all the training examples to get an idea about the direction of decreasing slope. By analyzing only one example at a time and following its slope we can reach a point that is very close to the actual minimum. Here’s an intuition…

Suppose you have made an app and want to improve it by taking feedback from 100 customers. You can do it in two ways. In the first way, you can give the app to the first customer and take his feedback then to the second one, then third and so on. After collecting feedbacks from all of them you can improve your app. But in the second way, you can improve the app as soon as you get the feedback from the first customer. Then you give it to the second one and you improve again before giving it to the third one. Notice that in this way you are improving your app at a much faster rate and can reach a optimal point much earlier.

Hopefully, you can tell that the first process is the Vanilla Gradient Descent and the second one is SGD.

But SGD has some cons too…

SGD is much faster but the convergence path of SGD is noisier than that of original gradient descent. This is because in each step it is not calculating the actual gradient but an approximation. So we see a lot of fluctuations in the cost. But still, it is a much better choice.

Convergence paths are shown on a contour plot

We can see the noise of SGD in the above contour plot. It is to be noted that vanilla GD takes a fewer number of updates but each update is done actually after one whole epoch. SGD takes a lot of update steps but it will take a lesser number of epochs i.e. the number of times we iterate through all examples will be lesser in this case and thus it is a much faster process.

As you can see in the plot there is a third variant of gradient descent known as Mini-batch gradient descent. This is a process that uses the flexibility of SGD and the accuracy of GD. In this case, we take a fixed number(known as batch size) of training examples at a time and compute the cost and corresponding gradient. Then we update the weights and continue the same process for the next batch. If batch size = 1 then it becomes SGD and if batch size = m then it becomes normal GD.

J_b is the cost of bth batch

Implementation from scratch

Here’s a python implementation of mini-batch gradient descent from scratch. You can easily make batch_size = 1 to implement SGD. In this code, I’ve used SGD to optimize the cost function of logistic regression for a simple binary classification problem.

Find the full code here.

Still curious? Watch a video that I made recently…

I hope you enjoyed the reading. Until next time…Happy learning!