Как добавить контрольную точку активации Deepspeed в LLM для тонкой настройки в PyTorch Lightning?

Я пытаюсь включить контрольную точку активации для модели T5-3b, чтобы значительно освободить память графического процессора. Однако не совсем понятно, как реализовать LLM. Судя по документации PTL , это примерно так:

      from lightning.pytorch import Trainer
import deepspeed


class MyModel(LightningModule):
    ...

    def __init__(self):
        super().__init__()
        self.block_1 = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
        self.block_2 = torch.nn.Linear(32, 2)

    def forward(self, x):
        # Use the DeepSpeed checkpointing function instead of calling the module directly
        # checkpointing self.block_1 means the activations are deleted after use,
        # and re-calculated during the backward passes
        x = deepspeed.checkpointing.checkpoint(self.block_1, x)
        return self.block_2(x)

Вот мой код модели PTL LLM. Я хочу добавить контрольную точку глубокой скорости, поэтому попытался сделать это следующим образом:

      class T5FineTuner(pl.LightningModule):
    """PyTorch Lightning T5 Model class"""

    def __init__(self, hparams, tokenizer, model):
        """initiates a PyTorch Lightning T5 Model"""
        super().__init__()
        self.hparams.update(vars(hparams))
        self.save_hyperparameters(self.hparams)

        self.model = model
        self.tokenizer = tokenizer
        self.outputdir = self.hparams.output_dir
        self.average_training_loss = None
        self.average_validation_loss = None
        self.save_only_last_epoch = self.hparams.save_only_last_epoch

    def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
                
        return deepspeed.checkpointing.checkpoint(self._forward, input_ids, attention_mask, decoder_attention_mask, labels)
    
    def _forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
                
        output = self.model(
            input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=decoder_attention_mask,
        )

        return output.loss, output.logits

    def training_step(self, batch, batch_size):
        """training step"""
        input_ids = batch["source_text_input_ids"]
        attention_mask = batch["source_text_attention_mask"]
        labels = batch["labels"]
        labels_attention_mask = batch["labels_attention_mask"]

        loss, outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=labels_attention_mask,
            labels=labels,
        )

        self.log(
            "train_loss",
            loss,
            prog_bar=True,
            logger=True,
            on_epoch=True,
            on_step=True,
            sync_dist=True,
        )
        return loss

К сожалению, я получаю следующую ошибку:

      RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Полный воспроизводимый пример можно найти здесь.

0 ответов

Другие вопросы по тегам