Как дополнить фиксированный BATCH_SIZE в tf.data.Dataset?

У меня есть набор данных с 11 образцами. И когда я выбираю BATCH_SIZE быть 2, следующий код будет иметь ошибки:

dataset = tf.contrib.data.TFRecordDataset(filenames) 
dataset = dataset.map(parser)
if shuffle:
    dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(count=1)

Проблема заключается в dataset = dataset.batch(batch_size), когда Dataset Зацикливаясь на последнюю партию, оставшееся количество выборок равно 1, так есть ли какой-нибудь способ случайным образом выбрать одну из предыдущих посещенных выборок и сгенерировать последнюю партию?

2 ответа

Решение

@mining предлагает решение путем заполнения имен файлов.

Другое решение заключается в использовании tf.contrib.data.batch_and_drop_remainder, Это позволит пакетировать данные с фиксированным размером пакета и удалить последнюю меньшую партию.

В ваших примерах с 11 входами и размером пакета 2 это даст 5 пакетов по 2 элемента.

Вот пример из документации:

dataset = tf.data.Dataset.range(11)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(2))

Вы можете просто установить drop_remainder=True в вашем звонке batch,

dataset = dataset.batch(batch_size, drop_remainder=True)

Из документации:

drop_remainder: (Необязательно.) Скалярный tf.Tolor tf.bool, представляющий, следует ли отбрасывать последний пакет в случае, если в нем меньше элементов batch_size; поведение по умолчанию - не отбрасывать меньший пакет.

Другие вопросы по тегам