Skip to content

Commit

Permalink
add graph_datacell and sparse_graph_datacell
Browse files Browse the repository at this point in the history
- graph_datacell for bottom graph
- sparse_graph_datacell for high level graph

Signed-off-by: LHT129 <tianlan.lht@antgroup.com>
  • Loading branch information
LHT129 committed Oct 31, 2024
1 parent 5e58d3e commit 694bb38
Showing 14 changed files with 670 additions and 39 deletions.
8 changes: 4 additions & 4 deletions src/data_cell/flatten_datacell.h
Original file line number Diff line number Diff line change
@@ -30,8 +30,8 @@ class FlattenDataCell : public FlattenInterface {
public:
FlattenDataCell() = default;

explicit FlattenDataCell(const nlohmann::json& quantization_obj,
const nlohmann::json& io_obj,
explicit FlattenDataCell(const JsonType& quantization_obj,
const JsonType& io_obj,
const IndexCommonParam& common_param);

void
@@ -121,8 +121,8 @@ class FlattenDataCell : public FlattenInterface {
};

template <typename QuantTmpl, typename IOTmpl>
FlattenDataCell<QuantTmpl, IOTmpl>::FlattenDataCell(const nlohmann::json& quantization_obj,
const nlohmann::json& io_obj,
FlattenDataCell<QuantTmpl, IOTmpl>::FlattenDataCell(const JsonType& quantization_obj,
const JsonType& io_obj,
const IndexCommonParam& common_param)
: allocator_(common_param.allocator_) {
this->quantizer_ = std::make_shared<QuantTmpl>(quantization_obj, common_param);
24 changes: 12 additions & 12 deletions src/data_cell/flatten_datacell_test.cpp
Original file line number Diff line number Diff line change
@@ -30,8 +30,8 @@ template <typename QuantTmpl, typename IOTmpl, MetricType metric>
void
TestFlattenDataCell(int dim,
std::shared_ptr<Allocator> allocator,
const nlohmann::json& quantizer_json,
const nlohmann::json& io_json,
const JsonType& quantizer_json,
const JsonType& io_json,
float error = 1e-5) {
auto counts = {100, 1000};
IndexCommonParam common;
@@ -53,8 +53,8 @@ template <typename IOTmpl>
void
TestFlattenDataCellFP32(int dim,
std::shared_ptr<Allocator> allocator,
const nlohmann::json& quantizer_json,
const nlohmann::json& io_json,
const JsonType& quantizer_json,
const JsonType& io_json,
float error = 1e-5) {
constexpr MetricType metrics[3] = {
MetricType::METRIC_TYPE_L2SQR, MetricType::METRIC_TYPE_COSINE, MetricType::METRIC_TYPE_IP};
@@ -66,10 +66,10 @@ TestFlattenDataCellFP32(int dim,
dim, allocator, quantizer_json, io_json, error);
}

TEST_CASE("fp32 [ut][flatten_data_cell]") {
TEST_CASE("fp32", "[ut][flatten_data_cell]") {
auto allocator = std::make_shared<DefaultAllocator>();
auto fp32_param = nlohmann::json::parse("{}");
auto io_param = nlohmann::json::parse("{}");
auto fp32_param = JsonType::parse("{}");
auto io_param = JsonType::parse("{}");
auto dims = {8, 64, 512};
float error = 1e-5;
for (auto dim : dims) {
@@ -82,8 +82,8 @@ template <typename IOTmpl>
void
TestFlattenDataCellSQ8(int dim,
std::shared_ptr<Allocator> allocator,
const nlohmann::json& quantizer_json,
const nlohmann::json& io_json,
const JsonType& quantizer_json,
const JsonType& io_json,
float error = 1e-5) {
constexpr MetricType metrics[3] = {
MetricType::METRIC_TYPE_L2SQR, MetricType::METRIC_TYPE_COSINE, MetricType::METRIC_TYPE_IP};
@@ -95,10 +95,10 @@ TestFlattenDataCellSQ8(int dim,
dim, allocator, quantizer_json, io_json, error);
}

TEST_CASE("sq8 [ut][flatten_data_cell]") {
TEST_CASE("sq8", "[ut][flatten_data_cell]") {
auto allocator = std::make_shared<DefaultAllocator>();
auto sq8_param = nlohmann::json::parse("{}");
auto io_param = nlohmann::json::parse("{}");
auto sq8_param = JsonType::parse("{}");
auto io_param = JsonType::parse("{}");
auto dims = {32, 64, 512};
auto error = 2e-2f;
for (auto dim : dims) {
9 changes: 4 additions & 5 deletions src/data_cell/flatten_interface.cpp
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ namespace vsag {

template <typename QuantTemp, typename IOTemp>
static FlattenInterfacePtr
make_instance(const nlohmann::json& json_obj, const IndexCommonParam& common_param) {
make_instance(const JsonType& json_obj, const IndexCommonParam& common_param) {
CHECK_ARGUMENT(
json_obj.contains(QUANTIZATION_PARAMS_KEY),
fmt::format("flatten interface parameters must contains {}", QUANTIZATION_PARAMS_KEY));
@@ -39,7 +39,7 @@ make_instance(const nlohmann::json& json_obj, const IndexCommonParam& common_par

template <MetricType metric, typename IOTemp>
static FlattenInterfacePtr
make_instance(const nlohmann::json& json_obj, const IndexCommonParam& common_param) {
make_instance(const JsonType& json_obj, const IndexCommonParam& common_param) {
CHECK_ARGUMENT(
json_obj.contains(QUANTIZATION_TYPE_KEY),
fmt::format("flatten interface parameters must contains {}", QUANTIZATION_TYPE_KEY));
@@ -55,7 +55,7 @@ make_instance(const nlohmann::json& json_obj, const IndexCommonParam& common_par

template <typename IOTemp>
static FlattenInterfacePtr
make_instance(const nlohmann::json& json_obj, const IndexCommonParam& common_param) {
make_instance(const JsonType& json_obj, const IndexCommonParam& common_param) {
auto metric = common_param.metric_;
if (metric == MetricType::METRIC_TYPE_L2SQR) {
return make_instance<MetricType::METRIC_TYPE_L2SQR, IOTemp>(json_obj, common_param);
@@ -70,8 +70,7 @@ make_instance(const nlohmann::json& json_obj, const IndexCommonParam& common_par
}

FlattenInterfacePtr
FlattenInterface::MakeInstance(const nlohmann::json& json_obj,
const IndexCommonParam& common_param) {
FlattenInterface::MakeInstance(const JsonType& json_obj, const IndexCommonParam& common_param) {
CHECK_ARGUMENT(json_obj.contains(IO_TYPE_KEY),
fmt::format("flatten interface parameters must contains {}", IO_TYPE_KEY));
std::string io_string = json_obj[IO_TYPE_KEY];
2 changes: 1 addition & 1 deletion src/data_cell/flatten_interface.h
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@ class FlattenInterface {
FlattenInterface() = default;

static FlattenInterfacePtr
MakeInstance(const nlohmann::json& json_obj, const IndexCommonParam& common_param);
MakeInstance(const JsonType& json_obj, const IndexCommonParam& common_param);

public:
virtual void
155 changes: 155 additions & 0 deletions src/data_cell/graph_datacell.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <limits>
#include <memory>
#include <nlohmann/json.hpp>
#include <unordered_map>
#include <vector>

#include "algorithm/hnswlib/hnswalg.h"
#include "common.h"
#include "graph_interface.h"
#include "index/index_common_param.h"
#include "io/basic_io.h"
#include "vsag/constants.h"

namespace vsag {

/**
* built by nn-descent or incremental insertion
* add neighbors and pruning
* retrieve neighbors
*/
template <typename IOTmpl, bool is_adapter>
class GraphDataCell;

template <typename IOTmpl>
class GraphDataCell<IOTmpl, false> : public GraphInterface {
public:
GraphDataCell(const JsonType& graph_json_params,
const JsonType& io_json_params,
const IndexCommonParam& common_param);

void
InsertNeighborsById(InnerIdType id, const Vector<InnerIdType>& neighbor_ids) override;

[[nodiscard]] uint32_t
GetNeighborSize(InnerIdType id) const override;

void
GetNeighbors(InnerIdType id, Vector<InnerIdType>& neighbor_ids) const override;

inline void
SetIO(std::shared_ptr<BasicIO<IOTmpl>> io) {
this->io_ = io;
}

/****
* prefetch neighbors of a base point with id
* @param id of base point
* @param neighbor_i index of neighbor, 0 for neighbor size, 1 for first neighbor
*/
void
Prefetch(InnerIdType id, uint32_t neighbor_i) override {
io_->Prefetch(id * this->code_line_size_ + sizeof(uint32_t) +
neighbor_i * sizeof(InnerIdType));
}

void
Serialize(StreamWriter& writer) override;

void
Deserialize(StreamReader& reader) override;

private:
std::shared_ptr<BasicIO<IOTmpl>> io_{nullptr};

uint32_t code_line_size_{0};
};

template <typename IOTmpl>
GraphDataCell<IOTmpl, false>::GraphDataCell(const JsonType& graph_json_params,
const JsonType& io_json_params,
const IndexCommonParam& common_param) {
this->io_ = std::make_shared<IOTmpl>(io_json_params, common_param);
if (graph_json_params.contains(GRAPH_PARAM_MAX_DEGREE)) {
this->maximum_degree_ = graph_json_params[GRAPH_PARAM_MAX_DEGREE];
}

if (graph_json_params.contains(GRAPH_PARAM_INIT_MAX_CAPACITY)) {
this->max_capacity_ = graph_json_params[GRAPH_PARAM_INIT_MAX_CAPACITY];
}

this->code_line_size_ = this->maximum_degree_ * sizeof(InnerIdType) + sizeof(uint32_t);
}

template <typename IOTmpl>
void
GraphDataCell<IOTmpl, false>::InsertNeighborsById(InnerIdType id,
const Vector<InnerIdType>& neighbor_ids) {
if (neighbor_ids.size() > this->maximum_degree_) {
logger::error(fmt::format(
"insert neighbors count {} more than {}", neighbor_ids.size(), this->maximum_degree_));
}
this->max_capacity_ = std::max(this->max_capacity_, id + 1);
auto start = id * this->code_line_size_;
uint32_t neighbor_count = std::min((uint32_t)(neighbor_ids.size()), this->maximum_degree_);
this->io_->Write((uint8_t*)(&neighbor_count), sizeof(neighbor_count), start);
start += sizeof(neighbor_count);
this->io_->Write((uint8_t*)(neighbor_ids.data()), neighbor_count * sizeof(InnerIdType), start);
}

template <typename IOTmpl>
uint32_t
GraphDataCell<IOTmpl, false>::GetNeighborSize(InnerIdType id) const {
auto start = id * this->code_line_size_;
uint32_t result = 0;
this->io_->Read(sizeof(result), start, (uint8_t*)(&result));
return result;
}

template <typename IOTmpl>
void
GraphDataCell<IOTmpl, false>::GetNeighbors(InnerIdType id,
Vector<InnerIdType>& neighbor_ids) const {
auto start = id * this->code_line_size_;
uint32_t neighbor_count = 0;
this->io_->Read(sizeof(neighbor_count), start, (uint8_t*)(&neighbor_count));
neighbor_ids.resize(neighbor_count);
start += sizeof(neighbor_count);
this->io_->Read(
neighbor_ids.size() * sizeof(InnerIdType), start, (uint8_t*)(neighbor_ids.data()));
}

template <typename IOTmpl>
void
GraphDataCell<IOTmpl, false>::Serialize(StreamWriter& writer) {
GraphInterface::Serialize(writer);
this->io_->Serialize(writer);
StreamWriter::WriteObj(writer, this->code_line_size_);
}

template <typename IOTmpl>
void
GraphDataCell<IOTmpl, false>::Deserialize(StreamReader& reader) {
GraphInterface::Deserialize(reader);
this->io_->Deserialize(reader);
StreamReader::ReadObj(reader, this->code_line_size_);
}

} // namespace vsag
67 changes: 67 additions & 0 deletions src/data_cell/graph_datacell_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "graph_datacell.h"

#include "catch2/catch_template_test_macros.hpp"
#include "fmt/format-inl.h"
#include "graph_interface_test.h"
#include "io/io_headers.h"
using namespace vsag;

template <typename IOTemp>
void
TestGraphDataCell(const JsonType& graph_param,
const JsonType& io_param,
const IndexCommonParam& param) {
auto counts = {1000, 2000};
auto max_id = 10000;
for (auto count : counts) {
auto graph = std::make_shared<GraphDataCell<IOTemp, false>>(graph_param, io_param, param);
GraphInterfaceTest test(graph);
auto other = std::make_shared<GraphDataCell<IOTemp, false>>(graph_param, io_param, param);
test.BasicTest(max_id, count, other);
}
}

TEST_CASE("graph basic test", "[ut][graph_datacell]") {
auto allocator = std::make_shared<DefaultAllocator>();
auto dims = {32, 64};
auto max_degrees = {5, 12, 24, 32, 64, 128};
auto max_capacities = {1, 100, 10000, 10'000'000, 32'179'837};
std::vector<JsonType> graph_params;
std::string graph_param_temp = R"(
{{
"max_degree": {},
"init_capacity": {}
}}
)";
for (auto degree : max_degrees) {
for (auto capacity : max_capacities) {
auto str = fmt::format(graph_param_temp, degree, capacity);
graph_params.emplace_back(JsonType::parse(str));
}
}
auto io_param = JsonType::parse("{}");
for (auto dim : dims) {
IndexCommonParam param;
param.dim_ = dim;
param.allocator_ = allocator.get();
for (auto& gp : graph_params) {
TestGraphDataCell<MemoryIO>(gp, io_param, param);
TestGraphDataCell<MemoryBlockIO>(gp, io_param, param);
}
}
}
Loading

0 comments on commit 694bb38

Please sign in to comment.