Я точно настроил предварительно обученный BERT для классификации предложений, но не могу заставить его предсказывать новые предложения
Ниже результат моей тонкой настройки.
Training Loss Valid. Loss Valid. Accur. Training Time Validation Time
epoch
1 0.16 0.11 0.96 0:02:11 0:00:05
2 0.07 0.13 0.96 0:02:19 0:00:05
3 0.03 0.14 0.97 0:02:22 0:00:05
4 0.02 0.16 0.96 0:02:21 0:00:05
Затем я попытался использовать модель для предсказания меток из файла csv. Я создал столбец метки, установил тип int64 и запустил прогноз.
print('Predicting labels for {:,} test sentences...'.format(len(input_ids)))
model.eval()
# Tracking variables
predictions , true_labels = [], []
# Predict
for batch in prediction_dataloader:
# Add batch to GPU
batch = tuple(t.to(device) for t in batch)
# Unpack the inputs from our dataloader
b_input_ids, b_input_mask, b_labels = batch
# Telling the model not to compute or store gradients, saving memory and
# speeding up prediction
with torch.no_grad():
# Forward pass, calculate logit predictions
outputs = model(b_input_ids, token_type_ids=None,
attention_mask=b_input_mask)
logits = outputs[0]
# Move logits and labels to CPU
logits = logits.detach().cpu().numpy()
label_ids = b_labels.to('cpu').numpy()
# Store predictions and true labels
predictions.append(logits)
true_labels.append(label_ids)
однако, хотя я могу распечатать прогнозы [4.235, -4.805] и т.д., а также true_labels[NaN,NaN.....], я не могу получить прогнозируемые метки {0 или 1}. Я что-то упустил?
1 ответ
Выходные данные моделей - логиты, то есть распределение вероятностей до нормализации с использованием softmax.
Если вы возьмете свой результат: [4.235, -4.805]
и запустите softmax поверх него
In [1]: import torch
In [2]: import torch.nn.functional as F
In [3]: F.softmax(torch.tensor([4.235, -4.805]))
Out[3]: tensor([9.9988e-01, 1.1856e-04])
Вы получите оценку вероятности 99% для метки 0. Когда у вас есть логиты в виде 2D-тензора, вы можете легко получить классы, вызвав
logits.argmax(0)
В NaN
s ценностей в вашем true_labels
вероятно, ошибка в том, как вы загружаете данные, это не имеет ничего общего с моделью BERT.