Numpy получить верхние k элементов из последнего измерения ndarray

У меня есть многомерный массив, и мне нужно получить верхние k элементов из каждой строки последнего измерения.

>>> x = np.random.random_integers(0, 100, size=(2,1,1,5))
>>> x
array([[[[99, 39, 10, 18, 68]]],
       [[[22,  3, 13, 56,  2]]]])

Я пытаюсь получить:

array([[[[ 99.,  68.]]],
       [[[ 18.,  99.]]]])

Я могу получить индексы, используя следующие, но я не уверен, как вырезать значения.

>>> k = 2
>>> parts = np.flip(-1 - np.arange(k), 0)
>>> indices = np.flip(
...     np.argpartition(x, parts, axis=-1)[..., -k:],
...     axis=-1)
>>> indices
array([[[[0, 4]]],
       [[[3, 0]]]])

2 ответа

Это может решить вашу проблему.

np.sort(x, axis=len(x.shape)-1)[...,-2:]
np.partition(x, 2)[..., -2:]

возвращает 2 самых больших элемента из каждой строки. Например,

x = np.random.random_integers(0, 100, size=(2,1,1,5))
print(x)
print(np.partition(x, 2)[..., -2:])

печатает что-то вроде

[[[[79 34 90 80 56]]]
 [[[78 11 24 20 42]]]]

[[[[80 90]]]
 [[[78 42]]]]
Другие вопросы по тегам