получить индексы пакета при итерации DataLoader по набору данных huggingface

Приведенный ниже код взят из учебника huggingface:

      from datasets import load_metric

metric= load_metric("glue", "mrpc")
model.eval()
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)
    
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()

Внутри петли for batch in eval_dataloader:, как я могу узнать, какие индексы из набора данных включает этот пакет?

DataLoader создан ранее с использованием

      eval_dataloader = DataLoader(
    tokenized_datasets["validation"], batch_size=8, collate_fn=data_collator
)

Обратите внимание, что в нем нет флага перетасовки, поэтому можно подсчитать вручную, используя размер пакета, но как это сделать при перетасовке? Можно ли сделать его полем пакета при создании набора данных и загрузчика данных?

0 ответов

Другие вопросы по тегам