Как ускорить "ImageFolder" для ImageNet

Я нахожусь в университете, и все файловые системы находятся в удаленной системе, где бы я ни входил со своей учетной записью, я всегда мог получить доступ к своему домашнему каталогу. хотя я вхожу на серверы GPU через команду SSH. Это условие, при котором я использую серверы графического процессора для чтения данных.

В настоящее время я использую PyTorch для обучения ResNet с нуля на ImageNet, мои коды используют только все графические процессоры на одном компьютере, я обнаружил, что "torchvision.datasets.ImageFolder" займет почти два часа.

Не могли бы вы рассказать о том, как ускорить работу torchvision.datasets.ImageFolder? Спасибо большое.

2 ответа

Почему так долго?
Настройка ImageFolder может занять много времени, особенно если изображения хранятся на медленном удаленном диске. Причина этой задержки заключается в том, что __init__ Функция для набора данных обходит все файлы в папках изображений и проверяет, является ли этот файл файлом изображений. Для ImageNet это может занять довольно много времени, так как есть более 1 миллиона файлов для проверки.

Что ты можешь сделать?
- Как уже указывал Кевин Сан, копирование набора данных в локальное (и, возможно, намного более быстрое) хранилище может значительно ускорить процесс.
- В качестве альтернативы вы можете создать модифицированный класс набора данных, который не читает все файлы, но использует кэшированный список файлов - кэшированный список, который вы готовите только один раз и который будет использоваться для всех запусков.

Если вы уверены, что структура папок не изменится, вы можете кэшировать структуру (а не данные, которые слишком велики), используя следующее:

      
import json
from functools import wraps
from torchvision.datasets import ImageNet

def file_cache(filename):
    """Decorator to cache the output of a function to disk."""
    def decorator(f):
        @wraps(f)
        def decorated(self, directory, *args, **kwargs):
            filepath = Path(directory) / filename
            if filepath.is_file():
                out = json.loads(filepath.read_text())
            else:
                out = f(self, directory, *args, **kwargs)
                filepath.write_text(json.dumps(out))
            return out
        return decorated
    return decorator

class CachedImageNet(ImageNet):
    @file_cache(filename="cached_classes.json")
    def find_classes(self, directory, *args, **kwargs):
        classes = super().find_classes(directory, *args, **kwargs)
        return classes

    @file_cache(filename="cached_structure.json")
    def make_dataset(self, directory, *args, **kwargs):
        dataset = super().make_dataset(directory, *args, **kwargs)
        return dataset
Другие вопросы по тегам