Как я могу использовать PyTorch DataLoader для обучения укреплению?

Я пытаюсь настроить обобщенную платформу Reinforcement Learning в PyTorch, чтобы воспользоваться всеми высокоуровневыми утилитами, использующими PyTorch DataSet и DataLoader, такими как Ignite или FastAI, но я столкнулся с блокирующим устройством с динамической природой Усиление данных обучения:

  • Элементы данных генерируются из кода, а не читаются из файла, и они зависят от предыдущих действий и результатов модели, поэтому каждому вызову nextItem необходим доступ к состоянию модели.
  • Эпизоды обучения не имеют фиксированной длины, поэтому мне нужен динамический размер пакета, а также динамический общий размер набора данных. Я предпочел бы использовать функцию условия завершения вместо числа. Я мог бы "возможно" сделать это с отступом, как при обработке предложений НЛП, но это настоящий взлом.

Мои поиски в Google и Stackru пока что дали пшик. Кто-нибудь здесь знает о существующих решениях или обходных путях использования DataLoader или DataSet с Reinforcement Learning? Я ненавижу терять доступ ко всем существующим библиотекам, которые зависят от них.

1 ответ

Вот один фреймворк на основе PyTorch, а вот что-то из Facebook.

Когда дело доходит до вашего вопроса (и благородного квеста, без сомнения):

Вы могли бы легко создать torch.utils.data.Dataset зависит от чего-либо, включая модель, что-то вроде этого (простите за слабую абстракцию, это просто чтобы доказать свою точку зрения):

import typing

import torch
from torch.utils.data import Dataset


class Environment(Dataset):
    def __init__(self, initial_state, actor: torch.nn.Module, max_interactions: int):
        self.current_state = initial_state
        self.actor: torch.nn.Module = actor
        self.max_interactions: int = max_interactions

    # Just ignore the index
    def __getitem__(self, _):
        self.current_state = self.actor.update(self.current_state)
        return self.current_state.get_data()

    def __len__(self):
        return self.max_interactions

Предполагая, torch.nn.Module -подобная сеть имеет какую-то update изменение состояния окружающей среды. В общем, это просто структура Python, и вы можете с ней многое моделировать.

Вы можете указать max_interactions быть почти infinite или вы можете изменить его на лету, если необходимо, с некоторыми обратными вызовами во время тренировки (как __len__ будет вызываться несколько раз по всему коду, вероятно). Окружающая среда может также обеспечить batches вместо образцов.

torch.utils.data.DataLoader имеет batch_sampler аргумент, там вы можете генерировать партии различной длины. Поскольку сеть не зависит от первого измерения, вы также можете вернуть любой размер пакета, который вам нужен.

КСТАТИ. Заполнение следует использовать, если каждый образец будет иметь разную длину, и разный размер партии не имеет к этому никакого отношения.

Другие вопросы по тегам