Skip to content

Commit

Permalink
fix(document): allow eq to include ndarray comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbp authored Mar 17, 2022
1 parent 1e9e945 commit e6a078d
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 2 deletions.
45 changes: 44 additions & 1 deletion docarray/document/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

from ..math.ndarray import check_arraylike_equality

if TYPE_CHECKING:
from ..score import NamedScore
from .. import DocumentArray, Document
Expand Down Expand Up @@ -32,7 +34,7 @@
_all_mime_types = set(mimetypes.types_map.values())


@dataclass(unsafe_hash=True)
@dataclass(unsafe_hash=True, eq=False)
class DocumentData:
_reference_doc: 'Document' = field(hash=False, compare=False)
id: str = field(
Expand Down Expand Up @@ -111,3 +113,44 @@ def _set_default_value_if_none(self, key):
setattr(self, key, defaultdict(NamedScore))
else:
setattr(self, key, v() if callable(v) else v)

@staticmethod
def _embedding_eq(array1: 'ArrayType', array2: 'ArrayType'):

if array1 is None and array2 is None:
return True

if type(array1) == type(array2):
return check_arraylike_equality(array1, array2)
else:
return False

@staticmethod
def _tensor_eq(array1: 'ArrayType', array2: 'ArrayType'):
DocumentData._embedding_eq(array1, array2)

def __eq__(self, other):

self_non_empty_fields = self._non_empty_fields
other_non_empty_fields = other._non_empty_fields

if other_non_empty_fields != self_non_empty_fields:
return False

for key in self_non_empty_fields:

if hasattr(self, f'_{key}_eq'):

if hasattr(DocumentData, f'_{key}_eq'):
are_equal = getattr(DocumentData, f'_{key}_eq')(
getattr(self, key), getattr(other, key)
)
print(
f'are_equal( {getattr(self, key)}, { getattr(other, key)}) ---> {are_equal}'
)
if are_equal == False:
return False
else:
if getattr(self, key) != getattr(other, key):
return False
return True
66 changes: 66 additions & 0 deletions docarray/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,69 @@ def get_array_rows(array: 'ArrayType') -> Tuple[int, int]:
raise ValueError

return num_rows, ndim


def check_arraylike_equality(x: 'ArrayType', y: 'ArrayType'):
"""Check if two array type objects are the same with the supported frameworks.
Examples
>>> import numpy as np
x = np.array([[1,2,0,0,3],[1,2,0,0,3]])
check_arraylike_equality(x,x)
True
>>> from scipy import sparse as sp
x = sp.csr_matrix([[1,2,0,0,3],[1,2,0,0,3]])
check_arraylike_equality(x,x)
True
>>> import torch
x = torch.tensor([1,2,3])
check_arraylike_equality(x,x)
True
"""
x_type, x_is_sparse = get_array_type(x)
y_type, y_is_sparse = get_array_type(y)

same_array = False
if x_type == y_type and x_is_sparse == y_is_sparse:

if x_type == 'python':
same_array = x == y

if x_type == 'numpy':
# Numpy does not support sparse tensors
import numpy as np

same_array = np.array_equal(x, y)
elif x_type == 'torch':
import torch

if x_is_sparse:
# torch.equal NotImplementedError for sparse
same_array = all((x - y).coalesce().values() == 0)
else:
same_array = torch.equal(x, y)
elif x_type == 'scipy':
# Not implemented in scipy this should work for all types
# Note: you can't simply look at nonzero values because they can be in
# different positions.
if x.shape != y.shape:
same_array = False
else:
same_array = (x != y).nnz == 0
elif x_type == 'tensorflow':
if x_is_sparse:
same_array = x == y
else:
# Does not have equal implemented, only elementwise, therefore reduce .all is needed
same_array = (x == y).numpy().all()
elif x_type == 'paddle':
# Paddle does not support sparse tensor on 11/8/2021
# https://github.com/PaddlePaddle/Paddle/issues/36697
# Does not have equal implemented, only elementwise, therefore reduce .all is needed
same_array = (x == y).numpy().all()
return same_array
else:
return same_array
20 changes: 20 additions & 0 deletions tests/unit/document/test_docdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,29 @@ def test_doc_hash_complicate_content():
d1 = Document(text='hello', embedding=np.array([1, 2, 3]), id=1)
d2 = Document(text='hello', embedding=np.array([1, 2, 3]), id=1)
assert d1 == d2
assert d2 == d1
assert hash(d1) == hash(d2)


def test_doc_difference_complicate_content():
# Here we ensure != is symmetric therefore we put d1 != d2 and d2 != d1
# The __eq__ at DocumentData level is implemented in docarray/document/data.py
d1 = Document(text='hello', embedding=np.array([1, 2, 3]), id=1)
d2 = Document(text='hello', embedding=np.array([1, 2, 4]), id=1)
assert d1 != d2
assert d2 != d1

d1 = Document(text='hello', embedding=np.array([1, 2, 3, 5]), id=1)
d2 = Document(text='hello', embedding=np.array([1, 2, 4]), id=1)
assert d1 != d2
assert d2 != d1

d1 = Document(text='hello', id=1)
d2 = Document(text='hello', embedding=np.array([1, 2, 4]), id=1)
assert d1 != d2
assert d2 != d1


def test_pop_field():
d1 = Document(text='hello', embedding=np.array([1, 2, 3]), id=1)
assert d1.non_empty_fields == ('id', 'text', 'embedding')
Expand Down
25 changes: 24 additions & 1 deletion tests/unit/math/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from scipy.sparse import csr_matrix, coo_matrix, bsr_matrix, csc_matrix, issparse

from docarray.math.ndarray import get_array_rows
from docarray.math.ndarray import get_array_rows, check_arraylike_equality
from docarray.proto.docarray_pb2 import NdArrayProto
from docarray.proto.io import flush_ndarray, read_ndarray

Expand Down Expand Up @@ -51,3 +51,26 @@ def test_get_array_rows(data, expected_result, arraytype, ndarray_type):
assert isinstance(r_data_array, list)
elif ndarray_type == 'numpy':
assert isinstance(r_data_array, np.ndarray)


def get_ndarrays():
a = np.random.random([10, 3])
a[a > 0.5] = 0
return [
a,
a.tolist(),
torch.tensor(a),
tf.constant(a),
paddle.to_tensor(a),
torch.tensor(a).to_sparse(),
csr_matrix(a),
bsr_matrix(a),
coo_matrix(a),
csc_matrix(a),
]


@pytest.mark.parametrize('ndarray_val', get_ndarrays())
def test_check_arraylike_equality(ndarray_val):
assert check_arraylike_equality(ndarray_val, ndarray_val) == True
assert check_arraylike_equality(ndarray_val, ndarray_val + ndarray_val) == False

0 comments on commit e6a078d

Please sign in to comment.