Стабильные базовые показатели с сохранением модели PPO и ее повторным обучением
Здравствуйте, я использую пакет Stable baselines ( https://stable-baselines.readthedocs.io/), в частности, я использую PPO2, и я не уверен, как правильно сохранить свою модель... Я тренировал его в течение 6 виртуальных дней и получил свой средний доход около 300, затем я решил, что этого недостаточно для меня, поэтому я тренировал модель еще 6 дней. Но когда я посмотрел на статистику тренировок, вторая отдача от тренировки за серию началась примерно с 30. Это говорит о том, что не были сохранены все параметры.
вот как я сохраняю использование пакета:
def make_env_init(env_id, rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
"""
def env_init():
# Important: use a different seed for each environment
env = gym.make(env_id, connection=blt.DIRECT)
env.seed(seed + rank)
return env
set_global_seeds(seed)
return env_init
envs = VecNormalize(SubprocVecEnv([make_env_init(f'envs:{env_name}', i) for i in range(processes)]), norm_reward=False)
if os.path.exists(folder / 'model_dump.zip'):
model = PPO2.load(folder / 'model_dump.zip', envs, **ppo_kwards)
else:
model = PPO2(MlpPolicy, envs, **ppo_kwards)
model.learn(total_timesteps=total_timesteps, callback=callback)
model.save(folder / 'model_dump.zip')
1 ответ
Вы сохранили модель правильно. Тренировка не является монотонным процессом: она может показывать гораздо худшие результаты после дальнейшего обучения.
Что вы можете сделать, это в первую очередь писать журналы прогресса:
model = PPO2(MlpPolicy, envs, tensorboard_log="./logs/progress_tensorboard/")
Чтобы посмотреть журнал, запустите в терминале:
tensorboard --port 6004 --logdir ./logs/progress_tensorboard/
он предоставит вам ссылку на доску, которую вы затем сможете открыть в браузере (например, http://pc0259:6004/)
Во-вторых, вы можете делать снимки модели каждые X шагов:
from stable_baselines.common.callbacks import CheckpointCallback
checkpoint_callback = CheckpointCallback(save_freq=1e4, save_path='./model_checkpoints/')
model.learn(total_timesteps=total_timesteps, callback=[callback, checkpoint_callback])
Совместив его с бревном, вы сможете выбрать наиболее эффективную модель!