Ошибка с get_peft_model() и PromptTuningConfig.

Я учусь выполнять быструю настройку и столкнулся с проблемой.

Я использую функцию get_peft_model для инициализации модели для обучения из «google/flan-t5-base».

      model_name='google/flan-t5-base'
tokenizer = AutoTokenizer.from_pretrained(model_name,)
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

peft_prompt_config = PromptTuningConfig(task_type="SEQ_2_SEQ_LM", num_virtual_tokens=20)
prompt_model = get_peft_model(original_model, peft_prompt_config)

#create a dummy input
input_ids = tokenizer('test', return_tensors="pt").input_ids

original_model.generate(input_ids)
>>Out[76]: tensor([[  0, 794,   1]])

prompt_model.generate(input_ids)
>> TypeError: generate() takes 1 positional argument but 2 were given

help(prompt_model.generate)
>>generate(**kwargs) method of peft.peft_model.PeftModelForSeq2SeqLM instance

prompt_model.generate(**{'input_ids':input_ids})
>> NotImplementedError: 

Это работает для LoRA:

      lora_config = LoraConfig(
    r=32, lora_alpha=32, target_modules=["q", "v"], lora_dropout=0.05,  task_type=TaskType.SEQ_2_SEQ_LM 
)
lora_model = get_peft_model(original_model, lora_config)
lora_model.generate(**{'input_ids':input_ids})
>> Out[92]: tensor([[  0, 794,   1]])

Вот версии:

      torch.__version__
>>Out[93]: '2.0.1+cu117'

transformers.__version__
>>Out[95]: '4.26.1'

peft.__version__
Out[98]: '0.3.0'

0 ответов

Другие вопросы по тегам