Фильтровать строки массива NumPy?
Я ищу, чтобы применить функцию к каждой строке массива. Если эта функция оценивается как true, я сохраню строку, в противном случае я откажусь от нее. Например, моя функция может быть:
def f(row):
if sum(row)>10: return True
else: return False
Мне было интересно, было ли что-то похожее на:
np.apply_over_axes()
который применяет функцию к каждой строке массива numpy и возвращает результат. Я надеялся на что-то вроде:
np.filter_over_axes()
который применял бы функцию к каждой строке массива numpy и возвращал только те строки, для которых функция вернула true. Есть что-нибудь подобное? Или мне просто использовать цикл for?
2 ответа
В идеале вы могли бы реализовать векторизованную версию своей функции и использовать ее для логического индексирования. Для подавляющего большинства проблем это правильное решение. Numpy предоставляет довольно много функций, которые могут действовать на различные оси, а также на все основные операции и сравнения, поэтому наиболее полезные условия должны быть векторизованными.
import numpy as np
x = np.random.randn(20, 3)
x_new = x[np.sum(x, axis=1) > .5]
Если вы абсолютно уверены, что не можете сделать вышеизложенное, я бы предложил использовать понимание списка (или np.apply_along_axis
) создать массив bools для индексации.
def myfunc(row):
return sum(row) > .5
bool_arr = np.array([myfunc(row) for row in x])
x_new = x[bool_arr]
Это выполнит работу относительно чистым способом, но будет значительно медленнее, чем векторизованная версия. Пример:
x = np.random.randn(5000, 200)
%timeit x[np.sum(x, axis=1) > .5]
# 100 loops, best of 3: 5.71 ms per loop
%timeit x[np.array([myfunc(row) for row in x])]
# 1 loops, best of 3: 217 ms per loop
Как упомянул @Roger Fan, применение функции по строкам действительно должно выполняться векторизованным образом для всего массива. Канонический способ фильтрации — создать логическую маску и применить ее к массиву. Тем не менее, если окажется, что функция настолько сложна, что векторизация невозможна, лучше/быстрее преобразовать массив в список Python (особенно если он использует такие функции Python, как ) и применить к нему функцию.
msk = arr.sum(axis=1)>10 # best way to create a boolean mask
msk = [f(row) for row in arr.tolist()] # second best way
# ^^^^^^^^ <---- convert to list
filtered_arr = arr[msk] # filtered via boolean indexing
Рабочий пример и тест производительности
Как вы можете видеть из приведенного ниже теста timeit, цикл по списку (arr.tolist()
) намного быстрее, чем цикл по массиву numpy (arr
), отчасти потому, что Pythonsum()
и неnp.sum()
вызывается в функцииf()
. Тем не менее, векторизованный метод намного быстрее, чем оба.
def f(row):
if sum(row)>10: return True
else: return False
arr = np.random.rand(10000, 200)
%timeit arr[[f(row) for row in arr]]
# 260 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit arr[[f(row) for row in arr.tolist()]]
# 114 ms ± 4.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit arr[arr.sum(axis=1)>10]
# 10.8 ms ± 2.03 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)