Ever since Tensorflow released Bindings for Go, I’ve been itching to give it a go. The ease of deployability with Go and microservice friendliness and even simple http performance improvements make it really handly to build a working prediction application with Go.

The immediate and apparent downside for anyone who’s tried to train a model is how unintuitive scoping is with Tensorflow for Go.Python’s a lot easier to train models with for a newcomer because of a lot of things:

+ The excellent Keras Library and API. + Most tutorials and lessons are done with Python. + Numpy is wonderfully intuitive. + Fantastic image processing support. + Generators are fun.

However, once the model is trained, making inference/applying feedforward does not need too much Tensorflow scoping wizardry.

Go is an obvious choice for deployment/inference because:

+ Fantastic for writing microservices. + Nice for people who prefer composition over inheritance for app building. + Benchmarks show Go is much faster than Python. While this does not matter when training because GPU compute doesn't require too much parallelism/concurrency. + Go with Chi seems to trump Python with Flask and even Python with Twisted by a considerable margin + Easy concurrency primitives add versatility to leverage the model.

If you want to skip the ramble and get your hands dirty, you can use telemus as a Quick Start.

Training a gender classifier in Python

For my proof of concept, I decided to a do a simple real world classifier that would take in a picture and classify if it was male or female. To this end, we will be fine-tuning a ResNet50 model replacing the final layer with a Softmax Classifier.

Data

For the data, I will be using the wonderful CelebA dataset that labels 200,000 odd celebrities with over 40 attributes (Male being one of them).

Now, I plan to source mine more data of my own for the future and therefore split the images into folders by class. If you intend to do this, the rudimentary script below helps.

import csv import shutil import os src_path = '/path/to/celebA/data' dst_path = '/path/to/train/directory' def prepare_celebA(): reader = csv.DictReader(open('list_attr_celeba.txt'), delimiter=' ') for row in reader: img_name = row['image'] src = os.path.join(src_path, img_name) if row['Male'] == '1': dst = os.path.join(dst_path, 'male', img_name) shutil.copyfile(src, dst) if row['Male'] == '-1': dst = os.path.join(dst_path, 'female', img_name) shutil.copyfile(src, dst) if __name__ == '__main__' : prepare_celebA()

The script above can be repurposed slightly for validation data as well.

Pre-requisites to be able to deploy in Go

While I like Keras and use Keras to run my models, the Go application requires a number of things to be met.

+ The model to be transferred has to be a tensorflow graph. + The tensorflow graph has to be saved with a tag. + The input layer and inference layer have to be named.

Training

With the pre-requisites above in mind, we start a tensorflow session and add it to a Keras backend first thing in our code.

import tensorflow as tf from keras import backend as K sess = tf.Session() K.set_session(sess)

I use two generators to read my prepared celeb data. The load_celeb_a code is listed below.

The images I use are 224 * 224 and I only have an Nvidia GTX 1080 so I tend to keep the train_batch_size to only as much as my GPU can handle.

from keras.preprocessing.image import ImageDataGenerator train_directory = "" valid_directory ="" train_batch_size = "50" valid_batch_size = "50" def load_celeb_data(img_rows=224, img_cols=224): train_batches = ImageDataGenerator( rotation_range=40, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode='nearest' ).flow_from_directory(train_directory,target_size=(img_rows,img_cols), shuffle=True, batch_size=train_batch_size, classes=('male', 'female')) valid_batches = ImageDataGenerator().flow_from_directory(valid_directory,target_size=(img_rows,img_cols), shuffle=True, batch_size=valid_batch_size,classes=('male','female')) return train_batches, valid_batches

Finally, finetune the model.

model = keras.applications.resnet50.ResNet50() classes = train_gen.class_indices model.layers.pop() for layer in model.layers: layer.trainable = False for layer in model.layers[:30]: layer.trainable = True last = model.layers[-1].output x = Dense(len(classes), activation="softmax", name="inferenceLayer")(last) finetuned_model = Model(model.input, x) finetuned_model.compile(optimizer=Adam(lr=1e-04), loss='categorical_crossentropy', metrics=['accuracy'])

The most relevant change here for us is this line right here.

x = Dense(len(classes), activation="softmax", name="inferenceLayer")(last)

I’ve gone ahead and named the inference layer (No points for originality there).

We use the ResNet50 model where the input layer is already named. Its called input_1.

Once the model is fit, make sure to save the graph with some tag.

builder = tf.saved_model.builder.SavedModelBuilder("forGo") builder.add_meta_graph_and_variables(sess, ["tags"]) builder.save()

This is going to get the model saved to a folder named forGo .

You can find the full code here.

Prediction

Time to move to the prediction part. We will build a simple command line app that takes an argument of an image and predicts its output. Copy the forGo directory to a common location. I like to put it within the Go project directory because its not too large.

The Go program is laughably simple once we’ve got the others set up.

All we have to do is to read the image and convert it to a *Tensor data type.

Here’s the gist.

imageFile, err := os.Open(imgName) if err != nil { log.Fatal(err) } var imgBuffer bytes.Buffer io.Copy(&imgBuffer, imageFile) img, err := readImage(&imgBuffer, "jpg") if err != nil { log.Fatal("error reading image: ", err) } result, err := model.Session.Run( map[tf.Output]*tf.Tensor{ model.Graph.Operation("input_1").Output(0): img, }, []tf.Output{ model.Graph.Operation("inferenceLayer/Softmax").Output(0), }, nil, ) if err != nil { log.Fatal(err) } if preds, ok := result[0].Value().([][]float32); ok { fmt.Println(preds) if preds[0][0] > preds[0][1] { fmt.Println("male") } else { fmt.Println("female") } }

You can find the full code here