np.add.at индексация с массивом
Я работаю над cs231n, и мне трудно понять, как работает эта индексация. При условии
x = [[0,4,1], [3,2,4]]
dW = np.zeros(5,6)
dout = [[[ 1.19034710e-01 -4.65005990e-01 8.93743168e-01 -9.78047129e-01
-8.88672957e-01 -4.66605091e-01]
[ -1.38617461e-03 -2.64569728e-01 -3.83712733e-01 -2.61360826e-01
8.07072009e-01 -5.47607277e-01]
[ -3.97087458e-01 -4.25187949e-02 2.57931759e-01 7.49565950e-01
1.37707667e+00 1.77392240e+00]]
[[ -1.20692745e+00 -8.28111550e-01 6.53041092e-01 -2.31247762e+00
-1.72370321e+00 2.44308033e+00]
[ -1.45191870e+00 -3.49328154e-01 6.15445782e-01 -2.84190582e-01
4.85997687e-02 4.81590106e-01]
[ -1.14828583e+00 -9.69055406e-01 -1.00773809e+00 3.63553835e-01
-1.28078363e+00 -2.54448436e+00]]]
Операция, которую они делают,
np.add.at(dW, x, dout)
х - двумерный массив. Как здесь работает индексация? я прошел сквозь np.ufunc.at
документация, но у них есть простые примеры с массивом 1d и константой:
np.add.at(a, [0, 1, 2, 2], 1)
3 ответа
In [226]: x = [[0,4,1], [3,2,4]]
...: dW = np.zeros((5,6),int)
In [227]: np.add.at(dW,x,1)
In [228]: dW
Out[228]:
array([[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0]])
С этим x
дубликатов нет, поэтому add.at
так же, как использование +=
индексации. Эквивалентно мы можем прочитать измененные значения с:
In [229]: dW[x[0], x[1]]
Out[229]: array([1, 1, 1])
Индексы работают одинаково в любом случае, включая трансляцию:
In [234]: dW[...]=0
In [235]: np.add.at(dW,[[[1],[2]],[2,4,4]],1)
In [236]: dW
Out[236]:
array([[0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 2, 0],
[0, 0, 1, 0, 2, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]])
возможные значения
Значения должны быть broadcastable
относительно индексов:
In [112]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)))
...
In [114]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)).ravel())
...
ValueError: array is not broadcastable to correct shape
In [115]: np.add.at(dW,[[[1],[2]],[2,4,4]],[1,2,3])
In [117]: np.add.at(dW,[[[1],[2]],[2,4,4]],[[1],[2]])
In [118]: dW
Out[118]:
array([[ 0, 0, 0, 0, 0, 0],
[ 0, 0, 3, 0, 9, 0],
[ 0, 0, 4, 0, 11, 0],
[ 0, 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0]])
В этом случае индексы определяют форму (2,3), поэтому (2,3),(3,), (2,1) и скалярные значения работают. (6,) нет.
В этом случае, add.at
отображает массив (2,3) на (2,2) подмассив dW
,
В последнее время мне также трудно понять эту строку кода. Надеюсь, что то, что я получил, поможет тебе, поправь меня, если я ошибаюсь.
Три массива в этой строке кода следующие:
x , whose shape is (N,T)
dW, ---(V,D)
dout ---(N,T,D)
Затем мы приходим к строковому коду, мы хотим выяснить, что происходит
np.add.at(dW, x, dout)
Если вы не хотите знать процедуру мышления. Приведенный выше код эквивалентен:
for row in range(N):
for col in range(T):
dW[ x[row,col] , :] += dout[row,col, :]
Это процедура мышления:
Ссылаясь на этот документ
https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.ufunc.at.html
Мы знаем, что х - это индексный массив. Таким образом, ключ должен понимать dW [x]. Это концепция индексации массива (dW) с использованием другого массива (x). Если вы не знакомы с этой концепцией, можете проверить эту ссылку
https://docs.scipy.org/doc/numpy-1.13.0/user/basics.indexing.html
Вообще говоря, то, что возвращается, когда используются индексные массивы, это массив такой же формы, что и индексный массив, но с типом и значениями индексируемого массива.
dW [x] даст нам массив, форма которого (N,T,D), часть (N,T) получена из x, а (D) - из dW (V,D). Обратите внимание, что каждый элемент x находится внутри диапазона [0, v).
Давайте возьмем некоторое число в качестве конкретного примера
x: np.array([[0,0],[0,0]]) ---- (2,2) N=2, T=2
dW: np.array([[0,0],[2,2]]) ---- (2,2) V=2, D=2
dout: np.arange(1,9).reshape(2,2,2) ----(2,2,2) N=2, T=2, D=2
dW[x] should be [ [[0 0] #this comes from the dW's firt row
[0 0]]
[[0 0]
[0 0]] ]
dW [x] add dout означает, что добавить элемент elemnet (здесь этот трюк, позже будет объяснено)
np.add.at(dW, x, dout) gives
[ [16 20]
[ 2 2] ]
Зачем? Процедура такова:
Он добавляет [1,2] к первой строке dW, которая равна [0,0].
Почему первый ряд? Поскольку x[0,0] = 0, что указывает на первую строку dW, dW[0] = dW[0,:] = первая строка.
Затем добавьте [3,4] в первый ряд dW[0,0]. [3,4]= DOUT [0,1],:. [0,0] снова происходит от dW, x [0,1] = 0, все еще первая строка dW [0].
Затем добавьте [5,6] к первой строке dW.
Затем добавьте [7,8] к первому ряду dW.
Таким образом, результат будет [1+3+5+7, 2+4+6+8] = [16,20]. Потому что мы не трогаем второй ряд dW. Второй ряд dW остается неизменным.
Хитрость в том, что мы будем считать только исходную строку один раз, можем подумать, что буфера нет, и каждый шаг воспроизводится в исходном месте.
Рассмотрим пример, основанный на этом назначении от cs231n. Если мы говорим о нескольких направлениях, гораздо проще использовать конкретные настройки.
np.random.seed(1)
N, T, V, D = 2, 3, 7, 6
x = np.random.randint(V, size=(N, T))
dW_man = np.zeros((V, D))
dW_man[x].shape, x.shape
((2, 3, 6), (2, 3))
x
array([[5, 3, 4],
[0, 1, 3]])
dout = np.arange(2*3*6).reshape(dW_man[x].shape)
dout
array([[[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17]],
[[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]]])
Какими должны быть ряды
dW_man[x]
? Хорошо
[0, 1, ...]
следует добавить в строку 5,
[ 6, 7, ..]
- в ряд 3. А также
[30, 31, ...]
следует добавить в строку 3. Давайте посчитаем это вручную. Дополнительные примеры и объяснения можно найти по этой ссылке на GitHub gist : .
dW_man[5] = dout[0, 0]
dW_man[3] = dout[0, 1]
dW_man[4] = dout[0, 2]
dW_man[0] = dout[1, 0]
dW_man[1] = dout[1, 1]
dW_man[3] = dout[1, 2]
dW_man
array([[18., 19., 20., 21., 22., 23.],
[24., 25., 26., 27., 28., 29.],
[ 0., 0., 0., 0., 0., 0.],
[30., 31., 32., 33., 34., 35.],
[12., 13., 14., 15., 16., 17.],
[ 0., 1., 2., 3., 4., 5.],
[ 0., 0., 0., 0., 0., 0.]])
Теперь воспользуемся
np.add.at
.
np.random.seed(1)
N, T, V, D = 2, 3, 7, 6
x = np.random.randint(V, size=(N, T))
dW = np.zeros((V, D))
dout = np.arange(2*3*6).reshape(dW[x].shape)
np.add.at(dW, x, dout)
dW
array([[18., 19., 20., 21., 22., 23.],
[24., 25., 26., 27., 28., 29.],
[ 0., 0., 0., 0., 0., 0.],
[36., 38., 40., 42., 44., 46.],
[12., 13., 14., 15., 16., 17.],
[ 0., 1., 2., 3., 4., 5.],
[ 0., 0., 0., 0., 0., 0.]])