-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add refine to the Python and C api's (#175)
Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai/cuvs#175
- Loading branch information
1 parent
dc4e3cf
commit 96db438
Showing
10 changed files
with
682 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cuvs/core/c_api.h> | ||
#include <cuvs/distance/distance.h> | ||
#include <dlpack/dlpack.h> | ||
#include <stdint.h> | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
/** | ||
* @defgroup ann_refine_c Approximate Nearest Neighbors Refinement C-API | ||
* @{ | ||
*/ | ||
/** | ||
* @brief Refine nearest neighbor search. | ||
* | ||
* Refinement is an operation that follows an approximate NN search. The approximate search has | ||
* already selected n_candidates neighbor candidates for each query. We narrow it down to k | ||
* neighbors. For each query, we calculate the exact distance between the query and its | ||
* n_candidates neighbor candidate, and select the k nearest ones. | ||
* | ||
* @param[in] res cuvsResources_t opaque C handle | ||
* @param[in] dataset device matrix that stores the dataset [n_rows, dims] | ||
* @param[in] queries device matrix of the queries [n_queris, dims] | ||
* @param[in] candidates indices of candidate vectors [n_queries, n_candidates], where | ||
* n_candidates >= k | ||
* @param[in] metric distance metric to use. Euclidean (L2) is used by default | ||
* @param[out] indices device matrix that stores the refined indices [n_queries, k] | ||
* @param[out] distances device matrix that stores the refined distances [n_queries, k] | ||
*/ | ||
cuvsError_t cuvsRefine(cuvsResources_t res, | ||
DLManagedTensor* dataset, | ||
DLManagedTensor* queries, | ||
DLManagedTensor* candidates, | ||
cuvsDistanceType metric, | ||
DLManagedTensor* indices, | ||
DLManagedTensor* distances); | ||
/** | ||
* @} | ||
*/ | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <cstdint> | ||
#include <dlpack/dlpack.h> | ||
|
||
#include <raft/core/error.hpp> | ||
#include <raft/core/mdspan_types.hpp> | ||
#include <raft/core/resources.hpp> | ||
|
||
#include <cuvs/core/c_api.h> | ||
#include <cuvs/core/exceptions.hpp> | ||
#include <cuvs/core/interop.hpp> | ||
#include <cuvs/neighbors/refine.h> | ||
#include <cuvs/neighbors/refine.hpp> | ||
|
||
namespace { | ||
|
||
template <typename T> | ||
void _refine(bool on_device, | ||
cuvsResources_t res, | ||
DLManagedTensor* dataset_tensor, | ||
DLManagedTensor* queries_tensor, | ||
DLManagedTensor* candidates_tensor, | ||
cuvsDistanceType metric, | ||
DLManagedTensor* indices_tensor, | ||
DLManagedTensor* distances_tensor) | ||
{ | ||
auto res_ptr = reinterpret_cast<raft::resources*>(res); | ||
|
||
if (on_device) { | ||
using queries_type = raft::device_matrix_view<const T, int64_t, raft::row_major>; | ||
using candidates_type = raft::device_matrix_view<const int64_t, int64_t, raft::row_major>; | ||
using indices_type = raft::device_matrix_view<int64_t, int64_t, raft::row_major>; | ||
using distances_type = raft::device_matrix_view<float, int64_t, raft::row_major>; | ||
auto dataset = cuvs::core::from_dlpack<queries_type>(dataset_tensor); | ||
auto queries = cuvs::core::from_dlpack<queries_type>(queries_tensor); | ||
auto candidates = cuvs::core::from_dlpack<candidates_type>(candidates_tensor); | ||
auto indices = cuvs::core::from_dlpack<indices_type>(indices_tensor); | ||
auto distances = cuvs::core::from_dlpack<distances_type>(distances_tensor); | ||
cuvs::neighbors::refine(*res_ptr, dataset, queries, candidates, indices, distances, metric); | ||
} else { | ||
using queries_type = raft::host_matrix_view<const T, int64_t, raft::row_major>; | ||
using candidates_type = raft::host_matrix_view<const int64_t, int64_t, raft::row_major>; | ||
using indices_type = raft::host_matrix_view<int64_t, int64_t, raft::row_major>; | ||
using distances_type = raft::host_matrix_view<float, int64_t, raft::row_major>; | ||
auto dataset = cuvs::core::from_dlpack<queries_type>(dataset_tensor); | ||
auto queries = cuvs::core::from_dlpack<queries_type>(queries_tensor); | ||
auto candidates = cuvs::core::from_dlpack<candidates_type>(candidates_tensor); | ||
auto indices = cuvs::core::from_dlpack<indices_type>(indices_tensor); | ||
auto distances = cuvs::core::from_dlpack<distances_type>(distances_tensor); | ||
cuvs::neighbors::refine(*res_ptr, dataset, queries, candidates, indices, distances, metric); | ||
} | ||
} | ||
} // namespace | ||
|
||
extern "C" cuvsError_t cuvsRefine(cuvsResources_t res, | ||
DLManagedTensor* dataset_tensor, | ||
DLManagedTensor* queries_tensor, | ||
DLManagedTensor* candidates_tensor, | ||
cuvsDistanceType metric, | ||
DLManagedTensor* indices_tensor, | ||
DLManagedTensor* distances_tensor) | ||
{ | ||
return cuvs::core::translate_exceptions([=] { | ||
auto dataset = dataset_tensor->dl_tensor; | ||
auto queries = queries_tensor->dl_tensor; | ||
auto candidates = candidates_tensor->dl_tensor; | ||
auto indices = indices_tensor->dl_tensor; | ||
auto distances = distances_tensor->dl_tensor; | ||
|
||
// all matrices must either be on host or on device, can't mix and match | ||
bool on_device = cuvs::core::is_dlpack_device_compatible(dataset); | ||
if (on_device != cuvs::core::is_dlpack_device_compatible(queries) || | ||
on_device != cuvs::core::is_dlpack_device_compatible(candidates) || | ||
on_device != cuvs::core::is_dlpack_device_compatible(indices) || | ||
on_device != cuvs::core::is_dlpack_device_compatible(distances)) { | ||
RAFT_FAIL("Tensors must either all be on device memory, or all on host memory"); | ||
} | ||
|
||
RAFT_EXPECTS(candidates.dtype.code == kDLInt && candidates.dtype.bits == 64, | ||
"candidates should be of type int64_t"); | ||
RAFT_EXPECTS(indices.dtype.code == kDLInt && indices.dtype.bits == 64, | ||
"indices should be of type int64_t"); | ||
RAFT_EXPECTS(distances.dtype.code == kDLFloat && distances.dtype.bits == 32, | ||
"distances should be of type float32"); | ||
|
||
RAFT_EXPECTS(queries.dtype.code == dataset.dtype.code, | ||
"type mismatch between dataset and queries"); | ||
|
||
if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) { | ||
_refine<float>(on_device, | ||
res, | ||
dataset_tensor, | ||
queries_tensor, | ||
candidates_tensor, | ||
metric, | ||
indices_tensor, | ||
distances_tensor); | ||
} else if (queries.dtype.code == kDLFloat && queries.dtype.bits == 16) { | ||
_refine<half>(on_device, | ||
res, | ||
dataset_tensor, | ||
queries_tensor, | ||
candidates_tensor, | ||
metric, | ||
indices_tensor, | ||
distances_tensor); | ||
} else if (queries.dtype.code == kDLInt && queries.dtype.bits == 8) { | ||
_refine<int8_t>(on_device, | ||
res, | ||
dataset_tensor, | ||
queries_tensor, | ||
candidates_tensor, | ||
metric, | ||
indices_tensor, | ||
distances_tensor); | ||
} else if (queries.dtype.code == kDLUInt && queries.dtype.bits == 8) { | ||
_refine<uint8_t>(on_device, | ||
res, | ||
dataset_tensor, | ||
queries_tensor, | ||
candidates_tensor, | ||
metric, | ||
indices_tensor, | ||
distances_tensor); | ||
} else { | ||
RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d", | ||
queries.dtype.code, | ||
queries.dtype.bits); | ||
} | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# cython: language_level=3 | ||
|
||
from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t | ||
from cuvs.common.cydlpack cimport DLDataType, DLManagedTensor | ||
from cuvs.distance_type cimport cuvsDistanceType | ||
|
||
|
||
cdef extern from "cuvs/neighbors/refine.h" nogil: | ||
cuvsError_t cuvsRefine(cuvsResources_t res, | ||
DLManagedTensor* dataset, | ||
DLManagedTensor* queries, | ||
DLManagedTensor* candidates, | ||
cuvsDistanceType metric, | ||
DLManagedTensor* indices, | ||
DLManagedTensor* distances) except + |
Oops, something went wrong.