Обратное распространение ошибки/минибатчинг при обучении больших языковых моделей (LLM)

Я изо всех сил пытаюсь понять, как работает обратное распространение для LLM на основе трансформатора.

Вот мое предположение о том, как работает этот процесс. Учитывая последовательность токенов длиной 64, мы обрабатываем последовательность параллельно, используя форсирование учителя (т. е. для каждой ФАКТИЧЕСКОЙ последовательной подпоследовательности, начиная с первого токена, мы ПРОГНОЗИРУЕМ следующий токен и вычисляем потери на основе нового предсказанного токена и фактического следующего токена). токен, что создает 63 значения перекрестной энтропии).

Мы делаем это для многих (скажем, размера пакета 8192) последовательностей одновременно, в одном мини-пакете, во время предварительного обучения. Затем мы делаем шаг обратного распространения ошибки по сети и корректируем веса — до сих пор мы сделали только один шаг. Затем мы переходим к следующей партии последовательностей размером 8192.

  1. Верно ли это понимание?
  2. Если да, то усредним ли мы 63 потери для одной последовательности?
  3. Усредняем ли мы потери по 8192 последовательностям?
  4. Если не усреднять, как накапливаются потери для обратного распространения ошибки для одной мини-партии и почему?

Пытался найти статьи, подробно объясняющие этот процесс для языковых моделей, но, похоже, не нашел ни одной - большинство из них были посвящены нейронным сетям в целом и не проясняли некоторые из моих вопросов о языковых последовательностях.

1 ответ

Для большинства авторегрессионных языковых моделей в основном существуют две процедуры обучения:

  • Модель повседневного языка (по одному слову предсказывает другое слово)
  • Модель языка в маске (учитывая фиксированное пространство последовательностей, предскажите лучшее слово, которое вписывается в маску)

Скорее всего, после популярности GPT-3 в 2022 году вы захотите понять программу обучения модели повседневного языка (CLM).

Вот пример кода, который демонстрирует, когда происходит обратное распространение ошибки . Обратите внимание наloss.backwards()/accelerator.backwards(), https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm_no_trainer.py#L616C1-L626C38

              for step, batch in enumerate(active_dataloader):
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss
                # We keep track of the loss at each epoch
                if args.with_tracking:
                    total_loss += loss.detach().float()
                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

Мы видим, что выполняем обратное распространение ошибки в конце каждого пакета. И каждая партия состоит из нескольких точек входных данных изDataLoaderобъект. В случае наборов языковых данных точка данных, скорее всего, представляет собой входное предложение/абзац фиксированного размера с 512/1024 токенами подслов.

Вопрос: Когда в модели обычного языка происходит обратное распространение ошибки?

О: После прямого прохода модели через размер пакета n и максимальную длину последовательности l мы вычисляем потери в результате обратного распространения ошибки для шагов nxl .


Далее мы рассмотрим, какloss.backward()работает, в обычных ванильных трансформаторах потери, скорее всего, рассчитываются с помощью

Для каждого nxl мы вычисляем потерю того, правильно ли мы предсказали слова. Некоторые конкретные примеры см. на https://towardsdatascience.com/cross-entropy-negative-log-likelihood-and-all-that-jazz-47a95bd2e81.

Потери NLL/недоумения рассчитываются на 1 xl для n предложений в большинстве случаев, когда потери при прогнозировании предложения обычно не влияют на другие предложения в пакете. Поэтому усреднение потерь для nxl по n вполне разумно.

Обратите внимание, что мы не вычисляем потери NLL для каждого токена в каждом предложении, а вычисляем потери NLL для всех токенов в предложении. Для более подробной информации это довольно длинно, но см. Раздел 7.7 на https://web.stanford.edu/~jurafsky/slp3/7.pdf. Как правило, идея та же самая: мы вычисляем выходную последовательность при каждом прямом проходе, затем проверяем, является ли каждый вычисленный нами токен правильным/неправильным, двоичную метку для каждой пары прогнозов и целевых токенов, затем вычисляются потери для всей последовательности прогнозов. за предложение.

Вопрос: Какие потери рассчитываются при обучении языковой модели?

