Как правильно возобновить обучение сети из файла контрольной точки tenorflow?
Я изо всех сил пытаюсь восстановить модель за один день без какого-либо успеха. Мой код состоит из class TF_MLPRegressor()
где я определяю сетевую архитектуру внутри конструктора. Затем я призываю fit()
функция, чтобы сделать обучение. Так вот, как я сохраняю простую модель Perceptron с 1 скрытым слоем внутри fit()
функция:
starting_epoch = 0
# Launch the graph
tf.set_random_seed(self.random_state) # fix the random seed before creating the Session in order to take effect!
if hasattr(self, 'sess'):
self.sess.close()
del self.sess # delete Session to release memory
gc.collect()
self.sess = tf.Session(config=self.config) # save the session to predict from new data
# Create a saver object which will save all the variables
saver = tf.train.Saver(max_to_keep=2) # max_to_keep=2 means to not keep more than 2 checkpoint files
self.sess.run(tf.global_variables_initializer())
# ... (each 100 epochs)
saver.save(self.sess, self.checkpoint_dir+"/resume", global_step=epoch)
Затем я создаю новый TF_MLPRegressor()
экземпляр с точно такими же значениями входного параметра и вызвать fit()
функция для восстановления модели так:
self.sess = tf.Session(config=self.config) # create a new session to load saved variables
ckpt = tf.train.latest_checkpoint(self.checkpoint_dir)
starting_epoch = int(ckpt.split('-')[-1])
metagraph = ".".join([ckpt, 'meta'])
saver = tf.train.import_meta_graph(metagraph)
self.sess.run(tf.global_variables_initializer()) # Initialize variables
lhl = tf.trainable_variables()[2]
lhlA = lhl.eval(session=self.sess)
saver.restore(sess=self.sess, save_path=ckpt) # Restore model weights from previously saved model
lhlB = lhl.eval(session=self.sess)
print lhlA == lhlB
lhlA
а также lhlB
последние веса скрытых слоев до и после восстановления, и согласно моему коду они полностью совпадают, а именно сохраненная модель не загружается в сеанс. Что я делаю неправильно?
1 ответ
Я нашел обходной путь! Как ни странно, метаграф не содержит всех переменных, которые я определил, или присваивает им новые имена. Для примеров в конструкторе я определяю тензоры, которые будут переносить входные векторы объектов и экспериментальные значения:
self.x = tf.placeholder("float", [None, feat_num], name='x')
self.y = tf.placeholder("float", [None], name='y')
Тем не менее, когда я делаю tf.reset_default_graph()
и загрузить метаграф, я получаю следующий список переменных:
[
<tf.Variable 'Variable:0' shape=(300, 300) dtype=float32_ref>,
<tf.Variable 'Variable_1:0' shape=(300,) dtype=float32_ref>,
<tf.Variable 'Variable_2:0' shape=(300, 1) dtype=float32_ref>,
<tf.Variable 'Variable_3:0' shape=(1,) dtype=float32_ref>
]
Для записи, каждый входной вектор объектов имеет 300 объектов. Во всяком случае, когда я позже попробую начать тренировку, используя:
_, c, p = self.sess.run([self.optimizer, self.cost, self.pred],
feed_dict={self.x: batch_x, self.y: batch_y, self.isTrain: True})
Я получаю ошибку как:
"TypeError: Cannot interpret feed_dict key as Tensor: Tensor 'x' is not an element of this graph."
Итак, каждый раз, когда я создаю экземпляр class TF_MLPRegressor()
Я определяю сетевую архитектуру внутри конструктора, я решил не загружать метаграф, и это сработало! Я не знаю, почему TF не сохраняет все переменные в мета-графике, возможно, потому что я явно определяю сетевую архитектуру (я не использую обертки или слои по умолчанию), как в примере ниже:
Подводя итог, я сохраняю свои модели, как описано в моем первом сообщении, но для их восстановления я использую это:
saver = tf.train.Saver(max_to_keep=2)
self.sess = tf.Session(config=self.config) # create a new session to load saved variables
self.sess.run(tf.global_variables_initializer())
ckpt = tf.train.latest_checkpoint(self.checkpoint_dir)
saver.restore(sess=self.sess, save_path=ckpt) # Restore model weights from previously saved model