Функции потери в GAN
Я пытаюсь создать простой MNIST GAN, и мне нужно меньше говорить, что это не сработало. Я много искал и исправил большую часть своего кода. Хотя я не могу понять, как работают функции потерь.
Вот что я сделал:
loss_d = -tf.reduce_mean(tf.log(discriminator(real_data))) # maximise
loss_g = -tf.reduce_mean(tf.log(discriminator(generator(noise_input), trainable = False))) # maxmize cuz d(g) instead of 1 - d(g)
loss = loss_d + loss_g
train_d = tf.train.AdamOptimizer(learning_rate).minimize(loss_d)
train_g = tf.train.AdamOptimizer(learning_rate).minimize(loss_g)
Я получаю -0.0 в качестве значения моей потери. Можете ли вы объяснить, как работать с функциями потерь в GAN?
2 ответа
Кажется, вы пытаетесь объединить потери генератора и дискриминатора, что совершенно неправильно! Поскольку дискриминатор тренируется как с реальными, так и сгенерированными данными, вам необходимо создать две различные потери: одну для реальных данных и другую для данных о шумах (генерируемых), которые вы передаете в сеть дискриминатора.
Попробуйте изменить свой код следующим образом:
1)
loss_d_real = -tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=discriminator(real_data),labels= tf.ones_like(discriminator(real_data))))
2)
loss_d_fake=-tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=discriminator(noise_input),labels= tf.zeros_like(discriminator(real_data))))
тогда потеря дискриминатора будет равна = loss_d_real+loss_d_fake. Теперь создайте потери для вашего генератора:
3)
loss_g= tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=discriminator(genereted_samples), labels=tf.ones_like(genereted_samples)))
Кажется, Марьям определила причину ваших значений побочных потерь (то есть суммирование потерь генератора и дискриминатора). Просто хотел добавить, что вам следует выбрать оптимизатор Stochastic Gradient Descent для дискриминатора вместо Адама - это обеспечивает более сильные теоретические гарантии конвергенции сети при игре в минимаксной игре (источник: https://github.com/soumith/ganhacks).