Евклидово дистанционное преобразование в тензорном потоке

Я хотел бы создать функцию тензорного потока, которая копирует евклидово преобразование расстояния scipy для каждой 2-мерной матрицы в моем 3-мерном тензоре.

У меня есть трехмерный тензор, где третья ось представляет горячо закодированный объект. Я хотел бы создать для каждого измерения объекта матрицу, где значения в каждой ячейке равны расстоянию до ближайшего объекта.

Пример:

input = [[1 0 0]
         [0 1 0]
         [0 0 1],

         [0 1 0]
         [0 0 0]
         [1 0 0]]

output = [[0    1   1.41]
          [1    0   1   ]
          [1.41 1   0   ],

          [1    0   1   ]
          [1    1   1.41]
          [0    1   2   ]]              

Мое текущее решение реализовано в Python. Метод выполняет итерацию по каждой ячейке измерения элемента, создает кольцо вокруг ячейки и ищет, содержит ли кольцо элемент. Затем он вычисляет расстояние для ячейки до каждой записи объекта и принимает минимум. Если кольцо не содержит ячейку с функцией, кольцо поиска становится шире.

Код:

import numpy as np
import math

def distance_matrix():
    feature_1 = np.eye(5)
    feature_2 = np.array([[0, 1, 0, 0, 0],
                  [0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0],
                  [1, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0],])
    ground_truth = np.stack((feature_1,feature_2), axis=2)
    x = np.zeros(ground_truth.shape)

    for feature_index in range(ground_truth.shape[2]):
        for i in range(ground_truth.shape[0]):
            for j in range(ground_truth.shape[1]):
                x[i,j,feature_index] = search_ring(i,j, feature_index,0,ground_truth)
    print(x[:,:,0])

def search_ring(i, j,feature_index, ring_size, truth):
    if ring_size == 0 and truth[i,j,feature_index] == 1.:
                    return 0
    else:
        distance = truth.shape[0]
        y_min = max(i - ring_size, 0)
        y_max = min(i + ring_size, truth.shape[0] - 1)
        x_min = max(j - ring_size, 0)
        x_max = min(j + ring_size, truth.shape[1] - 1)

        if truth[y_min:y_max+1, x_min:x_max+1, feature_index].sum() > 0:
            for y in range(y_min, y_max + 1):
                for x in range(x_min, x_max + 1):
                    if y == y_min or y == y_max or x == x_min or x == x_max:
                        if truth[y,x,feature_index] == 1.:
                            dist = norm(i,j,y,x,type='euclidean')
                            distance = min(distance, dist)
            return distance
        else:
            return search_ring(i, j,feature_index, ring_size + 1, truth)

def norm(index_y_a, index_x_a, index_y_b, index_x_b, type='euclidean'):
    if type == 'euclidean':
        return math.sqrt(abs(index_y_a - index_y_b)**2 + abs(index_x_a - index_x_b)**2)
    elif type == 'manhattan':
        return abs(index_y_a - index_y_b) + abs(index_x_a - index_x_b)


def main():
    distance_matrix()
if __name__ == '__main__':
    main()

Моя проблема заключается в репликации этого в Tensorflow, так как он мне нужен для пользовательской функции потерь в Keras. Как я могу получить доступ к индексам предметов, которые я перебираю?

2 ответа

Я сделал что-то подобное с py_func создать преобразование расстояния со знаком, используя scipy, Вот как это может выглядеть в вашем случае:

import scipy.ndimage.morphology as morph
arrs = []
for channel_index in range(C):
    arrs.append(tf.py_func(morph.distance_transform_edt, [tensor[..., channel_index]], tf.float32))
edt_tf = tf.stack(arrs, axis=-1)

Обратите внимание на ограничения py_func: они не будут сериализованы в GraphDefs поэтому он не будет сериализовать тело функции в моделях, которые вы сохраняете. Смотрите документацию по tf.py_func.

Я не вижу никаких проблем для вас, чтобы использовать преобразование расстояния в kerasв общем, все что вам нужно это tf.py_func, который оборачивает существующую функцию python в tensorflow оператор.

Тем не менее, я думаю, что основная проблема здесь заключается в обратном распространении. У вашей модели будут проблемы при прямом проходе, но какой градиент вы планируете распространять? Или вам просто наплевать на его градиент вообще.

Другие вопросы по тегам