Классифицировать новый образ на обученной пользовательской модели в deeplearning4j (сверточная сеть)

Я новичок в Deeplearning4J. Я уже экспериментировал с его функциональностью word2vec, и все было хорошо. Но сейчас я немного запутался в отношении классификации изображений. Я играл с этим примером:

https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/AnimalsClassification.java

Я изменил флаг "сохранить" на true, и моя модель сохраняется в файле model.bin. Теперь начинается проблемная часть (извините, если это звучит глупо, может быть, я упускаю что-то действительно очевидное здесь)

Я создал отдельный класс под названием AnimalClassifier, и его целью является загрузка модели из файла model.bin, восстановление нейронной сети из него, а затем классификация отдельного изображения с использованием восстановленной сети. Для этого единственного изображения я создал папку "temp" -> dl4j-examples/src/main/resources/animals/temp/, где я поместил изображение белого медведя, которое ранее использовалось в процессе обучения, в AnimalsClassification.java (я хотел быть уверен, что это изображение будет классифицировано правильно - поэтому я повторно использовал изображение из папки "медведь").

Вот мой код, пытающийся классифицировать белого медведя:

protected static int height = 100;
    protected static int width = 100;
    protected static int channels = 3;
    protected static int numExamples = 1;
    protected static int numLabels = 1;
    protected static int batchSize = 10;

    protected static long seed = 42;
    protected static Random rng = new Random(seed);
    protected static int listenerFreq = 1;
    protected static int iterations = 1;
    protected static int epochs = 7;
    protected static double splitTrainTest = 0.8;
    protected static int nCores = 2;
    protected static boolean save = true;

    protected static String modelType = "AlexNet"; //

    public static void main(String[] args) throws Exception {

        String basePath = FilenameUtils.concat(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/");
        MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(basePath + "model.bin", true);

        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
        File mainPath = new File(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/animals/temp/");
        FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng);
        BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, batchSize);


        InputSplit[] inputSplit = fileSplit.sample(pathFilter, 1);
        InputSplit analysedData = inputSplit[0];


        ImageRecordReader recordReader = new ImageRecordReader(height, width, channels);
        recordReader.initialize(analysedData);
        DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 0, 4);
        while (dataIter.hasNext()) {
            DataSet testDataSet = dataIter.next();

            String expectedResult = testDataSet.getLabelName(0);
            List<String> predict = multiLayerNetwork.predict(testDataSet);
            String modelResult = predict.get(0);
            System.out.println("\nFor example that is labeled " + expectedResult + " the model predicted " + modelResult + "\n\n");
        }
    }

После запуска я получаю сообщение об ошибке:

java.lang.UnsupportedOperationException по адресу org.datavec.api.writable.ArrayWritable.toInt(ArrayWritable.java:47) по адресу org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.getDataSet(RecordReader.jpg).jpg datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:186) по адресу org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:deIt_Date_Reader.Reader.Reader.Reader.Reader.Reader.Rat_Rat_Rate_Rate_Rate_Rat_R_D_W_D_W_D_W_D_W_W_W_W_set_set_set_set_set_set_set_ф_ф_ф_сервис at org.deeplearning4j.examples.convolution.AnimalClassifier.main(AnimalClassifier.java:66) Отключен от целевой виртуальной машины, адрес: "127.0.0.1:63967", транспорт: исключение "socket" в потоке "main" java.lang.IllegalStateException: имена меток не определены в этом наборе данных. Добавьте имена меток, чтобы использовать getLabelName с идентификатором. в org.nd4j.linalg.dataset.DataSet.getLabelName(DataSet.java:1106) в org.deeplearning4j.examples.convolution.AnimalClassifier.main(AnimalClassifier.java:68)

Я вижу, что есть метод public void setLabels(метки INDArray) в MultiLayerNetwork.java, но я не понимаю, как его использовать (особенно когда он принимает в качестве аргумента INDArray).

Я также смущен, почему я должен указать количество возможных меток в конструкторе RecordReaderDataSetIterator. Я ожидаю, что модель уже знает, какие метки использовать (не следует ли использовать метки, которые использовались во время обучения автоматически?). Я думаю, может быть, я загружаю картинку совершенно неправильно...

Подводя итог, я хотел бы добиться просто следующего:

  1. восстановить сеть из модели (это работает)
  2. загрузить изображение для классификации (также работает)
  3. классифицируйте это изображение, используя те же ярлыки, которые использовались во время тренировки (медведь, олень, утка, черепаха) (сложная часть)

Заранее благодарю за помощь или любые советы!

1 ответ

Итак, подытожив несколько вопросов здесь: запись для изображений - это 2 записи в коллекции. Второй 1 - это ярлык. Индекс метки зависит от типа записи, которую вы передаете.

Вторая часть вашего вопроса: несколько записей могут быть отдельно от набора данных. Список ссылается на метку для элемента в определенной строке в мини-пакете.

Другие вопросы по тегам