Добавить канал в MNIST через преобразование?

Я пытаюсь использовать набор данных MNIST из torchvision.datasets. Похоже, что предоставляется как N x H x W (uint8) (размер партии, высота, ширина) тензор. Все классы pytorch для работы с изображениями (например, Conv2d) однако требуют N x C x H x W (float32) тензор где C количество цветовых каналов Я пытался добавить добавить ToTensor преобразовать, но это не добавило цветовой канал.

Есть ли способ использования torchvision.transforms добавить это дополнительное измерение? Для сырой tensor мы могли бы просто сделать .unsqueeze(1) но это не выглядит как очень элегантное решение. Я просто пытаюсь сделать это "правильным" способом.

Вот неудачное преобразование.

import torchvision
dataset = torchvision.datasets.MNIST("~/PyTorchDatasets/MNIST/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
print(dataset.train_data[0])

1 ответ

Решение

У меня было неправильное представление: dataset.train_data не зависит от указанного transform, только вывод DataLoader(dataset,...) будет. После проверки data от

for data, _ in DataLoader(dataset):
    break

мы это видим ToTensor на самом деле делает именно то, что нужно.

Другие вопросы по тегам