Jax - vmap по пакету классов данных

В JAX я ищу функцию vmap для списка классов данных фиксированной длины, например:

      import jax, chex
from flax import struct

@struct.dataclass
class EnvParams:
    max_steps: int = 500
    random_respawn: bool = False

def foo(params: EnvParams):
    ...

param_list = jnp.Array([EnvParams(max_steps=500), EnvParams(max_steps=600)])
jax.vmap(foo)(param_list)

Приведенный выше пример терпит неудачу, поскольку невозможно создать jnp.Array пользовательских объектов, а JAX не разрешает vmapping по спискам Python. Единственный оставшийся вариант, который я вижу, - это преобразовать класс данных для представления пакета параметров, например:

      @struct.dataclass
class EnvParamBatch:
    max_steps: jnp.Array = jnp.array([500, 600])
    random_respawn: jnp.Array = jnp.array([False, True])

def bar(params):
    ...

jax.vmap(bar)(EnvParamBatch())

Было бы предпочтительнее использовать контейнер структур (каждая из которых представляет один набор параметров), поэтому мне интересно, есть ли какие-либо альтернативные подходы к этому?

NB Мне известен этот ответ , однако это не совсем тот же вопрос, и теперь могут быть лучшие решения.

1 ответ

Крис дал правильный ответ для простых кодов, но есть способ сделать это без изменения класса данных. Я столкнулся с той же проблемой, а другая часть моего кода зависит от перегруженных операторов в классе данных, поэтому я не мог легко изменять структуры данных.

Решение использует pytree и Tree_map(). Это структуры данных JAX, состоящие из списка/диктанта отслеживаемых массивов. Во-первых, вам нужно изменить свой класс на pytree . Это потребует очень небольших усилий.

Поскольку списки pytree также являются pytree, jax.tree_util.tree_map будет работать без необходимости изменять ваш data_class.

Вот минимальный рабочий пример:

      import jax
from jax import jit, vmap, tree_util
from functools import partial # for JAX jit with static params

class MyContainer:
    """ For JAX use """
    def _tree_flatten(self):
        children = (self.a,)  # arrays / dynamic values
        aux_data = {'a_stat': self.a_stat}  # static values
        return (children, aux_data)
    @classmethod
    def _tree_unflatten(cls, aux_data, children):
        return cls(*children, **aux_data)
    
    """
    A container with a traced and a static member.
    the * operator is overloaded as demonstration.
    """
    def __init__(self, a:int):
        self.a = a
        self.a_stat = a*100
    
    def __mul__(self, other):
        return(MyContainer(self.a*other.a))
    
# Registering the datatype with JAX
tree_util.register_pytree_node(
    MyContainer,
    MyContainer._tree_flatten,
    MyContainer._tree_unflatten)


X_list = [MyContainer(3),MyContainer(4),MyContainer(5)]
Y_list = [MyContainer(1),MyContainer(10),MyContainer(100)]

# A simple callable adds the traced var a to the static var a_stat
def simple_callable(my_container):
    return(MyContainer(my_container.a+my_container.a_stat))

# Note that tree_map will try to traverse into class members as well. 
# To stop it from doing that, we add is_leaf to stop it from looking 
# deeper when the item is a MyContainer. 
test_simple_list = jax.tree_util.tree_map(
    simple_callable, 
    [MyContainer(3),MyContainer(4),MyContainer(5)], 
    is_leaf=lambda n: isinstance(n, MyContainer)
)

# see if it works
for i in range(len(X_list)):
    print('simple_callable', test_simple_list[i].a)
    

# This also works for callables containing such list of dataclasses
# However, to do indexing, you need a list for the indices.
# this list will be automatically handled as a pytree.
tree_ind = list(range(len(X_list))) 
def callcables_containing_dataclass(i):
    return(X_list[i]*Y_list[i])

test_callable_list = jax.tree_util.tree_map(callcables_containing_dataclass, tree_ind)

# seeing if it works
for i in range(len(X_list)):
    print('callable with dataclass', test_callable_list[i].a)

# jitting works
@jit
def test():
    return(
        test_simple_list = jax.tree_util.tree_map(
            simple_callable, 
            [MyContainer(3),MyContainer(4),MyContainer(5)], 
            is_leaf=lambda n: isinstance(n, MyContainer)
        ),
        jax.tree_util.tree_map(callcables_containing_dataclass, tree_ind
    )

test()
Другие вопросы по тегам