Как восстановить модель в Tensorflow
Сначала я обучил модель, используя tf.contrib.gan, как показано ниже, и я смог обучить модель.
tf.contrib.gan.gan_train(
train_ops,
hooks=(
[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
tf.train.LoggingTensorHook([status_message], every_n_iter=10)] +
sync_hooks),
logdir=FLAGS.train_log_dir,
master=FLAGS.master,
is_chief=FLAGS.task == 0,
config=conf
)
Тогда я хочу оценить модель. При попытке восстановить контрольные точки следующим образом,
with tf.name_scope('inputs'):
real_images, one_hot_labels, _, num_classes = data_provider.provide_data(
FLAGS.batch_size, FLAGS.dataset_dir)
logits, end_points_des, feature, net = dcgan.discriminator(real_images)
variables_to_restore = slim.get_model_variables()
restorer = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
restorer.restore(sess, ckpt.model_checkpoint_path)
Я получаю это исключение:
2018-04-11 20:05:03.304089: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Discriminator/fully_connected_layer2/weights not found in checkpoint
2018-04-11 20:05:03.304280: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Discriminator/conv0/BatchNorm/Discriminator/conv0/BatchNorm/moving_mean/local_step not found in checkpoint
2018-04-11 20:05:03.304484: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Discriminator/conv0/BatchNorm/beta not found in checkpoint
2018-04-11 20:05:03.305197: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Discriminator/fully_connected_layer2/biases not found in checkpoint
Я использую TF 1.7rc1
1 ответ
На самом деле, возникла проблема в сгенерированном графе. Это шаги, которые я сделал, чтобы решить эту проблему.
Шаг 1: распечатать все переменные в файле контрольных точек, используя следующий код
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file(file_name, '')
Шаг 2: Затем я заметил, что каждый ключ состоит из дублирования первой области ("Дискриминатор"), которая была установлена, но когда я пытаюсь загрузить модель, она не состоит из этой части. Таким образом, мне пришлось удалить эту дополнительную часть следующим образом,
def name_in_checkpoint(var):
if "Discriminator/" in var.op.name:
return var.op.name.replace("Discriminator/", "Discriminator/Discriminator/")
logits, end_points_des, feature, net = dcgan.discriminator(real_images)
variables_to_restore = slim.get_model_variables()
variables_to_restore = {name_in_checkpoint(var): var for var in variables_to_restore}
restorer = tf.train.Saver(variables_to_restore)
Шаг 3: Тогда я смог загрузить модель, как показано ниже.
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
restorer.restore(sess, ckpt.model_checkpoint_path)
Надеюсь, что это поможет кому-то, кто может столкнуться с той же проблемой.