Реализация улучшенных генеративных состязательных сетей в Keras

Я использую реализованный в keras код для разработки улучшенной версии Wasserstein Generative Adversarial Gans. Я тестирую код, используя предоставленную базу данных Mnist. В коде при чтении Mnist также загружаются метки изображений.

(X_train, y_train), (X_test, y_test) = mnist.load_data()

Функции потерь для дискриминатора:

 discriminator_loss.append(discriminator_model.train_on_batch([image_batch, noise],[positive_y, negative_y, dummy_y])) 

И для генератора:

 generator_loss.append(generator_model.train_on_batch(np.random.rand(BATCH_SIZE, 100), positive_y))

Во время процесса GAN эти метки игнорируются. Как я могу использовать эту информацию для создания изображений из определенных ярлыков?

РЕДАКТИРОВАТЬ: я заметил, что строки, которые мне нужно изменить, между 203 и 217

discriminator.trainable = False
generator_input = Input(shape=(100,))
generator_layers = generator(generator_input)
discriminator_layers_for_generator = discriminator(generator_layers)
generator_model = Model(inputs=[generator_input], outputs= 
[discriminator_layers_for_generator])
//We use the Adam paramaters from Gulrajani et al.
generator_model.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9), 
    loss=wasserstein_loss)
...

real_samples = Input(shape=X_train.shape[1:])
generator_input_for_discriminator = Input(shape=(100,))
generated_samples_for_discriminator = 
generator(generator_input_for_discriminator)
discriminator_output_from_generator = 
       discriminator(generated_samples_for_discriminator)
discriminator_output_from_real_samples = discriminator(real_samples)

И я думаю, что мне нужно изменить также конструкторы моделей make_generator и make_discriminator.

0 ответов

Другие вопросы по тегам