Tensorflow-lite - получение растрового изображения на выходе квантованной модели
Мы работаем над приложением семантической сегментации в Android с использованием tenorflow-lite. Используемая модель deeplabv3 '.tflite' имеет ввод типа (ImageTensor) uint8[1,300,300,3] и вывод типа (SemanticPredictions) uint8[300,300]. Мы были успешно возможность запустить модель и получить выход в формате ByteBuffer с помощью метода tflite.run. Но нам не удалось извлечь изображение из этого вывода в java. Модель, которая обучается с помощью набора данных pascal voc и фактически была преобразована в Формат tflite из модели TF: ' mobilenetv2_dm05_coco_voc_trainval'.
Кажется, что проблема похожа на следующий вопрос stackru: tenorflow-lite - использование интерпретатора tflite для получения изображения на выходе
Похоже, проблема, связанная с преобразованием типов данных с плавающей точкой, решена в проблеме с github: https://github.com/tensorflow/tensorflow/issues/23483
Итак, как мы можем правильно извлечь маску сегментации из выходных данных модели UINT8?
2 ответа
Попробуйте этот код:
/**
* Converts ByteBuffer with segmentation mask to the Bitmap
*
* @param byteBuffer Output ByteBuffer from Interpreter.run
* @param imgSizeX Model output image width
* @param imgSizeY Model output image height
* @return Mono color Bitmap mask
*/
private Bitmap convertByteBufferToBitmap(ByteBuffer byteBuffer, int imgSizeX, int imgSizeY){
byteBuffer.rewind();
byteBuffer.order(ByteOrder.nativeOrder());
Bitmap bitmap = Bitmap.createBitmap(imgSizeX , imgSizeY, Bitmap.Config.ARGB_4444);
int[] pixels = new int[imgSizeX * imgSizeY];
for (int i = 0; i < imgSizeX * imgSizeY; i++)
if (byteBuffer.getFloat()>0.5)
pixels[i]= Color.argb(100, 255, 105, 180);
else
pixels[i]=Color.argb(0, 0, 0, 0);
bitmap.setPixels(pixels, 0, imgSizeX, 0, 0, imgSizeX, imgSizeY);
return bitmap;
}
Это работает для модели с моно цветным выходом.
Что-то вроде:
Byte[][] output = new Byte[300][300];
Bitmap bitmap = Bitmap.createBitmap(300,300,Bitmap.Config.ARGB_8888);
for (int row = 0; row < output.length ; row++) {
for (int col = 0; col < output[0].length ; col++) {
int pixelIntensity = output[col][row];
bitmap.setPixel(col,row,Color.rgb(pixelIntensity,pixelIntensity,pixelIntensity));
}
?