MNIST DC-GAN Все градиенты нули

Я пытаюсь создать DC-GAN для MNIST на платформе Flax, используя пример TF в качестве ориентира. Сама сеть технически работает, но ни генератор, ни дискриминатор не обновляются, поскольку их градиенты всегда равны нулю. Я уже убедился, что веса инициализированы правильно, и попытался увеличить скорость обучения, но это не помогло. Я мог только подозревать, что проблема в архитектуре самой сети, но она буквально копируется построчно, за исключением BatchNormalizations в генераторе и Dropouts в дискриминаторе.

class generator_class(nn.Module):
    def apply(self, x):
        x = nn.Dense(x, features=7*7*256, bias_init=initializers.zeros)
        x = lrelu(x)
        x = x.reshape((-1, 7, 7, 256))
        x = nn.ConvTranspose(x, features=128, kernel_size=(5, 5), strides=(1, 1), bias=False)
        x = lrelu(x)
        x = nn.ConvTranspose(x, features=64, kernel_size=(5, 5), strides=(2, 2), bias=False)
        x = lrelu(x)
        x = nn.ConvTranspose(x, features=1, kernel_size=(5, 5), strides=(2, 2), bias=False)
        x = nn.tanh(x)
    return x

class discriminator_class(nn.Module):
    def apply(self, x):
        x = nn.Conv(x, features=64, kernel_size=(5, 5), strides=(2,2))
        x = lrelu(x)
        x = nn.Conv(x, features=128, kernel_size=(5, 5), strides=(2,2))
        x = lrelu(x)
        x = x.reshape((x.shape[0], -1)) #flatten
        x = nn.Dense(x, features=1)
    return x

_, init_params = generator_class.init_by_shape(random.PRNGKey(0), [((100,), jnp.float32)])
generator = nn.Model(generator_class, init_params)

_, init_params = discriminator_class.init_by_shape(random.PRNGKey(0), [((1, 28, 28, 1), jnp.float32)])
discriminator = nn.Model(discriminator_class, init_params)

@jax.vmap
def binary_cross_entropy(logits, labels):
    logits = nn.log_sigmoid(logits)
    return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))

@jax.jit
def train_step(generator_optimizer, discriminator_optimizer, images):

    noise = jax.random.normal(random.PRNGKey(0), shape = [256, 100])
    generated_images = gen(noise)
    real_output = disc(images)
    fake_output = disc(generated_images)

    def generator_loss(generator):
        return binary_cross_entropy(jnp.ones_like(fake_output), fake_output).mean(), generated_images

    def discriminator_loss(discriminator):
        real_loss = binary_cross_entropy(jnp.ones_like(real_output), real_output).mean()
        fake_loss = binary_cross_entropy(jnp.zeros_like(fake_output), fake_output).mean()
        total_loss = real_loss + fake_loss
        return total_loss, fake_output

    grad_fn_gen = jax.value_and_grad(generator_loss, has_aux=True)
    (_, preds), grad_gen = grad_fn_gen(generator_optimizer.target)
    generator_optimizer = generator_optimizer.apply_gradient(grad_gen)    

    grad_fn_disc = jax.value_and_grad(discriminator_loss, has_aux=True)
    (_, preds), grad_disc = grad_fn_disc(discriminator_optimizer.target)
    discriminator_optimizer = discriminator_optimizer.apply_gradient(grad_disc)

return generator_optimizer, discriminator_optimizer

0 ответов

Другие вопросы по тегам