Ошибка типа при тонкой настройке модели bert-large-uncased-all-word-masking с помощью Huggingface

Я пытаюсь настроить модель Huggingface bert-large-uncased-all-word-masking, и при обучении получаю такую ​​ошибку типа:

«TypeError: только целочисленные тензоры одного элемента могут быть преобразованы в индекс»

Вот код:

      
train_inputs = tokenizer(text_list[0:457], return_tensors='pt', max_length=512, truncation=True, padding='max_length')
train_inputs['labels']= train_inputs.input_ids.detach().clone()

Затем я случайным образом маскирую около 15% слов во входных идентификаторах и определяю класс для набора данных, а затем в цикле обучения возникает ошибка:

      class MeditationsDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings= encodings
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    def __len__(self):
        return self.encodings.input_ids

train_dataset = MeditationsDataset(train_inputs)
train_dataloader = torch.utils.data.DataLoader(dataset= train_dataset, batch_size=8, shuffle=False)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

from transformers import BertModel, AdamW

model = BertModel.from_pretrained("bert-large-uncased-whole-word-masking")
model.to(device)
model.train()

optim = AdamW(model.parameters(), lr=1e-5)
num_epochs = 2
from tqdm.auto import tqdm

for epoch in range(num_epochs):
    loop = tqdm(train_dataloader, leave=True)
    for batch in loop:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

Ошибка возникает в "для партии в цикле"

Кто-нибудь это понимает и знает, как это решить? Заранее спасибо за помощь

1 ответ

В классе MeditationsDatasetв функции __getitem__
torch.tensor(val[idx])не рекомендуется PyTorch, вы должны использовать вместо этого val[idx].clone().detach()

Другие вопросы по тегам