ValueError в тензорном потоке while_loop инварианты формы

import tensorflow as tf

cluster_size = tf.constant(6) # size of the cluster
m = tf.constant(6) # number of contigs (column size)
n = tf.constant(3) # number of points in a single contigs (column size)
contigs_index = tf.reshape(tf.range(0, m, 1, dtype=tf.int32), [1, -1])
contigs = tf.constant(
  [[1.1, 2.2, 3.3], [6.6, 5.5, 4.4], [7.7, 8.8, 9.9], [11.1, 22.2, 33.3],
    [66.6, 55.5, 44.4], [77.7, 88.8, 99.9]])

# pad zeo to the right till fixed length
def rpad_with_zero(points):
  points = tf.slice(tf.pad(points, tf.reshape(tf.concat(
    [tf.zeros([1, 2], tf.int32), tf.add(
      tf.zeros([1, 2], tf.int32),
      tf.subtract(cluster_size, tf.size(points)))], 0), [2, -1]), "CONSTANT"),
                    (0, tf.subtract(cluster_size, tf.size(points))),
                    (1, cluster_size))
  return points

#calculate pearson correlation coefficient r value
def calculate_pcc(row, contigs):
  r = tf.divide(tf.subtract(
      tf.multiply(tf.to_float(n), tf.reduce_sum(tf.multiply(row, contigs), 1)),
      tf.multiply(tf.reduce_sum(row, 1), tf.reduce_sum(contigs, 1))),
                tf.multiply(
      tf.sqrt(tf.subtract(
        tf.multiply(tf.to_float(n), tf.reduce_sum(tf.square(row), 1)), 
        tf.square(tf.reduce_sum(row, 1)))), 
      tf.sqrt(tf.subtract(tf.multiply(
        tf.to_float(n), tf.reduce_sum(tf.square(contigs), 1)),
        tf.square(tf.reduce_sum(contigs, 1)))
      )))
  return r

#slice first row from contigs
row = tf.slice(contigs, (0, 0), (1, 3))
#calculate pcc
r = calculate_pcc(row, contigs)
#cluster member index whose r value is greater than 0.90, then casting to
# int32,
members0_index = tf.cast(tf.reshape(tf.where(tf.greater(r, 0.90)), [1, -1]),
                         tf.int32)
#members = index <intersection> members, padding the members index with
# zeros at right, to keep the fixed cluster length
members0_index = rpad_with_zero(
  tf.reshape(tf.sets.set_intersection(contigs_index, members0_index).values,
             [1, -1]))
#update index with the rest element index from contigs, and padding
contigs_index = rpad_with_zero(
  tf.reshape(tf.sets.set_difference(contigs_index, members0_index).values,
             [1, -1]))

#def condition(contigs, contigs_index, members0_index):
def condition(contigs_index, members0_index):
  return tf.greater(tf.count_nonzero(contigs_index),
                    0) # iterate until there is a contig

#def body(contigs, contigs_index, members0_index):
def body(contigs_index, members0_index):
  i = tf.reshape(tf.slice(contigs_index, [0, 0], [1, 1]),
                 []) #the first element in the contigs_index
  row = tf.slice(contigs, (i, 0),
                 (1, 3)) #slice the ith contig from contigs
  r = calculate_pcc(row, contigs)
  members_index = tf.cast(tf.reshape(tf.where(tf.greater(r, 0.90)), [1, -1]),
                          tf.int32)
  members_index = rpad_with_zero(rpad_with_zero(
    tf.reshape(tf.sets.set_intersection(contigs_index, members_index).values,
               [1, -1])))
  members0_index = tf.concat([members0_index, members_index], 0)
  contigs_index = rpad_with_zero(
    tf.reshape(tf.sets.set_difference(contigs_index, members_index).values,
               [1, -1]))
  #return [contigs, contigs_index, members0_index]
  return [contigs_index, members0_index]

sess = tf.Session()
sess.run(tf.while_loop(condition, body,
   #loop_vars=[contigs, contigs_index, members0_index],
   loop_vars=[contigs_index, members0_index],
   #shape_invariants=[contigs.get_shape(), contigs_index.get_shape(), 
   # tf.TensorShape([None, 6])]))
   shape_invariants=[contigs_index.get_shape(), tf.TensorShape([None, 6])]))

Ошибка:

ValueError: Форма для while_12/Merge:0 не является инвариантом для цикла. Он входит в цикл с формой (1, 6), но имеет форму (?,?) После одной итерации. Предоставьте инварианты формы, используя либо shape_invariants аргумент tf.while_loop или set_shape() для переменных цикла.

Похоже переменная

contigs_index

отвечает, но я действительно не знаю почему! Я раскрываю цикл выполнения каждого оператора, но не могу найти никакого несоответствия формы!

2 ответа

shape_invariants=[contigs_index.get_shape(), tf.TensorShape([None, 6])])) должен стать shape_invariants=[tf.TensorShape([None, None]), tf.TensorShape([None, 6])])), чтобы учесть изменения формы contigs_index переменная (в rpad_with_zero вызов).

Вы должны использовать tf.while_loop внутри графика, а не сессии. То есть, см. Раздел "Поток управления" в учебном пособии https://sebastianraschka.com/pdf/books/dlb/appendix_g_tensorflow.pdf или в официальной документации по тензорному потоку по адресу https://www.tensorflow.org/api_guides/python/control_flow_ops для больше информации о том, как использовать условные операторы, tf.while_loopс и т. д.

Другие вопросы по тегам