Как получить весь набор данных из загрузчика данных в PyTorch
Как загрузить весь набор данных из DataLoader? Я получаю только одну партию данных.
Это мой код
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=64)
images, labels = next(iter(dataloader))
2 ответа
Вы можете установить batch_size=dataset.__len__()
если набор данных является факелом Dataset
иначе что-то вроде batch_szie=len(dataset)
должно сработать.
Осторожно, это может потребовать много памяти в зависимости от вашего набора данных.
Другой вариант - получить весь набор данных напрямую, без использования загрузчика данных, например:
images, labels = dataset[:]
Я не уверен, хотите ли вы использовать набор данных где-то еще, кроме обучения по сети (например, для проверки изображений) или хотите перебирать партии во время обучения.
Итерация по набору данных
Либо следуйте ответу Усмана Али (который может переполниться) вашей памяти, либо вы можете сделать
for i in range(len(dataset)): # or i, image in enumerate(dataset)
images, labels = dataset[i] # or whatever your dataset returns
Вы можете написать dataset[i]
потому что вы реализовали __len__
а также __getitem__
в вашей Dataset
класс (пока это подкласс Pytorch Dataset
класс).
Получение всех пакетов из загрузчика данных
Насколько я понимаю, ваш вопрос заключается в том, что вы хотите получить все партии для обучения сети. Вы должны понимать, что iter
дает вам итератор загрузчика данных (если вы не знакомы с концепцией итераторов, смотрите запись в википедии). next
говорит итератору дать вам следующий элемент.
Таким образом, в отличие от итератора, пересекающего список, загрузчик данных всегда возвращает следующий элемент. Итераторы списка останавливаются в некоторый момент. Я предполагаю, что у вас есть что-то вроде количества эпох и количества шагов за эпоху. Тогда ваш код будет выглядеть так
for i in range(epochs):
# some code
for j in range(steps_per_epoch):
images, labels = next(iter(dataloader))
prediction = net(images)
loss = net.loss(prediction, labels)
...
Будь осторожен с next(iter(dataloader))
, Если вы хотите перебрать список, это также может сработать, потому что Python кэширует объекты, но вы можете получить новый итератор каждый раз, когда он снова начинается с индекса 0. Чтобы избежать этого, возьмите итератор на вершину, вот так:
iterator = iter(dataloader)
for i in range(epochs):
for j in range(steps_per_epoch):
images, labels = next(iterator)