PyTorch, выберите партии в соответствии с меткой в ​​столбце данных

У меня есть такой набор данных:

Количество записей в каждом теге не всегда одинаково.

И моя цель - загрузить только данные с определенным тегом или тегами, чтобы я получал только записи для одного мини-пакета, а затем для другого мини-пакета. Или например tag1 и tag2 Я установил размер своей мини-партии на 2.

Код, который я до сих пор полностью игнорирует tag label и просто выбирает партии случайным образом.

Я построил такие наборы данных:

      # features is a matrix with all the features columns through all rows
# target is a vector with the target column through all rows
featuresTrain, targetTrain = projutils.get_data(train=True, config=config)
train = torch.utils.data.TensorDataset(featuresTrain, targetTrain)
train_loader = make_loader(train, batch_size=config.batch_size)

А мой загрузчик (в общем) выглядит так:

      def make_loader(dataset, batch_size):
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=batch_size, 
                                     shuffle=True,
                                     pin_memory=True,
                                     num_workers=8)
return loader

Что я затем тренирую так:

      for epoch in range(config.epochs):
    for _, (features, target) in enumerate(loader):
        loss = train_batch(features, target, model, optimizer, criterion)

И train_batch:

      def train_batch(features, target, model, optimizer, criterion):
features, target = features.to(device), target.to(device)

# Forward pass ➡
outputs = model(features)
loss = criterion(outputs, target
return loss

1 ответ

Насколько я могу судить, простой набор данных, который приблизительно реализует те характеристики, которые вы ищете.

      class CustomDataset(data.Dataset):
    def __init__(self,featuresTrain,targetsTrain,tagsTrain,sample_equally = False):
       # self.tags should be a tensor in k-hot encoding form so a 2D tensor, 
       self.tags = tagsTrain
       self.x = featuresTrain
       self.y = targetsTrain
       self.unique_tagsets = None
       self.sample_equally = sample_equally

       # self.active tags is a 1D k-hot encoding vector
       self.active_tags = self.get_random_tag_set()
       
    
    def get_random_tag_set(self):
        # gets all unique sets of tags and returns one randomly
        if self.unique_tagsets is None:
             self.unique_tagsets = self.tags.unique(dim = 0)
        if self.sample_equally:
             rand_idx = torch.randint(len(self.unique_tagsets),[1])[1].detatch().int()
             return self.unique_tagsets[rand_idx]
        else:
            rand_idx = torch.randint(len(self.tags),[1])[1].detatch().int()
            return self.tags[rand_idx]

    def set_tags(self,tags):
       # specifies the set of tags that must be present for a datum to be selected
        self.active_tags = tags

    def __getitem__(self,index):
        # get all indices of elements with self.active_tags
        indices = torch.where(self.tags == self.active_tags)[0]

        # we select an index based on the indices of the elements that have the tag set
        idx = indices[index % len(indices)]

        item = self.x[idx], self.y[idx]
        return item

    def __len__(self):
        return len(self.y)

Этот набор данных случайным образом выбирает набор тегов. Тогда каждый раз __getitem__()вызывается, он использует указанный индекс для выбора среди элементов данных, которые имеют набор тегов. Вы можете позвонить или get_random_tag_set() тогда set_tags()после каждого мини-пакета или как бы часто вы ни захотели изменить набор тегов, или вы можете вручную указать набор тегов самостоятельно. Набор данных наследуется от torch.data.Dataset так что вы сможете использовать if с torch.data.Dataloader без модификации.

С помощью sample_equally.

Короче говоря, этот набор данных немного грубоват по краям, но должен позволить вам отбирать все партии с одним и тем же набором тегов. Основным недостатком является то, что каждый элемент, вероятно, будет отбираться более одного раза за партию.

Другие вопросы по тегам