diff --git a/src/data_cell/bucket_datacell_parameter.cpp b/src/data_cell/bucket_datacell_parameter.cpp new file mode 100644 index 00000000..3a9659ce --- /dev/null +++ b/src/data_cell/bucket_datacell_parameter.cpp @@ -0,0 +1,50 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bucket_datacell_parameter.h" + +#include + +#include "inner_string_params.h" + +namespace vsag { +BucketDataCellParameter::BucketDataCellParameter() = default; + +void +BucketDataCellParameter::FromJson(const JsonType& json) { + CHECK_ARGUMENT(json.contains(IO_PARAMS_KEY), + fmt::format("bucket interface parameters must contains {}", IO_PARAMS_KEY)); + this->io_parameter_ = IOParameter::GetIOParameterByJson(json[IO_PARAMS_KEY]); + + CHECK_ARGUMENT( + json.contains(QUANTIZATION_PARAMS_KEY), + fmt::format("bucket interface parameters must contains {}", QUANTIZATION_PARAMS_KEY)); + this->quantizer_parameter_ = + QuantizerParameter::GetQuantizerParameterByJson(json[QUANTIZATION_PARAMS_KEY]); + + if (json.contains(BUCKETS_COUNT_KEY)) { + this->buckets_count_ = json[BUCKETS_COUNT_KEY]; + } +} + +JsonType +BucketDataCellParameter::ToJson() { + JsonType json; + json[IO_PARAMS_KEY] = this->io_parameter_->ToJson(); + json[QUANTIZATION_PARAMS_KEY] = this->quantizer_parameter_->ToJson(); + json[BUCKETS_COUNT_KEY] = this->buckets_count_; + return json; +} +} // namespace vsag diff --git a/src/data_cell/bucket_datacell_parameter.h b/src/data_cell/bucket_datacell_parameter.h new file mode 100644 index 00000000..1fad22c2 --- /dev/null +++ b/src/data_cell/bucket_datacell_parameter.h @@ -0,0 +1,44 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "io/io_parameter.h" +#include "parameter.h" +#include "quantization/quantizer_parameter.h" + +namespace vsag { + +class BucketDataCellParameter : public Parameter { +public: + explicit BucketDataCellParameter(); + + void + FromJson(const JsonType& json) override; + + JsonType + ToJson() override; + +public: + QuantizerParamPtr quantizer_parameter_{nullptr}; + + IOParamPtr io_parameter_{nullptr}; + + int64_t buckets_count_{1}; +}; + +using BucketDataCellParamPtr = std::shared_ptr; + +} // namespace vsag diff --git a/src/data_cell/bucket_datacell_parameter_test.cpp b/src/data_cell/bucket_datacell_parameter_test.cpp new file mode 100644 index 00000000..7237e0b8 --- /dev/null +++ b/src/data_cell/bucket_datacell_parameter_test.cpp @@ -0,0 +1,106 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bucket_datacell_parameter.h" + +#include + +#include "parameter_test.h" + +using namespace vsag; + +TEST_CASE("BucketDataCellParameter ToJson Test", "[ut][BucketDataCellParameter]") { + std::string param_str = R"( + { + "io_params": { + "type": "memory_io" + }, + "quantization_params": { + "type": "sq8" + }, + "buckets_count": 10 + })"; + auto param = std::make_shared(); + auto json = JsonType::parse(param_str); + param->FromJson(json); + REQUIRE(param->buckets_count_ == 10); + ParameterTest::TestToJson(param); +} + +TEST_CASE("BucketDataCellParameter Parse Exception", "[ut][BucketDataCellParameter]") { + auto check_param = [](const std::string& str) -> BucketDataCellParamPtr { + auto param = std::make_shared(); + auto json = JsonType::parse(str); + param->FromJson(json); + return param; + }; + + SECTION("miss io param") { + std::string param_str = R"( + { + "quantization_params": { + "type": "sq8", + }, + "buckets_count": 10 + })"; + REQUIRE_THROWS(check_param(param_str)); + } + + SECTION("miss quantization param") { + std::string param_str = R"( + { + "io_params": { + "type": "memory_io" + }, + "buckets_count": 10 + })"; + REQUIRE_THROWS(check_param(param_str)); + } + + SECTION("wrong io param type") { + std::string param_str = R"( + { + "io_params": { + "type": "wrong_io" + }, + "buckets_count": 10 + })"; + REQUIRE_THROWS(check_param(param_str)); + } + + SECTION("wrong quantization param type") { + std::string param_str = R"( + { + "quantization_params": { + "type": "wrong_quantization", + }, + "buckets_count": 10 + })"; + REQUIRE_THROWS(check_param(param_str)); + } + + SECTION("valid on missing buckets_count") { + std::string param_str = R"( + { + "io_params": { + "type": "memory_io" + }, + "quantization_params": { + "type": "sq8" + } + })"; + auto param = check_param(param_str); + } +} diff --git a/src/data_cell/flatten_datacell_parameter.cpp b/src/data_cell/flatten_datacell_parameter.cpp index e7456d8b..abce63ec 100644 --- a/src/data_cell/flatten_datacell_parameter.cpp +++ b/src/data_cell/flatten_datacell_parameter.cpp @@ -20,8 +20,7 @@ #include "inner_string_params.h" namespace vsag { -FlattenDataCellParameter::FlattenDataCellParameter() { -} +FlattenDataCellParameter::FlattenDataCellParameter() = default; void FlattenDataCellParameter::FromJson(const JsonType& json) { diff --git a/src/inner_string_params.h b/src/inner_string_params.h index feda3410..ae9d0fd0 100644 --- a/src/inner_string_params.h +++ b/src/inner_string_params.h @@ -55,6 +55,8 @@ const char* const BUILD_PARAMS_KEY = "build_params"; const char* const BUILD_THREAD_COUNT = "build_thread_count"; const char* const BUILD_EF_CONSTRUCTION = "ef_construction"; +const char* const BUCKETS_COUNT_KEY = "buckets_count"; + const std::unordered_map DEFAULT_MAP = { {"INDEX_TYPE_HGRAPH", INDEX_TYPE_HGRAPH}, {"HGRAPH_USE_REORDER_KEY", HGRAPH_USE_REORDER_KEY}, @@ -76,6 +78,7 @@ const std::unordered_map DEFAULT_MAP = { {"BUILD_PARAMS_KEY", BUILD_PARAMS_KEY}, {"BUILD_THREAD_COUNT", BUILD_THREAD_COUNT}, {"BUILD_EF_CONSTRUCTION", BUILD_EF_CONSTRUCTION}, + {"BUCKETS_COUNT_KEY", BUCKETS_COUNT_KEY}, }; } // namespace vsag