Предотвращение или синхронизация перетасовки в make_tf_dataset

У меня есть модель тензорного потока, которая обучена с помощью набора рейтинговых данных, например:

Я хочу подчеркнуть этим примером, что каждая партия может иметь разную длину элементов, поэтому я попробовал следующее.

Я создаю Spark DataFrame с набором данных рейтингов.

      df = spark.read.parquet('s3a://path/to/dataset')

затем создайте SparkDatasetConverter с фреймом данных:

      conv_train = make_spark_converter(df)

теперь с идеей получить 3 равных партии из одного и того же конвертера и применить к каждой партии разные преобразования, я использовал этот код

      with conv_train.make_tf_dataset(transform_spec=transformation0, shuffling_queue_capacity=None, batch_size=2, num_epochs=1, seed=1) as rating,  \
     conv_train.make_tf_dataset(transform_spec=transformation1, shuffling_queue_capacity=None, batch_size=2, num_epochs=1, seed=1) as unique_user, \
     conv_train.make_tf_dataset(transform_spec=transformation2, shuffling_queue_capacity=None, batch_size=2, num_epochs=1, seed=1) as unique_movie:
    hist = model.fit(rating, unique_user, uynique_movie)

Проблема в том, что этот код перемешивает каждую партию независимо, даже используя одно и то же начальное число. Когда я запускаю следующий код:

      with conv_train.make_tf_dataset(shuffling_queue_capacity=None, batch_size=2, num_epochs=1, seed=1) as train0,  \
     conv_train.make_tf_dataset(shuffling_queue_capacity=None, batch_size=2, num_epochs=1, seed=1) as train1, \
     conv_train.make_tf_dataset(shuffling_queue_capacity=None, batch_size=2, num_epochs=1, seed=1) as train2:
   #Print train0
   #Print train1
   #Print train2

элементы train0, train1 и train2 не совпадают, чего не должно происходить из-за того, что я использую одно и то же семя.

Как я могу добиться, чтобы поезд0, поезд1 и поезд2 были равны?

0 ответов

Другие вопросы по тегам