Как сравнить равенство классов данных, содержащих numpy.ndarray (bool(a==b) повышает ValueError)?
Если я создаю класс данных Python, содержащий Numpy ndarray, я больше не могу использовать автоматически сгенерированный __eq__
больше.
import numpy as np
@dataclass
class Instr:
foo: np.ndarray
bar: np.ndarray
arr = np.array([1])
arr2 = np.array([1, 2])
print(Instr(arr, arr) == Instr(arr2, arr2))
ValueError: Значение истинности массива с более чем одним элементом неоднозначно. Используйте a.any() или a.all()
Это потому что ndarray.__eq__
иногда возвращает ndarray
ценностей правды, сравнивая a[0]
в b[0]
и т. д. и т. д. для более длинного из 2. Это довольно сложно и не интуитивно понятно, и фактически вызывает ошибку только тогда, когда массивы имеют разные формы или имеют разные значения или что-то в этом роде.
Как мне безопасно сравнить @dataclass
держит массивы Numpy?
@dataclass
Реализация __eq__
генерируется с помощью eval()
, Его источник отсутствует в трассировке стека и не может быть просмотрен с помощью inspect
, но на самом деле он использует сравнение кортежей, которое вызывает bool(foo).
import dis
dis.dis(Instr.__eq__)
выдержка:
3 12 LOAD_FAST 0 (self) 14 LOAD_ATTR 1 (foo) 16 LOAD_FAST 0 (self) 18 LOAD_ATTR 2 (bar) 20 BUILD_TUPLE 2 22 LOAD_FAST 1 (other) 24 LOAD_ATTR 1 (foo) 26 LOAD_FAST 1 (other) 28 LOAD_ATTR 2 (bar) 30 BUILD_TUPLE 2 32 COMPARE_OP 2 (==) 34 RETURN_VALUE
2 ответа
Решение заключается в том, чтобы положить в свой __eq__
метод и набор eq=False
таким образом, класс данных не генерирует свой собственный (хотя проверка документов, что последний шаг не является необходимым, но я думаю, что все равно приятно быть явным).
import numpy as np
def array_eq(arr1, arr2):
return (isinstance(arr1, np.ndarray) and
isinstance(arr2, np.ndarray) and
arr1.shape == arr2.shape and
(arr1 == arr2).all())
@dataclass(eq=False)
class Instr:
foo: np.ndarray
bar: np.ndarray
def __eq__(self, other):
if not isinstance(other, Instr):
return NotImplemented
return array_eq(self.foo, other.foo) and array_eq(self.bar, other.bar)
Это можно настроить, если вы используете attrs вместо классов данных:
from attrs import define, field
import numpy
@define
class C:
an_array = field(eq=attr.cmp_using(eq=numpy.array_equal))