Skip to content

Commit

Permalink
fix(match): exclude_self (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
numb3r3 authored Feb 25, 2022
1 parent d502925 commit 01b3976
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
5 changes: 4 additions & 1 deletion docarray/array/mixins/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ def find(
if exclude_self and isinstance(query, DocumentArray):
for i, q in enumerate(query):
matches = result[i].traverse_flat('r', filter_fn=lambda d: d.id != q.id)
result[i] = matches[:_limit]
if limit and len(matches) > limit:
result[i] = matches[:limit]
else:
result[i] = matches

if len(result) == 1:
return result[0]
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/array/mixins/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,16 @@ def doc_lists_to_doc_arrays(doc_lists, *args, **kwargs):
[('weaviate', WeaviateConfig(3)), ('pqlite', {'n_dim': 3})],
)
@pytest.mark.parametrize('limit', [1, 2, 3])
def test_match(storage, config, doc_lists, limit, start_weaviate):
@pytest.mark.parametrize('exclude_self', [True, False])
def test_match(storage, config, doc_lists, limit, exclude_self, start_weaviate):
D1, D2 = doc_lists_to_doc_arrays(doc_lists)

if config:
da = DocumentArray(D2, storage=storage, config=config)
else:
da = DocumentArray(D2, storage=storage)

D1.match(da, limit=limit)
D1.match(da, limit=limit, exclude_self=exclude_self)
for m in D1[:, 'matches']:
assert len(m) == limit

Expand Down

0 comments on commit 01b3976

Please sign in to comment.