Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add col-major support for brute force knn #217

Merged
merged 4 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 64 additions & 8 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ struct index : cuvs::neighbors::index {
*
* Constructs a brute force index from a dataset. This lets us precompute norms for
* the dataset, providing a speed benefit over doing this at query time.
* This index will store a non-owning reference to the dataset.
* This index will copy the host dataset onto the device, and take ownership of any
* precaculated norms.
*/
index(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset_view,
Expand All @@ -61,7 +62,8 @@ struct index : cuvs::neighbors::index {
*
* Constructs a brute force index from a dataset. This lets us precompute norms for
* the dataset, providing a speed benefit over doing this at query time.
* The dataset will be copied to the device and the index will own the device memory.
* This index will store a non-owning reference to the dataset, but will move
* any norms supplied.
*/
index(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset_view,
Expand All @@ -71,7 +73,7 @@ struct index : cuvs::neighbors::index {

/** Construct a brute force index from dataset
*
* This class stores a non-owning reference to the dataset and norms here.
* This class stores a non-owning reference to the dataset and norms.
* Having precomputed norms gives us a performance advantage at query time.
*/
index(raft::resources const& res,
Expand All @@ -80,6 +82,17 @@ struct index : cuvs::neighbors::index {
cuvs::distance::DistanceType metric,
T metric_arg = 0.0);

/** Construct a brute force index from dataset
*
* This class stores a non-owning reference to the dataset and norms, with
* the dataset being supplied on device in a col_major format
*/
index(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::col_major> dataset_view,
std::optional<raft::device_vector<T, int64_t>>&& norms,
cuvs::distance::DistanceType metric,
T metric_arg = 0.0);

/**
* Replace the dataset with a new dataset.
*/
Expand Down Expand Up @@ -152,12 +165,34 @@ struct index : cuvs::neighbors::index {
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
*
* @return the constructed ivf-flat index
* @return the constructed bruteforce index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float>;

/**
* @brief Build the index from the dataset for efficient search.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a [N, D] dataset
* auto index = brute_force::build(handle, dataset, metric);
* @endcode
*
* @param[in] handle
* @param[in] dataset a device pointer to a col-major matrix [n_rows, dim]
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
*
* @return the constructed bruteforce index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::col_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float>;
/**
* @}
*/
Expand All @@ -169,7 +204,7 @@ auto build(raft::resources const& handle,
/**
* @brief Search ANN using the constructed index.
*
* See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example.
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
*
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
Expand All @@ -186,20 +221,41 @@ auto build(raft::resources const& handle,
* @endcode
*
* @param[in] handle
* @param[in] index ivf-flat constructed index
* @param[in] index bruteforce constructed index
* @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] sample_filter a optional device bitmap filter function that greenlights samples for a
* given
* @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a
* given query
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<float>& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
std::optional<cuvs::core::bitmap_view<const uint32_t, int64_t>> sample_filter);

/**
* @brief Search ANN using the constructed index.
*
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
*
* @param[in] handle
* @param[in] index bruteforce constructed index
* @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a
* given query
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<float>& index,
raft::device_matrix_view<const float, int64_t, raft::col_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
std::optional<cuvs::core::bitmap_view<const uint32_t, int64_t>> sample_filter);
/**
* @}
*/
Expand Down
52 changes: 52 additions & 0 deletions cpp/src/neighbors/brute_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,36 @@ index<T>::index(raft::resources const& res,
{
}

template <typename T>
index<T>::index(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::col_major> dataset_view,
std::optional<raft::device_vector<T, int64_t>>&& norms,
cuvs::distance::DistanceType metric,
T metric_arg)
: cuvs::neighbors::index(),
metric_(metric),
dataset_(
raft::make_device_matrix<T, int64_t>(res, dataset_view.extent(0), dataset_view.extent(1))),
norms_(std::move(norms)),
metric_arg_(metric_arg)
{
// currently we don't support col_major inside tiled_brute_force_knn, because
// of limitations of the pairwise_distance API:
// 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have
// multiple options here (both dataset and queries)
// 2) because of tiling, we need to be able to set a custom stride in the PW
// api, which isn't supported
// Instead, transpose the input matrices if they are passed as col-major.
// (note: we're doing the transpose here to avoid doing per query)
raft::linalg::transpose(res,
const_cast<T*>(dataset_view.data_handle()),
dataset_.data_handle(),
dataset_view.extent(0),
dataset_view.extent(1),
raft::resource::get_cuda_stream(res));
dataset_view_ = raft::make_const_mdspan(dataset_.view());
}

template <typename T>
void index<T>::update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset)
Expand All @@ -93,6 +123,14 @@ void index<T>::update_dataset(raft::resources const& res,
->cuvs::neighbors::brute_force::index<T> \
{ \
return detail::build<T>(res, dataset, metric, metric_arg); \
} \
auto build(raft::resources const& res, \
raft::device_matrix_view<const T, int64_t, raft::col_major> dataset, \
cuvs::distance::DistanceType metric, \
T metric_arg) \
->cuvs::neighbors::brute_force::index<T> \
{ \
return detail::build<T>(res, dataset, metric, metric_arg); \
} \
\
void search( \
Expand All @@ -109,6 +147,20 @@ void index<T>::update_dataset(raft::resources const& res,
detail::brute_force_search_filtered<T, int64_t>( \
res, idx, queries, *sample_filter, neighbors, distances); \
} \
} \
void search( \
raft::resources const& res, \
const cuvs::neighbors::brute_force::index<T>& idx, \
raft::device_matrix_view<const T, int64_t, raft::col_major> queries, \
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<T, int64_t, raft::row_major> distances, \
std::optional<cuvs::core::bitmap_view<const uint32_t, int64_t>> sample_filter = std::nullopt) \
{ \
if (!sample_filter.has_value()) { \
detail::brute_force_search<T, int64_t>(res, idx, queries, neighbors, distances); \
} else { \
RAFT_FAIL("filtered search isn't available with col_major queries yet"); \
} \
} \
\
template struct cuvs::neighbors::brute_force::index<T>;
Expand Down
27 changes: 8 additions & 19 deletions cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -513,11 +513,11 @@ void brute_force_knn_impl(
if (translations == nullptr) delete id_ranges;
};

template <typename T, typename IdxT>
template <typename T, typename IdxT, typename QueryLayoutT = raft::row_major>
void brute_force_search(
raft::resources const& res,
const cuvs::neighbors::brute_force::index<T>& idx,
raft::device_matrix_view<const T, int64_t, raft::row_major> queries,
raft::device_matrix_view<const T, int64_t, QueryLayoutT> queries,
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<T, int64_t, raft::row_major> distances,
std::optional<raft::device_vector_view<const T, int64_t>> query_norms = std::nullopt)
Expand All @@ -544,7 +544,7 @@ void brute_force_search(
distances.data_handle(),
k,
true,
true,
std::is_same_v<QueryLayoutT, raft::row_major>,
nullptr,
idx.metric(),
idx.metric_arg(),
Expand Down Expand Up @@ -719,43 +719,32 @@ void brute_force_search_filtered(
return;
}

template <typename T>
template <typename T, typename LayoutT = raft::row_major>
cuvs::neighbors::brute_force::index<T> build(
raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset,
raft::device_matrix_view<const T, int64_t, LayoutT> dataset,
cuvs::distance::DistanceType metric,
T metric_arg)
{
// certain distance metrics can benefit by pre-calculating the norms for the index dataset
// which lets us avoid calculating these at query time
std::optional<raft::device_vector<T, int64_t>> norms;
auto dataset_storage = std::optional<raft::device_matrix<T, int64_t>>{};
auto dataset_view = [&res, &dataset_storage, dataset]() {
if constexpr (std::is_same_v<decltype(dataset),
raft::device_matrix_view<const T, int64_t, raft::row_major>>) {
return dataset;
} else {
dataset_storage =
raft::make_device_matrix<T, int64_t>(res, dataset.extent(0), dataset.extent(1));
raft::copy(res, dataset_storage->view(), dataset);
return raft::make_const_mdspan(dataset_storage->view());
}
}();

if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
metric == cuvs::distance::DistanceType::CosineExpanded) {
norms = raft::make_device_vector<T, int64_t>(res, dataset.extent(0));
// cosine needs the l2norm, where as l2 distances needs the squared norm
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
raft::linalg::norm(res,
dataset_view,
dataset,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS,
raft::sqrt_op{});
} else {
raft::linalg::norm(res,
dataset_view,
dataset,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS);
Expand Down
Loading
Loading