From 1b40f139399dd29efe37553517f3ff8911c8cb22 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 24 Jul 2023 16:37:44 +0000 Subject: [PATCH] Add: `Matches` and `BatchMatches` simple API --- README.md | 8 ++-- python/README.md | 14 +++--- python/lib.cpp | 20 ++++---- python/scripts/test.py | 15 +++--- python/scripts/test_jit.py | 10 ++-- python/usearch/index.py | 97 +++++++++++++++++++++----------------- 6 files changed, 87 insertions(+), 77 deletions(-) diff --git a/README.md b/README.md index c18ebf4d..63314833 100644 --- a/README.md +++ b/README.md @@ -88,12 +88,12 @@ index = Index( vector = np.array([0.2, 0.6, 0.4]) index.add(42, vector) -matches, distances, count = index.search(vector, 10) +matches: Matches = index.search(vector, 10) assert len(index) == 1 -assert count == 1 -assert matches[0] == 42 -assert distances[0] <= 0.001 +assert len(matches) == 1 +assert matches[0].label == 42 +assert matches[0].distance <= 0.001 assert np.allclose(index[42], vector) ``` diff --git a/python/README.md b/python/README.md index e10c2df4..f3146113 100644 --- a/python/README.md +++ b/python/README.md @@ -23,12 +23,12 @@ index = Index( vector = np.array([0.2, 0.6, 0.4]) index.add(42, vector) -matches, distances, count = index.search(vector, 10) +matches: Matches = index.search(vector, 10) assert len(index) == 1 -assert count == 1 -assert matches[0] == 42 -assert distances[0] <= 0.001 +assert len(matches) == 1 +assert matches[0].label == 42 +assert matches[0].distance <= 0.001 assert np.allclose(index[42], vector) ``` @@ -61,10 +61,10 @@ labels = np.arange(n) vectors = np.random.uniform(0, 0.3, (n, index.ndim)).astype(np.float32) index.add(labels, vectors, threads=..., copy=...) -matches, distances, counts = index.search(vectors, 10, threads=...) +matches: BatchMatches = index.search(vectors, 10, threads=...) -assert matches.shape[0] == vectors.shape[0] -assert counts[0] <= 10 +assert len(matches) == vectors.shape[0] +assert len(matches[0]) <= 10 ``` You can also override the default `threads` and `copy` arguments in bulk workloads. diff --git a/python/lib.cpp b/python/lib.cpp index a1129bad..1509f276 100644 --- a/python/lib.cpp +++ b/python/lib.cpp @@ -271,18 +271,21 @@ static py::tuple search_one_in_index(dense_index_py_t& index, py::buffer vector, if (vector_dimensions != static_cast(index.scalar_words())) throw std::invalid_argument("The number of vector dimensions doesn't match!"); - py::array_t labels_py(static_cast(wanted)); - py::array_t distances_py(static_cast(wanted)); + constexpr Py_ssize_t vectors_count = 1; + py::array_t labels_py({vectors_count, static_cast(wanted)}); + py::array_t distances_py({vectors_count, static_cast(wanted)}); + py::array_t counts_py(vectors_count); std::size_t count{}; - auto labels_py1d = labels_py.template mutable_unchecked<1>(); - auto distances_py1d = distances_py.template mutable_unchecked<1>(); + auto labels_py2d = labels_py.template mutable_unchecked<2>(); + auto distances_py2d = distances_py.template mutable_unchecked<2>(); + auto counts_py1d = counts_py.template mutable_unchecked<1>(); search_config_t config; config.exact = exact; auto raise_and_dump = [&](dense_search_result_t result) { result.error.raise(); - count = result.dump_to(&labels_py1d(0), &distances_py1d(0)); + count = result.dump_to(&labels_py2d(0, 0), &distances_py2d(0, 0)); }; switch (numpy_string_to_kind(vector_info.format)) { @@ -295,13 +298,14 @@ static py::tuple search_one_in_index(dense_index_py_t& index, py::buffer vector, throw std::invalid_argument("Incompatible scalars in the query vector: " + vector_info.format); } - labels_py.resize(py_shape_t{static_cast(count)}); - distances_py.resize(py_shape_t{static_cast(count)}); + labels_py.resize(py_shape_t{vectors_count, static_cast(count)}); + distances_py.resize(py_shape_t{vectors_count, static_cast(count)}); + counts_py1d[0] = static_cast(count); py::tuple results(3); results[0] = labels_py; results[1] = distances_py; - results[2] = static_cast(count); + results[2] = counts_py; return results; } diff --git a/python/scripts/test.py b/python/scripts/test.py index fbbdf424..d3ea449d 100644 --- a/python/scripts/test.py +++ b/python/scripts/test.py @@ -109,13 +109,12 @@ def test_index( if numpy_type != np.byte: assert np.allclose(index[42], vector, atol=0.1) - matches, distances, count = index.search(vector, 10) + matches = index.search(vector, 10) assert len(index) == 1 - assert len(matches) == count - assert len(distances) == count - assert count == 1 - assert matches[0] == 42 - assert distances[0] == pytest.approx(0, abs=1e-3) + assert len(matches.labels) == len(matches.distances) + assert len(matches.labels) == 1 + assert matches[0].label == 42 + assert matches[0].distance == pytest.approx(0, abs=1e-3) assert index.max_level >= 0 assert index.levels_stats.nodes >= 1 @@ -136,8 +135,8 @@ def test_index( assert len(index) == 0 index.add(42, vector) assert len(index) == 1 - matches, distances, count = index.search(vector, 10) - assert count == 1 + matches = index.search(vector, 10) + assert len(matches) == 1 index.load("tmp.usearch") assert len(index) == 1 diff --git a/python/scripts/test_jit.py b/python/scripts/test_jit.py index 0e0eec51..0f4589ef 100644 --- a/python/scripts/test_jit.py +++ b/python/scripts/test_jit.py @@ -89,9 +89,8 @@ def python_inner_product_four_args(a, a_ndim, b, b_ndim): vectors = random_vectors(count=batch_size, ndim=ndim) index.add(labels, vectors) - matches, distances, count = index.search(vectors, 10) - assert matches.shape[0] == distances.shape[0] - assert count.shape[0] == batch_size + matches = index.search(vectors, 10) + assert len(matches) == batch_size @pytest.mark.parametrize("ndim", dimensions[-1:]) @@ -153,9 +152,8 @@ def test_index_cppyy(ndim: int, batch_size: int): vectors = random_vectors(count=batch_size, ndim=ndim) index.add(labels, vectors) - matches, distances, count = index.search(vectors, 10) - assert matches.shape[0] == distances.shape[0] - assert count.shape[0] == batch_size + matches = index.search(vectors, 10) + assert len(matches) == batch_size @pytest.mark.parametrize("ndim", [8]) diff --git a/python/usearch/index.py b/python/usearch/index.py index 485174c0..f0247891 100644 --- a/python/usearch/index.py +++ b/python/usearch/index.py @@ -102,47 +102,57 @@ def _normalize_metric(metric): return metric +class Match(NamedTuple): + label: int + distance: float + + class Matches(NamedTuple): labels: np.ndarray distances: np.ndarray - counts: Union[np.ndarray, int] - @property - def is_multiple(self) -> bool: - return isinstance(self.counts, np.ndarray) + def __len__(self) -> int: + return len(self.labels) - @property - def batch_size(self) -> int: - return len(self.counts) if isinstance(self.counts, np.ndarray) else 1 + def __getitem__(self, index: int) -> Match: + return Match( + label=self.labels[index], + distance=self.distances[index], + ) - @property - def total_matches(self) -> int: - return np.sum(self.counts) + def to_list(self) -> List[tuple]: + return [(int(l), float(d)) for l, d in zip(self.labels, self.distances)] - def to_list(self, row: Optional[int] = None) -> Union[List[dict], List[List[dict]]]: - if not self.is_multiple: - assert row is None, "Exporting a single sequence is only for batch requests" - labels = self.labels - distances = self.distances + def __repr__(self) -> str: + return f"usearch.Matches({len(self)})" - elif row is None: - return [self.to_list(i) for i in range(self.batch_size)] - else: - count = self.counts[row] - labels = self.labels[row, :count] - distances = self.distances[row, :count] +class BatchMatches(NamedTuple): + labels: np.ndarray + distances: np.ndarray + counts: np.ndarray - return [ - {"label": int(label), "distance": float(distance)} - for label, distance in zip(labels, distances) - ] + def __len__(self) -> int: + return len(self.counts) + + def __getitem__(self, index: int) -> Matches: + return Matches( + labels=self.labels[index, : self.counts[index]], + distances=self.distances[index, : self.counts[index]], + ) + + def to_list(self) -> List[List[tuple]]: + lists = [self.__getitem__(row) for row in range(self.__len__())] + return [item for sublist in lists for item in sublist] def recall_first(self, expected: np.ndarray) -> float: - best_matches = self.labels if not self.is_multiple else self.labels[:, 0] - return np.sum(best_matches == expected) / len(expected) + """Measures recall [0, 1] as of `Matches` that contain the corresponding + `expected` entry as the first result.""" + return np.sum(self.labels[:, 0] == expected) / len(expected) def recall(self, expected: np.ndarray) -> float: + """Measures recall [0, 1] as of `Matches` that contain the corresponding + `expected` entry anywhere among results.""" assert len(expected) == self.batch_size recall = 0 for i in range(self.batch_size): @@ -150,13 +160,7 @@ def recall(self, expected: np.ndarray) -> float: return recall / len(expected) def __repr__(self) -> str: - return ( - "usearch.Matches({})".format(self.total_matches) - if self.is_multiple - else "usearch.Matches({} across {} queries)".format( - self.total_matches, self.batch_size - ) - ) + return f"usearch.BatchMatches({np.sum(self.counts)} across {len(self)} queries)" class CompiledMetric(NamedTuple): @@ -382,7 +386,7 @@ def add( tasks, desc=name, total=count_vectors, - unit="Vector", + unit="vector", disable=log is False, ) for labels, vectors in tasks: @@ -405,7 +409,7 @@ def search( exact: bool = False, log: Union[str, bool] = False, batch_size: int = 0, - ) -> Matches: + ) -> Union[Matches, BatchMatches]: """ Performs approximate nearest neighbors search for one or more queries. @@ -422,13 +426,16 @@ def search( :param batch_size: Number of vectors to process at once, defaults to 0 :type batch_size: int, optional :return: Approximate matches for one or more queries - :rtype: Matches + :rtype: Union[Matches, BatchMatches] """ assert isinstance(vectors, np.ndarray), "Expects a NumPy array" assert vectors.ndim == 1 or vectors.ndim == 2, "Expects a matrix or vector" count_vectors = vectors.shape[0] if vectors.ndim == 2 else 1 + def distil_batch(batch_matches: BatchMatches) -> Union[BatchMatches, Matches]: + return batch_matches if vectors.ndim == 2 else batch_matches[0] + if log and batch_size == 0: batch_size = int(math.ceil(count_vectors / 100)) @@ -443,7 +450,7 @@ def search( tasks, desc=name, total=count_vectors, - unit="Vector", + unit="vector", disable=log is False, ) for vectors in tasks: @@ -453,14 +460,16 @@ def search( exact=exact, threads=threads, ) - tasks_matches.append(Matches(*tuple_)) + tasks_matches.append(BatchMatches(*tuple_)) pbar.update(vectors.shape[0]) pbar.close() - return Matches( - labels=np.vstack([m.labels for m in tasks_matches]), - distances=np.vstack([m.distances for m in tasks_matches]), - counts=np.concatenate([m.counts for m in tasks_matches], axis=None), + return distil_batch( + BatchMatches( + labels=np.vstack([m.labels for m in tasks_matches]), + distances=np.vstack([m.distances for m in tasks_matches]), + counts=np.concatenate([m.counts for m in tasks_matches], axis=None), + ) ) else: @@ -470,7 +479,7 @@ def search( exact=exact, threads=threads, ) - return Matches(*tuple_) + return distil_batch(BatchMatches(*tuple_)) def remove( self,