пары (изображение, маска) не соответствуют друг другу в задаче семантической сегментации

Я пишу простой пользовательский DataLoader (в который позже я добавлю дополнительные функции) для набора данных сегментации, но пара (изображение, маска), которую я возвращаю, используя __getitem()__метод разные; возвращенная маска принадлежит другому изображению, чем то, которое было возвращено. Моя структура каталогов /home/bohare/data/images и /home/bohare/data/masks.

Вот код, который у меня есть:

import torch
from torch.utils.data.dataset import Dataset
from PIL import Image
import glob
import os
import matplotlib.pyplot as plt

class CustomDataset(Dataset):
    def __init__(self, folder_path):
        
        self.img_files = glob.glob(os.path.join(folder_path,'images','*.png'))
        self.mask_files = glob.glob(os.path.join(folder_path,'masks','*.png'))
    
    def __getitem__(self, index):
        
        image = Image.open(self.img_files[index])
        mask = Image.open(self.mask_files[index])
        
        return image, mask
    
    def __len__(self):
        return len(self.img_files)
data = CustomDataset(folder_path = '/home/bohare/data')
len(data)

Этот код правильно показывает общий размер набора данных.

Но когда я использую:img, msk = data.__getitem__(n) где n - индекс любой пары (изображение, маска), и я рисую изображение и маску, они не соответствуют друг другу.

Как я могу изменить / что добавить в код, чтобы убедиться, что пара (изображение, маска) возвращается правильно? Спасибо за помощь.

1 ответ

Решение

glob.glob возвращает его без заказа, glob.glob звонки внутри os.listdir:

os.listdir (путь) Возвращает список, содержащий имена записей в каталоге, заданном путем. Список находится в произвольном порядке. Он не включает специальные записи '.' и "..", даже если они присутствуют в каталоге.

Чтобы решить эту проблему, вы можете просто отсортировать оба, чтобы порядок был одинаковым:

self.img_files = sorted(glob.glob(os.path.join(folder_path,'images','*.png')))
self.mask_files = sorted(glob.glob(os.path.join(folder_path,'masks','*.png')))
Другие вопросы по тегам