Оценить все парные комбинации строк двух тензоров в тензорном потоке
Я пытаюсь определить пользовательский оператор в tenorflow, в котором в какой-то момент мне нужно построить матрицу (z
) который будет содержать суммы всех комбинаций пар строк двух матриц (x
а также y
). В общем, номера строк x
а также y
динамичны.
В NumPy это довольно просто:
import numpy as np
from itertools import product
rows_x = 4
rows_y = 2
dim = 2
x = np.arange(dim*rows_x).reshape(rows_x, dim)
y = np.arange(dim*rows_y).reshape(rows_y, dim)
print('x:\n{},\ny:\n{}\n'.format(x, y))
z = np.zeros((rows_x*rows_y, dim))
print('for loop:')
for i, (x_id, y_id) in enumerate(product(range(rows_x), range(rows_y))):
print('row {}: {} + {}'.format(i, x[x_id, ], y[y_id, ]))
z[i, ] = x[x_id, ] + y[y_id, ]
print('\nz:\n{}'.format(z))
возвращает:
x:
[[0 1]
[2 3]
[4 5]
[6 7]],
y:
[[0 1]
[2 3]]
for loop:
row 0: [0 1] + [0 1]
row 1: [0 1] + [2 3]
row 2: [2 3] + [0 1]
row 3: [2 3] + [2 3]
row 4: [4 5] + [0 1]
row 5: [4 5] + [2 3]
row 6: [6 7] + [0 1]
row 7: [6 7] + [2 3]
z:
[[ 0. 2.]
[ 2. 4.]
[ 2. 4.]
[ 4. 6.]
[ 4. 6.]
[ 6. 8.]
[ 6. 8.]
[ 8. 10.]]
Тем не менее, я понятия не имею, как реализовать что-то подобное в тензорном потоке.
В основном я изучал SO и API tenorflow в надежде найти функцию, которая выдала бы комбинации элементов из двух тензоров, или функцию, которая давала бы перестановки элементов тензора, но безрезультатно.
Любые предложения приветствуются.
2 ответа
Вы можете просто использовать широковещательную способность tenorflow.
import tensorflow as tf
x = tf.constant([[0, 1],[2, 3],[4, 5],[6, 7]], dtype=tf.float32)
y = tf.constant([[0, 1],[2, 3]], dtype=tf.float32)
x_ = tf.expand_dims(x, 0)
y_ = tf.expand_dims(y, 1)
z = tf.reshape(tf.add(x_, y_), [-1, 2])
# or more succinctly
z = tf.reshape(x[None] + y[:, None], [-1, 2])
sess = tf.Session()
sess.run(z)
Опция 1
определяющий z
как переменная и обновляем ее строки:
import tensorflow as tf
from itertools import product
x = tf.constant([[0, 1],[2, 3],[4, 5],[6, 7]],dtype=tf.float32)
y = tf.constant([[0, 1],[2, 3]],dtype=tf.float32)
rows_x,dim=x.get_shape()
rows_y=y.get_shape()[0]
z=tf.Variable(initial_value=tf.zeros([rows_x*rows_y,dim]),dtype=tf.float32)
for i, (x_id, y_id) in enumerate(product(range(rows_x), range(rows_y))):
z=tf.scatter_update(z,i,x[x_id]+y[y_id])
with tf.Session() as sess:
tf.global_variables_initializer().run()
z_val=sess.run(z)
print(z_val)
Это печатает
[[ 0. 2.]
[ 2. 4.]
[ 2. 4.]
[ 4. 6.]
[ 4. 6.]
[ 6. 8.]
[ 6. 8.]
[ 8. 10.]]
Вариант 2
Создание z
понимание списка бросков:
import tensorflow as tf
from itertools import product
x = tf.constant([[0, 1],[2, 3],[4, 5],[6, 7]],dtype=tf.float32)
y = tf.constant([[0, 1],[2, 3]],dtype=tf.float32)
rows_x,dim=x.get_shape().as_list()
rows_y=y.get_shape().as_list()[0]
z=[x[x_id]+y[y_id] for x_id in range(rows_x) for y_id in range(rows_y)]
z=tf.reshape(z,(rows_x*rows_y,dim))
with tf.Session() as sess:
z_val=sess.run(z)
print(z_val)
Сравнение: второе решение примерно в два раза быстрее (измеряется только конструкция z
в обоих решениях). В частности, время: первое решение: 0,211 секунды, второе решение: 0,137 секунды.