Получить длину набора данных в Tensorflow

source_dataset = tf.data.TextLineDataset('primary.csv')
target_dataset = tf.data.TextLineDataset('secondary.csv')
dataset = tf.data.Dataset.zip((source_dataset, target_dataset))
dataset = dataset.shard(10000, 0)
dataset = dataset.map(lambda source, target: (tf.string_to_number(tf.string_split([source], delimiter=',').values, tf.int32),
                                              tf.string_to_number(tf.string_split([target], delimiter=',').values, tf.int32)))
dataset = dataset.map(lambda source, target: (source, tf.concat(([start_token], target), axis=0), tf.concat((target, [end_token]), axis=0)))
dataset = dataset.map(lambda source, target_in, target_out: (source, tf.size(source), target_in, target_out, tf.size(target_in)))

dataset = dataset.shuffle(NUM_SAMPLES)  #This is the important line of code

Я хотел бы полностью перемешать весь набор данных, но shuffle() требует несколько образцов, чтобы вытащить, и tf.Size() не работает с tf.data.Dataset,

Как я могу правильно перемешать?

1 ответ

Решение

Я работал с tf.data.FixedLengthRecordDataset() и столкнулся с аналогичной проблемой. В моем случае я пытался взять только определенный процент необработанных данных. Так как я знал, что все записи имеют фиксированную длину, обходной путь для меня был:

totalBytes = sum([os.path.getsize(os.path.join(filepath, filename)) for filename in os.listdir(filepath)])
numRecordsToTake = tf.cast(0.01 * percentage * totalBytes / bytesPerRecord, tf.int64)
dataset = tf.data.FixedLengthRecordDataset(filenames, recordBytes).take(numRecordsToTake)

В вашем случае, я бы посоветовал посчитать непосредственно в python количество записей в "primary.csv" и "primary.csv". В качестве альтернативы, я думаю, для вашей цели установка аргумента buffer_size на самом деле не требует подсчета файлов. Согласно принятому ответу о значении buffer_size, число, превышающее количество элементов в наборе данных, обеспечит равномерное перемешивание по всему набору данных. Поэтому достаточно просто ввести действительно большое число (которое, по вашему мнению, превысит размер набора данных).

Начиная с TensorFlow 2, длину набора данных можно легко получить с помощью cardinality() функция.

dataset = tf.data.Dataset.range(42)
#both print 42 
dataset_length_v1 = tf.data.experimental.cardinality(dataset).numpy())
dataset_length_v2 = dataset.cardinality().numpy()
Другие вопросы по тегам