Как хранить контрольные точки лучших моделей, а не только новейшие 5, в Tensorflow Object Detection API?
Я обучаю MobileNet на наборе данных WIDER FACE и столкнулся с проблемой, которую не смог решить. API обнаружения объектов TF хранит только последние 5 контрольных точек в train
dir, но я хотел бы сохранить лучшие модели относительно метрики mAP (или, по крайней мере, оставить еще много моделей в train
dir перед удалением). Например, сегодня я посмотрел на Tensorboard после следующей ночи тренировок и вижу, что ночная модель переоснащена, и я не могу восстановить лучшую контрольную точку, потому что она уже была удалена.
РЕДАКТИРОВАТЬ: Я просто использую Tensorflow Object Detection API, он по умолчанию сохраняет последние 5 контрольных точек в каталоге train, который я и указал. Я ищу какой-либо параметр конфигурации или что-нибудь, что изменит это поведение.
У кого-нибудь есть какое-то исправление в параметре code / config, чтобы установить / обходить это? Кажется, что я что-то упускаю, должно быть очевидно, что на самом деле важна лучшая модель, а не самая новая (которая может быть лучше).
Спасибо!
5 ответов
Вы можете изменить (жестко запрограммировать в своем форке или открыть запрос на извлечение и добавить опции в protos) аргументы, передаваемые в tf.train.Saver в:
https://github.com/tensorflow/models/blob/master/research/object_detection/trainer.py
Возможно, вы захотите установить:
- max_to_keep: максимальное количество последних сохраняемых контрольных точек. По умолчанию 5.
- keep_checkpoint_every_n_hours: как часто сохранять контрольные точки. По умолчанию 10000 часов.
Вы можете изменить конфиг.
в run_config.py
class RunConfig(object):
"""This class specifies the configurations for an `Estimator` run."""
def __init__(self,
model_dir=None,
tf_random_seed=None,
save_summary_steps=100,
save_checkpoints_steps=_USE_DEFAULT,
save_checkpoints_secs=_USE_DEFAULT,
session_config=None,
keep_checkpoint_max=10,
keep_checkpoint_every_n_hours=10000,
log_step_count_steps=100,
train_distribute=None,
device_fn=None,
protocol=None,
eval_distribute=None,
experimental_distribute=None):
Вас может заинтересовать этот поток Tf github, в котором рассматривается новейшая / лучшая проблема контрольных точек. Пользователь разработал свою собственную оболочку, chekmate, вокруг tf.Saver
отслеживать лучшие контрольно-пропускные пункты.
Вы можете следить за этим PR. Здесь ваша лучшая контрольная точка сохраняется в вашем каталоге контрольных точек, подкаталог назван как лучший.
Вам просто нужно интегрировать best_saver() и (вызов метода в _run_checkpoint_once()) внутри../object_detection/eval_util.py
Кроме того, он также создаст json для all_evaluation_metrices.
Для сохранения большего количества контрольных точек вы можете написать простой скрипт на Python, который будет своевременно сохранять контрольные точки для конкретного.
import os
import shutil
import time
while True:
training_file = '/home/vignesh/training' # path of your train directory
archive_file = 'home/vignesh/training/archive' #path of the directory where you want to save your checkpoints
files_to_save = []
for files in os.listdir(training_file):
if files.rsplit('.')[0]=='model':
files_to_save.append(files)
for files in files_to_save:
if files in os.listdir(archive_file):
pass
else:
shutil.copy2(training_file+'/'+files,archive_file)
time.sleep(600) # This will make the script run for every 600 seconds, modify it for your need