как создать расширенный набор данных в pythorch
Я должен добавить к исходному набору данных CIFAR для каждого изображения соответствующие, повернутые на 90 градусов. Идея состоит в том, чтобы создать RotationDateset, класс, который расширяет наборы данных.VisionDataset, который принимает CIFAR и выполняет то, что описано выше.
from __future__ import print_function, division
import skimage.io
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
from torchvision.datasets import ImageFolder
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from sklearn.model_selection import train_test_split
import copy
import cv2
from torchvision.models.resnet import BasicBlock
from torchvision.models.resnet import ResNet
from PIL import Image
import xml.etree.ElementTree as ET
from torch.utils.model_zoo import load_url as load_state_dict_from_url
from torchvision.models.resnet import model_urls
// org_dataset - это CIFAR // num_rots равно 4 // transforms is transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5,0.5), (0.5, 0.5, 0.5))])
class RotDataset(datasets.VisionDataset):
def __init__(self, org_dataset, transforms, num_rots):
self.samples = org_dataset.data
self.targets = []
self.num_rots = num_rots
self.transforms = transforms
for k in self.samples:
self.targets.append(k)
for i in range(0, self.num_rots):
tr = torchvision.transforms.Compose([torchvision.transforms.RandomRotation(degrees=90*i),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# from PIL import Image
p_i = Image.fromarray(k)
te = tr(p_i)
r_im = torch.reshape(te, (k.shape))
r_im = np.array(r_im)
self.targets.append(r_im)
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
imgs = self.targets[index:index + self.num_rots]
labels = list(range(0, self.num_rots))
return imgs, labels
вот как я изначально импортирую и преобразую CIFAR:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
вот как я создаю расширенный CIFAR:
cifar_rot = RotDataset(trainset, trainset.transforms, 4)
rot_train, rot_val= train_test_split(
np.arange(len(cifar_rot.targets)),
test_size=0.2,
shuffle=True,
)
train_sampler_rot = torch.utils.data.SubsetRandomSampler(rot_train)
val_sampler_rot = torch.utils.data.SubsetRandomSampler(rot_val)
dataloaders_rot = {'train': torch.utils.data.DataLoader(cifar_rot, batch_size=8, sampler=train_sampler_rot)
, 'val':torch.utils.data.DataLoader(cifar_rot, batch_size=8, sampler=val_sampler_rot)}
sizes_rot = {'train':len(rot_train)*4,'val':len(rot_val)*4}
и модельное обучение
model_rot = torchvision.models.resnet34(pretrained=False)
num_ftrs = model_rot.fc.in_features
output_dim_rot = 4 # since are 4 rotations
model_rot.fc = nn.Linear(num_ftrs, output_dim_rot)
model_rot = model_rot.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_conv = optim.SGD(model_rot.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
model_rot = train_model(model_rot,
criterion,
optimizer_conv,
exp_lr_scheduler,
dataloaders_rot,
sizes_rot,
num_epochs=10)
torch.save(model_rot.state_dict(),'rotation_resnet34_10_epochs.pt')
// проблема в том, что когда я запускаю модель, pythorch выдает эту ошибку:
Epoch 0/9
----------
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-61-977dbbbef6fe> in <module>()
23 dataloaders_rot,
24 sizes_rot,
---> 25 num_epochs=10)
26 #Save the best trained model, for later use
27 torch.save(model_rot.state_dict(),'rotation_resnet34_10_epochs.pt')
6 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
394 _pair(0), self.dilation, self.groups)
395 return F.conv2d(input, weight, bias, self.stride,
--> 396 self.padding, self.dilation, self.groups)
397
398 def forward(self, input: Tensor) -> Tensor:
RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[32, 32, 32, 3] to have 3 channels, but got 32 channels instead
кто-нибудь может мне помочь? заранее спасибо
1 ответ
проблема возникает из-за того, что вы полагаетесь на
org_dataset.data
, который представляет собой большой массив формы
(N, 32, 32, 3)
(где бы вы хотели, чтобы это было
(N, 3, 32, 32)
)
Итак, с линией
self.targets.append(k)
, вы добавили неправильные формы в свой список целей. Тогда тензор
te
имеет правильную форму (спасибо), но вы изменяете его до неправильной формы через линию после
Также хочу отметить, что случайные преобразования, такие как
RandomRotation
обычно применяются в
__getitem__
метод, а не в
__init__
. Поскольку в этих преобразованиях происходит генерация случайных чисел, вы хотите, чтобы новые выборки создавались каждую эпоху, чтобы иметь практически бесконечный набор данных и выборки. На самом деле я не уверен, что вы понимаете, что делает RandomRotation: он вращает входной тензор со случайным вращением, для которого вы только указываете диапазон возможных углов. Так что вполне возможно, что применение «поворота» параметра 180 (i=2) даст почти неизменный тензор. Я вижу, вы пытаетесь предсказать ценность
i
после этого, скорее всего, не сработает. этого вы можете использовать Вместоtorch.rot90 .
В дополнение к этому, поскольку вы уже подаете заявку
ToTensor
а также
Normalize
в
RotDataset
, они вам точно не нужны
CIFAR10
.
Последний комментарий: я правда не понимаю, зачем вы хотите
__getitem
чтобы вернуть список тензоров (и меток). Я сохраню это в приведенном ниже коде, но похоже, что в конечном итоге это что-то сломает.
Итак, вот как бы вы исправили свой код:
class RotDataset(datasets.VisionDataset):
def __init__(self, org_dataset, transforms, num_rots):
# Let's buffer the underlying dataset, we will sample
# from it on the fly
self.dataset = org_dataset
self.num_rots = num_rots
# You did not use this attribute previously, probably a mistake
# It will now be applied in the __getitem__
self.transforms = transforms
def __len__(self):
# Typical front dataset : size is the same as the
# underlying dataset size
return len(self.dataset)
def __getitem__(self, index):
# sampling from CIFAR10
sample = self.dataset[index]
# Because you want to return a list
imgs = []
for i in range(0, self.num_rots):
# Creating the corresponding rotation
rotation = torchvision.transforms.RandomRotation(degrees=90*i)
# Applying rotation, followed by other transforms (toTensor, Normalize...)
transformed = self.transform(rotation(sample))
imgs.append(transformed)
# Cleaner way to generate your range :
labels = np.arange(self.num_rots)
return imgs, labels
# transform=None, since we will apply them in RotDataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=None)
# The transforms to call in RotDataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
cifar_rot = RotDataset(trainset, transform, 4)
# using torch's random split to remove dependency on sklearn
from torch.utils.data import random_split
test_size = 0.2*len(cifar_rot)
rot_train, rot_val= random_split(cifar_rot, [len(cifar_rot)-test_size, test_size])