Cross-Platform On-Device ML Inference

TensorFlow Lite ft. Flutter

With the release of Flutter 1.9 and Dart 2.5, there are too many exciting things to try out, but what really caught my attention is Dart’s ML Complete. It uses tflite_native , which in turn uses TensorFlow Lite C API via Dart FFI.

It immediately sparks a crazy idea in my mind, a single codebase for an app on multiple platforms (iOS, Android, Mac, Windows, Linux, even Web) that can do low-latency local machine learning inferencing.

Previously, it is not truly single codebase, even with Flutter, as there must be some code from the platform side using platform channel, but with FFI, we can now say that it’s truly single codebase.

After several trial-and-error iterations, I finally got the prototype of TensorFlow Object Detection with tflite_native . Source code is MIT-licensed and can be found on my Github. Notice, if you are running this source example on iOS, please use Flutter 1.10.x from dev Channel ( flutter channel dev; flutter upgrade ), as DynamicLibrary.process() was removed in 1.9.1 stable, and re-added later (ref).

It’s a long and eventful journey, so I will split it up to a series of articles, each of which works on its own without additional context and might have a different audience.

This article is about using tflite_native to load model and labels, create interpreter, pre-process input to feed to the model, and interpret output data from the model.

to load model and labels, create interpreter, pre-process input to feed to the model, and interpret output data from the model. An upcoming article about making tflite_native , the Flutter plugin, without any platform channel.

, the Flutter plugin, without any platform channel. An upcoming article about the rest of Dart/Flutter infrastructure of the app outside the scope of TensorFlow Lite, such as dependencies Isolate and how FFI can be used to do unsafe parallel threading in Dart.

1. Configuring dependencies and assets

To load tflite model, first, we need to have that model in our assets. Download the zip file, and extract it into assets folder of Flutter project, then declare the assets in pubspec.yaml .

There are 2 assets, model.tflite and labels.txt , both of which has been renamed. The first file is the pre-trained ML model that we can use to do the prediction, but the output are only numbers, among which are the index of labels, which are stored in the second file.

For this prototype, we will also need path_provider , camera , image , and of course tflite_native . path_provider is needed to load the model from the assets, camera is used to get real-time (raw) image stream, image is used to process the raw image into a usable RGB format that can be fed to our model. For tflite_native , for now we’ll use my forked version.

We also need to configure ios/Runner/Info.plist to get permission for camera access.

2. Load tflite model and labels

Right now, it’s not straight forward to load tflite model directly from assets. Current C API can either load model from bytes, or from accessible files. Loading model from accessible files is recommended, because deep-down, the TfLite C API uses memory-mapped file (mmap), which is vital if you need to load a large model on a memory-constrained device. Unfortunately, Flutter from the platform side does not have any easy way to read from bundle assets , so we have to do the work-around, by reading the model from assets and write to a temporary file, and then let TfLite C API reads from that. Afterward, we create an InterpreterOption , set number of threads we want, load the model and allocate tensors before returning the interpreter ready to be used. Ideally, with InterpreterOption , we should have the option to use GPU or NNAPI, which significantly reduces the inference time, but at the moment those APIs is not yet available in TfLite C API.

Loading labels is simple (bot not necessarily straight-forward), just read text from labels.txt , separated by new-line, and shift-by-1, (I guess the first element in labels.txt is reserved for list length)

3. Pre-processing input

A quick recap, a model may have one or many input tensors (e.g. SSD MobileNet has 1 input tensor, but the transfer model of Artistic Style Transfer has 2 input tensors, 1 for style and 1 for content). Each input tensor has a shape, which is a list of int , e.g. [1, 480, 640, 4] , and an expected input type, usually either float32 (non-quantized) or u_int8 (quantized), but there are others. Knowing the shape and expected input type is important to prepare the right input for our tensors. In our case, the input tensor expect 300 width x 300 height x 3-channel color (RGB), and input type is u_int8.

Depends on the platform, we have different raw images from the stream. The method of getting raw images from the camera stream will be covered in a separate article. On Android, camera returns raw images in YUV_420, and it takes a little bit effort to convert to RGB format.

On iOS, camera returns raw images in BGRA, and with the help of image package, we can convert to RGB with ease.

