TensorFlow: потери восстанавливаются после восстановления сети RNN
Информация об окружающей среде
- Операционная система: Windows 7 64-разрядная
- Tensorflow установлен из предварительно построенного пункта (без CUDA): 1.0.1
- Python 3.5.2 64-битный
проблема
У меня проблемы с восстановлением сети (модель базового языка RNN). Ниже приведена упрощенная версия с той же проблемой.
Когда я запускаю его в первый раз, я получаю, например, это.
...
step 160: loss = 1.956 (perplexity = 7.069016620211226)
step 180: loss = 1.837 (perplexity = 6.274748642468816)
step 200: loss = 1.825 (perplexity = 6.202084762557817)
Но на втором запуске, после восстановления параметров, я получаю это.
step 220: loss = 2.346 (perplexity = 10.446611983898903)
step 240: loss = 2.346 (perplexity = 10.446709120339545)
...
Кажется, что все переменные tf правильно восстановлены, включая состояние, которое будет передано в RNN. Положение данных также восстанавливается (с "шага").
Я также сделал аналогичную программу для модели распознавания MNIST, и она отлично работает: потери до и после восстановления непрерывны.
Существуют ли другие параметры или состояния, которые должны быть сохранены и восстановлены?
import argparse
import os
import tensorflow as tf
import numpy as np
import math
B = 20 # batch size
H = 200 # size of hidden layer of neurons
T = 25 # number of time steps to unroll the RNN for
data_file = 'ptb.train.txt' # any plain text file will do
checkpoint_dir = "tmp"
#----------------
# prepare data
#----------------
data = open(data_file, 'r').read()
chars = list(set(data))
data_size, vocab_size = len(data), len(chars)
print('data has {0} characters, {1} unique.'.format(data_size, vocab_size))
char_to_ix = { ch:i for i,ch in enumerate(chars) }
ix_to_char = { i:ch for i,ch in enumerate(chars) }
input_index_raw = np.array([char_to_ix[ch] for ch in data])
input_index_raw = input_index_raw[0:len(input_index_raw) // T * T]
input_index_raw_shift = np.append(input_index_raw[1:], input_index_raw[0])
input_all = input_index_raw.reshape([-1, T])
target_all = input_index_raw_shift.reshape([-1, T])
num_packed_data = len(input_all)
#----------------
# build model
#----------------
class Model(object):
def __init__(self):
self.input_ph = tf.placeholder(tf.int32, [None, T], name="input_ph")
self.target_ph = tf.placeholder(tf.int32, [None, T], name="target_ph")
embedding = tf.get_variable("embedding", [vocab_size, H], initializer=tf.random_normal_initializer(), dtype=tf.float32)
# input_ph is B x T.
# input_embedded is B x T x H.
input_embedded = tf.nn.embedding_lookup(embedding, self.input_ph)
cell = tf.contrib.rnn.BasicRNNCell(H)
self.state_ph = tf.placeholder(tf.float32, (None, cell.state_size), name="state_ph")
# Make state variable so that it will be saved by the saver.
self.state = tf.get_variable("state", (B, cell.state_size), initializer=tf.zeros_initializer(), trainable=False, dtype=tf.float32)
# Construct initial_state according to whether restoring or not.
self.isRestore = tf.placeholder(tf.bool, shape=(), name="isRestore")
zero_state = cell.zero_state(B, dtype=tf.float32)
self.initial_state = tf.cond(self.isRestore, lambda: self.state, lambda: zero_state)
# input_embedded : B x T x H
# output: B x T x H
# state : B x cell.state_size
output, state_ = tf.nn.dynamic_rnn(cell, input_embedded, initial_state=self.state_ph)
self.final_state = tf.assign(self.state, state_)
# reshape to (B * T) x H.
output_flat = tf.reshape(output, [-1, H])
# Convert hidden layer's output to vector of logits for each vocabulary.
softmax_w = tf.get_variable("softmax_w", [H, vocab_size], dtype=tf.float32)
softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=tf.float32)
logits = tf.matmul(output_flat, softmax_w) + softmax_b
# cross_entropy is a vector of length B * T
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.reshape(self.target_ph, [-1]), logits=logits)
self.loss = tf.reduce_mean(cross_entropy)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
self.global_step = tf.get_variable("global_step", (), initializer=tf.zeros_initializer(), trainable=False, dtype=tf.int32)
self.training_op = optimizer.minimize(cross_entropy, global_step=self.global_step)
def train_batch(self, sess, input_batch, target_batch, initial_state):
final_state_, _, final_loss = sess.run([self.final_state, self.training_op, self.loss], feed_dict={self.input_ph: input_batch, self.target_ph: target_batch, self.state_ph: initial_state})
return final_state_, final_loss
# main
with tf.Session() as sess:
if not tf.gfile.Exists(checkpoint_dir):
tf.gfile.MakeDirs(checkpoint_dir)
batch_stride = num_packed_data // B
# make model
model = Model()
saver = tf.train.Saver()
# always initialize
init = tf.global_variables_initializer()
init.run()
# restore if necessary
isRestore = False
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt:
isRestore = True
last_model = ckpt.model_checkpoint_path
print("Loading " + last_model)
saver.restore(sess, last_model)
# set initial step
step = tf.train.global_step(sess, model.global_step) + 1
print("start step = {0}".format(step))
# fetch initial state
state = sess.run(model.initial_state, feed_dict={model.isRestore: isRestore})
print("Initial state: {0}".format(state))
while True:
# prepare batch data
idx = [(step + x * batch_stride) % num_packed_data for x in range(0, B)]
input_batch = input_all[idx]
target_batch = target_all[idx]
state, last_loss = model.train_batch(sess, input_batch, target_batch, state)
if step % 20 == 0:
print('step {0}: loss = {1:.3f} (perplexity = {2})'.format(step, last_loss, math.exp(last_loss)))
if step % 200 == 0:
saved_file = saver.save(sess, os.path.join(checkpoint_dir, "model.ckpt"), global_step=step)
print("Saved to " + saved_file)
print("Last state: {0}".format(model.state.eval()))
break;
step = step + 1
1 ответ
Проблема решена. Это не имело ничего общего ни с RNN, ни с TensorFlow.
Я изменился
chars = list(set(data))
в
chars = sorted(set(data))
и теперь это работает.
Это связано с тем, что python использует случайную хеш-функцию для построения набора, и каждый раз, когда python перезапускается, 'chars' имеет разный порядок.