IndexError: список индексов вне диапазона при сохранении модели keras

Я пытаюсь сохранить модель Keras/Tensorflow через этот код:

def export_model(saver, model, input_node_names, output_node_name):
    MODEL_NAME = 'lyme_model'
    tf.train.write_graph(K.get_session().graph_def, 'out', \
    MODEL_NAME + '_graph.pbtxt')

    saver.save(K.get_session(), 'out/' + MODEL_NAME + '.chkp')

    input_graph_path = 'out/'+MODEL_NAME+'.pbtxt'
    checkpoint_path = 'out/'+MODEL_NAME+'.chkp'
    input_saver_def_path = ""
    input_binary = False
    restore_op_name = "save/restore_all"
    filename_tensor_name = "save/Const:0"
    output_frozen_graph_name = 'out/frozen_'+MODEL_NAME+'.pb'
    clear_devices = True


    freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                      input_binary, checkpoint_path, output_node_name,
                      restore_op_name, filename_tensor_name,
                      output_frozen_graph_name, clear_devices, "")


    input_graph_def = tf.GraphDef()
    with tf.gfile.Open('out/frozen_' + MODEL_NAME + '.pb', "rb") as f:
         input_graph_def.ParseFromString(f.read())

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def, input_node_names, [output_node_name],
            tf.float32.as_datatype_enum)

    with tf.gfile.FastGFile('out/opt_' + MODEL_NAME + '.pb', "wb") as f:
         f.write(output_graph_def.SerializeToString())

    print("graph saved!")

export_model(tf.train.Saver(), model, ["input_1"], "prediction")

и я получаю следующее сообщение об ошибке от Jupyter Notebook:

IndexErrorTraceback (most recent call last)
<ipython-input-24-31885f61798a> in <module>()
      4 #print([node.op.name for node in model.outputs])
      5 #model.summary()
----> 6 export_model(tf.train.Saver(), model, ["input_1"], ["prediction"])

<ipython-input-22-f4d33dc4c333> in export_model(saver, model, input_node_names, output_node_name)
     22                           input_binary, checkpoint_path, output_node_name,
     23                           restore_op_name, filename_tensor_name,
---> 24                           output_frozen_graph_name, clear_devices, "")
     25 
     26 

/usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/freeze_graph.pyc in freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist, variable_names_blacklist, input_meta_graph, input_saved_model_dir, saved_model_tags, checkpoint_version)
    252       input_saved_model_dir,
    253       saved_model_tags.replace(" ", "").split(","),
--> 254       checkpoint_version=checkpoint_version)
    255 
    256 

/usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/freeze_graph.pyc in freeze_graph_with_def_protos(***failed resolving arguments***)
    126         var_list[key] = tensor
    127       saver = saver_lib.Saver(
--> 128           var_list=var_list, write_version=checkpoint_version)
    129       saver.restore(sess, input_checkpoint)
    130       if initializer_nodes:

/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in __init__(self, var_list, reshape, sharded, max_to_keep, keep_checkpoint_every_n_hours, name, restore_sequentially, saver_def, builder, defer_build, allow_empty, write_version, pad_step_number, save_relative_paths, filename)
   1336           time.time() + self._keep_checkpoint_every_n_hours * 3600)
   1337     elif not defer_build:
-> 1338       self.build()
   1339     if self.saver_def:
   1340       self._check_saver_def()

/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in build(self)
   1345     if context.executing_eagerly():
   1346       raise RuntimeError("Use save/restore instead of build in eager mode.")
-> 1347     self._build(self._filename, build_save=True, build_restore=True)
   1348 
   1349   def _build_eager(self, checkpoint_path, build_save, build_restore):

/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in _build(self, checkpoint_path, build_save, build_restore)
   1382           restore_sequentially=self._restore_sequentially,
   1383           filename=checkpoint_path,
-> 1384           build_save=build_save, build_restore=build_restore)
   1385     elif self.saver_def and self._name:
   1386       # Since self._name is used as a name_scope by builder(), we are

/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in _build_internal(self, names_to_saveables, reshape, sharded, max_to_keep, keep_checkpoint_every_n_hours, name, restore_sequentially, filename, build_save, build_restore)
    811                        " when eager execution is not enabled.")
    812 
--> 813     saveables = self._ValidateAndSliceInputs(names_to_saveables)
    814     if max_to_keep is None:
    815       max_to_keep = 0

/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in _ValidateAndSliceInputs(self, names_to_saveables)
    718           else:
    719             saveable = BaseSaverBuilder.ResourceVariableSaveable(
--> 720                 variable, "", name)
    721         self._AddSaveable(saveables, seen_ops, saveable)
    722     return saveables

/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in __init__(self, var, slice_spec, name)
    192       self._var_shape = var.shape
    193       if isinstance(var, ops.Tensor):
--> 194         self.handle_op = var.op.inputs[0]
    195         tensor = var
    196       elif isinstance(var, resource_variable_ops.ResourceVariable):

/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.pyc in __getitem__(self, i)
   2101 
   2102     def __getitem__(self, i):
-> 2103       return self._inputs[i]
   2104 
   2105 # pylint: enable=protected-access

IndexError: list index out of range

Что могло вызвать эту проблему? Я даже нигде не вижу операций со списками, поэтому я предполагаю, что они происходят где-то в бэкэнде Keras/Tensorflow. Интересно, как это исправить. Спасибо!

PS метод export_model() был взят отсюда.

0 ответов

Другие вопросы по тегам