Обратное распространение ошибки/минибатчинг при обучении больших языковых моделей (LLM)
Я изо всех сил пытаюсь понять, как работает обратное распространение для LLM на основе трансформатора.
Вот мое предположение о том, как работает этот процесс. Учитывая последовательность токенов длиной 64, мы обрабатываем последовательность параллельно, используя форсирование учителя (т. е. для каждой ФАКТИЧЕСКОЙ последовательной подпоследовательности, начиная с первого токена, мы ПРОГНОЗИРУЕМ следующий токен и вычисляем потери на основе нового предсказанного токена и фактического следующего токена). токен, что создает 63 значения перекрестной энтропии).
Мы делаем это для многих (скажем, размера пакета 8192) последовательностей одновременно, в одном мини-пакете, во время предварительного обучения. Затем мы делаем шаг обратного распространения ошибки по сети и корректируем веса — до сих пор мы сделали только один шаг. Затем мы переходим к следующей партии последовательностей размером 8192.
- Верно ли это понимание?
- Если да, то усредним ли мы 63 потери для одной последовательности?
- Усредняем ли мы потери по 8192 последовательностям?
- Если не усреднять, как накапливаются потери для обратного распространения ошибки для одной мини-партии и почему?
Пытался найти статьи, подробно объясняющие этот процесс для языковых моделей, но, похоже, не нашел ни одной - большинство из них были посвящены нейронным сетям в целом и не проясняли некоторые из моих вопросов о языковых последовательностях.
1 ответ
Для большинства авторегрессионных языковых моделей в основном существуют две процедуры обучения:
- Модель повседневного языка (по одному слову предсказывает другое слово)
- Модель языка в маске (учитывая фиксированное пространство последовательностей, предскажите лучшее слово, которое вписывается в маску)
Скорее всего, после популярности GPT-3 в 2022 году вы захотите понять программу обучения модели повседневного языка (CLM).
Вот пример кода, который демонстрирует, когда происходит обратное распространение ошибки . Обратите внимание наloss.backwards()
/accelerator.backwards()
, https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm_no_trainer.py#L616C1-L626C38
for step, batch in enumerate(active_dataloader):
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Мы видим, что выполняем обратное распространение ошибки в конце каждого пакета. И каждая партия состоит из нескольких точек входных данных изDataLoader
объект. В случае наборов языковых данных точка данных, скорее всего, представляет собой входное предложение/абзац фиксированного размера с 512/1024 токенами подслов.
Вопрос: Когда в модели обычного языка происходит обратное распространение ошибки?
О: После прямого прохода модели через размер пакета n и максимальную длину последовательности l мы вычисляем потери в результате обратного распространения ошибки для шагов nxl .
Далее мы рассмотрим, какloss.backward()
работает, в обычных ванильных трансформаторах потери, скорее всего, рассчитываются с помощью
- для задач с недоумением только для декодера функция потерь - это недоумение
- для задач перевода кодировщик-декодер сглаживание меток, которое вычисляет расхождение KL между предсказанными словами и фактическими словами.
Для каждого nxl мы вычисляем потерю того, правильно ли мы предсказали слова. Некоторые конкретные примеры см. на https://towardsdatascience.com/cross-entropy-negative-log-likelihood-and-all-that-jazz-47a95bd2e81.
Потери NLL/недоумения рассчитываются на 1 xl для n предложений в большинстве случаев, когда потери при прогнозировании предложения обычно не влияют на другие предложения в пакете. Поэтому усреднение потерь для nxl по n вполне разумно.
Обратите внимание, что мы не вычисляем потери NLL для каждого токена в каждом предложении, а вычисляем потери NLL для всех токенов в предложении. Для более подробной информации это довольно длинно, но см. Раздел 7.7 на https://web.stanford.edu/~jurafsky/slp3/7.pdf. Как правило, идея та же самая: мы вычисляем выходную последовательность при каждом прямом проходе, затем проверяем, является ли каждый вычисленный нами токен правильным/неправильным, двоичную метку для каждой пары прогнозов и целевых токенов, затем вычисляются потери для всей последовательности прогнозов. за предложение.
Вопрос: Какие потери рассчитываются при обучении языковой модели?
A: Скорее всего, недоумение для простых языковых моделей, предназначенных только для декодера, а также расхождение KL или потеря перекрестной энтропии для задач seq2seq.
Вопрос: Хватит медлить, ответьте на вопрос: «Усредняем ли мы потери по последовательностям 8192?»
О: Предполагая, что 8192 — это n * l шагов для каждого пакета, т. е. если пакет состоит из 8 предложений длиной 1024, мы вычисляем потери для каждого пакета.
Определение «мини-пакета», «полного пакета» или «эпохи» различается в зависимости от того, кого вы спрашиваете, поэтому давайте в данном случае назовем 8 предложений пакетом.
Итак, «Мы усредняем?»
Ответ: В большинстве случаев у нас есть две потери: потеря обучения и потеря оценки/проверки.
Что касается потерь при проверке, обычно мы вычисляем недоумение для каждого ввода, а затем усредняем его, https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm_no_trainer.py#L642C1-L656C38
model.eval()
losses = []
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(**batch)
loss = outputs.loss
losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))
losses = torch.cat(losses)
try:
eval_loss = torch.mean(losses)
perplexity = math.exp(eval_loss)
except OverflowError:
perplexity = float("inf")
Что касается потерь обучения, для каждой партии рассчитываются потери для каждого токена n * l , потери суммируются, а затем усредняются, когда регистратор сообщает о потерях обучения.
if args.with_tracking:
total_loss = 0
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We skip the first `n` batches in the dataloader when resuming from a checkpoint
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
...
if args.with_tracking:
accelerator.log(
{
"perplexity": perplexity,
"eval_loss": eval_loss,
"train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
step=completed_steps,
)
А если говорить конкретно о модели GPT-2, вы можете видеть, что она выполняет одно и то же накопление для каждой партии n * l , см. https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L1269
Вопрос: Почему вы не отвечаете, средний ли он?
A: Это хороший вопрос! В большинстве случаев вы вычисляете потери для каждого пакета, да, вы усредняете потери из-за недоумения для каждого пакета, которые вы видите для всего набора данных. Но на самом деле процедура обучения полностью зависит от того, кто кодирует модель.
В большинстве моделей вы можете усреднить потери для nxl по n на партию, а затем макроусреднить потери для каждой партии по номеру. партий в эпоху, чтобы сообщить о потере обучения. Если усреднение выполнено, оно в основном делится только на n для каждой потери, рассчитанной для прогнозов 1 xl .
Хотя это интуитивно понятно, для каждой модели внутриtransformers
, каждый выбирает разные потери. https://github.com/huggingface/transformers/tree/main/src/transformers/models
Существует множество других факторов, влияющих на время обновления модели и расчета потерь. Если градиенты накапливаются, а обновления задерживаются, например https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.gradient_accumulation_steps , то среднее значение будет не только по nxl , но и, возможно, по dxnxl , где d — номер. шаги, которые вы отложили.
Эпилог
Итак, чтобы правильно ответить, усреднены потери или как/когда модель обновляет градиенты на основе потерь с обратным распространением ошибки, вы должны указать:
- Какая модель/архитектура используется? И если есть существующая база кода, которую вы можете проверить. Различные реализации также могут иметь разные процедуры расчета потерь.
- Какие гиперпараметры обучения используются, особенно. связанные с оптимизаторами?
- Какую языковую модель вы тренируете? В большинстве случаев я описывал модели случайного языка, но есть и другие, такие как модели замаскированного языка или даже неавторегрессионные модели, которые имеют другие потери.