Пользовательские партии для tf.data.Dataset
Я использую Estimator API tenorflow и хотел бы создавать собственные партии для обучения.
У меня есть примеры, которые выглядят следующим образом
example1 = {
"num_sentences": 3,
"sentences": [[1, 2], [3, 4], [5, 6]]
}
example2 = {
"num_sentences": 2,
"sentences": [[1, 2], [3, 4]]
}
Таким образом, пример может иметь любое количество предложений фиксированного размера. Теперь я хотел бы создать пакеты, размер которых зависит от количества предложений в пакете. В противном случае мне придется использовать размер пакета 1, так как в некоторых примерах могут быть предложения "размера пакета", а большой размер пакета не помещается в память GPU.
Например: у меня размер партии 6 и примеры с количеством предложений [5, 3, 3, 2, 2, 1]. Затем я группирую примеры в партии [5], [3, 3] и [2, 2, 1]. Обратите внимание, что пример "1" в последнем пакете будет дополнен.
Я написал алгоритм, который группирует примеры для таких партий. Теперь я не могу передать пакеты в tf.data.Dataset.
Я пытался использовать tf.data.Dataset.from_generator
но метод, похоже, ожидает отдельных примеров, и я получаю ошибку, если генератор выдает пакеты, например [example1, example2].
Как я могу кормить набор данных заказными партиями? Есть ли более элегантный способ решить мою проблему?
Обновление: я предполагаю, что не могу правильно указать параметр выходных фигур. Следующий код работает нормально.
import tensorflow as tf
def gen():
for b in range(3):
#yield [{"num_sentences": 3, "sentences": [[1, 2], [3, 4], [5, 6]]}]
yield {"num_sentences": 3, "sentences": [[1, 2], [3, 4], [5, 6]]}
dataset = tf.data.Dataset.from_generator(generator=gen,
output_types={'num_sentences': tf.int32, 'sentences': tf.int32},
#output_shapes=tf.TensorShape([None, {'num_sentences': tf.TensorShape(None), 'sentences': tf.TensorShape(None)}])
output_shapes={'num_sentences': tf.TensorShape(None), 'sentences': tf.TensorShape(None)}
)
def print_dataset(dataset):
it = dataset.make_one_shot_iterator()
with tf.Session() as sess:
print(dataset.output_shapes)
print(dataset.output_types)
while True:
try:
data = it.get_next()
print("data" + str(sess.run(data)))
except tf.errors.OutOfRangeError:
break
print_dataset(dataset)
Если вместо этого я получаю массив и раскомментирую output_shapes, я получаю сообщение об ошибке "Аргумент int() должен быть строкой, байтовым объектом или числом, а не" dict ""
1 ответ
Я думаю, что я понял, как решить вышеуказанную проблему. Я думаю, что я должен объединить примеры в один словарь.
# a batch with two examples each with sentence size 3
yield {"num_sentences": [3, 3], "sentences": [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]}