Как сохранить модель gpt-2-simple после обучения?
Я обучил модель чат-бота gpt-2-simple , но не могу ее сохранить. Для меня важно загрузить обученную модель из Colab, потому что в противном случае мне придется каждый раз загружать модель 355M (см. код ниже).
Я пробовал различные методы сохранения обученной модели (например,gpt2.saveload.save_gpt2()
), но ничего не помогло, и у меня больше нет идей.
Мой код обучения:
%tensorflow_version 2.x
!pip install gpt-2-simple
import gpt_2_simple as gpt2
import json
gpt2.download_gpt2(model_name="355M")
raw_data = '/content/drive/My Drive/data.json'
with open(raw_data, 'r') as f:
df =json.load(f)
data = []
for x in df:
for y in range(len(x['dialog'])-1):
a = '[BOT] : ' + x['dialog'][y+1]['text']
q = '[YOU] : ' + x['dialog'][y]['text']
data.append(q)
data.append(a)
with open('chatbot.txt', 'w') as f:
for line in data:
try:
f.write(line)
f.write('\n')
except:
pass
file_name = "/content/chatbot.txt"
sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
dataset=file_name,
model_name='355M',
steps=500,
restore_from='fresh',
run_name='run1',
print_every=10,
sample_every=100,
save_every=100
)
while True:
ques = input("Question : ")
inp = '[YOU] : '+ques+'\n'+'[BOT] :'
x = gpt2.generate(sess,
length=20,
temperature = 0.6,
include_prefix=False,
prefix=inp,
nsamples=1,
)
1 ответ
Репозиторий gpt-2-simple README.md содержит ссылку на пример блокнота Colab , в котором указано следующее:
Другие необязательные, но полезные параметры gpt2.finetune:
restore_from
: Установлен вfresh
начать обучение с базовой GPT-2 или установить последнюю версиюrestart
обучение с существующего КПП.- ...
- : подпапка внутри контрольной точки для сохранения модели. Это полезно, если вы хотите работать с несколькими моделями (также нужно будет указать при загрузке модели)
overwrite
: Установлен вTrue
если вы хотите продолжить настройку существующей модели (сrestore_from='latest'
) без создания дубликатов.
В README.md также указано, что контрольные точки модели хранятся в/checkpoint/run1
по умолчанию и что можно передатьrun_name
параметр дляfinetune
иload_gpt2
если вы хотите сохранить/загрузить несколько моделей в папке контрольной точки.
В целом вы сможете делать следующее, чтобы работать с сохраненными моделями вместо повторной загрузки каждый раз:
import gpt_2_simple as gpt2
sess = gpt2.start_tf_sess()
# To load existing model in default checkpoint dir from "run1"
gpt2.load_gpt2(sess)
# Or, to finetune existing model in default checkpoint dir from "run1"
gpt2.finetune(sess,
dataset=file_name,
model_name='355M',
steps=500,
restore_from='latest',
run_name='run1',
overwrite=True,
print_every=10,
sample_every=100,
save_every=500
)
Дополнительную информацию см. в исходном коде функций load_gpt2() и Finetune() .