Получить len(набор данных) = 0 в проблеме обнаружения объекта
Решаю задачу обнаружения объектов в наборе данных фруктов: https://yadi.sk/d/UPwQB7OZrB48qQ. Мне дали код для моего класса набора данных:
class2tag = {"apple": 1, "orange": 2, "banana": 3}
class FruitDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.images = []
self.annotations = []
self.transform = transform
for annotation in glob.glob(data_dir + "/*xml"):
image_fname = os.path.splitext(annotation)[0] + ".jpg"
self.images.append(cv2.cvtColor(cv2.imread(image_fname), cv2.COLOR_BGR2RGB))
with open(annotation) as f:
annotation_dict = xmltodict.parse(f.read())
bboxes = []
labels = []
objects = annotation_dict["annotation"]["object"]
if not isinstance(objects, list):
objects = [objects]
for obj in objects:
bndbox = obj["bndbox"]
bbox = [bndbox["xmin"], bndbox["ymin"], bndbox["xmax"], bndbox["ymax"]]
bbox = list(map(int, bbox))
bboxes.append(torch.tensor(bbox))
labels.append(class2tag[obj["name"]])
self.annotations.append(
{"boxes": torch.stack(bboxes).float(), "labels": torch.tensor(labels)}
)
def __getitem__(self, i):
if self.transform:
# the following code is correct if you use albumentations
# if you use torchvision transforms you have to modify it =)
res = self.transform(
image=self.images[i],
bboxes=self.annotations[i]["boxes"],
labels=self.annotations[i]["labels"],
)
return res["image"], {
"boxes": torch.tensor(res["bboxes"]),
"labels": torch.tensor(res["labels"]),
}
else:
return self.images[i], self.annotations[i]
def __len__(self):
return len(self.images)
Я делаю свой проект в Google Colab, поэтому смонтировал Google Drive и распаковал архив.
from google.colab import drive
drive.mount('/content/drive')
Затем я сделал несколько аугментаций альбументациями:
train_transform = A.Compose([
A.Flip(p=0.25),
A.RGBShift(p=0.2),
], bbox_params=A.BboxParams(format='coco'))
val_transform = A.Compose([], bbox_params=A.BboxParams(format='coco'))
train_dataset = FruitDataset("./train_zip/train", transform=train_transform)
val_dataset = FruitDataset("./test_zip/test", transform=val_transform)
Однако когда я бегу
len(train_dataset)
, Я получаю значение 0. Итак, я не могу понять, почему размер моего набора данных равен 0. Я также не могу понять, в чем проблема. Буду очень признателен за любой возможный совет.