(API tf.data) Утечка памяти при использовании train/valid handle

Моя программа тензорного потока, похоже, не освобождает память, и мое расследование указывает на виновника tf.data. Я добавил sess.graph.finalize() заявление, чтобы исключить добавление операций в граф. Я новичок в tf.data, поэтому, вероятно, я просто не использую API так, как это было задумано. Может быть, кто-то может обнаружить глупую ошибку в моем коде.

Документация tf.data предлагает несколько различных способов чередования обучения и проверки: один, который переключается между двумя итераторами с помощью дескриптора строки, и другой, который повторно инициализирует один итератор. В моем случае первый метод, по-видимому, приводит к утечке памяти, а второй - нет. Кто-нибудь знает, почему первый дает утечку памяти?

Код проприетарный, так что извините, если я не поделился достаточно, чтобы быть полезным. Я могу попытаться поделиться более или собрать минимальный воспроизводимый пример.

Похоже, это приводит к утечке памяти (ключевое отличие в разделе LOOK HERE):

train_dataset, val_dataset = train_utils.get_datasets()
types = train_dataset.output_types
shapes = train_dataset.output_shapes

# LOOK HERE #
train_iterator = train_dataset.make_one_shot_iterator()
val_iterator = val_dataset.make_one_shot_iterator()
handle = tf.placeholder(tf.string, shape=[], name='dataset_handle')
iterator = tf.data.Iterator.from_string_handle(handle, types, shapes)

inputs, outputs = iterator.get_next()

model = MyModel(inputs)
loss = MyLoss(model.outputs, outputs)

train_step = ...
val_step = ...

with tf.Session() as sess:
    train_handle = sess.run(train_iterator.string_handle())
    val_handle = sess.run(val_iterator.string_handle())

    sess.graph.finalize()
    while True:
        sess.run(train_step, {handle: train_handle})
        sess.run(val_step, {handle: val_handle})

Это не приводит к утечке памяти (см. ЗДЕСЬ):

train_dataset, val_dataset = train_utils.get_datasets()
types = train_dataset.output_types
shapes = train_dataset.output_shapes

# LOOK HERE #
iterator = tf.data.Iterator.from_structure(types, shapes)
train_init_op = iterator.make_initializer(train_dataset)
val_init_op = iterator.make_initializer(val_dataset)

inputs, outputs = iterator.get_next()

model = MyModel(inputs)
loss = MyLoss(model.outputs, outputs)

train_step = ...
val_step = ...

with tf.Session() as sess:
    sess.graph.finalize()
    while True:
        sess.run(train_init_op)
        sess.run(train_step)

        sess.run(val_init_op)
        sess.run(val_step)

Единственное, что может быть немного неортодоксальным в моем коде, это то, что мой train_utils.get_datasets() Функция использует генератор для повторения (маленького) массива numpy для каждого элемента в наборе данных (один и тот же массив распределяется между всеми примерами обучения и val и предварительно вычисляется один раз). Еще одна вещь, которая может быть причиной проблем, заключается в том, что я использую много tf.py_func s, чтобы запустить пользовательскую предварительную обработку в cython и numpy. Вот упрощенный фрагмент кода генератора:

my_np_array = np.array(...)
def gen():
    while True: yield my_np_array

data_1 = tf.data.Dataset.from_generator(gen)
data_2 = tf.data.Dataset.from_csv(...)
data = tf.data.Dataset.zip((data_1, data_2))

Я отслеживаю использование памяти в терминале, используя htop и в Python с помощью psutil.Process(pid).memory_info(), Я не думаю, что увеличение памяти совпадает с каким-либо регулярным периодом проверки или контрольной точки в моем учебном скрипте.

Утечка памяти: Утечка памяти (метод дескриптора строки

Нет утечки памяти: Нет утечки памяти (переинициализируемый итератор

Изменить: неважно. Кажется, проблема решилась сама собой. Я не уверен, что вызвало утечку памяти, но это, кажется, не ошибка tf.data.

0 ответов

Другие вопросы по тегам