Ошибка: тензорный граф отличается от сеансового графа
Я пытаюсь загрузить ранее обученную тензорную обученную модель из файлов контрольных точек, теперь в этих файлах контрольных точек есть операционные переменные, поэтому для загрузки графика мне нужно сначала загрузить graph_def из файла **ckpt.meta:
graph = tf.Graph()
sess = tf.InteractiveSession(graph=graph)
saver = tf.train.import_meta_graph('/data/model_cache/model.ckpt-39.meta')
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
if os.path.isabs(ckpt.model_checkpoint_path):
saver.restore(sess, ckpt.model_checkpoint_path)
После того, как я загрузил модели, у меня есть метод, который использует эту модель для вывода для реализации алгоритма глубокого сна. Проблема в том, что когда я вызываю eval с сеансом по умолчанию, я получаю сообщение об ошибке ниже:
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 555, in eval
return _eval_using_default_session(self, feed_dict, self.graph, session)File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework /ops.py", line 3495, in _eval_using_default_session
raise ValueError("Cannot use the given session to evaluate tensor: "
ValueError: Cannot use the given session to evaluate tensor: the tensor's graph is different from the session's graph.
Я подтвердил, что tf.get_default_graph() и sess.graph указывают на один и тот же адрес памяти. Должно быть что-то очень простое, чего мне не хватает.
Я новичок в tenorflow, поэтому любая помощь в этом отношении будет по достоинству оценена. Спасибо
1 ответ
Я думаю, что ваша проблема в том, что вы путаете "Python-name" и "TensorFlow-name". Когда вы создаете, например: W = tf.get_variable("weight", ...)
"имя Python" будет W
тогда как "TensorFlow-name" будет weight
, При загрузке модели она не имеет представления о последних именах Python. Так что никогда не узнаешь что W
на самом деле
Сначала вы должны вернуть тензоры и операции, которые вы хотите использовать. Вы перечисляете их с помощью:
for tensor in tf.get_default_graph().get_operations():
print (tensor.name)
Тогда используйте оба get_operation_by_name(name)
а также get_tensor_by_name(name)
чтобы вернуть свои вещи.
Например, если вы хотите получить веса, как я сказал вам раньше, вы должны сделать:
W = graph.get_tensor_by_name("weights:0")
print(W.eval())
Я считаю, что должно работать.
Весьма вероятно, что мета-граф, который вы импортируете, т.е. /data/model_cache/model.ckpt-39.meta отличается от того, который проверяет tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
использовал.
Обычная практика - иметь get_checkpoint_state()
позвонить (или tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
) и использовать его вывод в import_meta_graph()
вызовите и затем, с тем же именем контрольной точки (и возвращенной заставкой), восстановите переменные в сеансе. Это, конечно, можно сделать, если мета-график сохраняется в каждой контрольной точке.