Классифицировать новый образ на обученной пользовательской модели в deeplearning4j (сверточная сеть)
Я новичок в Deeplearning4J. Я уже экспериментировал с его функциональностью word2vec, и все было хорошо. Но сейчас я немного запутался в отношении классификации изображений. Я играл с этим примером:
Я изменил флаг "сохранить" на true, и моя модель сохраняется в файле model.bin. Теперь начинается проблемная часть (извините, если это звучит глупо, может быть, я упускаю что-то действительно очевидное здесь)
Я создал отдельный класс под названием AnimalClassifier, и его целью является загрузка модели из файла model.bin, восстановление нейронной сети из него, а затем классификация отдельного изображения с использованием восстановленной сети. Для этого единственного изображения я создал папку "temp" -> dl4j-examples/src/main/resources/animals/temp/, где я поместил изображение белого медведя, которое ранее использовалось в процессе обучения, в AnimalsClassification.java (я хотел быть уверен, что это изображение будет классифицировано правильно - поэтому я повторно использовал изображение из папки "медведь").
Вот мой код, пытающийся классифицировать белого медведя:
protected static int height = 100;
protected static int width = 100;
protected static int channels = 3;
protected static int numExamples = 1;
protected static int numLabels = 1;
protected static int batchSize = 10;
protected static long seed = 42;
protected static Random rng = new Random(seed);
protected static int listenerFreq = 1;
protected static int iterations = 1;
protected static int epochs = 7;
protected static double splitTrainTest = 0.8;
protected static int nCores = 2;
protected static boolean save = true;
protected static String modelType = "AlexNet"; //
public static void main(String[] args) throws Exception {
String basePath = FilenameUtils.concat(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/");
MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(basePath + "model.bin", true);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
File mainPath = new File(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/animals/temp/");
FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng);
BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, batchSize);
InputSplit[] inputSplit = fileSplit.sample(pathFilter, 1);
InputSplit analysedData = inputSplit[0];
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels);
recordReader.initialize(analysedData);
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 0, 4);
while (dataIter.hasNext()) {
DataSet testDataSet = dataIter.next();
String expectedResult = testDataSet.getLabelName(0);
List<String> predict = multiLayerNetwork.predict(testDataSet);
String modelResult = predict.get(0);
System.out.println("\nFor example that is labeled " + expectedResult + " the model predicted " + modelResult + "\n\n");
}
}
После запуска я получаю сообщение об ошибке:
java.lang.UnsupportedOperationException по адресу org.datavec.api.writable.ArrayWritable.toInt(ArrayWritable.java:47) по адресу org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.getDataSet(RecordReader.jpg).jpg datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:186) по адресу org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:deIt_Date_Reader.Reader.Reader.Reader.Reader.Reader.Rat_Rat_Rate_Rate_Rate_Rat_R_D_W_D_W_D_W_D_W_W_W_W_set_set_set_set_set_set_set_ф_ф_ф_сервис at org.deeplearning4j.examples.convolution.AnimalClassifier.main(AnimalClassifier.java:66) Отключен от целевой виртуальной машины, адрес: "127.0.0.1:63967", транспорт: исключение "socket" в потоке "main" java.lang.IllegalStateException: имена меток не определены в этом наборе данных. Добавьте имена меток, чтобы использовать getLabelName с идентификатором. в org.nd4j.linalg.dataset.DataSet.getLabelName(DataSet.java:1106) в org.deeplearning4j.examples.convolution.AnimalClassifier.main(AnimalClassifier.java:68)
Я вижу, что есть метод public void setLabels(метки INDArray) в MultiLayerNetwork.java, но я не понимаю, как его использовать (особенно когда он принимает в качестве аргумента INDArray).
Я также смущен, почему я должен указать количество возможных меток в конструкторе RecordReaderDataSetIterator. Я ожидаю, что модель уже знает, какие метки использовать (не следует ли использовать метки, которые использовались во время обучения автоматически?). Я думаю, может быть, я загружаю картинку совершенно неправильно...
Подводя итог, я хотел бы добиться просто следующего:
- восстановить сеть из модели (это работает)
- загрузить изображение для классификации (также работает)
- классифицируйте это изображение, используя те же ярлыки, которые использовались во время тренировки (медведь, олень, утка, черепаха) (сложная часть)
Заранее благодарю за помощь или любые советы!
1 ответ
Итак, подытожив несколько вопросов здесь: запись для изображений - это 2 записи в коллекции. Второй 1 - это ярлык. Индекс метки зависит от типа записи, которую вы передаете.
Вторая часть вашего вопроса: несколько записей могут быть отдельно от набора данных. Список ссылается на метку для элемента в определенной строке в мини-пакете.