Создайте набор данных тестирования pyTorch (без ярлыков)

Я создал набор данных pyTorch для моих обучающих данных, который состоит из функций и метки, позволяющей использовать pyTorch DataLoader с помощью этого руководства. Это хорошо работает для моих данных обучения, но я получаю сообщение об ошибке ( KeyError: "['label'] not found in axis") при загрузке тестового CSV-файла, который идентичен, за исключением отсутствия столбца «label».

Если это помогает, предполагаемый входной файл csv - это данные MNIST в файле csv, который имеет 28*28 столбцов функций.

      import torch

class mnist(torch.utils.data.Dataset):
    
    def __init__(self, csv_file):
        self.train = pd.read_csv(csv_file)
        self.train_x = self.train.drop("label", axis=1)
    
    def __len__(self):
        return len(self.train)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if isinstance(idx, list):
            idx_len = len(idx)
        else:
            idx_len = 1
        
        X = np.asarray(self.train_x.iloc[idx], dtype=np.float32)
        X = np.reshape(X, (1,28,28))
        y = np.asarray(self.train.iloc[idx]['label'])
        
        sample = {'X': X, 'y':y}
        
        return torch.from_numpy(sample['X']), torch.from_numpy(sample['y'])

1 ответ

Вы должны иметь возможность использовать оба данных:

      import torch

class mnist(torch.utils.data.Dataset):
    
    def __init__(self, csv_file):
        self.train = pd.read_csv(csv_file)

        self.training = "label" in self.train.columns
        self.train_x = self.train if not self.training else self.train.drop("label", axis=1)
    
    def __len__(self):
        return len(self.train)
    
    def __getitem__(self, idx):
        ...
        
        X = np.asarray(self.train_x.iloc[idx], dtype=np.float32)
        X = np.reshape(X, (1,28,28))
        if not self.training:
            return torch.from_numpy(X])

        y = np.asarray(self.train.iloc[idx]['label'])

        sample = {'X': X, 'y':y}
        return torch.from_numpy(sample['X']), torch.from_numpy(sample['y'])
Другие вопросы по тегам