diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 5f016ed88..abe6fa7b3 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -555,8 +555,13 @@ target_link_options(cuvs PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") if(BUILD_C_LIBRARY) add_library( cuvs_c SHARED - src/core/c_api.cpp src/neighbors/brute_force_c.cpp src/neighbors/ivf_flat_c.cpp - src/neighbors/ivf_pq_c.cpp src/neighbors/cagra_c.cpp src/distance/pairwise_distance_c.cpp + src/core/c_api.cpp + src/neighbors/brute_force_c.cpp + src/neighbors/ivf_flat_c.cpp + src/neighbors/ivf_pq_c.cpp + src/neighbors/cagra_c.cpp + src/neighbors/refine/refine_c.cpp + src/distance/pairwise_distance_c.cpp ) add_library(cuvs::c_api ALIAS cuvs_c) diff --git a/cpp/include/cuvs/core/detail/interop.hpp b/cpp/include/cuvs/core/detail/interop.hpp index 208daaae7..2ed0b330d 100644 --- a/cpp/include/cuvs/core/detail/interop.hpp +++ b/cpp/include/cuvs/core/detail/interop.hpp @@ -16,6 +16,8 @@ #pragma once +#include + #include #include #include @@ -44,7 +46,9 @@ DLDataType data_type_to_DLDataType() { uint8_t const bits{sizeof(T) * 8}; uint16_t const lanes{1}; - if constexpr (std::is_floating_point_v) { + // std::is_floating_point returns false for the half type - handle + // that here + if constexpr (std::is_floating_point_v || std::is_same_v) { return DLDataType{kDLFloat, bits, lanes}; } else if constexpr (std::is_signed_v) { return DLDataType{kDLInt, bits, lanes}; @@ -72,9 +76,13 @@ inline MdspanType from_dlpack(DLManagedTensor* managed_tensor) auto to_data_type = data_type_to_DLDataType(); RAFT_EXPECTS(to_data_type.code == tensor.dtype.code, - "code mismatch between return mdspan and DLTensor"); + "code mismatch between return mdspan (%i) and DLTensor (%i)", + to_data_type.code, + tensor.dtype.code); RAFT_EXPECTS(to_data_type.bits == tensor.dtype.bits, - "bits mismatch between return mdspan and DLTensor"); + "bits mismatch between return mdspan (%i) and DLTensor (%i)", + to_data_type.bits, + tensor.dtype.bits); RAFT_EXPECTS(to_data_type.lanes == tensor.dtype.lanes, "lanes mismatch between return mdspan and DLTensor"); RAFT_EXPECTS(tensor.dtype.lanes == 1, "More than 1 DLTensor lanes not supported"); diff --git a/cpp/include/cuvs/neighbors/refine.h b/cpp/include/cuvs/neighbors/refine.h new file mode 100644 index 000000000..4e7a572b4 --- /dev/null +++ b/cpp/include/cuvs/neighbors/refine.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif +/** + * @defgroup ann_refine_c Approximate Nearest Neighbors Refinement C-API + * @{ + */ +/** + * @brief Refine nearest neighbor search. + * + * Refinement is an operation that follows an approximate NN search. The approximate search has + * already selected n_candidates neighbor candidates for each query. We narrow it down to k + * neighbors. For each query, we calculate the exact distance between the query and its + * n_candidates neighbor candidate, and select the k nearest ones. + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] dataset device matrix that stores the dataset [n_rows, dims] + * @param[in] queries device matrix of the queries [n_queris, dims] + * @param[in] candidates indices of candidate vectors [n_queries, n_candidates], where + * n_candidates >= k + * @param[in] metric distance metric to use. Euclidean (L2) is used by default + * @param[out] indices device matrix that stores the refined indices [n_queries, k] + * @param[out] distances device matrix that stores the refined distances [n_queries, k] + */ +cuvsError_t cuvsRefine(cuvsResources_t res, + DLManagedTensor* dataset, + DLManagedTensor* queries, + DLManagedTensor* candidates, + cuvsDistanceType metric, + DLManagedTensor* indices, + DLManagedTensor* distances); +/** + * @} + */ + +#ifdef __cplusplus +} +#endif diff --git a/cpp/src/neighbors/refine/refine_c.cpp b/cpp/src/neighbors/refine/refine_c.cpp new file mode 100644 index 000000000..955eaded4 --- /dev/null +++ b/cpp/src/neighbors/refine/refine_c.cpp @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +template +void _refine(bool on_device, + cuvsResources_t res, + DLManagedTensor* dataset_tensor, + DLManagedTensor* queries_tensor, + DLManagedTensor* candidates_tensor, + cuvsDistanceType metric, + DLManagedTensor* indices_tensor, + DLManagedTensor* distances_tensor) +{ + auto res_ptr = reinterpret_cast(res); + + if (on_device) { + using queries_type = raft::device_matrix_view; + using candidates_type = raft::device_matrix_view; + using indices_type = raft::device_matrix_view; + using distances_type = raft::device_matrix_view; + auto dataset = cuvs::core::from_dlpack(dataset_tensor); + auto queries = cuvs::core::from_dlpack(queries_tensor); + auto candidates = cuvs::core::from_dlpack(candidates_tensor); + auto indices = cuvs::core::from_dlpack(indices_tensor); + auto distances = cuvs::core::from_dlpack(distances_tensor); + cuvs::neighbors::refine(*res_ptr, dataset, queries, candidates, indices, distances, metric); + } else { + using queries_type = raft::host_matrix_view; + using candidates_type = raft::host_matrix_view; + using indices_type = raft::host_matrix_view; + using distances_type = raft::host_matrix_view; + auto dataset = cuvs::core::from_dlpack(dataset_tensor); + auto queries = cuvs::core::from_dlpack(queries_tensor); + auto candidates = cuvs::core::from_dlpack(candidates_tensor); + auto indices = cuvs::core::from_dlpack(indices_tensor); + auto distances = cuvs::core::from_dlpack(distances_tensor); + cuvs::neighbors::refine(*res_ptr, dataset, queries, candidates, indices, distances, metric); + } +} +} // namespace + +extern "C" cuvsError_t cuvsRefine(cuvsResources_t res, + DLManagedTensor* dataset_tensor, + DLManagedTensor* queries_tensor, + DLManagedTensor* candidates_tensor, + cuvsDistanceType metric, + DLManagedTensor* indices_tensor, + DLManagedTensor* distances_tensor) +{ + return cuvs::core::translate_exceptions([=] { + auto dataset = dataset_tensor->dl_tensor; + auto queries = queries_tensor->dl_tensor; + auto candidates = candidates_tensor->dl_tensor; + auto indices = indices_tensor->dl_tensor; + auto distances = distances_tensor->dl_tensor; + + // all matrices must either be on host or on device, can't mix and match + bool on_device = cuvs::core::is_dlpack_device_compatible(dataset); + if (on_device != cuvs::core::is_dlpack_device_compatible(queries) || + on_device != cuvs::core::is_dlpack_device_compatible(candidates) || + on_device != cuvs::core::is_dlpack_device_compatible(indices) || + on_device != cuvs::core::is_dlpack_device_compatible(distances)) { + RAFT_FAIL("Tensors must either all be on device memory, or all on host memory"); + } + + RAFT_EXPECTS(candidates.dtype.code == kDLInt && candidates.dtype.bits == 64, + "candidates should be of type int64_t"); + RAFT_EXPECTS(indices.dtype.code == kDLInt && indices.dtype.bits == 64, + "indices should be of type int64_t"); + RAFT_EXPECTS(distances.dtype.code == kDLFloat && distances.dtype.bits == 32, + "distances should be of type float32"); + + RAFT_EXPECTS(queries.dtype.code == dataset.dtype.code, + "type mismatch between dataset and queries"); + + if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) { + _refine(on_device, + res, + dataset_tensor, + queries_tensor, + candidates_tensor, + metric, + indices_tensor, + distances_tensor); + } else if (queries.dtype.code == kDLFloat && queries.dtype.bits == 16) { + _refine(on_device, + res, + dataset_tensor, + queries_tensor, + candidates_tensor, + metric, + indices_tensor, + distances_tensor); + } else if (queries.dtype.code == kDLInt && queries.dtype.bits == 8) { + _refine(on_device, + res, + dataset_tensor, + queries_tensor, + candidates_tensor, + metric, + indices_tensor, + distances_tensor); + } else if (queries.dtype.code == kDLUInt && queries.dtype.bits == 8) { + _refine(on_device, + res, + dataset_tensor, + queries_tensor, + candidates_tensor, + metric, + indices_tensor, + distances_tensor); + } else { + RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d", + queries.dtype.code, + queries.dtype.bits); + } + }); +} diff --git a/python/cuvs/cuvs/common/cydlpack.pyx b/python/cuvs/cuvs/common/cydlpack.pyx index 526f6c78e..79f88cddc 100644 --- a/python/cuvs/cuvs/common/cydlpack.pyx +++ b/python/cuvs/cuvs/common/cydlpack.pyx @@ -53,6 +53,9 @@ cdef DLManagedTensor* dlpack_c(ary): elif ary.dtype == np.float64: dtype.code = DLDataTypeCode.kDLFloat dtype.bits = 64 + elif ary.dtype == np.float16: + dtype.code = DLDataTypeCode.kDLFloat + dtype.bits = 16 elif ary.dtype == np.int8: dtype.code = DLDataTypeCode.kDLInt dtype.bits = 8 @@ -74,6 +77,8 @@ cdef DLManagedTensor* dlpack_c(ary): elif ary.dtype == np.bool_: dtype.code = DLDataTypeCode.kDLFloat dtype.bits = 8 + else: + raise ValueError(f"Unsupported dtype {ary.dtype}") dtype.lanes = 1 diff --git a/python/cuvs/cuvs/neighbors/CMakeLists.txt b/python/cuvs/cuvs/neighbors/CMakeLists.txt index e0041243a..3579215fd 100644 --- a/python/cuvs/cuvs/neighbors/CMakeLists.txt +++ b/python/cuvs/cuvs/neighbors/CMakeLists.txt @@ -16,3 +16,14 @@ add_subdirectory(brute_force) add_subdirectory(cagra) add_subdirectory(ivf_flat) add_subdirectory(ivf_pq) + +# Set the list of Cython files to build +set(cython_sources refine.pyx) +set(linked_libraries cuvs::cuvs cuvs::c_api) + +# Build all of the Cython targets +rapids_cython_create_modules( + CXX + SOURCE_FILES "${cython_sources}" + LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS cuvs MODULE_PREFIX neighbors_refine_ +) diff --git a/python/cuvs/cuvs/neighbors/__init__.py b/python/cuvs/cuvs/neighbors/__init__.py index c431ffde7..0ecd57d75 100644 --- a/python/cuvs/cuvs/neighbors/__init__.py +++ b/python/cuvs/cuvs/neighbors/__init__.py @@ -15,4 +15,6 @@ from cuvs.neighbors import brute_force, cagra, ivf_flat, ivf_pq -__all__ = ["brute_force", "cagra", "ivf_flat", "ivf_pq"] +from .refine import refine + +__all__ = ["brute_force", "cagra", "ivf_flat", "ivf_pq", "refine"] diff --git a/python/cuvs/cuvs/neighbors/refine.pxd b/python/cuvs/cuvs/neighbors/refine.pxd new file mode 100644 index 000000000..f02404734 --- /dev/null +++ b/python/cuvs/cuvs/neighbors/refine.pxd @@ -0,0 +1,30 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: language_level=3 + +from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t +from cuvs.common.cydlpack cimport DLDataType, DLManagedTensor +from cuvs.distance_type cimport cuvsDistanceType + + +cdef extern from "cuvs/neighbors/refine.h" nogil: + cuvsError_t cuvsRefine(cuvsResources_t res, + DLManagedTensor* dataset, + DLManagedTensor* queries, + DLManagedTensor* candidates, + cuvsDistanceType metric, + DLManagedTensor* indices, + DLManagedTensor* distances) except + diff --git a/python/cuvs/cuvs/neighbors/refine.pyx b/python/cuvs/cuvs/neighbors/refine.pyx new file mode 100644 index 000000000..0eccc4108 --- /dev/null +++ b/python/cuvs/cuvs/neighbors/refine.pyx @@ -0,0 +1,169 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: language_level=3 + +import numpy as np + +cimport cuvs.common.cydlpack + +from cuvs.common.resources import auto_sync_resources + +from cython.operator cimport dereference as deref +from libc.stdint cimport uint32_t +from libcpp cimport bool + +from cuvs.common cimport cydlpack +from cuvs.distance_type cimport cuvsDistanceType + +from pylibraft.common import auto_convert_output, device_ndarray +from pylibraft.common.cai_wrapper import wrap_array +from pylibraft.common.interruptible import cuda_interruptible +from pylibraft.neighbors.common import _check_input_array + +from cuvs.distance import DISTANCE_TYPES + +from cuvs.common.c_api cimport cuvsResources_t + +from cuvs.common.exceptions import check_cuvs + + +@auto_sync_resources +@auto_convert_output +def refine(dataset, + queries, + candidates, + k=None, + metric="sqeuclidean", + indices=None, + distances=None, + resources=None): + """ + Refine nearest neighbor search. + + Refinement is an operation that follows an approximate NN search. The + approximate search has already selected n_candidates neighbor candidates + for each query. We narrow it down to k neighbors. For each query, we + calculate the exact distance between the query and its n_candidates + neighbor candidate, and select the k nearest ones. + + Input arrays can be either CUDA array interface compliant matrices or + array interface compliant matrices in host memory. All array must be in + the same memory space. + + Parameters + ---------- + dataset : array interface compliant matrix, shape (n_samples, dim) + Supported dtype [float32, int8, uint8, float16] + queries : array interface compliant matrix, shape (n_queries, dim) + Supported dtype [float32, int8, uint8, float16] + candidates : array interface compliant matrix, shape (n_queries, k0) + Supported dtype int64 + k : int + Number of neighbors to search (k <= k0). Optional if indices or + distances arrays are given (in which case their second dimension + is k). + metric : str + Name of distance metric to use, default ="sqeuclidean" + indices : Optional array interface compliant matrix shape \ + (n_queries, k). + If supplied, neighbor indices will be written here in-place. + (default None). Supported dtype int64. + distances : Optional array interface compliant matrix shape \ + (n_queries, k). + If supplied, neighbor indices will be written here in-place. + (default None) Supported dtype float. + {resources_docstring} + + Examples + -------- + >>> import cupy as cp + >>> from cuvs.common import Resources + >>> from cuvs.neighbors import ivf_pq, refine + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> resources = Resources() + >>> index_params = ivf_pq.IndexParams(n_lists=1024, + ... metric="sqeuclidean", + ... pq_dim=10) + >>> index = ivf_pq.build(index_params, dataset, resources=resources) + >>> # Search using the built index + >>> queries = cp.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> k = 40 + >>> _, candidates = ivf_pq.search(ivf_pq.SearchParams(), index, + ... queries, k, resources=resources) + >>> k = 10 + >>> distances, neighbors = refine(dataset, queries, candidates, k) + """ + cdef cuvsResources_t res = resources.get_c_obj() + + if k is None: + if indices is not None: + k = wrap_array(indices).shape[1] + elif distances is not None: + k = wrap_array(distances).shape[1] + else: + raise ValueError("Argument k must be specified if both indices " + "and distances arg is None") + + queries_cai = wrap_array(queries) + dataset_cai = wrap_array(dataset) + candidates_cai = wrap_array(candidates) + n_queries = wrap_array(queries).shape[0] + + on_device = hasattr(dataset, "__cuda_array_interface__") + ndarray = device_ndarray if on_device else np + if indices is None: + indices = ndarray.empty((n_queries, k), dtype='int64') + + if distances is None: + distances = ndarray.empty((n_queries, k), dtype='float32') + + indices_cai = wrap_array(indices) + distances_cai = wrap_array(distances) + + _check_input_array(indices_cai, [np.dtype('int64')], + exp_rows=n_queries, exp_cols=k) + _check_input_array(distances_cai, [np.dtype('float32')], + exp_rows=n_queries, exp_cols=k) + + cdef cydlpack.DLManagedTensor* dataset_dlpack = \ + cydlpack.dlpack_c(dataset_cai) + cdef cydlpack.DLManagedTensor* queries_dlpack = \ + cydlpack.dlpack_c(queries_cai) + cdef cydlpack.DLManagedTensor* candidates_dlpack = \ + cydlpack.dlpack_c(candidates_cai) + cdef cydlpack.DLManagedTensor* indices_dlpack = \ + cydlpack.dlpack_c(indices_cai) + cdef cydlpack.DLManagedTensor* distances_dlpack = \ + cydlpack.dlpack_c(distances_cai) + + cdef cuvsDistanceType c_metric = DISTANCE_TYPES[metric] + + with cuda_interruptible(): + check_cuvs(cuvsRefine( + res, + dataset_dlpack, + queries_dlpack, + candidates_dlpack, + c_metric, + indices_dlpack, + distances_dlpack)) + + return (distances, indices) diff --git a/python/cuvs/cuvs/test/test_refine.py b/python/cuvs/cuvs/test/test_refine.py new file mode 100644 index 000000000..c7f9f678e --- /dev/null +++ b/python/cuvs/cuvs/test/test_refine.py @@ -0,0 +1,239 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# h ttp://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from pylibraft.common import device_ndarray +from sklearn.neighbors import NearestNeighbors +from sklearn.preprocessing import normalize +from test_ivf_pq import calc_recall, generate_data + +from cuvs.neighbors import refine + + +def run_refine( + n_rows=500, + n_cols=50, + n_queries=100, + metric="sqeuclidean", + k0=40, + k=10, + inplace=False, + dtype=np.float32, + memory_type="device", +): + + dataset = generate_data((n_rows, n_cols), dtype) + queries = generate_data((n_queries, n_cols), dtype) + + if metric == "inner_product": + if dtype != np.float32: + pytest.skip("Normalized input cannot be represented in int8") + return + dataset = normalize(dataset, norm="l2", axis=1) + queries = normalize(queries, norm="l2", axis=1) + + dataset_device = device_ndarray(dataset) + queries_device = device_ndarray(queries) + + # Calculate reference values with sklearn + skl_metric = {"sqeuclidean": "euclidean", "inner_product": "cosine"}[ + metric + ] + nn_skl = NearestNeighbors( + n_neighbors=k0, algorithm="brute", metric=skl_metric + ) + nn_skl.fit(dataset) + skl_dist, candidates = nn_skl.kneighbors(queries) + candidates = candidates.astype(np.int64) + candidates_device = device_ndarray(candidates) + + out_idx = np.zeros((n_queries, k), dtype=np.int64) + out_dist = np.zeros((n_queries, k), dtype=np.float32) + out_idx_device = device_ndarray(out_idx) if inplace else None + out_dist_device = device_ndarray(out_dist) if inplace else None + + if memory_type == "device": + if inplace: + refine( + dataset_device, + queries_device, + candidates_device, + indices=out_idx_device, + distances=out_dist_device, + metric=metric, + ) + else: + out_dist_device, out_idx_device = refine( + dataset_device, + queries_device, + candidates_device, + k=k, + metric=metric, + ) + out_idx = out_idx_device.copy_to_host() + out_dist = out_dist_device.copy_to_host() + elif memory_type == "host": + if inplace: + refine( + dataset, + queries, + candidates, + indices=out_idx, + distances=out_dist, + metric=metric, + ) + else: + out_dist, out_idx = refine( + dataset, queries, candidates, k=k, metric=metric + ) + + skl_idx = candidates[:, :k] + + recall = calc_recall(out_idx, skl_idx) + + if memory_type == "device" and dtype == np.float16: + # fp16 differences between host and device make the + # reference results from sklearn likely to be substantially different + # from those calculated on device. + assert recall >= 0.9 + + elif recall <= 0.999: + # We did not find the same neighbor indices. + # We could have found other neighbor with same distance. + if metric == "sqeuclidean": + skl_dist = np.power(skl_dist[:, :k], 2) + elif metric == "inner_product": + skl_dist = 1 - skl_dist[:, :k] + else: + raise ValueError("Invalid metric") + + mask = out_idx != skl_idx + assert np.all(out_dist[mask] <= skl_dist[mask] + 1.0e-6) + + +@pytest.mark.parametrize("n_queries", [100, 1024, 37]) +@pytest.mark.parametrize("inplace", [True, False]) +@pytest.mark.parametrize("metric", ["sqeuclidean", "inner_product"]) +@pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8, np.float16]) +@pytest.mark.parametrize("memory_type", ["device", "host"]) +def test_refine_dtypes(n_queries, dtype, inplace, metric, memory_type): + run_refine( + n_rows=2000, + n_queries=n_queries, + n_cols=50, + k0=40, + k=10, + dtype=dtype, + inplace=inplace, + metric=metric, + memory_type=memory_type, + ) + + +@pytest.mark.parametrize( + "params", + [ + pytest.param( + { + "n_rows": 0, + "n_cols": 10, + "n_queries": 10, + "k0": 10, + "k": 1, + }, + marks=pytest.mark.xfail(reason="empty dataset"), + ), + {"n_rows": 1, "n_cols": 10, "n_queries": 10, "k": 1, "k0": 1}, + {"n_rows": 10, "n_cols": 1, "n_queries": 10, "k": 10, "k0": 10}, + {"n_rows": 999, "n_cols": 42, "n_queries": 453, "k0": 137, "k": 53}, + ], +) +@pytest.mark.parametrize("memory_type", ["device", "host"]) +def test_refine_row_col(params, memory_type): + run_refine( + n_rows=params["n_rows"], + n_queries=params["n_queries"], + n_cols=params["n_cols"], + k0=params["k0"], + k=params["k"], + memory_type=memory_type, + ) + + +@pytest.mark.parametrize("memory_type", ["device", "host"]) +def test_input_dtype(memory_type): + with pytest.raises(Exception): + run_refine(dtype=np.float64, memory_type=memory_type) + + +@pytest.mark.parametrize( + "params", + [ + {"idx_shape": None, "dist_shape": None, "k": None}, + {"idx_shape": [100, 9], "dist_shape": None, "k": 10}, + {"idx_shape": [101, 10], "dist_shape": None, "k": None}, + {"idx_shape": None, "dist_shape": [100, 11], "k": 10}, + {"idx_shape": None, "dist_shape": [99, 10], "k": None}, + ], +) +@pytest.mark.parametrize("memory_type", ["device", "host"]) +def test_input_assertions(params, memory_type): + n_cols = 5 + n_queries = 100 + k0 = 40 + dtype = np.float32 + dataset = generate_data((500, n_cols), dtype) + dataset_device = device_ndarray(dataset) + + queries = generate_data((n_queries, n_cols), dtype) + queries_device = device_ndarray(queries) + + candidates = np.random.randint( + 0, 500, size=(n_queries, k0), dtype=np.int64 + ) + candidates_device = device_ndarray(candidates) + + if params["idx_shape"] is not None: + out_idx = np.zeros(params["idx_shape"], dtype=np.int64) + out_idx_device = device_ndarray(out_idx) + else: + out_idx_device = None + if params["dist_shape"] is not None: + out_dist = np.zeros(params["dist_shape"], dtype=np.float32) + out_dist_device = device_ndarray(out_dist) + else: + out_dist_device = None + + if memory_type == "device": + with pytest.raises(Exception): + distances, indices = refine( + dataset_device, + queries_device, + candidates_device, + k=params["k"], + indices=out_idx_device, + distances=out_dist_device, + ) + else: + with pytest.raises(Exception): + distances, indices = refine( + dataset, + queries, + candidates, + k=params["k"], + indices=out_idx, + distances=out_dist, + )