Исходный код продукта Kronecker на TensorLy
Я пытаюсь понять код продукта Кронекера для тензоров, реализованных в TensorLy. Ниже приведен код:
def kron(self, a, b):
"""Kronecker product of two tensors.
Parameters
----------
a, b : tensor
The tensors to compute the kronecker product of.
Returns
-------
tensor
"""
s1, s2 = self.shape(a)
s3, s4 = self.shape(b)
a = self.reshape(a, (s1, 1, s2, 1))
b = self.reshape(b, (1, s3, 1, s4))
return self.reshape(a * b, (s1 * s3, s2 * s4))
Я это понимаю self.shape(a)
придаст форму тензора a
(строки, столбцы, срезы). Итак, мы принимаем формуa
в s1
а также s2
, и форма b
в s3
а также s4
.
a = self.reshape(a, (s1, 1, s2, 1))
изменяет тензор 'a', но мне трудно понять, что (s1, 1, s2, 1)
и почему мы это делаем? то же самое с(1, s3, 1, s4)
. Кроме того, почему мы это делаемself.reshape(a * b, (s1 * s3, s2 * s4))
?
Это может показаться открытым вопросом, но я только начал и хотел бы получить помощь!
1 ответ
Это довольно распространенный трюк с использованием широковещательной передачи. Вставка размеров агрегата вa
а также b
в этом выравнивании происходит следующее:
- По первой оси
b
тиражируетсяs1
раз, чтобы соответствовать каждой строкеa
. - На второй оси
a
тиражируетсяs3
раз, чтобы соответствовать каждой строкеb
. - На третьей оси
b
тиражируетсяs2
раз, чтобы соответствовать каждому столбцуa
. - На четвертой оси
a
тиражируетсяs4
раз, чтобы соответствовать каждому столбцуb
.
Когда вы выполняете умножение, вы получаете четырехмерное произведение каждой комбинации элементов. Элементresult[i, j, m, n]
происходит от a[i, m] * b[j, n]
При окончательном изменении формы одни и те же данные сохраняются в памяти и объединяются первые две и последние две оси без переупорядочивания данных.
Давайте рассмотрим простой пример:
a = [[1, 2, 3],
[2, 3, 4],
[3, 4, 5]]
b = [[6, 7]]
Формы изменены с (3, 3)
а также (1, 2)
к (3, 1, 3, 1)
а также (1, 1, 1, 2)
. Это не меняет макет в памяти, поэтомуa
становится
[[[[1], [2], [3]]],
[[[2], [3], [4]]],
[[[3], [4], [5]]]]
b
становится
[[[[6, 7]]]]
Результат будет сформирован (3, 1, 3, 2)
, и будет выглядеть так:
[[[[1*6, 1*7], [2*6, 2*7], [3*6, 3*7]]],
[[[2*6, 2*7], [3*6, 3*7], [4*6, 4*7]]],
[[[3*6, 3*7], [4*6, 4*7], [5*6, 5*7]]]]
Когда вы измените это в окончательный результат, макет памяти останется прежним, но форма изменится на (3*1, 3*2)
:
[[1*6, 1*7, 2*6, 2*7, 3*6, 3*7],
[2*6, 2*7, 3*6, 3*7, 4*6, 4*7],
[3*6, 3*7, 4*6, 4*7, 5*6, 5*7]]
И вуаля, вот продукт Кронекера a
а также b
.