Pytorch classification model to .tflite with Google-AI-Edge library.

George Soloupis
4 min readMay 26, 2024

--

Written by George Soloupis ML and Android GDE.

In this blog post, we will explore how to convert a PyTorch model to .tflite using the Google AI Edge library. Traditionally, this conversion involved using the ONNX library. The process required converting the PyTorch model to ONNX, then converting the ONNX model to TensorFlow, and finally using the TensorFlow Lite converter to obtain a .tflite file. Recently, the Google AI Edge library has significantly simplified this procedure, making the conversion process much more straightforward.

We will demonstrate the conversion of a basic image classification model, specifically ResNet18. Our goal is to obtain a .tflite file that can be directly used with the MediaPipe library in an official Android project, eliminating the need for any additional pre- or post-processing procedures. Let’s dive directly into the procedure:

  1. Install helper libraries.
!pip install -r https://github.com/google-ai-edge/ai-edge-torch/releases/download/v0.1.1/requirements.txt
!pip install mediapipe
!pip install ai-edge-torch==0.1.1
!pip install validators matplotlib
#For metadata writer
!pip install tflite-support

2. Import the libraries we are going to use.

import torch
from PIL import Image
import torchvision.transforms as transforms
from torch import nn
import numpy as np
import json
import requests
import matplotlib.pyplot as plt
import warnings
import validators
warnings.filterwarnings('ignore')
%matplotlib inline

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Using {device} for inference')

3. Download and create an instance of the Resnet18 Pytorch model.

model18 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
# or any of these variants
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet152', pretrained=True)
model18.eval()

4. Download the labels .txt file.

!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

5. Download and check one of the images we are going to use for inference.

uris = [
'http://images.cocodataset.org/test-stuff2017/000000024309.jpg',
'http://images.cocodataset.org/test-stuff2017/000000028117.jpg',
'http://images.cocodataset.org/test-stuff2017/000000006149.jpg',
'http://images.cocodataset.org/test-stuff2017/000000004954.jpg',
]

pil_image = Image.open(requests.get(uris[0], stream=True).raw)
print(type(pil_image))

resized_image = pil_image.resize((224, 224))
plt.imshow(resized_image)
plt.show()

# Apply the transformation
# Convert the PIL image to a NumPy array and expand dimensions
numpy_image = np.array(resized_image)
numpy_image_same = np.expand_dims(numpy_image, axis=0)

# Check the type and shape of the tensor
print(type(numpy_image_same))
print(numpy_image_same.shape)

6. Create the wrapper class that will encapsulate the Pytorch model and will contain the additional normalization steps. You can add different manipulation steps depend on your Pytorch models here at the ‘forward’ function.

class ImageClassificationModelWrapper(nn.Module):

def __init__(self, image_classification_model):
super().__init__()
self.model = image_classification_model

def forward(self, image: torch.Tensor):
# BHWC -> BCHW.
image = image.permute(0,3,1,2)
# Normalize inputs [0,1]
# Add more steps based on your model
input_batch = image / 255.0
logits = self.model(input_batch)
logits = torch.nn.functional.softmax(logits[0], dim=0)

return logits

7. Create the wrapped model.

wrapped_pt_model = ImageClassificationModelWrapper(
model18).eval()

8. Use the google-ai-edge library to convert the model feeding it with a random initialized torch tensor that the model expects.

import ai_edge_torch

sample_args = (torch.rand((1, 224, 224, 3)),)
edge_model = ai_edge_torch.convert(wrapped_pt_model, sample_args)

9. Convert the object to a .tflite file.

edge_model.export('resnet18.tflite')

10. Having the model it is important to test it with a random image to do classification. The easiest way is to use the TensorFlow Lite Interpreter API.

import tensorflow as tf

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="/content/resnet18.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on random input data.
input_shape = input_details[0]['shape']
input_data = np.array(numpy_image_same, dtype=np.float32)

interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])

# Read the categories
with open("imagenet_classes.txt", "r") as f:
categories = [s.strip() for s in f.readlines()]
# Show top categories per image
top5_prob, top5_catid = torch.topk(torch.from_numpy(output_data), 5)
for i in range(top5_prob.size(0)):
print(categories[top5_catid[i]], top5_prob[i].item())

the above code will print:

laptop 0.3680455684661865
notebook 0.21134474873542786
screen 0.09741917997598648
desk 0.08575554192066193
monitor 0.05525420978665352

11. You can visualize the model with the latest Model Explorer library.

!pip install ai-edge-model-explorer
import model_explorer
model_explorer.visualize('/content/resnet18.tflite')

12. At a very last step we need to add metadata to the .tflite file so the MediaPipe library will use it easily. Since we are already doing normalization during inference, there is no need to add normalization parameters to the .tflite file, so we keep them as MEAN = 0 and STD = 1

from tflite_support.metadata_writers import image_classifier
from tflite_support.metadata_writers import writer_utils
ImageClassifierWriter = image_classifier.MetadataWriter
_MODEL_PATH = "/content/resnet18.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "/content/imagenet_classes.txt"
_SAVE_TO_PATH = "/content/resnet18_meta.tflite"
# Normalization parameters is required when reprocessing the image. It is
# optional if the image pixel values are in range of [0, 255] and the input
# tensor is quantized to uint8. See the introduction for normalization and
# quantization parameters below for more details.
# https://www.tensorflow.org/lite/models/convert/metadata#normalization_and_quantization_parameters)
_INPUT_NORM_MEAN = 0.0
_INPUT_NORM_STD = 1.0

# Create the metadata writer.
writer = ImageClassifierWriter.create_for_inference(
writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],
[_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

The ‘resnet18_meta.tflite’ file is ready to be inserted at the official MediaPipe’s classification project and work out of the box.

You can find the complete Colab notebook at this GitHub repository.

Conclusion
We demonstrated the conversion of a simple PyTorch model to .tflite in a few basic steps. We also added metadata to the file, making it ready for use in a basic Android classification app using MediaPipe. This procedure, which now involves only a few steps thanks to the Google AI Edge library, has become simpler than ever.

--

--

George Soloupis

I am a pharmacist turned android developer and machine learning engineer. Right now I’m a senior android developer at Invisalign, a ML & Android GDE.