Не удалось заставить Dataset.filter() работать в файле model/official/resnet/resnet_run_loop.py

В официальной модели повторной сети я хочу отфильтровать набор данных из test.bin по значению "label", если для e val_only установлено значение True. Я попробовал функцию tf.data.Dataset.filter(), чтобы получить только один класс тестовых данных, но это не сработало.

dataset = dataset.filter(lambda inputs, label: tf.equal(label,15))

Я поместил этот код в функцию resnet_run_loop.process_record_dataset, но возникла ошибка

 raise ValueError("`predicate` must return a scalar boolean tensor.")

Я обнаружил, что форма тензора 'label' имеет вид (?,):'Tensor("arg1:0", shape=(?,), dtype=int32, device=/device:CPU:0)'

2 ответа

При сравнении двух тензоров он возвращаетbool tensorтак<tf.Tensor: shape=(2,), dtype=bool, numpy=array([ True, True])>что бесполезно, если вы хотите найти ответ на вопрос «Равен ли этот тензор другому тензору?». Добавлениеtf.reduce_allему вернет такой тензор<tf.Tensor: shape=(), dtype=bool, numpy=True>и теперь это должно работать.

dataset = dataset.filter(lambda inputs, label: tf.reduce_all(tf.equal(label,15)))

Я столкнулся с той же проблемой в другой ситуации, и, как было указано в комментариях, оказалось, что проблема была вызвана пакетной обработкой перед фильтрацией.

Вы можете воспроизвести это на следующем примере:

       import pprint
import tensorflow as tf

dataset = tf.data.Dataset.zip((
    tf.data.Dataset.range(0, 5),
    tf.data.Dataset.from_tensor_slices([0, 10, 15, 20, 15])
))
pprint.pprint(list(dataset.as_numpy_iterator()))
# [(0, 0), (1, 10), (2, 15), (3, 20), (4, 15)]

filtered = dataset.filter(lambda x, y: y == 15)
pprint.pprint(list(filtered.as_numpy_iterator()))
# [(2, 15), (4, 15)]

BATCH_SIZE = 2
batched = dataset.batch(BATCH_SIZE)
batched_filtered = batched.filter(lambda x, y: y == 15)
# ValueError: `predicate` return type must be convertible to a scalar boolean tensor. Was [...]

Одно из простых решений этой проблемы - разблокировать набор данных, затем отфильтровать и, наконец, снова выполнить пакетную обработку:

       BATCH_SIZE = 2
batched = dataset.batch(BATCH_SIZE)
batched_filtered = batched.unbatch().filter(lambda x, y: y == 15).batch(BATCH_SIZE)
pprint.pprint(list(batched_filtered.as_numpy_iterator()))
# [(array([1, 2]), array([15, 15], dtype=int32)),
#  (array([4]), array([15], dtype=int32))]

Если вы не знаете или не хотите отслеживать ценность BATCH_SIZE, вы можете адаптировать это решение для расчета размера партии по запросу.

В итоге я объединил эти два решения следующим образом:

       def calculate_batch_size(dataset):
    return next(iter(dataset))[0].shape[0]

def filter_batch(dataset, pred_fn):
    batch_size = calculate_batch_size(dataset)
    return dataset.unbatch().filter(pred_fn).batch(batch_size)

BATCH_SIZE = 2
batched = dataset.batch(BATCH_SIZE)
batched_filtered = filter_batch(batched, lambda x, y: y == 15)
pprint.pprint(list(batched_filtered.as_numpy_iterator()))
# [(array([1, 2]), array([15, 15], dtype=int32)),
#  (array([4]), array([15], dtype=int32))]
Другие вопросы по тегам