Skip to content

Commit

Permalink
fix: customize metric fn expect no metric_name (#405)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Jun 20, 2022
1 parent 5a3ba78 commit 18a1c46
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 31 deletions.
2 changes: 1 addition & 1 deletion docarray/array/storage/memory/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _find(
batch_size = int(batch_size)

if callable(metric):
cdist = metric
cdist = lambda *x: metric(*x[:2])
elif isinstance(metric, str):
if use_scipy:
from scipy.spatial.distance import cdist as cdist
Expand Down
2 changes: 1 addition & 1 deletion docs/advanced/document-store/weaviate.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ The following configs can be set:

*You can read more about the HNSW parameters and their default values [here](https://weaviate.io/developers/weaviate/current/vector-index-plugins/hnsw.html#how-to-use-hnsw-and-parameters)

## Minimum Example
## Minimum example

The following example shows how to use DocArray with Weaviate Document Store in order to index and search text
Documents.
Expand Down
37 changes: 8 additions & 29 deletions docs/fundamentals/documentarray/matching.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Though both `.find()` and `.match()` is about finding nearest neighbours of a gi

In the sequel, we will use `.match()` to describe the features. But keep in mind that `.find()` should also work by simply switching the right and left-hand sides.

## Example
### Example

The following example finds for each element in `da1` the three closest Documents from the elements in `da2` according to Euclidean distance.

Expand Down Expand Up @@ -134,11 +134,11 @@ da2.find(da1, metric='euclidean', limit=3)
or simply:

```python
da2.find(np.array(
[[0, 0, 0, 0, 1],
[1, 0, 0, 0, 0],
[1, 1, 1, 1, 0],
[1, 2, 2, 1, 0]]), metric='euclidean', limit=3)
da2.find(
np.array([[0, 0, 0, 0, 1], [1, 0, 0, 0, 0], [1, 1, 1, 1, 0], [1, 2, 2, 1, 0]]),
metric='euclidean',
limit=3,
)
```

The following metrics are supported:
Expand All @@ -155,27 +155,6 @@ Note that framework is auto-chosen based on the type of `.embeddings`. For examp

By default `A.match(B)` will copy the top-K matched Documents from B to `A.matches`. When these matches are big, copying them can be time-consuming. In this case, one can leverage `.match(..., only_id=True)` to keep only {attr}`~docarray.Document.id`.

### Pre filtering

Both `match` and `find` support pre-filtering by passing a `filter` argument to the method.

Pre-filtering is an advanced approximate nearest neighbors feature that allows to efficiently retrieve the nearest vectors
that respect the filtering condition.

In contrast, post-filtering in the naive approach where you first retrieve the
nearest neighbors and then discard all the candidates that do not respect the filter condition.

````{admonition} Pre-filtering is not available for in-memory backend
:class: caution
By default a DocumentArray will use the in-memory backend which does not support pre-filtering
```
````

You can find example on how to use the pre-filtering here:

- {ref}`ANNLite <annlite-filter>`
- {ref}`Weaviate <weaviate-filter>`
- {ref}`Qdrant <qdrant-filter>`


### GPU support
Expand Down Expand Up @@ -224,7 +203,7 @@ da2.embeddings = np.random.random([M, D]).astype(np.float32)
```
```python
%timeit da1.match(da2, only_id=True)
da1.match(da2, only_id=True)
```
```text
Expand All @@ -243,7 +222,7 @@ da2.embeddings = torch.tensor(np.random.random([M, D]).astype(np.float32))
```
```python
%timeit da1.match(da2, device='cuda', batch_size=1_000, only_id=True)
da1.match(da2, device='cuda', batch_size=1_000, only_id=True)
```
```text
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/array/mixins/test_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,23 @@
import operator


def test_customize_metric_fn():
N, D = 4, 128
da = DocumentArray.empty(N)
da.embeddings = np.random.random([N, D])

q = np.random.random([D])
_, r1 = da.find(q)[:, ['scores__cosine__value', 'id']]

from docarray.math.distance.numpy import cosine

def inv_cosine(*args):
return -cosine(*args)

_, r2 = da.find(q, metric=inv_cosine)[:, ['scores__inv_cosine__value', 'id']]
assert list(reversed(r1)) == r2


@pytest.mark.parametrize(
'storage, config',
[
Expand Down

0 comments on commit 18a1c46

Please sign in to comment.