Как правильно возобновить обучение сети из файла контрольной точки tenorflow?

Я изо всех сил пытаюсь восстановить модель за один день без какого-либо успеха. Мой код состоит из class TF_MLPRegressor()где я определяю сетевую архитектуру внутри конструктора. Затем я призываю fit() функция, чтобы сделать обучение. Так вот, как я сохраняю простую модель Perceptron с 1 скрытым слоем внутри fit() функция:

            starting_epoch = 0
            # Launch the graph
            tf.set_random_seed(self.random_state)   # fix the random seed before creating the Session in order to take effect!
            if hasattr(self, 'sess'):
                self.sess.close()
                del self.sess   # delete Session to release memory
                gc.collect()
            self.sess = tf.Session(config=self.config) # save the session to predict from new data
            # Create a saver object which will save all the variables
            saver = tf.train.Saver(max_to_keep=2)  # max_to_keep=2 means to not keep more than 2 checkpoint files
            self.sess.run(tf.global_variables_initializer())

# ... (each 100 epochs)

            saver.save(self.sess, self.checkpoint_dir+"/resume", global_step=epoch)

Затем я создаю новый TF_MLPRegressor() экземпляр с точно такими же значениями входного параметра и вызвать fit() функция для восстановления модели так:

    self.sess = tf.Session(config=self.config)  # create a new session to load saved variables
    ckpt = tf.train.latest_checkpoint(self.checkpoint_dir)
    starting_epoch = int(ckpt.split('-')[-1])
    metagraph = ".".join([ckpt, 'meta'])
    saver = tf.train.import_meta_graph(metagraph)
    self.sess.run(tf.global_variables_initializer())    # Initialize variables
    lhl = tf.trainable_variables()[2]
    lhlA = lhl.eval(session=self.sess)
    saver.restore(sess=self.sess, save_path=ckpt)   # Restore model weights from previously saved model
    lhlB = lhl.eval(session=self.sess)
    print lhlA == lhlB

lhlA а также lhlB последние веса скрытых слоев до и после восстановления, и согласно моему коду они полностью совпадают, а именно сохраненная модель не загружается в сеанс. Что я делаю неправильно?

1 ответ

Я нашел обходной путь! Как ни странно, метаграф не содержит всех переменных, которые я определил, или присваивает им новые имена. Для примеров в конструкторе я определяю тензоры, которые будут переносить входные векторы объектов и экспериментальные значения:

self.x = tf.placeholder("float", [None, feat_num], name='x')
self.y = tf.placeholder("float", [None], name='y')

Тем не менее, когда я делаю tf.reset_default_graph() и загрузить метаграф, я получаю следующий список переменных:

[
<tf.Variable 'Variable:0' shape=(300, 300) dtype=float32_ref>, 
<tf.Variable 'Variable_1:0' shape=(300,) dtype=float32_ref>, 
<tf.Variable 'Variable_2:0' shape=(300, 1) dtype=float32_ref>, 
<tf.Variable 'Variable_3:0' shape=(1,) dtype=float32_ref>
]

Для записи, каждый входной вектор объектов имеет 300 объектов. Во всяком случае, когда я позже попробую начать тренировку, используя:

_, c, p = self.sess.run([self.optimizer, self.cost, self.pred], 
feed_dict={self.x: batch_x, self.y: batch_y, self.isTrain: True})

Я получаю ошибку как:

"TypeError: Cannot interpret feed_dict key as Tensor: Tensor 'x' is not an element of this graph."

Итак, каждый раз, когда я создаю экземпляр class TF_MLPRegressor()Я определяю сетевую архитектуру внутри конструктора, я решил не загружать метаграф, и это сработало! Я не знаю, почему TF не сохраняет все переменные в мета-графике, возможно, потому что я явно определяю сетевую архитектуру (я не использую обертки или слои по умолчанию), как в примере ниже:

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/4_Utils/save_restore_model.py

Подводя итог, я сохраняю свои модели, как описано в моем первом сообщении, но для их восстановления я использую это:

saver = tf.train.Saver(max_to_keep=2)
self.sess = tf.Session(config=self.config)  # create a new session to load saved variables
self.sess.run(tf.global_variables_initializer())
ckpt = tf.train.latest_checkpoint(self.checkpoint_dir)
saver.restore(sess=self.sess, save_path=ckpt)   # Restore model weights from previously saved model
Другие вопросы по тегам