Jax - vmap по пакету классов данных
В JAX я ищу функцию vmap для списка классов данных фиксированной длины, например:
import jax, chex
from flax import struct
@struct.dataclass
class EnvParams:
max_steps: int = 500
random_respawn: bool = False
def foo(params: EnvParams):
...
param_list = jnp.Array([EnvParams(max_steps=500), EnvParams(max_steps=600)])
jax.vmap(foo)(param_list)
Приведенный выше пример терпит неудачу, поскольку невозможно создать jnp.Array пользовательских объектов, а JAX не разрешает vmapping по спискам Python. Единственный оставшийся вариант, который я вижу, - это преобразовать класс данных для представления пакета параметров, например:
@struct.dataclass
class EnvParamBatch:
max_steps: jnp.Array = jnp.array([500, 600])
random_respawn: jnp.Array = jnp.array([False, True])
def bar(params):
...
jax.vmap(bar)(EnvParamBatch())
Было бы предпочтительнее использовать контейнер структур (каждая из которых представляет один набор параметров), поэтому мне интересно, есть ли какие-либо альтернативные подходы к этому?
NB Мне известен этот ответ , однако это не совсем тот же вопрос, и теперь могут быть лучшие решения.
1 ответ
Крис дал правильный ответ для простых кодов, но есть способ сделать это без изменения класса данных. Я столкнулся с той же проблемой, а другая часть моего кода зависит от перегруженных операторов в классе данных, поэтому я не мог легко изменять структуры данных.
Решение использует pytree и Tree_map(). Это структуры данных JAX, состоящие из списка/диктанта отслеживаемых массивов. Во-первых, вам нужно изменить свой класс на pytree . Это потребует очень небольших усилий.
Поскольку списки pytree также являются pytree, jax.tree_util.tree_map будет работать без необходимости изменять ваш data_class.
Вот минимальный рабочий пример:
import jax
from jax import jit, vmap, tree_util
from functools import partial # for JAX jit with static params
class MyContainer:
""" For JAX use """
def _tree_flatten(self):
children = (self.a,) # arrays / dynamic values
aux_data = {'a_stat': self.a_stat} # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
"""
A container with a traced and a static member.
the * operator is overloaded as demonstration.
"""
def __init__(self, a:int):
self.a = a
self.a_stat = a*100
def __mul__(self, other):
return(MyContainer(self.a*other.a))
# Registering the datatype with JAX
tree_util.register_pytree_node(
MyContainer,
MyContainer._tree_flatten,
MyContainer._tree_unflatten)
X_list = [MyContainer(3),MyContainer(4),MyContainer(5)]
Y_list = [MyContainer(1),MyContainer(10),MyContainer(100)]
# A simple callable adds the traced var a to the static var a_stat
def simple_callable(my_container):
return(MyContainer(my_container.a+my_container.a_stat))
# Note that tree_map will try to traverse into class members as well.
# To stop it from doing that, we add is_leaf to stop it from looking
# deeper when the item is a MyContainer.
test_simple_list = jax.tree_util.tree_map(
simple_callable,
[MyContainer(3),MyContainer(4),MyContainer(5)],
is_leaf=lambda n: isinstance(n, MyContainer)
)
# see if it works
for i in range(len(X_list)):
print('simple_callable', test_simple_list[i].a)
# This also works for callables containing such list of dataclasses
# However, to do indexing, you need a list for the indices.
# this list will be automatically handled as a pytree.
tree_ind = list(range(len(X_list)))
def callcables_containing_dataclass(i):
return(X_list[i]*Y_list[i])
test_callable_list = jax.tree_util.tree_map(callcables_containing_dataclass, tree_ind)
# seeing if it works
for i in range(len(X_list)):
print('callable with dataclass', test_callable_list[i].a)
# jitting works
@jit
def test():
return(
test_simple_list = jax.tree_util.tree_map(
simple_callable,
[MyContainer(3),MyContainer(4),MyContainer(5)],
is_leaf=lambda n: isinstance(n, MyContainer)
),
jax.tree_util.tree_map(callcables_containing_dataclass, tree_ind
)
test()