Как добавить контрольную точку активации Deepspeed в LLM для тонкой настройки в PyTorch Lightning?
Я пытаюсь включить контрольную точку активации для модели T5-3b, чтобы значительно освободить память графического процессора. Однако не совсем понятно, как реализовать LLM. Судя по документации PTL , это примерно так:
from lightning.pytorch import Trainer
import deepspeed
class MyModel(LightningModule):
...
def __init__(self):
super().__init__()
self.block_1 = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
self.block_2 = torch.nn.Linear(32, 2)
def forward(self, x):
# Use the DeepSpeed checkpointing function instead of calling the module directly
# checkpointing self.block_1 means the activations are deleted after use,
# and re-calculated during the backward passes
x = deepspeed.checkpointing.checkpoint(self.block_1, x)
return self.block_2(x)
Вот мой код модели PTL LLM. Я хочу добавить контрольную точку глубокой скорости, поэтому попытался сделать это следующим образом:
class T5FineTuner(pl.LightningModule):
"""PyTorch Lightning T5 Model class"""
def __init__(self, hparams, tokenizer, model):
"""initiates a PyTorch Lightning T5 Model"""
super().__init__()
self.hparams.update(vars(hparams))
self.save_hyperparameters(self.hparams)
self.model = model
self.tokenizer = tokenizer
self.outputdir = self.hparams.output_dir
self.average_training_loss = None
self.average_validation_loss = None
self.save_only_last_epoch = self.hparams.save_only_last_epoch
def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
return deepspeed.checkpointing.checkpoint(self._forward, input_ids, attention_mask, decoder_attention_mask, labels)
def _forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
output = self.model(
input_ids,
attention_mask=attention_mask,
labels=labels,
decoder_attention_mask=decoder_attention_mask,
)
return output.loss, output.logits
def training_step(self, batch, batch_size):
"""training step"""
input_ids = batch["source_text_input_ids"]
attention_mask = batch["source_text_attention_mask"]
labels = batch["labels"]
labels_attention_mask = batch["labels_attention_mask"]
loss, outputs = self(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_attention_mask=labels_attention_mask,
labels=labels,
)
self.log(
"train_loss",
loss,
prog_bar=True,
logger=True,
on_epoch=True,
on_step=True,
sync_dist=True,
)
return loss
К сожалению, я получаю следующую ошибку:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Полный воспроизводимый пример можно найти здесь.