В Tensorflow, получить имена всех тензоров в графе
Я создаю нейронные сети с Tensorflow
а также skflow
; по какой-то причине я хочу получить значения некоторых внутренних тензоров для данного ввода, поэтому я использую myClassifier.get_layer_value(input, "tensorName")
, myClassifier
быть skflow.estimators.TensorFlowEstimator
,
Однако мне трудно найти правильный синтаксис имени тензора, даже зная его имя (и я путаюсь между операцией и тензорами), поэтому я использую тензорную доску для построения графика и поиска имени.
Есть ли способ перечислить все тензоры в графе без использования тензорной доски?
6 ответов
Ты можешь сделать
[n.name for n in tf.get_default_graph().as_graph_def().node]
Кроме того, если вы создаете прототипы в записной книжке IPython, вы можете отобразить график прямо в записной книжке, см. show_graph
функция в записной книжке Александра Deep Dream
Попробую обобщить ответы:
Чтобы получить все узлы (введитеtensorflow.core.framework.node_def_pb2.NodeDef
):
all_nodes = [n for n in tf.get_default_graph().as_graph_def().node]
Чтобы получить все операции (введитеtensorflow.python.framework.ops.Operation
):
all_ops = tf.get_default_graph().get_operations()
Чтобы получить все переменные (введитеtensorflow.python.ops.resource_variable_ops.ResourceVariable
):
all_vars = tf.global_variables()
Чтобы получить все тензоры (введитеtensorflow.python.framework.ops.Tensor
):
all_tensors = [tensor for op in tf.get_default_graph().get_operations() for tensor in op.values()]
Есть способ сделать это немного быстрее, чем в ответе Ярослава, используя get_operations. Вот быстрый пример:
import tensorflow as tf
a = tf.constant(1.3, name='const_A')
b = tf.Variable(3.1, name='b')
c = tf.add(a, b, name='addition')
d = tf.multiply(c, a, name='multiply')
for op in tf.get_default_graph().get_operations():
print str(op.name)
tf.all_variables()
может получить вам информацию, которую вы хотите.
Кроме того, этот коммит, сделанный сегодня в TensorFlow Learn, предоставляет функцию get_variable_names
в оценщике, который вы можете использовать, чтобы легко получить все имена переменных.
Я думаю, что это тоже будет делать:
print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))
Но по сравнению с ответами Сальвадо и Ярослава я не знаю, какой из них лучше.
Принятый ответ дает только список строк с именами. Я предпочитаю другой подход, который дает вам (почти) прямой доступ к тензорам:
graph = tf.get_default_graph()
list_of_tuples = [op.values() for op in graph.get_operations()]
list of tuples
теперь содержит каждый тензор, каждый внутри кортежа. Вы также можете адаптировать его для получения тензоров напрямую:
graph = tf.get_default_graph()
list_of_tuples = [op.values()[0] for op in graph.get_operations()]
Поскольку OP запрашивает список тензоров вместо списка операций / узлов, код должен немного отличаться:
graph = tf.get_default_graph()
tensors_per_node = [node.values() for node in graph.get_operations()]
tensor_names = [tensor.name for tensors in tensors_per_node for tensor in tensors]
Предыдущие ответы хороши, я просто хотел бы поделиться функцией полезности, которую я написал, чтобы выбрать Тензор из графика:
def get_graph_op(graph, and_conds=None, op='and', or_conds=None):
"""Selects nodes' names in the graph if:
- The name contains all items in and_conds
- OR/AND depending on op
- The name contains any item in or_conds
Condition starting with a "!" are negated.
Returns all ops if no optional arguments is given.
Args:
graph (tf.Graph): The graph containing sought tensors
and_conds (list(str)), optional): Defaults to None.
"and" conditions
op (str, optional): Defaults to 'and'.
How to link the and_conds and or_conds:
with an 'and' or an 'or'
or_conds (list(str), optional): Defaults to None.
"or conditions"
Returns:
list(str): list of relevant tensor names
"""
assert op in {'and', 'or'}
if and_conds is None:
and_conds = ['']
if or_conds is None:
or_conds = ['']
node_names = [n.name for n in graph.as_graph_def().node]
ands = {
n for n in node_names
if all(
cond in n if '!' not in cond
else cond[1:] not in n
for cond in and_conds
)}
ors = {
n for n in node_names
if any(
cond in n if '!' not in cond
else cond[1:] not in n
for cond in or_conds
)}
if op == 'and':
return [
n for n in node_names
if n in ands.intersection(ors)
]
elif op == 'or':
return [
n for n in node_names
if n in ands.union(ors)
]
Так что если у вас есть график с опс:
['model/classifier/dense/kernel',
'model/classifier/dense/kernel/Assign',
'model/classifier/dense/kernel/read',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd',
'model/classifier/ArgMax/dimension',
'model/classifier/ArgMax']
Потом работает
get_graph_op(tf.get_default_graph(), ['dense', '!kernel'], 'or', ['Assign'])
возвращает:
['model/classifier/dense/kernel/Assign',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd']
Следующее решение работает для меня в TensorFlow 2.3 -
def load_pb(path_to_pb):
with tf.io.gfile.GFile(path_to_pb, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
tf_graph = load_pb(MODEL_FILE)
sess = tf.compat.v1.Session(graph=tf_graph)
# Show tensor names in graph
for op in tf_graph.get_operations():
print(op.values())
где
MODEL_FILE
путь к вашему замороженному графику.
Взято отсюда.
Это сработало для меня:
for n in tf.get_default_graph().as_graph_def().node:
print('\n',n)