как предотвратить переобучение в усиленном обучении с помощью vgg16
Я пытаюсь обучить модель распознавать выражения лица, поэтому в основном проблема классификации с 7 классами:
img_size=48
batch_size=64
datagen_train=ImageDataGenerator( rotation_range=15,
width_shift_range=0.15,
height_shift_range=0.15,
shear_range=0.15,
zoom_range=0.15,
horizontal_flip=True,
preprocessing_function=preprocess_input
)
train_generator=datagen_train.flow_from_directory(
train_path,
target_size=(img_size,img_size),
# color_mode='grayscale',
batch_size=batch_size,
class_mode='categorical',
shuffle=True
)
datagen_validation=ImageDataGenerator( horizontal_flip=True, preprocessing_function=preprocess_input)
validation_generator=datagen_train.flow_from_directory(
valid_path,
target_size=(img_size,img_size),
# color_mode='grayscale',
batch_size=batch_size,
class_mode='categorical',
shuffle=True,
)
Я использую ImageDataGenerator, и я сделал свою модель с VGG16 без переноса головы, обучаясь так:
ptm = PretrainedModel(
input_shape=[48,48,3],
weights='imagenet',
include_top=False)
x = Flatten()(ptm.output)
x = Dropout(0.5)(x)
x = Dense(512, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
x = Dense(512, activation='relu')(x)
x = BatchNormalization()(x)
x = Dense(7, activation='softmax',kernel_initializer='random_uniform', bias_initializer='random_uniform', bias_regularizer=regularizers.l2(0.01), name='predictions')(x)
opt=optimizers.RMSprop(learning_rate=0.0001)
model = Model(inputs=ptm.input, outputs=x)
model.compile(
loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy']
)
model.summary()
Я использовал оптимизаторы и раннюю остановку и прогнал 100 эпох:
early_stopping = EarlyStopping(
monitor='val_accuracy',
min_delta=0.00005,
patience=11,
verbose=1,
restore_best_weights=True,
)
lr_scheduler = ReduceLROnPlateau(
monitor='val_accuracy',
factor=0.5,
patience=7,
min_lr=1e-7,
verbose=1,
)
callbacks = [
early_stopping,
lr_scheduler,
]
и после 61 эпохи у меня была ранняя остановка, и я получил приличную точность, но val_accuracy была очень низкой по сравнению с ней:
loss: 0.6081 - accuracy: 0.7910 - val_loss: 1.4658 - val_accuracy: 0.5608
какие-нибудь предложения о том, как я могу исправить это переоснащение? Благодарность!
1 ответ
В вашем генераторе проверки удалите
horizontal_flip=True
и установить
shuffle=False
. Также у вас есть код
validation_generator=datagen_train.flow_from_directory( etc
Вы хотите изменить его на
validation_generator=datagen_validation.flow_from_directory(etc