КАК точно настроить предварительно обученную модель для пользовательских данных

Я хочу использовать предварительно подготовленную модель сегментации niftynet для сегментации пользовательских данных. Я скачал предварительно обученные веса и изменил путь model_dir к загруженному.

Однако когда я бегу

python3 net_segment.py train -c /home/Container_data/config/promise12_demo_train_config.ini

Я получаю ошибку ниже.

Caused by op 'save/Assign_17', defined at:
    File "net_segment.py", line 8, in <module>
      sys.exit(main())
    File "/home/NiftyNet/niftynet/__init__.py", line 142, in main
      app_driver.run(app_driver.app)
    File "/home/NiftyNet/niftynet/engine/application_driver.py", line 197, in run
      SESS_STARTED.send(application, iter_msg=None)
    File "/usr/local/lib/python3.5/dist-packages/blinker/base.py", line 267, in send
      for receiver in self.receivers_for(sender)]
    File "/usr/local/lib/python3.5/dist-packages/blinker/base.py", line 267, in <listcomp>
      for receiver in self.receivers_for(sender)]
    File "/home/NiftyNet/niftynet/engine/handler_model.py", line 109, in restore_model
      var_list=to_restore, save_relative_paths=True)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1102, in __init__
      self.build()
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1114, in build
      self._build(self._filename, build_save=True, build_restore=True)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1151, in _build
      build_save=build_save, build_restore=build_restore)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 795, in _build_internal
      restore_sequentially, reshape)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 428, in _AddRestoreOps
      assign_ops.append(saveable.restore(saveable_tensors, shapes))
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 119, in restore
      self.op.get_shape().is_fully_defined())
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/state_ops.py", line 221, in assign
      validate_shape=validate_shape)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_state_ops.py", line 61, in assign
      use_locking=use_locking, name=name)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
      op_def=op_def)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
      return func(*args, **kwargs)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3274, in create_op
      op_def=op_def)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1770, in __init__
      self._traceback = tf_stack.extract_stack()
  InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
  Assign requires shapes of both tensors to match. lhs shape= [3,3,61,256] rhs shape= [3,3,3,61,9]
           [[node save/Assign_17 (defined at /home/NiftyNet/niftynet/engine/handler_model.py:109)  = Assign[T=DT_FLOAT, _class=["loc:@DenseVNet/conv/conv_/w"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](DenseVNet/conv/conv_/w, save/RestoreV2/_35)

https://github.com/tensorflow/models/issues/5390 Над ссылкой сказано добавить

--initialize_last_layer = False
--last_layers_contain_logits_only = False

Может кто-нибудь помочь мне, как избавиться от этой ошибки.

1 ответ

Кажется, у вас проблемы с вашим последним слоем. Когда вы используете предварительно обученную модель в новой задаче, вам, вероятно, нужно изменить последний слой, чтобы он соответствовал вашим новым требованиям.

Чтобы сделать это, вы должны изменить свой конфигурационный файл, восстановив все переменные, кроме последнего слоя: vars_to_restore = ^((?!(last_layer_name)).)*$

а затем установить num_classes чтобы удовлетворить вашу новую проблему сегментации.

Вы можете проверить учебные документы по передаче здесь: https://niftynet.readthedocs.io/en/dev/transfer_learning.html

Другие вопросы по тегам