Группа argmax/argmin над индексами разбиения в numpy

Numpy-х ufunc с reduceat метод, который запускает их по смежным разделам в массиве. Поэтому вместо того, чтобы писать:

import numpy as np
a = np.array([4, 0, 6, 8, 0, 9, 8, 5, 4, 9])
split_at = [4, 5]
maxima = [max(subarray for subarray in np.split(a, split_at)]

Я могу написать:

maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))

Оба вернут максимальные значения в срезах a[0:4], a[4:5], a[5:10], будучи [8, 0, 9],

Я хотел бы, чтобы аналогичная функция выполняла argmax отмечая, что мне нужен только один максимальный индекс в каждом разделе: [3, 4, 5] с вышеупомянутым a а также split_at (несмотря на то, что индексы 5 и 9 оба получают максимальное значение в последней группе), что будет возвращено

np.hstack([0, split_at]) + [np.argmax(subarray) for subarray in np.split(a, split_at)]

Ниже я опубликую возможное решение, но хотелось бы видеть такое, которое будет векторизовано без создания индекса по группам.

2 ответа

Это решение включает создание индекса по группам ([0, 0, 0, 0, 1, 2, 2, 2, 2, 2] в приведенном выше примере).

group_lengths = np.diff(np.hstack([0, split_at, len(a)]))
n_groups = len(group_lengths)
index = np.repeat(np.arange(n_groups), group_lengths)

Тогда мы можем использовать:

maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
all_argmax = np.flatnonzero(np.repeat(maxima, group_lengths) == a)
result = np.empty(len(group_lengths), dtype='i')
result[index[all_argmax[::-1]]] = all_argmax[::-1]

Получить [3, 4, 5] в result, [::-1] s убедитесь, что мы получаем первое, а не последнее значение argmax в каждой группе.

Это основано на том факте, что последний индекс в причудливом назначении определяет назначенное значение, на которое @seberg говорит, что на него не следует полагаться (и более безопасная альтернатива может быть достигнута с помощью result = all_argmax[np.unique(index[all_argmax], return_index=True)[1]], который включает в себя сортировку по len(maxima) ~ n_groups элементы).

Вдохновленный этим вопросом, я добавил функциональность argmin/max в пакет numpy_indexed. Вот как выглядит соответствующий тест. Обратите внимание, что ключи могут быть в любом порядке (и любого типа, поддерживаемого npi):

def test_argmin():
    keys   = [2, 0, 0, 1, 1, 2, 2, 2, 2, 2]
    values = [4, 5, 6, 8, 0, 9, 8, 5, 4, 9]
    unique, amin = group_by(keys).argmin(values)
    npt.assert_equal(unique, [0, 1, 2])
    npt.assert_equal(amin,   [1, 4, 0])
Другие вопросы по тегам