Как использовать последовательность / генератор в объекте tf.data.Dataset для размещения частичных данных в памяти?

Я делаю классификацию изображений с помощью Keras в Google Colab. Я загружаю изображения с помощью функции tf.keras.preprocessing.image_dataset_from_directory() (https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory), которая возвращает объект tf.data.Dataset:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=1234,
  image_size=(img_height, img_width),
  batch_size=batch_size,
  label_mode="categorical")

Я обнаружил, что, когда данные содержат тысячи изображений, model.fit () будет использовать всю память после обучения нескольких пакетов (я использую Google Colab и вижу, что использование ОЗУ растет в течение первой эпохи). Затем я пытаюсь использовать Keras Sequence, который является предлагаемым решением для загрузки частичных данных в ОЗУ (https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence):

  class DatasetGenerator(tf.keras.utils.Sequence):
      def __init__(self, dataset):
          self.dataset = dataset

      def __len__(self):
          return tf.data.experimental.cardinality(self.dataset).numpy()

      def __getitem__(self, idx):
          return list(self.dataset.as_numpy_iterator())[idx]

И я тренирую модель с помощью:

history = model.fit(DatasetGenerator(train_ds), ...)

Проблема в том, что getitem() должен возвращать пакет данных с индексом. Однако функция list (), которую я использую, должна поместить весь набор данных в ОЗУ и, таким образом, достичь предела памяти при создании экземпляра объекта DatasetGenerator (объект tf.data.Dataset не поддерживает индексацию с помощью []).

Мои вопросы:

  1. Есть ли способ реализовать getitem() (получить конкретный пакет из объекта набора данных) без помещения всего объекта в память?
  2. Если пункт 1 невозможен, есть ли обходной путь?

Заранее спасибо!

1 ответ

Решение

Я понимаю, что вы беспокоитесь о том, что полный набор данных находится в памяти.

Не волнуйтесь, tf.data.Dataset API очень эффективен и не загружает весь набор данных в память.

Внутри он просто создает последовательность функций и при вызове с model.fit() он загрузит в память только пакет, а не весь набор данных.

Вы можете прочитать больше по этой ссылке, я вставляю важную часть из документации.

API tf.data.Dataset поддерживает создание описательных и эффективных конвейеров ввода. Использование набора данных следует общей схеме:

Создайте исходный набор данных из ваших входных данных. Примените преобразования наборов данных для предварительной обработки данных. Обходите набор данных и обработайте элементы. Итерация происходит в потоковом режиме, поэтому полный набор данных не обязательно помещается в память.

Из последней строчки видно, что tf.data.Dataset API загружает в память не весь набор данных, а по одному пакету за раз.

Для создания пакетов набора данных вам нужно будет сделать следующее.

train_ds.batch(32)

Это создаст пакет размером 32. Также вы можете использовать предварительную выборку для подготовки одной партии к тренировке. Это устраняет узкое место, где модель является идеальной после обучения одной партии и ожидания следующей партии.

train_ds.batch(32).prefetch(1)

Вы также можете использовать cacheAPI, чтобы сделать конвейер данных еще быстрее. Он кэширует ваш набор данных и ускоряет обучение.

train_ds.batch(32).prefetch(1).cache()

Итак, чтобы ответить коротко, вам не нужно generator если вы беспокоитесь о загрузке всего набора данных в память, tf.data.Dataset API позаботится об этом.

Я надеюсь, что мой ответ тебе понравится.

Другие вопросы по тегам