Как распечатать градиенты во время тренировок в Tensorflow?

Для отладки модели Tensorflow мне нужно посмотреть, изменились ли градиенты или есть ли в них nans. Простая печать переменной в Tensorflow не работает, потому что все, что вы видите, это:

 <tf.Variable 'Model/embedding:0' shape=(8182, 100) dtype=float32_ref>

Я пытался использовать tf.Print класс, но не могу заставить его работать, и мне интересно, можно ли его так использовать. В моей модели у меня есть тренировочный цикл, который печатает значения потерь каждой эпохи:

def run_epoch(session, model, eval_op=None, verbose=False):
    costs = 0.0
    iters = 0
    state = session.run(model.initial_state)
    fetches = {
            "cost": model.cost,
            "final_state": model.final_state,
    }
    if eval_op is not None:
        fetches["eval_op"] = eval_op

    for step in range(model.input.epoch_size):
        feed_dict = {}
        for i, (c, h) in enumerate(model.initial_state):
            feed_dict[c] = state[i].c
            feed_dict[h] = state[i].h

        vals = session.run(fetches, feed_dict)
        cost = vals["cost"]
        state = vals["final_state"]

        costs += cost
        iters += model.input.num_steps

    print("Loss:", costs)

    return costs

Вставка print(model.gradients[0][1]) в эту функцию не будет работать, поэтому я попытался использовать следующий код сразу после потери печати:

grads = model.gradients[0][1]
x = tf.Print(grads, [grads])
session.run(x)

Но я получил следующее сообщение об ошибке:

ValueError: Fetch argument <tf.Tensor 'mul:0' shape=(8182, 100) dtype=float32> cannot be interpreted as a Tensor. (Tensor Tensor("mul:0", shape=(8182, 100), dtype=float32) is not an element of this graph.)

Что имеет смысл, потому что tf.Print действительно не является частью графика. Итак, я попытался с помощью tf.Print после расчета потерь в реальном графике, но это не сработало, и я все еще получил Tensor("Train/Model/mul:0", shape=(8182, 100), dtype=float32),

Как я могу напечатать переменную градиентов внутри цикла обучения в Tensorflow?

1 ответ

Решение

По моему опыту, лучший способ увидеть поток градиента в тензорном потоке не с tf.Print, но с тензорной доской. Вот пример кода, который я использовал в другой проблеме, где градиенты были ключевой проблемой в обучении:

for g, v in grads_and_vars:
  tf.summary.histogram(v.name, v)
  tf.summary.histogram(v.name + '_grad', g)

merged = tf.summary.merge_all()
writer = tf.summary.FileWriter('train_log_layer', tf.get_default_graph())

...

_, summary = sess.run([train_op, merged], feed_dict={I: 2*np.random.rand(1, 1)-1})
if i % 10 == 0:
  writer.add_summary(summary, global_step=i)

Это представит вам распределение градиентов во времени. Кстати, для проверки на NaN есть специальная функция в tenorflow: tf.is_nan, Обычно вам не нужно проверять, равен ли градиент NaN: когда это происходит, переменная также взрывается, и это будет ясно видно на тензорной доске.

Другие вопросы по тегам