Numpy, проблема с длинными массивами

У меня есть два массива (а и б) с n целых элементов в диапазоне (0,N).

опечатка: массивы с 2^n целыми числами, где наибольшее целое принимает значение N = 3^n

Я хочу вычислить сумму каждой комбинации элементов в a и b (sum_ij_ = a_i_ + b_j_ для всех i, j). Затем возьмите модуль N (sum_ij_ = sum_ij_ % N) и, наконец, вычислите частоту различных сумм.

Чтобы сделать это быстро с NumPy, без каких-либо циклов, я попытался использовать сетку сетки и функцию bincount.

A,B = numpy.meshgrid(a,b)
A = A + B
A = A % N
A = numpy.reshape(A,A.size)
result = numpy.bincount(A)

Теперь проблема в том, что мои входные массивы длинные. И meshgrid дает мне MemoryError, когда я использую входные данные с 2^13 элементами. Я хотел бы рассчитать это для массивов с 2^15-2^20 элементов.

то есть n в диапазоне от 15 до 20

Есть какие-нибудь умные уловки, чтобы сделать это с NumPy?

Любая помощь будет высоко оценена.

- Джон

3 ответа

Решение

Попробуйте кусать это. Ваша сетка является матрицей NxN, блокируйте до 10x10 N/10xN/10 и просто вычисляйте 100 бинов, добавьте их в конце. он использует всего ~1% памяти, как и все это.

Редактировать в ответ на комментарий Йональма:

jonalm: N ~3^n не n~3^N. N - это максимальный элемент в a, а n - это количество элементов в a.

