IndexError: список индексов вне диапазона при сохранении модели keras
Я пытаюсь сохранить модель Keras/Tensorflow через этот код:
def export_model(saver, model, input_node_names, output_node_name):
MODEL_NAME = 'lyme_model'
tf.train.write_graph(K.get_session().graph_def, 'out', \
MODEL_NAME + '_graph.pbtxt')
saver.save(K.get_session(), 'out/' + MODEL_NAME + '.chkp')
input_graph_path = 'out/'+MODEL_NAME+'.pbtxt'
checkpoint_path = 'out/'+MODEL_NAME+'.chkp'
input_saver_def_path = ""
input_binary = False
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_frozen_graph_name = 'out/frozen_'+MODEL_NAME+'.pb'
clear_devices = True
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path, output_node_name,
restore_op_name, filename_tensor_name,
output_frozen_graph_name, clear_devices, "")
input_graph_def = tf.GraphDef()
with tf.gfile.Open('out/frozen_' + MODEL_NAME + '.pb', "rb") as f:
input_graph_def.ParseFromString(f.read())
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def, input_node_names, [output_node_name],
tf.float32.as_datatype_enum)
with tf.gfile.FastGFile('out/opt_' + MODEL_NAME + '.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
print("graph saved!")
export_model(tf.train.Saver(), model, ["input_1"], "prediction")
и я получаю следующее сообщение об ошибке от Jupyter Notebook:
IndexErrorTraceback (most recent call last)
<ipython-input-24-31885f61798a> in <module>()
4 #print([node.op.name for node in model.outputs])
5 #model.summary()
----> 6 export_model(tf.train.Saver(), model, ["input_1"], ["prediction"])
<ipython-input-22-f4d33dc4c333> in export_model(saver, model, input_node_names, output_node_name)
22 input_binary, checkpoint_path, output_node_name,
23 restore_op_name, filename_tensor_name,
---> 24 output_frozen_graph_name, clear_devices, "")
25
26
/usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/freeze_graph.pyc in freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist, variable_names_blacklist, input_meta_graph, input_saved_model_dir, saved_model_tags, checkpoint_version)
252 input_saved_model_dir,
253 saved_model_tags.replace(" ", "").split(","),
--> 254 checkpoint_version=checkpoint_version)
255
256
/usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/freeze_graph.pyc in freeze_graph_with_def_protos(***failed resolving arguments***)
126 var_list[key] = tensor
127 saver = saver_lib.Saver(
--> 128 var_list=var_list, write_version=checkpoint_version)
129 saver.restore(sess, input_checkpoint)
130 if initializer_nodes:
/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in __init__(self, var_list, reshape, sharded, max_to_keep, keep_checkpoint_every_n_hours, name, restore_sequentially, saver_def, builder, defer_build, allow_empty, write_version, pad_step_number, save_relative_paths, filename)
1336 time.time() + self._keep_checkpoint_every_n_hours * 3600)
1337 elif not defer_build:
-> 1338 self.build()
1339 if self.saver_def:
1340 self._check_saver_def()
/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in build(self)
1345 if context.executing_eagerly():
1346 raise RuntimeError("Use save/restore instead of build in eager mode.")
-> 1347 self._build(self._filename, build_save=True, build_restore=True)
1348
1349 def _build_eager(self, checkpoint_path, build_save, build_restore):
/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in _build(self, checkpoint_path, build_save, build_restore)
1382 restore_sequentially=self._restore_sequentially,
1383 filename=checkpoint_path,
-> 1384 build_save=build_save, build_restore=build_restore)
1385 elif self.saver_def and self._name:
1386 # Since self._name is used as a name_scope by builder(), we are
/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in _build_internal(self, names_to_saveables, reshape, sharded, max_to_keep, keep_checkpoint_every_n_hours, name, restore_sequentially, filename, build_save, build_restore)
811 " when eager execution is not enabled.")
812
--> 813 saveables = self._ValidateAndSliceInputs(names_to_saveables)
814 if max_to_keep is None:
815 max_to_keep = 0
/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in _ValidateAndSliceInputs(self, names_to_saveables)
718 else:
719 saveable = BaseSaverBuilder.ResourceVariableSaveable(
--> 720 variable, "", name)
721 self._AddSaveable(saveables, seen_ops, saveable)
722 return saveables
/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in __init__(self, var, slice_spec, name)
192 self._var_shape = var.shape
193 if isinstance(var, ops.Tensor):
--> 194 self.handle_op = var.op.inputs[0]
195 tensor = var
196 elif isinstance(var, resource_variable_ops.ResourceVariable):
/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.pyc in __getitem__(self, i)
2101
2102 def __getitem__(self, i):
-> 2103 return self._inputs[i]
2104
2105 # pylint: enable=protected-access
IndexError: list index out of range
Что могло вызвать эту проблему? Я даже нигде не вижу операций со списками, поэтому я предполагаю, что они происходят где-то в бэкэнде Keras/Tensorflow. Интересно, как это исправить. Спасибо!
PS метод export_model() был взят отсюда.