Использование tf.cond() в функции модели оценки для обучения WGAN на TPU вызывает удвоение global_step
Я пытаюсь обучить GAN на TPU, поэтому я возился с классом TPUEstimator и сопутствующей функцией модели, пытаясь реализовать цикл обучения WGAN. Я пытаюсь использовать tf.cond
объединить две операции обучения для TPUEstimatorSpec следующим образом:
opt = tf.cond(
tf.equal(tf.mod(tf.train.get_or_create_global_step(),
CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1),
lambda: gen_opt,
lambda: critic_opt
)
gen_opt
а также critic_opt
являются функцией минимизации оптимизатора, который я использую, также для обновления глобального шага. CRITIC_UPDATES_PER_GEN_UPDATE
является константой Python именно для этого и является частью обучения WGAN. Я пытался найти модель GAN, используя tf.cond
, но все модели используют tf.group
, который я не могу использовать, потому что вам нужно оптимизировать критику гораздо больше, чем генератор. Однако каждый раз, когда я запускаю 100 партий, глобальный шаг увеличивается на 200 в соответствии с номером контрольной точки. Моя модель все еще тренируется правильно, или tf.cond
просто не предполагается использовать этот способ для обучения ГАН?
1 ответ
tf.cond
не предполагается использовать таким образом для обучения GAN.
Вы получаете 200, потому что на каждом этапе обучения побочные эффекты (например, операции присваивания) обоих true_fn
а также false_fn
оцениваются. Одним из побочных эффектов является глобальный шаг tf.assign_add
операция, которую определяют оба оптимизатора.
Следовательно, то, что происходит, похоже на
- Исполнение
global_step++ (gen_opt)
а такжеglobal_step++ (critic_op)
- Оценка состояния
- Исполнение
true_fn
тело илиfalse_fn
кузов (в зависимости от состояния).
Если вы хотите тренировать ГАН, используя tf.cond
необходимо удалить все побочные операции (например, присвоение, отсюда и определение шага оптимизации) снаружи true_fn
/false_fn
и объявить все внутри них.
В качестве ссылки, вы можете увидеть этот ответ о поведении tf.cond
: /questions/22870087/smuschaet-povedenie-tfcond/22870098#22870098