Эффективная функция PyTorch DataLoader collate_fn для входов различных измерений
У меня проблемы с написанием кастома collate_fn
функция для PyTorch DataLoader
учебный класс. Мне нужна пользовательская функция, потому что мои входы имеют разные размеры.
В настоящее время я пытаюсь написать базовую реализацию статьи Стэнфордского MURA. Набор данных имеет набор помеченных исследований. Исследование может содержать более одного изображения. Я создал кастом Dataset
класс, который складывает эти несколько изображений, используя torch.stack
,
Сложенный тензор затем предоставляется в качестве входных данных для модели, и список выходных данных усредняется для получения одного выходного сигнала. Эта реализация отлично работает с DataLoader
когда batch_size=1
, Тем не менее, когда я пытаюсь установить batch_size
до 8, как в оригинальной статье, DataLoader
не удается, так как он использует torch.stack
для штабелирования партии и входные данные в моей партии имеют переменные размеры (поскольку каждое исследование может иметь несколько изображений).
Чтобы исправить это, я попытался реализовать свой кастом collate_fn
функция.
def collate_fn(batch):
imgs = [item['images'] for item in batch]
targets = [item['label'] for item in batch]
targets = torch.LongTensor(targets)
return imgs, targets
Затем в цикле моей тренировочной эпохи я повторяю каждую партию следующим образом:
for image, label in zip(*batch):
label = label.type(torch.FloatTensor)
# wrap them in Variable
image = Variable(image).cuda()
label = Variable(label).cuda()
# forward
output = model(image)
output = torch.mean(output)
loss = criterion(output, label, phase)
Тем не менее, это не дает мне каких-либо улучшенных таймингов в эпоху и все равно занимает столько же времени, сколько и при размере партии, равном только 1. Я также попытался установить размер партии равным 32, и это тоже не улучшает время.
Я делаю что-то неправильно? Есть ли лучший подход к этому?
1 ответ
Очень интересная проблема! Если я правильно вас понял (а также проверил резюме статьи), у вас есть 40561 изображение из 14863 исследований, где каждое исследование вручную помечено радиологами как нормальное или ненормальное.
Я считаю, что причина, по которой у вас возникла проблема, с которой вы столкнулись, заключалась, например, в том, что вы создали стек для,
- исследование A - 12 изображений
- исследование B - 13 изображений
- исследование C - 7 изображений
- исследование D - 1 изображение и т. д.
И вы пытаетесь использовать пакет размером 8 во время обучения, что не даст результата, когда дойдет до изучения D.
Следовательно, есть ли причина, по которой мы хотим усреднить список результатов исследования, чтобы он соответствовал одной метке? В противном случае я бы просто собрал все 40561 изображение, назначил одну и ту же метку всем изображениям из одного исследования (так, чтобы список выходных данных в A сравнивался со списком из 12 меток).
Таким образом, с помощью одного загрузчика данных вы можете перемещаться между исследованиями (при желании) и использовать желаемый размер пакета во время обучения.
Я вижу, что этот вопрос существует уже некоторое время, надеюсь, он кому-то поможет в будущем :)