Skip to content

Commit

Permalink
add graph_datacell and sparse_graph_datacell (#96)
Browse files Browse the repository at this point in the history
- graph_datacell for bottom graph
- sparse_graph_datacell for high level graph

Signed-off-by: LHT129 <tianlan.lht@antgroup.com>
  • Loading branch information
LHT129 authored Nov 4, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 185eccf commit 385a5d9
Showing 18 changed files with 699 additions and 62 deletions.
12 changes: 6 additions & 6 deletions src/data_cell/flatten_datacell.h
Original file line number Diff line number Diff line change
@@ -30,8 +30,8 @@ class FlattenDataCell : public FlattenInterface {
public:
FlattenDataCell() = default;

explicit FlattenDataCell(const nlohmann::json& quantization_obj,
const nlohmann::json& io_obj,
explicit FlattenDataCell(const JsonType& quantization_param,
const JsonType& io_param,
const IndexCommonParam& common_param);

void
@@ -121,12 +121,12 @@ class FlattenDataCell : public FlattenInterface {
};

template <typename QuantTmpl, typename IOTmpl>
FlattenDataCell<QuantTmpl, IOTmpl>::FlattenDataCell(const nlohmann::json& quantization_obj,
const nlohmann::json& io_obj,
FlattenDataCell<QuantTmpl, IOTmpl>::FlattenDataCell(const JsonType& quantization_param,
const JsonType& io_param,
const IndexCommonParam& common_param)
: allocator_(common_param.allocator_) {
this->quantizer_ = std::make_shared<QuantTmpl>(quantization_obj, common_param);
this->io_ = std::make_shared<IOTmpl>(io_obj, common_param);
this->quantizer_ = std::make_shared<QuantTmpl>(quantization_param, common_param);
this->io_ = std::make_shared<IOTmpl>(io_param, common_param);
this->code_size_ = quantizer_->GetCodeSize();
}

24 changes: 12 additions & 12 deletions src/data_cell/flatten_datacell_test.cpp
Original file line number Diff line number Diff line change
@@ -30,8 +30,8 @@ template <typename QuantTmpl, typename IOTmpl, MetricType metric>
void
TestFlattenDataCell(int dim,
std::shared_ptr<Allocator> allocator,
const nlohmann::json& quantizer_json,
const nlohmann::json& io_json,
const JsonType& quantizer_json,
const JsonType& io_json,
float error = 1e-5) {
auto counts = {100, 1000};
IndexCommonParam common;
@@ -53,8 +53,8 @@ template <typename IOTmpl>
void
TestFlattenDataCellFP32(int dim,
std::shared_ptr<Allocator> allocator,
const nlohmann::json& quantizer_json,
const nlohmann::json& io_json,
const JsonType& quantizer_json,
const JsonType& io_json,
float error = 1e-5) {
constexpr MetricType metrics[3] = {
MetricType::METRIC_TYPE_L2SQR, MetricType::METRIC_TYPE_COSINE, MetricType::METRIC_TYPE_IP};
@@ -66,10 +66,10 @@ TestFlattenDataCellFP32(int dim,
dim, allocator, quantizer_json, io_json, error);
}

TEST_CASE("fp32 [ut][flatten_data_cell]") {
TEST_CASE("fp32", "[ut][flatten_data_cell]") {
auto allocator = std::make_shared<DefaultAllocator>();
auto fp32_param = nlohmann::json::parse("{}");
auto io_param = nlohmann::json::parse("{}");
auto fp32_param = JsonType::parse("{}");
auto io_param = JsonType::parse("{}");
auto dims = {8, 64, 512};
float error = 1e-5;
for (auto dim : dims) {
@@ -82,8 +82,8 @@ template <typename IOTmpl>
void
TestFlattenDataCellSQ8(int dim,
std::shared_ptr<Allocator> allocator,
const nlohmann::json& quantizer_json,
const nlohmann::json& io_json,
const JsonType& quantizer_json,
const JsonType& io_json,
float error = 1e-5) {
constexpr MetricType metrics[3] = {
MetricType::METRIC_TYPE_L2SQR, MetricType::METRIC_TYPE_COSINE, MetricType::METRIC_TYPE_IP};
@@ -95,10 +95,10 @@ TestFlattenDataCellSQ8(int dim,
dim, allocator, quantizer_json, io_json, error);
}

TEST_CASE("sq8 [ut][flatten_data_cell]") {
TEST_CASE("sq8", "[ut][flatten_data_cell]") {
auto allocator = std::make_shared<DefaultAllocator>();
auto sq8_param = nlohmann::json::parse("{}");
auto io_param = nlohmann::json::parse("{}");
auto sq8_param = JsonType::parse("{}");
auto io_param = JsonType::parse("{}");
auto dims = {32, 64, 512};
auto error = 2e-2f;
for (auto dim : dims) {
41 changes: 23 additions & 18 deletions src/data_cell/flatten_interface.cpp
Original file line number Diff line number Diff line change
@@ -27,59 +27,64 @@ namespace vsag {

template <typename QuantTemp, typename IOTemp>
static FlattenInterfacePtr
make_instance(const nlohmann::json& json_obj, const IndexCommonParam& common_param) {
make_instance(const JsonType& flatten_interface_param, const IndexCommonParam& common_param) {
CHECK_ARGUMENT(
json_obj.contains(QUANTIZATION_PARAMS_KEY),
flatten_interface_param.contains(QUANTIZATION_PARAMS_KEY),
fmt::format("flatten interface parameters must contains {}", QUANTIZATION_PARAMS_KEY));
CHECK_ARGUMENT(json_obj.contains(IO_PARAMS_KEY),
CHECK_ARGUMENT(flatten_interface_param.contains(IO_PARAMS_KEY),
fmt::format("flatten interface parameters must contains {}", IO_PARAMS_KEY));
return std::make_shared<FlattenDataCell<QuantTemp, IOTemp>>(
json_obj[QUANTIZATION_PARAMS_KEY], json_obj[IO_PARAMS_KEY], common_param);
flatten_interface_param[QUANTIZATION_PARAMS_KEY],
flatten_interface_param[IO_PARAMS_KEY],
common_param);
}

template <MetricType metric, typename IOTemp>
static FlattenInterfacePtr
make_instance(const nlohmann::json& json_obj, const IndexCommonParam& common_param) {
make_instance(const JsonType& flatten_interface_param, const IndexCommonParam& common_param) {
CHECK_ARGUMENT(
json_obj.contains(QUANTIZATION_TYPE_KEY),
flatten_interface_param.contains(QUANTIZATION_TYPE_KEY),
fmt::format("flatten interface parameters must contains {}", QUANTIZATION_TYPE_KEY));

std::string quantization_string = json_obj[QUANTIZATION_TYPE_KEY];
std::string quantization_string = flatten_interface_param[QUANTIZATION_TYPE_KEY];
if (quantization_string == QUANTIZATION_TYPE_VALUE_SQ8) {
return make_instance<SQ8Quantizer<metric>, IOTemp>(json_obj, common_param);
return make_instance<SQ8Quantizer<metric>, IOTemp>(flatten_interface_param, common_param);
} else if (quantization_string == QUANTIZATION_TYPE_VALUE_FP32) {
return make_instance<FP32Quantizer<metric>, IOTemp>(json_obj, common_param);
return make_instance<FP32Quantizer<metric>, IOTemp>(flatten_interface_param, common_param);
}
return nullptr;
}

template <typename IOTemp>
static FlattenInterfacePtr
make_instance(const nlohmann::json& json_obj, const IndexCommonParam& common_param) {
make_instance(const JsonType& flatten_interface_param, const IndexCommonParam& common_param) {
auto metric = common_param.metric_;
if (metric == MetricType::METRIC_TYPE_L2SQR) {
return make_instance<MetricType::METRIC_TYPE_L2SQR, IOTemp>(json_obj, common_param);
return make_instance<MetricType::METRIC_TYPE_L2SQR, IOTemp>(flatten_interface_param,
common_param);
}
if (metric == MetricType::METRIC_TYPE_IP) {
return make_instance<MetricType::METRIC_TYPE_IP, IOTemp>(json_obj, common_param);
return make_instance<MetricType::METRIC_TYPE_IP, IOTemp>(flatten_interface_param,
common_param);
}
if (metric == MetricType::METRIC_TYPE_COSINE) {
return make_instance<MetricType::METRIC_TYPE_COSINE, IOTemp>(json_obj, common_param);
return make_instance<MetricType::METRIC_TYPE_COSINE, IOTemp>(flatten_interface_param,
common_param);
}
return nullptr;
}

FlattenInterfacePtr
FlattenInterface::MakeInstance(const nlohmann::json& json_obj,
FlattenInterface::MakeInstance(const JsonType& flatten_interface_param,
const IndexCommonParam& common_param) {
CHECK_ARGUMENT(json_obj.contains(IO_TYPE_KEY),
CHECK_ARGUMENT(flatten_interface_param.contains(IO_TYPE_KEY),
fmt::format("flatten interface parameters must contains {}", IO_TYPE_KEY));
std::string io_string = json_obj[IO_TYPE_KEY];
std::string io_string = flatten_interface_param[IO_TYPE_KEY];
if (io_string == IO_TYPE_VALUE_BLOCK_MEMORY_IO) {
return make_instance<MemoryBlockIO>(json_obj, common_param);
return make_instance<MemoryBlockIO>(flatten_interface_param, common_param);
}
if (io_string == IO_TYPE_VALUE_MEMORY_IO) {
return make_instance<MemoryIO>(json_obj, common_param);
return make_instance<MemoryIO>(flatten_interface_param, common_param);
}
return nullptr;
}
2 changes: 1 addition & 1 deletion src/data_cell/flatten_interface.h
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@ class FlattenInterface {
FlattenInterface() = default;

static FlattenInterfacePtr
MakeInstance(const nlohmann::json& json_obj, const IndexCommonParam& common_param);
MakeInstance(const JsonType& flatten_interface_param, const IndexCommonParam& common_param);

public:
virtual void
155 changes: 155 additions & 0 deletions src/data_cell/graph_datacell.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <limits>
#include <memory>
#include <nlohmann/json.hpp>
#include <unordered_map>
#include <vector>

#include "algorithm/hnswlib/hnswalg.h"
#include "common.h"
#include "graph_interface.h"
#include "index/index_common_param.h"
#include "io/basic_io.h"
#include "vsag/constants.h"

namespace vsag {

/**
* built by nn-descent or incremental insertion
* add neighbors and pruning
* retrieve neighbors
*/
template <typename IOTmpl, bool is_adapter>
class GraphDataCell;

template <typename IOTmpl>
class GraphDataCell<IOTmpl, false> : public GraphInterface {
public:
GraphDataCell(const JsonType& graph_param,
const JsonType& io_param,
const IndexCommonParam& common_param);

void
InsertNeighborsById(InnerIdType id, const Vector<InnerIdType>& neighbor_ids) override;

[[nodiscard]] uint32_t
GetNeighborSize(InnerIdType id) const override;

void
GetNeighbors(InnerIdType id, Vector<InnerIdType>& neighbor_ids) const override;

inline void
SetIO(std::shared_ptr<BasicIO<IOTmpl>> io) {
this->io_ = io;
}

/****
* prefetch neighbors of a base point with id
* @param id of base point
* @param neighbor_i index of neighbor, 0 for neighbor size, 1 for first neighbor
*/
void
Prefetch(InnerIdType id, uint32_t neighbor_i) override {
io_->Prefetch(id * this->code_line_size_ + sizeof(uint32_t) +
neighbor_i * sizeof(InnerIdType));
}

void
Serialize(StreamWriter& writer) override;

void
Deserialize(StreamReader& reader) override;

private:
std::shared_ptr<BasicIO<IOTmpl>> io_{nullptr};

uint32_t code_line_size_{0};
};

template <typename IOTmpl>
GraphDataCell<IOTmpl, false>::GraphDataCell(const JsonType& graph_param,
const JsonType& io_param,
const IndexCommonParam& common_param) {
this->io_ = std::make_shared<IOTmpl>(io_param, common_param);
if (graph_param.contains(GRAPH_PARAM_MAX_DEGREE)) {
this->maximum_degree_ = graph_param[GRAPH_PARAM_MAX_DEGREE];
}

if (graph_param.contains(GRAPH_PARAM_INIT_MAX_CAPACITY)) {
this->max_capacity_ = graph_param[GRAPH_PARAM_INIT_MAX_CAPACITY];
}

this->code_line_size_ = this->maximum_degree_ * sizeof(InnerIdType) + sizeof(uint32_t);
}

template <typename IOTmpl>
void
GraphDataCell<IOTmpl, false>::InsertNeighborsById(InnerIdType id,
const Vector<InnerIdType>& neighbor_ids) {
if (neighbor_ids.size() > this->maximum_degree_) {
logger::warn(fmt::format(
"insert neighbors count {} more than {}", neighbor_ids.size(), this->maximum_degree_));
}
this->max_capacity_ = std::max(this->max_capacity_, id + 1);
auto start = id * this->code_line_size_;
uint32_t neighbor_count = std::min((uint32_t)(neighbor_ids.size()), this->maximum_degree_);
this->io_->Write((uint8_t*)(&neighbor_count), sizeof(neighbor_count), start);
start += sizeof(neighbor_count);
this->io_->Write((uint8_t*)(neighbor_ids.data()), neighbor_count * sizeof(InnerIdType), start);
}

template <typename IOTmpl>
uint32_t
GraphDataCell<IOTmpl, false>::GetNeighborSize(InnerIdType id) const {
auto start = id * this->code_line_size_;
uint32_t result = 0;
this->io_->Read(sizeof(result), start, (uint8_t*)(&result));
return result;
}

template <typename IOTmpl>
void
GraphDataCell<IOTmpl, false>::GetNeighbors(InnerIdType id,
Vector<InnerIdType>& neighbor_ids) const {
auto start = id * this->code_line_size_;
uint32_t neighbor_count = 0;
this->io_->Read(sizeof(neighbor_count), start, (uint8_t*)(&neighbor_count));
neighbor_ids.resize(neighbor_count);
start += sizeof(neighbor_count);
this->io_->Read(
neighbor_ids.size() * sizeof(InnerIdType), start, (uint8_t*)(neighbor_ids.data()));
}

template <typename IOTmpl>
void
GraphDataCell<IOTmpl, false>::Serialize(StreamWriter& writer) {
GraphInterface::Serialize(writer);
this->io_->Serialize(writer);
StreamWriter::WriteObj(writer, this->code_line_size_);
}

template <typename IOTmpl>
void
GraphDataCell<IOTmpl, false>::Deserialize(StreamReader& reader) {
GraphInterface::Deserialize(reader);
this->io_->Deserialize(reader);
StreamReader::ReadObj(reader, this->code_line_size_);
}

} // namespace vsag
Loading

0 comments on commit 385a5d9

Please sign in to comment.