Как сохранить модель gpt-2-simple после обучения?

Я обучил модель чат-бота gpt-2-simple , но не могу ее сохранить. Для меня важно загрузить обученную модель из Colab, потому что в противном случае мне придется каждый раз загружать модель 355M (см. код ниже).

Я пробовал различные методы сохранения обученной модели (например,gpt2.saveload.save_gpt2()), но ничего не помогло, и у меня больше нет идей.

Мой код обучения:

      %tensorflow_version 2.x
!pip install gpt-2-simple

import gpt_2_simple as gpt2
import json

gpt2.download_gpt2(model_name="355M")

raw_data = '/content/drive/My Drive/data.json'

with open(raw_data, 'r') as f:
    df =json.load(f)

data = []

for x in df:
    for y in range(len(x['dialog'])-1):
        a = '[BOT] : ' + x['dialog'][y+1]['text']
        q = '[YOU] : ' + x['dialog'][y]['text']
        data.append(q)
        data.append(a)

with open('chatbot.txt', 'w') as f:
     for line in data:
        try:
            f.write(line)
            f.write('\n')
        except:
            pass

file_name = "/content/chatbot.txt"

sess = gpt2.start_tf_sess()

gpt2.finetune(sess,
              dataset=file_name,
              model_name='355M',
              steps=500,
              restore_from='fresh',
              run_name='run1',
              print_every=10,
              sample_every=100,
              save_every=100
              )

while True:
  ques = input("Question : ")
  inp = '[YOU] : '+ques+'\n'+'[BOT] :'
  x = gpt2.generate(sess,
                length=20,
                temperature = 0.6,
                include_prefix=False,
                prefix=inp,
                nsamples=1,
                )

1 ответ

Репозиторий gpt-2-simple README.md содержит ссылку на пример блокнота Colab , в котором указано следующее:

Другие необязательные, но полезные параметры gpt2.finetune:

  • restore_from: Установлен вfreshначать обучение с базовой GPT-2 или установить последнюю версиюrestartобучение с существующего КПП.
  • ...
  • : подпапка внутри контрольной точки для сохранения модели. Это полезно, если вы хотите работать с несколькими моделями (также нужно будет указать при загрузке модели)
  • overwrite: Установлен вTrueесли вы хотите продолжить настройку существующей модели (сrestore_from='latest') без создания дубликатов.

В README.md также указано, что контрольные точки модели хранятся в/checkpoint/run1по умолчанию и что можно передатьrun_nameпараметр дляfinetuneиload_gpt2если вы хотите сохранить/загрузить несколько моделей в папке контрольной точки.

В целом вы сможете делать следующее, чтобы работать с сохраненными моделями вместо повторной загрузки каждый раз:

      import gpt_2_simple as gpt2

sess = gpt2.start_tf_sess()

# To load existing model in default checkpoint dir from "run1"
gpt2.load_gpt2(sess)

# Or, to finetune existing model in default checkpoint dir from "run1"
gpt2.finetune(sess,
              dataset=file_name,
              model_name='355M',
              steps=500,
              restore_from='latest',
              run_name='run1',
              overwrite=True,
              print_every=10,
              sample_every=100,
              save_every=500
)

Дополнительную информацию см. в исходном коде функций load_gpt2() и Finetune() .

Другие вопросы по тегам