Отображение изображений, загруженных с помощью загрузчика данных pytorch

Я работаю с некоторыми изображениями данных лидара, которые я не могу публиковать здесь из-за ограничения репутации при публикации изображений. Однако при загрузке одних и тех же изображений с использованием pytorch ImageFolder и Dataloader с единственным преобразованием, которое преобразует изображения в тензоры, кажется, что существует какое-то экстремальное пороговое значение, и я не могу найти причину этого.

Ниже показано, как я показываю первое изображение:

      dataset = gdal.Open(dir)

print(dataset.RasterCount)
img = dataset.GetRasterBand(1).ReadAsArray() 

f = plt.figure() 
plt.imshow(img) 
print(img.shape)
plt.show() 

и вот как я использую загрузчик данных и отображаю пороговое изображение:

      data_transforms = {
        'train': transforms.Compose([
            transforms.ToTensor(),
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
        ]),
    }

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms[x]) for x in ['train', 'val']}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 batch_size=1,
                                                 shuffle=True,
                                                 num_workers=2) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

for image in dataloders["train"]:
  f = plt.figure() 
  print(image[0].shape)
  plt.imshow(image[0].squeeze()[0,:,:]) 
  plt.show() 
  break

Любая помощь по альтернативному способу отображения изображений или любые ошибки, которые я делаю, были бы очень признательны.

1 ответ

Если вы хотите визуализировать изображения, загруженные Dataloader, я предлагаю этот скрипт:

      for batch in train_data_loader:
    inputs, targets = batch
    for img in inputs:
        image  = img.cpu().numpy()
        # transpose image to fit plt input
        image = image.T
        # normalise image
        data_min = np.min(image, axis=(1,2), keepdims=True)
        data_max = np.max(image, axis=(1,2), keepdims=True)
        scaled_data = (image - data_min) / (data_max - data_min)
        # show image
        plt.imshow(scaled_data)
        plt.show()
Другие вопросы по тегам