Как сохранить модель при использовании MXnet

Я использую MXnet для обучения CNN (в R), и я могу обучить модель без каких-либо ошибок с помощью следующего кода:

model <- mx.model.FeedForward.create(symbol=network,
                                     X=train.iter,
                                     ctx=mx.gpu(0),
                                     num.round=20,
                                     array.batch.size=batch.size,
                                     learning.rate=0.1,
                                     momentum=0.1,  
                                     eval.metric=mx.metric.accuracy,
                                     wd=0.001,
                                     batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
    )

Но так как этот процесс занимает много времени, я запускаю его на сервере ночью и хочу сохранить модель для использования после окончания обучения.

Я использовал:

save(list = ls(), file="mymodel.RData")

а также

mx.model.save("mymodel", 10)

Но никто из них не может спасти модель! например, когда я загружаю "mymodel.RData"Я не могу предсказать метки для тестового набора!

Другой пример, когда я загружаю "mymodel.RData" и попробуйте построить его с помощью следующего кода:

graph.viz(model$symbol$as.json())

Я получаю следующую ошибку:

Error in model$symbol$as.json() : external pointer is not valid

Кто-нибудь может дать мне решение для сохранения и последующей загрузки этой модели для будущего использования?

Спасибо

2 ответа

Решение

Вы можете сохранить модель по

model <- mx.model.FeedForward.create(symbol=network,
                                 X=train.iter,
                                 ctx=mx.gpu(0),
                                 num.round=20,
                                 array.batch.size=batch.size,
                                 learning.rate=0.1,
                                 momentum=0.1,  
                                 eval.metric=mx.metric.accuracy,
                                 wd=0.001,
                                 epoch.end.callback=mx.callback.save.checkpoint("model_prefix")
                                 batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
)

Модель mxnet - это список R, но ее первый компонент - это не объект R, а указатель C++, и его нельзя сохранить и перезагрузить как объект R. Следовательно, модель должна быть сериализована, чтобы вести себя как реальный объект R. Сериализованный объект также является списком, но его первый объект - это текстовая строка, содержащая информацию о модели.

Чтобы сохранить модель:

modelR <- mx.serialize(model)
save(modelR, file="~/model1.RData")

Чтобы получить его и использовать снова:

load("~/model1.RData", verbose=TRUE)
model <- mx.unserialize(modelR)

Для сохранения снимка вашего прогресса в обучении рекомендуется использовать save_snapshot ( http://mxnet.io/api/python/module.html) как часть обратного вызова после каждого обучения эпохи. В R эквивалентная команда, вероятно, mx.callback.save.checkpoint, но я не использую R и не уверен в ее использовании.

Использование этих снимков также может позволить вам воспользоваться преимуществом недорогого варианта использования AWS Spot market ( https://aws.amazon.com/ec2/spot/pricing/), который, например, теперь предлагает, например, 16 графических процессоров K80. по цене $3,8/ час по сравнению с ценой по требованию $14,4. Такая скидка 80%-90% распространена на спотовом рынке и может оптимизировать скорость и стоимость обучения, если вы правильно используете эти снимки.

Другие вопросы по тегам