U-net: Как изменить вес фона в pytorch
База данных, которую я использую, является МРТ-картиной мозга. Я преобразовал эти изображения в 2D-изображения. Во время обучения я обнаружил, что, поскольку большинство помеченных изображений были черным фоном, только небольшая их часть была опухолью. Это приводит к несбалансированной категории. Я хочу увеличить вес опухоли и уменьшить вес фона. Моя идея состоит в том, сколько раз фоновый пиксель может быть вычислен, чем больше он появляется, тем меньше его вес. Но я не знаю, как изменить программу? Мой код: unet.py:
import torch.nn as nn
import torch.nn.functional as F
import torch
from numpy.linalg import svd
from numpy.random import normal
from math import sqrt
class UNet(nn.Module):
def __init__(self,colordim =4):
super(UNet, self).__init__()
self.conv1_1 = nn.Conv2d(colordim,32,3,padding=1,stride=1) # input of (n,n,1), output of (n-2,n-2,64)
self.conv1_2 = nn.Conv2d(32,32,3,padding=1,stride=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2_1 = nn.Conv2d(32, 64, 3,padding=1)
self.conv2_2 = nn.Conv2d(64, 64, 3,padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.conv3_1 = nn.Conv2d(64, 128, 3,padding=1)
self.conv3_2 = nn.Conv2d(128, 128, 3,padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.conv4_1 = nn.Conv2d(128, 256, 3,padding=1)
self.conv4_2 = nn.Conv2d(256, 256, 3,padding=1)
self.upconv4 = nn.Conv2d(256, 128, 1)
self.bn4 = nn.BatchNorm2d(128)
self.bn4_out = nn.BatchNorm2d(256)
self.conv5_1 = nn.Conv2d(256, 128, 3,padding=1)
self.conv5_2 = nn.Conv2d(128, 128, 3,padding=1)
self.upconv5 = nn.Conv2d(128, 64, 1)
self.bn5 = nn.BatchNorm2d(64)
self.bn5_out = nn.BatchNorm2d(128)
self.conv6_1 = nn.Conv2d(128, 64, 3,padding=1)
self.conv6_2 = nn.Conv2d(64, 64, 3,padding=1)
self.upconv6 = nn.Conv2d(64, 32, 1)
self.bn6 = nn.BatchNorm2d(32)
self.bn6_out = nn.BatchNorm2d(64)
self.conv7_1 = nn.Conv2d(64, 32, 3,padding=1)
self.conv7_2 = nn.Conv2d(32, 32, 3,padding=1)
self.conv7_3 = nn.Conv2d(32, 1, 1)
self.bn7 = nn.BatchNorm2d(1)
self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
self._initialize_weights()
def forward(self, x1):
#x1out
x1 = F.relu(self.bn1(self.conv1_2(F.relu(self.conv1_1(x1)))))
#print('x1 size: %d'%(x1.size(2)))
#x2out
x2 = F.relu(self.bn2(self.conv2_2(F.relu(self.conv2_1(self.maxpool(x1))))))
#print('x2 size: %d'%(x2.size(2)))
#x3out
x3 = F.relu(self.bn3(self.conv3_2(F.relu(self.conv3_1(self.maxpool(x2))))))
#print('x3 size: %d'%(x3.size(2)))
#x4out
xup = F.relu((self.conv4_2(F.relu(self.conv4_1(self.maxpool(x3))))))
#print('x4 size: %d'%(xup.size(2)))
#x5in
xup = self.bn4(self.upconv4(self.upsample(xup)))
#print('x5in size: %d'%(xup.size(2)))
#x5out
xup = self.bn4_out(torch.cat((x3,xup),1))
xup = F.relu(self.conv5_2(F.relu(self.conv5_1(xup))))
#print('x5ou size: %d' % (xup.size(2)))
#x6in
xup = self.bn5(self.upconv5(self.upsample(xup)))
#x6out
xup = self.bn5_out(torch.cat((x2,xup),1))
xup = F.relu(self.conv6_2(F.relu(self.conv6_1(xup))))
#x7in
xup = self.bn6(self.upconv6(self.upsample(xup)))
#x7out
xup = self.bn6_out(torch.cat((x1,xup),1))
xup = F.relu(self.conv7_3(F.relu(self.conv7_2(F.relu(self.conv7_1(xup))))))
return F.softsign(self.bn7(xup))
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
unet = UNet().cuda()
train.py:
#-*- coding:utf-8 -*-
from __future__ import print_function
from math import log10
import numpy as np
import random
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from unet import UNet
from load_data import get_train_img,get_test_img
import torchvision
# Training settings
class option:
def __init__(self):
self.cuda = True #use cuda?
self.batchSize = 3 #training batch size
self.testBatchSize = 3 #testing batch size
self.nEpochs = 5 #number of epochs to train for
self.lr = 0.001 #Learning Rate. Default=0.01
self.threads = 0 #number of threads for data loader to use
self.seed = 123 #random seed to use. Default=123
self.size = 240 # image size
self.colordim = 4 #
self.pretrain_net = 'F:\\image\\100data\\model\\model_epoch_140.pth'
def map01(tensor,eps=1e-5):
#input/output:tensor
max = np.max(tensor.numpy(), axis=(1,2,3), keepdims=True)
min = np.min(tensor.numpy(), axis=(1,2,3), keepdims=True)
if (max-min).any():
return torch.from_numpy( (tensor.numpy() - min) / (max-min + eps) )
else:
return torch.from_numpy( (tensor.numpy() - min) / (max-min) )
opt = option()
cuda = opt.cuda
if cuda and not torch.cuda.is_available():
raise Exception("No GPU found, please run without --cuda")
torch.manual_seed(opt.seed)
if cuda:
torch.cuda.manual_seed(opt.seed)
print('===>Use the default training, test data')
print('===> Loading datasets')
train_set = get_train_img(opt.size)
test_set = get_test_img(opt.size)
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)
#torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
# num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
print('===> Building unet')
unet = UNet()
criterion = nn.MSELoss()
if cuda:
unet = unet.cuda()
criterion = criterion.cuda()
pretrained = False
if pretrained:
unet.load_state_dict(torch.load(opt.pretrain_net))
optimizer = optim.SGD(unet.parameters(), lr=opt.lr)
print('===> Training unet')
print(training_data_loader)
def train(epoch):
epoch_loss = 0
for iteration, batch in enumerate(training_data_loader, 1):
#print('train\' iteration and batch'+str(type(iteration))+' '+str(type(batch)))
#print()
input = Variable(batch[0])
target = Variable(batch[1])
filename = batch[2]
#target =target.squeeze(1)
#print(target.data.size())
if cuda:
input = input.cuda()
target = target.cuda()
input = unet(input)
loss = criterion( input, target)
epoch_loss += loss.data[0]
loss.backward()
optimizer.step()
if iteration%10 is 0:
print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.data[0]))
imgout = input.data
torchvision.utils.save_image(imgout,"F:\\image\\100data\\output\\train_out\\"+str(epoch)+'.png',padding=0)
print("===> Epoch {} Complete: Avg. Loss: {}".format(epoch, epoch_loss / len(training_data_loader)))
print('train ok')
def test(epoch):
totalloss = 0
for iteration ,batch in enumerate(testing_data_loader):
input = Variable(batch[0])
filename = batch[1]
if cuda:
input = input.cuda()
optimizer.zero_grad()
prediction = unet(input)
imgout = prediction.data
torchvision.utils.save_image(imgout,"F:\\image\\100data\\output\\test_out\\"+str(epoch)+'.png',padding=0)
def checkpoint(epoch):
model_out_path = "F:\\image\\100data\\model\\model_epoch_{}.pth".format(epoch)
torch.save(unet.state_dict(), model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
for epoch in range(1, 1+opt.nEpochs + 1):
train(epoch)
if epoch%10 is 0:
checkpoint(epoch)
test(epoch)
checkpoint(epoch)
load_data.py:
#-*- coding:utf-8 -*-
from os.path import exists,join
from os import listdir
import numpy as np
from PIL import Image
import torch.utils.data as data
from skimage import io
from torchvision.transforms import Compose, CenterCrop, ToTensor, Scale,ToPILImage
def brats2015(dest = r'F:\image'):
if not exists(dest):
print('Sorry,dataset not exits')
print('please check the file path')
return dest
def input_transform(crop_size):
return Compose([
#CenterCrop(crop_size),
ToTensor()
])
def get_train_img(size,train_path = r'H:\deepfortest\2Dpng'):
brats_train = brats2015(train_path)
train_num = join(brats_train,'T1')
return DatasetFromFolder(image_dir = train_num,
image_path = train_path,
input_transform = input_transform(size),
target_transform = input_transform(size),
OT_point=True,
)
def get_test_img(size,test_path = r'F:\image\100data\test'):
brats_test = brats2015(test_path)
test_num = join(brats_test,'T1')
return DatasetFromFolder(image_dir = test_num,
image_path = test_path,
input_transform = input_transform(size),
target_transform = input_transform(size),
OT_point=False,
)
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg"])
def load_img(image_path, name, label_bool=False):
if label_bool is False:
Flair_path1 = join(join(image_path,'Flair'),name)
T1_path2 = join(join(image_path,'T1'),name)
T1C_path3 = join(join(image_path,'T1C'),name)
T2_path4 = join(join(image_path,'T2'),name)
Flair = io.imread(Flair_path1)
T1 = io.imread(T1_path2)
T1C = io.imread(T1C_path3)
T2 = io.imread(T2_path4)
B = np.dstack((Flair,T1,T1C,T2))
if label_bool is True:
path = join(join(image_path,'OT'),name)
B = Image.open(path)
#B = np.array(A)
return B
#image_dir = r'F:\image\train\T1',image_path = r'F:\image\train'
class DatasetFromFolder(data.Dataset):
def __init__(self, image_dir, image_path, input_transform=None, target_transform=None,OT_point = True):
super(DatasetFromFolder, self).__init__()
self.image_filenames = [x for x in listdir(image_dir) if is_image_file(x)]
self.input_transform = input_transform
self.target_transform = target_transform
self.image_dir = image_dir
self.image_path = image_path
self.OT_point = OT_point
def __getitem__(self, index):
input = load_img(self.image_path, self.image_filenames[index])
if self.input_transform:
input = self.input_transform(input)
if self.OT_point:
target = load_img(self.image_path, self.image_filenames[index], label_bool=True)
if self.target_transform:
target = self.target_transform(target)
return input, target,self.image_filenames
return input,self.image_filenames
def __len__(self):
return len(self.image_filenames)
Мой каталог:
----image
------train
-------- T1
-------- T1C
-------- T2
-------- Flair
--------OT
------test
-------- T1
-------- T1C
-------- T2
-------- Flair
------model
------output