Model training

Let’s train the model! I will be training this model on my laptop, which does not have enough RAM to take the entire dataset into memory. With image data, this is very often the case. Keras provides the model.fit_generator() method that can use a custom Python generator yielding images from disc for training. However, as of Keras 2.0.6, we can use the Sequence object instead of a generator which allows for safe multiprocessing which means significant speedups and less risk of bottlenecking your GPU if you have one. The Keras documentation already provides good example code, which I will customize a bit to:

make it work with a dataframe that maps image names to labels

shuffle the training data after every epoch

This Sequence object can be used instead of a custom generator together with fit_generator() to train the model. Note that there is no need to provide the number of steps per epoch, since the __len__ method implements that logic for the generator.

Furthermore, tf.keras provides access to all the available Keras callbacks that can be used to enhance the training loop. These can be quite powerful and provide options for early stopping, learning rate scheduling, storing files for TensorBoard… Here, we will use the ModelCheckPoint callback to save the model after every epoch so that we can pick up training afterwards if we want. By default, the model architecture, training configuration, state of the optimizer and the weights are stored, such that the entire model can be recreated from a single file.

Let’s train the model for a single epoch:

Epoch 1/1

Epoch 00001: saving model to ./model.h5

1265/1265 [==============================] - 941s 744ms/step - loss: 0.8686 - weather_loss: 0.6571 - ground_loss: 0.2115

Suppose that we want to finetune the model in a later stage, we can simply read the model file and pick up training without recompiling:

Finally, it’s good to verify that our Sequence effectively passes over all the data by instantiating a Sequence in test mode (that is, without shuffling) and using it to make predictions for the entire dataset:

Wait, what about the Dataset API?

The tf.data API is a powerful library that allows to consume data from various sources and pass it to TensorFlow models. Can we train our tf.keras model using the tf.data API instead of with the Sequence object? Yes. First of all, let’s serialize the images and labels together into a TFRecord file, which is the recommended format for serializing data in TensorFlow:

After dumping the images and the labels into a TFRecord file, we can come up with another generator using the tf.data API. The idea is to instantiate a TFRecordDataset from our file and tell it how to parse the serialized data using the map() operation.

Dataset objects provide multiple methods to produce iterator objects to loop over the data. However, as of TensorFlow 1.9, we can simply pass our ds_train directly to model.fit() to train the model:

Epoch 1/1

100/100 [==============================] - 76s 755ms/step - loss: 0.5460 - weather_loss: 0.3780 - ground_loss: 0.1680