Якобианский определитель векторной функции с Python JAX/Autograd
У меня есть функция, которая отображает векторы на векторы
и я хочу вычислить его определитель Якоби
,
где якобиан определяется как
.
Поскольку я могу использовать numpy.linalg.det
, чтобы вычислить определитель, мне просто нужна матрица Якоби. Я знаю о numdifftools.Jacobian
, но здесь используется числовое дифференцирование, и я ищу автоматическое дифференцирование. Войти Autograd
/JAX
(Я буду придерживаться Autograd
на данный момент в нем есть autograd.jacobian()
метод, но я счастлив использовать JAX
пока я получаю то, что хочу). Как мне это использоватьautograd.jacobian()
-функция правильно с векторной функцией?
В качестве простого примера рассмотрим функцию
! [f (x) = (x_0 ^ 2, x_1 ^ 2)] ( https://chart.googleapis.com/chart?cht=tx&chl=f(x%29%20%3D%20(x_0%5E2%2C%20x_1%5E2%29)
который имеет якобиан
![J_f = diag(2 x_0, 2 x_1)](https://chart.googleapis.com/chart?cht=tx&chl=J_f%20%3D%20%5Coperatorname%7Bdiag%7D(2x_0%2C%202x_1%29)
что приводит к определителю Якоби
>>> import autograd.numpy as np
>>> import autograd as ag
>>> x = np.array([[3],[11]])
>>> result = 4*x[0]*x[1]
array([132])
>>> jac = ag.jacobian(f)(x)
array([[[[ 6],
[ 0]]],
[[[ 0],
[22]]]])
>>> jac.shape
(2, 1, 2, 1)
>>> np.linalg.det(jac)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/lib/python3.8/site-packages/autograd/tracer.py", line 48, in f_wrapped
return f_raw(*args, **kwargs)
File "<__array_function__ internals>", line 5, in det
File "/usr/lib/python3.8/site-packages/numpy/linalg/linalg.py", line 2113, in det
_assert_stacked_square(a)
File "/usr/lib/python3.8/site-packages/numpy/linalg/linalg.py", line 213, in _assert_stacked_square
raise LinAlgError('Last 2 dimensions of the array must be square')
numpy.linalg.LinAlgError: Last 2 dimensions of the array must be square
Первый подход дает мне правильные значения, но неправильную форму. Почему.jacobian()
вернуть такой вложенный массив? Если я его правильно переделываю, я получаю правильный результат:
>>> jac = ag.jacobian(f)(x).reshape(-1,2,2)
array([[[ 6, 0],
[ 0, 22]]])
>>> np.linalg.det(jac)
array([132.])
Но теперь давайте посмотрим, как это работает с широковещательной рассылкой массивов, когда я пытаюсь вычислить определитель Якоби для нескольких значений x
>>> x = np.array([[3,5,7],[11,13,17]])
array([[ 3, 5, 7],
[11, 13, 17]])
>>> result = 4*x[0]*x[1]
array([132, 260, 476])
>>> jac = ag.jacobian(f)(x)
array([[[[ 6, 0, 0],
[ 0, 0, 0]],
[[ 0, 10, 0],
[ 0, 0, 0]],
[[ 0, 0, 14],
[ 0, 0, 0]]],
[[[ 0, 0, 0],
[22, 0, 0]],
[[ 0, 0, 0],
[ 0, 26, 0]],
[[ 0, 0, 0],
[ 0, 0, 34]]]])
>>> jac = ag.jacobian(f)(x).reshape(-1,2,2)
>>> jac
array([[[ 6, 0],
[ 0, 0]],
[[ 0, 0],
[ 0, 10]],
[[ 0, 0],
[ 0, 0]],
[[ 0, 0],
[14, 0]],
[[ 0, 0],
[ 0, 0]],
[[ 0, 22],
[ 0, 0]],
[[ 0, 0],
[ 0, 0]],
[[26, 0],
[ 0, 0]],
[[ 0, 0],
[ 0, 34]]])
>>> jac.shape
(9,2,2)
Здесь очевидно, что обе формы неправильные, правильные (как в матрице Якоби, которую я хочу), будет
[[[ 6, 0],
[ 0, 22]],
[[10, 0],
[ 0, 26]],
[[14, 0],
[ 0, 34]]]
с shape=(6,2,2)
Как мне использовать autograd.jacobian
(или jax.jacfwd
/jax.jacrev
), чтобы он правильно обрабатывал несколько векторных входов?
Примечание. Используя явный цикл и обрабатывая каждую точку вручную, я получаю правильный результат. Но есть ли способ сделать это на месте?
>>> dets = []
>>> for v in zip(*x):
>>> v = np.array(v)
>>> jac = ag.jacobian(f)(v)
>>> print(jac, jac.shape, '\n')
>>> det = np.linalg.det(jac)
>>> dets.append(det)
[[ 6. 0.]
[ 0. 22.]] (2, 2)
[[10. 0.]
[ 0. 26.]] (2, 2)
[[14. 0.]
[ 0. 34.]] (2, 2)
>>> dets
[131.99999999999997, 260.00000000000017, 475.9999999999998]
1 ответ
"Как мне правильно использовать эту функцию autograd.jacobian()- с векторной функцией?"
Ты написал
x = np.array([[3],[11]])
Здесь есть две проблемы. Во-первых, это вектор векторов, а autograd предназначен для векторных функций. Во-вторых, autograd ожидает числа с плавающей запятой, а не целые числа. Если вы попытаетесь различать целые числа, вы получите ошибку. Вы не видите ошибки с вектором векторов, потому что autograd автоматически преобразует ваши списки целых чисел в списки чисел с плавающей запятой.
TypeError: Can't differentiate w.r.t. type <class 'int'>
Следующий код должен дать вам определитель.
import autograd.numpy as np
import autograd as ag
def f(x):
return np.array([x[0]**2,x[1]**2])
x = np.array([3.,11.])
jac = ag.jacobian(f)(x)
result = np.linalg.det(jac)
print(result)
"Как мне использовать autograd.jacobian (или jax.jacfwd/jax.jacrev), чтобы он правильно обрабатывал несколько векторных входов?"
Есть способ сделать это на месте, он называется jax.vmap. См. Документы JAX. ( https://jax.readthedocs.io/en/latest/jax.html)
В этом случае я мог бы вычислить вектор определителей Якоби с помощью следующего кода. Обратите внимание, что я могу определить функцию f точно так же, как и раньше, vmap делает всю работу за нас за кулисами.
import jax.numpy as np
import jax
def f(x):
return np.array([x[0]**2,x[1]**2])
x = np.array([[3.,11.],[5.,13.],[7.,17.]])
jac = jax.jacobian(f)
vmap_jac = jax.vmap(jac)
result = np.linalg.det(vmap_jac(x))
print(result)