Use TensorFlow Lite Model Maker with a custom dataset

George Soloupis
6 min readNov 5, 2023

Written by George Soloupis ML and Android GDE.

In this blog post, we will explore the practical application of TensorFlow Lite Model Maker using a custom audio dataset. This library revolutionizes the training process by implementing transfer learning techniques, enabling the retraining of an existing TensorFlow model with minimal sample data and training time. This approach significantly reduces the complexity of the procedure, as it automates various tasks such as resampling audio files and splitting and preparing the dataset. By leveraging this powerful tool, developers can efficiently train models tailored to their specific audio data, making the entire process seamless and highly efficient.

The purpose of this endeavor was to create a .tflite file, optimized for Android devices, capable of distinguishing background noises from gunshot sounds. The ultimate goal was to develop an application that counts the number of gunshots in scenes from the John Wick film series. One of the critical aspects of this project was curating a high-quality dataset specifically for background noises. The approach was inspired by a TensorFlow example, which served as a valuable reference point for developing a robust solution.

Whether utilizing a default audio dataset or a custom one, it’s crucial to have a well-curated collection of background noises. This ensures that the model can accurately distinguish the target sounds from other noises, including periods of silence. In our case, the background samples were initially in the form of lengthy WAV files spanning a minute or more. To effectively train the model and reserve samples for the test dataset, these lengthy files needed to be divided into smaller two-second segments. Additionally, we amalgamated various samples from different sources, creating a diverse and comprehensive set of background noises and moments of silence. This process laid the foundation for training a robust and accurate audio classification model.

On the other side, the gunshot dataset comprised of .mp3 files, which had to be converted into .wav format to make them compatible with the TensorFlow Lite Model Maker. These files were typically 3 to 4 seconds in duration, with the gunshot sound predominantly occurring within the initial 1 to 2 seconds of each file.

When using Model Maker for audio, you have to start with a model spec. This is the base model that your new model will extract information to learn about the new classes. It also affects how the dataset will be transformed to respect the models spec parameters like: sample rate, number of channels.

YAMNet is an audio event classifier trained on the AudioSet dataset to predict audio events from the AudioSet ontology. It’s input is expected to be at 16kHz and with 1 channel. You don’t need to do any resampling yourself. Model Maker takes care of that for you.

  • frame_length is to decide how long each training sample is. In this case EXPECTED_WAVEFORM_LENGTH * 1s
  • frame_steps is to decide how far apart are the training samples. In this case, the ith sample will start at EXPECTED_WAVEFORM_LENGTH * 2s after the (i-1)th sample.

The rationale behind setting these specific values was to address certain constraints inherent in real-world datasets. For instance, in the gunshot dataset, the sound is audible only for a brief moment at the start of the audio file. Choosing a longer frame duration could potentially capture the gunshot sound comprehensively. However, setting the frame duration excessively long would lead to a reduction in the number of samples available for training. Striking a balance between capturing critical audio features and ensuring an adequate sample size for training was essential to develop a model capable of accurately classifying the gunshot sounds within the given dataset. You can find more info about using transfer learning for audio classification at this example.

Colaboratory served as the platform for training the model and setting up the environment and dataset. However, during the time this article was written, a minor complication arose due to the compatibility concerns between Python 3.10 and some of the TensorFlow Lite Model Maker requirements. To overcome this challenge, a Conda environment was utilized within Colab, allowing us to successfully train the model and extract the desired results. This strategic workaround ensured a smooth and efficient workflow, despite the specific compatibility issues encountered. You can follow along with the python notebook here where you can see the minimum steps that are needed to prepare and train a .tflite file with Model Maker.

The procedure consists of 2 steps. One is the creation of the Conda environment and the second is the preparation and the training of the model.

Create the Conda environment with Python 3.9 version:

!wget https://repo.anaconda.com/miniconda/Miniconda3-py39_23.3.1-0-Linux-x86_64.sh
!chmod +x Miniconda3-py39_23.3.1-0-Linux-x86_64.sh
!./Miniconda3-py39_23.3.1-0-Linux-x86_64.sh -b -f -p /usr/local
!conda update conda

import sys
sys.path.append('/usr/local/lib/python3.9/site-packages')

!conda create -n myenv python=3.9

