Skip to content

Commit

Permalink
Add: Matches and BatchMatches simple API
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Jul 24, 2023
1 parent 9a6a01c commit 1b40f13
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 77 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ index = Index(

vector = np.array([0.2, 0.6, 0.4])
index.add(42, vector)
matches, distances, count = index.search(vector, 10)
matches: Matches = index.search(vector, 10)

assert len(index) == 1
assert count == 1
assert matches[0] == 42
assert distances[0] <= 0.001
assert len(matches) == 1
assert matches[0].label == 42
assert matches[0].distance <= 0.001
assert np.allclose(index[42], vector)
```

Expand Down
14 changes: 7 additions & 7 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ index = Index(

vector = np.array([0.2, 0.6, 0.4])
index.add(42, vector)
matches, distances, count = index.search(vector, 10)
matches: Matches = index.search(vector, 10)

assert len(index) == 1
assert count == 1
assert matches[0] == 42
assert distances[0] <= 0.001
assert len(matches) == 1
assert matches[0].label == 42
assert matches[0].distance <= 0.001
assert np.allclose(index[42], vector)
```

Expand Down Expand Up @@ -61,10 +61,10 @@ labels = np.arange(n)
vectors = np.random.uniform(0, 0.3, (n, index.ndim)).astype(np.float32)

index.add(labels, vectors, threads=..., copy=...)
matches, distances, counts = index.search(vectors, 10, threads=...)
matches: BatchMatches = index.search(vectors, 10, threads=...)

assert matches.shape[0] == vectors.shape[0]
assert counts[0] <= 10
assert len(matches) == vectors.shape[0]
assert len(matches[0]) <= 10
```

You can also override the default `threads` and `copy` arguments in bulk workloads.
Expand Down
20 changes: 12 additions & 8 deletions python/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,18 +271,21 @@ static py::tuple search_one_in_index(dense_index_py_t& index, py::buffer vector,
if (vector_dimensions != static_cast<Py_ssize_t>(index.scalar_words()))
throw std::invalid_argument("The number of vector dimensions doesn't match!");

py::array_t<label_t> labels_py(static_cast<Py_ssize_t>(wanted));
py::array_t<distance_t> distances_py(static_cast<Py_ssize_t>(wanted));
constexpr Py_ssize_t vectors_count = 1;
py::array_t<label_t> labels_py({vectors_count, static_cast<Py_ssize_t>(wanted)});
py::array_t<distance_t> distances_py({vectors_count, static_cast<Py_ssize_t>(wanted)});
py::array_t<Py_ssize_t> counts_py(vectors_count);
std::size_t count{};
auto labels_py1d = labels_py.template mutable_unchecked<1>();
auto distances_py1d = distances_py.template mutable_unchecked<1>();
auto labels_py2d = labels_py.template mutable_unchecked<2>();
auto distances_py2d = distances_py.template mutable_unchecked<2>();
auto counts_py1d = counts_py.template mutable_unchecked<1>();

search_config_t config;
config.exact = exact;

auto raise_and_dump = [&](dense_search_result_t result) {
result.error.raise();
count = result.dump_to(&labels_py1d(0), &distances_py1d(0));
count = result.dump_to(&labels_py2d(0, 0), &distances_py2d(0, 0));
};

switch (numpy_string_to_kind(vector_info.format)) {
Expand All @@ -295,13 +298,14 @@ static py::tuple search_one_in_index(dense_index_py_t& index, py::buffer vector,
throw std::invalid_argument("Incompatible scalars in the query vector: " + vector_info.format);
}

labels_py.resize(py_shape_t{static_cast<Py_ssize_t>(count)});
distances_py.resize(py_shape_t{static_cast<Py_ssize_t>(count)});
labels_py.resize(py_shape_t{vectors_count, static_cast<Py_ssize_t>(count)});
distances_py.resize(py_shape_t{vectors_count, static_cast<Py_ssize_t>(count)});
counts_py1d[0] = static_cast<Py_ssize_t>(count);

py::tuple results(3);
results[0] = labels_py;
results[1] = distances_py;
results[2] = static_cast<Py_ssize_t>(count);
results[2] = counts_py;
return results;
}

Expand Down
15 changes: 7 additions & 8 deletions python/scripts/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,12 @@ def test_index(
if numpy_type != np.byte:
assert np.allclose(index[42], vector, atol=0.1)

matches, distances, count = index.search(vector, 10)
matches = index.search(vector, 10)
assert len(index) == 1
assert len(matches) == count
assert len(distances) == count
assert count == 1
assert matches[0] == 42
assert distances[0] == pytest.approx(0, abs=1e-3)
assert len(matches.labels) == len(matches.distances)
assert len(matches.labels) == 1
assert matches[0].label == 42
assert matches[0].distance == pytest.approx(0, abs=1e-3)

assert index.max_level >= 0
assert index.levels_stats.nodes >= 1
Expand All @@ -136,8 +135,8 @@ def test_index(
assert len(index) == 0
index.add(42, vector)
assert len(index) == 1
matches, distances, count = index.search(vector, 10)
assert count == 1
matches = index.search(vector, 10)
assert len(matches) == 1

index.load("tmp.usearch")
assert len(index) == 1
Expand Down
10 changes: 4 additions & 6 deletions python/scripts/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,8 @@ def python_inner_product_four_args(a, a_ndim, b, b_ndim):
vectors = random_vectors(count=batch_size, ndim=ndim)

index.add(labels, vectors)
matches, distances, count = index.search(vectors, 10)
assert matches.shape[0] == distances.shape[0]
assert count.shape[0] == batch_size
matches = index.search(vectors, 10)
assert len(matches) == batch_size


@pytest.mark.parametrize("ndim", dimensions[-1:])
Expand Down Expand Up @@ -153,9 +152,8 @@ def test_index_cppyy(ndim: int, batch_size: int):
vectors = random_vectors(count=batch_size, ndim=ndim)

index.add(labels, vectors)
matches, distances, count = index.search(vectors, 10)
assert matches.shape[0] == distances.shape[0]
assert count.shape[0] == batch_size
matches = index.search(vectors, 10)
assert len(matches) == batch_size


@pytest.mark.parametrize("ndim", [8])
Expand Down
97 changes: 53 additions & 44 deletions python/usearch/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,61 +102,65 @@ def _normalize_metric(metric):
return metric


class Match(NamedTuple):
label: int
distance: float


class Matches(NamedTuple):
labels: np.ndarray
distances: np.ndarray
counts: Union[np.ndarray, int]

@property
def is_multiple(self) -> bool:
return isinstance(self.counts, np.ndarray)
def __len__(self) -> int:
return len(self.labels)

@property
def batch_size(self) -> int:
return len(self.counts) if isinstance(self.counts, np.ndarray) else 1
def __getitem__(self, index: int) -> Match:
return Match(
label=self.labels[index],
distance=self.distances[index],
)

@property
def total_matches(self) -> int:
return np.sum(self.counts)
def to_list(self) -> List[tuple]:
return [(int(l), float(d)) for l, d in zip(self.labels, self.distances)]

def to_list(self, row: Optional[int] = None) -> Union[List[dict], List[List[dict]]]:
if not self.is_multiple:
assert row is None, "Exporting a single sequence is only for batch requests"
labels = self.labels
distances = self.distances
def __repr__(self) -> str:
return f"usearch.Matches({len(self)})"

elif row is None:
return [self.to_list(i) for i in range(self.batch_size)]

else:
count = self.counts[row]
labels = self.labels[row, :count]
distances = self.distances[row, :count]
class BatchMatches(NamedTuple):
labels: np.ndarray
distances: np.ndarray
counts: np.ndarray

return [
{"label": int(label), "distance": float(distance)}
for label, distance in zip(labels, distances)
]
def __len__(self) -> int:
return len(self.counts)

def __getitem__(self, index: int) -> Matches:
return Matches(
labels=self.labels[index, : self.counts[index]],
distances=self.distances[index, : self.counts[index]],
)

def to_list(self) -> List[List[tuple]]:
lists = [self.__getitem__(row) for row in range(self.__len__())]
return [item for sublist in lists for item in sublist]

def recall_first(self, expected: np.ndarray) -> float:
best_matches = self.labels if not self.is_multiple else self.labels[:, 0]
return np.sum(best_matches == expected) / len(expected)
"""Measures recall [0, 1] as of `Matches` that contain the corresponding
`expected` entry as the first result."""
return np.sum(self.labels[:, 0] == expected) / len(expected)

def recall(self, expected: np.ndarray) -> float:
"""Measures recall [0, 1] as of `Matches` that contain the corresponding
`expected` entry anywhere among results."""
assert len(expected) == self.batch_size
recall = 0
for i in range(self.batch_size):
recall += expected[i] in self.labels[i]
return recall / len(expected)

def __repr__(self) -> str:
return (
"usearch.Matches({})".format(self.total_matches)
if self.is_multiple
else "usearch.Matches({} across {} queries)".format(
self.total_matches, self.batch_size
)
)
return f"usearch.BatchMatches({np.sum(self.counts)} across {len(self)} queries)"


class CompiledMetric(NamedTuple):
Expand Down Expand Up @@ -382,7 +386,7 @@ def add(
tasks,
desc=name,
total=count_vectors,
unit="Vector",
unit="vector",
disable=log is False,
)
for labels, vectors in tasks:
Expand All @@ -405,7 +409,7 @@ def search(
exact: bool = False,
log: Union[str, bool] = False,
batch_size: int = 0,
) -> Matches:
) -> Union[Matches, BatchMatches]:
"""
Performs approximate nearest neighbors search for one or more queries.
Expand All @@ -422,13 +426,16 @@ def search(
:param batch_size: Number of vectors to process at once, defaults to 0
:type batch_size: int, optional
:return: Approximate matches for one or more queries
:rtype: Matches
:rtype: Union[Matches, BatchMatches]
"""

assert isinstance(vectors, np.ndarray), "Expects a NumPy array"
assert vectors.ndim == 1 or vectors.ndim == 2, "Expects a matrix or vector"
count_vectors = vectors.shape[0] if vectors.ndim == 2 else 1

def distil_batch(batch_matches: BatchMatches) -> Union[BatchMatches, Matches]:
return batch_matches if vectors.ndim == 2 else batch_matches[0]

if log and batch_size == 0:
batch_size = int(math.ceil(count_vectors / 100))

Expand All @@ -443,7 +450,7 @@ def search(
tasks,
desc=name,
total=count_vectors,
unit="Vector",
unit="vector",
disable=log is False,
)
for vectors in tasks:
Expand All @@ -453,14 +460,16 @@ def search(
exact=exact,
threads=threads,
)
tasks_matches.append(Matches(*tuple_))
tasks_matches.append(BatchMatches(*tuple_))
pbar.update(vectors.shape[0])

pbar.close()
return Matches(
labels=np.vstack([m.labels for m in tasks_matches]),
distances=np.vstack([m.distances for m in tasks_matches]),
counts=np.concatenate([m.counts for m in tasks_matches], axis=None),
return distil_batch(
BatchMatches(
labels=np.vstack([m.labels for m in tasks_matches]),
distances=np.vstack([m.distances for m in tasks_matches]),
counts=np.concatenate([m.counts for m in tasks_matches], axis=None),
)
)

else:
Expand All @@ -470,7 +479,7 @@ def search(
exact=exact,
threads=threads,
)
return Matches(*tuple_)
return distil_batch(BatchMatches(*tuple_))

def remove(
self,
Expand Down

0 comments on commit 1b40f13

Please sign in to comment.