Использование tf.cond() в функции модели оценки для обучения WGAN на TPU вызывает удвоение global_step

Я пытаюсь обучить GAN на TPU, поэтому я возился с классом TPUEstimator и сопутствующей функцией модели, пытаясь реализовать цикл обучения WGAN. Я пытаюсь использовать tf.cond объединить две операции обучения для TPUEstimatorSpec следующим образом:

opt = tf.cond(
    tf.equal(tf.mod(tf.train.get_or_create_global_step(), 
    CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1), 
    lambda: gen_opt, 
    lambda: critic_opt
)

gen_opt а также critic_opt являются функцией минимизации оптимизатора, который я использую, также для обновления глобального шага. CRITIC_UPDATES_PER_GEN_UPDATE является константой Python именно для этого и является частью обучения WGAN. Я пытался найти модель GAN, используя tf.cond, но все модели используют tf.group, который я не могу использовать, потому что вам нужно оптимизировать критику гораздо больше, чем генератор. Однако каждый раз, когда я запускаю 100 партий, глобальный шаг увеличивается на 200 в соответствии с номером контрольной точки. Моя модель все еще тренируется правильно, или tf.cond просто не предполагается использовать этот способ для обучения ГАН?

1 ответ

Решение

tf.cond не предполагается использовать таким образом для обучения GAN.

Вы получаете 200, потому что на каждом этапе обучения побочные эффекты (например, операции присваивания) обоих true_fn а также false_fn оцениваются. Одним из побочных эффектов является глобальный шаг tf.assign_add операция, которую определяют оба оптимизатора.

Следовательно, то, что происходит, похоже на

  • Исполнение global_step++ (gen_opt) а также global_step++ (critic_op)
  • Оценка состояния
  • Исполнение true_fn тело или false_fn кузов (в зависимости от состояния).

Если вы хотите тренировать ГАН, используя tf.condнеобходимо удалить все побочные операции (например, присвоение, отсюда и определение шага оптимизации) снаружи true_fn/false_fn и объявить все внутри них.

В качестве ссылки, вы можете увидеть этот ответ о поведении tf.cond: /questions/22870087/smuschaet-povedenie-tfcond/22870098#22870098

Другие вопросы по тегам