Почему tf.case строит вызываемую функцию дважды?

Я пытался построить несколько филиалов в сети. Поэтому я использовал tf.case в своем коде. Но я обнаружил, что tf.case всегда создает последнюю вызываемую функцию дважды, что приводит к ошибке переменной: "Переменная XXX уже существует"(я создал переменные slim, область видимости переменной case/If_x не существует, это почему я бы получил ошибку). Вот тестовая программа с выходом.

import numpy as np
import tensorflow as tf

slim = tf.contrib.slim

def fn1(X, Y):
    with tf.variable_scope("fn1", reuse=False):
       w = tf.Variable(1.0, name="w")
       #w = slim.variable(name="w", shape=())
    return X*w, Y*w

def fn2(X, Y):
    with tf.variable_scope("fn2", reuse=False):
       w = tf.Variable(2.0, name="w")
       #w = slim.variable(name="w", shape=())
    return X*w, Y*w

def fn3(X, Y):
    with tf.variable_scope("fn3", reuse=False):
       w = tf.Variable(3.0, name="w")
       #w = slim.variable(name="w", shape=())
    return X*w, Y*w

class Test:

    def __init__(self):
        self.Z = tf.placeholder(dtype=tf.int32, shape=())
        self.X = tf.Variable(1.0, name="X")
        self.Y = tf.Variable(2.0, name="Y")

    def build(self):
        self.result = tf.case(
            pred_fn_pairs=[
                    (tf.equal(self.Z, 10), lambda : fn3(self.X, self.Y)),
                    (tf.equal(self.Z, 20), lambda : fn2(self.X, self.Y)),
                    (tf.equal(self.Z, 30), lambda : fn1(self.X, self.Y))],
                    exclusive=False)

test = Test()
test.build()

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    tvars = tf.trainable_variables()
    tvars_vals = sess.run(tvars)

    for var, val in zip(tvars, tvars_vals):
        print(var.name) 

    aa = sess.run(test.result, feed_dict={test.Z:20})
    print aa

Выход:

X:0
Y:0
case/If_0/fn1/w:0
case/If_0/fn1_1/w:0
case/If_1/fn2/w:0
case/If_2/fn3/w:0
(2.0, 4.0)

0 ответов

Другие вопросы по тегам