Skip to content

Commit

Permalink
[#24085] DocDB: Refactor HNSW wrappers
Browse files Browse the repository at this point in the history
Summary:
We don't use HNSW wrappers directly and use VectorIndexIf instead.
So wrapper could be fully defined in .cc file, keeping only factory in header.
Jira: DB-12979

Test Plan: Jenkins

Reviewers: mbautin

Reviewed By: mbautin

Subscribers: ybase

Tags: #jenkins-ready

Differential Revision: https://phorge.dev.yugabyte.com/D38256
  • Loading branch information
spolitov committed Sep 23, 2024
1 parent 039c9a2 commit 09f7a0f
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 90 deletions.
4 changes: 2 additions & 2 deletions src/yb/tools/hnsw_tool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -679,13 +679,13 @@ std::optional<Status> BenchmarkExecuteHelper(
if (args.ann_method == ann_method_kind &&
args.hnsw_options.distance_kind == distance_kind &&
input_coordinate_kind == CoordinateTypeTraits<typename InputVector::value_type>::kKind) {
using IndexType = typename ANNMethodTraits<ann_method_kind>::template IndexType<
using FactoryType = typename ANNMethodTraits<ann_method_kind>::template FactoryType<
IndexedVector,
typename DistanceTraits<IndexedVector, distance_kind>::Result>;
return BenchmarkTool<InputVector, InputDistanceResult, IndexedVector, IndexedDistanceResult>(
args,
[](const HNSWOptions& options) {
return CreateIndexFactory<IndexType>(options);
return std::bind(&FactoryType::Create, options);
}
).Execute();
}
Expand Down
4 changes: 2 additions & 2 deletions src/yb/vector/ann_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ struct ANNMethodTraits {
template<>
struct ANNMethodTraits<ANNMethodKind::kUsearch> {
template<IndexableVectorType Vector, ValidDistanceResultType DistanceResult>
using IndexType = UsearchIndex<Vector, DistanceResult>;
using FactoryType = UsearchIndexFactory<Vector, DistanceResult>;
};

template<>
struct ANNMethodTraits<ANNMethodKind::kHnswlib> {
template<IndexableVectorType Vector, ValidDistanceResultType DistanceResult>
using IndexType = HnswlibIndex<Vector, DistanceResult>;
using FactoryType = HnswlibIndexFactory<Vector, DistanceResult>;
};

} // namespace yb::vectorindex
31 changes: 15 additions & 16 deletions src/yb/vector/hnswlib_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,20 @@

namespace yb::vectorindex {

namespace detail {
namespace {

template<IndexableVectorType Vector, ValidDistanceResultType DistanceResult>
class HnswlibIndexImpl {
class HnswlibIndex : public VectorIndexIf<Vector, DistanceResult> {
public:
using Scalar = typename Vector::value_type;

using HNSWImpl = typename hnswlib::HierarchicalNSW<DistanceResult>;

explicit HnswlibIndexImpl(const HNSWOptions& options)
explicit HnswlibIndex(const HNSWOptions& options)
: options_(options) {
}

Status Reserve(size_t num_vectors) {
Status Reserve(size_t num_vectors) override {
if (hnsw_) {
return STATUS_FORMAT(
IllegalState, "Cannot reserve space for $0 vectors: Hnswlib index already initialized",
Expand All @@ -64,13 +65,13 @@ class HnswlibIndexImpl {
return Status::OK();
}

Status Insert(VertexId vertex_id, const Vector& v) {
Status Insert(VertexId vertex_id, const Vector& v) override {
hnsw_->addPoint(v.data(), vertex_id);
return Status::OK();
}

std::vector<VertexWithDistance<DistanceResult>> Search(
const Vector& query_vector, size_t max_num_results) {
const Vector& query_vector, size_t max_num_results) const override {
std::vector<VertexWithDistance<DistanceResult>> result;
auto tmp_result = hnsw_->searchKnnCloserFirst(query_vector.data(), max_num_results);
result.reserve(tmp_result.size());
Expand All @@ -87,7 +88,7 @@ class HnswlibIndexImpl {
return result;
}

Result<Vector> GetVector(VertexId vertex_id) const {
Result<Vector> GetVector(VertexId vertex_id) const override {
return STATUS(
NotSupported, "Hnswlib wrapper currently does not allow retriving vectors by id");
}
Expand Down Expand Up @@ -121,17 +122,15 @@ class HnswlibIndexImpl {
std::unique_ptr<HNSWImpl> hnsw_;
};

} // namespace detail
} // namespace

template<IndexableVectorType Vector, ValidDistanceResultType DistanceResult>
HnswlibIndex<Vector, DistanceResult>::HnswlibIndex(const HNSWOptions& options)
: VectorIndexBase<Impl, Vector, DistanceResult>(std::make_unique<Impl>(options)) {
template <IndexableVectorType Vector, ValidDistanceResultType DistanceResult>
VectorIndexIfPtr<Vector, DistanceResult> HnswlibIndexFactory<Vector, DistanceResult>::Create(
const HNSWOptions& options) {
return std::make_shared<HnswlibIndex<Vector, DistanceResult>>(options);
}

template<IndexableVectorType Vector, ValidDistanceResultType DistanceResult>
HnswlibIndex<Vector, DistanceResult>::~HnswlibIndex() = default;

template class HnswlibIndex<FloatVector, float>;
template class HnswlibIndex<UInt8Vector, int32_t>;
template class HnswlibIndexFactory<FloatVector, float>;
template class HnswlibIndexFactory<UInt8Vector, int32_t>;

} // namespace yb::vectorindex
13 changes: 2 additions & 11 deletions src/yb/vector/hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,10 @@

namespace yb::vectorindex {

namespace detail {
template<IndexableVectorType Vector, ValidDistanceResultType DistanceResult>
class HnswlibIndexImpl;
} // namespace detail

template<IndexableVectorType Vector, ValidDistanceResultType DistanceResult>
class HnswlibIndex : public VectorIndexBase<
detail::HnswlibIndexImpl<Vector, DistanceResult>, Vector, DistanceResult> {
class HnswlibIndexFactory {
public:
explicit HnswlibIndex(const HNSWOptions& options);
virtual ~HnswlibIndex();
private:
using Impl = detail::HnswlibIndexImpl<Vector, DistanceResult>;
static VectorIndexIfPtr<Vector, DistanceResult> Create(const HNSWOptions& options);
};

} // namespace yb::vectorindex
28 changes: 13 additions & 15 deletions src/yb/vector/usearch_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ scalar_kind_t ConvertCoordinateKind(CoordinateKind coordinate_kind) {
FATAL_INVALID_ENUM_VALUE(CoordinateKind, coordinate_kind);
}

namespace detail {
namespace {

template<IndexableVectorType Vector, ValidDistanceResultType DistanceResult>
class UsearchIndexImpl {
class UsearchIndex : public VectorIndexIf<Vector, DistanceResult> {
public:
explicit UsearchIndexImpl(const HNSWOptions& options)
explicit UsearchIndex(const HNSWOptions& options)
: dimensions_(options.dimensions),
distance_kind_(options.distance_kind),
metric_(dimensions_,
Expand All @@ -81,20 +81,20 @@ class UsearchIndexImpl {
CHECK_GT(dimensions_, 0);
}

Status Reserve(size_t num_vectors) {
Status Reserve(size_t num_vectors) override {
index_.reserve(num_vectors);
return Status::OK();
}

Status Insert(VertexId vertex_id, const Vector& v) {
Status Insert(VertexId vertex_id, const Vector& v) override {
if (!index_.add(vertex_id, v.data())) {
return STATUS_FORMAT(RuntimeError, "Failed to add a vector");
}
return Status::OK();
}

std::vector<VertexWithDistance<DistanceResult>> Search(
const Vector& query_vector, size_t max_num_results) {
const Vector& query_vector, size_t max_num_results) const override {
auto usearch_results = index_.search(query_vector.data(), max_num_results);
std::vector<VertexWithDistance<DistanceResult>> result_vec;
result_vec.reserve(usearch_results.size());
Expand All @@ -105,7 +105,7 @@ class UsearchIndexImpl {
return result_vec;
}

Result<Vector> GetVector(VertexId vertex_id) const {
Result<Vector> GetVector(VertexId vertex_id) const override {
Vector result;
result.resize(dimensions_);
if (index_.get(vertex_id, result.data())) {
Expand All @@ -121,16 +121,14 @@ class UsearchIndexImpl {
index_dense_gt<VertexId> index_;
};

} // namespace detail
} // namespace

template<IndexableVectorType Vector, ValidDistanceResultType DistanceResult>
UsearchIndex<Vector, DistanceResult>::UsearchIndex(const HNSWOptions& options)
: VectorIndexBase<Impl, Vector, DistanceResult>(std::make_unique<Impl>(options)) {
template <class Vector, class DistanceResult>
VectorIndexIfPtr<Vector, DistanceResult> UsearchIndexFactory<Vector, DistanceResult>::Create(
const HNSWOptions& options) {
return std::make_shared<UsearchIndex<Vector, DistanceResult>>(options);
}

template<IndexableVectorType Vector, ValidDistanceResultType DistanceResult>
UsearchIndex<Vector, DistanceResult>::~UsearchIndex() = default;

template class UsearchIndex<FloatVector, float>;
template class UsearchIndexFactory<FloatVector, float>;

} // namespace yb::vectorindex
15 changes: 3 additions & 12 deletions src/yb/vector/usearch_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,10 @@

namespace yb::vectorindex {

namespace detail {
template<IndexableVectorType Vector, ValidDistanceResultType DistanceResult>
class UsearchIndexImpl;
} // namespace detail

template<IndexableVectorType Vector, ValidDistanceResultType DistanceResult>
class UsearchIndex : public VectorIndexBase<
detail::UsearchIndexImpl<Vector, DistanceResult>, Vector, DistanceResult> {
template <class Vector, class DistanceResult>
class UsearchIndexFactory {
public:
explicit UsearchIndex(const HNSWOptions& options);
virtual ~UsearchIndex();
private:
using Impl = detail::UsearchIndexImpl<Vector, DistanceResult>;
static VectorIndexIfPtr<Vector, DistanceResult> Create(const HNSWOptions& options);
};

} // namespace yb::vectorindex
32 changes: 0 additions & 32 deletions src/yb/vector/vector_index_wrapper_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,6 @@

namespace yb::vectorindex {

// A base class for vector index implementations implementing the pointer-to-implementation idiom.
template <typename Impl, typename Vector, typename DistanceResult>
class VectorIndexBase : public VectorIndexIf<Vector, DistanceResult> {
public:
explicit VectorIndexBase(std::unique_ptr<Impl> impl)
: impl_(std::move(impl)) {}

~VectorIndexBase() override = default;

// Implementations for the VectorIndexReaderIf interface
std::vector<VertexWithDistance<DistanceResult>> Search(
const Vector& query_vector, size_t max_num_results) const override {
return impl_->Search(query_vector, max_num_results);
}

// Implementations for the VectorIndexWriterIf interface
Status Reserve(size_t num_vectors) override {
return impl_->Reserve(num_vectors);
}

Status Insert(VertexId vertex_id, const Vector& vector) override {
return impl_->Insert(vertex_id, vector);
}

Result<Vector> GetVector(VertexId vertex_id) const override {
return impl_->GetVector(vertex_id);
}

protected:
std::unique_ptr<Impl> impl_;
};

// An adapter that allows us to view an index reader with one vector type as an index reader with a
// different vector type. Casts the queries to the vector type supported by the index, and then
// casts the distance type in the results to the distance type expected by the caller.
Expand Down

0 comments on commit 09f7a0f

Please sign in to comment.