Сохранение и восстановление функций в TensorFlow
Я работаю над проектом VAE в TensorFlow, где сети кодера / декодера встроены в функции. Идея состоит в том, чтобы сохранить, затем загрузить обученную модель и выполнить выборку, используя функцию энкодера.
После восстановления модели у меня возникают проблемы с запуском функции декодера и возвращением восстановленных обученных переменных с ошибкой "Неинициализированное значение". Я предполагаю, что это потому, что функция либо создает новый, перезаписывает существующий, либо иным образом. Но я не могу понять, как это решить. Вот некоторый код:
class VAE(object):
def __init__(self, restore=True):
self.session = tf.Session()
if restore:
self.restore_model()
self.build_decoder = tf.make_template('decoder', self._build_decoder)
@staticmethod
def _build_decoder(z, output_size=768, hidden_size=200,
hidden_activation=tf.nn.elu, output_activation=tf.nn.sigmoid):
x = tf.layers.dense(z, hidden_size, activation=hidden_activation)
x = tf.layers.dense(x, hidden_size, activation=hidden_activation)
logits = tf.layers.dense(x, output_size, activation=output_activation)
return distributions.Independent(distributions.Bernoulli(logits), 2)
def sample_decoder(self, n_samples):
prior = self.build_prior(self.latent_dim)
samples = self.build_decoder(prior.sample(n_samples), self.input_size).mean()
return self.session.run([samples])
def restore_model(self):
print("Restoring")
self.saver = tf.train.import_meta_graph(os.path.join(self.save_dir, "turbolearn.meta"))
self.saver.restore(self.sess, tf.train.latest_checkpoint(self.save_dir))
self._restored = True
хочу бежать samples = vae.sample_decoder(5)
В своей тренировочной программе я бегу:
if self.checkpoint:
self.saver.save(self.session, os.path.join(self.save_dir, "myvae"), write_meta_graph=True)
ОБНОВИТЬ
Основываясь на предложенном ответе ниже, я изменил метод восстановления
self.saver = tf.train.Saver()
self.saver.restore(self.session, tf.train.latest_checkpoint(self.save_dir))
Но теперь получим ошибку значения при создании объекта Saver():
ValueError: No variables to save
1 ответ
tf.train.import_meta_graph
восстанавливает график, то есть восстанавливает сетевую архитектуру, которая была сохранена в файл. Призыв к tf.train.Saver.restore
с другой стороны, только восстанавливают значения переменных из файла в текущий граф в сеансе (это, естественно, дает сбой, если некоторые значения в файле принадлежат переменным, которых нет в текущем активном графе).
Так что если вы уже строите сетевые уровни в коде, вам не нужно вызывать tf.train.import_meta_graph
, В противном случае это может вызвать у вас проблемы.
Не уверен, как выглядит остальная часть вашего кода, но вот несколько советов. Сначала создайте график, затем создайте сеанс и, наконец, восстановите, если применимо. Тогда ваш инициал может выглядеть так
def __init__(self, restore=True):
self.build_decoder = tf.make_template('decoder', self._build_decoder)
self.session = tf.Session()
if restore:
self.restore_model()
Однако, если вы только восстанавливаете кодер и строите декодер заново, вы можете построить декодер последним. Но затем не забудьте инициализировать его переменные перед использованием.