Tensorflow.js: dtype фида (int32) несовместим с типом ключа input_1 (float32)
Я перенес обучение из Mobilenet в свою модель и попытался сделать прогноз:
const img = document.querySelector("img");
const image = tf.reshape(tf.fromPixels(img), [1, 224, 224, 3]);
const pretrainedModelPrediction = pretrainedModel.predict(image);
const modelPrediction = model.predict(pretrainedModelPrediction);
const prediction = modelPrediction.as1D().argMax().dataSync()[0];
console.log({ prediction });
Ошибка в этой строке кода:
const pretrainedModelPrediction = pretrainedModel.predict(image);
С этой ошибкой:
tfjs.js:67 Uncaught (in promise) Error: The dtype of the feed (int32) is incompatible with that of the key 'input_1' (float32).
at new t (tfjs.js:67)
at assertFeedCompatibility (tfjs.js:67)
at e.add (tfjs.js:67)
at new e (tfjs.js:67)
at tfjs.js:67
at tfjs.js:49
at e.scopedRun (tfjs.js:49)
at e.tidy (tfjs.js:49)
at e.tidy (tfjs.js:49)
at s (tfjs.js:67)
Есть идеи, почему возникает эта ошибка и как ее исправить?
В качестве дополнительной информации:
- я использую
@tensorflow/tfjs
версия0.12.0
- Весь код сбоя (с моделью) находится здесь: https://github.com/aralroca/skin-cancer-detection-tfjs/tree/d0d288c84919410dd422a1a19de7b207b6f49000
1 ответ
Решение
image
относится к типу int32
. Вы можете транслировать его наfloat32
.
pretrainedModel.predict(image.cast('float32'));