Реализация улучшенных генеративных состязательных сетей в Keras
Я использую реализованный в keras код для разработки улучшенной версии Wasserstein Generative Adversarial Gans. Я тестирую код, используя предоставленную базу данных Mnist. В коде при чтении Mnist также загружаются метки изображений.
(X_train, y_train), (X_test, y_test) = mnist.load_data()
Функции потерь для дискриминатора:
discriminator_loss.append(discriminator_model.train_on_batch([image_batch, noise],[positive_y, negative_y, dummy_y]))
И для генератора:
generator_loss.append(generator_model.train_on_batch(np.random.rand(BATCH_SIZE, 100), positive_y))
Во время процесса GAN эти метки игнорируются. Как я могу использовать эту информацию для создания изображений из определенных ярлыков?
РЕДАКТИРОВАТЬ: я заметил, что строки, которые мне нужно изменить, между 203 и 217
discriminator.trainable = False
generator_input = Input(shape=(100,))
generator_layers = generator(generator_input)
discriminator_layers_for_generator = discriminator(generator_layers)
generator_model = Model(inputs=[generator_input], outputs=
[discriminator_layers_for_generator])
//We use the Adam paramaters from Gulrajani et al.
generator_model.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9),
loss=wasserstein_loss)
...
real_samples = Input(shape=X_train.shape[1:])
generator_input_for_discriminator = Input(shape=(100,))
generated_samples_for_discriminator =
generator(generator_input_for_discriminator)
discriminator_output_from_generator =
discriminator(generated_samples_for_discriminator)
discriminator_output_from_real_samples = discriminator(real_samples)
И я думаю, что мне нужно изменить также конструкторы моделей make_generator и make_discriminator.