Now we have the image in RGB format already, but not necessarily in the right dimension, thus we need to resize and crop it into a square of 300x300 (and on Android, we also have to rotate by 90°, it was a known issue). Finally, get the raw bytes and original image size (because depends on whether we’re on Android or iOS, it could be 640x480 or 480x640 after conditional rotation)

4. Feed input, invoke model, interpret output

The raw bytes from previous is now ready to be fed to our model, which has only 1 input tensor. We just need to get the hold of it to set data and invoke the model.

Unlike input, which has only 1 tensor, this model’s output has 4 tensors:

We parse the output (from raw bytes) to the corresponding approximate data type. By approximate, I mean that if the output is a 2D list like List<List<Float32>> , we still parse it as a 1D list List<Float32> , and will interpret as 2D list shortly after.

Now, we have the detection results scattered among 3 different variables, the 1st one being an ordered list of rectangle location, the 2nd being an ordered list of classification index, whose index is corresponding to our labels map, the 3rd being a list of probability scores (with additionally the 4th being a number of detected result). We need to combine these 3 ordered lists into a single new list, whose element has the detected class name in English (e.g. “bicycle” or “laptop”), probability scores, and rectangle location.

Generally, we just do a for loop, and process each element attribute we are interested in. For a detected class name in English, we simply need to convert it from double to int , and look it up in the label map we loaded before (and already shift-by-1). For probability scores, the raw output is between 0.0 (0%) and 1.0 (100%), so we just use the raw output directly.

Rectangle location is a little bit more complicated. As mentioned, this tensor shape is a list of rectangle location, which itself is a list of 4 coordinate points (top, left, bottom, right), each of which is between 0.0 (corresponding to 0 of our 300x300 input image) and 1.0 (corresponding to 300 of our input image). So, if other lists have 10 elements (of 10 detected class names and probabilities scores), this 2D list, when flattened to 1D, has 40 elements.

In other words, assuming i is the index of the element

And boom 💥, there you have it! Whenever there’s an image coming from the camera stream, it will be processed, and we get a list of detection results, each of which has the detected object name, probability scores (from 0.0 to 1.0), and rectangle location. To overlay these locations, we will need to scale their coordinates to match the underlying picture, but it’s out of this article scope. Technically, this code works on all Flutter platforms (iOS, Android, Linux, Windows, Mac), except web (but we can use later as well if there’s an integration between Flutter Web nad TensorFlow.js). However, right now we are constrained by other plugins rather than tflite_native , such as camera or path_provider . It’s worth reminded that tflite_native itself is used to power Dart’s ML Complete on desktop.

Final thoughts

Machine learning is not only about descent gradient, binary classification, Accelerated Linear Algebra (XLA), or other intimidating jargons. In fact, machine learning work itself is quite broad, and here we only do the “production inferencing” of the workflow.

To be even more specific, production inferencing can happen remotely on a powerful server, and/or locally on users’ device. With TensorFlow Lite and tflite_native specifically, we’re doing the latter. Even though the accuracy of the local detection result might not be as good as those on remote powerful clusters of servers, the local inference has other advantages, such as low latency, no connectivity required, no internet bandwidth consumed, and no data - potentially personal and/or sensitivity - leave users’ device. In fact, local and remote inferencing can be combined together to have the best of both worlds. In this specific example, we can extend further such that the local app can detect a traffic sign with a reasonably high score, and only sends that portion of the image to the server, to further recognize which traffic sign it is (stop sign, speed limit, warning, etc.)

Doing this example helps you to open the door to the machine learning world. It is not too hard for a fresh software engineer, but not too easy either. Completing this, you will understand the way a pre-trained ML model communicates (input and output), and also aspire to learn more about using GPU (instead of CPU) for inferencing, quantized vs non-quantized trained model, or convert a from one type of pre-trained model to another (pb vs chkpt vs tflite).

Finally, as Federated Learning (FL) gets more popular, let's hope TensorFlow Federated (TFF) will support mobile (FL has already been used to train prediction models for mobile keyboards), or even more crazy, let’s hope TFF goes Flutter first 😉

Once again, source code is MIT-licensed and can be found on my Github.