Как я могу использовать while_loop и tf.layers.batch_normalization для обучения?
Мне нужно добавить слой batch_normalization в теле цикла while, но он ломается, когда я тренируюсь в сети. Все нормально если уберу x = tf.layers.batch_normalization(x, training=flag)
, Могу ли я использовать высокий API в теле цикла? Я не хочу использовать tf.nn.tf.nn.batch_normalization
потому что это простой пример, а моя сеть гораздо сложнее.
import tensorflow as tf
from data_pre import get_data
data, labels = get_data(
['../UCR_TS_Archive_2015/ItalyPowerDemand/ItalyPowerDemand_TRAIN'], 24, 2,True, 0, 2) #pylint: disable=line-too-long
flag = True
def cond(i, x):
return i < 1
def body(i, x):
x = tf.layers.conv1d(x, 1, 7, padding='same')
x = tf.layers.batch_normalization(x, training=flag)
x = tf.nn.relu(x)
return i + 1, x
_, y = tf.while_loop(cond, body, [0, data], back_prop=False)
y = tf.layers.flatten(y)
logits = tf.layers.dense(y, 2)
loss = tf.losses.mean_squared_error(labels, logits)
optimizer = tf.train.AdamOptimizer()
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss, tf.train.get_global_step())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for _ in range(10):
sess.run(train_op)
coord.request_stop()
coord.join(threads)
Вот информация об ошибке:
Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1327, in _do_call
return fn(*args)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1312, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1420, in _call_tf_sessionrun
status, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 516, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: The node 'gradients/mean_squared_error/div_grad/Neg' has inputs from different frames. The input 'while/batch_normalization/AssignMovingAvg_1' is in frame 'while/while_context'. The input 'one_hot' is in frame ''.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "./test.py", line 40, in <module>
sess.run(train_op)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 905, in run
run_metadata_ptr)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1140, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1321, in _do_run
run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1340, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: The node 'gradients/mean_squared_error/div_grad/Neg' has inputs from different frames. The input 'while/batch_normalization/AssignMovingAvg_1' is in frame 'while/while_context'. The input 'one_hot' is in frame ''.
1 ответ
Решение
Я получил помощь от GitHub. Если у вас есть подобная проблема, вы можете получить помощь из сети, используя while_loop с batch_normalization не может тренироваться