Модель Tensorflow дает разные результаты на Android и Python

Я пытаюсь запустить переобученную модель на Android, но результаты отличаются на Python и Android. Модель точно такая же, и я с подозрением отношусь к нормализации и постобработке на Android. Я также использую точно такие же файлы изображений и масок.

Код Python:

img = np.float32(cv2.imread("fig6.png")/255.0)#/255.0
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
mask = np.float32(cv2.imread("mask4.png")/255.0)


sess = K.get_session()
result = sess.run("outputs", 
                {"inputs_img":np.expand_dims(img,0), "inputs_mask":np.expand_dims(mask,0)})
plt.imshow(result[0])
plt.show()

Java-код:


 private static final String MODEL_FILE = "file:///android_asset/optimized_model2.pb";
    private TensorFlowInferenceInterface inferenceInterface;

    public static final String INPUT_NAME_MASK = "inputs_mask";
    public static final String INPUT_NAME_IMAGE = "inputs_img";
    public static final String OUTPUT_NAME = "outputs";

    private static final long[] INPUT_SIZE = {1,512,512,3};
    private static final String[] OUTPUT_NODES = {"outputs"};

    inferenceInterface = new TensorFlowInferenceInterface(getAssets(),MODEL_FILE);

    float[] image_values = normalizeBitmap(resizeBitmap(bm));
    float[] mask_values = normalizeBitmap(resizeBitmap(bmMask));

    float[] mOutputs = new float[image_values.length];

    inferenceInterface.feed(INPUT_NAME_IMAGE,image_values,INPUT_SIZE);
    inferenceInterface.feed(INPUT_NAME_MASK,mask_values,INPUT_SIZE);

    inferenceInterface.run(OUTPUT_NODES,true);

    inferenceInterface.fetch(OUTPUT_NAME,mOutputs);

    bm = getBackToBitmap(mOutputs );

    public float[] normalizeBitmap(Bitmap bitmap){
    int[] int_values = new int[bitmap.getHeight()*bitmap.getWidth()];
    float[] floatValues = new float[bitmap.getHeight()*bitmap.getWidth()*3];

 bitmap.getPixels(int_values,0,bitmap.getWidth(),0,0,bitmap.getWidth(),bitmap.getHeight());

        for (int i = 0; i < int_values.length; ++i) {
            final int val = int_values[i];
            floatValues[i * 3] = ((val >> 16) & 0xFF) / 255.0f;
            floatValues[i * 3 + 1] = ((val >> 8) & 0xFF) / 255.0f;
            floatValues[i * 3 + 2] = (val & 0xFF) / 255.0f;
        }
        return floatValues;
    }

    public Bitmap getBackToBitmap(float[] input){

        int[] intValues = new int[input.length/3];

        Bitmap bitmap = Bitmap.createBitmap(512,512, Bitmap.Config.RGB_565);
        for (int i=0;i<intValues.length;i++){
            intValues[i] =
                    0xFF000000
                            | (((int) (input[i * 3] * 255)) << 16)
                            | (((int) (input[i * 3 + 1] * 255)) << 8)
                            | ((int) (input[i * 3 + 2] * 255));
        }
        bitmap.setPixels(intValues,0,bm.getWidth(),0,0,bm.getWidth(),bm.getHeight());
        return bitmap;
    }


0 ответов

Другие вопросы по тегам