Обучение продолжает останавливаться. Ошибка кортежа. (API Tensorflow Object_detection)
Я использую API обнаружения объектов tenorflows, когда я выполняю обучение, оно останавливается после нескольких итераций. Изначально у меня были мои изображения в формате jpg, из которых я создавал xml-файлы, преобразованные в csv, однако, люди упоминали, что причиной ошибки может быть то, что она была в jpg, а не в jpeg (хотя другие заставили его работать в формате ni jpg). Затем я преобразовал свои изображения в JPEG и выполнил остальные шаги, затем приступил к тренировкам, и возникла та же проблема. Я застрял в этом вопросе ооочень долго, но безрезультатно, и кажется, что там не так много рабочих решений. Если у кого-то есть идея, чтобы решить эту проблему, я был бы чрезвычайно благодарен. Код ниже
Instructions for updating:
Please switch to tf.train.get_or_create_global_step
WARNING:root:Variable [Conv/biases/Momentum] is not available in checkpoint
WARNING:root:Variable [Conv/weights/Momentum] is not available in checkpoint
WARNING:root:Variable [FirstStageBoxPredictor/BoxEncodingPredictor/biases/Momentum] is not available in checkpoint
WARNING:root:Variable [FirstStageBoxPredictor/BoxEncodingPredictor/weights/Momentum] is not available in checkpoint
....
INFO:tensorflow:global step 1: loss = 1.6760 (13.660 sec/step)
INFO:tensorflow:global step 1: loss = 1.6760 (13.660 sec/step)
INFO:tensorflow:Finished training! Saving model to disk.
INFO:tensorflow:Finished training! Saving model to disk.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/summary/writer/writer.py:386: UserWarning: Attempting to use a closed FileWriter. The operation will be a noop unless the FileWriter is explicitly reopened.
warnings.warn("Attempting to use a closed FileWriter. "
Traceback (most recent call last):
File "train.py", line 185, in <module>
tf.app.run()
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/platform/app.py", line 125, in run
_sys.exit(main(argv))
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py", line 324, in new_func
return func(*args, **kwargs)
File "train.py", line 181, in main
graph_hook_fn=graph_rewriter_fn)
File "/usr/local/lib/python3.6/dist-packages/object_detection-0.1-py3.6.egg/object_detection/legacy/trainer.py", line 416, in train
saver=saver)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/slim/python/slim/learning.py", line 785, in train
ignore_live_threads=ignore_live_threads)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/supervisor.py", line 832, in stop
ignore_live_threads=ignore_live_threads)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/coordinator.py", line 389, in join
six.reraise(*self._exc_info_to_raise)
File "/usr/local/lib/python3.6/dist-packages/six.py", line 693, in reraise
raise value
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/queue_runner_impl.py", line 257, in _run
enqueue_callable()
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1257, in _single_operation_run
self._call_tf_sessionrun(None, {}, [], target_list, None)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1407, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Shape mismatch in tuple component 18. Expected [1,?,?,3], got [1,1,314,384,3]
[[{{node batch/padding_fifo_queue_enqueue}}]]
1 ответ
Эта строка должна дать вам подсказку: Expected [1,?,?,3], got [1,1,314,384,3]
Tensorflow использует 4D Tensors в качестве входного изображения модели, поэтому Tensor размера [1,?,?,3]
ожидается. Тем не менее, вы предоставляете 5D Tensor. Я думаю, что есть один tf.expand_dims()
много в вашем коде где-то.
Для тех, кто сталкивается с этой проблемой, проверьте ваш поезд и протестируйте файлы CSV, чтобы увидеть, есть ли какие-либо записи с шириной и высотой равными 0. Это обычно происходит, если изображение имеет другой формат с его расширением. Решите проблему, удалив эти изображения или преобразовав их в нужный формат, используя -
img = cv2.imread(test_full_path)
cv2.imwrite(test_full_path, img, [int(cv2.IMWRITE_JPEG_QUALITY), 100])