Tensorflow.py Защищенное подразделение
Я пытаюсь реализовать своего рода защищенное разделение, используя Tensorflow.where
но почему-то кажется, что пропускает условие, установленное на where
заявление.
Основная идея при разделении x/y
, если y == 0.
тогда результат деления будет x
вместо метания и ошибки.
Мой код выглядит следующим образом:
def Pdivide(x,y):
result = tf.where(y == 0., x, x/y)
return result
Но почему-то это условие пропускается:
>>> a = tf.Variable([1.7, 0.2, 0., 1.1, 0.9, 0.3, 23., -1.])
>>> b = tf.Variable([0., 0., 0., 1., 1., 0., 1., 1.])
>>>Pdivide(a,b)
>>>(inf, inf, nan, 1.1, 0.9, inf, 23, -1)
Предполагаемый выход:
>>>(1.7, 0.2, 0., 1.1, 0.9, 0.3, 23, -1)
PS: используя eager
выполнение.
1 ответ
Итак, ответ довольно прост, по-видимому.
Почему-то тензорные элементы не могут меня сравнивать с простыми ==
но используя tf.equal(y, 0.)
решает проблему и выдает правильный вывод.