7. Weight Initialization

If you have no clue about how to properly initialize your model weights (just like I did not have any idea, when I got started): As a simple rule of thumb, initialize all your biases with zeros (tf.zeros(shape)) and your weights (kernels of convolutions and weights of fully connected layers) with non zero values, drawn from some kind of normal distribution. For example you could simply use tf.randomNormal(shape), but nowadays I prefer to use a glorot normal distribution, which is available in tfjs-layers as follows:

8. Shuffle your Inputs!

A common advice for training a neural network is to randomize the order of occurence of your training samples by shuffling them at the begin of each epoch. Conveniently, we can use tf.utils.shuffle for that purpose, which will shuffle an arbitray array inplace:

9. Saving Model Checkpoints using FileSaver.js

Since we are training our model in the browser, you may now ask yourself: How do we automatically save checkpoints of our model weights while training? We simply use FileSaver.js. The script exposes a function called saveAs, which we can use to store arbitrary types of files, which will end up in our downloads folder.

This way we can save our model weights:

Or even json files, for example to save the accumulated losses for an epoch:

Troubleshooting

Before spending a lot of time at training your model, you want to make sure, that your model is actually learning what it is supposed to and erase any potential source of errors and bugs. If you do not consider the following tips, you might end up wasting your time training complete garbage and you will end up wondering:

10. Check your Input Data, Pre- and Post Processing Logic!

If you pass garbage into your network, it will throw garbage back at you. Thus, make sure your input data is labeled correctly and that your network inputs are what you expect them to be. Especially if you have implemented some preprocessing logic like random cropping, padding, squaring, centering, mean subtraction or what else, make sure to visualize your inputs after preprocessing. Also I would highly recommend unit testing these steps. Same goes for post processing of course!

I know this sounds like a tedious amount of extra work, but it is worth it for sure! You won’t believe, how many hours I was trying to figure out, why the heck my object detector did not learn to detect faces at all, until I eventually discovered my preprocessing logic to turn inputs into trash due to incorrect cropping and distortion.

11. Check your Loss Function!

Now in most cases tensorflow.js luckily provides you with the loss function of your needs. However, in case you need to implement your own loss function, you should definitely unit test it! A while ago, I implemented the Yolo v2 loss function using the tfjs-core API from scratch to train yolo object detectors for the web. Let me tell you that this can get very hairy, unless you break down the problem and make sure, the individual components compute what they are supposed to.

12. Overfit on a small Dataset first!

Generally it’s a good idea, to overfit on a small subset of your training data, to verify, that the loss is converging and that your model is actually learning something useful. Therefore, you should simply pick 10 to 20 images of your training data and train for some epochs. Once the loss converges, run inference on these 10 to 20 images and visualize the results:

This is a very important step, which will help you to eliminate all kinds of sources of bugs in the implementation of your network, pre and post processing logic, as it is unlikely, that your model will learn to make the desired predictions with substantial bugs in your code.

Especially, if you are implementing your own loss function (11.) you definitely want to make sure, your model is able to converge before jumping into training it!

Performance

Finally, I want to give you some advice, which will help you to reduce training time as much as possible and prevent your browser from crashing with memory leaks, by considering some basic principles.

13. Preventing obvious Memory Leaks

Unless you are completely new to tensorflow.js, you probably already know, that we have to dispose unused tensors manually to free up their memory by either calling tensor.dispose() or wrapping our operations in tf.tidy blocks. Ensure, that there are no such memory leaks due to not disposing tensors correctly, otherwise your application will sooner or later run out of memory.

Identifying these kinds of memory leaks is pretty easy. Simply log tf.memory() for a few iterations to verify, that the number of tensors does not inadvertently grow with each iteration:

14. Resize your Canvases and not your Tensors!

Note, the following statements are only valid as of the current state of tfjs-core (I am currently using tfjs-core version 0.12.14) until this will eventually get fixed.

I know this might sound a bit strange: Why not use tf.resizeBilinear, tf.pad and such to reshape your input tensors to the desired network input shape? There is currently an open issue at tfjs, illustrating the problem.

TLDR: Before calling tf.fromPixels, to convert your canvases to tensors, resize your canvases, such that they have the size accepted by your network, otherwise you will run out of GPU memory quickly, depending on the variety of different input sizes of the images in your training data. This will be less of a problem if your training images are all equally sized anyways, but in case you have to resize them explicitly, you can use the following code snippet:

15. Figuring out the optimal Batch Size

Don’t go overboard with batching your inputs! Try out different batch sizes and measure the time required for backpropagation. The optimal batch size obviously depends on your GPU stats, the input size as well as the complexity of your network. In some cases you don’t want to batch your inputs at all.

If in doubt however, I would always go with a batch size of 1. Personally, I figured out that in some cases increasing the batch size doesn’t really help for performance, but in other cases I could see an overall speedup by a factor of somewhere around 1.5–2.0 by creating batches of a size 16 to 24 for an input image size of 112 x 112 pixels at a fairly small network size.

16. Caching, Offline Storage, Indexeddb

Our training images (and labels) might be of considerably large size, maybe up to 1GB or even more, depending on the size as well as the number of your images. Since we can not simply read images from disk in the browser, we would instead use a file proxy, which might be a simple express server, to host our training data and the browser would fetch every single data item.

Apparently, this is very inefficient, but something we have to keep in mind when training in the browser. If your dataset is small enough, you could probably try to keep your entire data in memory, but that’s obviously not very efficient either. Initially, I tried to increase the browser cache size to simply cache the entire data on disk, but that seems to not work anymore in later versions of Chrome and I had no luck with FireFox either.

Finally, I decided to just go for Indexeddb, an in browser database in case you are not familar, which we can utilize to store our entire training and test data sets. Getting started with Indexeddb is quite simple, as we can basically store and query our entire data as key value stores with only a few lines of code. With Indexeddb we can conveniently store our labels as plain json objects and our image data as blobs. Check out this blog post, which nicely explains, how to persist image data and other files in Indexeddb.

Querying Indexeddb is quite fast, atleast I found it to be way faster to query each data item, than fetching files from the proxy over and over again. Plus, after moving your data into Indexeddb, training technically works completely offline now, meaning we might not need the proxy server anymore.

17. Async Loss Reporting

This is a simple, yet pretty effective tip, which helped me a lot reducing iteration times while training. The main idea is, in case we want to retreive the value of our loss tensors returned by optimizer.minimize, which we certainly do, because we want to keep track of our loss while training, we want to avoid awaiting the Promise returned by loss.data() to prevent waiting for CPU and GPU to synchronize at each iteration. Instead we want to do the something like the following for reporting the loss value for an iteration:

We simply have to keep in mind, that our losses are now reported asynchronously, so in case we want to save the overall loss at the end of each epoch to a file, we will have to wait for the last promises to resolve, before doing so. I usually just hack around this issue by using a setTimeout for saving the overall loss value 10 seconds or so after an epoch has finished:

After successfully Training a Model

18. Weight Quantiaztion

Once we are done training our model and we are satisfied with it’s performance, I would recommend to shrink the model size by applying weight quantization. By quantizing our model weights, we can reduce the size of our model to 1/4th of the original size! Reducing the size of our model as much as possible is critical for fast delivery of our model weights down to the client application, especially if we can get it basically for free.

Thus, make sure to check out my guide about weight quantization with tensorflow.js: Shrink your Tensorflow.js Web Model Size with Weight Quantization.