Почему DCGAN с неконволюционным дискриминатором не учится?

Я настроил код DCGAN-Tensorflow, изменив его архитектуру дискриминатора на не сверточную сеть. Это приводит к тому, что потери дискриминатора и генератора идут параллельно с самого начала обучения, например: введите описание изображения здесь Числа обоих потерь не постоянны, они отскакивают, но очень незначительно. Сгенерированный результат не очень хорош - модель, похоже, не учится.

Вопрос: кто-нибудь сталкивался с чем-то подобным раньше? В чем может быть проблема?


Полный репозиторий git доступен здесь.

И это код отредактированного дискриминатора, который нарушил код:

  def discriminator(self, image, y=None, reuse=False):
    with tf.variable_scope("discriminator") as scope:
      if reuse:
        scope.reuse_variables()

      if not self.y_dim:
        im = tf.reshape(image, [64, 160])

        self.h0, self.h0_w, self.h0_b = linear(im, 1024, 'd_h0', with_w = True)
        h0 = tf.nn.tanh(self.g_bn0(self.h0))
        self.h1, self.h1_w, self.h1_b = linear(h0, 1024, 'd_h1', with_w = True)
        h1 = tf.nn.tanh(self.g_bn1(self.h1))
        self.h2, self.h2_w, self.h2_b = linear(h1, 160, 'd_h2', with_w = True)
        h2 = tf.nn.tanh(self.g_bn2(self.h2))
        self.h3, self.h3_w, self.h3_b = linear(h2, 2, 'd_h3', with_w = True)        
        h3 = tf.nn.tanh(self.g_bn3(self.h3))            
        h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h4')

        return tf.nn.sigmoid(h4), h4

Вместо оригинального:

  def discriminator(self, image, y=None, reuse=False):
    with tf.variable_scope("discriminator") as scope:
      if reuse:
        scope.reuse_variables()

      if not self.y_dim:
        h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv'))
        h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv')))
        h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv')))
        h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv')))
        h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h3_lin')

        return tf.nn.sigmoid(h4), h4

0 ответов

Другие вопросы по тегам