В архитектуре U-Net h и w не совпадают, я не знаю, понял я это неправильно или нет

Меня смущает структура преобразования изображений в статье, которую я прочитал. Выходные размеры слоев, которые они объединяют, не совпадают.

Должен ли это быть слой 9 и слой 7?

class ImageTransformationNetwork(nn.Module):
def __init__(self):
    super(ImageTransformationNetwork, self).__init__()
    self.layer1 = nn.Sequential(nn.Conv2d(3, 64, padding=1, kernel_size=4, stride=2),
                                nn.LeakyReLU())
    self.layer2 = nn.Sequential(nn.Conv2d(64, 128, padding=1, kernel_size=4, stride=2),
                                nn.LeakyReLU(),
                                nn.BatchNorm2d(128))
    self.layer3 = nn.Sequential(nn.Conv2d(128, 256, padding=1, kernel_size=4, stride=2),
                                nn.LeakyReLU(),
                                nn.BatchNorm2d(256))
    self.layer4 = nn.Sequential(nn.Conv2d(256, 512, padding=1, kernel_size=4, stride=2),
                                nn.LeakyReLU(),
                                nn.BatchNorm2d(512))
    self.layer5 = nn.Sequential(nn.Conv2d(512, 512, padding=1, kernel_size=4, stride=2),
                                nn.LeakyReLU(),
                                nn.BatchNorm2d(512))
    self.layer6 = nn.Sequential(nn.Conv2d(512, 512, padding=1, kernel_size=4, stride=2),
                                nn.LeakyReLU(),
                                nn.BatchNorm2d(512))
    self.layer7 = nn.Sequential(nn.Conv2d(512, 512, padding=1, kernel_size=4, stride=2),
                                nn.LeakyReLU(),
                                nn.BatchNorm2d(512))
    self.layer8 = nn.Sequential(nn.Conv2d(512, 512, padding=1, kernel_size=4, stride=2),
                                nn.LeakyReLU())
    self.layer9 = nn.Sequential(nn.ConvTranspose2d(512, 1024, padding=1, kernel_size=4, stride=2),
                                nn.ReLU(),
                                nn.BatchNorm2d(1024))
    self.layer10 = nn.Sequential(nn.ConvTranspose2d(1024+512, 1024, padding=1, kernel_size=4, stride=2),
                                nn.ReLU(),
                                nn.BatchNorm2d(1024))
    self.layer11 = nn.Sequential(nn.ConvTranspose2d(1024+512, 1024, padding=1, kernel_size=4, stride=2),
                                 nn.ReLU(),
                                 nn.BatchNorm2d(1024))
    self.layer12 = nn.Sequential(nn.ConvTranspose2d(1024+512, 1024, padding=1, kernel_size=4, stride=2),
                                 nn.ReLU(),
                                 nn.BatchNorm2d(1024))
    self.layer13 = nn.Sequential(nn.ConvTranspose2d(1024+512, 512, padding=1, kernel_size=4, stride=2),
                                 nn.ReLU(),
                                 nn.BatchNorm2d(512))
    self.layer14 = nn.Sequential(nn.ConvTranspose2d(512+256, 256, padding=1, kernel_size=4, stride=2),
                                 nn.ReLU(),
                                 nn.BatchNorm2d(256))
    self.layer15 = nn.Sequential(nn.ConvTranspose2d(256+128, 128, padding=1, kernel_size=4, stride=2),
                                 nn.Tanh(),
                                 nn.BatchNorm2d(128))
    self.layer16 = nn.Sequential(nn.ConvTranspose2d(128, 3, padding=1, kernel_size=4, stride=2))

def forward(self, x):
    x1 = self.layer1(x)
    x2 = self.layer2(x1)
    x3 = self.layer3(x2)
    x4 = self.layer4(x3)
    x5 = self.layer5(x4)
    x6 = self.layer6(x5)
    x7 = self.layer7(x6)
    x8 = self.layer8(x7)
    x9 = self.layer9(x8)
    x10 = self.layer10(torch.cat([x7, x9], dim=1))
    x11 = self.layer11(torch.cat([x6, x10], dim=1))
    x12 = self.layer12(torch.cat([x5, x11], dim=1))
    x13 = self.layer13(torch.cat([x4, x12], dim=1))
    x14 = self.layer14(torch.cat([x3, x13], dim=1))
    x15 = self.layer15(torch.cat([x2, x14], dim=1))
    x16 = self.layer16(x15)
    return x16

Не знаю, правильно ли я это понимаю. Не могли бы вы чем-нибудь помочь?

0 ответов

Другие вопросы по тегам