Поиск в сетке с помощью GridSearchCV - scikit-learn (настройка гиперпараметров) с использованием ImageDataGenerator (keras)?
Как я могу выполнить настройку гиперпараметра, когда мой ввод изображения через ImageDataGenerator
? Мои данные тренировок и тестов представлены не в виде массивов (X_train, Y_train и т. Д.). Я хочу настроить свои гиперпараметры, используя GridSearchCV
от sklearn
а также ImageDataGenerator
от keras
,
Вот несколько фрагментов из кода, который я пробовал!
#(5) Train
train_datagen = ImageDataGenerator(rescale=1./255)
validation_datagen = ImageDataGenerator(rescale=1./255)
train_batchsize = 15
val_batchsize = 10
train_generator = train_datagen.flow_from_directory(
train_dir,
batch_size=train_batchsize,
class_mode='categorical',
shuffle=False)
validation_generator = validation_datagen.flow_from_directory(
validation_dir,
target_size=(image_size, image_size1),
batch_size=val_batchsize,
class_mode='categorical')
#Function for Creating Model
def create_model():
.....................
return model
model = KerasClassifier(build_fn=create_model, batch_size=1000, epochs=10, verbose = 1)
# Use scikit-learn to grid search
activation = ['relu', 'tanh', 'sigmoid', 'hard_sigmoid', 'linear'] # softmax, softplus, softsign
momentum = [0.0, 0.2, 0.4, 0.6, 0.8, 0.9]
neurons = [1, 5, 10, 15, 20, 25, 30]
init = ['uniform', 'lecun_uniform', 'normal', 'zero', 'glorot_normal', 'glorot_uniform', 'he_normal', 'he_uniform']
optimizer = [ 'SGD', 'RMSprop', 'Adagrad', 'Adadelta', 'Adam', 'Adamax', 'Nadam']
param_grid = dict(epochs=epochs, batch_size=batch_size)
##############################################################
grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)
grid_result = grid.fit_generator(train_generator, validation_generator)
Получение ошибки в этой строке:grid_result = grid.fit_generator(train_generator, validation_generator)
1 ответ
Sklearn GridSearchCV не предоставляет
fit_generator
метод. Вы, вероятно, путаете его с Keras (теперь устаревшим) fit_generator .
Это означает, что поиск модели Keras по сетке нетривиален, если вы получаете обучающие данные от генераторов. Я нашел два связанных вопроса на SO:
как использовать поиск по сетке с генератором подгонки в keras
keras/scikit-learn: использование fit_generator() с перекрестной проверкой
Так что пока приходится прибегать к обходным путям.