U-net: Как изменить вес фона в pytorch

База данных, которую я использую, является МРТ-картиной мозга. Я преобразовал эти изображения в 2D-изображения. Во время обучения я обнаружил, что, поскольку большинство помеченных изображений были черным фоном, только небольшая их часть была опухолью. Это приводит к несбалансированной категории. Я хочу увеличить вес опухоли и уменьшить вес фона. Моя идея состоит в том, сколько раз фоновый пиксель может быть вычислен, чем больше он появляется, тем меньше его вес. Но я не знаю, как изменить программу? Мой код: unet.py:

    import torch.nn as nn
    import torch.nn.functional as F
    import torch
    from numpy.linalg import svd
    from numpy.random import normal
    from math import sqrt


    class UNet(nn.Module):
        def __init__(self,colordim =4):
            super(UNet, self).__init__()
            self.conv1_1 = nn.Conv2d(colordim,32,3,padding=1,stride=1)  # input of (n,n,1), output of (n-2,n-2,64)
            self.conv1_2 = nn.Conv2d(32,32,3,padding=1,stride=1)
            self.bn1 = nn.BatchNorm2d(32)

            self.conv2_1 = nn.Conv2d(32, 64, 3,padding=1)
            self.conv2_2 = nn.Conv2d(64, 64, 3,padding=1)
            self.bn2 = nn.BatchNorm2d(64)

            self.conv3_1 = nn.Conv2d(64, 128, 3,padding=1)
            self.conv3_2 = nn.Conv2d(128, 128, 3,padding=1)
            self.bn3 = nn.BatchNorm2d(128)

            self.conv4_1 = nn.Conv2d(128, 256, 3,padding=1)
            self.conv4_2 = nn.Conv2d(256, 256, 3,padding=1)
            self.upconv4 = nn.Conv2d(256, 128, 1)
            self.bn4 = nn.BatchNorm2d(128)
            self.bn4_out = nn.BatchNorm2d(256)

            self.conv5_1 = nn.Conv2d(256, 128, 3,padding=1)
            self.conv5_2 = nn.Conv2d(128, 128, 3,padding=1)
            self.upconv5 = nn.Conv2d(128, 64, 1)
            self.bn5 = nn.BatchNorm2d(64)
            self.bn5_out = nn.BatchNorm2d(128)

            self.conv6_1 = nn.Conv2d(128, 64, 3,padding=1)
            self.conv6_2 = nn.Conv2d(64, 64, 3,padding=1)
            self.upconv6 = nn.Conv2d(64, 32, 1)
            self.bn6 = nn.BatchNorm2d(32)
            self.bn6_out = nn.BatchNorm2d(64)

            self.conv7_1 = nn.Conv2d(64, 32, 3,padding=1)
            self.conv7_2 = nn.Conv2d(32, 32, 3,padding=1)
            self.conv7_3 = nn.Conv2d(32, 1, 1)
            self.bn7 = nn.BatchNorm2d(1)

            self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
            self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
            self._initialize_weights()

        def forward(self, x1):
            #x1out
            x1 = F.relu(self.bn1(self.conv1_2(F.relu(self.conv1_1(x1)))))
            #print('x1 size: %d'%(x1.size(2)))
            #x2out
            x2 = F.relu(self.bn2(self.conv2_2(F.relu(self.conv2_1(self.maxpool(x1))))))
            #print('x2 size: %d'%(x2.size(2)))
            #x3out
            x3 = F.relu(self.bn3(self.conv3_2(F.relu(self.conv3_1(self.maxpool(x2))))))
            #print('x3 size: %d'%(x3.size(2)))
            #x4out
            xup = F.relu((self.conv4_2(F.relu(self.conv4_1(self.maxpool(x3))))))
            #print('x4 size: %d'%(xup.size(2)))
            #x5in
            xup = self.bn4(self.upconv4(self.upsample(xup)))
            #print('x5in size: %d'%(xup.size(2)))
            #x5out
            xup = self.bn4_out(torch.cat((x3,xup),1))
            xup = F.relu(self.conv5_2(F.relu(self.conv5_1(xup))))
            #print('x5ou size: %d' % (xup.size(2)))
            #x6in
            xup = self.bn5(self.upconv5(self.upsample(xup)))

            #x6out
            xup = self.bn5_out(torch.cat((x2,xup),1))
            xup = F.relu(self.conv6_2(F.relu(self.conv6_1(xup))))

            #x7in
            xup = self.bn6(self.upconv6(self.upsample(xup)))

            #x7out
            xup = self.bn6_out(torch.cat((x1,xup),1))
            xup = F.relu(self.conv7_3(F.relu(self.conv7_2(F.relu(self.conv7_1(xup))))))
            return F.softsign(self.bn7(xup))


        def _initialize_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(0, sqrt(2. / n))
                    if m.bias is not None:
                        m.bias.data.zero_()
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()


    unet = UNet().cuda()

