Как чередовать поезда в тензорном потоке?

Я внедряю альтернативную схему обучения. График содержит две учебные операции. Обучение должно чередоваться между ними.

Это актуально для таких исследований, как это или это

Ниже приведен небольшой пример. Но, похоже, обновлять обе операции на каждом шагу. Как я могу явно переключаться между ними?

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
# Import data
mnist = input_data.read_data_sets('/tmp/tensorflow/mnist/input_data', one_hot=True)

# Create the model
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]), name='weights')
b = tf.Variable(tf.zeros([10]), name='biases')
y = tf.matmul(x, W) + b

# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
global_step = tf.Variable(0, trainable=False)

tvars1 = [b]
train_step1 = tf.train.GradientDescentOptimizer(0.5).apply_gradients(zip(tf.gradients(cross_entropy, tvars1), tvars1), global_step)
tvars2 = [W]
train_step2 = tf.train.GradientDescentOptimizer(0.5).apply_gradients(zip(tf.gradients(cross_entropy, tvars2), tvars2), global_step)
train_step = tf.cond(tf.equal(tf.mod(global_step,2), 0), true_fn= lambda:train_step1, false_fn=lambda : train_step2)


sess = tf.InteractiveSession()
tf.global_variables_initializer().run()


# Train
for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    if i % 100 == 0:
        print(sess.run([cross_entropy, global_step], feed_dict={x: mnist.test.images,
                                         y_: mnist.test.labels}))

Это приводит к

[2.0890141, 2]
[0.38277805, 202]
[0.33943111, 402]
[0.32314575, 602]
[0.3113254, 802]
[0.3006627, 1002]
[0.2965056, 1202]
[0.29858461, 1402]
[0.29135355, 1602]
[0.29006076, 1802]      

Глобальный шаг повторяется до 1802, поэтому каждый раз выполняются оба поезда. train_step называется. (Это также происходит, когда всегда ложное условие tf.equal(global_step,-1) например.)

Мой вопрос, как чередовать выполнение train_step1 а также train_step2?

1 ответ

Решение

Я думаю, что самый простой способ это просто

for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  if i % 2 == 0:
    sess.run(train_step1, feed_dict={x: batch_xs, y_: batch_ys})
  else:
    sess.run(train_step2, feed_dict={x: batch_xs, y_: batch_ys})

Но если необходимо сделать условное переключение через тензор потока, сделайте это так:

optimizer = tf.train.GradientDescentOptimizer(0.5)
train_step = tf.cond(tf.equal(tf.mod(global_step, 2), 0),
                     true_fn=lambda: optimizer.apply_gradients(zip(tf.gradients(cross_entropy, tvars1), tvars1), global_step),
                     false_fn=lambda: optimizer.apply_gradients(zip(tf.gradients(cross_entropy, tvars2), tvars2), global_step))
Другие вопросы по тегам