ошибка в Python при токенизации строки: проблема NLP
Я использую токенизатор distilbert-base-uncased для токенизации входной строки перед циклом обучения. Это для задачи НЛП (прогнозирования настроений), и я использую набор данных из 3 столбцов (ярлыки [мир, спорт, бизнес, наука / техника], заголовок новостей, новостная статья). Моя цель - классифицировать статьи по ярлыкам.
Здесь, в файле datasets.py, я пытаюсь разбить CSV-файл на метки и текст статьи (в виде строки). Затем показываем первый текст и метку в коде Main.py.
Я получаю сообщение об ошибке, приведенное ниже:
datasets.py код:
class TextDataset(Dataset):
def __init__(self, fname, sentence_len):
self.fname = pd.read_csv(fname)
self.sentence_len = sentence_len
texts = self.fname[2].str.slice(0, self.sentence_len).tolist()
labels = self.fname[0].tolist()
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
self.vocab_size = tokenizer.vocab_size
tokens = tokenizer(texts, truncation=True, padding=True)
self.tokens = tokens["input_ids"]
self.labels = labels
def __len__(self):
return len(self.labels_fname)
def __getitem__(self, idx):
inputs = torch.tensor(self.tokens[idx], device=self.device)
label = torch.tensor(self.labels[idx], device=self.device)
return inputs, label
Код Main.py:
# Import our class
from datasets import TextDataset
# Initialise
dataset = TextDataset('/content/data/txt/train.csv', 80)
# Take out the first example
text, label = dataset[0]
# Print the contents
print("Tokens:", text)
print("Label:", label)
ошибка:
--------------------------------------------------------------------------
KeyError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
2897 try:
-> 2898 return self._engine.get_loc(casted_key)
2899 except KeyError as err:
pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()
pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()
KeyError: 2
The above exception was the direct cause of the following exception:
KeyError Traceback (most recent call last)
3 frames
<ipython-input-12-d0cd36be26be> in <module>()
3 # Initialise
4
----> 5 dataset = TextDataset('/content/data/txt/train.csv', 80)
6 # Take out the first example
7 text, label = dataset[0]
/content/drive/My Drive/DL Assignment/datasets.py in __init__(self, fname, sentence_len)
75 df = pd.read_csv(fname)
76 self.sentence_len = sentence_len
---> 77 texts = df[2].str.slice(0, sentence_len).tolist()
78 labels = df[0].tolist()
79
/usr/local/lib/python3.7/dist-packages/pandas/core/frame.py in __getitem__(self, key)
2904 if self.columns.nlevels > 1:
2905 return self._getitem_multilevel(key)
-> 2906 indexer = self.columns.get_loc(key)
2907 if is_integer(indexer):
2908 indexer = [indexer]
/usr/local/lib/python3.7/dist-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
2898 return self._engine.get_loc(casted_key)
2899 except KeyError as err:
-> 2900 raise KeyError(key) from err
2901
2902 if tolerance is not None:
KeyError: 2
Как я могу это исправить?
Спасибо.