Эффективный способ проверки массивов высокой размерности перекрываются в двух ndarray в Python

Например, у меня есть два ndarrays, форма train_dataset является (10000, 28, 28) и форма val_dateset является (2000, 28, 28),

Кроме использования итераций, есть ли эффективный способ использовать функции массива numpy, чтобы найти перекрытие между двумя ndarrays?

5 ответов

Один трюк, который я узнал из превосходного ответа Хайме здесь, заключается в использовании np.void dtype для просмотра каждой строки во входных массивах как одного элемента. Это позволяет вам рассматривать их как одномерные массивы, которые затем могут быть переданы np.in1d или одна из других установленных подпрограмм.

import numpy as np

def find_overlap(A, B):

    if not A.dtype == B.dtype:
        raise TypeError("A and B must have the same dtype")
    if not A.shape[1:] == B.shape[1:]:
        raise ValueError("the shapes of A and B must be identical apart from "
                         "the row dimension")

    # reshape A and B to 2D arrays. force a copy if neccessary in order to
    # ensure that they are C-contiguous.
    A = np.ascontiguousarray(A.reshape(A.shape[0], -1))
    B = np.ascontiguousarray(B.reshape(B.shape[0], -1))

    # void type that views each row in A and B as a single item
    t = np.dtype((np.void, A.dtype.itemsize * A.shape[1]))

    # use in1d to find rows in A that are also in B
    return np.in1d(A.view(t), B.view(t))

Например:

gen = np.random.RandomState(0)

A = gen.randn(1000, 28, 28)
dupe_idx = gen.choice(A.shape[0], size=200, replace=False)
B = A[dupe_idx]

A_in_B = find_overlap(A, B)

print(np.all(np.where(A_in_B)[0] == np.sort(dupe_idx)))
# True

Этот метод намного эффективнее памяти, чем Divakar, так как он не требует трансляции на (m, n, ...) логический массив. На самом деле, если A а также B являются основными строками, тогда копирование не требуется вообще.


Для сравнения я немного адаптировал решения Divakar и BM.

def divakar(A, B):
    A.shape = A.shape[0], -1
    B.shape = B.shape[0], -1
    return (B[:,None] == A).all(axis=(2)).any(0)

