Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-262] fix remainer loss in decimal divide #263

Merged
merged 3 commits into from
Apr 21, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2559,6 +2559,7 @@ class AvgByCountAction<DataType, CType, ResDataType, ResCType,
auto typed_type = std::dynamic_pointer_cast<arrow::Decimal128Type>(type);
auto typed_res_type = std::dynamic_pointer_cast<arrow::Decimal128Type>(res_type);
scale_ = typed_type->scale();
res_precision_ = typed_type->precision();
res_scale_ = typed_res_type->scale();
std::unique_ptr<arrow::ArrayBuilder> builder;
arrow::MakeBuilder(ctx_->memory_pool(), res_type, &builder);
Expand Down Expand Up @@ -2660,10 +2661,11 @@ class AvgByCountAction<DataType, CType, ResDataType, ResCType,
cache_sum_[i] = 0;
} else {
cache_validity_[i] = true;
if (res_scale_ > scale_) {
if (res_scale_ != scale_) {
cache_sum_[i] = cache_sum_[i].Rescale(scale_, res_scale_).ValueOrDie();
}
cache_sum_[i] /= cache_count_[i];
cache_sum_[i] =
divide(cache_sum_[i], res_precision_, res_scale_, cache_count_[i]);
}
}
cache_sum_.resize(length_);
Expand Down Expand Up @@ -2691,11 +2693,12 @@ class AvgByCountAction<DataType, CType, ResDataType, ResCType,
cache_sum_[i + offset] = 0;
} else {
cache_validity_[i + offset] = true;
if (res_scale_ > scale_) {
if (res_scale_ != scale_) {
cache_sum_[i + offset] =
cache_sum_[i + offset].Rescale(scale_, res_scale_).ValueOrDie();
}
cache_sum_[i + offset] /= cache_count_[i + offset];
cache_sum_[i + offset] = divide(cache_sum_[i + offset], res_precision_,
res_scale_, cache_count_[i + offset]);
}
}
for (uint64_t i = 0; i < res_length; i++) {
Expand Down Expand Up @@ -2724,6 +2727,7 @@ class AvgByCountAction<DataType, CType, ResDataType, ResCType,
int in_null_count_ = 0;
// result
int scale_;
int res_precision_;
int res_scale_;
std::vector<ResCType> cache_sum_;
std::vector<int64_t> cache_count_;
Expand Down
7 changes: 7 additions & 0 deletions native-sql-engine/cpp/src/precompile/gandiva.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ arrow::Decimal128 divide(arrow::Decimal128 left, int32_t left_precision,
return arrow::Decimal128(out);
}

arrow::Decimal128 divide(arrow::Decimal128 x, int32_t precision, int32_t scale,
rui-mo marked this conversation as resolved.
Show resolved Hide resolved
int64_t y) {
gandiva::BasicDecimalScalar128 val(x, precision, scale);
arrow::BasicDecimal128 out = gandiva::decimalops::Divide(val, y);
return arrow::Decimal128(out);
}

// A comparison with a NaN always returns false even when comparing with itself.
// To get the same result as spark, we can regard NaN as big as Infinity when
// doing comparison.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ TEST(TestArrowCompute, AggregateTest) {
"[39]",
R"(["345.262397"])",
"[785]",
R"(["0.439824"])",
R"(["0.439825"])",
R"([39])",
R"([8.85288])",
R"([11113.3])"};
Expand Down Expand Up @@ -284,8 +284,8 @@ TEST(TestArrowCompute, GroupByAggregateTest) {
R"(["15.704202", "12.050089", "19.776600", "15.878089", "24.840018",
null, "28.101000", "22.136100", "16.008800", "26.676800", "164.090699"])",
R"([140, 20, 11, 89, 131, null, 57, 27, 10, 89, 211])",
R"(["0.1121728714", "0.6025044500", "1.7978727272", "0.1784054943", "0.1896184580",
null, "0.4930000000", "0.8198555555", "1.6008800000", "0.2997393258", "0.7776810379"])"};
R"(["0.1121728714", "0.6025044500", "1.7978727273", "0.1784054944", "0.1896184580",
null, "0.4930000000", "0.8198555556", "1.6008800000", "0.2997393258", "0.7776810379"])"};
auto res_sch = arrow::schema(ret_types);
MakeInputBatch(expected_result_string, res_sch, &expected_result);
if (aggr_result_iterator->HasNext()) {
Expand Down Expand Up @@ -425,8 +425,8 @@ TEST(TestArrowCompute, GroupByAggregateWSCGTest) {
R"(["15.704202", "12.050089", "19.776600", "15.878089", "24.840018",
null, "28.101000", "22.136100", "16.008800", "26.676800", "164.090699"])",
R"([140, 20, 11, 89, 131, null, 57, 27, 10, 89, 211])",
R"(["0.1121728714", "0.6025044500", "1.7978727272", "0.1784054943", "0.1896184580",
null, "0.4930000000", "0.8198555555", "1.6008800000", "0.2997393258", "0.7776810379"])"};
R"(["0.1121728714", "0.6025044500", "1.7978727273", "0.1784054944", "0.1896184580",
null, "0.4930000000", "0.8198555556", "1.6008800000", "0.2997393258", "0.7776810379"])"};
auto res_sch = arrow::schema(ret_types);
MakeInputBatch(expected_result_string, res_sch, &expected_result);
if (aggr_result_iterator->HasNext()) {
Expand Down
11 changes: 11 additions & 0 deletions native-sql-engine/cpp/src/tests/arrow_compute_test_precompile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@ TEST(TestArrowCompute, ArithmeticDecimalTest) {
ASSERT_EQ(res, arrow::Decimal128("32342423.0129"));
res = arrow::Decimal128("-32342423.012875").Abs();
ASSERT_EQ(res, left);
// decimal divide int test
auto x = arrow::Decimal128("30.222215");
int32_t x_precision = 14;
int32_t x_scale = 6;
int64_t y = 8;
res = x / y;
// wrong result
ASSERT_EQ(res, arrow::Decimal128("3.777776"));
// correct result
res = divide(x, x_precision, x_scale, y);
ASSERT_EQ(res, arrow::Decimal128("3.777777"));
}

TEST(TestArrowCompute, ArithmeticComparisonTest) {
Expand Down
17 changes: 17 additions & 0 deletions native-sql-engine/cpp/src/third_party/gandiva/decimal_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,23 @@ BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x,
return result;
}

BasicDecimal128 Divide(const BasicDecimalScalar128& x, int64_t y) {
if (y == 0) {
throw std::runtime_error("divide by zero error");
}
BasicDecimal128 result;
BasicDecimal128 remainder;
auto status = x.value().Divide(y, &result, &remainder);
DCHECK_EQ(status, arrow::DecimalStatus::kSuccess);
// round-up
// returns 1 for positive and zero values, -1 for negative values.
int64_t y_sign = y < 0 ? -1 : 1;
if (BasicDecimal128::Abs(2 * remainder) >= BasicDecimal128::Abs(y)) {
result += (x.value().Sign() ^ y_sign) + 1;
}
return result;
}

BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x,
const BasicDecimalScalar128& y, int32_t out_precision,
int32_t out_scale, bool* overflow) {
Expand Down
3 changes: 3 additions & 0 deletions native-sql-engine/cpp/src/third_party/gandiva/decimal_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ arrow::BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x,
const BasicDecimalScalar128& y, int32_t out_precision,
int32_t out_scale, bool* overflow);

// Divide 'x'(decimal) by 'y'(int64_t), and return the result.
BasicDecimal128 Divide(const BasicDecimalScalar128& x, int64_t y);

/// Divide 'x' by 'y', and return the remainder.
arrow::BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x,
const BasicDecimalScalar128& y, int32_t out_precision,
Expand Down