Почему модель JAX + STAX требует больше памяти GPU, чем необходимо?
Я пытаюсь запустить модель JAX + STAX из ядер Kaggle на графическом процессоре, но это не удается из-за ошибки нехватки памяти. я установилXLA_PYTHON_CLIENT_PREALLOCATE
кfalse
чтобы избежать предварительного выделения памяти графического процессора, а также попытался установитьXLA_PYTHON_CLIENT_ALLOCATOR
кplatform
, ничего не помогло. Устройство по умолчанию настроено на ЦП с самого начала, так как я не хочу, чтобы все данные хранились на графическом процессоре. Модель и пакетные данные отправляются в GPU вручную. Размер переменных (параметры модели, данные...) не должен быть проблемой, так как один и тот же код работает гладко на процессоре, без ошибок OOM. Я также сделал профилирование памяти модели. Чтобы получить только память GPU, нужно было сделать другую версию кода, где GPU является устройством по умолчанию и все данные хранятся там. Если я запустил профилирование исходного кода, где ЦП по умолчанию, я получаю профилирование только для данных ЦП. Уменьшение размера партии до 10 также было необходимо для завершения обучения модели. Профилирование показывает только объем памяти, необходимый для хранения данных и параметров (≈ 5,5 ГБ),batch_size = 100
память также достигает 14,6 ГБ во время первой мини-партии, но не может идти дальше).
Вот упрощенная версия кода, который я использовал:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = 'false'
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = 'platform' # Tried this, didn't help
import jax
from jax.lib import xla_bridge
jax.config.update('jax_platform_name', 'cpu') # If not set default device = CPU then all the device arrays will be saved to GPU by default
# Set the processor to GPU if available
try: print('Available GPU Devices: ', jax.devices("gpu")); device = jax.devices("gpu")[0]; gpu_available = 1
except: device = jax.devices("cpu")[0]; gpu_available = 0
# Load data into jax device arrays of dimensions (2000, 200, 200, 3)...
InitializationFunction, ApplyFunction = stax.serial(
Conv(out_chan = 64, filter_shape = (3, 3), strides = (1, 1), padding = 'SAME'), Relu,
Conv(out_chan = 64, filter_shape = (3, 3), strides = (1, 1), padding = 'SAME'), Relu,
Flatten, Dense(128), Relu, Dense(2),)
key = random.PRNGKey(2793)
output_shape, parameters = jax.device_put(InitializationFunction(rng = key, input_shape = (100, image_width, image_height, number_of_channels)), device)
optimizer = optax.adam(0.001)
optimizer_state = jax.device_put(optimizer.init(parameters), device)
def Loss(parameters, inputs, targets):
predictions = ApplyFunction(parameters, inputs)
loss = jnp.mean(optax.softmax_cross_entropy(predictions, targets))
return loss
@jit
def Step(parameters, optimizer_state, inputs, targets):
loss, gradients = value_and_grad(Loss)(parameters, inputs, targets)
updates, optimizer_state = optimizer.update(gradients, optimizer_state, parameters)
parameters = optax.apply_updates(parameters, updates)
return parameters, optimizer_state, loss
epochs, batch_size = 2, 100
key, subkey = random.split(key)
keys_epochs = random.split(subkey, epochs)
for epoch in range(epochs):
random_indices_order = random.permutation(keys_epochs[epoch], jnp.arange(len(train_set['images'])))
for batch_number in range(len(train_set['images']) // batch_size):
start = batch_number * batch_size
end = (batch_number + 1) * batch_size
batch_inputs = jax.device_put(jnp.take(train_set['images'], random_indices_order[start:end], 0), device)
batch_targets = jax.device_put(OneHot(jnp.take(train_set['class_numbers'], random_indices_order[start:end], 0), jnp.max(train_set['class_numbers']) + 1), device)
parameters, optimizer_state, loss = Step(parameters, optimizer_state, inputs = batch_inputs, targets = batch_targets)
Мои вопросы:
- Почему используется больше памяти графического процессора, чем необходимо для размера переменных, и больше, чем захвачено с помощью профилирования памяти устройства jax? Для чего используется избыток памяти, как его отследить и как предотвратить?
- Как захватить память ЦП и ГП при профилировании памяти устройства jax? Он захватывает ЦП только тогда, когда ЦП является устройством по умолчанию, хотя ГП также доступен и используется.
Вот результат профилирования памяти устройства для GPU, когда GPU настроен как устройство по умолчанию и хранит весь набор данных (2x(2000, 200, 200, 3) ≈ 1,79 ГБ). Размер пакета уменьшен до 10.Профилирование памяти устройства Jax GPU для размера пакета 10