From add2f4263c6aca4392374bcbfc0818f3005b9f03 Mon Sep 17 00:00:00 2001 From: "jinjiabao.jjb" Date: Tue, 15 Oct 2024 11:54:04 +0800 Subject: [PATCH 01/11] support int8 for hnsw Signed-off-by: jinjiabao.jjb --- include/vsag/constants.h | 1 + src/algorithm/hnswlib/hnswalg.cpp | 2 +- src/algorithm/hnswlib/space_ip.h | 21 ++-- src/algorithm/hnswlib/space_l2.h | 8 +- src/constants.cpp | 1 + src/factory/factory.cpp | 2 + src/index/hnsw.cpp | 37 +++++-- src/index/hnsw.h | 6 ++ src/index/hnsw_test.cpp | 128 ++++++++++++++++++------- src/index/hnsw_zparameters.cpp | 27 ++++-- src/index/hnsw_zparameters.h | 2 + src/simd/avx512.cpp | 113 ++++++++++++++++++++++ src/simd/avx512_test.cpp | 43 +++++++++ src/simd/generic.cpp | 17 ++++ src/simd/simd.cpp | 12 +++ src/simd/simd.h | 15 +++ tests/performance/test_performance.cpp | 102 ++++++++++++++------ tests/test_index.cpp | 53 ++++++++++ 18 files changed, 496 insertions(+), 94 deletions(-) create mode 100644 src/simd/avx512_test.cpp diff --git a/include/vsag/constants.h b/include/vsag/constants.h index 32889a2a..c1d5d4be 100644 --- a/include/vsag/constants.h +++ b/include/vsag/constants.h @@ -39,6 +39,7 @@ extern const char* const METRIC_L2; extern const char* const METRIC_COSINE; extern const char* const METRIC_IP; extern const char* const DATATYPE_FLOAT32; +extern const char* const DATATYPE_INT8; extern const char* const BLANK_INDEX; // parameters diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index 5af9036c..2bc3289f 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -40,7 +40,7 @@ HierarchicalNSW::HierarchicalNSW(SpaceInterface* s, data_size_ = s->get_data_size(); fstdistfunc_ = s->get_dist_func(); dist_func_param_ = s->get_dist_func_param(); - dim_ = data_size_ / sizeof(float); + dim_ = *((size_t*)dist_func_param_); M_ = M; maxM_ = M_; maxM0_ = M_ * 2; diff --git a/src/algorithm/hnswlib/space_ip.h b/src/algorithm/hnswlib/space_ip.h index 702121d6..5b92e60a 100644 --- a/src/algorithm/hnswlib/space_ip.h +++ b/src/algorithm/hnswlib/space_ip.h @@ -14,26 +14,27 @@ // limitations under the License. #pragma once +#include "data_type.h" +#include "simd/simd.h" #include "space_interface.h" -namespace vsag { - -extern hnswlib::DISTFUNC -GetInnerProductDistanceFunc(size_t dim); - -} // namespace vsag - namespace hnswlib { + class InnerProductSpace : public SpaceInterface { DISTFUNC fstdistfunc_; size_t data_size_; size_t dim_; public: - explicit InnerProductSpace(size_t dim) { - fstdistfunc_ = vsag::GetInnerProductDistanceFunc(dim); + explicit InnerProductSpace(size_t dim, vsag::DataTypes type) { dim_ = dim; - data_size_ = dim * sizeof(float); + if (type == vsag::DataTypes::DATA_TYPE_FLOAT) { + fstdistfunc_ = vsag::GetInnerProductDistanceFunc(dim); + data_size_ = dim * sizeof(float); + } else if (type == vsag::DataTypes::DATA_TYPE_INT8) { + fstdistfunc_ = vsag::GetINT8InnerProductDistanceFunc(dim); + data_size_ = dim * sizeof(int8_t); + } } size_t diff --git a/src/algorithm/hnswlib/space_l2.h b/src/algorithm/hnswlib/space_l2.h index 08f77f4c..edae3f57 100644 --- a/src/algorithm/hnswlib/space_l2.h +++ b/src/algorithm/hnswlib/space_l2.h @@ -14,15 +14,9 @@ // limitations under the License. #pragma once +#include "simd/simd.h" #include "space_interface.h" -namespace vsag { - -extern hnswlib::DISTFUNC -GetL2DistanceFunc(size_t dim); - -} // namespace vsag - namespace hnswlib { class L2Space : public SpaceInterface { diff --git a/src/constants.cpp b/src/constants.cpp index aec320f9..42017057 100644 --- a/src/constants.cpp +++ b/src/constants.cpp @@ -40,6 +40,7 @@ const char* const METRIC_L2 = "l2"; const char* const METRIC_COSINE = "cosine"; const char* const METRIC_IP = "ip"; const char* const DATATYPE_FLOAT32 = "float32"; +const char* const DATATYPE_INT8 = "int8"; const char* const BLANK_INDEX = "blank_index"; // parameters diff --git a/src/factory/factory.cpp b/src/factory/factory.cpp index f66f9bec..c65cd38e 100644 --- a/src/factory/factory.cpp +++ b/src/factory/factory.cpp @@ -51,6 +51,7 @@ Factory::CreateIndex(const std::string& origin_name, return std::make_shared(params.space, params.max_degree, params.ef_construction, + params.type, params.use_static, false, params.use_conjugate_graph, @@ -63,6 +64,7 @@ Factory::CreateIndex(const std::string& origin_name, return std::make_shared(params.space, params.max_degree, params.ef_construction, + params.type, params.use_static, true, false, diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index 7401b72c..c71f5404 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -47,6 +47,7 @@ const static float GENERATE_OMEGA = 0.51; HNSW::HNSW(std::shared_ptr space_interface, int M, int ef_construction, + DataTypes type, bool use_static, bool use_reversed_edges, bool use_conjugate_graph, @@ -55,7 +56,8 @@ HNSW::HNSW(std::shared_ptr space_interface, : space_(std::move(space_interface)), use_static_(use_static), use_conjugate_graph_(use_conjugate_graph), - use_reversed_edges_(use_reversed_edges) { + use_reversed_edges_(use_reversed_edges), + type_(type) { dim_ = *((size_t*)space_->get_dist_func_param()); M = std::min(std::max(M, MINIMAL_M), MAXIMAL_M); @@ -120,13 +122,15 @@ HNSW::build(const DatasetPtr& base) { } auto ids = base->GetIds(); - auto vectors = base->GetFloat32Vectors(); + void* vectors = nullptr; + size_t data_size = 0; + get_vectors(base, vectors, data_size); std::vector failed_ids; { SlowTaskTimer t("hnsw graph"); for (int64_t i = 0; i < num_elements; ++i) { // noexcept runtime - if (!alg_hnsw_->addPoint((const void*)(vectors + i * dim_), ids[i])) { + if (!alg_hnsw_->addPoint((const void*)((char*)vectors + data_size * i), ids[i])) { logger::debug("duplicate point: {}", ids[i]); failed_ids.emplace_back(ids[i]); } @@ -160,7 +164,9 @@ HNSW::add(const DatasetPtr& base) { int64_t num_elements = base->GetNumElements(); auto ids = base->GetIds(); - auto vectors = base->GetFloat32Vectors(); + void* vectors = nullptr; + size_t data_size = 0; + get_vectors(base, vectors, data_size); std::vector failed_ids; std::unique_lock lock(rw_mutex_); @@ -169,7 +175,7 @@ HNSW::add(const DatasetPtr& base) { } for (int64_t i = 0; i < num_elements; ++i) { // noexcept runtime - if (!alg_hnsw_->addPoint((const void*)(vectors + i * dim_), ids[i])) { + if (!alg_hnsw_->addPoint((const void*)((char*)vectors + data_size * i), ids[i])) { logger::debug("duplicate point: {}", i); failed_ids.push_back(ids[i]); } @@ -213,7 +219,9 @@ HNSW::knn_search(const DatasetPtr& query, // check query vector CHECK_ARGUMENT(query->GetNumElements() == 1, "query dataset should contain 1 vector only"); - auto vector = query->GetFloat32Vectors(); + void* vector = nullptr; + size_t data_size = 0; + get_vectors(query, vector, data_size); int64_t query_dim = query->GetDim(); CHECK_ARGUMENT( query_dim == dim_, @@ -331,7 +339,9 @@ HNSW::range_search(const DatasetPtr& query, // check query vector CHECK_ARGUMENT(query->GetNumElements() == 1, "query dataset should contain 1 vector only"); - auto vector = query->GetFloat32Vectors(); + void* vector = nullptr; + size_t data_size = 0; + get_vectors(query, vector, data_size); int64_t query_dim = query->GetDim(); CHECK_ARGUMENT( query_dim == dim_, @@ -800,4 +810,17 @@ HNSW::init_memory_space() { return true; } +void +HNSW::get_vectors(const vsag::DatasetPtr& base, void*& vectors, size_t& data_size) const { + if (type_ == DataTypes::DATA_TYPE_FLOAT) { + vectors = (void*)base->GetFloat32Vectors(); + data_size = dim_ * sizeof(float); + } else if (type_ == DataTypes::DATA_TYPE_INT8) { + vectors = (void*)base->GetInt8Vectors(); + data_size = dim_ * sizeof(int8_t); + } else { + throw std::invalid_argument("fail to support this metric"); + } +} + } // namespace vsag diff --git a/src/index/hnsw.h b/src/index/hnsw.h index a822a131..61e65f26 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -28,6 +28,7 @@ #include "../algorithm/hnswlib/hnswlib.h" #include "../common.h" +#include "../data_type.h" #include "../default_allocator.h" #include "../impl/conjugate_graph.h" #include "../logger.h" @@ -70,6 +71,7 @@ class HNSW : public Index { HNSW(std::shared_ptr space_interface, int M, int ef_construction, + DataTypes type, bool use_static = false, bool use_reversed_edges = false, bool use_conjugate_graph = false, @@ -280,6 +282,9 @@ class HNSW : public Index { tl::expected init_memory_space(); + void + get_vectors(const DatasetPtr& base, void*& vectors, size_t& data_size) const; + BinarySet empty_binaryset() const; @@ -295,6 +300,7 @@ class HNSW : public Index { bool empty_index_ = false; bool use_reversed_edges_ = false; bool is_init_memory_ = false; + DataTypes type_; std::shared_ptr allocator_; diff --git a/src/index/hnsw_test.cpp b/src/index/hnsw_test.cpp index 1e7724c9..58f91558 100644 --- a/src/index/hnsw_test.cpp +++ b/src/index/hnsw_test.cpp @@ -20,6 +20,7 @@ #include #include +#include "../data_type.h" #include "../logger.h" #include "fixtures.h" #include "vsag/bitset.h" @@ -32,8 +33,10 @@ TEST_CASE("build & add", "[ut][hnsw]") { int64_t dim = 128; int64_t max_degree = 12; int64_t ef_construction = 100; - auto index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction); + auto index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT); std::vector ids(1); int64_t incorrect_dim = 63; @@ -66,8 +69,11 @@ TEST_CASE("build with allocator", "[ut][hnsw]") { int64_t max_degree = 12; int64_t ef_construction = 100; vsag::DefaultAllocator allocator; - auto index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction, &allocator); + auto index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT, + &allocator); const int64_t num_elements = 10; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_elements, dim); @@ -88,8 +94,10 @@ TEST_CASE("knn_search", "[ut][hnsw]") { int64_t dim = 128; int64_t max_degree = 12; int64_t ef_construction = 100; - auto index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction); + auto index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT); const int64_t num_elements = 10; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_elements, dim); @@ -161,8 +169,10 @@ TEST_CASE("range_search", "[ut][hnsw]") { int64_t dim = 128; int64_t max_degree = 12; int64_t ef_construction = 100; - auto index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction); + auto index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT); const int64_t num_elements = 10; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_elements, dim); @@ -257,8 +267,10 @@ TEST_CASE("serialize empty index", "[ut][hnsw]") { int64_t dim = 128; int64_t max_degree = 12; int64_t ef_construction = 100; - auto index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction); + auto index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT); SECTION("serialize to binaryset") { auto result = index->Serialize(); @@ -281,8 +293,13 @@ TEST_CASE("deserialize on not empty index", "[ut][hnsw]") { int64_t dim = 128; int64_t max_degree = 12; int64_t ef_construction = 100; - auto index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction, false, false, true); + auto index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT, + false, + false, + true); const int64_t num_elements = 10; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_elements, dim); @@ -307,6 +324,7 @@ TEST_CASE("deserialize on not empty index", "[ut][hnsw]") { auto another_index = std::make_shared(std::make_shared(dim), max_degree, ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT, false, false, true); @@ -335,8 +353,11 @@ TEST_CASE("static hnsw", "[ut][hnsw]") { int64_t dim = 128; int64_t max_degree = 12; int64_t ef_construction = 100; - auto index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction, true); + auto index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT, + true); const int64_t num_elements = 10; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_elements, dim); @@ -371,8 +392,11 @@ TEST_CASE("static hnsw", "[ut][hnsw]") { REQUIRE_FALSE(range_result.has_value()); REQUIRE(range_result.error().type == vsag::ErrorType::UNSUPPORTED_INDEX_OPERATION); - REQUIRE_THROWS(std::make_shared( - std::make_shared(127), max_degree, ef_construction, true)); + REQUIRE_THROWS(std::make_shared(std::make_shared(127), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT, + true)); auto remove_result = index->Remove(ids[0]); REQUIRE_FALSE(remove_result.has_value()); @@ -385,8 +409,10 @@ TEST_CASE("hnsw add vector with duplicated id", "[ut][hnsw]") { int64_t dim = 128; int64_t max_degree = 12; int64_t ef_construction = 100; - auto index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction); + auto index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT); std::vector ids{1}; std::vector vectors(dim); @@ -420,8 +446,12 @@ TEST_CASE("build with reversed edges", "[ut][hnsw]") { int64_t dim = 128; int64_t max_degree = 12; int64_t ef_construction = 100; - auto index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction, false, true); + auto index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT, + false, + true); const int64_t num_elements = 1000; auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_elements, dim); @@ -451,8 +481,12 @@ TEST_CASE("build with reversed edges", "[ut][hnsw]") { in_file.seekg(0, std::ios::end); int64_t length = in_file.tellg(); in_file.seekg(0, std::ios::beg); - auto new_index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction, false, true); + auto new_index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT, + false, + true); REQUIRE(new_index->Deserialize(in_file).has_value()); REQUIRE(new_index->CheckGraphIntegrity()); } @@ -498,8 +532,12 @@ TEST_CASE("build with reversed edges", "[ut][hnsw]") { bs.Set(key, b); } - auto new_index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction, false, true); + auto new_index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT, + false, + true); REQUIRE(new_index->Deserialize(bs).has_value()); REQUIRE(new_index->CheckGraphIntegrity()); } @@ -514,8 +552,13 @@ TEST_CASE("feedback with invalid argument", "[ut][hnsw]") { int64_t num_vectors = 1000; int64_t k = 10; - auto index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction, false, false, true); + auto index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT, + false, + false, + true); nlohmann::json search_parameters{ {"hnsw", {{"ef_search", 200}}}, @@ -549,8 +592,13 @@ TEST_CASE("redundant feedback and empty enhancement", "[ut][hnsw]") { int64_t num_query = 1; int64_t k = 10; - auto index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction, false, false, true); + auto index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT, + false, + false, + true); auto [base_ids, base_vectors] = fixtures::generate_ids_and_vectors(num_base, dim); auto base = vsag::Dataset::Make(); @@ -605,8 +653,10 @@ TEST_CASE("feedback and pretrain without use conjugate graph", "[ut][hnsw]") { int64_t num_query = 1; int64_t k = 10; - auto index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction); + auto index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT); auto [base_ids, base_vectors] = fixtures::generate_ids_and_vectors(num_base, dim); auto base = vsag::Dataset::Make(); @@ -647,8 +697,13 @@ TEST_CASE("feedback and pretrain on empty index", "[ut][hnsw]") { int64_t num_query = 1; int64_t k = 100; - auto index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction, false, false, true); + auto index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT, + false, + false, + true); auto [base_ids, base_vectors] = fixtures::generate_ids_and_vectors(num_base, dim); auto base = vsag::Dataset::Make(); @@ -689,8 +744,13 @@ TEST_CASE("invalid pretrain", "[ut][hnsw]") { int64_t num_query = 1; int64_t k = 100; - auto index = std::make_shared( - std::make_shared(dim), max_degree, ef_construction, false, false, true); + auto index = std::make_shared(std::make_shared(dim), + max_degree, + ef_construction, + vsag::DataTypes::DATA_TYPE_FLOAT, + false, + false, + true); auto [base_ids, base_vectors] = fixtures::generate_ids_and_vectors(num_base, dim); auto base = vsag::Dataset::Make(); diff --git a/src/index/hnsw_zparameters.cpp b/src/index/hnsw_zparameters.cpp index 5b2dde68..114953e2 100644 --- a/src/index/hnsw_zparameters.cpp +++ b/src/index/hnsw_zparameters.cpp @@ -20,6 +20,7 @@ #include #include "../common.h" +#include "../data_type.h" #include "vsag/constants.h" namespace vsag { @@ -30,26 +31,38 @@ CreateHnswParameters::FromJson(const std::string& json_string) { CHECK_ARGUMENT(params.contains(PARAMETER_DTYPE), fmt::format("parameters must contains {}", PARAMETER_DTYPE)); - CHECK_ARGUMENT( - params[PARAMETER_DTYPE] == DATATYPE_FLOAT32, - fmt::format("parameters[{}] supports {} only now", PARAMETER_DTYPE, DATATYPE_FLOAT32)); + + CreateHnswParameters obj; + + if (params[PARAMETER_DTYPE] == DATATYPE_FLOAT32) { + obj.type = DataTypes::DATA_TYPE_FLOAT; + } else if (params[PARAMETER_DTYPE] == DATATYPE_INT8) { + obj.type = DataTypes::DATA_TYPE_INT8; + } else { + throw std::invalid_argument(fmt::format("parameters[{}] supports {}, {} only now", + PARAMETER_DTYPE, + DATATYPE_FLOAT32, + DATATYPE_INT8)); + } CHECK_ARGUMENT(params.contains(PARAMETER_METRIC_TYPE), fmt::format("parameters must contains {}", PARAMETER_METRIC_TYPE)); CHECK_ARGUMENT(params.contains(PARAMETER_DIM), fmt::format("parameters must contains {}", PARAMETER_DIM)); - CreateHnswParameters obj; - // set obj.space CHECK_ARGUMENT(params.contains(INDEX_HNSW), fmt::format("parameters must contains {}", INDEX_HNSW)); + if (obj.type == DataTypes::DATA_TYPE_INT8 && params[PARAMETER_METRIC_TYPE] != METRIC_IP) { + throw std::invalid_argument(fmt::format( + "no support for INT8 when using {}, {} as metric", METRIC_L2, METRIC_COSINE)); + } if (params[PARAMETER_METRIC_TYPE] == METRIC_L2) { obj.space = std::make_shared(params[PARAMETER_DIM]); } else if (params[PARAMETER_METRIC_TYPE] == METRIC_IP) { - obj.space = std::make_shared(params[PARAMETER_DIM]); + obj.space = std::make_shared(params[PARAMETER_DIM], obj.type); } else if (params[PARAMETER_METRIC_TYPE] == METRIC_COSINE) { obj.normalize = true; - obj.space = std::make_shared(params[PARAMETER_DIM]); + obj.space = std::make_shared(params[PARAMETER_DIM], obj.type); } else { std::string metric = params[PARAMETER_METRIC_TYPE]; throw std::invalid_argument(fmt::format("parameters[{}] must in [{}, {}, {}], now is {}", diff --git a/src/index/hnsw_zparameters.h b/src/index/hnsw_zparameters.h index 58f5d008..2314346f 100644 --- a/src/index/hnsw_zparameters.h +++ b/src/index/hnsw_zparameters.h @@ -19,6 +19,7 @@ #include #include "../algorithm/hnswlib/hnswlib.h" +#include "../data_type.h" namespace vsag { @@ -35,6 +36,7 @@ struct CreateHnswParameters { bool use_conjugate_graph; bool use_static; bool normalize = false; + DataTypes type; protected: CreateHnswParameters() = default; diff --git a/src/simd/avx512.cpp b/src/simd/avx512.cpp index 7258454f..39403545 100644 --- a/src/simd/avx512.cpp +++ b/src/simd/avx512.cpp @@ -15,7 +15,10 @@ #include +#include + #include "fp32_simd.h" +#include "simd.h" #include "sq4_simd.h" #include "sq4_uniform_simd.h" #include "sq8_simd.h" @@ -87,6 +90,116 @@ InnerProductSIMD16ExtAVX512(const void* pVect1v, const void* pVect2v, const void return sum; } +float +INT8InnerProduct512AVX512(const void* pVect1v, const void* pVect2v, const void* qty_ptr) { + __mmask32 mask = 0xFFFFFFFF; + __mmask64 mask64 = 0xFFFFFFFFFFFFFFFF; + + size_t qty = *((size_t*)qty_ptr); + int32_t cTmp[16]; + + int8_t* pVect1 = (int8_t*)pVect1v; + int8_t* pVect2 = (int8_t*)pVect2v; + const int8_t* pEnd1 = pVect1 + qty; + + __m512i sum512 = _mm512_set1_epi32(0); + + while (pVect1 < pEnd1) { + __m256i v1 = _mm256_maskz_loadu_epi8(mask, pVect1); + __m512i v1_512 = _mm512_cvtepi8_epi16(v1); + pVect1 += 32; + __m256i v2 = _mm256_maskz_loadu_epi8(mask, pVect2); + __m512i v2_512 = _mm512_cvtepi8_epi16(v2); + pVect2 += 32; + + sum512 = _mm512_add_epi32(sum512, _mm512_madd_epi16(v1_512, v2_512)); + } + + _mm512_mask_storeu_epi32(cTmp, mask64, sum512); + double res = 0; + for (int i = 0; i < 16; i++) { + res += cTmp[i]; + } + return res; +} + +float +INT8InnerProduct256AVX512(const void* pVect1v, const void* pVect2v, const void* qty_ptr) { + __mmask16 mask = 0xFFFF; + __mmask64 mask64 = 0xFFFFFFFFFFFFFFFF; + size_t qty = *((size_t*)qty_ptr); + + int32_t cTmp[16]; + + int8_t* pVect1 = (int8_t*)pVect1v; + int8_t* pVect2 = (int8_t*)pVect2v; + const int8_t* pEnd1 = pVect1 + qty; + + __m512i sum512 = _mm512_set1_epi32(0); + + while (pVect1 < pEnd1) { + __m128i v1 = _mm_maskz_loadu_epi8(mask, pVect1); + __m512i v1_512 = _mm512_cvtepi8_epi32(v1); + pVect1 += 16; + __m128i v2 = _mm_maskz_loadu_epi8(mask, pVect2); + __m512i v2_512 = _mm512_cvtepi8_epi32(v2); + pVect2 += 16; + + sum512 = _mm512_add_epi32(sum512, _mm512_mullo_epi32(v1_512, v2_512)); + } + + _mm512_mask_storeu_epi32(cTmp, mask64, sum512); + double res = 0; + for (int i = 0; i < 16; i++) { + res += cTmp[i]; + } + return res; +} + +float +INT8InnerProduct256ResidualsAVX512(const void* pVect1v, const void* pVect2v, const void* qty_ptr) { + size_t qty = *((size_t*)qty_ptr); + size_t qty2 = qty >> 4 << 4; + double res = INT8InnerProduct256AVX512(pVect1v, pVect2v, &qty2); + int8_t* pVect1 = (int8_t*)pVect1v + qty2; + int8_t* pVect2 = (int8_t*)pVect2v + qty2; + + size_t qty_left = qty - qty2; + if (qty_left != 0) { + res += INT8InnerProduct(pVect1, pVect2, &qty_left); + } + return res; +} + +float +INT8InnerProduct256ResidualsAVX512Distance(const void* pVect1v, + const void* pVect2v, + const void* qty_ptr) { + return -INT8InnerProduct256ResidualsAVX512(pVect1v, pVect2v, qty_ptr); +} + +float +INT8InnerProduct512ResidualsAVX512(const void* pVect1v, const void* pVect2v, const void* qty_ptr) { + size_t qty = *((size_t*)qty_ptr); + size_t qty2 = qty >> 5 << 5; + double res = INT8InnerProduct512AVX512(pVect1v, pVect2v, &qty2); + int8_t* pVect1 = (int8_t*)pVect1v + qty2; + int8_t* pVect2 = (int8_t*)pVect2v + qty2; + + size_t qty_left = qty - qty2; + if (qty_left != 0) { + res += INT8InnerProduct256ResidualsAVX512(pVect1, pVect2, &qty_left); + } + return res; +} + +float +INT8InnerProduct512ResidualsAVX512Distance(const void* pVect1v, + const void* pVect2v, + const void* qty_ptr) { + return -INT8InnerProduct512ResidualsAVX512(pVect1v, pVect2v, qty_ptr); +} + namespace avx512 { float FP32ComputeIP(const float* query, const float* codes, uint64_t dim) { diff --git a/src/simd/avx512_test.cpp b/src/simd/avx512_test.cpp new file mode 100644 index 00000000..155ea12b --- /dev/null +++ b/src/simd/avx512_test.cpp @@ -0,0 +1,43 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "catch2/catch_approx.hpp" +#include "cpuinfo.h" +#include "fixtures.h" +#include "simd.h" + +TEST_CASE("avx512 int8", "[ut][simd][avx]") { +#if defined(ENABLE_AVX512) + if (cpuinfo_has_x86_sse()) { + auto common_dims = fixtures::get_common_used_dims(); + for (size_t dim : common_dims) { + auto vectors = fixtures::generate_vectors(2, dim); + fixtures::dist_t distance_512 = vsag::INT8InnerProduct512ResidualsAVX512Distance( + vectors.data(), vectors.data() + dim, &dim); + fixtures::dist_t distance_256 = vsag::INT8InnerProduct256ResidualsAVX512Distance( + vectors.data(), vectors.data() + dim, &dim); + fixtures::dist_t expected_distance = + vsag::INT8InnerProductDistance(vectors.data(), vectors.data() + dim, &dim); + REQUIRE(distance_512 == expected_distance); + REQUIRE(distance_256 == expected_distance); + } + } +#endif +} diff --git a/src/simd/generic.cpp b/src/simd/generic.cpp index 2aaeceee..aead3cc3 100644 --- a/src/simd/generic.cpp +++ b/src/simd/generic.cpp @@ -50,6 +50,23 @@ InnerProductDistance(const void* pVect1, const void* pVect2, const void* qty_ptr return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr); } +float +INT8InnerProduct(const void* pVect1, const void* pVect2, const void* qty_ptr) { + size_t qty = *((size_t*)qty_ptr); + int8_t* vec1 = (int8_t*)pVect1; + int8_t* vec2 = (int8_t*)pVect2; + double res = 0; + for (size_t i = 0; i < qty; i++) { + res += vec1[i] * vec2[i]; + } + return res; +} + +float +INT8InnerProductDistance(const void* pVect1, const void* pVect2, const void* qty_ptr) { + return -INT8InnerProduct(pVect1, pVect2, qty_ptr); +} + void PQDistanceFloat256(const void* single_dim_centers, float single_dim_val, void* result) { const auto* float_centers = (const float*)single_dim_centers; diff --git a/src/simd/simd.cpp b/src/simd/simd.cpp index 4519cc65..5c5f6f07 100644 --- a/src/simd/simd.cpp +++ b/src/simd/simd.cpp @@ -124,6 +124,18 @@ GetInnerProductDistanceFunc(size_t dim) { } } +DistanceFunc +GetINT8InnerProductDistanceFunc(size_t dim) { +#ifdef ENABLE_AVX512 + if (dim > 32) { + return vsag::INT8InnerProduct512ResidualsAVX512Distance; + } else if (dim > 16) { + return vsag::INT8InnerProduct256ResidualsAVX512Distance; + } +#endif + return vsag::INT8InnerProductDistance; +} + PQDistanceFunc GetPQDistanceFunc() { #ifdef ENABLE_AVX diff --git a/src/simd/simd.h b/src/simd/simd.h index 828665fc..86ca5335 100644 --- a/src/simd/simd.h +++ b/src/simd/simd.h @@ -98,6 +98,10 @@ float InnerProduct(const void* pVect1, const void* pVect2, const void* qty_ptr); float InnerProductDistance(const void* pVect1, const void* pVect2, const void* qty_ptr); +float +INT8InnerProduct(const void* pVect1, const void* pVect2, const void* qty_ptr); +float +INT8InnerProductDistance(const void* pVect1, const void* pVect2, const void* qty_ptr); void PQDistanceFloat256(const void* single_dim_centers, float single_dim_val, void* result); @@ -148,6 +152,14 @@ float L2SqrSIMD16ExtAVX512(const void* pVect1v, const void* pVect2v, const void* qty_ptr); float InnerProductSIMD16ExtAVX512(const void* pVect1v, const void* pVect2v, const void* qty_ptr); +float +INT8InnerProduct256ResidualsAVX512Distance(const void* pVect1v, + const void* pVect2v, + const void* qty_ptr); +float +INT8InnerProduct512ResidualsAVX512Distance(const void* pVect1v, + const void* pVect2v, + const void* qty_ptr); #endif typedef float (*DistanceFunc)(const void* pVect1, const void* pVect2, const void* qty_ptr); @@ -156,6 +168,9 @@ GetL2DistanceFunc(size_t dim); DistanceFunc GetInnerProductDistanceFunc(size_t dim); +DistanceFunc +GetINT8InnerProductDistanceFunc(size_t dim); + typedef void (*PQDistanceFunc)(const void* single_dim_centers, float single_dim_val, void* result); PQDistanceFunc diff --git a/tests/performance/test_performance.cpp b/tests/performance/test_performance.cpp index 8c99997f..8d2e2a44 100644 --- a/tests/performance/test_performance.cpp +++ b/tests/performance/test_performance.cpp @@ -107,29 +107,54 @@ class TestDataset { obj->number_of_base_ = train_shape.first; obj->number_of_query_ = test_shape.first; - // alloc memory - { - obj->train_ = - std::shared_ptr(new float[train_shape.first * train_shape.second]); - obj->test_ = std::shared_ptr(new float[test_shape.first * test_shape.second]); - obj->neighbors_ = std::shared_ptr( - new int64_t[neighbors_shape.first * neighbors_shape.second]); - } - // read from file { H5::DataSet dataset = file.openDataSet("/train"); H5::DataSpace dataspace = dataset.getSpace(); - H5::FloatType datatype(H5::PredType::NATIVE_FLOAT); - dataset.read(obj->train_.get(), datatype, dataspace); + auto data_type = dataset.getDataType(); + H5::PredType type = H5::PredType::ALPHA_I8; + if (data_type.getClass() == H5T_INTEGER && data_type.getSize() == 1) { + obj->train_data_type_ = vsag::DATATYPE_INT8; + type = H5::PredType::ALPHA_I8; + obj->train_data_size_ = 1; + } else if (data_type.getClass() == H5T_FLOAT) { + obj->train_data_type_ = vsag::DATATYPE_FLOAT32; + type = H5::PredType::NATIVE_FLOAT; + obj->train_data_size_ = 4; + } else { + throw std::runtime_error( + fmt::format("wrong data type, data type ({}), data size ({})", + (int)data_type.getClass(), + data_type.getSize())); + } + obj->train_ = std::shared_ptr( + new char[train_shape.first * train_shape.second * obj->train_data_size_]); + dataset.read(obj->train_.get(), type, dataspace); } + { H5::DataSet dataset = file.openDataSet("/test"); H5::DataSpace dataspace = dataset.getSpace(); - H5::FloatType datatype(H5::PredType::NATIVE_FLOAT); - dataset.read(obj->test_.get(), datatype, dataspace); + auto data_type = dataset.getDataType(); + H5::PredType type = H5::PredType::ALPHA_I8; + if (data_type.getClass() == H5T_INTEGER && data_type.getSize() == 1) { + obj->test_data_type_ = vsag::DATATYPE_INT8; + type = H5::PredType::ALPHA_I8; + obj->test_data_size_ = 1; + } else if (data_type.getClass() == H5T_FLOAT) { + obj->test_data_type_ = vsag::DATATYPE_FLOAT32; + type = H5::PredType::NATIVE_FLOAT; + obj->test_data_size_ = 4; + } else { + throw std::runtime_error("wrong data type"); + } + obj->test_ = std::shared_ptr( + new char[test_shape.first * test_shape.second * obj->test_data_size_]); + dataset.read(obj->test_.get(), type, dataspace); } { + obj->neighbors_ = std::shared_ptr( + new int64_t[neighbors_shape.first * neighbors_shape.second]); H5::DataSet dataset = file.openDataSet("/neighbors"); H5::DataSpace dataspace = dataset.getSpace(); H5::FloatType datatype(H5::PredType::NATIVE_INT64); @@ -140,14 +165,19 @@ class TestDataset { } public: - std::shared_ptr + const void* GetTrain() const { - return train_; + return train_.get(); } - std::shared_ptr + const void* GetTest() const { - return test_; + return test_.get(); + } + + const void* + GetOneTest(int64_t id) const { + return test_.get() + id * dim_ * test_data_size_; } int64_t @@ -175,6 +205,15 @@ class TestDataset { return dim_; } + std::string + GetTrainDataType() const { + return train_data_type_; + } + std::string + GetTestDataType() const { + return test_data_type_; + } + private: using shape_t = std::pair; static std::unordered_set @@ -208,8 +247,8 @@ class TestDataset { } private: - std::shared_ptr train_; - std::shared_ptr test_; + std::shared_ptr train_; + std::shared_ptr test_; std::shared_ptr neighbors_; shape_t train_shape_; shape_t test_shape_; @@ -217,6 +256,10 @@ class TestDataset { int64_t number_of_base_; int64_t number_of_query_; int64_t dim_; + size_t train_data_size_; + size_t test_data_size_; + std::string train_data_type_; + std::string test_data_type_; }; class Test { @@ -236,11 +279,12 @@ class Test { int64_t total_base = test_dataset->GetNumberOfBase(); auto ids = range(total_base); auto base = Dataset::Make(); - base->NumElements(total_base) - ->Dim(test_dataset->GetDim()) - ->Ids(ids.get()) - ->Float32Vectors(test_dataset->GetTrain().get()) - ->Owner(false); + base->NumElements(total_base)->Dim(test_dataset->GetDim())->Ids(ids.get())->Owner(false); + if (test_dataset->GetTrainDataType() == vsag::DATATYPE_FLOAT32) { + base->Float32Vectors((const float*)test_dataset->GetTrain()); + } else if (test_dataset->GetTrainDataType() == vsag::DATATYPE_INT8) { + base->Int8Vectors((const int8_t*)test_dataset->GetTrain()); + } auto build_start = std::chrono::steady_clock::now(); if (auto buildindex = index->Build(base); not buildindex.has_value()) { std::cerr << "build error: " << buildindex.error().message << std::endl; @@ -336,11 +380,13 @@ class Test { std::vector results; for (int64_t i = 0; i < total; ++i) { auto query = Dataset::Make(); - query->NumElements(1) - ->Dim(test_dataset->GetDim()) - ->Float32Vectors(test_dataset->GetTest().get() + i * test_dataset->GetDim()) - ->Owner(false); + query->NumElements(1)->Dim(test_dataset->GetDim())->Owner(false); + if (test_dataset->GetTestDataType() == vsag::DATATYPE_FLOAT32) { + query->Float32Vectors((const float*)test_dataset->GetOneTest(i)); + } else if (test_dataset->GetTestDataType() == vsag::DATATYPE_INT8) { + query->Int8Vectors((const int8_t*)test_dataset->GetOneTest(i)); + } auto result = index->KnnSearch(query, 10, search_parameters); if (not result.has_value()) { std::cerr << "query error: " << result.error().message << std::endl; diff --git a/tests/test_index.cpp b/tests/test_index.cpp index 6505f3f3..d859cecb 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -219,6 +219,59 @@ TEST_CASE("hnsw float32 recall", "[ft][index][hnsw]") { REQUIRE(range_recall > 0.99); } +TEST_CASE("hnsw int8 recall", "[ft][index][hnsw]") { + vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kDEBUG); + + int64_t num_vectors = 1000; + int64_t dim = 104; + auto metric_type = "ip"; + + nlohmann::json hnsw_parameters{{"max_degree", 12}, {"ef_construction", 50}, {"ef_search", 50}}; + nlohmann::json index_parameters{ + {"dtype", "int8"}, {"metric_type", "ip"}, {"dim", dim}, {"hnsw", hnsw_parameters}}; + std::shared_ptr hnsw; + if (auto index = vsag::Factory::CreateIndex("hnsw", index_parameters.dump()); + index.has_value()) { + hnsw = index.value(); + } else { + std::cout << "Build HNSW Error" << std::endl; + return; + } + std::shared_ptr ids(new int64_t[num_vectors]); + std::shared_ptr data(new int8_t[dim * num_vectors]); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_int_distribution<> distrib_real; + for (int i = 0; i < num_vectors; i++) ids[i] = i; + for (int i = 0; i < dim * num_vectors; i++) data[i] = (int8_t)distrib_real(rng); + + // build index + auto base = vsag::Dataset::Make(); + base->NumElements(num_vectors)->Dim(dim)->Ids(ids.get())->Int8Vectors(data.get())->Owner(false); + auto buildindex = hnsw->Build(base); + REQUIRE(buildindex.has_value()); + + { + for (int i = 0; i < num_vectors; i++) { + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Int8Vectors(data.get() + i * dim)->Owner(false); + auto search_parameters = R"( + { + "hnsw": { + "ef_search": 100 + } + } + )"; + int64_t k = 10; + auto result = hnsw->KnnSearch(query, k, search_parameters); + REQUIRE(result.has_value()); + REQUIRE(result.value()->GetIds()[0] == ids[i]); + } + } +} + TEST_CASE("index search distance", "[ft][index]") { vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kDEBUG); From 32a9d1e2768dcc25039320223b73558a975f2dec Mon Sep 17 00:00:00 2001 From: "jinjiabao.jjb" Date: Tue, 15 Oct 2024 15:03:57 +0800 Subject: [PATCH 02/11] modify Signed-off-by: jinjiabao.jjb --- tests/test_index.cpp | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/test_index.cpp b/tests/test_index.cpp index d859cecb..e67248ea 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -253,22 +253,20 @@ TEST_CASE("hnsw int8 recall", "[ft][index][hnsw]") { auto buildindex = hnsw->Build(base); REQUIRE(buildindex.has_value()); - { - for (int i = 0; i < num_vectors; i++) { - auto query = vsag::Dataset::Make(); - query->NumElements(1)->Dim(dim)->Int8Vectors(data.get() + i * dim)->Owner(false); - auto search_parameters = R"( + for (int i = 0; i < num_vectors; i++) { + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Int8Vectors(data.get() + i * dim)->Owner(false); + auto search_parameters = R"( { "hnsw": { "ef_search": 100 } } )"; - int64_t k = 10; - auto result = hnsw->KnnSearch(query, k, search_parameters); - REQUIRE(result.has_value()); - REQUIRE(result.value()->GetIds()[0] == ids[i]); - } + int64_t k = 10; + auto result = hnsw->KnnSearch(query, k, search_parameters); + REQUIRE(result.has_value()); + REQUIRE(result.value()->GetIds()[0] == ids[i]); } } From 1a90dcd3ad123b57a5242dfcd91faf07373be72a Mon Sep 17 00:00:00 2001 From: "jinjiabao.jjb" Date: Tue, 15 Oct 2024 20:33:48 +0800 Subject: [PATCH 03/11] modify Signed-off-by: jinjiabao.jjb --- tests/test_index.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_index.cpp b/tests/test_index.cpp index e67248ea..e48f5ab9 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -224,7 +224,6 @@ TEST_CASE("hnsw int8 recall", "[ft][index][hnsw]") { int64_t num_vectors = 1000; int64_t dim = 104; - auto metric_type = "ip"; nlohmann::json hnsw_parameters{{"max_degree", 12}, {"ef_construction", 50}, {"ef_search", 50}}; nlohmann::json index_parameters{ From 56c01396ee432d4258b3060ca03b03c203e6ecb9 Mon Sep 17 00:00:00 2001 From: "jinjiabao.jjb" Date: Wed, 16 Oct 2024 10:39:55 +0800 Subject: [PATCH 04/11] modify Signed-off-by: jinjiabao.jjb --- tests/test_index.cpp | 136 ++++++++++++++++++++----------------------- 1 file changed, 63 insertions(+), 73 deletions(-) diff --git a/tests/test_index.cpp b/tests/test_index.cpp index e48f5ab9..84e863c0 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -782,13 +782,8 @@ TEST_CASE("remove vectors from the index", "[ft][index]") { vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kDEBUG); int64_t num_vectors = 1000; int64_t dim = 64; - auto index_name = GENERATE("fresh_hnsw", "diskann"); + auto index_name = "fresh_hnsw"; auto metric_type = GENERATE("cosine", "ip", "l2"); - - if (index_name == std::string("diskann") and metric_type == std::string("cosine")) { - return; // TODO: support cosine for diskann - } - bool need_normalize = metric_type != std::string("cosine"); auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_vectors, dim, need_normalize); auto index = fixtures::generate_index(index_name, metric_type, num_vectors, dim, ids, vectors); @@ -807,87 +802,82 @@ TEST_CASE("remove vectors from the index", "[ft][index]") { } )"; - if (index_name != std::string("diskann")) { // index that supports remove - // remove half data - - int correct = 0; - for (int i = 0; i < num_vectors; i++) { - auto query = vsag::Dataset::Make(); - query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false); + // remove half data + int correct = 0; + for (int i = 0; i < num_vectors; i++) { + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false); - int64_t k = 10; - auto result = index->KnnSearch(query, k, search_parameters); - REQUIRE(result.has_value()); - if (result.value()->GetIds()[0] == ids[i]) { - correct += 1; - } + int64_t k = 10; + auto result = index->KnnSearch(query, k, search_parameters); + REQUIRE(result.has_value()); + if (result.value()->GetIds()[0] == ids[i]) { + correct += 1; } - float recall_before = ((float)correct) / num_vectors; + } + float recall_before = ((float)correct) / num_vectors; - for (int i = 0; i < num_vectors / 2; ++i) { - auto result = index->Remove(ids[i]); - REQUIRE(result.has_value()); - REQUIRE(result.value()); - } - auto wrong_result = index->Remove(-1); - REQUIRE(wrong_result.has_value()); - REQUIRE_FALSE(wrong_result.value()); + for (int i = 0; i < num_vectors / 2; ++i) { + auto result = index->Remove(ids[i]); + REQUIRE(result.has_value()); + REQUIRE(result.value()); + } + auto wrong_result = index->Remove(-1); + REQUIRE(wrong_result.has_value()); + REQUIRE_FALSE(wrong_result.value()); - REQUIRE(index->GetNumElements() == num_vectors / 2); + REQUIRE(index->GetNumElements() == num_vectors / 2); - // test recall for half data - correct = 0; - for (int i = 0; i < num_vectors; i++) { - auto query = vsag::Dataset::Make(); - query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false); + // test recall for half data + correct = 0; + for (int i = 0; i < num_vectors; i++) { + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false); - int64_t k = 10; - auto result = index->KnnSearch(query, k, search_parameters); - REQUIRE(result.has_value()); - if (i < num_vectors / 2) { - REQUIRE(result.value()->GetIds()[0] != ids[i]); - } else { - if (result.value()->GetIds()[0] == ids[i]) { - correct += 1; - } + int64_t k = 10; + auto result = index->KnnSearch(query, k, search_parameters); + REQUIRE(result.has_value()); + if (i < num_vectors / 2) { + REQUIRE(result.value()->GetIds()[0] != ids[i]); + } else { + if (result.value()->GetIds()[0] == ids[i]) { + correct += 1; } } - float recall = ((float)correct) / (num_vectors / 2); - REQUIRE(recall >= 0.98); + } + float recall = ((float)correct) / (num_vectors / 2); + REQUIRE(recall >= 0.98); - // remove all data - for (int i = num_vectors / 2; i < num_vectors; ++i) { - auto result = index->Remove(i); - REQUIRE(result.has_value()); - REQUIRE(result.value()); - } + // remove all data + for (int i = num_vectors / 2; i < num_vectors; ++i) { + auto result = index->Remove(i); + REQUIRE(result.has_value()); + REQUIRE(result.value()); + } - // add data into index again - correct = 0; - auto dataset = vsag::Dataset::Make(); - dataset->NumElements(num_vectors) - ->Dim(dim) - ->Float32Vectors(vectors.data()) - ->Ids(ids.data()) - ->Owner(false); - auto result = index->Add(dataset); + // add data into index again + correct = 0; + auto dataset = vsag::Dataset::Make(); + dataset->NumElements(num_vectors) + ->Dim(dim) + ->Float32Vectors(vectors.data()) + ->Ids(ids.data()) + ->Owner(false); + auto result = index->Add(dataset); - for (int i = 0; i < num_vectors; i++) { - auto query = vsag::Dataset::Make(); - query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false); + for (int i = 0; i < num_vectors; i++) { + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false); - int64_t k = 10; - auto result = index->KnnSearch(query, k, search_parameters); - REQUIRE(result.has_value()); - if (result.value()->GetIds()[0] == ids[i]) { - correct += 1; - } + int64_t k = 10; + auto result = index->KnnSearch(query, k, search_parameters); + REQUIRE(result.has_value()); + if (result.value()->GetIds()[0] == ids[i]) { + correct += 1; } - float recall_after = ((float)correct) / num_vectors; - REQUIRE(abs(recall_before - recall_after) < 0.001); - } else { // index that does not supports remove - REQUIRE_THROWS(index->Remove(-1)); } + float recall_after = ((float)correct) / num_vectors; + REQUIRE(abs(recall_before - recall_after) < 0.001); } TEST_CASE("index with bsa", "[ft][index]") { From 318c637e4be7168bb762b3661769d6f14bdda828 Mon Sep 17 00:00:00 2001 From: "jinjiabao.jjb" Date: Wed, 16 Oct 2024 13:48:37 +0800 Subject: [PATCH 05/11] modify Signed-off-by: jinjiabao.jjb --- src/index/hnsw.cpp | 2 +- src/index/hnsw_zparameters.cpp | 1 + tests/test_index.cpp | 136 ++++++++++++++++++--------------- 3 files changed, 75 insertions(+), 64 deletions(-) diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index c71f5404..a55f5b83 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -819,7 +819,7 @@ HNSW::get_vectors(const vsag::DatasetPtr& base, void*& vectors, size_t& data_siz vectors = (void*)base->GetInt8Vectors(); data_size = dim_ * sizeof(int8_t); } else { - throw std::invalid_argument("fail to support this metric"); + throw std::invalid_argument("fail to support this data type"); } } diff --git a/src/index/hnsw_zparameters.cpp b/src/index/hnsw_zparameters.cpp index 114953e2..274a133f 100644 --- a/src/index/hnsw_zparameters.cpp +++ b/src/index/hnsw_zparameters.cpp @@ -140,6 +140,7 @@ CreateFreshHnswParameters::FromJson(const std::string& json_string) { obj.space = parrent_obj.space; obj.use_static = false; obj.normalize = parrent_obj.normalize; + obj.type = parrent_obj.type; // set obj.use_reversed_edges obj.use_reversed_edges = true; diff --git a/tests/test_index.cpp b/tests/test_index.cpp index 84e863c0..2540c00b 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -782,8 +782,13 @@ TEST_CASE("remove vectors from the index", "[ft][index]") { vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kDEBUG); int64_t num_vectors = 1000; int64_t dim = 64; - auto index_name = "fresh_hnsw"; + auto index_name = GENERATE("fresh_hnsw", "diskann"); auto metric_type = GENERATE("cosine", "ip", "l2"); + + if (index_name == std::string("diskann") and metric_type == std::string("cosine")) { + return; // TODO: support cosine for diskann + } + bool need_normalize = metric_type != std::string("cosine"); auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_vectors, dim, need_normalize); auto index = fixtures::generate_index(index_name, metric_type, num_vectors, dim, ids, vectors); @@ -802,82 +807,87 @@ TEST_CASE("remove vectors from the index", "[ft][index]") { } )"; - // remove half data - int correct = 0; - for (int i = 0; i < num_vectors; i++) { - auto query = vsag::Dataset::Make(); - query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false); + if (index_name != std::string("diskann")) { // index that supports remove + // remove half data - int64_t k = 10; - auto result = index->KnnSearch(query, k, search_parameters); - REQUIRE(result.has_value()); - if (result.value()->GetIds()[0] == ids[i]) { - correct += 1; + int correct = 0; + for (int i = 0; i < num_vectors; i++) { + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false); + + int64_t k = 10; + auto result = index->KnnSearch(query, k, search_parameters); + REQUIRE(result.has_value()); + if (result.value()->GetIds()[0] == ids[i]) { + correct += 1; + } } - } - float recall_before = ((float)correct) / num_vectors; + float recall_before = ((float)correct) / num_vectors; - for (int i = 0; i < num_vectors / 2; ++i) { - auto result = index->Remove(ids[i]); - REQUIRE(result.has_value()); - REQUIRE(result.value()); - } - auto wrong_result = index->Remove(-1); - REQUIRE(wrong_result.has_value()); - REQUIRE_FALSE(wrong_result.value()); + for (int i = 0; i < num_vectors / 2; ++i) { + auto result = index->Remove(ids[i]); + REQUIRE(result.has_value()); + REQUIRE(result.value()); + } + auto wrong_result = index->Remove(-1); + REQUIRE(wrong_result.has_value()); + REQUIRE_FALSE(wrong_result.value()); - REQUIRE(index->GetNumElements() == num_vectors / 2); + REQUIRE(index->GetNumElements() == num_vectors / 2); - // test recall for half data - correct = 0; - for (int i = 0; i < num_vectors; i++) { - auto query = vsag::Dataset::Make(); - query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false); + // test recall for half data + correct = 0; + for (int i = 0; i < num_vectors; i++) { + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false); - int64_t k = 10; - auto result = index->KnnSearch(query, k, search_parameters); - REQUIRE(result.has_value()); - if (i < num_vectors / 2) { - REQUIRE(result.value()->GetIds()[0] != ids[i]); - } else { - if (result.value()->GetIds()[0] == ids[i]) { - correct += 1; + int64_t k = 10; + auto result = index->KnnSearch(query, k, search_parameters); + REQUIRE(result.has_value()); + if (i < num_vectors / 2) { + REQUIRE(result.value()->GetIds()[0] != ids[i]); + } else { + if (result.value()->GetIds()[0] == ids[i]) { + correct += 1; + } } } - } - float recall = ((float)correct) / (num_vectors / 2); - REQUIRE(recall >= 0.98); + float recall = ((float)correct) / (num_vectors / 2); + REQUIRE(recall >= 0.98); - // remove all data - for (int i = num_vectors / 2; i < num_vectors; ++i) { - auto result = index->Remove(i); - REQUIRE(result.has_value()); - REQUIRE(result.value()); - } + // remove all data + for (int i = num_vectors / 2; i < num_vectors; ++i) { + auto result = index->Remove(i); + REQUIRE(result.has_value()); + REQUIRE(result.value()); + } - // add data into index again - correct = 0; - auto dataset = vsag::Dataset::Make(); - dataset->NumElements(num_vectors) - ->Dim(dim) - ->Float32Vectors(vectors.data()) - ->Ids(ids.data()) - ->Owner(false); - auto result = index->Add(dataset); + // add data into index again + correct = 0; + auto dataset = vsag::Dataset::Make(); + dataset->NumElements(num_vectors) + ->Dim(dim) + ->Float32Vectors(vectors.data()) + ->Ids(ids.data()) + ->Owner(false); + auto result = index->Add(dataset); - for (int i = 0; i < num_vectors; i++) { - auto query = vsag::Dataset::Make(); - query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false); + for (int i = 0; i < num_vectors; i++) { + auto query = vsag::Dataset::Make(); + query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false); - int64_t k = 10; - auto result = index->KnnSearch(query, k, search_parameters); - REQUIRE(result.has_value()); - if (result.value()->GetIds()[0] == ids[i]) { - correct += 1; + int64_t k = 10; + auto result = index->KnnSearch(query, k, search_parameters); + REQUIRE(result.has_value()); + if (result.value()->GetIds()[0] == ids[i]) { + correct += 1; + } } + float recall_after = ((float)correct) / num_vectors; + REQUIRE(std::abs(recall_before - recall_after) < 0.001); + } else { // index that does not supports remove + REQUIRE_THROWS(index->Remove(-1)); } - float recall_after = ((float)correct) / num_vectors; - REQUIRE(abs(recall_before - recall_after) < 0.001); } TEST_CASE("index with bsa", "[ft][index]") { From 6513d202844775fce80ddbf23b8f704d193591af Mon Sep 17 00:00:00 2001 From: "jinjiabao.jjb" Date: Tue, 22 Oct 2024 19:59:46 +0800 Subject: [PATCH 06/11] modify Signed-off-by: jinjiabao.jjb --- src/index/hnsw.cpp | 20 +++++++++----------- src/index/hnsw.h | 2 +- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index a55f5b83..6ca6b6cd 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -124,7 +124,7 @@ HNSW::build(const DatasetPtr& base) { auto ids = base->GetIds(); void* vectors = nullptr; size_t data_size = 0; - get_vectors(base, vectors, data_size); + get_vectors(base, &vectors, &data_size); std::vector failed_ids; { SlowTaskTimer t("hnsw graph"); @@ -166,7 +166,7 @@ HNSW::add(const DatasetPtr& base) { auto ids = base->GetIds(); void* vectors = nullptr; size_t data_size = 0; - get_vectors(base, vectors, data_size); + get_vectors(base, &vectors, &data_size); std::vector failed_ids; std::unique_lock lock(rw_mutex_); @@ -221,7 +221,7 @@ HNSW::knn_search(const DatasetPtr& query, CHECK_ARGUMENT(query->GetNumElements() == 1, "query dataset should contain 1 vector only"); void* vector = nullptr; size_t data_size = 0; - get_vectors(query, vector, data_size); + get_vectors(query, &vector, &data_size); int64_t query_dim = query->GetDim(); CHECK_ARGUMENT( query_dim == dim_, @@ -341,7 +341,7 @@ HNSW::range_search(const DatasetPtr& query, CHECK_ARGUMENT(query->GetNumElements() == 1, "query dataset should contain 1 vector only"); void* vector = nullptr; size_t data_size = 0; - get_vectors(query, vector, data_size); + get_vectors(query, &vector, &data_size); int64_t query_dim = query->GetDim(); CHECK_ARGUMENT( query_dim == dim_, @@ -811,15 +811,13 @@ HNSW::init_memory_space() { } void -HNSW::get_vectors(const vsag::DatasetPtr& base, void*& vectors, size_t& data_size) const { +HNSW::get_vectors(const vsag::DatasetPtr& base, void** vectors_ptr, size_t* data_size_ptr) const { if (type_ == DataTypes::DATA_TYPE_FLOAT) { - vectors = (void*)base->GetFloat32Vectors(); - data_size = dim_ * sizeof(float); + *vectors_ptr = (void*)base->GetFloat32Vectors(); + *data_size_ptr = dim_ * sizeof(float); } else if (type_ == DataTypes::DATA_TYPE_INT8) { - vectors = (void*)base->GetInt8Vectors(); - data_size = dim_ * sizeof(int8_t); - } else { - throw std::invalid_argument("fail to support this data type"); + *vectors_ptr = (void*)base->GetInt8Vectors(); + *data_size_ptr = dim_ * sizeof(int8_t); } } diff --git a/src/index/hnsw.h b/src/index/hnsw.h index 61e65f26..1f4d4315 100644 --- a/src/index/hnsw.h +++ b/src/index/hnsw.h @@ -283,7 +283,7 @@ class HNSW : public Index { init_memory_space(); void - get_vectors(const DatasetPtr& base, void*& vectors, size_t& data_size) const; + get_vectors(const DatasetPtr& base, void** vectors_ptr, size_t* data_size_ptr) const; BinarySet empty_binaryset() const; From aa07e3074df35c338b30dc209d31c8486905638a Mon Sep 17 00:00:00 2001 From: "jinjiabao.jjb" Date: Wed, 23 Oct 2024 11:54:09 +0800 Subject: [PATCH 07/11] modify Signed-off-by: jinjiabao.jjb --- src/data_type.h | 12 ++++++++++++ src/index/hnsw.cpp | 3 +++ 2 files changed, 15 insertions(+) diff --git a/src/data_type.h b/src/data_type.h index 39aa5802..ddfa4b69 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -18,4 +18,16 @@ namespace vsag { enum class DataTypes { DATA_TYPE_FLOAT = 0, DATA_TYPE_INT8 = 1, DATA_TYPE_FP16 = 2 }; +inline std::string +datatype_to_str(DataTypes type) { + switch (type) { + case DataTypes::DATA_TYPE_FLOAT: + return "float32"; + case DataTypes::DATA_TYPE_INT8: + return "int8"; + case DataTypes::DATA_TYPE_FP16: + return "float16"; + } +} + } // namespace vsag diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index 6ca6b6cd..54b70460 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -818,6 +818,9 @@ HNSW::get_vectors(const vsag::DatasetPtr& base, void** vectors_ptr, size_t* data } else if (type_ == DataTypes::DATA_TYPE_INT8) { *vectors_ptr = (void*)base->GetInt8Vectors(); *data_size_ptr = dim_ * sizeof(int8_t); + } else { + throw std::invalid_argument( + fmt::format("no support for this metric: {}", datatype_to_str(type_))); } } From 84ce816da1faa7d79b50f2b0d7f0e415470da20c Mon Sep 17 00:00:00 2001 From: "jinjiabao.jjb" Date: Wed, 23 Oct 2024 12:00:08 +0800 Subject: [PATCH 08/11] modify Signed-off-by: jinjiabao.jjb --- src/data_type.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data_type.h b/src/data_type.h index ddfa4b69..02c98600 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -14,7 +14,7 @@ // limitations under the License. #pragma once - +#include namespace vsag { enum class DataTypes { DATA_TYPE_FLOAT = 0, DATA_TYPE_INT8 = 1, DATA_TYPE_FP16 = 2 }; From 5cac32a364dc3e04810b015d79646733db2d3078 Mon Sep 17 00:00:00 2001 From: "jinjiabao.jjb" Date: Wed, 23 Oct 2024 14:07:29 +0800 Subject: [PATCH 09/11] modify Signed-off-by: jinjiabao.jjb --- src/data_type.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/data_type.h b/src/data_type.h index 02c98600..64722f90 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -28,6 +28,7 @@ datatype_to_str(DataTypes type) { case DataTypes::DATA_TYPE_FP16: return "float16"; } + return "unknown type"; } } // namespace vsag From 1791ea7503c696dd2dfb19f598fdbfad361bf875 Mon Sep 17 00:00:00 2001 From: "jinjiabao.jjb" Date: Wed, 23 Oct 2024 14:35:23 +0800 Subject: [PATCH 10/11] modify Signed-off-by: jinjiabao.jjb --- src/algorithm/hnswlib/space_ip.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/algorithm/hnswlib/space_ip.h b/src/algorithm/hnswlib/space_ip.h index 5b92e60a..35455cf6 100644 --- a/src/algorithm/hnswlib/space_ip.h +++ b/src/algorithm/hnswlib/space_ip.h @@ -34,6 +34,9 @@ class InnerProductSpace : public SpaceInterface { } else if (type == vsag::DataTypes::DATA_TYPE_INT8) { fstdistfunc_ = vsag::GetINT8InnerProductDistanceFunc(dim); data_size_ = dim * sizeof(int8_t); + } else { + throw std::invalid_argument( + fmt::format("no support for this metric: {}", datatype_to_str(type))); } } From b2351f956322a3ed42c731a9db2abdcdc3011525 Mon Sep 17 00:00:00 2001 From: "jinjiabao.jjb" Date: Wed, 23 Oct 2024 14:54:47 +0800 Subject: [PATCH 11/11] modify Signed-off-by: jinjiabao.jjb --- src/algorithm/hnswlib/space_ip.h | 3 +-- src/data_type.h | 15 +-------------- src/index/hnsw.cpp | 3 +-- 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/src/algorithm/hnswlib/space_ip.h b/src/algorithm/hnswlib/space_ip.h index 35455cf6..50cc24a0 100644 --- a/src/algorithm/hnswlib/space_ip.h +++ b/src/algorithm/hnswlib/space_ip.h @@ -35,8 +35,7 @@ class InnerProductSpace : public SpaceInterface { fstdistfunc_ = vsag::GetINT8InnerProductDistanceFunc(dim); data_size_ = dim * sizeof(int8_t); } else { - throw std::invalid_argument( - fmt::format("no support for this metric: {}", datatype_to_str(type))); + throw std::invalid_argument(fmt::format("no support for this metric: {}", (int)type)); } } diff --git a/src/data_type.h b/src/data_type.h index 64722f90..39aa5802 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -14,21 +14,8 @@ // limitations under the License. #pragma once -#include + namespace vsag { enum class DataTypes { DATA_TYPE_FLOAT = 0, DATA_TYPE_INT8 = 1, DATA_TYPE_FP16 = 2 }; -inline std::string -datatype_to_str(DataTypes type) { - switch (type) { - case DataTypes::DATA_TYPE_FLOAT: - return "float32"; - case DataTypes::DATA_TYPE_INT8: - return "int8"; - case DataTypes::DATA_TYPE_FP16: - return "float16"; - } - return "unknown type"; -} - } // namespace vsag diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index 54b70460..5e5d7ca5 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -819,8 +819,7 @@ HNSW::get_vectors(const vsag::DatasetPtr& base, void** vectors_ptr, size_t* data *vectors_ptr = (void*)base->GetInt8Vectors(); *data_size_ptr = dim_ * sizeof(int8_t); } else { - throw std::invalid_argument( - fmt::format("no support for this metric: {}", datatype_to_str(type_))); + throw std::invalid_argument(fmt::format("no support for this metric: {}", (int)type_)); } }