Skip to content

Commit

Permalink
modify
Browse files Browse the repository at this point in the history
Signed-off-by: jinjiabao.jjb <jinjiabao.jjb@antgroup.com>
  • Loading branch information
jinjiabao.jjb committed Feb 7, 2025
1 parent 25a5d78 commit dc79928
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 86 deletions.
26 changes: 16 additions & 10 deletions src/algorithm/hnswlib/hnswalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "hnswalg.h"

#include <memory>

#include "data_cell/graph_interface.h"
namespace hnswlib {
HierarchicalNSW::HierarchicalNSW(SpaceInterface* s,
size_t max_elements,
Expand Down Expand Up @@ -1520,20 +1522,24 @@ HierarchicalNSW::searchRange(const void* query_data,
}

void
HierarchicalNSW::setDataAndGraph(const float* data,
const int64_t* ids,
int64_t data_num,
int64_t data_dim,
const vsag::Vector<vsag::Vector<uint32_t>>& graph) {
resizeIndex(data_num);
for (int i = 0; i < data_num; ++i) {
std::memcpy(getDataByInternalId(i), data + i * data_dim, data_size_);
setBatchNeigohbors(i, 0, graph[i].data(), graph[i].size());
HierarchicalNSW::setDataAndGraph(vsag::FlattenInterfacePtr& data,
vsag::GraphInterfacePtr& graph,
vsag::Vector<LabelType>& ids) {
resizeIndex(data->total_count_);
std::shared_ptr<uint8_t[]> temp_vector =
std::shared_ptr<uint8_t[]>(new uint8_t[data->code_size_]);
for (int i = 0; i < data->total_count_; ++i) {
data->GetCodesById(i, temp_vector.get());
std::memcpy(
getDataByInternalId(i), reinterpret_cast<const char*>(temp_vector.get()), data_size_);
vsag::Vector<InnerIdType> edges(allocator_);
graph->GetNeighbors(i, edges);
setBatchNeigohbors(i, 0, edges.data(), edges.size());
setExternalLabel(i, ids[i]);
label_lookup_[ids[i]] = i;
element_levels_[i] = 0;
}
cur_element_count_ = data_num;
cur_element_count_ = data->total_count_;
enterpoint_node_ = 0;
max_level_ = 0;
}
Expand Down
14 changes: 7 additions & 7 deletions src/algorithm/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
#include <unordered_map>
#include <unordered_set>

#include "../../default_allocator.h"
#include "../../simd/simd.h"
#include "../../utils.h"
#include "algorithm_interface.h"
#include "block_manager.h"
#include "data_cell/flatten_interface.h"
#include "data_cell/graph_interface.h"
#include "default_allocator.h"
#include "simd/simd.h"
#include "visited_list_pool.h"
#include "vsag/dataset.h"
namespace hnswlib {
Expand Down Expand Up @@ -309,11 +311,9 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
resizeIndex(size_t new_max_elements) override;

void
setDataAndGraph(const float* data,
const int64_t* ids,
int64_t data_num,
int64_t data_dim,
const vsag::Vector<vsag::Vector<uint32_t>>& graph);
setDataAndGraph(vsag::FlattenInterfacePtr& data,
vsag::GraphInterfacePtr& graph,
vsag::Vector<LabelType>& ids);

size_t
calcSerializeSize() override;
Expand Down
101 changes: 44 additions & 57 deletions src/index/hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
#include <nlohmann/json.hpp>
#include <stdexcept>

#include "../utils.h"
#include "algorithm/hnswlib/hnswlib.h"
#include "common.h"
#include "data_cell/flatten_datacell.h"
#include "data_cell/graph_datacell_parameter.h"
#include "index/hnsw_zparameters.h"
#include "logger.h"
#include "io/memory_io_parameter.h"
#include "quantization/fp32_quantizer_parameter.h"
#include "safe_allocator.h"
#include "vsag/binaryset.h"
#include "vsag/constants.h"
Expand All @@ -52,7 +54,9 @@ HNSW::HNSW(HnswParameters hnsw_params, const IndexCommonParam& index_common_para
use_conjugate_graph_(hnsw_params.use_conjugate_graph),
use_reversed_edges_(hnsw_params.use_reversed_edges),
type_(hnsw_params.type),
dim_(index_common_param.dim_) {
max_degree_(hnsw_params.max_degree),
dim_(index_common_param.dim_),
index_common_param_(index_common_param) {
auto M = std::min( // NOLINT(readability-identifier-naming)
std::max((int)hnsw_params.max_degree, MINIMAL_M),
MAXIMAL_M);
Expand Down Expand Up @@ -995,23 +999,17 @@ HNSW::init_feature_list() {
}

bool
HNSW::ExtractDataAndGraph(const DatasetPtr& dataset,
Vector<Vector<uint32_t>>& graph,
IdMapFunction func) {
HNSW::ExtractDataAndGraph(FlattenInterfacePtr& data,
GraphInterfacePtr& graph,
Vector<LabelType>& ids,
IdMapFunction func,
Allocator* allocator) {
if (use_static_) {
return false;
}
auto hnsw = std::static_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_);
auto cur_element_count = hnsw->getCurrentElementCount();
int64_t origin_data_num = dataset->GetNumElements();
int64_t origin_data_dim = dataset->GetDim();
auto dataset_vectors = dataset->GetFloat32Vectors();
auto dataset_ids = const_cast<int64_t*>(dataset->GetIds());
CHECK_ARGUMENT(
origin_data_dim == dim_,
fmt::format("origin_data_dim({}) is not equal to dim_({}) when extract data in hnsw",
origin_data_dim,
dim_));
int64_t origin_data_num = data->total_count_;
int64_t valid_id_count = 0;
BitsetPtr bitset = std::make_shared<BitsetImpl>();
for (auto i = 0; i < cur_element_count; ++i) {
Expand All @@ -1022,82 +1020,71 @@ HNSW::ExtractDataAndGraph(const DatasetPtr& dataset,
}
auto offset = valid_id_count + origin_data_num;
char* vector_data = hnsw->getDataByInternalId(i);
std::memcpy((char*)(dataset_vectors + offset * dim_), vector_data, sizeof(float) * dim_);
int* data = (int*)hnsw->getLinklistAtLevel(i, 0);
size_t size = hnsw->getListCount((unsigned int*)data);
data->InsertVector(reinterpret_cast<float*>(vector_data));
int* link_data = (int*)hnsw->getLinklistAtLevel(i, 0);
size_t size = hnsw->getListCount((unsigned int*)link_data);
Vector<InnerIdType> edge(allocator);
for (int j = 0; j < size; ++j) {
if (not bitset->Test(*(data + 1 + j))) {
graph[offset].push_back(origin_data_num + *(data + 1 + j));
if (not bitset->Test(*(link_data + 1 + j))) {
edge.push_back(origin_data_num + *(link_data + 1 + j));
}
}
dataset_ids[offset] = new_id;
graph->InsertNeighborsById(offset, edge);
graph->IncreaseTotalCount(1);
ids.push_back(new_id);
valid_id_count++;
}
dataset->NumElements(origin_data_num + valid_id_count);
return true;
}

bool
HNSW::SetDataAndGraph(const DatasetPtr& dataset, const Vector<Vector<uint32_t>>& graph) {
HNSW::SetDataAndGraph(FlattenInterfacePtr& data, GraphInterfacePtr& graph, Vector<LabelType>& ids) {
if (use_static_) {
return false;
}
auto hnsw = std::static_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_);
hnsw->setDataAndGraph(dataset->GetFloat32Vectors(),
dataset->GetIds(),
dataset->GetNumElements(),
dataset->GetDim(),
graph);
hnsw->setDataAndGraph(data, graph, ids);
return true;
}

void
extract_data_and_graph(const std::vector<MergeUnit>& merge_units,
const DatasetPtr& dataset,
Vector<Vector<uint32_t>>& graph) {
FlattenInterfacePtr& data,
GraphInterfacePtr& graph,
Vector<LabelType>& ids,
Allocator* allocator) {
for (const auto& merge_unit : merge_units) {
auto stat_string = merge_unit.index->GetStats();
auto stats = JsonType::parse(stat_string);
std::string index_name = stats[STATSTIC_INDEX_NAME];
auto hnsw = std::dynamic_pointer_cast<HNSW>(merge_unit.index);
hnsw->ExtractDataAndGraph(dataset, graph, merge_unit.id_map_func);
hnsw->ExtractDataAndGraph(data, graph, ids, merge_unit.id_map_func, allocator);
}
}

tl::expected<void, Error>
HNSW::merge(const std::vector<MergeUnit>& merge_units) {
int64_t total_data_num = this->GetNumElements();
for (const auto& merge_unit : merge_units) {
total_data_num += merge_unit.index->GetNumElements();
}
DatasetPtr dataset = Dataset::Make();
auto& allocator = allocator_;
dataset->Owner(true, allocator.get());
auto vectors = (float*)allocator->Allocate(dim_ * total_data_num * sizeof(float*));
if (vectors == nullptr) {
LOG_ERROR_AND_RETURNS(ErrorType::NO_ENOUGH_MEMORY,
"fail to allocate vectors in the process of merge index");
}
dataset->Float32Vectors(vectors);
auto ids = (int64_t*)allocator->Allocate(total_data_num * sizeof(int64_t*));
if (ids == nullptr) {
LOG_ERROR_AND_RETURNS(ErrorType::NO_ENOUGH_MEMORY,
"fail to allocate ids in the process of merge index");
}
dataset->Ids(ids);
dataset->NumElements(0);
dataset->Dim(dim_);
Vector<Vector<uint32_t>> graph(
total_data_num, Vector<uint32_t>(allocator.get()), allocator.get());
auto param = std::make_shared<FlattenDataCellParameter>();
param->io_parameter_ = std::make_shared<MemoryIOParameter>();
param->quantizer_parameter_ = std::make_shared<FP32QuantizerParameter>();
GraphDataCellParamPtr graph_param_ptr = std::make_shared<GraphDataCellParameter>();
graph_param_ptr->io_parameter_ = std::make_shared<vsag::MemoryIOParameter>();
graph_param_ptr->max_degree_ = max_degree_ * 2;

FlattenInterfacePtr flatten_interface =
FlattenInterface::MakeInstance(param, index_common_param_);
GraphInterfacePtr graph_interface =
GraphInterface::MakeInstance(graph_param_ptr, index_common_param_, false);
Vector<LabelType> ids(allocator_.get());
// extract data and graph
IdMapFunction id_map = [](int64_t id) -> std::tuple<bool, int64_t> {
return std::make_tuple(true, id);
};
this->ExtractDataAndGraph(dataset, graph, id_map);
extract_data_and_graph(merge_units, dataset, graph);
this->ExtractDataAndGraph(flatten_interface, graph_interface, ids, id_map, allocator_.get());
extract_data_and_graph(merge_units, flatten_interface, graph_interface, ids, allocator_.get());
// TODO(inabao): merge graph
// set graph
SetDataAndGraph(dataset, graph);
SetDataAndGraph(flatten_interface, graph_interface, ids);
return {};
}

Expand Down
15 changes: 11 additions & 4 deletions src/index/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include "algorithm/hnswlib/hnswlib.h"
#include "base_filter_functor.h"
#include "common.h"
#include "data_cell/flatten_interface.h"
#include "data_cell/graph_interface.h"
#include "data_type.h"
#include "hnsw_zparameters.h"
#include "impl/conjugate_graph.h"
Expand Down Expand Up @@ -216,12 +218,14 @@ class HNSW : public Index {
InitMemorySpace();

bool
ExtractDataAndGraph(const DatasetPtr& dataset,
Vector<Vector<uint32_t>>& graph,
IdMapFunction func);
ExtractDataAndGraph(FlattenInterfacePtr& data,
GraphInterfacePtr& graph,
Vector<LabelType>& ids,
IdMapFunction func,
Allocator* allocator);

bool
SetDataAndGraph(const DatasetPtr& dataset, const Vector<Vector<uint32_t>>& graph);
SetDataAndGraph(FlattenInterfacePtr& data, GraphInterfacePtr& graph, Vector<LabelType>& ids);

private:
tl::expected<std::vector<int64_t>, Error>
Expand Down Expand Up @@ -324,6 +328,8 @@ class HNSW : public Index {
bool empty_index_ = false;
bool use_reversed_edges_ = false;
bool is_init_memory_ = false;
int64_t max_degree_{0};

DataTypes type_;

std::shared_ptr<Allocator> allocator_;
Expand All @@ -334,6 +340,7 @@ class HNSW : public Index {
mutable std::shared_mutex rw_mutex_;

IndexFeatureList feature_list_{};
const IndexCommonParam& index_common_param_;
};

} // namespace vsag
24 changes: 16 additions & 8 deletions src/index/hnsw_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@

#include "../data_type.h"
#include "../logger.h"
#include "data_cell/graph_datacell_parameter.h"
#include "fixtures.h"
#include "io/memory_io_parameter.h"
#include "quantization/fp32_quantizer_parameter.h"
#include "vsag/bitset.h"
#include "vsag/errors.h"
#include "vsag/options.h"
Expand Down Expand Up @@ -966,21 +969,26 @@ TEST_CASE("extract/set data and graph", "[ut][hnsw]") {
auto result = index->Build(dataset);
REQUIRE(result.has_value());

auto new_dataset = Dataset::Make();
auto new_ids = new int64_t[num_elements];
auto new_vectors = new float[num_elements * dim];
Vector<Vector<uint32_t>> graph(
num_elements, Vector<uint32_t>(allocator.get()), allocator.get());
new_dataset->NumElements(0)->Dim(dim)->Float32Vectors(new_vectors)->Ids(new_ids);
auto param = std::make_shared<FlattenDataCellParameter>();
param->io_parameter_ = std::make_shared<vsag::MemoryIOParameter>();
param->quantizer_parameter_ = std::make_shared<vsag::FP32QuantizerParameter>();
vsag::GraphDataCellParamPtr graph_param_ptr = std::make_shared<vsag::GraphDataCellParameter>();
graph_param_ptr->io_parameter_ = std::make_shared<vsag::MemoryIOParameter>();

FlattenInterfacePtr flatten_interface = FlattenInterface::MakeInstance(param, commom_param);
GraphInterfacePtr graph_interface =
GraphInterface::MakeInstance(graph_param_ptr, commom_param, false);
Vector<LabelType> ids_vector(allocator.get());

IdMapFunction id_map = [](int64_t id) -> std::tuple<bool, int64_t> {
return std::make_tuple(true, id);
};
REQUIRE(index->ExtractDataAndGraph(new_dataset, graph, id_map));
REQUIRE(index->ExtractDataAndGraph(
flatten_interface, graph_interface, ids_vector, id_map, allocator.get()));

auto another_index = std::make_shared<HNSW>(hnsw_obj, commom_param);
another_index->InitMemorySpace();
REQUIRE(another_index->SetDataAndGraph(new_dataset, graph));
REQUIRE(another_index->SetDataAndGraph(flatten_interface, graph_interface, ids_vector));

dataset->Dim(dim)
->NumElements(num_elements / 2)
Expand Down

0 comments on commit dc79928

Please sign in to comment.