Эффективная функция PyTorch DataLoader collate_fn для входов различных измерений

У меня проблемы с написанием кастома collate_fn функция для PyTorch DataLoader учебный класс. Мне нужна пользовательская функция, потому что мои входы имеют разные размеры.

В настоящее время я пытаюсь написать базовую реализацию статьи Стэнфордского MURA. Набор данных имеет набор помеченных исследований. Исследование может содержать более одного изображения. Я создал кастом Datasetкласс, который складывает эти несколько изображений, используя torch.stack,

Сложенный тензор затем предоставляется в качестве входных данных для модели, и список выходных данных усредняется для получения одного выходного сигнала. Эта реализация отлично работает с DataLoader когда batch_size=1, Тем не менее, когда я пытаюсь установить batch_size до 8, как в оригинальной статье, DataLoader не удается, так как он использует torch.stack для штабелирования партии и входные данные в моей партии имеют переменные размеры (поскольку каждое исследование может иметь несколько изображений).

Чтобы исправить это, я попытался реализовать свой кастом collate_fn функция.

def collate_fn(batch):
    imgs = [item['images'] for item in batch]
    targets = [item['label'] for item in batch]
    targets = torch.LongTensor(targets)
    return imgs, targets

Затем в цикле моей тренировочной эпохи я повторяю каждую партию следующим образом:

for image, label in zip(*batch):
    label = label.type(torch.FloatTensor)
    # wrap them in Variable
    image = Variable(image).cuda()  
    label = Variable(label).cuda()
    # forward
    output = model(image)
    output = torch.mean(output)
    loss = criterion(output, label, phase)

Тем не менее, это не дает мне каких-либо улучшенных таймингов в эпоху и все равно занимает столько же времени, сколько и при размере партии, равном только 1. Я также попытался установить размер партии равным 32, и это тоже не улучшает время.

Я делаю что-то неправильно? Есть ли лучший подход к этому?

1 ответ

Очень интересная проблема! Если я правильно вас понял (а также проверил резюме статьи), у вас есть 40561 изображение из 14863 исследований, где каждое исследование вручную помечено радиологами как нормальное или ненормальное.

Я считаю, что причина, по которой у вас возникла проблема, с которой вы столкнулись, заключалась, например, в том, что вы создали стек для,

  1. исследование A - 12 изображений
  2. исследование B - 13 изображений
  3. исследование C - 7 изображений
  4. исследование D - 1 изображение и т. д.

И вы пытаетесь использовать пакет размером 8 во время обучения, что не даст результата, когда дойдет до изучения D.

Следовательно, есть ли причина, по которой мы хотим усреднить список результатов исследования, чтобы он соответствовал одной метке? В противном случае я бы просто собрал все 40561 изображение, назначил одну и ту же метку всем изображениям из одного исследования (так, чтобы список выходных данных в A сравнивался со списком из 12 меток).

Таким образом, с помощью одного загрузчика данных вы можете перемещаться между исследованиями (при желании) и использовать желаемый размер пакета во время обучения.

Я вижу, что этот вопрос существует уже некоторое время, надеюсь, он кому-то поможет в будущем :)

Другие вопросы по тегам