Я настраиваю модель «t5-small» для стандартизации жаргонного текста. Я не могу получить правильный результат даже для примера из обучающей выборки.
Пример из обучающего набора: input_text = «у тебя очень сексуальный заголовок». Я ожидал, что модель изменит букву «u» на «ты».
Я думаю, что есть какая-то проблема с форматом набора данных или параметрами. Я прилагаю свой обучающий код, а также код, который я использовал для использования модели. Любая помощь будет принята с благодарностью.
Вот как выглядит мой набор данных в формате CSV:введите здесь описание изображения.
Это мой тренировочный код:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import Dataset
import pandas as pd
# Step 1: Load and preprocess the dataset
df = pd.read_csv('train_full.csv')[:100]
input_texts = df['prompt'].tolist()
target_texts = df['completion'].tolist()
# Step 2: Tokenize the input and target texts along with prompts
tokenizer = AutoTokenizer.from_pretrained("t5-small", model_max_length=512)
prompts = ["Translate the following slang expression to standard language: " for _ in input_texts]
inputs = tokenizer(prompts, input_texts, padding=True, truncation=True, return_tensors="pt")
labels = tokenizer(target_texts, padding=True, truncation=True, return_tensors="pt")
# Step 3: Convert the tokenized dataset into lists of tensors
train_dataset = Dataset.from_dict({
'input_ids': inputs['input_ids'].tolist(),
'attention_mask': inputs['attention_mask'].tolist(),
'decoder_input_ids': labels['input_ids'].tolist(),
'decoder_attention_mask': labels['attention_mask'].tolist(),
})
# Step 4: Define the model architecture and load a pre-trained model
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") # You can use any other model as well
# Step 5: Define a custom data collator for the text-to-text generation task
class Text2TextDataCollator(DataCollatorForSeq2Seq):
def __call__(self, features):
input_ids = torch.tensor([feature["input_ids"] for feature in features])
attention_mask = torch.tensor([feature["attention_mask"] for feature in features])
decoder_input_ids = torch.tensor([feature["decoder_input_ids"] for feature in features])
decoder_attention_mask = torch.tensor([feature["decoder_attention_mask"] for feature in features])
labels = decoder_input_ids.clone()
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"labels": labels,
}
# Step 6: Fine-tune the model on the dataset using the custom data collator
training_args = Seq2SeqTrainingArguments(
output_dir="./text2text_generator",
num_train_epochs=3, # You can adjust the number of epochs based on your dataset and resources.
per_device_train_batch_size=8,
save_steps=10,
save_total_limit=2,
)
data_collator = Text2TextDataCollator(tokenizer)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
)
trainer.train()
# Step 7: Save the fine-tuned model and tokenizer for future use
trainer.save_model("./fine_tuned_model_6")
Это мой код для получения вывода:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Step 1: Load the fine-tuned model and tokenizer
model_path = "./fine_tuned_model_6" # Replace with the path to your saved model directory
tokenizer = AutoTokenizer.from_pretrained("t5-small", model_max_length=512)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
prompt = "Translate the following slang expression to standard language: "
input_text = "u have a very sexy header rawr"
# Concatenate the prompt and input text
combined_text = prompt + input_text
# Tokenize and encode the combined text
input_ids = tokenizer.encode(combined_text, return_tensors="pt")
# Generate the output using the loaded model
output_ids = model.generate(input_ids, max_length=30) # Adjust max_length as needed
# Decode the output and skip special tokens
standardized_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(standardized_text)
Мой вывод таков:Translate the following slang expression to standard language: u have a very sexy header rawr rawr rawr