Tensorflow 2.1/Keras - ошибка "output_node is not in graph" при попытке заморозить график
Я пытаюсь сохранить модель, созданную с помощью Keras и сохраненную как файл.h5, но я получаю это сообщение об ошибке каждый раз, когда пытаюсь запустить функцию freeze_session: output_node / Identity не на графике
Это мой код (я использую Tensorflow 2.1.0):
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.compat.v1.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.compat.v1.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
session, input_graph_def, output_names, freeze_var_names)
return frozen_graph
model=kr.models.load_model("model.h5")
model.summary()
# inputs:
print('inputs: ', model.input.op.name)
# outputs:
print('outputs: ', model.output.op.name)
#layers:
layer_names=[layer.name for layer in model.layers]
print(layer_names)
Какие отпечатки:
inputs: input_node
outputs: output_node/Identity
['input_node', 'conv2d_6', 'max_pooling2d_6', 'conv2d_7', 'max_pooling2d_7', 'conv2d_8', 'max_pooling2d_8', 'flatten_2', 'dense_4', 'dense_5', 'output_node']
как и ожидалось (те же имена слоев и выходы, что и в модели, которую я сохранил после обучения).
Затем я пытаюсь вызвать функцию freeze_session и сохранить полученный замороженный график:
frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])
write_graph(frozen_graph, './', 'graph.pbtxt', as_text=True)
write_graph(frozen_graph, './', 'graph.pb', as_text=False)
но я получаю эту ошибку:
AssertionError Traceback (most recent call last)
<ipython-input-4-1848000e99b7> in <module>
----> 1 frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])
2 write_graph(frozen_graph, './', 'graph.pbtxt', as_text=True)
3 write_graph(frozen_graph, './', 'graph.pb', as_text=False)
<ipython-input-2-3214992381a9> in freeze_session(session, keep_var_names, output_names, clear_devices)
24 node.device = ""
25 frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
---> 26 session, input_graph_def, output_names, freeze_var_names)
27 return frozen_graph
c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\util\deprecation.py in new_func(*args, **kwargs)
322 'in a future version' if date is None else ('after %s' % date),
323 instructions)
--> 324 return func(*args, **kwargs)
325 return tf_decorator.make_decorator(
326 func, new_func, 'deprecated',
c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\framework\graph_util_impl.py in convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist, variable_names_blacklist)
275 # This graph only includes the nodes needed to evaluate the output nodes, and
276 # removes unneeded nodes like those involved in saving and assignment.
--> 277 inference_graph = extract_sub_graph(input_graph_def, output_node_names)
278
279 # Identify the ops in the graph.
c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\util\deprecation.py in new_func(*args, **kwargs)
322 'in a future version' if date is None else ('after %s' % date),
323 instructions)
--> 324 return func(*args, **kwargs)
325 return tf_decorator.make_decorator(
326 func, new_func, 'deprecated',
c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\framework\graph_util_impl.py in extract_sub_graph(graph_def, dest_nodes)
195 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
196 graph_def)
--> 197 _assert_nodes_are_present(name_to_node, dest_nodes)
198
199 nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name)
c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\framework\graph_util_impl.py in _assert_nodes_are_present(name_to_node, nodes)
150 """Assert that nodes are present in the graph."""
151 for d in nodes:
--> 152 assert d in name_to_node, "%s is not in graph" % d
153
154
**AssertionError: output_node/Identity is not in graph**
Я пробовал, но я действительно не знаю, как это исправить, поэтому любая помощь будет очень признательна.
1 ответ
Если вы используете Tensorflow версии 2.x, добавьте:
tf.compat.v1.disable_eager_execution()
Это должно работать. Я не проверял получившийся pb файл, но он должен работать.
Обратная связь приветствуется.
edit: Однако, следуя, например, этой теме, pb-файлы TF1 и TF2 принципиально отличаются. Мое решение может не работать должным образом или фактически создать файл pb TF1.
Если вы затем столкнетесь с
RuntimeError: попытка использовать закрытый сеанс.
Это можно решить, перезапустив ядро. У вас есть только один выстрел, используя строку выше.