Что именно tf.expand_dims делает с вектором и почему результаты могут быть добавлены вместе, даже если формы матрицы различны?

Я добавляю два вектора, которые, как я думал, были "преобразованы" вместе, и в результате получаю 2d матрицу. Я ожидаю некоторый тип ошибки здесь, но не получил это. Я думаю, что понимаю, что происходит, он рассматривал их так, как будто было еще два набора каждого вектора по горизонтали и вертикали, но я не понимаю, почему результаты a и b не отличаются. И если они не предназначены, почему это работает вообще?

import tensorflow as tf
import numpy as np

start_vec = np.array((83,69,45))
a = tf.expand_dims(start_vec, 0)
b = tf.expand_dims(start_vec, 1)
ab_sum = a + b
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    a = sess.run(a)
    b = sess.run(b)
    ab_sum = sess.run(ab_sum)

print(a)
print(b)
print(ab_sum)

=================================================

[[83 69 45]]

[[83]
 [69]
 [45]]

[[166 152 128]
 [152 138 114]
 [128 114  90]]

1 ответ

Решение

На самом деле, в этом вопросе шире используются вещательные характеристики тензорного потока, которые так же, как и numpy ( Broadcasting). Broadcasting избавляется от требования, что форма операции между тензорами должна быть одинаковой. Конечно, он также должен соответствовать определенным условиям.

Общие правила вещания:

При работе с двумя массивами NumPy сравнивает их формы поэлементно. Он начинается с конечных размеров и продвигается вперед. Два измерения совместимы, когда

1. они равны или

2. один из них 1

Простой пример - одномерные тензоры, умноженные на скаляры.

import tensorflow as tf

start_vec = tf.constant((83,69,45))
b = start_vec * 2

with tf.Session() as sess:
    print(sess.run(b))

[166 138  90]

Вернемся к вопросу, функция tf.expand_dims() это вставить размер в форму тензора в указанном axis позиция. Ваша исходная форма данных (3,), Вы получите форму a=tf.expand_dims(start_vec, 0) является (1,3) когда ваш набор axis=0. Вы получите форму b=tf.expand_dims(start_vec, 1) является (3,1) когда ваш набор axis=1,

Сравнивая правила broadcastingВидно, что они удовлетворяют второму условию. Таким образом, их фактическая работа

83,83,83     83,69,45
69,69,69  +  83,69,45
45,45,45     83,69,45
Другие вопросы по тегам