Создание сводки для градиентов через облачную TPU host_call_fn()?

Насколько я понимаю, host_call и host_call_fn() передают статистику из TPU в хост. Однако в инструкциях не очень ясно, как генерировать сводку для чего-либо нескалярного.

Например, я попытался изменить официальный файл mnist_tpu.py, чтобы получить сводку для градиентов, полученных во время обучения. Model_fn() - это место, где добавляются изменения:

...
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
if FLAGS.use_tpu:
  optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

grads = optimizer.compute_gradients(loss)
train_op = optimizer.apply_gradients(grads, global_step)

if not FLAGS.skip_host_call:
    def host_call_fn(gs, loss, lr, grads):
        gs = gs[0]
        with summary.create_file_write(FLAGS.model_dir).as_default():
            summary.scalar('loss', loss[0], step=gs)
            summary.scalar('learning_rate', lr[0], step=gs)

            for index, grad in enumerate(grads):
                summary.histogram('{}-grad'.format(grads[index][1].name),
                        grads[index])

            return summary.all_summary_ops()

    gs_t = tf.reshape(global_step, [1])
    loss_t = tf.reshape(loss, [1])
    lr_t = tf.reshape(learning_rate, [1])
    grads_t = grads
    host_call = (host_call_fn, [gs_t, loss_t, lr_t, grads_t])
return tf.contrib.tpu.TPUEstimatorSpec(
    mode=mode,
    loss=loss,
    train_op=train_op
    )
....

К сожалению, приведенное выше дополнение, похоже, не работает, как при создании гистограммы во время обучения на базе процессора. Любая идея, как правильно генерировать гистограмму на нескалярных тензорах?

1 ответ

Аргументы для host_call_fn должны быть тензорами. Проблема в том, что грады - это пара градиентных тензоров и переменных. Вы должны извлечь имя переменной перед передачей его в host_call_fn и просто передать тензоры градиента. Один из способов заставить это работать - изменить аргументы host_call_fn на **kwargs, где аргументом ключевого слова name является имя переменной, и вместо этого передать словарь в качестве списка тензорных переменных.

Другие вопросы по тегам