Как сравнить равенство классов данных, содержащих numpy.ndarray (bool(a==b) повышает ValueError)?

Если я создаю класс данных Python, содержащий Numpy ndarray, я больше не могу использовать автоматически сгенерированный __eq__ больше.

import numpy as np

@dataclass
class Instr:
    foo: np.ndarray
    bar: np.ndarray

arr = np.array([1])
arr2 = np.array([1, 2])
print(Instr(arr, arr) == Instr(arr2, arr2))

ValueError: Значение истинности массива с более чем одним элементом неоднозначно. Используйте a.any() или a.all()

Это потому что ndarray.__eq__ иногда возвращает ndarray ценностей правды, сравнивая a[0] в b[0]и т. д. и т. д. для более длинного из 2. Это довольно сложно и не интуитивно понятно, и фактически вызывает ошибку только тогда, когда массивы имеют разные формы или имеют разные значения или что-то в этом роде.

Как мне безопасно сравнить @dataclassдержит массивы Numpy?


@dataclassРеализация __eq__ генерируется с помощью eval(), Его источник отсутствует в трассировке стека и не может быть просмотрен с помощью inspect, но на самом деле он использует сравнение кортежей, которое вызывает bool(foo).

import dis
dis.dis(Instr.__eq__)

выдержка:

  3          12 LOAD_FAST                0 (self)
             14 LOAD_ATTR                1 (foo)
             16 LOAD_FAST                0 (self)
             18 LOAD_ATTR                2 (bar)
             20 BUILD_TUPLE              2
             22 LOAD_FAST                1 (other)
             24 LOAD_ATTR                1 (foo)
             26 LOAD_FAST                1 (other)
             28 LOAD_ATTR                2 (bar)
             30 BUILD_TUPLE              2
             32 COMPARE_OP               2 (==)
             34 RETURN_VALUE

2 ответа

Решение

Решение заключается в том, чтобы положить в свой __eq__ метод и набор eq=False таким образом, класс данных не генерирует свой собственный (хотя проверка документов, что последний шаг не является необходимым, но я думаю, что все равно приятно быть явным).

import numpy as np

def array_eq(arr1, arr2):
    return (isinstance(arr1, np.ndarray) and
            isinstance(arr2, np.ndarray) and
            arr1.shape == arr2.shape and
            (arr1 == arr2).all())

@dataclass(eq=False)
class Instr:

    foo: np.ndarray
    bar: np.ndarray

    def __eq__(self, other):
        if not isinstance(other, Instr):
            return NotImplemented
        return array_eq(self.foo, other.foo) and array_eq(self.bar, other.bar)

Это можно настроить, если вы используете attrs вместо классов данных:

      from attrs import define, field
import numpy

@define
class C:
   an_array = field(eq=attr.cmp_using(eq=numpy.array_equal))
Другие вопросы по тегам