Почему модель JAX + STAX требует больше памяти GPU, чем необходимо?

Я пытаюсь запустить модель JAX + STAX из ядер Kaggle на графическом процессоре, но это не удается из-за ошибки нехватки памяти. я установилXLA_PYTHON_CLIENT_PREALLOCATEкfalseчтобы избежать предварительного выделения памяти графического процессора, а также попытался установитьXLA_PYTHON_CLIENT_ALLOCATORкplatform, ничего не помогло. Устройство по умолчанию настроено на ЦП с самого начала, так как я не хочу, чтобы все данные хранились на графическом процессоре. Модель и пакетные данные отправляются в GPU вручную. Размер переменных (параметры модели, данные...) не должен быть проблемой, так как один и тот же код работает гладко на процессоре, без ошибок OOM. Я также сделал профилирование памяти модели. Чтобы получить только память GPU, нужно было сделать другую версию кода, где GPU является устройством по умолчанию и все данные хранятся там. Если я запустил профилирование исходного кода, где ЦП по умолчанию, я получаю профилирование только для данных ЦП. Уменьшение размера партии до 10 также было необходимо для завершения обучения модели. Профилирование показывает только объем памяти, необходимый для хранения данных и параметров (≈ 5,5 ГБ),batch_size = 100память также достигает 14,6 ГБ во время первой мини-партии, но не может идти дальше).

Вот упрощенная версия кода, который я использовал:

      import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = 'false'
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = 'platform' # Tried this, didn't help

import jax
from jax.lib import xla_bridge
jax.config.update('jax_platform_name', 'cpu') # If not set default device = CPU then all the device arrays will be saved to GPU by default

# Set the processor to GPU if available
try: print('Available GPU Devices: ', jax.devices("gpu")); device = jax.devices("gpu")[0]; gpu_available = 1
except: device = jax.devices("cpu")[0]; gpu_available = 0

# Load data into jax device arrays of dimensions (2000, 200, 200, 3)...

InitializationFunction, ApplyFunction = stax.serial(
    Conv(out_chan = 64, filter_shape = (3, 3), strides = (1, 1), padding = 'SAME'), Relu,
    Conv(out_chan = 64, filter_shape = (3, 3), strides = (1, 1), padding = 'SAME'), Relu,
    Flatten, Dense(128), Relu, Dense(2),)

key = random.PRNGKey(2793)
output_shape, parameters = jax.device_put(InitializationFunction(rng = key, input_shape = (100, image_width, image_height, number_of_channels)), device)
optimizer = optax.adam(0.001)
optimizer_state = jax.device_put(optimizer.init(parameters), device)

def Loss(parameters, inputs, targets):
    predictions = ApplyFunction(parameters, inputs)
    loss = jnp.mean(optax.softmax_cross_entropy(predictions, targets))
    return loss

@jit
def Step(parameters, optimizer_state, inputs, targets):
    loss, gradients = value_and_grad(Loss)(parameters, inputs, targets)
    updates, optimizer_state = optimizer.update(gradients, optimizer_state, parameters)
    parameters = optax.apply_updates(parameters, updates)
    return parameters, optimizer_state, loss

epochs, batch_size = 2, 100
key, subkey = random.split(key)
keys_epochs = random.split(subkey, epochs)
    
for epoch in range(epochs):
    random_indices_order = random.permutation(keys_epochs[epoch], jnp.arange(len(train_set['images'])))

    for batch_number in range(len(train_set['images']) // batch_size):
        start = batch_number * batch_size
        end = (batch_number + 1) * batch_size
        batch_inputs = jax.device_put(jnp.take(train_set['images'], random_indices_order[start:end], 0), device)
        batch_targets = jax.device_put(OneHot(jnp.take(train_set['class_numbers'], random_indices_order[start:end], 0), jnp.max(train_set['class_numbers']) + 1), device)
        parameters, optimizer_state, loss = Step(parameters, optimizer_state, inputs = batch_inputs, targets = batch_targets)          

Мои вопросы:

  1. Почему используется больше памяти графического процессора, чем необходимо для размера переменных, и больше, чем захвачено с помощью профилирования памяти устройства jax? Для чего используется избыток памяти, как его отследить и как предотвратить?
  2. Как захватить память ЦП и ГП при профилировании памяти устройства jax? Он захватывает ЦП только тогда, когда ЦП является устройством по умолчанию, хотя ГП также доступен и используется.

Вот результат профилирования памяти устройства для GPU, когда GPU настроен как устройство по умолчанию и хранит весь набор данных (2x(2000, 200, 200, 3) ≈ 1,79 ГБ). Размер пакета уменьшен до 10.Профилирование памяти устройства Jax GPU для размера пакета 10

0 ответов

Другие вопросы по тегам