Я точно настроил предварительно обученный BERT для классификации предложений, но не могу заставить его предсказывать новые предложения

Ниже результат моей тонкой настройки.

Training Loss   Valid. Loss Valid. Accur.   Training Time   Validation Time
epoch                   
1   0.16    0.11    0.96    0:02:11 0:00:05
2   0.07    0.13    0.96    0:02:19 0:00:05
3   0.03    0.14    0.97    0:02:22 0:00:05
4   0.02    0.16    0.96    0:02:21 0:00:05

Затем я попытался использовать модель для предсказания меток из файла csv. Я создал столбец метки, установил тип int64 и запустил прогноз.

print('Predicting labels for {:,} test sentences...'.format(len(input_ids)))
model.eval()
# Tracking variables 
predictions , true_labels = [], []
# Predict 
for batch in prediction_dataloader:
  # Add batch to GPU
  batch = tuple(t.to(device) for t in batch)

  # Unpack the inputs from our dataloader
  b_input_ids, b_input_mask, b_labels = batch

  # Telling the model not to compute or store gradients, saving memory and 
  # speeding up prediction
  with torch.no_grad():
      # Forward pass, calculate logit predictions
      outputs = model(b_input_ids, token_type_ids=None, 
                      attention_mask=b_input_mask)

  logits = outputs[0]

  # Move logits and labels to CPU
  logits = logits.detach().cpu().numpy()
  label_ids = b_labels.to('cpu').numpy()

  # Store predictions and true labels
  predictions.append(logits)
  true_labels.append(label_ids)


однако, хотя я могу распечатать прогнозы [4.235, -4.805] и т.д., а также true_labels[NaN,NaN.....], я не могу получить прогнозируемые метки {0 или 1}. Я что-то упустил?

1 ответ

Решение

Выходные данные моделей - логиты, то есть распределение вероятностей до нормализации с использованием softmax.

Если вы возьмете свой результат: [4.235, -4.805] и запустите softmax поверх него

In [1]: import torch
In [2]: import torch.nn.functional as F 
In [3]: F.softmax(torch.tensor([4.235, -4.805]))
Out[3]: tensor([9.9988e-01, 1.1856e-04])

Вы получите оценку вероятности 99% для метки 0. Когда у вас есть логиты в виде 2D-тензора, вы можете легко получить классы, вызвав

logits.argmax(0)

В NaNs ценностей в вашем true_labels вероятно, ошибка в том, как вы загружаете данные, это не имеет ничего общего с моделью BERT.

Другие вопросы по тегам