then install all packages inside:

%%shell
eval "$(conda shell.bash hook)"
conda activate myenv && pip install tflite-model-maker && sudo apt -y install libportaudio2 && pip install numpy==1.23.4 && pip install ipykernel && pip install seaborn

Prepare and train the model:

Everything that is going to be used inside the Conda environment has to be written as a .py script:

%%writefile train_model_script.py

import tensorflow as tf
import tflite_model_maker as mm
from tflite_model_maker import audio_classifier
import os

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import itertools
import glob
import random
import IPython

from IPython.display import Audio, Image
from IPython import display
from scipy.io import wavfile

####################################
# Directory
data_dir = './dataset'

#Specs
################################################################################ spec = audio_classifier.BrowserFftSpec()
spec = audio_classifier.YamNetSpec(
keep_yamnet_and_custom_heads=True,
frame_step=1 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH,
frame_length=2 * audio_classifier.YamNetSpec.EXPECTED_WAVEFORM_LENGTH)

####################################
# Data
train_data = audio_classifier.DataLoader.from_folder(
spec, os.path.join(data_dir, 'train'), cache=True)
train_data, validation_data = train_data.split(0.8)
test_data = audio_classifier.DataLoader.from_folder(
spec, os.path.join(data_dir, 'test'), cache=True)

####################################
# Train
batch_size = 128
epochs = 20

print('Training the model')
model = audio_classifier.create(
train_data,
spec,
validation_data,
batch_size=batch_size,
epochs=epochs)

#####################################
# export
models_path = './models'
print(f'Exporing the TFLite model to {models_path}')
model.export(models_path, tflite_filename='gunshot_model.tflite')

and then execute inside the environment:

%%shell
eval "$(conda shell.bash hook)"
conda activate myenv
python train_model_script.py

It’s important to note that the deviation from the straightforward Python cell execution occurred because of the incompatibility between Model Maker libraries and Python 3.10 at the time of writing this article. Under normal circumstances, the process is simplified to installing tflite-model-maker and proceeding with the workflow without such additional steps. The temporary workaround ensured that the project could progress smoothly despite the specific compatibility challenges faced during that period.

pip install tflite-model-maker

After training, a .tflite model is created at the directory that you can download and use. Let’s check the inputs and outputs with Netron.

Inputs and Outputs properties.

A great thing to mention is that after the training the model has two outputs:

Outputs of the model.

There are two distinct outputs to consider in this context: one from the Yamnet model, providing probabilities for 521 audio classes, and the other from our custom dataset, offering probabilities for the two specific classes — background noises and gunshots. This differentiation is crucial because testing environments are multifaceted and diverse, extending beyond simple scenarios like gunshots sounds.

By utilizing Yamnet’s output, we can effectively filter out irrelevant audio data. For instance, in a gunshot sounds use case, if Yamnet doesn’t classify certain sounds as gunshots, it indicates that the output from our model might have an inaccurate or irrelevant classification for those instances. This interplay between the broader Yamnet model and our custom dataset output ensures a more nuanced and accurate analysis of complex audio environments, enhancing the overall reliability of our classification system.

You can download the background dataset (2 seconds .wav files), the gunshot dataset and the final .tflite file.

Conclusion

In this blog post, the practical application of TensorFlow Lite Model Maker using a custom audio dataset is explored. The post discusses the implementation of transfer learning techniques, allowing the retraining of existing TensorFlow models with minimal sample data and training time. This approach automates tasks such as resampling audio files and dataset preparation, simplifying the process significantly. The project’s goal was to create a .tflite file optimized for Android devices, capable of distinguishing background noises from gunshots sounds. To achieve this, a high-quality dataset was curated, balancing frame duration to capture essential audio features. The article details the challenges faced, including compatibility issues with Python 3.10 and Model Maker libraries, and presents a workaround using a Conda environment within Colab. The final model output includes probabilities for 521 audio classes from Yamnet and specific classes — background noises and gunshots — from the custom dataset. The integration of Yamnet’s output allows for accurate filtering of irrelevant audio data, enhancing the reliability of the classification system in diverse audio environments.

--

--

George Soloupis

I am a pharmacist turned android developer and machine learning engineer. Right now I’m a senior android developer at Invisalign, a ML & Android GDE.