Потери не уменьшаются для двоичной классификации
Я пытаюсь реализовать двоичную классификацию. У меня есть набор данных изображения 100K (3 канала, 224 x 224 пикселей с предварительно измененным размером), который я пытаюсь обучить модели, если изображение безопасно для работы или нет. Я инженер по обработке данных со статистическим образованием, поэтому работаю над моделью последние 5-10 дней. Я пытался реализовать решения на основе предложений, но, к сожалению, потери не уменьшились.
Вот класс, реализованный с помощью PyTorch Lightning,
from .dataset import CloudDataset
from .split import DatasetSplit
from pytorch_lightning import LightningModule
from pytorch_lightning.metrics import Accuracy
from torch import stack
from torch.nn import BCEWithLogitsLoss, Conv2d, Dropout, Linear, MaxPool2d, ReLU
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torchvision.transforms import ToTensor
from util import logger
from util.config import config
class ClassifyModel(LightningModule):
def __init__(self):
super(ClassifyModel, self).__init__()
# custom dataset split class
ds = DatasetSplit(config.s3.bucket, config.train.ratio)
# split records for train, validation and test
self._train_itr, self._valid_itr, self._test_itr = ds.split()
self.conv1 = Conv2d(3, 32, 3, padding=1)
self.conv2 = Conv2d(32, 64, 3, padding=1)
self.conv3 = Conv2d(64, 64, 3, padding=1)
self.pool = MaxPool2d(2, 2)
self.fc1 = Linear(7 * 28 * 64, 512)
self.fc2 = Linear(512, 16)
self.fc3 = Linear(16, 4)
self.fc4 = Linear(4, 1)
self.dropout = Dropout(0.25)
self.relu = ReLU(inplace=True)
self.accuracy = Accuracy()
def forward(self, x):
# comments are shape before execution
# [32, 3, 224, 224]
x = self.pool(self.relu(self.conv1(x)))
# [32, 32, 112, 112]
x = self.pool(self.relu(self.conv2(x)))
# [32, 64, 56, 56]
x = self.pool(self.relu(self.conv3(x)))
# [32, 64, 28, 28]
x = self.pool(self.relu(self.conv3(x)))
# [32, 64, 14, 14]
x = self.dropout(x)
# [32, 64, 14, 14]
x = x.view(-1, 7 * 28 * 64)
# [32, 12544]
x = self.relu(self.fc1(x))
# [32, 512]
x = self.relu(self.fc2(x))
# [32, 16]
x = self.relu(self.fc3(x))
# [32, 4]
x = self.dropout(self.fc4(x))
# [32, 1]
x = x.squeeze(1)
# [32]
return x
def configure_optimizers(self):
return Adam(self.parameters(), lr=0.001)
def training_step(self, batch, batch_idx):
image, target = batch
target = target.float()
output = self.forward(image)
loss = BCEWithLogitsLoss()
output = loss(output, target)
logits = self(image)
self.accuracy(logits, target)
return {'loss': output}
def validation_step(self, batch, batch_idx):
image, target = batch
target = target.float()
output = self.forward(image)
loss = BCEWithLogitsLoss()
output = loss(output, target)
return {'val_loss': output}
def collate_fn(self, batch):
batch = list(filter(lambda x: x is not None, batch))
return default_collate(batch)
def train_dataloader(self):
transform = ToTensor()
workers = 0 if config.train.test else config.train.workers
# custom data set class that read files from s3
cds = CloudDataset(config.s3.bucket, self._train_itr, transform)
return DataLoader(
dataset=cds,
batch_size=32,
shuffle=True,
num_workers=workers,
collate_fn=self.collate_fn,
)
def val_dataloader(self):
transform = ToTensor()
workers = 0 if config.train.test else config.train.workers
# custom data set class that read files from s3
cds = CloudDataset(config.s3.bucket, self._valid_itr, transform)
return DataLoader(
dataset=cds,
batch_size=32,
num_workers=workers,
collate_fn=self.collate_fn,
)
def test_dataloader(self):
transform = ToTensor()
workers = 0 if config.train.test else config.train.workers
# custom data set class that read files from s3
cds = CloudDataset(config.s3.bucket, self._test_itr, transform)
return DataLoader(
dataset=cds,
batch_size=32,
shuffle=True,
num_workers=workers,
collate_fn=self.collate_fn,
)
def validation_epoch_end(self, outputs):
avg_loss = stack([x['val_loss'] for x in outputs]).mean()
logger.info(f'Validation loss is {avg_loss}')
def training_epoch_end(self, outs):
accuracy = self.accuracy.compute()
logger.info(f'Training accuracy is {accuracy}')
Вот пользовательский вывод журнала,
epoch 0
Validation loss is 0.5988735556602478
Training accuracy is 0.4441356360912323
epoch 1
Validation loss is 0.6406065225601196
Training accuracy is 0.4441356360912323
epoch 2
Validation loss is 0.621654748916626
Training accuracy is 0.443579763174057
epoch 3
Validation loss is 0.5089989304542542
Training accuracy is 0.4580322504043579
epoch 4
Validation loss is 0.5484663248062134
Training accuracy is 0.4886047840118408
epoch 5
Validation loss is 0.5552918314933777
Training accuracy is 0.6142301559448242
epoch 6
Validation loss is 0.661466121673584
Training accuracy is 0.625903308391571
Проблема может быть связана с оптимизатором или функцией потерь, но я не мог понять.