Как я могу загрузить частично предварительно обученную модель pytorch?
Я пытаюсь запустить модель pytorch для задачи классификации предложений. Поскольку я работаю с медицинскими записями, я использую ClinicalBert (https://github.com/kexinhuang12345/clinicalBERT) и хотел бы использовать его предварительно обученные веса. К сожалению, модель ClinicalBert классифицирует текст только по одной двоичной метке, в то время как у меня есть 281 двоичная метка. Поэтому я пытаюсь реализовать этот код https://github.com/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb, где конечный классификатор после bert имеет длину 281.
Как я могу загрузить предварительно обученные веса Берта из модели ClinicalBert без загрузки классификационных весов?
Наивно пытаясь загрузить веса из предварительно натренированных весов ClinicalBert, я получаю следующую ошибку:
size mismatch for classifier.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in current model is torch.Size([281, 768]).
size mismatch for classifier.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([281]).
В настоящее время я пытался заменить функцию from_pretrained из пакета pytorch_pretrained_bert и добавить веса и смещения классификатора следующим образом:
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
...
if state_dict is None:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path, map_location='cpu')
state_dict.pop('classifier.weight')
state_dict.pop('classifier.bias')
old_keys = []
new_keys = []
...
И я получаю следующее сообщение об ошибке: INFO - models_diagnosis - Веса BertForMultiLabelSequenceClassification не инициализированы из предварительно обученной модели: ['classifier.weight', 'classifier.bias']
В конце я хотел бы загрузить вложения bert из предварительно тренированных весов ClinicalBert и случайным образом инициализировать веса верхнего классификатора.
1 ответ
Удаление ключей в состоянии dict перед загрузкой - хорошее начало. Предполагая, что вы используете nn.Module.load_state_dict
чтобы загрузить предварительно тренированные веса, вам также необходимо установить strict=False
аргумент, чтобы избежать ошибок из-за неожиданных или отсутствующих ключей. Это будет игнорировать записи в state_dict, которых нет в модели (неожиданные ключи), и, что более важно для вас, оставит отсутствующие записи с их инициализацией по умолчанию (отсутствующие ключи). В целях безопасности вы можете проверить возвращаемое значение метода, чтобы убедиться, что рассматриваемые веса являются частью отсутствующих ключей и что нет никаких неожиданных ключей.