Как выполнить итерацию токенизированного набора данных pytorch Multi30k в BucketIterator?
Я использую Pytorch(1.10 v), я использую набор данных Multi30k с немецкого на английский для машинного перевода. Я использую spacy для токенизации (как для английских, так и для немецких слов) и хочу передать токенизированные данные в ( torchtext.legacy.data.BucketIterator ) для заполнения и преобразования строки в индекс. Возникла ошибка, связанная с sort_key, я ее не понимаю. Кто-нибудь, пожалуйста, помогите мне.
Код
import spacy
from torchtext.datasets import Multi30k # this is a en and gr dataset for machine translation
from torchtext.legacy.data import Field, BucketIterator
spacy_eng = spacy.load("en_core_web_sm")
spacy_ger = spacy.load("de_core_news_sm")
def tokenize_eng(text):
return [tok.text for tok in spacy_eng.tokenizer(text)]
def tokenize_ger(text):
return [tok.text for tok in spacy_ger.tokenizer(text)]
english = Field(sequential=True, use_vocab=True, tokenize=tokenize_eng, lower=True, init_token='<sos>', eos_token='<eos>')
german = Field(sequential=True, use_vocab=True, tokenize=tokenize_ger, lower=True, init_token='<sos>', eos_token='<eos>')
train, valid, test = Multi30k(root=".data", split=('train', 'valid', 'test'), language_pair=('en', 'de'))
# will make vocabulary from train data
english.build_vocab(train, max_size=10000, min_freq=2)
german.build_vocab(train, max_size=10000, min_freq=2)
train_data, valid_data, test_data = BucketIterator.splits((train, valid, test),
batch_size=64,
device='cuda')
Ошибка
Traceback (most recent call last):
File "D:\Torch\Multi30K_inbuilt_dataset.py", line 28, in <module>
train_data, valid_data, test_data = BucketIterator.splits((train, valid, test),
File "C:\Users\Devanshu\anaconda3\envs\deeplearning\lib\site-packages\torchtext\legacy\data\iterator.py", line 99, in splits
ret.append(cls(
File "C:\Users\Devanshu\anaconda3\envs\deeplearning\lib\site-packages\torchtext\legacy\data\iterator.py", line 59, in __init__
self.sort_key = dataset.sort_key
File "C:\Users\Devanshu\anaconda3\envs\deeplearning\lib\site-packages\torch\utils\data\dataset.py", line 226, in __getattr__
raise AttributeError
AttributeError
1 ответ
Надеюсь, что еще не поздно, но попробуйте использовать:
train, valid, test = Multi30k.splits(exts=('.de', '.en'), fields=(german, english))