Android implementation of the BERT model, developed through KerasNLP.

George Soloupis
4 min readFeb 26, 2024

--

Written by George Soloupis ML and Android GDE.

This is a blog post that demonstrates the usage of a .tflite model inside android that was generated using KerasNLP. The task that this ML model serves is the Question Answering where its objective is to pinpoint the exact span of text within the document housing the answer. More details of the model and the code can be found at this blog post. The generated .tflite file can handle a sequence length of 512 input tokens making it appropriate to answer questions at almost a full A4 document.

We won’t delve into the entire Android codebase, but rather direct our attention to the specific segment responsible for executing the ML model.

First, to use the TensorFlow Lite Interpreter inside android we have to set the tflite dependency at the app’s build.gradle.kts file:

implementation("org.tensorflow:tensorflow-lite:2.14.0")

To have a general idea of the inputs and outputs of the model we can use Netron.app:

Inputs and outputs of the model.

The .tflite file can be added to the assets folder and we can start the interpreter loading the model as:

public MappedByteBuffer loadModelFile(AssetManager assetManager) throws IOException {
try (AssetFileDescriptor fileDescriptor = assetManager.openFd(MODEL_PATH);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
}

Based on the netron utility analysis provided above, it’s evident that the model anticipates three integer inputs and yields two float outputs. These precise input-output requirements serve as our guideline for generating the necessary data within the Android codebase.

MAX_SEQ_LEN = 512;
int[][] inputIds = new int[1][MAX_SEQ_LEN];
int[][] inputMask = new int[1][MAX_SEQ_LEN];
int[][] segmentIds = new int[1][MAX_SEQ_LEN];
float[][] startLogits = new float[1][MAX_SEQ_LEN];
float[][] endLogits = new float[1][MAX_SEQ_LEN];

Based on the info from this blog post the context and the question has to be converted into tokens with the BertWordPieceTokenizer which splits the sentences into words based on the whitespaces. The same work is done inside the android project with the FeatureConverter class:

public Feature convert(String query, String context) {
List<String> queryTokens = tokenizer.tokenize(query);
if (queryTokens.size() > maxQueryLen) {
queryTokens = queryTokens.subList(0, maxQueryLen);
}

List<String> origTokens = Arrays.asList(context.trim().split("\\s+"));
List<Integer> tokenToOrigIndex = new ArrayList<>();
List<String> allDocTokens = new ArrayList<>();
for (int i = 0; i < origTokens.size(); i++) {
String token = origTokens.get(i);
List<String> subTokens = tokenizer.tokenize(token);
for (String subToken : subTokens) {
tokenToOrigIndex.add(i);
allDocTokens.add(subToken);
}
}

// -3 accounts for [CLS], [SEP] and [SEP].
int maxContextLen = maxSeqLen - queryTokens.size() - 3;
if (allDocTokens.size() > maxContextLen) {
allDocTokens = allDocTokens.subList(0, maxContextLen);
}

List<String> tokens = new ArrayList<>();
List<Integer> segmentIds = new ArrayList<>();

// Map token index to original index (in feature.origTokens).
Map<Integer, Integer> tokenToOrigMap = new HashMap<>();

// Start of generating the features.
tokens.add("[CLS]");
segmentIds.add(0);

// For Text Input.
for (int i = 0; i < allDocTokens.size(); i++) {
String docToken = allDocTokens.get(i);
tokens.add(docToken);
segmentIds.add(0);
tokenToOrigMap.put(tokens.size(), tokenToOrigIndex.get(i));
}

// For Separation.
tokens.add("[SEP]");
segmentIds.add(0);

// For query input.
for (String queryToken : queryTokens) {
tokens.add(queryToken);
segmentIds.add(1);
}

// For ending mark.
tokens.add("[SEP]");
segmentIds.add(1);

List<Integer> inputIds = tokenizer.convertTokensToIds(tokens);
List<Integer> inputMask = new ArrayList<>(Collections.nCopies(inputIds.size(), 1));

while (inputIds.size() < maxSeqLen) {
inputIds.add(0);
inputMask.add(0);
segmentIds.add(0);
}

return new Feature(inputIds, inputMask, segmentIds, origTokens, tokenToOrigMap);
}

Having the inputs we can load them into the appropriate objects and feed the TensorFlow Lite Interpreter:

for (int j = 0; j < MAX_SEQ_LEN; j++) {
inputIds[0][j] = feature.inputIds[j];
inputMask[0][j] = feature.inputMask[j];
segmentIds[0][j] = feature.segmentIds[j];
}

Object[] inputs = { inputIds, inputMask, segmentIds};

Map<Integer, Object> output = new HashMap<>();
// Arrange outputs based on what Netron.app spits out.
output.put(0, endLogits);
output.put(1, startLogits);

Log.v(TAG, "Run inference...");

tflite.runForMultipleInputsOutputs(inputs, output);
Log.v(TAG, "Convert answers...");
List<QaAnswer> answers = getBestAnswers(startLogits[0], endLogits[0], feature);

We select the best answer and we get the start and end token of the answer. Check the code here. That means that the answer starts at a specific word and ends at another one. The result is highlighted into the original text:

Mobile’s screen.

You can build the android project with the latest Android Studio using this Github repository. Download the model and place it inside the assets folder from here.

Conclusion

An example implementation of the Question Answering task in machine learning can be found within an Android application. Leveraging the TensorFlow Lite library, we can execute the BERT model from the KerasNLP library inside the project. The resulting implementation demonstrates the model’s capability to receive both a context and a question, subsequently producing tokens that accurately represent the precise span of text containing the answer within the document.

--

--

George Soloupis
George Soloupis

Written by George Soloupis

I am a pharmacist turned android developer and machine learning engineer. Right now I’m a senior android developer at Invisalign, a ML & Android GDE.