Как включить нетерпеливый режим в tf.data

Для проекта я использую tf.data.Dataset для написания входного конвейера. Входными данными является изображение RGB. Метка представляет собой список 2D-координат объектов на изображении, которые использовались для создания тепловой карты.

Вот MWE для воспроизведения проблемы.

 def encode_images(image, label):
        """

        Parameters
        ----------
        image
        label

        Returns
        -------

        """
        # load image
        # here the normal code
        # img_contents = tf.io.read_file(image)
        # # decode the image
        # img = tf.image.decode_jpeg(img_contents, channels=3)
        # img = tf.image.resize(img, (256, 256))
        # img = tf.cast(img, tf.float32)

        # this is just for testing
        image = tf.random.uniform(
            (256, 256, 3), minval=0, maxval=255, dtype=tf.dtypes.float32, seed=None, name=None
        )
        return image, label

    def generate_heatmap(image, label):
        """

        Parameters
        ----------
        image
        label

        Returns
        -------

        """

        start = 0.5
        sigma=3
        img_shape = (image.shape[0] , image.shape[1] )
        density_map = np.zeros(img_shape, dtype=np.float32)
        for center_x, center_y in label[0]:
            for v_y in range(img_shape[0]):
                for v_x in range(img_shape[1]):
                    x = start + v_x
                    y = start + v_y
                    d2 = (x - center_x) * (x - center_x) + (y - center_y) * (y - center_y)
                    exp = d2 / (2.0 * sigma**2)
                    if exp > math.log(100):
                        continue
                    density_map[v_y, v_x] = math.exp(-exp)
        return density_map


    X = ["img1.png", "img2.png", "img3.png", "img4.png", "img5.png"]
    y = [[[2, 3], [100, 120], [100, 120]],
         [[2, 3], [100, 120], [100, 120], [2, 1]],
         [[2, 3], [100, 120], [100, 120], [10, 10], [11, 12]],
         [[2, 3], [100, 120], [100, 120], [10, 10], [11, 12], [10, 2]],
         [[2, 3], [100, 120], [100, 120]]
         ]
    dataset = tf.data.Dataset.from_tensor_slices((X, tf.ragged.constant(y)))
    dataset = dataset.map(encode_images, num_parallel_calls=8)
    dataset = dataset.map(generate_heatmap, num_parallel_calls=8)
    dataset = dataset.batch(1, drop_remainder=False)

Проблема в том, что в generate_heatmap()функции, я использовал массив numpy для изменения элементов по индексам, что невозможно в тензорном потоке. Я пытаюсь перебрать тензор меток, что до сих пор невозможно в тензорном потоке. Другое дело, что режим ожидания не включен в tf.data.Dataset!! Любое предложение, как с этим справиться! Думаю, в пыторче такой код можно сделать быстро, не мучаясь:)!

0 ответов

Другие вопросы по тегам