Как разбить набор данных MNIST на несколько подмножеств для распределенных узлов, используя Pytorch?
Я реализую обучение DistributedDataParallel для простого CNN для torchvision.datasets.MNIST, одновременно работающего на 3 распределенных узлах. Я хочу разделить наборы данных на 3 непересекающихся подмножества (A,B,C), каждый из которых должен содержать 20000 изображений. Отдельные подмножества должны быть далее разделены на разделы обучения и тестирования, то есть 0,7% обучения и 0,3% тестирования. Я планирую предоставить каждое подмножество каждому распределенному узлу отдельно, чтобы они могли обучаться и тестировать в стиле DistributedDataParallel.
Основной код, как показано ниже, загружает набор данных MNIST из torchvision.datasets.MNIST, а затем использует torch.utils.data.distributed.DistributedSampler и torch.utils.data.DataLoader для создания пакетов данных для обучения и тестирования на одном узле.
# TRAINING DATA
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]))
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=3, rank=dist.get_rank())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=False, num_workers=3, pin_memory=True, sampler=True)
# TESTING DATA
test_dataset = datasets.MNIST('data', train=False, download=False, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=3, pin_memory=True)
Я ожидаю, что ответ должен создать train_dataset_a, train_dataset_b и train_dataset_c, а также test_dataset_a, test_dataset_b и test_dataset_c.