Смущает поведение `tf.cond`

Мне нужен поток условного контроля в моем графике. Если pred является True, граф должен вызвать опцию, которая обновляет переменную и затем возвращает ее, в противном случае она возвращает переменную без изменений. Упрощенная версия:

pred = tf.constant(True)
x = tf.Variable([1])
assign_x_2 = tf.assign(x, [2])
def update_x_2():
  with tf.control_dependencies([assign_x_2]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())

Тем не менее, я считаю, что оба pred=True а также pred=False привести к тому же результату y=[2], что означает, что оператор присваивания также вызывается, когда update_x_2 не выбран tf.cond, Как это объяснить? И как решить эту проблему?

2 ответа

Решение

TL;DR: если хочешь tf.cond() чтобы выполнить побочный эффект (например, назначение) в одной из ветвей, вы должны создать опцию, которая выполняет побочный эффект внутри функции, которую вы передаете tf.cond(),

Поведение tf.cond() немного не интуитивно. Поскольку выполнение в графе TensorFlow проходит через граф, все операции, на которые вы ссылаетесь в любой ветви, должны быть выполнены до оценки условия. Это означает, что как истинная, так и ложная ветви получают управляющую зависимость от tf.assign() оп, и так y всегда настроен на 2, даже если pred is False`.

Решение состоит в том, чтобы создать tf.assign() op внутри функции, которая определяет истинную ветвь. Например, вы можете структурировать свой код следующим образом:

pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
def update_x_2():
  with tf.control_dependencies([tf.assign(x, [2])]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval(feed_dict={pred: False}))  # ==> [1]
  print(y.eval(feed_dict={pred: True}))   # ==> [2]
pred = tf.constant(False)
x = tf.Variable([1])

def update_x_2():
    assign_x_2 = tf.assign(x, [2])
    with tf.control_dependencies([assign_x_2]):
        return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())

Это получит результат [1],

Этот ответ в точности совпадает с ответом выше. Но я хочу поделиться с вами тем, что вы можете включить все функции, которые вы хотели бы использовать, в функцию ветвления. Потому что, учитывая ваш пример кода, тензор x это может быть непосредственно использовано update_x_2 функция.

Другие вопросы по тегам