Как ускорить пакетную подготовку при использовании Estimators API в сочетании с tf.data.Dataset
Я хотел бы ускорить мою тренировочную программу, использующую API Estimator с записью input_fn, используя tf.data.Dataset
,
Моя реализация занимает 2 секунды, чтобы подготовить пакет данных, а затем запускает обучение на графическом процессоре в течение 1 секунды, а затем начинает заново подготовку пакета. Что действительно неэффективно.
Я ищу способ подготовить пакеты асинхронно и загрузить их в GPU, чтобы ускорить обучение. Или в качестве альтернативы для способа кэширования наборов данных между вызовами input_fn
(dataset.cache()
не кажется хорошим выбором, так как набор данных должен быть воссоздан при каждом вызове input_fn).
Вот упрощенная версия моего кода:
def input_fn(filenames, labels, epochs):
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_read_wav, num_parallel_calls=num_map_threads)
if shuffle:
dataset = dataset.shuffle(buffer_size=len(labels))
dataset = dataset.map(_post_process, num_parallel_calls=num_map_threads)
dataset = dataset.map(lambda wav, label: ({'wav': wav}, label))
dataset = dataset.batch(128)
dataset = dataset.repeat(epochs) # to iterate over the training set forever
iterator = dataset.dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
train_input_fn = lambda : input_fn(train_files, train_labels, None)
eval_input_fn = lambda : input_fn(eval_files, eval_labels, 1)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=45000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
Я заметил, что API Estimator находится в активной разработке, и в основной ветке tenorflow input_fn уже может возвращать наборы данных, поэтому, возможно, я спрашиваю слишком рано, и эта функция еще не готова. Но если так, пожалуйста, предоставьте билет, где эта реализация может быть отслежена.
2 ответа
С помощью tf.data.Dataset.cache()
это действительно плохой выбор, поскольку он кэширует весь набор данных в память, что занимает много времени и может переполнить вашу память.
Путь должен использовать tf.data.Dataset.prefetch()
в конце вашего конвейера, который всегда будет гарантировать, что конвейер данных содержит buffer_size
элементы. Обычно достаточно иметь buffer_size = 1
в конце:
dataset = ...
dataset = dataset.batch(128)
dataset = dataset.prefetch(1) # prefetch one batch
Как объясняет @mrry в этом ответе, вы также можете попытаться немного увеличить количество предварительно выбранных пакетов.
Как правило, наиболее полезно добавить небольшой буфер предварительной выборки (возможно, с одним элементом) в самом конце конвейера, но более сложные конвейеры могут выиграть от дополнительной предварительной выборки, особенно когда время для создания одного элемента может изменяться.
Если у вас по-прежнему медленный входной конвейер по сравнению с вычислениями на GPU, вам нужно увеличить количество потоков, работающих параллельно, используя num_parallel_calls
аргумент tf.data.Dataset.map()
,
Несколько моментов, которые нужно добавить к ответу Оливье, в основном из этого поста:
repeat
доshuffle
немного быстрее, в нижней части размытых границ эпохи. Это может быть важно в редких случаях, но я сомневаюсь в этом.shuffle
доmap
ping - это уменьшает отпечаток памяти вашего размера буфера случайного использования, так как для этого требуется только буферизовать имена файлов, а не содержимое файла.- для меня более логично применить преобразование третьей карты к выводу
get_next()
а не набор данных - не уверен, сильно ли это влияет на скорость. Вы также можете рассмотреть вопрос о том, чтобы поместить оба других вызова карты в один и тот же, чтобы уменьшить проблемы планирования. - Эксперимент с
repeat
доbatch
ING. Вероятно, не будет иметь значения, но может быть незначительным. если тыrepeat
доshuffle
как уже упоминалось выше, вам придется. - как уже упоминалось Оливье, используйте
prefetch
,
Код с изменениями:
def input_fn(filenames, labels, epochs):
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.repeat(epochs)
if shuffle:
dataset = dataset.shuffle(buffer_size=len(labels))
def combined_map_fn(*args):
return _post_process(_read_wav(*args))
dataset = dataset.map(combined_map_fn, num_parallel_calls=num_map_threads)
dataset = dataset.batch(128)
dataset = dataset.prefetch(1)
iterator = dataset.dataset.make_one_shot_iterator()
wavs, labels = iterator.get_next()
features = {'wav': wavs}
return features, labels