Segmentation task inside android using different APIs
Written by George Soloupis ML and Android GDE.
This is a blog post to demonstrate the usage of different APIs for a common machine learning task, the segmentation. We are going to showcase implementations with TensorFlow Lite Interpreter, Task Library and the latest MediaPipe’s API.
Image segmentation is a sub-domain of computer vision and digital image processing which aims at grouping similar regions or segments of an image under their respective class labels.
Since the entire process is digital, a representation of the analog image in the form of pixels is available, making the task of forming segments equivalent to that of grouping pixels. Image segmentation is an extension of image classification where, in addition to classification, we perform localization.
Image segmentation thus is a superset of image classification with the model pinpointing where a corresponding object is present by outlining the object’s boundary.
This procedure is an extremely common machine learning task thus a lot of APIs exist to help us in this domain. Let’s focus on the low level TensorFlow Lite’s Interpreter, high level TensorFlow Lite’s Task Library and the recent one, the MediaPipe’s implementation. We are going to investigate the initialization of the library, the preprocessing of the image, the inference and the post processing using the same model for every case.
TensorFlow Lite Interpreter
This API has proven its importance and its value the past 4 years with the flexibility to support common and uncommon tasks. You can feed the Interpreter with custom arrays or buffers, multiple inputs or outputs, it has support for CPU and GPU inference and its superb documentation with the guide and the examples make this API handy for experienced and inexperienced users. Let’s have a look at the procedure for image segmentation. You can find a lot of details at the guide. The below procedure is low level with no additional libraries for pre and post processing.
Initialization
Τhe .tflite file is usually placed inside the assets folder, then you load and prepare the Interpreter with options:
private fun getInterpreter(
context: Context,
modelName: String,
useGpu: Boolean = false
): Interpreter {
val tfliteOptions = Interpreter.Options()
tfliteOptions.numThreads = numberThreads
gpuDelegate = null
if (useGpu) {
gpuDelegate = GpuDelegate()
tfliteOptions.addDelegate(gpuDelegate)
}
return Interpreter(loadModelFile(context, modelName), tfliteOptions)
}
private fun loadModelFile(context: Context, modelFile: String): MappedByteBuffer {
val fileDescriptor = context.assets.openFd(modelFile)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
val retFile = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
fileDescriptor.close()
return retFile
}
You can not load directly a bitmap to the Interpreter but you have to convert it to an array or bytebuffer before. Check Netron, a handy visualization tool where you can drag and drop your .tflite files and get the information for inputs and outputs. If your model contains metadata, then you can simple double click your .tflite and the Android Studio will show the information to you:
Pre-processing
The bitmap has to be converted to a bytebuffer to feed it to the Interpreter. Let’s see the code for that:
fun bitmapToByteBuffer(
bitmapIn: Bitmap,
width: Int,
height: Int,
mean: Float = 127.5f,
std: Float = 127.5f
): ByteBuffer {
val bitmap = scaleBitmapAndKeepRatio(bitmapIn, width, height)
val inputImage = ByteBuffer.allocateDirect(1 * width * height * 3 * 4)
inputImage.order(ByteOrder.nativeOrder())
inputImage.rewind()
val intValues = IntArray(width * height)
bitmap.getPixels(intValues, 0, width, 0, 0, width, height)
var pixel = 0
for (y in 0 until height) {
for (x in 0 until width) {
val value = intValues[pixel++]
// Normalize channel values to [-1.0, 1.0]. This requirement varies by
// model. For example, some models might require values to be normalized
// to the range [0.0, 1.0] instead.
inputImage.putFloat(((value shr 16 and 0xFF) - mean) / std)
inputImage.putFloat(((value shr 8 and 0xFF) - mean) / std)
inputImage.putFloat(((value and 0xFF) - mean) / std)
}
}
inputImage.rewind()
return inputImage
}
Overwhelming? Not so as it is a standard procedure if you want to prepare the bytebuffer for the Interpreter for this common task or manipulate the final bytebuffer for rare cases (Computer Vision tasks) when your interpreter expects grayscale images or image arrays where the channel is at the second position (.tflite files that have been converted from Pytorch models).
Inference
This is one line of code:
interpreter.run(byteBuffer, segmentationMasks)
You feed the Interpreter with an input, inference is done and the results are passed as a ByteBuffer to the segmentationMasks object which is prepared based on the outputs of the model:
private val segmentationMasks: ByteBuffer
segmentationMasks = ByteBuffer.allocateDirect(1 * imageSize * imageSize * NUM_CLASSES * 4)
segmentationMasks.order(ByteOrder.nativeOrder())
Post-processing
Then the output buffer is converted to the mask bitmap:
private fun convertBytebufferMaskToBitmap(
inputBuffer: ByteBuffer,
imageWidth: Int,
imageHeight: Int,
backgroundImage: Bitmap,
colors: IntArray
): Triple<Bitmap, Bitmap, Map<String, Int>> {
val conf = Bitmap.Config.ARGB_8888
val maskBitmap = Bitmap.createBitmap(imageWidth, imageHeight, conf)
val resultBitmap = Bitmap.createBitmap(imageWidth, imageHeight, conf)
val scaledBackgroundImage =
ImageUtils.scaleBitmapAndKeepRatio(
backgroundImage,
imageWidth,
imageHeight
)
val mSegmentBits = Array(imageWidth) { IntArray(imageHeight) }
val itemsFound = HashMap<String, Int>()
inputBuffer.rewind()
for (y in 0 until imageHeight) {
for (x in 0 until imageWidth) {
var maxVal = 0f
mSegmentBits[x][y] = 0
for (c in 0 until NUM_CLASSES) {
val value = inputBuffer
.getFloat((y * imageWidth * NUM_CLASSES + x * NUM_CLASSES + c) * 4)
if (c == 0 || value > maxVal) {
maxVal = value
mSegmentBits[x][y] = c
}
}
val label = labelsArrays[mSegmentBits[x][y]]
val color = colors[mSegmentBits[x][y]]
itemsFound.put(label, color)
val newPixelColor = ColorUtils.compositeColors(
colors[mSegmentBits[x][y]],
scaledBackgroundImage.getPixel(x, y)
)
resultBitmap.setPixel(x, y, newPixelColor)
maskBitmap.setPixel(x, y, colors[mSegmentBits[x][y]])
}
}
return Triple(resultBitmap, maskBitmap, itemsFound)
}
You can see all the files at this TensorFlow Lite example.
Task Library
If you do not need to use the above low level inference and image manipulation, then you can use the high level image segmentation from the Task Library. Here the preprocessing is done with the help of the Tensorflow Lite Support library that loads the bitmap and the Task library uses the metadata of the .tflite file (at the assets folder) to do the cropping and the normalization automatically.
Initialization
private lateinit var imageSegmenter: ImageSegmenter
val options = ImageSegmenter.ImageSegmenterOptions.builder()
.setOutputType(OutputType.CATEGORY_MASK)
.build()
imageSegmenter = ImageSegmenter.createFromFileAndOptions(
getApplication(),
"deeplabv3.tflite",
options
)
Pre-processing
val tensorImage = TensorImage.fromBitmap(bitmap)
Inference
val results = imageSegmenter.segment(tensorImage)
Post-processing
val tensorMask = result.masks[0]
val rawMask = tensorMask.tensorBuffer.intArray
val output = Bitmap.createBitmap(
tensorMask.width,
tensorMask.height,
Bitmap.Config.ARGB_8888
)
for (y in 0 until tensorMask.height) {
for (x in 0 until tensorMask.width) {
output.setPixel(
x,
y,
if (rawMask[y * tensorMask.width + x] == 0) Color.TRANSPARENT else Color.BLACK
)
}
}
scaledMaskBitmap = Bitmap.createScaledBitmap(output, bitmap.width, bitmap.height, true)
Totally effortless, performant and able to use the GPU delegate. A summary of this API would be:
1. Uses the Support Library to load the image.
2. Task Library draws information from the .tflite’s metadata to do the cropping and the normalization.
3. The returned ByteBuffer can be converted immediately to an IntArray and then to a Bitmap
MediaPipe’s Segmentation API
By the time this blog was created there was no official example at MediaPipe’s web page. On the other hand there were enough examples for other popular tasks as Object detection, Image classification etc and the API was so handy that there was no difficulty to set up the ImageSegmenter and do the segmentation task. You can take a look at the source code of the ImageSegmenter class here and the API reference. At that moment the dependency that was used at build.gradle file was:
implementation("com.google.mediapipe:tasks-vision:0.1.0-alpha-5")
Initialization
private lateinit var imageSegmenter: ImageSegmenter
val baseOptions = BaseOptions.builder()
.setModelAssetPath("deeplabv3.tflite")
.setDelegate(Delegate.CPU)
.build()
val options =ImageSegmenter.ImageSegmenterOptions.builder()
.setOutputType(ImageSegmenter.ImageSegmenterOptions.OutputType.CATEGORY_MASK)
.setBaseOptions(baseOptions)
.setResultListener(this::returnLivestreamResult)
.setErrorListener(this::returnLivestreamError)
.build()
imageSegmenter =
ImageSegmenter.createFromOptions(
getApplication(),
options
)
The model is placed again at the assets folder and contains metadata which give the info for the pre-processing (cropping, normalization). The Builder() can be set up with a CPU or GPU option easily.
Loading the image
val image = BitmapImageBuilder(bitmap).build()
Inference
imageSegmenter.segment(image)
Post-processing
Two options here.
A. The result is acquired by the listener and the final mask is created:
private fun returnLivestreamResult(
result: ImageSegmenterResult, image: MPImage
) {
// We only need the first mask for this sample because we are using
// the OutputType CATEGORY_MASK, which only provides a single mask.
val mPImage = result.segmentations().first()
val pixels = IntArray(ByteBufferExtractor.extract(mPImage).capacity())
for (i in pixels.indices) {
val index = ByteBufferExtractor.extract(mPImage).get(i).toInt()
val color =
if (index in 1..20) Color.BLACK else Color.TRANSPARENT
pixels[i] = color
}
val image = Bitmap.createBitmap(
pixels,
mPImage.width,
mPImage.height,
Bitmap.Config.ARGB_8888
)
scaledMaskBitmap = image
}
B. With dependency
implementation("com.google.mediapipe:tasks-vision:0.1.0-alpha-6")
You can get the ImageSegmenter result as a return value of
val result = imageSegmenter.segment(image)
Then the post-processing is the same:
// We only need the first mask for this sample because we are using
// the OutputType CATEGORY_MASK, which only provides a single mask.
val mPImage = result.segmentations().first()
val pixels = IntArray(ByteBufferExtractor.extract(mPImage).capacity())
for (i in pixels.indices) {
val index = ByteBufferExtractor.extract(mPImage).get(i).toInt()
val color =
if (index in 1..20) Color.BLACK else Color.TRANSPARENT
pixels[i] = color
}
val image = Bitmap.createBitmap(
pixels,
mPImage.width,
mPImage.height,
Bitmap.Config.ARGB_8888
)
scaledMaskBitmap = image
You can also take a look at this file for other options and set up like when you want to segment a live stream video. Relationship between Task library and MediaPipe solutions at that web page.
Conclusion
Either you are an experienced or inexperienced user you can find the API that suits your needs. From low level TensorFlow Lite to high level Task Library and the latest MediaPipe’s implementation you can perform common or uncommon machine learning tasks successfully! Image segmentation is as easy as ever.
Special thanks to Paul Ruiz and Lu Wang for giving feedback on this post.