Проблема с тензором. В конце кода ошибка приходит только

Следующий код выдает ошибку только в последних нескольких строках. Пожалуйста, посмотрите последние строки кода и сообщите решение ошибки, связанной с некоторой ошибкой тензора.

      from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer, sample_dataset

загрузить пользовательские наборы данных

      dataset = load_dataset('csv', data_files={
    'train': ['train.csv'],
    'eval': ['eval.csv']},
    cache_dir="./data/"
)

Загрузите модель SetFit из концентратора

      model = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
    cache_dir="./models/"
)

Создать тренера

      trainer = SetFitTrainer(
    model=model,
    train_dataset=dataset['train'],
    eval_dataset=dataset['eval'],
    loss_class=CosineSimilarityLoss,
    metric="accuracy",
    batch_size=16,
    num_iterations=20,  # The number of text pairs to generate for contrastive learning
    num_epochs=1,  # The number of epochs to use for contrastive learning
    column_mapping={"text": "text", "label": "label"}  # Map dataset columns to text/label expected by trainer
)

Тренируйтесь и оценивайте

      trainer.train()
metrics = trainer.evaluate()

сохранять

      trainer.model._save_pretrained(save_directory="./output/")

from setfit import SetFitModel

model = SetFitModel.from_pretrained("./output/", local_files_only=True)

sentiment_dict = {"negative": 0, "positive": 1}
inverse_dict = {value: key for (key, value) in sentiment_dict.items()}

Запустить вывод

      text_list = [
    "i loved the spiderman movie!",
    "pineapple on pizza is the worst",
    "what the fuck is this piece",
    "good morning, lady boss",
    "the product is excellent",
    "a piece of rubbish"
]

preds = model(text_list)

'''for i in range(len(text_list)):
    print(text_list[i])
    print(inverse_dict[preds[i]])
    print('\n')'''

Ошибка возникает следующим образом.

      i loved the spiderman movie!
---------------------------------------------------------------------------
**KeyError                                  Traceback (most recent call last)
<ipython-input-14-bf6d34450e7a> in <module>
      2 for i in range(len(text_list)):
      3     print(text_list[i])
----> 4     print(inverse_dict[preds[i]])
      5     print('\n')
KeyError: tensor(1)**
'''

0 ответов

Другие вопросы по тегам