Пользовательские партии для tf.data.Dataset

Я использую Estimator API tenorflow и хотел бы создавать собственные партии для обучения.

У меня есть примеры, которые выглядят следующим образом

example1 = {
   "num_sentences": 3,
   "sentences": [[1, 2], [3, 4], [5, 6]] 
}
example2 = {
   "num_sentences": 2,
   "sentences": [[1, 2], [3, 4]] 
}

Таким образом, пример может иметь любое количество предложений фиксированного размера. Теперь я хотел бы создать пакеты, размер которых зависит от количества предложений в пакете. В противном случае мне придется использовать размер пакета 1, так как в некоторых примерах могут быть предложения "размера пакета", а большой размер пакета не помещается в память GPU.

Например: у меня размер партии 6 и примеры с количеством предложений [5, 3, 3, 2, 2, 1]. Затем я группирую примеры в партии [5], [3, 3] и [2, 2, 1]. Обратите внимание, что пример "1" в последнем пакете будет дополнен.

Я написал алгоритм, который группирует примеры для таких партий. Теперь я не могу передать пакеты в tf.data.Dataset.

Я пытался использовать tf.data.Dataset.from_generator но метод, похоже, ожидает отдельных примеров, и я получаю ошибку, если генератор выдает пакеты, например [example1, example2].

Как я могу кормить набор данных заказными партиями? Есть ли более элегантный способ решить мою проблему?

Обновление: я предполагаю, что не могу правильно указать параметр выходных фигур. Следующий код работает нормально.

import tensorflow as tf

def gen():
    for b in range(3):
        #yield [{"num_sentences": 3, "sentences": [[1, 2], [3, 4], [5, 6]]}]
        yield {"num_sentences": 3, "sentences": [[1, 2], [3, 4], [5, 6]]}


dataset = tf.data.Dataset.from_generator(generator=gen, 
                                         output_types={'num_sentences': tf.int32, 'sentences': tf.int32},
                                         #output_shapes=tf.TensorShape([None,  {'num_sentences': tf.TensorShape(None), 'sentences': tf.TensorShape(None)}])
                                         output_shapes={'num_sentences': tf.TensorShape(None), 'sentences': tf.TensorShape(None)}
                                        )

def print_dataset(dataset):
    it = dataset.make_one_shot_iterator()
    with tf.Session() as sess:
        print(dataset.output_shapes)
        print(dataset.output_types)
        while True:
            try:
                data = it.get_next()
                print("data" + str(sess.run(data)))
            except tf.errors.OutOfRangeError:
                break

print_dataset(dataset)

Если вместо этого я получаю массив и раскомментирую output_shapes, я получаю сообщение об ошибке "Аргумент int() должен быть строкой, байтовым объектом или числом, а не" dict ""

1 ответ

Я думаю, что я понял, как решить вышеуказанную проблему. Я думаю, что я должен объединить примеры в один словарь.

# a batch with two examples each with sentence size 3
yield {"num_sentences": [3, 3], "sentences": [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]}
Другие вопросы по тегам