Должен ли я перекомпилировать мой Gan каждую партию, чтобы предотвратить обучение дискриминатора?

У меня есть Ган, как так

generator = Model(g_in, g_out)
generator.compile(...)

discriminator = Model(d_in, d_out)
discriminator.trainable = True
discriminator.compile(..)

discriminator.trainable = False

gan = Model(inputs=.., outputs=..)
gan.compile(..)

#iterate over epochs and batches, without compiling

Он учится и дает приемлемый результат. Однако я получаю предупреждение:

"keras \ engine \ training.py: 490: UserWarning: расхождение между обучаемыми весами и собранными обучаемыми весами, вы установили model.trainable без звонка model.compile после? "Расхождение между обучаемым весом и собранным обучаемым"

Если я перекомпилирую дискриминатор и собираю каждую партию, предупреждение исчезает, но одна итерация занимает намного больше времени, а скорость обучения снижается.

for epoch:
  for batch:

    fakes=generator.predict_on_batch(batch)

    discriminator.trainable = True
    discriminator.compile(..)

    discriminator.train_on_batch(batch, ..)
    discriminator.train_on_batch(fakes, ..)

    discriminator.trainable = False
    discriminator.compile(..)
    gan.compile(..)

    gan.train_on_batch(batch,..)

Какой из них правильный?

1 ответ

Решение

Это ожидается, и нет необходимости перекомпилировать каждую партию. У Keras есть открытая ошибка об этом: https://github.com/keras-team/keras/issues/8585

В ответах есть несколько примеров того, как обойти предупреждение, я не буду повторять их здесь. Есть также ответ, который дает отличный совет о том, как проверить, действительно ли вы тренируетесь, чему вы должны тренироваться, если вы не уверены в специфике вашей модели: https://github.com/keras-team/keras/issues/8585

Другие вопросы по тегам