ValueError в тензорном потоке while_loop инварианты формы
import tensorflow as tf
cluster_size = tf.constant(6) # size of the cluster
m = tf.constant(6) # number of contigs (column size)
n = tf.constant(3) # number of points in a single contigs (column size)
contigs_index = tf.reshape(tf.range(0, m, 1, dtype=tf.int32), [1, -1])
contigs = tf.constant(
[[1.1, 2.2, 3.3], [6.6, 5.5, 4.4], [7.7, 8.8, 9.9], [11.1, 22.2, 33.3],
[66.6, 55.5, 44.4], [77.7, 88.8, 99.9]])
# pad zeo to the right till fixed length
def rpad_with_zero(points):
points = tf.slice(tf.pad(points, tf.reshape(tf.concat(
[tf.zeros([1, 2], tf.int32), tf.add(
tf.zeros([1, 2], tf.int32),
tf.subtract(cluster_size, tf.size(points)))], 0), [2, -1]), "CONSTANT"),
(0, tf.subtract(cluster_size, tf.size(points))),
(1, cluster_size))
return points
#calculate pearson correlation coefficient r value
def calculate_pcc(row, contigs):
r = tf.divide(tf.subtract(
tf.multiply(tf.to_float(n), tf.reduce_sum(tf.multiply(row, contigs), 1)),
tf.multiply(tf.reduce_sum(row, 1), tf.reduce_sum(contigs, 1))),
tf.multiply(
tf.sqrt(tf.subtract(
tf.multiply(tf.to_float(n), tf.reduce_sum(tf.square(row), 1)),
tf.square(tf.reduce_sum(row, 1)))),
tf.sqrt(tf.subtract(tf.multiply(
tf.to_float(n), tf.reduce_sum(tf.square(contigs), 1)),
tf.square(tf.reduce_sum(contigs, 1)))
)))
return r
#slice first row from contigs
row = tf.slice(contigs, (0, 0), (1, 3))
#calculate pcc
r = calculate_pcc(row, contigs)
#cluster member index whose r value is greater than 0.90, then casting to
# int32,
members0_index = tf.cast(tf.reshape(tf.where(tf.greater(r, 0.90)), [1, -1]),
tf.int32)
#members = index <intersection> members, padding the members index with
# zeros at right, to keep the fixed cluster length
members0_index = rpad_with_zero(
tf.reshape(tf.sets.set_intersection(contigs_index, members0_index).values,
[1, -1]))
#update index with the rest element index from contigs, and padding
contigs_index = rpad_with_zero(
tf.reshape(tf.sets.set_difference(contigs_index, members0_index).values,
[1, -1]))
#def condition(contigs, contigs_index, members0_index):
def condition(contigs_index, members0_index):
return tf.greater(tf.count_nonzero(contigs_index),
0) # iterate until there is a contig
#def body(contigs, contigs_index, members0_index):
def body(contigs_index, members0_index):
i = tf.reshape(tf.slice(contigs_index, [0, 0], [1, 1]),
[]) #the first element in the contigs_index
row = tf.slice(contigs, (i, 0),
(1, 3)) #slice the ith contig from contigs
r = calculate_pcc(row, contigs)
members_index = tf.cast(tf.reshape(tf.where(tf.greater(r, 0.90)), [1, -1]),
tf.int32)
members_index = rpad_with_zero(rpad_with_zero(
tf.reshape(tf.sets.set_intersection(contigs_index, members_index).values,
[1, -1])))
members0_index = tf.concat([members0_index, members_index], 0)
contigs_index = rpad_with_zero(
tf.reshape(tf.sets.set_difference(contigs_index, members_index).values,
[1, -1]))
#return [contigs, contigs_index, members0_index]
return [contigs_index, members0_index]
sess = tf.Session()
sess.run(tf.while_loop(condition, body,
#loop_vars=[contigs, contigs_index, members0_index],
loop_vars=[contigs_index, members0_index],
#shape_invariants=[contigs.get_shape(), contigs_index.get_shape(),
# tf.TensorShape([None, 6])]))
shape_invariants=[contigs_index.get_shape(), tf.TensorShape([None, 6])]))
Ошибка:
ValueError: Форма для while_12/Merge:0 не является инвариантом для цикла. Он входит в цикл с формой (1, 6), но имеет форму (?,?) После одной итерации. Предоставьте инварианты формы, используя либо
shape_invariants
аргумент tf.while_loop или set_shape() для переменных цикла.
Похоже переменная
contigs_index
отвечает, но я действительно не знаю почему! Я раскрываю цикл выполнения каждого оператора, но не могу найти никакого несоответствия формы!
2 ответа
shape_invariants=[contigs_index.get_shape(), tf.TensorShape([None, 6])]))
должен стать shape_invariants=[tf.TensorShape([None, None]), tf.TensorShape([None, 6])]))
, чтобы учесть изменения формы contigs_index
переменная (в rpad_with_zero
вызов).
Вы должны использовать tf.while_loop
внутри графика, а не сессии. То есть, см. Раздел "Поток управления" в учебном пособии https://sebastianraschka.com/pdf/books/dlb/appendix_g_tensorflow.pdf или в официальной документации по тензорному потоку по адресу https://www.tensorflow.org/api_guides/python/control_flow_ops для больше информации о том, как использовать условные операторы, tf.while_loop
с и т. д.