Как сохранить модель при использовании MXnet
Я использую MXnet для обучения CNN (в R), и я могу обучить модель без каких-либо ошибок с помощью следующего кода:
model <- mx.model.FeedForward.create(symbol=network,
X=train.iter,
ctx=mx.gpu(0),
num.round=20,
array.batch.size=batch.size,
learning.rate=0.1,
momentum=0.1,
eval.metric=mx.metric.accuracy,
wd=0.001,
batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
)
Но так как этот процесс занимает много времени, я запускаю его на сервере ночью и хочу сохранить модель для использования после окончания обучения.
Я использовал:
save(list = ls(), file="mymodel.RData")
а также
mx.model.save("mymodel", 10)
Но никто из них не может спасти модель! например, когда я загружаю "mymodel.RData"
Я не могу предсказать метки для тестового набора!
Другой пример, когда я загружаю "mymodel.RData"
и попробуйте построить его с помощью следующего кода:
graph.viz(model$symbol$as.json())
Я получаю следующую ошибку:
Error in model$symbol$as.json() : external pointer is not valid
Кто-нибудь может дать мне решение для сохранения и последующей загрузки этой модели для будущего использования?
Спасибо
2 ответа
Вы можете сохранить модель по
model <- mx.model.FeedForward.create(symbol=network,
X=train.iter,
ctx=mx.gpu(0),
num.round=20,
array.batch.size=batch.size,
learning.rate=0.1,
momentum=0.1,
eval.metric=mx.metric.accuracy,
wd=0.001,
epoch.end.callback=mx.callback.save.checkpoint("model_prefix")
batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
)
Модель mxnet - это список R, но ее первый компонент - это не объект R, а указатель C++, и его нельзя сохранить и перезагрузить как объект R. Следовательно, модель должна быть сериализована, чтобы вести себя как реальный объект R. Сериализованный объект также является списком, но его первый объект - это текстовая строка, содержащая информацию о модели.
Чтобы сохранить модель:
modelR <- mx.serialize(model)
save(modelR, file="~/model1.RData")
Чтобы получить его и использовать снова:
load("~/model1.RData", verbose=TRUE)
model <- mx.unserialize(modelR)
Для сохранения снимка вашего прогресса в обучении рекомендуется использовать save_snapshot ( http://mxnet.io/api/python/module.html) как часть обратного вызова после каждого обучения эпохи. В R эквивалентная команда, вероятно, mx.callback.save.checkpoint, но я не использую R и не уверен в ее использовании.
Использование этих снимков также может позволить вам воспользоваться преимуществом недорогого варианта использования AWS Spot market ( https://aws.amazon.com/ec2/spot/pricing/), который, например, теперь предлагает, например, 16 графических процессоров K80. по цене $3,8/ час по сравнению с ценой по требованию $14,4. Такая скидка 80%-90% распространена на спотовом рынке и может оптимизировать скорость и стоимость обучения, если вы правильно используете эти снимки.