train.py:

#-*- coding:utf-8 -*-
from __future__ import print_function
from math import log10
import numpy as np
import random
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from unet import UNet
from load_data import get_train_img,get_test_img
import torchvision


# Training settings

class option:
    def __init__(self):
        self.cuda = True #use cuda?
        self.batchSize = 3 #training batch size
        self.testBatchSize = 3 #testing batch size
        self.nEpochs = 5 #number of epochs to train for
        self.lr = 0.001 #Learning Rate. Default=0.01
        self.threads = 0 #number of threads for data loader to use
        self.seed = 123 #random seed to use. Default=123
        self.size = 240 # image size
        self.colordim = 4 #
        self.pretrain_net = 'F:\\image\\100data\\model\\model_epoch_140.pth'

def map01(tensor,eps=1e-5):
    #input/output:tensor
    max = np.max(tensor.numpy(), axis=(1,2,3), keepdims=True)
    min = np.min(tensor.numpy(), axis=(1,2,3), keepdims=True)
    if (max-min).any():
        return torch.from_numpy( (tensor.numpy() - min) / (max-min + eps) )
    else:
        return torch.from_numpy( (tensor.numpy() - min) / (max-min) )

opt = option()

cuda = opt.cuda
if cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")

torch.manual_seed(opt.seed)
if cuda:
    torch.cuda.manual_seed(opt.seed)
print('===>Use the default training, test data')
print('===> Loading datasets')
train_set = get_train_img(opt.size)
test_set = get_test_img(opt.size)
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)
#torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
    #  num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

print('===> Building unet')
unet = UNet()

criterion = nn.MSELoss()
if cuda:
    unet = unet.cuda()
    criterion = criterion.cuda()

pretrained = False
if pretrained:
    unet.load_state_dict(torch.load(opt.pretrain_net))

optimizer = optim.SGD(unet.parameters(), lr=opt.lr)
print('===> Training unet')
print(training_data_loader)
def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(training_data_loader, 1):
        #print('train\' iteration and batch'+str(type(iteration))+' '+str(type(batch)))
        #print()
        input = Variable(batch[0])
        target = Variable(batch[1])
        filename = batch[2]
        #target =target.squeeze(1)
        #print(target.data.size())
        if cuda:
            input = input.cuda()
            target = target.cuda()
        input = unet(input)

        loss = criterion( input, target)
        epoch_loss += loss.data[0]
        loss.backward()
        optimizer.step()
        if iteration%10 is 0:
            print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.data[0]))

    imgout = input.data
    torchvision.utils.save_image(imgout,"F:\\image\\100data\\output\\train_out\\"+str(epoch)+'.png',padding=0)
    print("===> Epoch {} Complete: Avg. Loss: {}".format(epoch, epoch_loss / len(training_data_loader)))
    print('train ok')
