Tf-trt convert saved_model.pb не удалось
Вот мой код конвертации.
import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow import gfile, compat
from tensorflow.core.protobuf import saved_model_pb2
# from tensorflow.core.protobuf import meta_graph_pb2
with tf.Session() as sess:
converter = trt.TrtGraphConverter(input_saved_model_dir=".",
input_saved_model_signature_key='predictor',
input_saved_model_tags=[tf.saved_model.tag_constants.SERVING],
is_dynamic_op=True,
precision_mode="FP16"
)
converter.convert()
converter.save("trt_model.pb")
И при его выполнении выдает исключения:
20066 Traceback (most recent call last):
20067 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 501, in _import_gr aph_def_internal
20068 graph._c_graph, serialized, options) # pylint: disable=protected-access
20069 tensorflow.python.framework.errors_impl.InvalidArgumentError: Input 0 of node global_step/Assign was passed int64 from global_step:0 incompatible with expected int64_ref.
20070
20071 During handling of the above exception, another exception occurred:
20072
20073 Traceback (most recent call last):
20074 File "converter.py", line 24, in <module>
20075 converter.save("trt_model.pb")
20076 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/compiler/tensorrt/trt_convert.py", line 717, in save
20077 importer.import_graph_def(self._converted_graph_def, name="")
20078 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
20079 return func(*args, **kwargs)
20080 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 405, in import_gra ph_def
20081 producer_op_list=producer_op_list)
20082 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 505, in _import_gr aph_def_internal
20083 raise ValueError(str(e))
20084 ValueError: Input 0 of node global_step/Assign was passed int64 from global_step:0 incompatible with expected int64_ref
Проверяю узел графа. перед преобразованием это похоже на
···
global_step/Initializer/zeros Const
global_step VariableV2
global_step/Assign Assign
global_step/read Identity
···
после
···
IteratorGetNext IteratorGetNext
global_step/Assign Assign
inference/embed_continuous/Initializer/truncated_normal/TruncatedNormal TruncatedNormal
···
более конкретно перед
name: "global_step/Assign"
op: "Assign"
input: "global_step"
input: "global_step/Initializer/zeros"
attr {
key: "T"
value {
type: DT_INT64
}
}
attr {
key: "_class"
value {
list {
s: "loc:@global_step"
}
}
}
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
attr {
key: "use_locking"
value {
b: true
}
}
attr {
key: "validate_shape"
value {
b: true
}
}
после
name: "global_step/Assign"
op: "Assign"
input: "global_step"
input: "global_step/Initializer/zeros"
attr {
key: "T"
value {
type: DT_INT64
}
}
attr {
key: "use_locking"
value {
b: true
}
}
attr {
key: "validate_shape"
value {
b: true
}
}
должно быть что-то не так с моим кодом или моделью, но я не могу понять. Кто-нибудь может мне помочь?
Похоже, ваш пост - это в основном код; пожалуйста, добавьте более подробную информацию....