Загрузка предварительно обученного ВГГ-16 на тензор потока

Я пытаюсь загрузить предварительно обученную сеть vgg-16 с помощью tenorflow r1.1. Сеть представлена ​​в 3 файлах:

  • saved_model.pb
  • Переменные /variables.index
  • Переменные /variables.data-00000-оф-00001

После инициализации переменных sess как tf.Session()

Я использую следующий скрипт для загрузки сети и извлечения некоторых специфических слоев:

vgg_path='./'
model_filename = os.path.join(vgg_path, "saved_model.pb")
export_dir = os.path.join(vgg_path, "variables/")

with gfile.FastGFile(model_filename, 'rb') as f:
    data = compat.as_bytes(f.read())
    sm = saved_model_pb2.SavedModel()
    sm.ParseFromString(data)
    image_input, l7, l4, l3 = tf.import_graph_def(sm.meta_graphs[0].graph_def, 
            name='',return_elements=["image_input:0", "layer7_out:0",
            "layer4_out:0", "layer3_out:0"])

tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, image_input)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, l7)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, l4)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, l3)

saver = tf.train.Saver(tf.global_variables())
print("load data")
saver.restore(sess, export_dir)

Скрипт завершается со следующей ошибкой при инициализации переменной saver:

Ошибка типа: Переменная для сохранения не является Переменной: Тензор ("image_input:0", shape=(?,?,?, 3), dtype=float32)

Как я могу исправить свой сценарий и восстановить предварительно обученную сеть VGG?

1 ответ

Решение

Поскольку у вас есть SavedModel, вы можете использовать tf.saved_model.loader для его загрузки:

with tf.Session() as sess:
    tf.saved_model.loader.load(sess, ["some_tag"], model_dir)
Другие вопросы по тегам