Модель Tensorflow дает разные результаты на Android и Python
Я пытаюсь запустить переобученную модель на Android, но результаты отличаются на Python и Android. Модель точно такая же, и я с подозрением отношусь к нормализации и постобработке на Android. Я также использую точно такие же файлы изображений и масок.
Код Python:
img = np.float32(cv2.imread("fig6.png")/255.0)#/255.0
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
mask = np.float32(cv2.imread("mask4.png")/255.0)
sess = K.get_session()
result = sess.run("outputs",
{"inputs_img":np.expand_dims(img,0), "inputs_mask":np.expand_dims(mask,0)})
plt.imshow(result[0])
plt.show()
Java-код:
private static final String MODEL_FILE = "file:///android_asset/optimized_model2.pb";
private TensorFlowInferenceInterface inferenceInterface;
public static final String INPUT_NAME_MASK = "inputs_mask";
public static final String INPUT_NAME_IMAGE = "inputs_img";
public static final String OUTPUT_NAME = "outputs";
private static final long[] INPUT_SIZE = {1,512,512,3};
private static final String[] OUTPUT_NODES = {"outputs"};
inferenceInterface = new TensorFlowInferenceInterface(getAssets(),MODEL_FILE);
float[] image_values = normalizeBitmap(resizeBitmap(bm));
float[] mask_values = normalizeBitmap(resizeBitmap(bmMask));
float[] mOutputs = new float[image_values.length];
inferenceInterface.feed(INPUT_NAME_IMAGE,image_values,INPUT_SIZE);
inferenceInterface.feed(INPUT_NAME_MASK,mask_values,INPUT_SIZE);
inferenceInterface.run(OUTPUT_NODES,true);
inferenceInterface.fetch(OUTPUT_NAME,mOutputs);
bm = getBackToBitmap(mOutputs );
public float[] normalizeBitmap(Bitmap bitmap){
int[] int_values = new int[bitmap.getHeight()*bitmap.getWidth()];
float[] floatValues = new float[bitmap.getHeight()*bitmap.getWidth()*3];
bitmap.getPixels(int_values,0,bitmap.getWidth(),0,0,bitmap.getWidth(),bitmap.getHeight());
for (int i = 0; i < int_values.length; ++i) {
final int val = int_values[i];
floatValues[i * 3] = ((val >> 16) & 0xFF) / 255.0f;
floatValues[i * 3 + 1] = ((val >> 8) & 0xFF) / 255.0f;
floatValues[i * 3 + 2] = (val & 0xFF) / 255.0f;
}
return floatValues;
}
public Bitmap getBackToBitmap(float[] input){
int[] intValues = new int[input.length/3];
Bitmap bitmap = Bitmap.createBitmap(512,512, Bitmap.Config.RGB_565);
for (int i=0;i<intValues.length;i++){
intValues[i] =
0xFF000000
| (((int) (input[i * 3] * 255)) << 16)
| (((int) (input[i * 3 + 1] * 255)) << 8)
| ((int) (input[i * 3 + 2] * 255));
}
bitmap.setPixels(intValues,0,bm.getWidth(),0,0,bm.getWidth(),bm.getHeight());
return bitmap;
}