def test(epoch):
    totalloss = 0
    for iteration ,batch in enumerate(testing_data_loader):

        input = Variable(batch[0])
        filename = batch[1]
        if cuda:
            input = input.cuda()


            optimizer.zero_grad()
            prediction = unet(input)

        imgout = prediction.data
        torchvision.utils.save_image(imgout,"F:\\image\\100data\\output\\test_out\\"+str(epoch)+'.png',padding=0)


def checkpoint(epoch):
    model_out_path = "F:\\image\\100data\\model\\model_epoch_{}.pth".format(epoch)
    torch.save(unet.state_dict(), model_out_path)
    print("Checkpoint saved to {}".format(model_out_path))

for epoch in range(1, 1+opt.nEpochs + 1):
    train(epoch)
    if epoch%10 is 0:
        checkpoint(epoch)
    test(epoch)
checkpoint(epoch)

load_data.py:

#-*- coding:utf-8 -*-
from os.path import exists,join
from os import listdir
import numpy as np
from PIL import Image
import torch.utils.data as data
from skimage import io
from torchvision.transforms import Compose, CenterCrop, ToTensor, Scale,ToPILImage


def brats2015(dest = r'F:\image'):
    if not exists(dest):
        print('Sorry,dataset not exits')
        print('please check the file path')
    return dest

def input_transform(crop_size):
    return Compose([
        #CenterCrop(crop_size),
        ToTensor()
    ])
def get_train_img(size,train_path = r'H:\deepfortest\2Dpng'):
    brats_train = brats2015(train_path)
    train_num = join(brats_train,'T1')
    return DatasetFromFolder(image_dir = train_num,
                             image_path = train_path,
                             input_transform = input_transform(size),
                             target_transform = input_transform(size),
                             OT_point=True,
                            )

def get_test_img(size,test_path = r'F:\image\100data\test'):
    brats_test = brats2015(test_path)
    test_num = join(brats_test,'T1')
    return DatasetFromFolder(image_dir = test_num,
                             image_path = test_path,
                             input_transform = input_transform(size),
                             target_transform = input_transform(size),
                             OT_point=False,
                            )

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg"])

def load_img(image_path, name, label_bool=False):
    if label_bool is False:
        Flair_path1 = join(join(image_path,'Flair'),name)
        T1_path2 = join(join(image_path,'T1'),name)
        T1C_path3 = join(join(image_path,'T1C'),name)
        T2_path4 = join(join(image_path,'T2'),name)
        Flair = io.imread(Flair_path1)
        T1 = io.imread(T1_path2)
        T1C = io.imread(T1C_path3)
        T2 = io.imread(T2_path4)
        B = np.dstack((Flair,T1,T1C,T2))
    if label_bool is True:
        path = join(join(image_path,'OT'),name)
        B = Image.open(path)
        #B = np.array(A)
    return B
#image_dir = r'F:\image\train\T1',image_path = r'F:\image\train'
class DatasetFromFolder(data.Dataset):
    def __init__(self, image_dir, image_path, input_transform=None, target_transform=None,OT_point = True):
        super(DatasetFromFolder, self).__init__()
        self.image_filenames = [x for x in listdir(image_dir) if is_image_file(x)]
        self.input_transform = input_transform
        self.target_transform = target_transform
        self.image_dir = image_dir
        self.image_path = image_path
        self.OT_point = OT_point

    def __getitem__(self, index):
        input = load_img(self.image_path, self.image_filenames[index])
        if self.input_transform:
            input = self.input_transform(input)
        if self.OT_point:
            target = load_img(self.image_path, self.image_filenames[index], label_bool=True)
            if self.target_transform:
                target = self.target_transform(target)
            return input, target,self.image_filenames
        return input,self.image_filenames

    def __len__(self):
        return len(self.image_filenames)

Мой каталог:

----image
------train
-------- T1
-------- T1C
-------- T2
-------- Flair
--------OT
------test
-------- T1
-------- T1C
-------- T2
-------- Flair
------model
------output

0 ответов

Другие вопросы по тегам