Трансляция операций с индексами в NumPy

Как я могу взять элементы из массива NumPy с учетом нескольких индексных массивов с широковещательной рассылкой? Или: как я могу упростить / векторизовать этот цикл:

elems = np.random.rand(3, 10, 7) # shape N x I x M
ind = np.array([[1, 2], [3, 4], [0, 9]]) # shape N x J
res = np.stack([elems[i, ind[i]] for i in range(len(elems))]) # shape N x J x M

1 ответ

Переведите индекс цикла в arange и используйте braodcasting:

>>> elems = np.arange(2*3*4).reshape(2,3,4)
>>> ind = np.arange(0,8,2).reshape(2, 2) % 3
>>> 
>>> elems
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])
>>> elems[np.arange(2)[:, None], ind]
array([[[ 0,  1,  2,  3],
        [ 8,  9, 10, 11]],

       [[16, 17, 18, 19],
        [12, 13, 14, 15]]])
Другие вопросы по тегам