как создать расширенный набор данных в pythorch

Я должен добавить к исходному набору данных CIFAR для каждого изображения соответствующие, повернутые на 90 градусов. Идея состоит в том, чтобы создать RotationDateset, класс, который расширяет наборы данных.VisionDataset, который принимает CIFAR и выполняет то, что описано выше.

      from __future__ import print_function, division
import skimage.io

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
from torchvision.datasets import ImageFolder
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from sklearn.model_selection import train_test_split
import copy
import cv2
from torchvision.models.resnet import BasicBlock
from torchvision.models.resnet import ResNet
from PIL import Image
import xml.etree.ElementTree as ET
from torch.utils.model_zoo import load_url as load_state_dict_from_url
from torchvision.models.resnet import model_urls

// org_dataset - это CIFAR // num_rots равно 4 // transforms is transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5,0.5), (0.5, 0.5, 0.5))])

      class RotDataset(datasets.VisionDataset):
    def __init__(self, org_dataset, transforms, num_rots):
        
        self.samples = org_dataset.data
        self.targets = []
        self.num_rots = num_rots
        self.transforms = transforms

        for k in self.samples:
          self.targets.append(k)

          for i in range(0, self.num_rots):
            tr = torchvision.transforms.Compose([torchvision.transforms.RandomRotation(degrees=90*i),
                                                 torchvision.transforms.ToTensor(),
                                                 torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
            # from PIL import Image
            p_i = Image.fromarray(k)
            te = tr(p_i)
            r_im = torch.reshape(te, (k.shape))
            r_im = np.array(r_im)
            self.targets.append(r_im)
            
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
      imgs = self.targets[index:index + self.num_rots]
      labels = list(range(0, self.num_rots))

      return imgs, labels

вот как я изначально импортирую и преобразую CIFAR:

      transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)


testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

вот как я создаю расширенный CIFAR:

      cifar_rot = RotDataset(trainset, trainset.transforms, 4)

rot_train, rot_val= train_test_split(
np.arange(len(cifar_rot.targets)),
test_size=0.2,
shuffle=True,
)

train_sampler_rot = torch.utils.data.SubsetRandomSampler(rot_train)
val_sampler_rot = torch.utils.data.SubsetRandomSampler(rot_val)

dataloaders_rot = {'train': torch.utils.data.DataLoader(cifar_rot, batch_size=8, sampler=train_sampler_rot)
               , 'val':torch.utils.data.DataLoader(cifar_rot, batch_size=8, sampler=val_sampler_rot)}

sizes_rot = {'train':len(rot_train)*4,'val':len(rot_val)*4}

и модельное обучение

      model_rot = torchvision.models.resnet34(pretrained=False) 

num_ftrs = model_rot.fc.in_features
output_dim_rot = 4 # since are 4 rotations

model_rot.fc = nn.Linear(num_ftrs, output_dim_rot)

model_rot = model_rot.to(device)
criterion = nn.CrossEntropyLoss()

optimizer_conv = optim.SGD(model_rot.parameters(), lr=0.001, momentum=0.9)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
model_rot = train_model(model_rot,
                        criterion,
                        optimizer_conv,
                        exp_lr_scheduler,
                        dataloaders_rot,
                        sizes_rot,
                        num_epochs=10)

torch.save(model_rot.state_dict(),'rotation_resnet34_10_epochs.pt')

// проблема в том, что когда я запускаю модель, pythorch выдает эту ошибку:

      Epoch 0/9
----------
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-61-977dbbbef6fe> in <module>()
     23                         dataloaders_rot,
     24                         sizes_rot,
---> 25                         num_epochs=10)
     26 #Save the best trained model, for later use
     27 torch.save(model_rot.state_dict(),'rotation_resnet34_10_epochs.pt')

6 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    394                             _pair(0), self.dilation, self.groups)
    395         return F.conv2d(input, weight, bias, self.stride,
--> 396                         self.padding, self.dilation, self.groups)
    397 
    398     def forward(self, input: Tensor) -> Tensor:

RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[32, 32, 32, 3] to have 3 channels, but got 32 channels instead

кто-нибудь может мне помочь? заранее спасибо

1 ответ

проблема возникает из-за того, что вы полагаетесь на org_dataset.data, который представляет собой большой массив формы (N, 32, 32, 3) (где бы вы хотели, чтобы это было (N, 3, 32, 32))

Итак, с линией self.targets.append(k), вы добавили неправильные формы в свой список целей. Тогда тензор te имеет правильную форму (спасибо), но вы изменяете его до неправильной формы через линию после

Также хочу отметить, что случайные преобразования, такие как RandomRotation обычно применяются в __getitem__ метод, а не в __init__. Поскольку в этих преобразованиях происходит генерация случайных чисел, вы хотите, чтобы новые выборки создавались каждую эпоху, чтобы иметь практически бесконечный набор данных и выборки. На самом деле я не уверен, что вы понимаете, что делает RandomRotation: он вращает входной тензор со случайным вращением, для которого вы только указываете диапазон возможных углов. Так что вполне возможно, что применение «поворота» параметра 180 (i=2) даст почти неизменный тензор. Я вижу, вы пытаетесь предсказать ценность iпосле этого, скорее всего, не сработает. этого вы можете использовать Вместоtorch.rot90 .

В дополнение к этому, поскольку вы уже подаете заявку ToTensor а также Normalize в RotDataset, они вам точно не нужны CIFAR10.

Последний комментарий: я правда не понимаю, зачем вы хотите __getitemчтобы вернуть список тензоров (и меток). Я сохраню это в приведенном ниже коде, но похоже, что в конечном итоге это что-то сломает.

Итак, вот как бы вы исправили свой код:

      class RotDataset(datasets.VisionDataset):
    def __init__(self, org_dataset, transforms, num_rots):
    
        # Let's buffer the underlying dataset, we will sample   
        # from it on the fly
        self.dataset = org_dataset
        self.num_rots = num_rots
        # You did not use this attribute previously, probably a mistake
        # It will now be applied in the __getitem__
        self.transforms = transforms
        
    def __len__(self):
        # Typical front dataset : size is the same as the 
        # underlying dataset size
        return len(self.dataset)

    def __getitem__(self, index):
        # sampling from CIFAR10
        sample = self.dataset[index]
        # Because you want to return a list
        imgs = []
        for i in range(0, self.num_rots):
            # Creating the corresponding rotation
            rotation = torchvision.transforms.RandomRotation(degrees=90*i)
            # Applying rotation, followed by other transforms (toTensor, Normalize...)
            transformed = self.transform(rotation(sample))
            imgs.append(transformed)

        # Cleaner way to generate your range : 
        labels = np.arange(self.num_rots)

        return imgs, labels

# transform=None, since we will apply them in RotDataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=None)
# The transforms to call in RotDataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
cifar_rot = RotDataset(trainset, transform, 4)

# using torch's random split to remove dependency on sklearn
from torch.utils.data import random_split
test_size = 0.2*len(cifar_rot)
rot_train, rot_val= random_split(cifar_rot, [len(cifar_rot)-test_size, test_size])
Другие вопросы по тегам