A: Скорее всего, недоумение для простых языковых моделей, предназначенных только для декодера, а также расхождение KL или потеря перекрестной энтропии для задач seq2seq.


Вопрос: Хватит медлить, ответьте на вопрос: «Усредняем ли мы потери по последовательностям 8192?»

О: Предполагая, что 8192 — это n * l шагов для каждого пакета, т. е. если пакет состоит из 8 предложений длиной 1024, мы вычисляем потери для каждого пакета.

Определение «мини-пакета», «полного пакета» или «эпохи» различается в зависимости от того, кого вы спрашиваете, поэтому давайте в данном случае назовем 8 предложений пакетом.

Итак, «Мы усредняем?»

Ответ: В большинстве случаев у нас есть две потери: потеря обучения и потеря оценки/проверки.

Что касается потерь при проверке, обычно мы вычисляем недоумение для каждого ввода, а затем усредняем его, https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm_no_trainer.py#L642C1-L656C38

              model.eval()
        losses = []
        for step, batch in enumerate(eval_dataloader):
            with torch.no_grad():
                outputs = model(**batch)

            loss = outputs.loss
            losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))

        losses = torch.cat(losses)
        try:
            eval_loss = torch.mean(losses)
            perplexity = math.exp(eval_loss)
        except OverflowError:
            perplexity = float("inf")

Что касается потерь обучения, для каждой партии рассчитываются потери для каждого токена n * l , потери суммируются, а затем усредняются, когда регистратор сообщает о потерях обучения.

              if args.with_tracking:
            total_loss = 0
        if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
            # We skip the first `n` batches in the dataloader when resuming from a checkpoint
            active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
        else:
            active_dataloader = train_dataloader
        for step, batch in enumerate(active_dataloader):
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss
                # We keep track of the loss at each epoch
                if args.with_tracking:
                    total_loss += loss.detach().float()
                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

        ...

        if args.with_tracking:
            accelerator.log(
                {
                    "perplexity": perplexity,
                    "eval_loss": eval_loss,
                    "train_loss": total_loss.item() / len(train_dataloader),
                    "epoch": epoch,
                    "step": completed_steps,
                },
                step=completed_steps,
            )

А если говорить конкретно о модели GPT-2, вы можете видеть, что она выполняет одно и то же накопление для каждой партии n * l , см. https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L1269

Вопрос: Почему вы не отвечаете, средний ли он?

A: Это хороший вопрос! В большинстве случаев вы вычисляете потери для каждого пакета, да, вы усредняете потери из-за недоумения для каждого пакета, которые вы видите для всего набора данных. Но на самом деле процедура обучения полностью зависит от того, кто кодирует модель.

В большинстве моделей вы можете усреднить потери для nxl по n на партию, а затем макроусреднить потери для каждой партии по номеру. партий в эпоху, чтобы сообщить о потере обучения. Если усреднение выполнено, оно в основном делится только на n для каждой потери, рассчитанной для прогнозов 1 xl .

Хотя это интуитивно понятно, для каждой модели внутриtransformers, каждый выбирает разные потери. https://github.com/huggingface/transformers/tree/main/src/transformers/models

Существует множество других факторов, влияющих на время обновления модели и расчета потерь. Если градиенты накапливаются, а обновления задерживаются, например https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.gradient_accumulation_steps , то среднее значение будет не только по nxl , но и, возможно, по dxnxl , где d — номер. шаги, которые вы отложили.

Эпилог

Итак, чтобы правильно ответить, усреднены потери или как/когда модель обновляет градиенты на основе потерь с обратным распространением ошибки, вы должны указать:

  • Какая модель/архитектура используется? И если есть существующая база кода, которую вы можете проверить. Различные реализации также могут иметь разные процедуры расчета потерь.
  • Какие гиперпараметры обучения используются, особенно. связанные с оптимизаторами?
  • Какую языковую модель вы тренируете? В большинстве случаев я описывал модели случайного языка, но есть и другие, такие как модели замаскированного языка или даже неавторегрессионные модели, которые имеют другие потери.
Другие вопросы по тегам