DCGAN просто генерирует шум

Я пытаюсь обучить DCGAN в PyTorch, который может генерировать карты глубины. Даже после небольшого изменения параметров GAN просто выводит шум. В настоящее время я сохранил количество эпох до 100. Я знаю, что недостаточно получить достойные изображения за 100 эпох, но я думаю, что мой GAN должен был изучить хотя бы что-то лучше, чем просто генерировать шум. Есть ли стандартный способ исправить эту проблему?

Вот так выглядит сгенерированное изображение

Сгенерированные изображения, кажется, улучшаются только в течение первых 2 или 3 эпох, а затем остаются такими же для остальных 97 или 98 эпох.

Я перепробовал много вариантов скорости обучения как для генератора, так и для дискриминатора. Что я заметил, так это то, что если мои скорости обучения для генератора и дискриминатора очень близки друг к другу, то генератору будет очень трудно его поддерживать, и его потери очень быстро возрастают, а потери дискриминатора быстро приближаются к нулю. Вот почему я держал уровень обучения дискриминатора очень низким, в 0.000000001 и скорость обучения генератора в 0.0001, Это просто позволяет генератору дышать на несколько больше эпох, чем в предыдущем случае, но после этого дискриминатор снова начинает доминировать, и потери генератора снова начинают увеличиваться.

Основываясь на онлайн-советах, я даже пытался поменять метки, сохраняя real_label как 0 и fake_label как 1, но безрезультатно. Казалось, это только ухудшило ситуацию.

Я очень новичок в GAN и не уверен, как стабилизировать тренировочный процесс. Любая помощь будет оценена. Заранее спасибо!

РЕДАКТИРОВАТЬ:

Модель генератора, которую я использовал, выглядит следующим образом:

class Generator(nn.Module):
def __init__(self, ngpu):
    super(Generator, self).__init__()
    self.ngpu = ngpu
    self.main = nn.Sequential(
        # input is Z, going into a convolution
        nn.ConvTranspose2d(nz, ngf*16, 4, 1, 0, bias=False),
        nn.BatchNorm2d(ngf*16),
        nn.ReLU(True),
        # state size. (ngf*16) x 4 x 4
        nn.ConvTranspose2d(ngf*16, ngf*8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf*8),
        nn.ReLU(True),
        # state size. (ngf*8) x 8 x 8
        nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 0, bias=False),
        nn.BatchNorm2d(ngf*4),
        nn.ReLU(True),
        # state size. (ngf*4) x 18 x 18
        nn.ConvTranspose2d(ngf*4, ngf*2, 3, 2, 0, bias=False),
        nn.BatchNorm2d(ngf*2),
        nn.ReLU(True),
        # state size. (ngf*2) x 37 x 37
        nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ReLU(True),
        # state size. (ngf) x 74 x 74
        nn.ConvTranspose2d(ngf, 1, 4, 2, 0, bias=False),
        nn.Tanh(),
        # state size. (1) x 150 x 150
    )

def forward(self, input):
    return self.main(input)

Инициализации: nz = 100, ngf = 64

Модель Дискриминатора, которую я использовал, выглядит следующим образом:

class Discriminator(nn.Module):
def __init__(self, ngpu):
    super(Discriminator, self).__init__()
    self.ngpu = ngpu
    self.main = nn.Sequential(
        # input is (nc) x 150 x 150
        nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
        # state size. (ndf) x 75 x 75
        nn.Conv2d(ndf, ndf * 2, 3, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 2),
        nn.LeakyReLU(0.2, inplace=True),
        # state size. (ndf*2) x 38 x 38
        nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 4),
        nn.LeakyReLU(0.2, inplace=True),
        # state size. (ndf*4) x 19 x 19
        nn.Conv2d(ndf * 4, ndf * 8, 3, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 8),
        nn.LeakyReLU(0.2, inplace=True),
        # state size. (ndf*8) x 10 x 10
        nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
        # state size. 1 x 7 x 7
        nn.Conv2d(1, 1, 4, 1, 0, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
        # state size. 1 x 4 x 4
        nn.Conv2d(1, 1, 4, 1, 0, bias=False),
        nn.Sigmoid()
        # state size 1 x 1 x 1
    )

def forward(self, input):
    return self.main(input)

Инициализации: nc = 4, ndf = 64

Тренировочный цикл выглядит следующим образом:

G_losses = []
D_losses = []
iters = 0
print("Starting training loop")

for epoch in range(num_epochs):
    print("Epoch", epoch+1, "of", num_epochs)
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data.to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Concat the first 3 dimensions of data (which contains the actual image) with the depth map produced by the generator
        fake_combined = data.to(device)
        fake_combined[:, [3], :, :] = fake.to(device)
        fake_combined = torch.tensor(fake_combined, dtype=torch.float32)
        fake_combined = fake_combined.to(device)
        # Classify all fake batch with D
        output = netD(fake_combined.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake_combined).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

0 ответов

Другие вопросы по тегам