Нетранзитивное создание подклассов с помощью numpy и jax

Мой вопрос простой:

>>> isinstance(x, jax.numpy.ndarray)
True
>>> issubclass(jax.numpy.ndarray, numpy.ndarray)
True
>>> isinstance(x, numpy.ndarray)
False

?

А теперь я побеспокоюсь, чтобы SE примет мой разумный вопрос.

1 ответ

Решение

Причина этого в том, что jax.numpy.ndarray переопределяет проверки экземпляра с помощью метакласса:

class _ArrayMeta(type(np.ndarray)):  # type: ignore
  """Metaclass for overriding ndarray isinstance checks."""

  def __instancecheck__(self, instance):
    try:
      return isinstance(instance.aval, _arraylike_types)
    except AttributeError:
      return isinstance(instance, _arraylike_types)

class ndarray(np.ndarray, metaclass=_ArrayMeta):
  dtype: np.dtype
  shape: Tuple[int, ...]
  size: int

  def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
               order=None):
    raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
                    " Use jax.numpy.array, or jax.numpy.zeros instead.")

(просмотреть источник)

Причина, по которой ваш код возвращает то, что он делает, заключается в том, что у вас есть x значение, которое не является экземпляром numpy.ndarray, но для чего это __instancecheck__ метод возвращает истину.

К чему такие уловки в JAX? Что ж, для целей JIT-компиляции, авто-дифференцирования и других преобразований JAX использует замещающие объекты, называемые трассировщиками, которые должны выглядеть и действовать как массив, хотя на самом деле он не является массивом. Это переопределение проверок экземпляров - один из приемов, которые JAX использует для выполнения такой трассировки.

Другие вопросы по тегам