Как извлечь функции из нескольких контрольных точек в tenorfow-slim?

У меня две группы A и B прошли обучение и получили модельные контрольные точки model_A.ckpt, а также model_B.ckpt соответственно, используя тот же [inception_v1] сеть ( ссылка).

У меня есть два входа, которые приходят из двух групп A и B. Для каждого изображения я хочу извлечь последние функции (global_pool), которые соответствуют его контрольной точке (то есть изображение из группы A будет загружаться из model_A.ckpt, и так далее). Затем global_pool группы A и global_pool группы B объединятся вместе, чтобы создать объединенный global_pool для другого этапа обучения. Как мы могли бы сделать это в Tensorflow с TF-Slim?

Это мое текущее решение, но оно не может работать

import tensorflow as tf
slim = tf.contrib.slim
from PIL import Image
from nets.inception_v1 import *
import numpy as np

checkpoint_file_A = 'model_A.ckpt-40000'
sample_images_A = ['1_A.jpg']

checkpoint_file_B = 'model_B.ckpt-40000'
sample_images_B = ['1_B.jpg']

input_tensor = tf.placeholder(tf.float32, shape=(None,299,299,3), name='input_image')
scaled_input_tensor = tf.scalar_mul((1.0/255), input_tensor)
scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5)
scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0)    
#Load the model
sess_A = tf.Session()
arg_scope_A = inception_v1_arg_scope()
with slim.arg_scope(arg_scope_A):
  logits, end_points = inception_v1(scaled_input_tensor, num_classes=3, is_training=False)
saver_A = tf.train.Saver()
saver_A.restore(sess_A, checkpoint_file_A)    
sess_B = tf.Session()
arg_scope_B = inception_v1_arg_scope()
with slim.arg_scope(arg_scope_B):
  logits, end_points = inception_v1(scaled_input_tensor, num_classes=3, is_training=False)
saver_B = tf.train.Saver()
saver_B.restore(sess_B, checkpoint_file_B) 

for image in sample_images_A:
  im = Image.open(image).resize((299,299))
  im = np.array(im)
  im = im.reshape(-1,299,299,3)
  global_pool, _, _ = sess_A.run([end_points['global_pool']], feed_dict={input_tensor: im})
  print ( global_pool.shape)

for image in sample_images_B:
  im = Image.open(image).resize((299,299))
  im = np.array(im)
  im = im.reshape(-1,299,299,3)
  global_pool, _, _ = sess_B.run([end_points['global_pool']], feed_dict={input_tensor: im})
  print ( global_pool.shape)

0 ответов

Другие вопросы по тегам