Не удалось заставить Dataset.filter() работать в файле model/official/resnet/resnet_run_loop.py
В официальной модели повторной сети я хочу отфильтровать набор данных из test.bin по значению "label", если для e val_only установлено значение True. Я попробовал функцию tf.data.Dataset.filter(), чтобы получить только один класс тестовых данных, но это не сработало.
dataset = dataset.filter(lambda inputs, label: tf.equal(label,15))
Я поместил этот код в функцию resnet_run_loop.process_record_dataset, но возникла ошибка
raise ValueError("`predicate` must return a scalar boolean tensor.")
Я обнаружил, что форма тензора 'label' имеет вид (?,):'Tensor("arg1:0", shape=(?,), dtype=int32, device=/device:CPU:0)'
2 ответа
При сравнении двух тензоров он возвращаетbool tensor
так<tf.Tensor: shape=(2,), dtype=bool, numpy=array([ True, True])>
что бесполезно, если вы хотите найти ответ на вопрос «Равен ли этот тензор другому тензору?». Добавлениеtf.reduce_all
ему вернет такой тензор<tf.Tensor: shape=(), dtype=bool, numpy=True>
и теперь это должно работать.
dataset = dataset.filter(lambda inputs, label: tf.reduce_all(tf.equal(label,15)))
Я столкнулся с той же проблемой в другой ситуации, и, как было указано в комментариях, оказалось, что проблема была вызвана пакетной обработкой перед фильтрацией.
Вы можете воспроизвести это на следующем примере:
import pprint
import tensorflow as tf
dataset = tf.data.Dataset.zip((
tf.data.Dataset.range(0, 5),
tf.data.Dataset.from_tensor_slices([0, 10, 15, 20, 15])
))
pprint.pprint(list(dataset.as_numpy_iterator()))
# [(0, 0), (1, 10), (2, 15), (3, 20), (4, 15)]
filtered = dataset.filter(lambda x, y: y == 15)
pprint.pprint(list(filtered.as_numpy_iterator()))
# [(2, 15), (4, 15)]
BATCH_SIZE = 2
batched = dataset.batch(BATCH_SIZE)
batched_filtered = batched.filter(lambda x, y: y == 15)
# ValueError: `predicate` return type must be convertible to a scalar boolean tensor. Was [...]
Одно из простых решений этой проблемы - разблокировать набор данных, затем отфильтровать и, наконец, снова выполнить пакетную обработку:
BATCH_SIZE = 2
batched = dataset.batch(BATCH_SIZE)
batched_filtered = batched.unbatch().filter(lambda x, y: y == 15).batch(BATCH_SIZE)
pprint.pprint(list(batched_filtered.as_numpy_iterator()))
# [(array([1, 2]), array([15, 15], dtype=int32)),
# (array([4]), array([15], dtype=int32))]
Если вы не знаете или не хотите отслеживать ценность
BATCH_SIZE
, вы можете адаптировать это решение для расчета размера партии по запросу.
В итоге я объединил эти два решения следующим образом:
def calculate_batch_size(dataset):
return next(iter(dataset))[0].shape[0]
def filter_batch(dataset, pred_fn):
batch_size = calculate_batch_size(dataset)
return dataset.unbatch().filter(pred_fn).batch(batch_size)
BATCH_SIZE = 2
batched = dataset.batch(BATCH_SIZE)
batched_filtered = filter_batch(batched, lambda x, y: y == 15)
pprint.pprint(list(batched_filtered.as_numpy_iterator()))
# [(array([1, 2]), array([15, 15], dtype=int32)),
# (array([4]), array([15], dtype=int32))]