Как генератор обучается с выходом дискриминатора в Генеративных состязательных сетях
Недавно я узнал о порождающих состязательных сетях.
Что касается обучения Генератора, я как-то запутался в том, как он учится. Вот реализация GAN:
`# train generator
z = Variable(xp.random.uniform(-1, 1, (batchsize, nz), dtype=np.float32))
x = gen(z)
yl = dis(x)
L_gen = F.softmax_cross_entropy(yl, Variable(xp.zeros(batchsize, dtype=np.int32)))
L_dis = F.softmax_cross_entropy(yl, Variable(xp.ones(batchsize, dtype=np.int32)))
# train discriminator
x2 = Variable(cuda.to_gpu(x2))
yl2 = dis(x2)
L_dis += F.softmax_cross_entropy(yl2, Variable(xp.zeros(batchsize, dtype=np.int32)))
#print "forward done"
o_gen.zero_grads()
L_gen.backward()
o_gen.update()
o_dis.zero_grads()
L_dis.backward()
o_dis.update()`
Таким образом, он рассчитывает потери для генератора, как это упоминается в статье. Тем не менее, он вызывает функцию обратного генератора, основанную на выводе дискриминатора. Выход дискриминатора - просто число (не массив).
Но мы знаем, что в целом для обучения сети мы вычисляем функцию потерь в последнем слое (потерю между выходом последних слоев и реальным выходом), а затем вычисляем градиенты. Так, например, если выходной сигнал равен 64*64, то мы сравниваем его с изображением 64*64, а затем вычисляем потери и выполняем обратное распространение.
Тем не менее, в кодах, которые я вижу в Generative Adversarial Networks, я вижу, что они вычисляют потери для генератора из выходного сигнала дискриминатора (который является просто числом), а затем они вызывают обратное распространение для генератора. Последние уровни Генераторов, например, составляют 64 * 64 пикселя, но потеря дискриминатора составляет 1*1 (что отличается от обычных сетей). Поэтому я не понимаю, как это приводит к обучению и обучению Генератора?
Я подумал, что если мы присоединяем две сети (подключаем Генератор и Дискриминатор), а затем вызываем обратное распространение, но просто обновляем параметры Генераторов, это имеет смысл и должно работать. Но то, что я вижу в кодах, совершенно другое.
Вот я и спрашиваю, как это возможно?
Спасибо
0 ответов
Вы говорите: "Тем не менее, он вызывает функцию обратной генерации, основанную на выводе Дискриминатора. Выходные данные дискриминатора - это просто число (а не массив), тогда как потеря всегда является скалярным значением. Когда мы вычисляем среднеквадратичную ошибку двух изображений, это также скалярное значение.
L_adversarial = E [log (D(x))] + E [log (1-D(G(z))]
х из реального распределения данных
z - скрытое распределение данных, которое преобразуется Генератором.
Возвращаясь к вашему актуальному вопросу, сеть Discriminator имеет функцию активации сигмоида на последнем уровне, что означает, что она выводит в диапазоне [0,1]. Дискриминатор пытается максимизировать эту потерю, максимизируя оба термина, которые добавляются в функцию потерь. Максимальное значение первого слагаемого равно 0 и возникает, когда D(x) равно 1, а максимальное значение второго слагаемого также равно 0 и имеет место, когда 1-D (G (z)) равно 1, что означает, что D (G (z)) равно 0 Таким образом, Discriminator пытается выполнить бинарную классификацию, максимизируя эту функцию потерь, где он пытается вывести 1, когда ему подают x(реальные данные), и 0, когда ему подают G(z)(сгенерированные поддельные данные). Но Генератор пытается минимизировать эту потерю, другими словами, он пытается обмануть Дискриминатор, генерируя поддельные сэмплы, которые похожи на реальные сэмплы. Со временем и генератор, и дискриминатор становятся все лучше и лучше. Это интуиция GAN.
Код находится в pytorch
bce_loss = nn.BCELoss() #bce_loss = -ylog(y_hat)-(1-y)log(1-y_hat)[similar to L_adversarial]
Discriminator = ..... #some network
Generator = ..... #some network
optimizer_generator = ....... #some optimizer for generator network
optimizer_discriminator = ....... #some optimizer for discriminator network
z = ...... #some latent data distribution that is transformed by the generator
real = ..... #real data distribution
#####################
#Update Discriminator
#####################
fake = Generator(z)
fake_prediction = Discriminator(fake)
real_prediction = Discriminator(real)
discriminator_loss = bce_loss(fake_prediction,torch.zeros(batch_size))+bce_loss(real_prediction,torch.ones(batch_size))
discriminator_loss.backward()
optimizer_discriminator.step()
#################
#Update Generator
#################
fake = Generator(z)
fake_prediction = Discriminator(fake)
generator_loss = bce_loss(fake_prediction,torch.ones(batch_size))
generator_loss.backward()
optimizer_generator.step()