Как загрузить обученную модель, сохраненную с помощью export_inference_graph.py?
Я следую примеру, в котором используется API обнаружения объектов tenorflow 1.15.0. В учебнике четко рассматриваются следующие аспекты:
- как скачать модель
- как загрузить собственную базу данных с файлами.xml, сделать из них файлы.cvs, а затем файлы.record
- как настроить конвейер обучения
- как получить графики тензорной доски
- как обучить контрольные точки чистых сбережений (используя model_main.py)
- как экспортировать (сохранить) модель (используя export_inference_graph.py)
Однако я не смог загрузить сохраненную модель, чтобы использовать ее. Я пробовал сtf.saved_model.loader.load(sess, flags, export_dir
, но я получаю
INFO:tensorflow:Saver not created because there are no variables in the graph to restore.
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.
папка, указанная в export_dir
имеет следующую структуру:
+dir
+saved_model
-saved_model.pb
-model.ckpt.data-00000-of-00001
-model.ckpt.index
-checkpoint
-frozen_inference_graph.pb
-model.ckpt.meta
-pipeline.config
Моя конечная цель здесь - захватывать изображения с помощью камеры и передавать их в сеть для обнаружения объектов в реальном времени.\ В качестве промежуточного шага теперь я просто хочу иметь возможность передавать одно изображение и получать результат. Я смог натренировать сетку, но теперь не могу ею пользоваться.
Заранее спасибо.
1 ответ
Я нашел пример того, как загрузить модель, который позволил мне пройти через это.\ Поскольку формат папки загружаемого в примере файла такой же, как и в моем коде, мне просто пришлось его адаптировать.
Исходная функция, загружающая модель:
def load_model(model_name):
base_url = 'http://download.tensorflow.org/models/object_detection/'
model_file = model_name + '.tar.gz'
model_dir = tf.keras.utils.get_file(
fname=model_name,
origin=base_url + model_file,
untar=True)
model_dir = pathlib.Path(model_dir)/"saved_model"
model = tf.saved_model.load(str(model_dir))
model = model.signatures['serving_default']
return model
Затем я использовал эту функцию для создания этого нового
def load_local_model(model_path):
model_dir = pathlib.Path(model_path)/"saved_model"
model = tf.saved_model.load(str(model_dir))
model = model.signatures['serving_default']
return model
Сначала это не сработало, так как tf.saved_model.load
ожидалось 3 аргумента, но это было решено путем импорта двух блоков импорта в том же примере, я не знаю, какой импорт сделал трюк и почему (я отредактирую этот ответ, когда получу его), но на данный момент этот код работает и пример позволяет делать больше вещей.
Блоки импорта следующие
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from IPython.display import display
а также
from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
РЕДАКТИРОВАТЬ Что действительно нужно для этого, так это следующий блок.
import os
import pathlib
if "models" in pathlib.Path.cwd().parts:
while "models" in pathlib.Path.cwd().parts:
os.chdir('..')
elif not pathlib.Path('models').exists():
!git clone --depth 1 https://github.com/tensorflow/models
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
%%bash
cd models/research
pip install .
Otherwhise этот блок импорта не будет работать
from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util