Значения клипа Tensorflow в коллекции?
Я пытаюсь обрезать все обучающие переменные для моих дискриминаторов в моей сети.
Я получаю переменные для дискриминаторов, как это:
A_d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'A_d_')
B_d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'A_B_')
discriminatorVars = self.A_d_vars + self.B_d_vars
Теперь, если я попытаюсь сделать этоdiscriminatorVars.assign(tf.clip_by_value(discriminatorVars, 0.01, 0.1))
чтобы обрезать все значения до [0,01, 0,1], это не сработает, поскольку переменные представляют собой списки Python, а не тензоры.
Я также попробовал это, но это не работает:
self.sess.run(tf.map_fn(lambda var: var.assign(tf.clip_by_value(var, 0.01, 0.1)), var_list))
Это говорит о том, что list
объект не имеет assign
метод.
В настоящее время я перебираю все переменные в списке и вызываю self.sess.run(var.assign(tf.clip_by_value(var, 0.01, 0.1)))
Проблема в том, что это очень медленно.
Как я могу пакетно обновить коллекции, чтобы их значения были обрезаны?
1 ответ
Попробуйте составить список назначенных операций, которые вы хотите сделать, и используйте tf.group
( https://www.tensorflow.org/api_docs/python/tf/group), чтобы сгруппировать их. Пройти tf.group
оператор для sess.run
,
Session.run()
может иметь нетривиальные накладные расходы, поэтому вы хотите сделать все обновления в одном Session.run()
вызов.