Заголовок: Генерация предложений с помощью TRL при сохранении тональности — проблема с AutoModelForCausalLMWithValueHead

В настоящее время я работаю над созданием предложений с помощью TRL (Transformers Reinforcement Learning), сохраняя при этом то же настроение, что и примеры предложений. Однако я столкнулся с проблемой с кодом TRL, который использует, который в первую очередь предназначен для генерации ответов, а не образца текста.

Я был бы очень признателен за любые рекомендации или предложения о том, как решить эту проблему и соответствующим образом изменить код TRL для создания образца текста с сохранением тональности.

Заранее благодарю за ценную информацию!

Вот код:

      # 0. imports
import torch
from transformers import GPT2Tokenizer

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer


# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# 2. initialize trainer
ppo_config = {"batch_size": 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)

# 3. encode a query
query_txt = "I want to rewrite this sentence with the same sentiment; ex. I really like this movie "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device)

# 4. generate model response
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": 20,
}
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
response_txt = tokenizer.decode(response_tensor[0])

# 5. define a reward for a response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0, device=model.pretrained_model.device)]

# 6. train model with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

0 ответов

Другие вопросы по тегам