(API tf.data) Утечка памяти при использовании train/valid handle
Моя программа тензорного потока, похоже, не освобождает память, и мое расследование указывает на виновника tf.data. Я добавил sess.graph.finalize()
заявление, чтобы исключить добавление операций в граф. Я новичок в tf.data, поэтому, вероятно, я просто не использую API так, как это было задумано. Может быть, кто-то может обнаружить глупую ошибку в моем коде.
Документация tf.data предлагает несколько различных способов чередования обучения и проверки: один, который переключается между двумя итераторами с помощью дескриптора строки, и другой, который повторно инициализирует один итератор. В моем случае первый метод, по-видимому, приводит к утечке памяти, а второй - нет. Кто-нибудь знает, почему первый дает утечку памяти?
Код проприетарный, так что извините, если я не поделился достаточно, чтобы быть полезным. Я могу попытаться поделиться более или собрать минимальный воспроизводимый пример.
Похоже, это приводит к утечке памяти (ключевое отличие в разделе LOOK HERE):
train_dataset, val_dataset = train_utils.get_datasets()
types = train_dataset.output_types
shapes = train_dataset.output_shapes
# LOOK HERE #
train_iterator = train_dataset.make_one_shot_iterator()
val_iterator = val_dataset.make_one_shot_iterator()
handle = tf.placeholder(tf.string, shape=[], name='dataset_handle')
iterator = tf.data.Iterator.from_string_handle(handle, types, shapes)
inputs, outputs = iterator.get_next()
model = MyModel(inputs)
loss = MyLoss(model.outputs, outputs)
train_step = ...
val_step = ...
with tf.Session() as sess:
train_handle = sess.run(train_iterator.string_handle())
val_handle = sess.run(val_iterator.string_handle())
sess.graph.finalize()
while True:
sess.run(train_step, {handle: train_handle})
sess.run(val_step, {handle: val_handle})
Это не приводит к утечке памяти (см. ЗДЕСЬ):
train_dataset, val_dataset = train_utils.get_datasets()
types = train_dataset.output_types
shapes = train_dataset.output_shapes
# LOOK HERE #
iterator = tf.data.Iterator.from_structure(types, shapes)
train_init_op = iterator.make_initializer(train_dataset)
val_init_op = iterator.make_initializer(val_dataset)
inputs, outputs = iterator.get_next()
model = MyModel(inputs)
loss = MyLoss(model.outputs, outputs)
train_step = ...
val_step = ...
with tf.Session() as sess:
sess.graph.finalize()
while True:
sess.run(train_init_op)
sess.run(train_step)
sess.run(val_init_op)
sess.run(val_step)
Единственное, что может быть немного неортодоксальным в моем коде, это то, что мой train_utils.get_datasets()
Функция использует генератор для повторения (маленького) массива numpy для каждого элемента в наборе данных (один и тот же массив распределяется между всеми примерами обучения и val и предварительно вычисляется один раз). Еще одна вещь, которая может быть причиной проблем, заключается в том, что я использую много tf.py_func
s, чтобы запустить пользовательскую предварительную обработку в cython и numpy. Вот упрощенный фрагмент кода генератора:
my_np_array = np.array(...)
def gen():
while True: yield my_np_array
data_1 = tf.data.Dataset.from_generator(gen)
data_2 = tf.data.Dataset.from_csv(...)
data = tf.data.Dataset.zip((data_1, data_2))
Я отслеживаю использование памяти в терминале, используя htop
и в Python с помощью psutil.Process(pid).memory_info()
, Я не думаю, что увеличение памяти совпадает с каким-либо регулярным периодом проверки или контрольной точки в моем учебном скрипте.
Изменить: неважно. Кажется, проблема решилась сама собой. Я не уверен, что вызвало утечку памяти, но это, кажется, не ошибка tf.data.