From fe4145c9a84e32baf60af8291bd7dadd461a6e61 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Tue, 21 Jan 2025 12:09:20 +0000 Subject: [PATCH] use trunc train for sq as default Signed-off-by: LHT129 --- .../scalar_quantization/scalar_quantization_trainer.cpp | 8 +++++--- .../scalar_quantization/scalar_quantization_trainer.h | 9 +++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/quantization/scalar_quantization/scalar_quantization_trainer.cpp b/src/quantization/scalar_quantization/scalar_quantization_trainer.cpp index fd41a84f..4b569bb6 100644 --- a/src/quantization/scalar_quantization/scalar_quantization_trainer.cpp +++ b/src/quantization/scalar_quantization/scalar_quantization_trainer.cpp @@ -55,9 +55,11 @@ ScalarQuantizationTrainer::TrainUniform(const float* data, std::vector lower(dim_); if (mode == CLASSIC) { this->classic_train(sample_datas.data(), sample_count, upper.data(), lower.data()); - upper_bound = *std::max_element(upper.begin(), upper.end()); - lower_bound = *std::min_element(lower.begin(), lower.end()); + } else if (mode == TRUNC_BOUND) { + this->trunc_bound_train(sample_datas.data(), sample_count, upper.data(), lower.data()); } + upper_bound = *std::max_element(upper.begin(), upper.end()); + lower_bound = *std::min_element(lower.begin(), lower.end()); } void @@ -81,7 +83,7 @@ ScalarQuantizationTrainer::trunc_bound_train(const float* data, uint64_t count, float* upper_bound, float* lower_bound) const { - auto ignore_count = static_cast(static_cast(count - 1) * 0.001); + auto ignore_count = static_cast(static_cast(count - 1) * 0.0001); for (uint64_t i = 0; i < dim_; ++i) { std::priority_queue, std::greater<>> heap_max; std::priority_queue, std::less<>> heap_min; diff --git a/src/quantization/scalar_quantization/scalar_quantization_trainer.h b/src/quantization/scalar_quantization/scalar_quantization_trainer.h index da95cecb..c1a0b561 100644 --- a/src/quantization/scalar_quantization/scalar_quantization_trainer.h +++ b/src/quantization/scalar_quantization/scalar_quantization_trainer.h @@ -37,7 +37,7 @@ class ScalarQuantizationTrainer { float* upper_bound, float* lower_bound, bool need_normalize = false, - SQTrainMode mode = SQTrainMode::CLASSIC); + SQTrainMode mode = SQTrainMode::TRUNC_BOUND); void TrainUniform(const float* data, @@ -45,10 +45,7 @@ class ScalarQuantizationTrainer { float& upper_bound, float& lower_bound, bool need_normalize = false, - SQTrainMode mode = SQTrainMode::CLASSIC); - - void - Encode(const float* origin_data, uint8_t*); + SQTrainMode mode = SQTrainMode::TRUNC_BOUND); inline void SetSampleCount(uint64_t sample) { @@ -78,7 +75,7 @@ class ScalarQuantizationTrainer { uint64_t max_sample_count_{MAX_DEFAULT_SAMPLE}; - const static uint64_t MAX_DEFAULT_SAMPLE{65536}; + constexpr static uint64_t MAX_DEFAULT_SAMPLE{65536}; }; } // namespace vsag