Skip to content

Commit

Permalink
use trunc train for sq as default
Browse files Browse the repository at this point in the history
Signed-off-by: LHT129 <tianlan.lht@antgroup.com>
  • Loading branch information
LHT129 committed Feb 11, 2025
1 parent 20b156e commit fe4145c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ ScalarQuantizationTrainer::TrainUniform(const float* data,
std::vector<float> lower(dim_);
if (mode == CLASSIC) {
this->classic_train(sample_datas.data(), sample_count, upper.data(), lower.data());
upper_bound = *std::max_element(upper.begin(), upper.end());
lower_bound = *std::min_element(lower.begin(), lower.end());
} else if (mode == TRUNC_BOUND) {
this->trunc_bound_train(sample_datas.data(), sample_count, upper.data(), lower.data());
}
upper_bound = *std::max_element(upper.begin(), upper.end());
lower_bound = *std::min_element(lower.begin(), lower.end());
}

void
Expand All @@ -81,7 +83,7 @@ ScalarQuantizationTrainer::trunc_bound_train(const float* data,
uint64_t count,
float* upper_bound,
float* lower_bound) const {
auto ignore_count = static_cast<uint64_t>(static_cast<float>(count - 1) * 0.001);
auto ignore_count = static_cast<uint64_t>(static_cast<float>(count - 1) * 0.0001);
for (uint64_t i = 0; i < dim_; ++i) {
std::priority_queue<float, std::vector<float>, std::greater<>> heap_max;
std::priority_queue<float, std::vector<float>, std::less<>> heap_min;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,15 @@ class ScalarQuantizationTrainer {
float* upper_bound,
float* lower_bound,
bool need_normalize = false,
SQTrainMode mode = SQTrainMode::CLASSIC);
SQTrainMode mode = SQTrainMode::TRUNC_BOUND);

void
TrainUniform(const float* data,
uint64_t count,
float& upper_bound,
float& lower_bound,
bool need_normalize = false,
SQTrainMode mode = SQTrainMode::CLASSIC);

void
Encode(const float* origin_data, uint8_t*);
SQTrainMode mode = SQTrainMode::TRUNC_BOUND);

inline void
SetSampleCount(uint64_t sample) {
Expand Down Expand Up @@ -78,7 +75,7 @@ class ScalarQuantizationTrainer {

uint64_t max_sample_count_{MAX_DEFAULT_SAMPLE};

const static uint64_t MAX_DEFAULT_SAMPLE{65536};
constexpr static uint64_t MAX_DEFAULT_SAMPLE{65536};
};

} // namespace vsag

0 comments on commit fe4145c

Please sign in to comment.