Pickle меняет тип в jax

У меня есть класс данных структуры льна, содержащий массив jax numpy.

Когда я собираю дамп этого объекта и загружаю его снова, массив больше не является массивом jax numpy и преобразуется в массив numpy, вот код для его воспроизведения:

      import flax
import jax.numpy as jnp
import pickle

@flax.struct.dataclass
class A:
    data: jnp.ndarray

a = A(data=jnp.zeros((2,2)))
print(a, type(a.data))



with open('file.pickle', 'wb') as handle:
    pickle.dump(a, handle)
    
with open('file.pickle', 'rb') as handle:
    loaded_a = pickle.load(handle)

print(loaded_a, type(loaded_a.data))

Мне не нужно такое поведение, и я бы хотел, чтобы оно сохраняло свой первоначальный тип, возможно ли это?

0 ответов

Другие вопросы по тегам