Pytorch: «KeyError: обнаружена ошибка KeyError в рабочем процессе DataLoader 0.»

Описание проблемы:

Я пытаюсь загрузить данные изображения с помощью пользовательского набора данных Pytorch. Я немного углубился и обнаружил, что мой набор изображений состоит из двух типов форм (512,512,3) и (1024,1024). Мое предположение заключается в том, что по вышеуказанной причине возникает ошибка ниже.

Примечание. Код может читать некоторые изображения, но для некоторых из них он выдает сообщение об ошибке, приведенное ниже. Это послужило причиной для проведения небольшого EDA для данных изображения и обнаружило, что в наборе данных было 2 разных формы изображений.

Q1. Как предварительно обработать такие данные изображения для обучения?

Q2. Есть ли другие причины, по которым я могу видеть сообщение об ошибке ниже?

Сообщение об ошибке:

      KeyError                                  Traceback (most recent call last)
<ipython-input-163-aa3385de8026> in <module>
----> 1 train_features, train_labels = next(iter(train_dataloader))
  2 print(f"Feature batch shape: {train_features.size()}")
  3 print(f"Labels batch shape: {train_labels.size()}")
  4 img = train_features[0].squeeze()
  5 label = train_labels[0]

 ~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils  /data/dataloader.py in __next__(self)
519             if self._sampler_iter is None:
520                 self._reset()
521             data = self._next_data()
522             self._num_yielded += 1
523             if self._dataset_kind == _DatasetKind.Iterable and \

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _next_data(self)
1201             else:
1202                 del self._task_info[idx]
1203                 return self._process_data(data)
1204 
1205     def _try_put_index(self):

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
1227         self._try_put_index()
1228         if isinstance(data, ExceptionWrapper):
1229             data.reraise()
1230         return data
1231 

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/_utils.py in reraise(self)
423             # have message field
424             raise self.exc_type(message=msg)
425         raise self.exc_type(msg)
426 
427 

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas  /core/indexes/base.py", line 2898, in get_loc
return self._engine.get_loc(casted_key)
File "pandas/_libs/index.pyx", line 70, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/index.pyx", line 101, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/hashtable_class_helper.pxi", line 1032, in    pandas._libs.hashtable.Int64HashTable.get_item
File "pandas/_libs/hashtable_class_helper.pxi", line 1039, in   pandas._libs.hashtable.Int64HashTable.get_item
KeyError: 16481

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "<ipython-input-161-f38b78d77dcb>", line 19, in __getitem__
img_path =os.path.join(self.img_dir,self.image_ids[idx])
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas/core/series.py", line 882, in __getitem__
return self._get_value(key)
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas/core/series.py", line 990, in _get_value
loc = self.index.get_loc(label)
File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas/core/indexes/base.py", line 2900, in get_loc
raise KeyError(key) from err
KeyError: 16481

Код:

      from torchvision.io import read_image
import torch
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
     # init
    def __init__(self,dataset,transforms=None,target_transforms=None):
        #self.train_data = pd.read_csv("Data/train_data.csv")
        self.image_ids = dataset.image_id
        self.image_labels = dataset.label
        self.img_dir = 'Data/images'
        self.transforms = transforms
        self.target_transforms = target_transforms
# len
    def __len__(self):
        return len(self.image_ids)
# getitem
    def __getitem__(self,idx):
        # image path
        img_path =os.path.join(self.img_dir,self.image_ids[idx])
        # image
        image = read_image(img_path)
        label = self.image_labels[idx]
    # transform image
        if self.transforms:
             image = self.transforms(image)
    # transform target
        if self.target_transforms:
             label = self.target_transforms(label)
    return image, label

Код: train_data- это объект pandas файла csv, который имеет идентификатор изображения, информацию о ярлыке.

        from sklearn.model_selection import train_test_split
  X_train, X_test = train_test_split(train_data, test_size=0.1, random_state=42)
  train_df = CustomImageDataset(X_train)
  train_dataloader = torch.utils.data.DataLoader(
        train_df,
        batch_size=64,
        num_workers=1,
        shuffle=True,
    )

3 ответа

обнаружил проблему с кодом.

Пользовательская функция загрузчика данных Pytorch " getitem " использует idx для извлечения данных, и я предполагаю, что она знает диапазон idx от функции len, например: 0, до len (строк в наборе данных).

В моем случае у меня уже был набор данных панды (train_data) с idx в качестве одного из столбцов. Когда я случайным образом разделил его на X_train и X_test, несколько строк данных были перемещены в X_test вместе с idx.

Теперь, когда я отправляю X_train пользовательскому загрузчику данных, он пытается получить image_id строки с idx, и этот idx просто случайно находится в наборе данных X_test. Это приводит к ошибке как keyerror: 16481, т.е. строка с idx=16481 отсутствует в наборе данных X_train. Он был перенесен в X_test во время разделения.

уф ...

Я получил ту же ошибку при тонкой настройке модели на основе трансформаторов DistilBertModel в PyTorch при замене ее головы .

Я забыл сбросить индексы train_dataframe и test_dataframe после train_test_split, что вызвало мой CustomDatasetнеправильно индексировать.

Проблема решена.

Я не использовал метод преобразования Pytorch, и, следовательно, разные размеры изображений создавали проблемы. Я использовал приведенный ниже код, и он решил эту проблему.

      train_transforms = transforms.Compose([
        transforms.Resize((244,244)),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
         ])



train_df = CustomImageDataset(X_train,train_transforms)

train_dataloader = torch.utils.data.DataLoader(
        train_df,
        batch_size=64,
        num_workers=1,
        shuffle=True,
    )
Другие вопросы по тегам