def bm(A, B):
    t = 'S' + str(A.size // A.shape[0] * A.dtype.itemsize)
    ma = np.frombuffer(np.ascontiguousarray(A), t)
    mb = np.frombuffer(np.ascontiguousarray(B), t)
    return (mb[:, None] == ma).any(0)

тесты:

In [1]: na = 1000; nb = 200; rowshape = 28, 28

In [2]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
divakar(A, B)
   ....: 
1 loops, best of 3: 244 ms per loop

In [3]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
   ....: 
100 loops, best of 3: 2.81 ms per loop

In [4]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
   ....: 
100 loops, best of 3: 15 ms per loop

Как видите, решение BM немного быстрее моего при малых n, но np.in1d масштабируется лучше, чем проверка равенства для всех элементов (O (n log n), а не O(n²) сложность).

In [5]: na = 10000; nb = 2000; rowshape = 28, 28

In [6]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
   ....: 
1 loops, best of 3: 271 ms per loop

In [7]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
   ....: 
10 loops, best of 3: 123 ms per loop

Решение Divakar трудно использовать на моем ноутбуке для массивов такого размера, поскольку для этого требуется создание промежуточного массива 15 ГБ, тогда как у меня только 8 ГБ ОЗУ.

Полное вещание генерирует здесь логический массив 10000*2000*28*28 =150 Мо.

Для эффективности вы можете:

  • Пакет данных, для массива 200 КО:

    from pylab import *
    N=10000
    a=rand(N,28,28)
    b=a[[randint(0,N,N//5)]]
    
    packedtype='S'+ str(a.size//a.shape[0]*a.dtype.itemsize) # 'S6272' 
    ma=frombuffer(a,packedtype)  # ma.shape=10000
    mb=frombuffer(b,packedtype)  # mb.shape=2000
    
    %timeit a[:,None]==b   : 102 s
    %timeit ma[:,None]==mb   : 800 ms
    allclose((a[:,None]==b).all((2,3)),(ma[:,None]==mb)) : True
    

    Здесь меньше памяти помогает ленивое сравнение строк, ломающееся при первой разнице:

    In [31]: %timeit a[:100]==b[:100]
    10000 loops, best of 3: 175 µs per loop
    
    In [32]: %timeit a[:100]==a[:100]
    10000 loops, best of 3: 133 µs per loop
    
    In [34]: %timeit ma[:100]==mb[:100]
    100000 loops, best of 3: 7.55 µs per loop
    
    In [35]: %timeit ma[:100]==ma[:100]
    10000 loops, best of 3: 156 µs per loop
    

Решения приведены здесь с (ma[:,None]==mb).nonzero().

  • использование in1d, для (Na+Nb) ln(Na+Nb) сложность, противNa*Nb на полное сравнение:

    %timeit in1d(ma,mb).nonzero()  : 590ms 
    

Не большой выигрыш здесь, но асимптотически лучше.

Память позволяет вам использовать broadcasting, вот так -

val_dateset[(train_dataset[:,None] == val_dateset).all(axis=(2,3)).any(0)]

Пробный прогон -

In [55]: train_dataset
Out[55]: 
array([[[1, 1],
        [1, 1]],

       [[1, 0],
        [0, 0]],

       [[0, 0],
        [0, 1]],

       [[0, 1],
        [0, 0]],

       [[1, 1],
        [1, 0]]])

In [56]: val_dateset
Out[56]: 
array([[[0, 1],
        [1, 0]],

       [[1, 1],
        [1, 1]],

       [[0, 0],
        [0, 1]]])

In [57]: val_dateset[(train_dataset[:,None] == val_dateset).all(axis=(2,3)).any(0)]
Out[57]: 
array([[[1, 1],
        [1, 1]],

       [[0, 0],
        [0, 1]]])

Если элементы целые, вы можете свернуть каждый блок axis=(1,2) во входных массивах в скаляр, предполагая их как линейно индексируемые числа, а затем эффективно использовать np.in1d или же np.intersect1d чтобы найти спички.

Решение

def overlap(a,b):
    """
    returns a boolean index array for input array b representing
    elements in b that are also found in a
    """
    a.repeat(b.shape[0],axis=0)
    b.repeat(a.shape[0],axis=0)
    c = aa == bb
    c = c[::a.shape[0]]
    return c.all(axis=1)[:,0]

Вы можете использовать возвращенный индексный массив для индексации b извлечь элементы, которые также находятся в a

b[overlap(a,b)]

объяснение

Для простоты я предполагаю, что вы импортировали все из numpy для этого примера:

from numpy import *

Так, например, даны два ndarrays

a = arange(4*2*2).reshape(4,2,2)
b = arange(3*2*2).reshape(3,2,2)

мы повторяем a а также b так что они имеют одинаковую форму

aa = a.repeat(b.shape[0],axis=0)
bb = b.repeat(a.shape[0],axis=0)

тогда мы можем просто сравнить элементы aa а также bb

c = aa == bb

Наконец, чтобы получить индексы элементов в b которые также найдены в a глядя на каждый 4-й, или на самом деле, каждый shape(a)[0]й элемент c

cc == c[::a.shape[0]]

Наконец, мы извлекаем индексный массив только с элементами, где все элементы в подмассивах True

c.all(axis=1)[:,0]

В нашем примере мы получаем

array([True,  True,  True], dtype=bool)

Чтобы проверить, измените первый элемент b

b[0] = array([[50,60],[70,80]])

и мы получаем

array([False,  True,  True], dtype=bool)

Этот вопрос задается онлайн-курсом глубокого обучения Google? Вот мое решение:

sum = 0 # number of overlapping rows
for i in range(val_dataset.shape[0]): # iterate over all rows of val_dataset
    overlap = (train_dataset == val_dataset[i,:,:]).all(axis=1).all(axis=1).sum()
    if overlap:
        sum += 1
print(sum)

Автоматическое вещание используется вместо итерации. Вы можете проверить разницу в производительности.

Другие вопросы по тегам