Обучение продолжает останавливаться. Ошибка кортежа. (API Tensorflow Object_detection)

Я использую API обнаружения объектов tenorflows, когда я выполняю обучение, оно останавливается после нескольких итераций. Изначально у меня были мои изображения в формате jpg, из которых я создавал xml-файлы, преобразованные в csv, однако, люди упоминали, что причиной ошибки может быть то, что она была в jpg, а не в jpeg (хотя другие заставили его работать в формате ni jpg). Затем я преобразовал свои изображения в JPEG и выполнил остальные шаги, затем приступил к тренировкам, и возникла та же проблема. Я застрял в этом вопросе ооочень долго, но безрезультатно, и кажется, что там не так много рабочих решений. Если у кого-то есть идея, чтобы решить эту проблему, я был бы чрезвычайно благодарен. Код ниже

Instructions for updating:
Please switch to tf.train.get_or_create_global_step
WARNING:root:Variable [Conv/biases/Momentum] is not available in checkpoint
WARNING:root:Variable [Conv/weights/Momentum] is not available in checkpoint
WARNING:root:Variable [FirstStageBoxPredictor/BoxEncodingPredictor/biases/Momentum] is not available in checkpoint
WARNING:root:Variable [FirstStageBoxPredictor/BoxEncodingPredictor/weights/Momentum] is not available in checkpoint

....

    INFO:tensorflow:global step 1: loss = 1.6760 (13.660 sec/step)
INFO:tensorflow:global step 1: loss = 1.6760 (13.660 sec/step)
INFO:tensorflow:Finished training! Saving model to disk.
INFO:tensorflow:Finished training! Saving model to disk.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/summary/writer/writer.py:386: UserWarning: Attempting to use a closed FileWriter. The operation will be a noop unless the FileWriter is explicitly reopened.
  warnings.warn("Attempting to use a closed FileWriter. "
Traceback (most recent call last):
  File "train.py", line 185, in <module>
    tf.app.run()
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/platform/app.py", line 125, in run
    _sys.exit(main(argv))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py", line 324, in new_func
    return func(*args, **kwargs)
  File "train.py", line 181, in main
    graph_hook_fn=graph_rewriter_fn)
  File "/usr/local/lib/python3.6/dist-packages/object_detection-0.1-py3.6.egg/object_detection/legacy/trainer.py", line 416, in train
    saver=saver)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/slim/python/slim/learning.py", line 785, in train
    ignore_live_threads=ignore_live_threads)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/supervisor.py", line 832, in stop
    ignore_live_threads=ignore_live_threads)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/coordinator.py", line 389, in join
    six.reraise(*self._exc_info_to_raise)
  File "/usr/local/lib/python3.6/dist-packages/six.py", line 693, in reraise
    raise value
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/queue_runner_impl.py", line 257, in _run
    enqueue_callable()
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1257, in _single_operation_run
    self._call_tf_sessionrun(None, {}, [], target_list, None)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1407, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Shape mismatch in tuple component 18. Expected [1,?,?,3], got [1,1,314,384,3]
     [[{{node batch/padding_fifo_queue_enqueue}}]]

1 ответ

Эта строка должна дать вам подсказку: Expected [1,?,?,3], got [1,1,314,384,3]Tensorflow использует 4D Tensors в качестве входного изображения модели, поэтому Tensor размера [1,?,?,3] ожидается. Тем не менее, вы предоставляете 5D Tensor. Я думаю, что есть один tf.expand_dims() много в вашем коде где-то.

Для тех, кто сталкивается с этой проблемой, проверьте ваш поезд и протестируйте файлы CSV, чтобы увидеть, есть ли какие-либо записи с шириной и высотой равными 0. Это обычно происходит, если изображение имеет другой формат с его расширением. Решите проблему, удалив эти изображения или преобразовав их в нужный формат, используя -

img = cv2.imread(test_full_path)
cv2.imwrite(test_full_path, img, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
Другие вопросы по тегам