Проблема с ранней остановкой в tf.keras
Я тренирую свою первую перенесенную модель обучения (ура!), И мне не удается заставить модель прекратить обучение, если потеря проверки не изменилась более чем на 0,1 за более чем 3 эпохи.
Вот соответствующий блок кода
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, min_delta = 0.1)
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
callbacks=[early_stopping])
EPOCHS = 100
history = model.fit(training_batches,
epochs=EPOCHS,
validation_data=validation_batches)
А вот несколько логов:
Epoch 32/100
155/155 [==============================] - 21s 134ms/step - loss: 0.0042 - accuracy: 0.9998 - val_loss: 0.3407 - val_accuracy: 0.9012
Epoch 33/100
155/155 [==============================] - 21s 133ms/step - loss: 0.0040 - accuracy: 0.9998 - val_loss: 0.3443 - val_accuracy: 0.9000
Epoch 34/100
155/155 [==============================] - 21s 134ms/step - loss: 0.0037 - accuracy: 0.9998 - val_loss: 0.3393 - val_accuracy: 0.9019
Epoch 35/100
155/155 [==============================] - 21s 135ms/step - loss: 0.0031 - accuracy: 1.0000 - val_loss: 0.3396 - val_accuracy: 0.9000
Epoch 36/100
155/155 [==============================] - 21s 134ms/step - loss: 0.0028 - accuracy: 1.0000 - val_loss: 0.3390 - val_accuracy: 0.9000
Epoch 37/100
155/155 [==============================] - 21s 133ms/step - loss: 0.0026 - accuracy: 1.0000 - val_loss: 0.3386 - val_accuracy: 0.9025
Epoch 38/100
155/155 [==============================] - 21s 133ms/step - loss: 0.0024 - accuracy: 1.0000 - val_loss: 0.3386 - val_accuracy: 0.8994
Epoch 39/100
155/155 [==============================] - 21s 133ms/step - loss: 0.0022 - accuracy: 1.0000 - val_loss: 0.3386 - val_accuracy: 0.9019
Вопросы:
- Почему обучение не остановилось на Эпохе 37, когда у меня установлен обратный вызов EarlyStopping для отслеживания val_loss?
- Могу ли я выполнять более сложные обратные вызовы EarlyStopping? Что-то вроде "Если val_accuracy > 0.90 && val_loss не изменилось более чем на 0,1 за 3 эпохи". Если можно, могу ли я получить ссылку на учебник?
2 ответа
РЕДАКТИРОВАТЬ
Это не работает, потому что вы поместили callback
параметр в неправильном вызове метода. (и на самом деле я получил ошибку недопустимого аргумента при подборе модели сcallbacks
перешел к compile
. Таким образом, я не уверен, почему ваша модель была скомпилирована без проблем.)
Это должно быть внутри вашего fit
метод, как показано ниже. Обратите внимание, что рекомендуется установитьverbose = 1
в вашей конфигурации ранней остановки, чтобы он распечатал журнал ранней остановки.
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, min_delta = 0.1, verbose = 1)
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
EPOCHS = 100
history = model.fit(training_batches,
epochs=EPOCHS,
callbacks=[early_stopping], # call back should be here!
validation_data=validation_batches)
По вашему второму вопросу возможен настраиваемый обратный вызов, вы можете обратиться к примеру из документации здесь. По сути, вы должны определить свою логику ранней остановки вon_epoch_end
.
Кстати, я считаю, что вам не следует заблаговременно останавливаться на нескольких показателях, выберите тот, который имеет значение (т.е. показатель, который вы оптимизируете - val_accuracy
) и просто следите за этим. Есть даже источники, которые не рекомендуют преждевременно останавливаться, а вместо этого рассматривают эпоху как настраиваемый гиперпараметр. См. Эту дискуссию на Reddit, которую я нашел полезной.
Уменьшить patience=3
меньше, например 1
или 2
и посмотрим, что произойдет.
Это говорит Керасу, как сильно ты хочешь попробовать. терпение = небольшое количество скажет Керасу прервать обучение раньше. С другой стороны, если вы используете большое число, он скажет Керасу подождать, пока не будет достигнута значительная точность.
терпение: количество эпох, в течение которых наблюдаемое количество получилось без улучшения, после которого обучение будет остановлено. Величины валидации не могут быть произведены для каждой эпохи, если частота валидации (model.fit(validation_freq=5)) больше единицы