получить индексы пакета при итерации DataLoader по набору данных huggingface
Приведенный ниже код взят из учебника huggingface:
from datasets import load_metric
metric= load_metric("glue", "mrpc")
model.eval()
for batch in eval_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model(**batch)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
metric.add_batch(predictions=predictions, references=batch["labels"])
metric.compute()
Внутри петли
for batch in eval_dataloader:
, как я могу узнать, какие индексы из набора данных включает этот пакет?
DataLoader создан ранее с использованием
eval_dataloader = DataLoader(
tokenized_datasets["validation"], batch_size=8, collate_fn=data_collator
)
Обратите внимание, что в нем нет флага перетасовки, поэтому можно подсчитать вручную, используя размер пакета, но как это сделать при перетасовке? Можно ли сделать его полем пакета при создании набора данных и загрузчика данных?