Я пытаюсь назначить объект JAX Tracer массиву NumPy, для которого требуются конкретные значения - пожалуйста, обойдите это

Я новичок в Джексе.

Я реализую вариационный автоэнкодер (VAE) с использованием Jax и Flax. Во время обучения я сэмплирую скрытый код (из дистрибутива, выведенного энкодером, который я реализую с помощью композиций модулей flax.linen.nn). Важно отметить, что помимо передачи этого кода через декодер (что является стандартом для VAE), я также передаю код внешней функции (физическому движку MuJoCo), которая пытается присвоить его массиву NumPy. Это неудивительно приводит к следующей ошибке:

TracerArrayConversionError: массив методов преобразования numpy.ndarray () был вызван для объекта JAX Tracer...

По сути, мне нужно передать конкретный массив numpy в MuJoCo. Как я могу сделать свою переменную массивом NumPy, который по-прежнему позволит реализовать мою модель вычислительно эффективным способом с использованием абстрактных трассировщиков, где это возможно?

Вот минимальный рабочий пример проблемы, с которой я столкнулся: для запуска необходимо установить тренажерный зал и mujoco (https://mujoco.org/ ):

      import jax
import jax.numpy as np
import numpy as onp
import gym
from jax import jit

# create an instance of an open AI gym environment
env = gym.make('Humanoid-v3')
env.reset()

def this_fails(env, x):
    
    # this gives a TracerArrayConversionError
    env.sim.data.qpos[:] = x

    return env, x

x = np.arange(len(env.sim.data.qpos))
jit_this_fails = jax.jit(this_fails, static_argnums = 0)
env, x = jit_this_fails(env, x)

1 ответ

Изменить: теперь в этой теме есть запись часто задаваемых вопросов JAX: https://jax.readthedocs.io/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-множество


Примечание: это ответ на вопрос ОП, как было написано изначально. Вопрос редактировался несколько раз и больше не задает то, что задавал изначально.

В прошлом такого рода вещи не поддерживались, но вы можете сделать это с помощью новой функции, которая является частью JAX версии 0.3.17, которая еще не выпущена на момент написания этой статьи.

Например, предположим, что вы хотите вызвать функцию на основе numpy из jit-компилируемой функции JAX; мы будем использоватьnp.sinдля простоты. Вы можете сначала попробовать что-то вроде этого:

      import jax
import jax.numpy as jnp
import numpy as np

@jax.jit
def this_fails(x):
  # Call a numpy function...
  return np.sin(x)

x = jnp.arange(5.0)
this_fails(x)
      jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function this_fails at tmp.py:7 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

В результатеTracerConversionError, потому что вы пытаетесь передать трассируемое значение JAX в функцию, которая ожидает массив numpy (примечание: см. Как думать в JAX для введения в трассировщики JAX и связанные темы).

В JAX версии 0.3.17 или новее вы можете обойти эту проблему, используяjax.pure_callback:

      @jax.jit
def numpy_callback(x):
  # Need to forward-declare the shape & dtype of the expected output.
  result_shape = jax.core.ShapedArray(x.shape, x.dtype)
  return jax.pure_callback(np.sin, result_shape, x)

x = jnp.arange(5.0)
print(numpy_callback(x))
      [ 0.         0.841471   0.9092974  0.14112   -0.7568025]

Несколько предостережений, о которых следует помнить:

  • результирующее выполнение будет зависеть от обратного вызова к хосту, поэтому оно будет довольно медленным на ускорителях, таких как GPU/TPU, особенно в распределенных/многохостовых настройках. Однако в случае локального выполнения ЦП он избегает буферных копий и может быть достаточно производительным.
  • если выvmapфункции, это приведет кforцикл из нескольких обратных вызовов (можно указатьvectorized=Trueесли функция обратного вызова изначально обрабатывает пакеты).
  • автодифференциальные преобразования, такие какgradиjacobianне будет работать с этой функцией, потому что JAX не может рассуждать о выполняемых вычислениях. Если вы хотите использовать его с преобразованиями autodiff, вы можете определить пользовательские градиенты, как в Custom Derivative Rules, хотя для этого потребуется доступ к функции, которая вычисляет градиент для вашей функции обратного вызова.

Ничто из этого еще не задокументировано на веб-сайте JAX, но мы надеемся написать документацию дляpure_callbackскоро!

Другие вопросы по тегам