Размер массива с шишечками и широковещательной памятью в памяти?

Я пытаюсь создать эффективные широковещательные массивы в NumPy, например, набор shape=[1000,1000,1000] массивы, которые имеют только 1000 элементов, но повторяются 1e6 раз. Это может быть достигнуто как через np.lib.stride_tricks.as_strided а также np.broadcast_arrays,

Однако у меня возникают проблемы с проверкой отсутствия дублирования в памяти, и это очень важно, поскольку тесты, которые на самом деле дублируют массивы в памяти, приводят к сбоям в работе моей машины и не оставляют следов.

Я попытался изучить размер массивов, используя .nbytes, но это, похоже, не соответствует фактическому использованию памяти:

>>> import numpy as np
>>> import resource
>>> initial_memuse = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
>>> pagesize = resource.getpagesize()
>>>
>>> x = np.arange(1000)
>>> memuse_x = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
>>> print("Size of x = {0} MB".format(x.nbytes/1e6))
Size of x = 0.008 MB
>>> print("Memory used = {0} MB".format((memuse_x-initial_memuse)*resource.getpagesize()/1e6))
Memory used = 150.994944 MB
>>>
>>> y = np.lib.stride_tricks.as_strided(x, [1000,10,10], strides=x.strides + (0, 0))
>>> memuse_y = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
>>> print("Size of y = {0} MB".format(y.nbytes/1e6))
Size of y = 0.8 MB
>>> print("Memory used = {0} MB".format((memuse_y-memuse_x)*resource.getpagesize()/1e6))
Memory used = 201.326592 MB
>>>
>>> z = np.lib.stride_tricks.as_strided(x, [1000,100,100], strides=x.strides + (0, 0))
>>> memuse_z = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
>>> print("Size of z = {0} MB".format(z.nbytes/1e6))
Size of z = 80.0 MB
>>> print("Memory used = {0} MB".format((memuse_z-memuse_y)*resource.getpagesize()/1e6))
Memory used = 0.0 MB

Так .nbytes сообщает о "теоретическом" размере массива, но, очевидно, не фактическом размере. resource проверка немного неловкая, так как похоже, что некоторые вещи загружаются и кэшируются (возможно?), что приводит к тому, что первый шаг занимает некоторое количество памяти, но будущие шаги не занимают ничего.

tl; dr: Как вы определяете фактический размер массива или представления массива в памяти?

1 ответ

Решение

Одним из способов было бы изучить .base атрибут массива, который ссылается на объект, из которого массив "занимает" свою память. Например:

x = np.arange(1000)
print(x.flags.owndata)      # x "owns" its data
# True
print(x.base is None)       # its base is therefore 'None'
# True

a = x.reshape(100, 10)      # a is a reshaped view onto x
print(a.flags.owndata)      # it therefore "borrows" its data
# False
print(a.base is x)          # its .base is x
# True

Все немного сложнее с np.lib.stride_tricks:

b = np.lib.stride_tricks.as_strided(x, [1000,100,100], strides=x.strides + (0, 0))

print(b.flags.owndata)
# False
print(b.base)   
# <numpy.lib.stride_tricks.DummyArray object at 0x7fb40c02b0f0>

Вот, b.base это numpy.lib.stride_tricks.DummyArray экземпляр, который выглядит так:

class DummyArray(object):
    """Dummy object that just exists to hang __array_interface__ dictionaries
    and possibly keep alive a reference to a base array.
    """

    def __init__(self, interface, base=None):
        self.__array_interface__ = interface
        self.base = base

Поэтому мы можем изучить b.base.base:

print(b.base.base is x)
# True

Если у вас есть базовый массив, то его .nbytes атрибут должен точно отражать объем памяти, который он занимает.

В принципе, можно иметь представление о виде массива или создать пошаговый массив из другого пошагового массива. Предполагая, что ваше представление или расширенный массив в конечном итоге поддерживаются другим пустым массивом, вы можете рекурсивно ссылаться на его .base приписывать. Как только вы найдете объект, чей .base является NoneВы нашли базовый объект, у которого ваш массив занимает свою память:

def find_base_nbytes(obj):
    if obj.base is not None:
        return find_base_nbytes(obj.base)
    return obj.nbytes

Как и ожидалось,

print(find_base_nbytes(x))
# 8000

print(find_base_nbytes(y))
# 8000

print(find_base_nbytes(z))
# 8000
Другие вопросы по тегам