Нетранзитивное создание подклассов с помощью numpy и jax
Мой вопрос простой:
>>> isinstance(x, jax.numpy.ndarray)
True
>>> issubclass(jax.numpy.ndarray, numpy.ndarray)
True
>>> isinstance(x, numpy.ndarray)
False
?
А теперь я побеспокоюсь, чтобы SE примет мой разумный вопрос.
1 ответ
Причина этого в том, что
jax.numpy.ndarray
переопределяет проверки экземпляра с помощью метакласса:
class _ArrayMeta(type(np.ndarray)): # type: ignore
"""Metaclass for overriding ndarray isinstance checks."""
def __instancecheck__(self, instance):
try:
return isinstance(instance.aval, _arraylike_types)
except AttributeError:
return isinstance(instance, _arraylike_types)
class ndarray(np.ndarray, metaclass=_ArrayMeta):
dtype: np.dtype
shape: Tuple[int, ...]
size: int
def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")
Причина, по которой ваш код возвращает то, что он делает, заключается в том, что у вас есть
x
значение, которое не является экземпляром
numpy.ndarray
, но для чего это
__instancecheck__
метод возвращает истину.
К чему такие уловки в JAX? Что ж, для целей JIT-компиляции, авто-дифференцирования и других преобразований JAX использует замещающие объекты, называемые трассировщиками, которые должны выглядеть и действовать как массив, хотя на самом деле он не является массивом. Это переопределение проверок экземпляров - один из приемов, которые JAX использует для выполнения такой трассировки.