Numpy, проблема с длинными массивами
У меня есть два массива (а и б) с n целых элементов в диапазоне (0,N).
опечатка: массивы с 2^n целыми числами, где наибольшее целое принимает значение N = 3^n
Я хочу вычислить сумму каждой комбинации элементов в a и b (sum_ij_ = a_i_ + b_j_ для всех i, j). Затем возьмите модуль N (sum_ij_ = sum_ij_ % N) и, наконец, вычислите частоту различных сумм.
Чтобы сделать это быстро с NumPy, без каких-либо циклов, я попытался использовать сетку сетки и функцию bincount.
A,B = numpy.meshgrid(a,b)
A = A + B
A = A % N
A = numpy.reshape(A,A.size)
result = numpy.bincount(A)
Теперь проблема в том, что мои входные массивы длинные. И meshgrid дает мне MemoryError, когда я использую входные данные с 2^13 элементами. Я хотел бы рассчитать это для массивов с 2^15-2^20 элементов.
то есть n в диапазоне от 15 до 20
Есть какие-нибудь умные уловки, чтобы сделать это с NumPy?
Любая помощь будет высоко оценена.
- Джон
3 ответа
Попробуйте кусать это. Ваша сетка является матрицей NxN, блокируйте до 10x10 N/10xN/10 и просто вычисляйте 100 бинов, добавьте их в конце. он использует всего ~1% памяти, как и все это.
Редактировать в ответ на комментарий Йональма:
jonalm: N ~3^n не n~3^N. N - это максимальный элемент в a, а n - это количество элементов в a.
п ~ 2^20. Если N равно ~3^n, то N равно ~ 3^(2^20) > 10^(500207). По оценкам ученых ( http://www.stormloader.com/ajy/reallife.html), во Вселенной всего около 10^87 частиц. Таким образом, нет (наивного) способа, которым компьютер может обрабатывать целое число размером 10 ^ (500207).
jonalm: Мне, однако, немного любопытно узнать о функции pv(), которую вы определяете. (Мне не удается запустить его, так как text.find() не определен (угадайте его в другом модуле)). Как работает эта функция и в чем ее преимущество?
pv - маленькая вспомогательная функция, которую я написал для отладки значения переменных. Он работает как print() за исключением того, что когда вы говорите pv(x), он печатает как буквальное имя переменной (или строку выражения), двоеточие, а затем значение переменной.
Если вы положите
#!/usr/bin/env python
import traceback
def pv(var):
(filename,line_number,function_name,text)=traceback.extract_stack()[-2]
print('%s: %s'%(text[text.find('(')+1:-1],var))
x=1
pv(x)
в сценарии вы должны получить
x: 1
Скромное преимущество использования pv перед печатью заключается в том, что он экономит ваш набор текста. Вместо того, чтобы писать
print('x: %s'%x)
Вы можете просто ударить
pv(x)
Когда нужно отслеживать несколько переменных, полезно пометить переменные. Я просто устал писать все это.
Функция pv работает с помощью модуля traceback для просмотра строки кода, используемой для вызова самой функции pv. (См. http://docs.python.org/library/traceback.html) Эта строка кода хранится в виде строки в тексте переменной. text.find() - это вызов обычного строкового метода find(). Например, если
text='pv(x)'
затем
text.find('(') == 2 # The index of the '(' in string text
text[text.find('(')+1:-1] == 'x' # Everything in between the parentheses
Я предполагаю, что n ~ 3^N, а n~2**20
Идея состоит в том, чтобы работать модуль N. Это сокращает размер массивов. Вторая идея (важная, когда n велико) состоит в том, чтобы использовать numy ndarrays типа 'object', потому что если вы используете целочисленный тип d, вы рискуете переполниться размером максимально допустимого целого числа.
#!/usr/bin/env python
import traceback
import numpy as np
def pv(var):
(filename,line_number,function_name,text)=traceback.extract_stack()[-2]
print('%s: %s'%(text[text.find('(')+1:-1],var))
Вы можете изменить n на 2**20, но ниже я покажу, что происходит с маленьким n, так что вывод легче читать.
n=100
N=int(np.exp(1./3*np.log(n)))
pv(N)
# N: 4
a=np.random.randint(N,size=n)
b=np.random.randint(N,size=n)
pv(a)
pv(b)
# a: [1 0 3 0 1 0 1 2 0 2 1 3 1 0 1 2 2 0 2 3 3 3 1 0 1 1 2 0 1 2 3 1 2 1 0 0 3
# 1 3 2 3 2 1 1 2 2 0 3 0 2 0 0 2 2 1 3 0 2 1 0 2 3 1 0 1 1 0 1 3 0 2 2 0 2
# 0 2 3 0 2 0 1 1 3 2 2 3 2 0 3 1 1 1 1 2 3 3 2 2 3 1]
# b: [1 3 2 1 1 2 1 1 1 3 0 3 0 2 2 3 2 0 1 3 1 0 0 3 3 2 1 1 2 0 1 2 0 3 3 1 0
# 3 3 3 1 1 3 3 3 1 1 0 2 1 0 0 3 0 2 1 0 2 2 0 0 0 1 1 3 1 1 1 2 1 1 3 2 3
# 3 1 2 1 0 0 2 3 1 0 2 1 1 1 1 3 3 0 2 2 3 2 0 1 3 1]
wa содержит число 0s, 1s, 2s, 3s в wb содержит число 0s, 1s, 2s, 3s в b
wa=np.bincount(a)
wb=np.bincount(b)
pv(wa)
pv(wb)
# wa: [24 28 28 20]
# wb: [21 34 20 25]
result=np.zeros(N,dtype='object')
Думайте о 0 как о токене или чипе. Аналогично для 1,2,3.
Представьте, что wa=[24 28 28 20] означает, что есть пакет с 24 0 фишками, 28 1 фишками, 28 2 фишками, 20 3 фишками.
У вас есть ва-сумка и wb-сумка. Когда вы берете фишку из каждой сумки, вы "складываете" их вместе и формируете новую фишку. Вы "мод" ответ (по модулю N).
Представьте себе, что вы берете 1 чип из wb-bag и добавляете его вместе с каждым чипом в wa-bag.
1-chip + 0-chip = 1-chip
1-chip + 1-chip = 2-chip
1-chip + 2-chip = 3-chip
1-chip + 3-chip = 4-chip = 0-chip (we are mod'ing by N=4)
Поскольку в пакете wb есть 34 1-фишки, добавляя их ко всем фишкам в пакете wa=[24 28 28 20], вы получаете
34*24 1-chips
34*28 2-chips
34*28 3-chips
34*20 0-chips
Это только частичный счет из-за 34 1-фишек. Вы также должны обращаться с другими типами чипов в wb-bag, но это показывает вам метод, использованный ниже:
for i,count in enumerate(wb):
partial_count=count*wa
pv(partial_count)
shifted_partial_count=np.roll(partial_count,i)
pv(shifted_partial_count)
result+=shifted_partial_count
# partial_count: [504 588 588 420]
# shifted_partial_count: [504 588 588 420]
# partial_count: [816 952 952 680]
# shifted_partial_count: [680 816 952 952]
# partial_count: [480 560 560 400]
# shifted_partial_count: [560 400 480 560]
# partial_count: [600 700 700 500]
# shifted_partial_count: [700 700 500 600]
pv(result)
# result: [2444 2504 2520 2532]
Это конечный результат: 2444 0, 2504 1, 2520 2, 2532 3.
# This is a test to make sure the result is correct.
# This uses a very memory intensive method.
# c is too huge when n is large.
if n>1000:
print('n is too large to run the check')
else:
c=(a[:]+b[:,np.newaxis])
c=c.ravel()
c=c%N
result2=np.bincount(c)
pv(result2)
assert(all(r1==r2 for r1,r2 in zip(result,result2)))
# result2: [2444 2504 2520 2532]
Проверьте свою математику, это много места, которое вы просите:
2 ^ 20 * 2 ^ 20 = 2 ^ 40 = 1 099 511 627 776
Если каждый из ваших элементов был всего один байт, это уже один терабайт памяти.
Добавьте цикл или два. Эта проблема не подходит для увеличения вашей памяти и минимизации вычислений.