Вычисление векторизованного гессиана в Tensorflow
проблема
Рассмотрим проблему настройки
x : A (N, K) tensor that we want to differentiate with respect to.
f(x): A function sending x to a scalar i.e.
f: R^(N x K) -> R
То, что я хочу найти, это при каждом наблюдении x[i,:]
(размер N
ось), градиент (N x K
) и гессиан (N x K x K
).
Градиенты
Теперь градиенты в каждом наблюдении легко найти, так как вам просто нужно найти градиент f
по отношению ко всем x
значения т.е.
df/dx[0,0] ... df/dx[0,K]
. .
. .
. .
df/dx[N,0] ... df/dx[N,K]
что можно сделать просто с
tf.gradients(f(x), x)
гессенцы
Теперь у меня проблема с поиском размера (N, K, K)
Гессенский тензор. Если я использую tf.hessians
функционировать наивно то есть
tf.hessians(f(x), x)
это находит (правильно) (N, K, N, K)
тензор частных вторых производных, даже для x
значения между наблюдениями. Это всегда 0 (в моем случае), поэтому для больших значений N
это может быть очень неэффективно.
Как я могу заставить Tensorflow только найти N
(K x K)
Гессианские матрицы с j, k
записи df/(dx[i,j]dx[i,k])
для наблюдения i
?
Я думаю, что может быть решение зацикливания над значениями 0, ..., N-1
, тем не мение N
статическому графу также неизвестен, он только динамически определяется для новых входных данных.
Фиктивный код
Ниже приведен минимальный рабочий пример, иллюстрирующий проблему.
import tensorflow as tf
import numpy as np
N = 2
K = 3
# Create dummy data.
x_np = np.random.rand(N, K).astype(np.float32)
# Define Tensorflow graph.
x = tf.placeholder(tf.float32, shape=(None, K), name='x')
f = tf.reduce_sum(tf.multiply(x, x), name='f')
grad = tf.gradients(f, x, name='grad')
hess = tf.hessians(f, x, name='hess')
# Run the Tensorflow graph.
sess = tf.Session()
print("\nTensorflow gradient:")
print(sess.run(grad, feed_dict={'x:0': x_np})[0])
print("\nTensorflow Hessian:")
hess_tf = sess.run(hess, feed_dict={'x:0': x_np})[0]
print(hess_tf)
# Show how we can get the Hessian we want from `hess_tf`.
hess_np = np.empty([N, K, K])
for i in range(N):
hess_np[i, :, :] = hess_tf[i, :, i, :]
print("\nWanted Hessian:")
print(hess_np)