Проблема с тензором. В конце кода ошибка приходит только
Следующий код выдает ошибку только в последних нескольких строках. Пожалуйста, посмотрите последние строки кода и сообщите решение ошибки, связанной с некоторой ошибкой тензора.
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer, sample_dataset
загрузить пользовательские наборы данных
dataset = load_dataset('csv', data_files={
'train': ['train.csv'],
'eval': ['eval.csv']},
cache_dir="./data/"
)
Загрузите модель SetFit из концентратора
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
cache_dir="./models/"
)
Создать тренера
trainer = SetFitTrainer(
model=model,
train_dataset=dataset['train'],
eval_dataset=dataset['eval'],
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=16,
num_iterations=20, # The number of text pairs to generate for contrastive learning
num_epochs=1, # The number of epochs to use for contrastive learning
column_mapping={"text": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)
Тренируйтесь и оценивайте
trainer.train()
metrics = trainer.evaluate()
сохранять
trainer.model._save_pretrained(save_directory="./output/")
from setfit import SetFitModel
model = SetFitModel.from_pretrained("./output/", local_files_only=True)
sentiment_dict = {"negative": 0, "positive": 1}
inverse_dict = {value: key for (key, value) in sentiment_dict.items()}
Запустить вывод
text_list = [
"i loved the spiderman movie!",
"pineapple on pizza is the worst",
"what the fuck is this piece",
"good morning, lady boss",
"the product is excellent",
"a piece of rubbish"
]
preds = model(text_list)
'''for i in range(len(text_list)):
print(text_list[i])
print(inverse_dict[preds[i]])
print('\n')'''
Ошибка возникает следующим образом.
i loved the spiderman movie!
---------------------------------------------------------------------------
**KeyError Traceback (most recent call last)
<ipython-input-14-bf6d34450e7a> in <module>
2 for i in range(len(text_list)):
3 print(text_list[i])
----> 4 print(inverse_dict[preds[i]])
5 print('\n')
KeyError: tensor(1)**
'''