PyTorch, выберите партии в соответствии с меткой в столбце данных
У меня есть такой набор данных:
Количество записей в каждом теге не всегда одинаково.
И моя цель - загрузить только данные с определенным тегом или тегами, чтобы я получал только записи для одного мини-пакета, а затем для другого мини-пакета. Или например
tag1
и
tag2
Я установил размер своей мини-партии на
2
.
Код, который я до сих пор полностью игнорирует
tag
label и просто выбирает партии случайным образом.
Я построил такие наборы данных:
# features is a matrix with all the features columns through all rows
# target is a vector with the target column through all rows
featuresTrain, targetTrain = projutils.get_data(train=True, config=config)
train = torch.utils.data.TensorDataset(featuresTrain, targetTrain)
train_loader = make_loader(train, batch_size=config.batch_size)
А мой загрузчик (в общем) выглядит так:
def make_loader(dataset, batch_size):
loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=8)
return loader
Что я затем тренирую так:
for epoch in range(config.epochs):
for _, (features, target) in enumerate(loader):
loss = train_batch(features, target, model, optimizer, criterion)
И
train_batch
:
def train_batch(features, target, model, optimizer, criterion):
features, target = features.to(device), target.to(device)
# Forward pass ➡
outputs = model(features)
loss = criterion(outputs, target
return loss
1 ответ
Насколько я могу судить, простой набор данных, который приблизительно реализует те характеристики, которые вы ищете.
class CustomDataset(data.Dataset):
def __init__(self,featuresTrain,targetsTrain,tagsTrain,sample_equally = False):
# self.tags should be a tensor in k-hot encoding form so a 2D tensor,
self.tags = tagsTrain
self.x = featuresTrain
self.y = targetsTrain
self.unique_tagsets = None
self.sample_equally = sample_equally
# self.active tags is a 1D k-hot encoding vector
self.active_tags = self.get_random_tag_set()
def get_random_tag_set(self):
# gets all unique sets of tags and returns one randomly
if self.unique_tagsets is None:
self.unique_tagsets = self.tags.unique(dim = 0)
if self.sample_equally:
rand_idx = torch.randint(len(self.unique_tagsets),[1])[1].detatch().int()
return self.unique_tagsets[rand_idx]
else:
rand_idx = torch.randint(len(self.tags),[1])[1].detatch().int()
return self.tags[rand_idx]
def set_tags(self,tags):
# specifies the set of tags that must be present for a datum to be selected
self.active_tags = tags
def __getitem__(self,index):
# get all indices of elements with self.active_tags
indices = torch.where(self.tags == self.active_tags)[0]
# we select an index based on the indices of the elements that have the tag set
idx = indices[index % len(indices)]
item = self.x[idx], self.y[idx]
return item
def __len__(self):
return len(self.y)
Этот набор данных случайным образом выбирает набор тегов. Тогда каждый раз
__getitem__()
вызывается, он использует указанный индекс для выбора среди элементов данных, которые имеют набор тегов. Вы можете позвонить или
get_random_tag_set()
тогда
set_tags()
после каждого мини-пакета или как бы часто вы ни захотели изменить набор тегов, или вы можете вручную указать набор тегов самостоятельно. Набор данных наследуется от
torch.data.Dataset
так что вы сможете использовать if с
torch.data.Dataloader
без модификации.
С помощью
sample_equally
.
Короче говоря, этот набор данных немного грубоват по краям, но должен позволить вам отбирать все партии с одним и тем же набором тегов. Основным недостатком является то, что каждый элемент, вероятно, будет отбираться более одного раза за партию.