Как асинхронно обновить генератор и дискриминатор GAN в Tensorflow?

Я хочу разработать GAN с Tensorflow, где Генератор является авто-кодером, а Дискриминатор - сверточной нейронной сетью с двоичным выходом. Нет проблем в разработке автоэнкодера и CNN, но моя идея состоит в том, чтобы обучить 1 эпоху для каждого из компонентов (дискриминатор и генератор) и повторить этот цикл для 1000 эпох, сохраняя результаты (веса) предыдущей эпохи обучения для следующего. Как я могу реализовать это?

2 ответа

Если у вас есть две операции называется train_step_generator а также train_step_discriminator (каждый из которых, например, имеет вид tf.train.AdamOptimizer().minimize(loss) с соответствующей потерей для каждого), тогда ваш тренировочный цикл должен быть чем-то похожим на следующую структуру:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(1000):
        if epoch%2 == 0: # train discriminator on even epochs
            for i in range(training_set_size/batch_size):
                z_ = np.random.normal(0,1,batch_size) # this is the input to the generator
                batch = get_next_batch(batch_size)
                sess.run(train_step_discriminator,feed_dict={z:z_, x:batch})
        else: # train generator on odd epochs
            for i in range(training_set_size/batch_size):
                z_ = np.random.normal(0,1,batch_size)  # this is the input to the generator
                sess.run(train_step_generator,feed_dict={z:z_})

Веса будут сохраняться между итерациями.

Я решил проблему. На самом деле, я хочу, чтобы выход автоэнкодера был входом CNN, соединяющего GAN и обновляющего веса в пропорции 1:1. Я заметил, что мне нужно было с особой тщательностью различать потери генератора и дискриминатора, иначе в начале второго цикла потеря тензора генератора будет заменена на число с плавающей запятой, последняя потеря, генерируемая дискриминатором.

Вот код:

with tf.Session() as sess:
sess.run(init)
for i in range(1, num_steps+1):

здесь генератор обучения

    batch_x, batch_y=next_batch(batch_size, x_train_noisy, x_train)        
    _, l = sess.run([optimizer, loss], feed_dict={X: batch_x.reshape(n,784),
                    Y:batch_y})
    if i % display_step == 0 or i == 1:
        print('Epoch %i: Denoising Loss: %f' % (i, l))

здесь выход генератора будет использоваться как вход для дискриминатора

    output=sess.run([decoder_op],feed_dict={X: x_train})
    x_train2=np.array(output).reshape(n,784).astype(np.float64)

здесь тренинг Дискриминатор

    batch_x2, batch_y2 = next_batch(batch_size, x_train2, y_train)
    sess.run(train_op, feed_dict={X2: batch_x2.reshape(n,784), Y2: batch_y2, keep_prob: 0.8})
    if i % display_step == 0 or i == 1:
        loss3, acc = sess.run([loss_op2, accuracy], feed_dict={X2: batch_x2,
                                                             Y2: batch_y2,
                                                             keep_prob: 1.0})
        print("Epoch " + str(i) + ", CNN Loss= " + \
              "{:.4f}".format(loss3) + ", Training Accuracy= " + "{:.3f}".format(acc))

Таким образом, асинхронное обновление может быть запущено в пропорции 1:1, 1:5, 5:1 (дискриминатор: генератор) или любым другим способом.

Другие вопросы по тегам