Расширение Unet Размеры
Я пытаюсь создать 3D Unet для сегментации медицинских изображений с помощью дилатационных слоев, чтобы дать мне необходимый размер восприимчивого поля, который мне нужен, не делая модель слишком тяжелой. У меня много проблем с установкой размеров кодера и декодера в соответствии. Мой ввод [bs,4,96,96,64] и вывод [bs,5,96,96,64], т.е. каждый пиксель имеет 5 возможных классов. Как вы видите, у меня определены мой нисходящий блок и восходящий блок, и я максимально увеличил пул в части кодера. Я показал размеры и ошибку после кода ниже:
Ошибка находится в первом восходящем блоке, поскольку он интерполирует размер 12, в то время как коррелирующий блок имеет размер 13 в 3-м измерении. Может кто-нибудь помочь мне сделать его симметричным или хотя бы работоспособным?
class UNet_down_block(torch.nn.Module):
def __init__(self, input_channel, output_channel, down_size):
super(UNet_down_block, self).__init__()
self.conv1 = torch.nn.Conv3d(input_channel, output_channel, 3, padding=1,dilation=1)
self.bn1 = torch.nn.BatchNorm3d(output_channel)
self.conv2 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1,dilation=2)
self.bn2 = torch.nn.BatchNorm3d(output_channel)
self.conv3 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1,dilation=2)
self.bn3 = torch.nn.BatchNorm3d(output_channel)
self.max_pool = torch.nn.MaxPool3d(2, 2)
self.relu = torch.nn.ELU()
self.down_size = down_size
def forward(self, x):
if self.down_size:
x = self.max_pool(x)
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.relu(self.bn3(self.conv3(x)))
return x
class UNet_up_block(torch.nn.Module):
def __init__(self, prev_channel, input_channel, output_channel):
super(UNet_up_block, self).__init__()
# self.up_sampling = torch.nn.functional.interpolate(scale_factor=2, mode='trilinear')
self.conv1 = torch.nn.Conv3d(prev_channel + input_channel, output_channel, 3, padding=1)
self.bn1 = torch.nn.BatchNorm3d(output_channel)
self.conv2 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1)
self.bn2 = torch.nn.BatchNorm3d(output_channel)
self.conv3 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1)
self.bn3 = torch.nn.BatchNorm3d(output_channel)
self.relu = torch.nn.ELU()
def forward(self, prev_feature_map, x):
x = torch.nn.functional.interpolate(x,scale_factor=2, mode='trilinear')
x = torch.cat((x, prev_feature_map), dim=1)
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.relu(self.bn3(self.conv3(x)))
return x
class UNet(torch.nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.down_block1 = UNet_down_block(4, 24, False)
self.down_block2 = UNet_down_block(24, 72, True)
self.down_block3 = UNet_down_block(72, 148, True)
self.down_block4 = UNet_down_block(148, 224, False)
self.max_pool = torch.nn.MaxPool3d(2, 2)
self.mid_conv1 = torch.nn.Conv3d(224, 224, 3, padding=1)
self.bn1 = torch.nn.BatchNorm3d(224)
self.mid_conv2 = torch.nn.Conv3d(224, 224, 3, padding=1)
self.bn2 = torch.nn.BatchNorm3d(224)
self.mid_conv3 = torch.nn.Conv3d(224, 224, 3, padding=1)
self.bn3 = torch.nn.BatchNorm3d(224)
self.up_block1 = UNet_up_block(224, 224, 148)
self.up_block2 = UNet_up_block(148, 148, 72)
self.up_block3 = UNet_up_block(72, 72, 24)
self.up_block4 = UNet_up_block(24, 24, 8)
self.last_conv1 = torch.nn.Conv3d(8, 4, 3, padding=1)
self.last_bn = torch.nn.BatchNorm3d(4)
self.last_conv2 = torch.nn.Conv3d(4, 1, 1, padding=0)
self.relu = torch.nn.ELU()
self.last_conv3 = torch.nn.Conv3d(1, 1, 1, padding=0)
self.relu = torch.nn.ELU()
self.conv1f=torch.nn.Conv2d(1, 5, 3,padding=1)
self.conv2f=torch.nn.Conv2d(5, 5, 3,padding=1)
self.conv3f=torch.nn.Conv2d(5, 5, 3,padding=1)
def forward(self, x):
print('input unet',x.size())
self.x1 = self.down_block1(x)
print("Block 1 shape:",self.x1.size())
self.x2 = self.down_block2(self.x1)
if self.x2.size()[2]==49: ###*********************************** ifffff if self.x2.size()[2]==49:
self.x2=self.x2[:,:,1:,1:,:]
print("Block 2 shape:",self.x2.size())
self.x3 = self.down_block3(self.x2)
print("Block 3 shape:",self.x3.size())
self.x4 = self.down_block4(self.x3)
print("Block 4 shape:",self.x4.size())
self.xmid=self.max_pool(self.x4)
self.xmid = self.relu(self.bn1(self.mid_conv1(self.xmid)))
self.xmid = self.relu(self.bn2(self.mid_conv2(self.xmid)))
self.xmid = self.relu(self.bn3(self.mid_conv3(self.xmid)))
print("Block Mid shape:",self.xmid.size())
x = self.up_block1(self.x4, self.xmid)
# print("BlockU 1 shape:",x.size())
x = self.up_block2(self.x3, x)
print("BlockU 2 shape:",x.size())
x = self.up_block3(self.x2, x)
print("BlockU 3 shape:",x.size())
if self.x1.size()[2]==98: ###*********************************** ifffff
self.x1=self.x1[:,:,1:-1,1:-1,:]
# print('chan98',self.x1.size())
x = self.up_block4(self.x1, x)
print("BlockU 4 shape:",x.size())
x = self.relu(self.last_bn(self.last_conv1(x)))
x = self.last_conv2(x) # of size [batch_size,1,h,w,depth] or [bs, modalities(1) ,96 ,96 , 64]
x=x.view(batch_size,1,-1,64)
# x=x.squeeze(1)
# print('input convf',x.size())
conv=self.relu(self.conv1f(x))
conv=self.relu(self.conv2f(conv))
conv=self.conv3f(conv)
try:
conv=conv.view(batch_size,5,96,96,64)
except:
conv=conv.view(batch_size_val,5,96,96,64)
# print('unet output',conv.size())
return(conv)
Вот результат:
input unet torch.Size([1, 4, 96, 96, 64])
Block 1 shape: torch.Size([1, 24, 92, 92, 60])
Block 2 shape: torch.Size([1, 72, 42, 42, 26])
Block 3 shape: torch.Size([1, 148, 17, 17, 9])
Block 4 shape: torch.Size([1, 224, 13, 13, 5])
Block Mid shape: torch.Size([1, 224, 6, 6, 2])
Error:
x = self.up_block1(self.x4, self.xmid)
111 # print("BlockU 1 shape:",x.size())
112 x = self.up_block2(self.x3, x)
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
475 result = self._slow_forward(*input, **kwargs)
476 else:
--> 477 result = self.forward(*input, **kwargs)
478 for hook in self._forward_hooks.values():
479 hook_result = hook(self, input, result)
<ipython-input-5-cbcdda025480> in forward(self, prev_feature_map, x)
39 x = torch.nn.functional.interpolate(x,scale_factor=2, mode='trilinear')
40
---> 41 x = torch.cat((x, prev_feature_map), dim=1)
42 x = self.relu(self.bn1(self.conv1(x)))
43 x = self.relu(self.bn2(self.conv2(x)))
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 12 and 13 in dimension 2 at /opt/conda/conda-bld/pytorch_1535491974311/work/aten/src/TH/generic/THTensorMath.cpp:3616