Как дополнить фиксированный BATCH_SIZE в tf.data.Dataset?
У меня есть набор данных с 11 образцами. И когда я выбираю BATCH_SIZE
быть 2, следующий код будет иметь ошибки:
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parser)
if shuffle:
dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(count=1)
Проблема заключается в dataset = dataset.batch(batch_size)
, когда Dataset
Зацикливаясь на последнюю партию, оставшееся количество выборок равно 1, так есть ли какой-нибудь способ случайным образом выбрать одну из предыдущих посещенных выборок и сгенерировать последнюю партию?
2 ответа
@mining предлагает решение путем заполнения имен файлов.
Другое решение заключается в использовании tf.contrib.data.batch_and_drop_remainder
, Это позволит пакетировать данные с фиксированным размером пакета и удалить последнюю меньшую партию.
В ваших примерах с 11 входами и размером пакета 2 это даст 5 пакетов по 2 элемента.
Вот пример из документации:
dataset = tf.data.Dataset.range(11)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(2))
Вы можете просто установить drop_remainder=True
в вашем звонке batch
,
dataset = dataset.batch(batch_size, drop_remainder=True)
Из документации:
drop_remainder: (Необязательно.) Скалярный tf.Tolor tf.bool, представляющий, следует ли отбрасывать последний пакет в случае, если в нем меньше элементов batch_size; поведение по умолчанию - не отбрасывать меньший пакет.