Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-130] fix overflow and precision loss (#156)
Browse files Browse the repository at this point in the history
* fixes overflow and remove duplicate cast

* scale down literal
  • Loading branch information
rui-mo authored Mar 14, 2021
1 parent 651fab6 commit 3d4225d
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,36 @@ class ColumnarAdd(left: Expression, right: Expression, original: Expression)
extends Add(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

// If casting between DecimalType, unnecessary cast is skipped to avoid data loss,
// because res type of "cast" is actually the res type of "add/subtract".
val left_val: Any = left match {
case c: ColumnarCast =>
if (c.child.dataType.isInstanceOf[DecimalType] &&
c.dataType.isInstanceOf[DecimalType]) {
c.child
} else {
left
}
case _ =>
left
}
val right_val: Any = right match {
case c: ColumnarCast =>
if (c.child.dataType.isInstanceOf[DecimalType] &&
c.dataType.isInstanceOf[DecimalType]) {
c.child
} else {
right
}
case _ =>
right
}
override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
var (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
left_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
right_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
Expand Down Expand Up @@ -76,11 +101,34 @@ class ColumnarSubtract(left: Expression, right: Expression, original: Expression
extends Subtract(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

val left_val: Any = left match {
case c: ColumnarCast =>
if (c.child.dataType.isInstanceOf[DecimalType] &&
c.dataType.isInstanceOf[DecimalType]) {
c.child
} else {
left
}
case _ =>
left
}
val right_val: Any = right match {
case c: ColumnarCast =>
if (c.child.dataType.isInstanceOf[DecimalType] &&
c.dataType.isInstanceOf[DecimalType]) {
c.child
} else {
right
}
case _ =>
right
}
override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
var (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
left_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
right_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
Expand Down Expand Up @@ -113,6 +161,7 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression
extends Multiply(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
var (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand All @@ -121,10 +170,39 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression

(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
var resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.MULTIPLY, l, r)
// Scaling down the unnecessary scale for Literal to avoid precision loss
val newLeftNode = left match {
case literal: ColumnarLiteral =>
val leftStr = literal.value.asInstanceOf[Decimal].toDouble.toString
val newLeftPrecision = leftStr.length - 1
val newLeftScale = leftStr.split('.')(1).length
val newLeftType =
new ArrowType.Decimal(newLeftPrecision, newLeftScale, 128)
resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.MULTIPLY, newLeftType, r)
TreeBuilder.makeFunction(
"castDECIMAL", Lists.newArrayList(left_node), newLeftType)
case _ =>
left_node
}
val newRightNode = right match {
case literal: ColumnarLiteral =>
val rightStr = literal.value.asInstanceOf[Decimal].toDouble.toString
val newRightPrecision = rightStr.length - 1
val newRightScale = rightStr.split('.')(1).length
val newRightType =
new ArrowType.Decimal(newRightPrecision, newRightScale, 128)
resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.MULTIPLY, l, newRightType)
TreeBuilder.makeFunction(
"castDECIMAL", Lists.newArrayList(right_node), newRightType)
case _ =>
right_node
}
val mulNode = TreeBuilder.makeFunction(
"multiply", Lists.newArrayList(left_node, right_node), resultType)
"multiply", Lists.newArrayList(newLeftNode, newRightNode), resultType)
(mulNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
Expand All @@ -150,6 +228,7 @@ class ColumnarDivide(left: Expression, right: Expression, original: Expression)
extends Divide(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
var (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand All @@ -158,10 +237,38 @@ class ColumnarDivide(left: Expression, right: Expression, original: Expression)

(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
var resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.DIVIDE, l, r)
val newLeftNode = left match {
case literal: ColumnarLiteral =>
val leftStr = literal.value.asInstanceOf[Decimal].toDouble.toString
val newLeftPrecision = leftStr.length - 1
val newLeftScale = leftStr.split('.')(1).length
val newLeftType =
new ArrowType.Decimal(newLeftPrecision, newLeftScale, 128)
resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.DIVIDE, newLeftType, r)
TreeBuilder.makeFunction(
"castDECIMAL", Lists.newArrayList(left_node), newLeftType)
case _ =>
left_node
}
val newRightNode = right match {
case literal: ColumnarLiteral =>
val rightStr = literal.value.asInstanceOf[Decimal].toDouble.toString
val newRightPrecision = rightStr.length - 1
val newRightScale = rightStr.split('.')(1).length
val newRightType =
new ArrowType.Decimal(newRightPrecision, newRightScale, 128)
resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.DIVIDE, l, newRightType)
TreeBuilder.makeFunction(
"castDECIMAL", Lists.newArrayList(right_node), newRightType)
case _ =>
right_node
}
val divNode = TreeBuilder.makeFunction(
"divide", Lists.newArrayList(left_node, right_node), resultType)
"divide", Lists.newArrayList(newLeftNode, newRightNode), resultType)
(divNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,21 +273,21 @@ class ColumnarCheckOverflow(child: Expression, original: CheckOverflow)
override def doColumnarCodeGen(args: Object): (TreeNode, ArrowType) = {
val (child_node, childType): (TreeNode, ArrowType) =
child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
// since spark will call toPrecision in checkOverFlow and rescale from zero, we need to re-calculate result dataType here
val childScale: Int = childType match {
case d: ArrowType.Decimal => d.getScale
case _ => 0
}
val newDataType =
DecimalType(dataType.precision, dataType.scale)
val resType = CodeGeneration.getResultType(newDataType)
var function = "castDECIMAL"
if (nullOnOverflow) {
function = "castDECIMALNullOnOverflow"
if (resType == childType) {
// If target type is the same as childType, cast is not needed
(child_node, childType)
} else {
var function = "castDECIMAL"
if (nullOnOverflow) {
function = "castDECIMALNullOnOverflow"
}
val funcNode =
TreeBuilder.makeFunction(function, Lists.newArrayList(child_node), resType)
(funcNode, resType)
}
val funcNode =
TreeBuilder.makeFunction(function, Lists.newArrayList(child_node), resType)
(funcNode, resType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
<< child_visitor_list[1]->GetResult() << ", " << rightType->precision()
<< ", " << rightType->scale() << ", " << resType->precision() << ", "
<< resType->scale() << ")";
header_list_.push_back(R"(#include "precompile/gandiva.h")");
}
std::stringstream prepare_ss;
prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";"
Expand Down Expand Up @@ -596,6 +597,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
<< child_visitor_list[1]->GetResult() << ", " << rightType->precision()
<< ", " << rightType->scale() << ", " << resType->precision() << ", "
<< resType->scale() << ")";
header_list_.push_back(R"(#include "precompile/gandiva.h")");
}
std::stringstream prepare_ss;
prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";"
Expand Down Expand Up @@ -632,7 +634,8 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
<< leftType->precision() << ", " << leftType->scale() << ", "
<< child_visitor_list[1]->GetResult() << ", " << rightType->precision()
<< ", " << rightType->scale() << ", " << resType->precision() << ", "
<< resType->scale() << ")";
<< resType->scale() << ", &overflow)";
header_list_.push_back(R"(#include "precompile/gandiva.h")");
}
std::stringstream prepare_ss;
prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";"
Expand All @@ -642,7 +645,13 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
child_visitor_list[1]->GetPreCheck()})
<< ");" << std::endl;
prepare_ss << "if (" << validity << ") {" << std::endl;
if (node.return_type()->id() == arrow::Type::DECIMAL) {
prepare_ss << "bool overflow = false;" << std::endl;
}
prepare_ss << codes_str_ << " = " << fix_ss.str() << ";" << std::endl;
if (node.return_type()->id() == arrow::Type::DECIMAL) {
prepare_ss << "if (overflow) {\n" << validity << " = false;}" << std::endl;
}
prepare_ss << "}" << std::endl;

for (int i = 0; i < 2; i++) {
Expand All @@ -669,7 +678,8 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
<< leftType->precision() << ", " << leftType->scale() << ", "
<< child_visitor_list[1]->GetResult() << ", " << rightType->precision()
<< ", " << rightType->scale() << ", " << resType->precision() << ", "
<< resType->scale() << ")";
<< resType->scale() << ", &overflow)";
header_list_.push_back(R"(#include "precompile/gandiva.h")");
}
std::stringstream prepare_ss;
prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";"
Expand All @@ -679,7 +689,13 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
child_visitor_list[1]->GetPreCheck()})
<< ");" << std::endl;
prepare_ss << "if (" << validity << ") {" << std::endl;
if (node.return_type()->id() == arrow::Type::DECIMAL) {
prepare_ss << "bool overflow = false;" << std::endl;
}
prepare_ss << codes_str_ << " = " << fix_ss.str() << ";" << std::endl;
if (node.return_type()->id() == arrow::Type::DECIMAL) {
prepare_ss << "if (overflow) {\n" << validity << " = false;}" << std::endl;
}
prepare_ss << "}" << std::endl;

for (int i = 0; i < 2; i++) {
Expand Down
10 changes: 6 additions & 4 deletions native-sql-engine/cpp/src/precompile/gandiva.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,29 +95,31 @@ arrow::Decimal128 subtract(arrow::Decimal128 left, int32_t left_precision,
arrow::Decimal128 multiply(arrow::Decimal128 left, int32_t left_precision,
int32_t left_scale, arrow::Decimal128 right,
int32_t right_precision, int32_t right_scale,
int32_t out_precision, int32_t out_scale) {
int32_t out_precision, int32_t out_scale,
bool* overflow_) {
gandiva::BasicDecimalScalar128 x(left, left_precision, left_scale);
gandiva::BasicDecimalScalar128 y(right, right_precision, right_scale);
bool overflow = false;
arrow::BasicDecimal128 out =
gandiva::decimalops::Multiply(x, y, out_precision, out_scale, &overflow);
if (overflow) {
throw std::overflow_error("Decimal multiply overflowed!");
*overflow_ = true;
}
return arrow::Decimal128(out);
}

arrow::Decimal128 divide(arrow::Decimal128 left, int32_t left_precision,
int32_t left_scale, arrow::Decimal128 right,
int32_t right_precision, int32_t right_scale,
int32_t out_precision, int32_t out_scale) {
int32_t out_precision, int32_t out_scale,
bool* overflow_) {
gandiva::BasicDecimalScalar128 x(left, left_precision, left_scale);
gandiva::BasicDecimalScalar128 y(right, right_precision, right_scale);
bool overflow = false;
arrow::BasicDecimal128 out =
gandiva::decimalops::Divide(0, x, y, out_precision, out_scale, &overflow);
if (overflow) {
throw std::overflow_error("Decimal divide overflowed!");
*overflow_ = true;
}
return arrow::Decimal128(out);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,21 @@ TEST(TestArrowCompute, ArithmeticDecimalTest) {
int32_t out_scale = 10;
auto res = castDECIMAL(left, left_precision, left_scale, out_precision, out_scale);
ASSERT_EQ(res, arrow::Decimal128("32342423.0128750000"));
bool overflow = false;
res = castDECIMALNullOnOverflow(left, left_precision, left_scale, out_precision,
out_scale, &overflow);
ASSERT_EQ(res, arrow::Decimal128("32342423.0128750000"));
res = add(left, left_precision, left_scale, right, right_precision, right_scale,
17, 9);
ASSERT_EQ(res, arrow::Decimal128("32344770.025749535"));
res = subtract(left, left_precision, left_scale, right, right_precision, right_scale,
17, 9);
ASSERT_EQ(res, arrow::Decimal128("32340076.000000465"));
res = multiply(left, left_precision, left_scale, right, right_precision, right_scale,
28, 15);
28, 15, &overflow);
ASSERT_EQ(res, arrow::Decimal128("75908083204.874689064638125"));
res = divide(left, left_precision, left_scale, right, right_precision, right_scale,
out_precision, out_scale);
out_precision, out_scale, &overflow);
ASSERT_EQ(res, arrow::Decimal128("13780.2495094037"));
}

Expand Down

0 comments on commit 3d4225d

Please sign in to comment.