Группа argmax/argmin над индексами разбиения в numpy
Numpy-х ufunc
с reduceat
метод, который запускает их по смежным разделам в массиве. Поэтому вместо того, чтобы писать:
import numpy as np
a = np.array([4, 0, 6, 8, 0, 9, 8, 5, 4, 9])
split_at = [4, 5]
maxima = [max(subarray for subarray in np.split(a, split_at)]
Я могу написать:
maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
Оба вернут максимальные значения в срезах a[0:4]
, a[4:5]
, a[5:10]
, будучи [8, 0, 9]
,
Я хотел бы, чтобы аналогичная функция выполняла argmax
отмечая, что мне нужен только один максимальный индекс в каждом разделе: [3, 4, 5]
с вышеупомянутым a
а также split_at
(несмотря на то, что индексы 5 и 9 оба получают максимальное значение в последней группе), что будет возвращено
np.hstack([0, split_at]) + [np.argmax(subarray) for subarray in np.split(a, split_at)]
Ниже я опубликую возможное решение, но хотелось бы видеть такое, которое будет векторизовано без создания индекса по группам.
2 ответа
Это решение включает создание индекса по группам ([0, 0, 0, 0, 1, 2, 2, 2, 2, 2]
в приведенном выше примере).
group_lengths = np.diff(np.hstack([0, split_at, len(a)]))
n_groups = len(group_lengths)
index = np.repeat(np.arange(n_groups), group_lengths)
Тогда мы можем использовать:
maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
all_argmax = np.flatnonzero(np.repeat(maxima, group_lengths) == a)
result = np.empty(len(group_lengths), dtype='i')
result[index[all_argmax[::-1]]] = all_argmax[::-1]
Получить [3, 4, 5]
в result
, [::-1]
s убедитесь, что мы получаем первое, а не последнее значение argmax в каждой группе.
Это основано на том факте, что последний индекс в причудливом назначении определяет назначенное значение, на которое @seberg говорит, что на него не следует полагаться (и более безопасная альтернатива может быть достигнута с помощью result = all_argmax[np.unique(index[all_argmax], return_index=True)[1]]
, который включает в себя сортировку по len(maxima) ~ n_groups
элементы).
Вдохновленный этим вопросом, я добавил функциональность argmin/max в пакет numpy_indexed. Вот как выглядит соответствующий тест. Обратите внимание, что ключи могут быть в любом порядке (и любого типа, поддерживаемого npi):
def test_argmin():
keys = [2, 0, 0, 1, 1, 2, 2, 2, 2, 2]
values = [4, 5, 6, 8, 0, 9, 8, 5, 4, 9]
unique, amin = group_by(keys).argmin(values)
npt.assert_equal(unique, [0, 1, 2])
npt.assert_equal(amin, [1, 4, 0])