п ~ 2^20. Если N равно ~3^n, то N равно ~ 3^(2^20) > 10^(500207). По оценкам ученых ( http://www.stormloader.com/ajy/reallife.html), во Вселенной всего около 10^87 частиц. Таким образом, нет (наивного) способа, которым компьютер может обрабатывать целое число размером 10 ^ (500207).

jonalm: Мне, однако, немного любопытно узнать о функции pv(), которую вы определяете. (Мне не удается запустить его, так как text.find() не определен (угадайте его в другом модуле)). Как работает эта функция и в чем ее преимущество?

pv - маленькая вспомогательная функция, которую я написал для отладки значения переменных. Он работает как print() за исключением того, что когда вы говорите pv(x), он печатает как буквальное имя переменной (или строку выражения), двоеточие, а затем значение переменной.

Если вы положите

#!/usr/bin/env python
import traceback
def pv(var):
    (filename,line_number,function_name,text)=traceback.extract_stack()[-2]
    print('%s: %s'%(text[text.find('(')+1:-1],var))
x=1
pv(x)

в сценарии вы должны получить

x: 1

Скромное преимущество использования pv перед печатью заключается в том, что он экономит ваш набор текста. Вместо того, чтобы писать

print('x: %s'%x)

Вы можете просто ударить

pv(x)

Когда нужно отслеживать несколько переменных, полезно пометить переменные. Я просто устал писать все это.

Функция pv работает с помощью модуля traceback для просмотра строки кода, используемой для вызова самой функции pv. (См. http://docs.python.org/library/traceback.html) Эта строка кода хранится в виде строки в тексте переменной. text.find() - это вызов обычного строкового метода find(). Например, если

text='pv(x)'

затем

text.find('(') == 2               # The index of the '(' in string text
text[text.find('(')+1:-1] == 'x'  # Everything in between the parentheses

Я предполагаю, что n ~ 3^N, а n~2**20

Идея состоит в том, чтобы работать модуль N. Это сокращает размер массивов. Вторая идея (важная, когда n велико) состоит в том, чтобы использовать numy ndarrays типа 'object', потому что если вы используете целочисленный тип d, вы рискуете переполниться размером максимально допустимого целого числа.

#!/usr/bin/env python
import traceback
import numpy as np

def pv(var):
    (filename,line_number,function_name,text)=traceback.extract_stack()[-2]
    print('%s: %s'%(text[text.find('(')+1:-1],var))

Вы можете изменить n на 2**20, но ниже я покажу, что происходит с маленьким n, так что вывод легче читать.

n=100
N=int(np.exp(1./3*np.log(n)))
pv(N)
# N: 4

a=np.random.randint(N,size=n)
b=np.random.randint(N,size=n)
pv(a)
pv(b)
# a: [1 0 3 0 1 0 1 2 0 2 1 3 1 0 1 2 2 0 2 3 3 3 1 0 1 1 2 0 1 2 3 1 2 1 0 0 3
#  1 3 2 3 2 1 1 2 2 0 3 0 2 0 0 2 2 1 3 0 2 1 0 2 3 1 0 1 1 0 1 3 0 2 2 0 2
#  0 2 3 0 2 0 1 1 3 2 2 3 2 0 3 1 1 1 1 2 3 3 2 2 3 1]
# b: [1 3 2 1 1 2 1 1 1 3 0 3 0 2 2 3 2 0 1 3 1 0 0 3 3 2 1 1 2 0 1 2 0 3 3 1 0
#  3 3 3 1 1 3 3 3 1 1 0 2 1 0 0 3 0 2 1 0 2 2 0 0 0 1 1 3 1 1 1 2 1 1 3 2 3
#  3 1 2 1 0 0 2 3 1 0 2 1 1 1 1 3 3 0 2 2 3 2 0 1 3 1]

wa содержит число 0s, 1s, 2s, 3s в wb содержит число 0s, 1s, 2s, 3s в b

wa=np.bincount(a)
wb=np.bincount(b)
pv(wa)
pv(wb)
# wa: [24 28 28 20]
# wb: [21 34 20 25]
result=np.zeros(N,dtype='object')

Думайте о 0 как о токене или чипе. Аналогично для 1,2,3.

Представьте, что wa=[24 28 28 20] означает, что есть пакет с 24 0 фишками, 28 1 фишками, 28 2 фишками, 20 3 фишками.

У вас есть ва-сумка и wb-сумка. Когда вы берете фишку из каждой сумки, вы "складываете" их вместе и формируете новую фишку. Вы "мод" ответ (по модулю N).

Представьте себе, что вы берете 1 чип из wb-bag и добавляете его вместе с каждым чипом в wa-bag.

1-chip + 0-chip = 1-chip
1-chip + 1-chip = 2-chip
1-chip + 2-chip = 3-chip
1-chip + 3-chip = 4-chip = 0-chip  (we are mod'ing by N=4)

Поскольку в пакете wb есть 34 1-фишки, добавляя их ко всем фишкам в пакете wa=[24 28 28 20], вы получаете

34*24 1-chips
34*28 2-chips
34*28 3-chips
34*20 0-chips

Это только частичный счет из-за 34 1-фишек. Вы также должны обращаться с другими типами чипов в wb-bag, но это показывает вам метод, использованный ниже:

for i,count in enumerate(wb):
    partial_count=count*wa
    pv(partial_count)
    shifted_partial_count=np.roll(partial_count,i)
    pv(shifted_partial_count)
    result+=shifted_partial_count
# partial_count: [504 588 588 420]
# shifted_partial_count: [504 588 588 420]
# partial_count: [816 952 952 680]
# shifted_partial_count: [680 816 952 952]
# partial_count: [480 560 560 400]
# shifted_partial_count: [560 400 480 560]
# partial_count: [600 700 700 500]
# shifted_partial_count: [700 700 500 600]

pv(result)    
# result: [2444 2504 2520 2532]

Это конечный результат: 2444 0, 2504 1, 2520 2, 2532 3.

# This is a test to make sure the result is correct.
# This uses a very memory intensive method.
# c is too huge when n is large.
if n>1000:
    print('n is too large to run the check')
else:
    c=(a[:]+b[:,np.newaxis])
    c=c.ravel()
    c=c%N
    result2=np.bincount(c)
    pv(result2)
    assert(all(r1==r2 for r1,r2 in zip(result,result2)))
# result2: [2444 2504 2520 2532]

Проверьте свою математику, это много места, которое вы просите:

2 ^ 20 * 2 ^ 20 = 2 ^ 40 = 1 099 511 627 776

Если каждый из ваших элементов был всего один байт, это уже один терабайт памяти.

Добавьте цикл или два. Эта проблема не подходит для увеличения вашей памяти и минимизации вычислений.

Другие вопросы по тегам