Как хранить контрольные точки лучших моделей, а не только новейшие 5, в Tensorflow Object Detection API?

Я обучаю MobileNet на наборе данных WIDER FACE и столкнулся с проблемой, которую не смог решить. API обнаружения объектов TF хранит только последние 5 контрольных точек в train dir, но я хотел бы сохранить лучшие модели относительно метрики mAP (или, по крайней мере, оставить еще много моделей в train dir перед удалением). Например, сегодня я посмотрел на Tensorboard после следующей ночи тренировок и вижу, что ночная модель переоснащена, и я не могу восстановить лучшую контрольную точку, потому что она уже была удалена.

РЕДАКТИРОВАТЬ: Я просто использую Tensorflow Object Detection API, он по умолчанию сохраняет последние 5 контрольных точек в каталоге train, который я и указал. Я ищу какой-либо параметр конфигурации или что-нибудь, что изменит это поведение.

У кого-нибудь есть какое-то исправление в параметре code / config, чтобы установить / обходить это? Кажется, что я что-то упускаю, должно быть очевидно, что на самом деле важна лучшая модель, а не самая новая (которая может быть лучше).

Спасибо!

5 ответов

Решение

Вы можете изменить (жестко запрограммировать в своем форке или открыть запрос на извлечение и добавить опции в protos) аргументы, передаваемые в tf.train.Saver в:

https://github.com/tensorflow/models/blob/master/research/object_detection/trainer.py

Возможно, вы захотите установить:

  • max_to_keep: максимальное количество последних сохраняемых контрольных точек. По умолчанию 5.
  • keep_checkpoint_every_n_hours: как часто сохранять контрольные точки. По умолчанию 10000 часов.

Вы можете изменить конфиг.

в run_config.py

class RunConfig(object):
  """This class specifies the configurations for an `Estimator` run."""

  def __init__(self,
           model_dir=None,
           tf_random_seed=None,
           save_summary_steps=100,
           save_checkpoints_steps=_USE_DEFAULT,
           save_checkpoints_secs=_USE_DEFAULT,
           session_config=None,
           keep_checkpoint_max=10,
           keep_checkpoint_every_n_hours=10000,
           log_step_count_steps=100,
           train_distribute=None,
           device_fn=None,
           protocol=None,
           eval_distribute=None,
           experimental_distribute=None):

Вас может заинтересовать этот поток Tf github, в котором рассматривается новейшая / лучшая проблема контрольных точек. Пользователь разработал свою собственную оболочку, chekmate, вокруг tf.Saver отслеживать лучшие контрольно-пропускные пункты.

Вы можете следить за этим PR. Здесь ваша лучшая контрольная точка сохраняется в вашем каталоге контрольных точек, подкаталог назван как лучший.

Вам просто нужно интегрировать best_saver() и (вызов метода в _run_checkpoint_once()) внутри../object_detection/eval_util.py

Кроме того, он также создаст json для all_evaluation_metrices.

Для сохранения большего количества контрольных точек вы можете написать простой скрипт на Python, который будет своевременно сохранять контрольные точки для конкретного.

      import os
import shutil
import time

while True:
    
    training_file = '/home/vignesh/training' # path of your train directory
    archive_file = 'home/vignesh/training/archive' #path of the directory where you want to save your checkpoints
    files_to_save = []

    for files in os.listdir(training_file):
        
        if files.rsplit('.')[0]=='model':
            
            files_to_save.append(files)

    for files in files_to_save:
        if files in os.listdir(archive_file):
            pass
        else:
            shutil.copy2(training_file+'/'+files,archive_file)
    time.sleep(600) # This will make the script run for every 600 seconds, modify it for your need
Другие вопросы по тегам