Подведение итогов с помощью Huggingface: как генерировать по одному слову за раз?

Я использую DistilBART для абстрактного обобщения. Методgenerate()очень прост в использовании. Однако он возвращает полные, законченные сводки. Что я хочу, так это на каждом шаге получать доступ к логитам, чтобы затем получить список кандидатов на следующее слово и выбрать на основе моих собственных критериев. После выбора продолжайте со следующим словом и так далее, пока не будет создан токен EOS.

Я знаю, что могу получить доступ к логитам, выполнив model(**input).logits[:, -1, :], но здесь вводом будет весь (закодированный) текст, так чему же будут соответствовать эти логиты? Первый сгенерированный токен? Последний?

Спасибо за ответ!

1 ответ

Для справки в будущем, вот как это можно сделать (примечание: это относится к моделям кодировщик-декодер, таким как BART):

1. Инициализация

      import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Load model
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-1-1")
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-xsum-1-1")

text = "..."

# Tokenize text
batch = tokenizer(text, return_tensors="pt")

2. Пример 1: Генерация сводки с жадным декодированием (без кеша)

      generated_sequence = torch.tensor([[tokenizer.sep_token_id]])  # initial token

# Generation loop
while True:
    with torch.no_grad():
        output = model(input_ids=batch["input_ids"], decoder_input_ids=generated_sequence)
    next_token_logits = output.logits[:, -1, :]
    next_token_scores = next_token_logits.softmax(dim=-1)

    # Take token with highest probability
    next_token = next_token_scores.argmax().unsqueeze(0).unsqueeze(0)

    # Append token to generated sequence
    generated_sequence = torch.cat((generated_sequence, next_token), dim=1)
    # Stop if EOS token generated
    if (generated_sequence.squeeze()[-1] == tokenizer.eos_token_id):
        break

summary = tokenizer.batch_decode(generated_sequence, skip_special_tokens=True)

3. Пример 2: Генерация сводки с выборкой и температурой top-k, top-p (без кеша)

      from transformers.generation_utils import top_k_top_p_filtering

temperature = 0.7
generated_sequence = torch.tensor([[tokenizer.sep_token_id]])  # initial token

# Generation loop
while True:
    with torch.no_grad():
        output = model(input_ids=batch["input_ids"], decoder_input_ids=generated_sequence)
    logits = output.logits[:, -1, :] / temperature  # apply temperature
    filtered_logits = top_k_top_p_filtering(logits=logits, top_k=4, top_p=0.7)
    probabilities = filtered_logits.softmax(dim=-1)

    # Sample next token
    next_token = torch.multinomial(probabilities, 1)

    # Append token to generated sequence
    generated_sequence = torch.cat((generated_sequence, next_token), dim=1)
    # Stop if EOS token generated
    if (generated_sequence.squeeze()[-1] == tokenizer.eos_token_id):
        break

summary = tokenizer.batch_decode(generated_sequence, skip_special_tokens=True)

(Другие стратегии генерации будут аналогичными).


4. Использование кеша

Поскольку входные данные для кодировщика (т. е. суммируемый текст) всегда одни и те же, мы можем кэшировать их, чтобы значительно ускорить генерацию.

      generated_sequence = torch.tensor([[tokenizer.sep_token_id]])  # initial token
input_ids = batch["input_ids"]
past_key_values = None

with torch.no_grad():
    output = model(
        input_ids=input_ids,
        decoder_input_ids=generated_sequence,
        past_key_values=past_key_values
    )
    
encoder_outputs=output.encoder_last_hidden_state

# Generation loop
while True:
    # From here on, use cached attention
    past_key_values = output.past_key_values
    next_token_logits = output.logits[:, -1, :]
    next_token_scores = next_token_logits.softmax(dim=-1)
    next_token = next_token_scores.argmax().unsqueeze(0).unsqueeze(0)  # greedy decoding
    generated_sequence = torch.cat((generated_sequence, next_token), dim=1)
    # Stop if EOS token generated
    if (generated_sequence.squeeze()[-1] == tokenizer.eos_token_id):
        break
    with torch.no_grad():
        output = model(
            decoder_input_ids=torch.tensor([[generated_sequence.squeeze()[-1]]]),
            past_key_values=past_key_values,
            encoder_outputs=encoder_outputs
        )

summary = tokenizer.batch_decode(generated_sequence, skip_special_tokens=True)
Другие вопросы по тегам