Почему tf.case строит вызываемую функцию дважды?
Я пытался построить несколько филиалов в сети. Поэтому я использовал tf.case в своем коде. Но я обнаружил, что tf.case всегда создает последнюю вызываемую функцию дважды, что приводит к ошибке переменной: "Переменная XXX уже существует"(я создал переменные slim, область видимости переменной case/If_x не существует, это почему я бы получил ошибку). Вот тестовая программа с выходом.
import numpy as np
import tensorflow as tf
slim = tf.contrib.slim
def fn1(X, Y):
with tf.variable_scope("fn1", reuse=False):
w = tf.Variable(1.0, name="w")
#w = slim.variable(name="w", shape=())
return X*w, Y*w
def fn2(X, Y):
with tf.variable_scope("fn2", reuse=False):
w = tf.Variable(2.0, name="w")
#w = slim.variable(name="w", shape=())
return X*w, Y*w
def fn3(X, Y):
with tf.variable_scope("fn3", reuse=False):
w = tf.Variable(3.0, name="w")
#w = slim.variable(name="w", shape=())
return X*w, Y*w
class Test:
def __init__(self):
self.Z = tf.placeholder(dtype=tf.int32, shape=())
self.X = tf.Variable(1.0, name="X")
self.Y = tf.Variable(2.0, name="Y")
def build(self):
self.result = tf.case(
pred_fn_pairs=[
(tf.equal(self.Z, 10), lambda : fn3(self.X, self.Y)),
(tf.equal(self.Z, 20), lambda : fn2(self.X, self.Y)),
(tf.equal(self.Z, 30), lambda : fn1(self.X, self.Y))],
exclusive=False)
test = Test()
test.build()
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
tvars = tf.trainable_variables()
tvars_vals = sess.run(tvars)
for var, val in zip(tvars, tvars_vals):
print(var.name)
aa = sess.run(test.result, feed_dict={test.Z:20})
print aa
Выход:
X:0
Y:0
case/If_0/fn1/w:0
case/If_0/fn1_1/w:0
case/If_1/fn2/w:0
case/If_2/fn3/w:0
(2.0, 4.0)