A vectorised proof with no summations nor indices.

TLDR

Back propagation is the spine of deep learning. While there are plenty of literature on the this subject, there are few that thoroughly explain where the formulas of gradients (∂loss / ∂W) needed for back propagation come from. Even when they do explain, the math tends to get long and filled with indices everywhere because of the high dimensiality of the problem: you have an index for the number of samples, an index for the number of layers and one for the number of neurons in each layer.

So every-time I was preparing for a job interview or just want to refresh the machine learning stuff I learned in the past, I always had a hard time writing the math behind back propagation without looking to a text book. Although I understood how it works, I found that the formulas were very unintuitive and sometimes confusing…

So the idea of this article is to demonstrate back-propagation formulas in an elegant way using only vectorised calculus: So no indices i,j,k,… and no summation ∑ at all! (The only index that we’ll be using is to denote the layer.)

Introduction

Suppose we have the following neural network (NN) architecture

NN with 2 hidden layers and an output of 3 classes

Our NN has two hidden layers and one output layer of three classes. We will be using softmax activation for all the layers.

Training this model means minimzing our loss function which is the cross entropy (log loss). Since we have 3 classes, the log loss is: 𝜉(Y,X,W₁,W₂,W₃) = sum[Y ○ log(X₃)]

With: “sum” being the sum of all the elements in a matrix, and “○” being the element-wise product of two matrices of the same shape. “○” is also called the Hadamard product.

X₀ is a n×3 matrix with n being the size of the mini batch.

W₀ is a 3×5 matrix. It has the weights for the transition from layer 0 to 1.

X₁ is a n×5 matrix. It represents the data transformation of the first layer.

W₁ is a 5×4 matrix. It has the weights for the transition from layer 1 to 2.

X₂ is a n×4 matrix. It represents the data transformation of the second layer.

W₂ is a 4×3 matrix. It has the weights for the transition from layer 2 to 3.

X₃ is a n×3 matrix. It represents the data transformation of the third layer.

Y is a n×3 matrix with 3 being the number of classes.

Basically the goal of this article is to demonstrate the following differential expression: (★)