Ошибка несоответствия размера тензора при генерации текста с помощью beam_search (библиотека huggingface)
Я использую библиотеку huggingface для генерации текста с использованием предварительно обученной модели distilgpt2. В частности, я использую функцию beam_search , так как я хотел бы включить LogitsProcessorList (который вы не можете использовать с функцией генерации ).
Соответствующая часть моего кода выглядит так:
beam_scorer = BeamSearchScorer(
batch_size=btchsze,
max_length=15, # not sure why lengths under 20 fail
num_beams=num_seq,
device=model.device,
)
j = input_ids.tile((num_seq*btchsze,1))
next_output = model.beam_search(
j,
beam_scorer,
eos_token_id=tokenizer.encode('.')[0],
logits_processor=logits_processor
)
Однако функция beam_search выдает эту ошибку, когда я пытаюсь сгенерировать, используя max_length меньше 20:
~/anaconda3/envs/techtweets37/lib/python3.7/site-packages/transformers-4.4.2-py3.8.egg/transformers/generation_beam_search.py in finalize(self, input_ids, final_beam_scores, final_beam_tokens, final_beam_indices, pad_token_id, eos_token_id)
326 # fill with hypotheses and eos_token_id if the latter fits in
327 for i, hypo in enumerate(best):
--> 328 decoded[i, : sent_lengths[i]] = hypo
329 if sent_lengths[i] < self.max_length:
330 decoded[i, sent_lengths[i]] = eos_token_id
RuntimeError: The expanded size of the tensor (15) must match the existing size (20) at non-singleton dimension 0. Target sizes: [15]. Tensor sizes: [20]
Кажется, я не могу понять, откуда берется 20: это то же самое, даже если длина ввода больше или меньше, даже если я использую другой размер партии или количество лучей. Нет ничего, что я определил как длину 20, и я не могу найти никакого значения по умолчанию. Максимальная длина последовательности влияет на результаты поиска луча, поэтому я хотел бы выяснить это и иметь возможность установить более короткую максимальную длину.
1 ответ
Это известная проблема в библиотеке обнимающих лиц:
https://github.com/huggingface/transformers/issues/11040
По сути, бомбардир использует не переданные ему передачи, а
max_length
модели.
На данный момент исправление состоит в том, чтобы установить
model.config.max_length
до желаемой максимальной длины.