Определение пользовательского градиента как метода класса в Tensorflow
Мне нужно определить метод как пользовательский градиент следующим образом:
class CustGradClass:
def __init__(self):
pass
@tf.custom_gradient
def f(self,x):
fx = x
def grad(dy):
return dy * 1
return fx, grad
Я получаю следующую ошибку:
ValueError: Попытка преобразовать значение (< main.CustGradClass object at 0x12ed91710>) с неподдерживаемым типом () в Tensor.
Причина в том, что пользовательский градиент принимает функцию f(*x), где x - это последовательность тензоров. И первым передаваемым аргументом является сам объект, то есть я.
Из документации:
f: функция f(*x), которая возвращает кортеж (y, grad_fn), где:
x - это последовательность входов Tensor в функцию. y является Tensor или последовательностью выходных данных Tensor применения операций TensorFlow в f к x. grad_fn - это функция с сигнатурой g(*grad_ys)
Как мне заставить это работать? Нужно ли мне наследовать некоторый класс тензорного потока Python?
Я использую TF версии 1.12.0 и нетерпеливый режим.
2 ответа
Это один из возможных простых обходных путей:
import tensorflow as tf
class CustGradClass:
def __init__(self):
self.f = tf.custom_gradient(lambda x: CustGradClass._f(self, x))
@staticmethod
def _f(self, x):
fx = x * 1
def grad(dy):
return dy * 1
return fx, grad
with tf.Graph().as_default(), tf.Session() as sess:
x = tf.constant(1.0)
c = CustGradClass()
y = c.f(x)
print(tf.gradients(y, x))
# [<tf.Tensor 'gradients/IdentityN_grad/mul:0' shape=() dtype=float32>]
РЕДАКТИРОВАТЬ:
Если вы хотите делать это много раз на разных классах или просто хотите использовать решение более многократного использования, вы можете использовать такой декоратор, например, такой:
import functools
import tensorflow as tf
def tf_custom_gradient_method(f):
@functools.wraps(f)
def wrapped(self, *args, **kwargs):
if not hasattr(self, '_tf_custom_gradient_wrappers'):
self._tf_custom_gradient_wrappers = {}
if f not in self._tf_custom_gradient_wrappers:
self._tf_custom_gradient_wrappers[f] = tf.custom_gradient(lambda *a, **kw: f(self, *a, **kw))
return self._tf_custom_gradient_wrappers[f](*args, **kwargs)
return wrapped
Тогда вы можете просто сделать:
class CustGradClass:
def __init__(self):
pass
@tf_custom_gradient_method
def f(self, x):
fx = x * 1
def grad(dy):
return dy * 1
return fx, grad
@tf_custom_gradient_method
def f2(self, x):
fx = x * 2
def grad(dy):
return dy * 2
return fx, grad
В вашем примере вы не используете никаких переменных-членов, поэтому вы можете просто сделать метод статическим методом. Если вы используете переменные-члены, тогда вызовите статический метод из функции-члена и передайте переменные-члены в качестве параметров.
class CustGradClass:
def __init__(self):
self.some_var = ...
@staticmethod
@tf.custom_gradient
def _f(x):
fx = x
def grad(dy):
return dy * 1
return fx, grad
def f(self):
return CustGradClass._f(self.some_var)