Я пытаюсь назначить объект JAX Tracer массиву NumPy, для которого требуются конкретные значения - пожалуйста, обойдите это
Я новичок в Джексе.
Я реализую вариационный автоэнкодер (VAE) с использованием Jax и Flax. Во время обучения я сэмплирую скрытый код (из дистрибутива, выведенного энкодером, который я реализую с помощью композиций модулей flax.linen.nn). Важно отметить, что помимо передачи этого кода через декодер (что является стандартом для VAE), я также передаю код внешней функции (физическому движку MuJoCo), которая пытается присвоить его массиву NumPy. Это неудивительно приводит к следующей ошибке:
TracerArrayConversionError: массив методов преобразования numpy.ndarray () был вызван для объекта JAX Tracer...
По сути, мне нужно передать конкретный массив numpy в MuJoCo. Как я могу сделать свою переменную массивом NumPy, который по-прежнему позволит реализовать мою модель вычислительно эффективным способом с использованием абстрактных трассировщиков, где это возможно?
Вот минимальный рабочий пример проблемы, с которой я столкнулся: для запуска необходимо установить тренажерный зал и mujoco (https://mujoco.org/ ):
import jax
import jax.numpy as np
import numpy as onp
import gym
from jax import jit
# create an instance of an open AI gym environment
env = gym.make('Humanoid-v3')
env.reset()
def this_fails(env, x):
# this gives a TracerArrayConversionError
env.sim.data.qpos[:] = x
return env, x
x = np.arange(len(env.sim.data.qpos))
jit_this_fails = jax.jit(this_fails, static_argnums = 0)
env, x = jit_this_fails(env, x)
1 ответ
Изменить: теперь в этой теме есть запись часто задаваемых вопросов JAX: https://jax.readthedocs.io/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-множество
Примечание: это ответ на вопрос ОП, как было написано изначально. Вопрос редактировался несколько раз и больше не задает то, что задавал изначально.
В прошлом такого рода вещи не поддерживались, но вы можете сделать это с помощью новой функции, которая является частью JAX версии 0.3.17, которая еще не выпущена на момент написания этой статьи.
Например, предположим, что вы хотите вызвать функцию на основе numpy из jit-компилируемой функции JAX; мы будем использоватьnp.sin
для простоты. Вы можете сначала попробовать что-то вроде этого:
import jax
import jax.numpy as jnp
import numpy as np
@jax.jit
def this_fails(x):
# Call a numpy function...
return np.sin(x)
x = jnp.arange(5.0)
this_fails(x)
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function this_fails at tmp.py:7 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
В результатеTracerConversionError
, потому что вы пытаетесь передать трассируемое значение JAX в функцию, которая ожидает массив numpy (примечание: см. Как думать в JAX для введения в трассировщики JAX и связанные темы).
В JAX версии 0.3.17 или новее вы можете обойти эту проблему, используяjax.pure_callback
:
@jax.jit
def numpy_callback(x):
# Need to forward-declare the shape & dtype of the expected output.
result_shape = jax.core.ShapedArray(x.shape, x.dtype)
return jax.pure_callback(np.sin, result_shape, x)
x = jnp.arange(5.0)
print(numpy_callback(x))
[ 0. 0.841471 0.9092974 0.14112 -0.7568025]
Несколько предостережений, о которых следует помнить:
- результирующее выполнение будет зависеть от обратного вызова к хосту, поэтому оно будет довольно медленным на ускорителях, таких как GPU/TPU, особенно в распределенных/многохостовых настройках. Однако в случае локального выполнения ЦП он избегает буферных копий и может быть достаточно производительным.
- если вы
vmap
функции, это приведет кfor
цикл из нескольких обратных вызовов (можно указатьvectorized=True
если функция обратного вызова изначально обрабатывает пакеты). - автодифференциальные преобразования, такие как
grad
иjacobian
не будет работать с этой функцией, потому что JAX не может рассуждать о выполняемых вычислениях. Если вы хотите использовать его с преобразованиями autodiff, вы можете определить пользовательские градиенты, как в Custom Derivative Rules, хотя для этого потребуется доступ к функции, которая вычисляет градиент для вашей функции обратного вызова.
Ничто из этого еще не задокументировано на веб-сайте JAX, но мы надеемся написать документацию дляpure_callback
скоро!