MNIST DC-GAN Все градиенты нули
Я пытаюсь создать DC-GAN для MNIST на платформе Flax, используя пример TF в качестве ориентира. Сама сеть технически работает, но ни генератор, ни дискриминатор не обновляются, поскольку их градиенты всегда равны нулю. Я уже убедился, что веса инициализированы правильно, и попытался увеличить скорость обучения, но это не помогло. Я мог только подозревать, что проблема в архитектуре самой сети, но она буквально копируется построчно, за исключением BatchNormalizations в генераторе и Dropouts в дискриминаторе.
class generator_class(nn.Module):
def apply(self, x):
x = nn.Dense(x, features=7*7*256, bias_init=initializers.zeros)
x = lrelu(x)
x = x.reshape((-1, 7, 7, 256))
x = nn.ConvTranspose(x, features=128, kernel_size=(5, 5), strides=(1, 1), bias=False)
x = lrelu(x)
x = nn.ConvTranspose(x, features=64, kernel_size=(5, 5), strides=(2, 2), bias=False)
x = lrelu(x)
x = nn.ConvTranspose(x, features=1, kernel_size=(5, 5), strides=(2, 2), bias=False)
x = nn.tanh(x)
return x
class discriminator_class(nn.Module):
def apply(self, x):
x = nn.Conv(x, features=64, kernel_size=(5, 5), strides=(2,2))
x = lrelu(x)
x = nn.Conv(x, features=128, kernel_size=(5, 5), strides=(2,2))
x = lrelu(x)
x = x.reshape((x.shape[0], -1)) #flatten
x = nn.Dense(x, features=1)
return x
_, init_params = generator_class.init_by_shape(random.PRNGKey(0), [((100,), jnp.float32)])
generator = nn.Model(generator_class, init_params)
_, init_params = discriminator_class.init_by_shape(random.PRNGKey(0), [((1, 28, 28, 1), jnp.float32)])
discriminator = nn.Model(discriminator_class, init_params)
@jax.vmap
def binary_cross_entropy(logits, labels):
logits = nn.log_sigmoid(logits)
return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))
@jax.jit
def train_step(generator_optimizer, discriminator_optimizer, images):
noise = jax.random.normal(random.PRNGKey(0), shape = [256, 100])
generated_images = gen(noise)
real_output = disc(images)
fake_output = disc(generated_images)
def generator_loss(generator):
return binary_cross_entropy(jnp.ones_like(fake_output), fake_output).mean(), generated_images
def discriminator_loss(discriminator):
real_loss = binary_cross_entropy(jnp.ones_like(real_output), real_output).mean()
fake_loss = binary_cross_entropy(jnp.zeros_like(fake_output), fake_output).mean()
total_loss = real_loss + fake_loss
return total_loss, fake_output
grad_fn_gen = jax.value_and_grad(generator_loss, has_aux=True)
(_, preds), grad_gen = grad_fn_gen(generator_optimizer.target)
generator_optimizer = generator_optimizer.apply_gradient(grad_gen)
grad_fn_disc = jax.value_and_grad(discriminator_loss, has_aux=True)
(_, preds), grad_disc = grad_fn_disc(discriminator_optimizer.target)
discriminator_optimizer = discriminator_optimizer.apply_gradient(grad_disc)
return generator_optimizer, discriminator_optimizer