ValueError: `validation_data` должен быть кортежем`(val_x, val_y, val_sample_weight)`или`(val_x, val_y)`. Найдено: <__ main __. Объект генератора в>

Я знаю, что меня спросили уже несколько раз, но ни один ответ не отвечал моим требованиям.

У меня есть CSV-файл с текстом (содержание газеты) и меткой, столбцы 0 и 1.

Я пытаюсь написать свой первый генератор для классификации текста, но получаю ошибку

ValueError: `validation_data` should be a tuple `(val_x, val_y, val_sample_weight)` or `(val_x, val_y)`. Found: <__main__.Generator object at 0xd376a6e80>

Вот класс

class Generator(object):

    def __init__(self, data_file):
        self.data_file = data_file
        self.length = -1

    def __iter__(self):
        while True:
            with open(self.data_file, 'r') as f:
                reader = csv.reader(f)
                for row in reader:
                    yield row[0], row[1]

    def __len__(self):
        if self.length ==  -1:
            n_rows = 0
            with open(self.data_file, 'r') as f:
                reader = csv.reader(f)
                for row in reader:
                    n_rows += 1
            self.length = n_rows
        return self.length

Я тоже пробовал с yield row[0], row[1], также return, Ни один не работал.

Спасибо за помощь

0 ответов

У меня была такая же ошибка, пока я не заставил свой класс Generator унаследовать методы от keras.utils.Sequence (см. Документацию fit_generator). Вы можете попробовать это:

import keras

class Generator(keras.utils.Sequence):
    def __init__(self, data_file):
        self.data_file = data_file
        self.length = -1

    def __iter__(self):
        while True:
            with open(self.data_file, 'r') as f:
                reader = csv.reader(f)
                for row in reader:
                    yield row[0], row[1]

    def __len__(self):
        if self.length ==  -1:
            n_rows = 0
            with open(self.data_file, 'r') as f:
                reader = csv.reader(f)
                for row in reader:
                    n_rows += 1
            self.length = n_rows
        return self.length
Другие вопросы по тегам