Как ускорить "ImageFolder" для ImageNet
Я нахожусь в университете, и все файловые системы находятся в удаленной системе, где бы я ни входил со своей учетной записью, я всегда мог получить доступ к своему домашнему каталогу. хотя я вхожу на серверы GPU через команду SSH. Это условие, при котором я использую серверы графического процессора для чтения данных.
В настоящее время я использую PyTorch для обучения ResNet с нуля на ImageNet, мои коды используют только все графические процессоры на одном компьютере, я обнаружил, что "torchvision.datasets.ImageFolder" займет почти два часа.
Не могли бы вы рассказать о том, как ускорить работу torchvision.datasets.ImageFolder? Спасибо большое.
2 ответа
Почему так долго?
Настройка ImageFolder
может занять много времени, особенно если изображения хранятся на медленном удаленном диске. Причина этой задержки заключается в том, что __init__
Функция для набора данных обходит все файлы в папках изображений и проверяет, является ли этот файл файлом изображений. Для ImageNet это может занять довольно много времени, так как есть более 1 миллиона файлов для проверки.
Что ты можешь сделать?
- Как уже указывал Кевин Сан, копирование набора данных в локальное (и, возможно, намного более быстрое) хранилище может значительно ускорить процесс.
- В качестве альтернативы вы можете создать модифицированный класс набора данных, который не читает все файлы, но использует кэшированный список файлов - кэшированный список, который вы готовите только один раз и который будет использоваться для всех запусков.
Если вы уверены, что структура папок не изменится, вы можете кэшировать структуру (а не данные, которые слишком велики), используя следующее:
import json
from functools import wraps
from torchvision.datasets import ImageNet
def file_cache(filename):
"""Decorator to cache the output of a function to disk."""
def decorator(f):
@wraps(f)
def decorated(self, directory, *args, **kwargs):
filepath = Path(directory) / filename
if filepath.is_file():
out = json.loads(filepath.read_text())
else:
out = f(self, directory, *args, **kwargs)
filepath.write_text(json.dumps(out))
return out
return decorated
return decorator
class CachedImageNet(ImageNet):
@file_cache(filename="cached_classes.json")
def find_classes(self, directory, *args, **kwargs):
classes = super().find_classes(directory, *args, **kwargs)
return classes
@file_cache(filename="cached_structure.json")
def make_dataset(self, directory, *args, **kwargs):
dataset = super().make_dataset(directory, *args, **kwargs)
return dataset