Генеративная Состязательная Сеть не генерирует узнаваемые шаблоны

Я попытался построить генеративную сеть соперничества с помощью этого кода:

import tensorflow as tf
import numpy as np
import matplotlib.image as mpimg
import glob
from PIL import Image

d_epochs = 10
g_epochs = 10

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

batch_size = 9

x_train = []
y_train = []

for filename in glob.glob('trainig_data/*.jpg'):
    im = mpimg.imread(filename)
    x_train.append(im)
    y_train.append(1.0)
    if len(x_train) == 396:
        break

gen_seed = np.random.rand(396)

g_input = 1
g_hidden1 = 100
g_hidden2 = 100
g_hidden3 = 116412

g_weights = [tf.Variable(tf.random_normal([g_input,                                                                                                                                                                         g_hidden1])),
                 tf.Variable(tf.random_normal([g_hidden1, g_hidden2])),
                 tf.Variable(tf.random_normal([g_hidden2, g_hidden3])),
                 tf.Variable(tf.random_normal([5,5,3,3])),
                 tf.Variable(tf.random_normal([5,5,3,3])),
                 tf.Variable(tf.random_normal([5,5,3,3]))]

def generator(x, weights):
    output=tf.matmul([[x]], weights[0])
    output=tf.nn.relu(output)
    output=tf.matmul(output, weights[1])
    output=tf.nn.relu(output)
    output=tf.matmul(output, weights[2])
    output=tf.nn.relu(output)
    output=tf.reshape(output, [1,218,178,3])
    output=tf.nn.conv2d(output, weights[3], [1,1,1,1], 'SAME')
    output=tf.nn.relu(output)
    output=tf.nn.conv2d(output, weights[4], [1,1,1,1], 'SAME')
    output=tf.nn.relu(output)
    output=tf.nn.conv2d(output, weights[5], [1,1,1,1], 'SAME')
    return output[0]

d_weights = [tf.Variable(tf.random_normal([5,5,3,3])),
             tf.Variable(tf.random_normal([5,5,3,3])),
             tf.Variable(tf.random_normal([5,5,3,3])),
             tf.Variable(tf.random_normal([261927, 100])),
             tf.Variable(tf.random_normal([100, 100])),
             tf.Variable(tf.random_normal([100, 2]))]

def discriminator(x, weights):
    output=tf.nn.conv2d(x, weights[0], [1,1,1,1], 'SAME')
    output=tf.nn.relu(output)
    output=tf.nn.conv2d(output, weights[1], [1,1,1,1], 'SAME')
    output=tf.nn.relu(output)
    output=tf.nn.conv2d(output, weights[2], [1,2,2,1], 'SAME')
    output=tf.nn.relu(output)
    output=tf.reshape(output, [261927])
    output=tf.matmul([output], weights[3])
    output=tf.nn.relu(output)
    output=tf.matmul(output, weights[4])
    output=tf.nn.relu(output)
    output=tf.matmul(output, weights[5])
    output=tf.reduce_mean(output)
    return(output)

prediction = discriminator(x, d_weights)
loss = -tf.reduce_sum(y * tf.log(prediction + 1e-12))
optimizer = tf.train.AdamOptimizer().minimize(loss)
g_optimizer = tf.train.AdamOptimizer().minimize(-prediction)

saver = tf.train.Saver()
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for e in range(d_epochs):
        print('epoch:', e+1)
        for b in range(int(len(x_train)/batch_size)):
            start = b*batch_size
            end = start+batch_size
            batch_x = x_train[start:end]
            batch_y = y_train[start:end]
            _, c = sess.run([optimizer,loss],feed_dict={y:batch_y, x:batch_x})
    saver.save(sess, 'saved/discriminator.ckpt')

    file = 0
    for e in range(g_epochs):
        print('epoch:', e+1)
        for b in range(int(len(gen_seed)/batch_size)):
            file += 1
            preds = []
            fake = []
            start = b*batch_size
            end = start+batch_size
            batch = gen_seed[start:end]
            for i in batch:
                fake.append(sess.run(generator(i, g_weights)))
            sess.run([prediction, loss, optimizer, g_optimizer], feed_dict={x:fake, y:np.zeros(len(fake))})
            im = Image.fromarray(fake[0].astype('uint8'))
            im.save('output/'+str(file)+'.png')
    saver.save(sess, '/saved/generator.ckpt')

Но даже после 10 эпох (440 партий) обучения генератора генерируемые им шаблоны вообще не связаны (я приведу пример изображения). Пожалуйста, помогите мне решить эту проблему. PS. Я использую Изображения набора данных CelebA, если это помогает. введите описание изображения здесь

0 ответов

Другие вопросы по тегам