Значения клипа Tensorflow в коллекции?

Я пытаюсь обрезать все обучающие переменные для моих дискриминаторов в моей сети.

Я получаю переменные для дискриминаторов, как это:

A_d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'A_d_')
B_d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'A_B_')
discriminatorVars = self.A_d_vars + self.B_d_vars 

Теперь, если я попытаюсь сделать этоdiscriminatorVars.assign(tf.clip_by_value(discriminatorVars, 0.01, 0.1)) чтобы обрезать все значения до [0,01, 0,1], это не сработает, поскольку переменные представляют собой списки Python, а не тензоры.

Я также попробовал это, но это не работает:

self.sess.run(tf.map_fn(lambda var: var.assign(tf.clip_by_value(var, 0.01, 0.1)), var_list))

Это говорит о том, что list объект не имеет assign метод.

В настоящее время я перебираю все переменные в списке и вызываю self.sess.run(var.assign(tf.clip_by_value(var, 0.01, 0.1)))
Проблема в том, что это очень медленно.

Как я могу пакетно обновить коллекции, чтобы их значения были обрезаны?

1 ответ

Попробуйте составить список назначенных операций, которые вы хотите сделать, и используйте tf.group ( https://www.tensorflow.org/api_docs/python/tf/group), чтобы сгруппировать их. Пройти tf.group оператор для sess.run,

Session.run() может иметь нетривиальные накладные расходы, поэтому вы хотите сделать все обновления в одном Session.run() вызов.

Другие вопросы по тегам