Tensorflow.py Защищенное подразделение

Я пытаюсь реализовать своего рода защищенное разделение, используя Tensorflow.where но почему-то кажется, что пропускает условие, установленное на where заявление.

Основная идея при разделении x/y, если y == 0. тогда результат деления будет x вместо метания и ошибки.

Мой код выглядит следующим образом:

def Pdivide(x,y):
    result = tf.where(y == 0., x, x/y) 
    return result

Но почему-то это условие пропускается:

>>> a = tf.Variable([1.7, 0.2, 0., 1.1, 0.9, 0.3, 23., -1.])
>>> b = tf.Variable([0., 0., 0., 1., 1., 0., 1., 1.])

>>>Pdivide(a,b)

>>>(inf, inf, nan, 1.1, 0.9, inf, 23, -1)

Предполагаемый выход:

>>>(1.7, 0.2, 0., 1.1, 0.9, 0.3, 23, -1)

PS: используя eager выполнение.

1 ответ

Итак, ответ довольно прост, по-видимому.

Почему-то тензорные элементы не могут меня сравнивать с простыми == но используя tf.equal(y, 0.) решает проблему и выдает правильный вывод.

Другие вопросы по тегам