Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-130] support decimal round and abs #166

Merged
merged 2 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class ColumnarAdd(left: Expression, right: Expression, original: Expression)
with Logging {

// If casting between DecimalType, unnecessary cast is skipped to avoid data loss,
// because res type of "cast" is actually the res type of "add/subtract".
// because actually res type of "cast" is the res type in "add/subtract",
// and is the wider type in "multiply/divide".
val left_val: Any = left match {
case c: ColumnarCast =>
if (c.child.dataType.isInstanceOf[DecimalType] &&
Expand Down Expand Up @@ -162,18 +163,41 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression
with ColumnarExpression
with Logging {

val left_val: Any = left match {
case c: ColumnarCast =>
if (c.child.dataType.isInstanceOf[DecimalType] &&
c.dataType.isInstanceOf[DecimalType]) {
c.child
} else {
left
}
case _ =>
left
}
val right_val: Any = right match {
case c: ColumnarCast =>
if (c.child.dataType.isInstanceOf[DecimalType] &&
c.dataType.isInstanceOf[DecimalType]) {
c.child
} else {
right
}
case _ =>
right
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
var (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
left_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
right_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
var resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.MULTIPLY, l, r)
// Scaling down the unnecessary scale for Literal to avoid precision loss
val newLeftNode = left match {
val newLeftNode = left_val match {
case literal: ColumnarLiteral =>
val leftStr = literal.value.asInstanceOf[Decimal].toDouble.toString
val newLeftPrecision = leftStr.length - 1
Expand All @@ -187,7 +211,7 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression
case _ =>
left_node
}
val newRightNode = right match {
val newRightNode = right_val match {
case literal: ColumnarLiteral =>
val rightStr = literal.value.asInstanceOf[Decimal].toDouble.toString
val newRightPrecision = rightStr.length - 1
Expand Down Expand Up @@ -230,11 +254,33 @@ class ColumnarDivide(left: Expression, right: Expression,
with ColumnarExpression
with Logging {

val left_val: Any = left match {
case c: ColumnarCast =>
if (c.child.dataType.isInstanceOf[DecimalType] &&
c.dataType.isInstanceOf[DecimalType]) {
c.child
} else {
left
}
case _ =>
left
}
val right_val: Any = right match {
case c: ColumnarCast =>
if (c.child.dataType.isInstanceOf[DecimalType] &&
c.dataType.isInstanceOf[DecimalType]) {
c.child
} else {
right
}
case _ =>
right
}
override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
var (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
left_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
right_val.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
Expand All @@ -244,7 +290,7 @@ class ColumnarDivide(left: Expression, right: Expression,
DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.DIVIDE, l, r)
}
val newLeftNode = left match {
val newLeftNode = left_val match {
case literal: ColumnarLiteral =>
val leftStr = literal.value.asInstanceOf[Decimal].toDouble.toString
val newLeftPrecision = leftStr.length - 1
Expand All @@ -258,7 +304,7 @@ class ColumnarDivide(left: Expression, right: Expression,
case _ =>
left_node
}
val newRightNode = right match {
val newRightNode = right_val match {
case literal: ColumnarLiteral =>
val rightStr = literal.value.asInstanceOf[Decimal].toDouble.toString
val newRightPrecision = rightStr.length - 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class ColumnarRound(child: Expression, scale: Expression, original: Expression)
buildCheck()

def buildCheck(): Unit = {
if (child.dataType != DoubleType) {
val supportedTypes = List(FloatType, DoubleType, IntegerType, LongType)
if (supportedTypes.indexOf(child.dataType) == -1 &&
!child.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${child.dataType} is not supported in ColumnarRound")
}
Expand All @@ -54,7 +56,7 @@ class ColumnarRound(child: Expression, scale: Expression, original: Expression)
val (scale_node, scaleType): (TreeNode, ArrowType) =
scale.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
val resultType = CodeGeneration.getResultType(dataType)
val funcNode = TreeBuilder.makeFunction("round",
Lists.newArrayList(child_node, scale_node), resultType)
(funcNode, resultType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,9 @@ class ColumnarAbs(child: Expression, original: Expression)
buildCheck()

def buildCheck(): Unit = {
val supportedTypes = List(FloatType, DoubleType)
if (supportedTypes.indexOf(child.dataType) == -1) {
val supportedTypes = List(FloatType, DoubleType, IntegerType, LongType)
if (supportedTypes.indexOf(child.dataType) == -1 &&
!child.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${child.dataType} is not supported in ColumnarAbs")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,15 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
codes_str_ = func_name + "_" + std::to_string(cur_func_id);
auto validity = codes_str_ + "_validity";
std::stringstream fix_ss;
if (node.return_type()->id() != arrow::Type::DECIMAL) {
fix_ss << "round2(" << child_visitor_list[0]->GetResult();
} else {
auto childNode = node.children().at(0);
auto childType =
std::dynamic_pointer_cast<arrow::Decimal128Type>(childNode->return_type());
fix_ss << "round(" << child_visitor_list[0]->GetResult() << ", "
<< childType->precision() << ", " << childType->scale() << ", &overflow";
}
if (child_visitor_list.size() > 1) {
fix_ss << ", " << child_visitor_list[1]->GetResult();
}
Expand All @@ -514,8 +523,13 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
prepare_ss << "bool " << validity << " = " << child_visitor_list[0]->GetPreCheck()
<< ";" << std::endl;
prepare_ss << "if (" << validity << ") {" << std::endl;
prepare_ss << codes_str_ << " = round2(" << child_visitor_list[0]->GetResult()
<< fix_ss.str() << ");" << std::endl;
if (node.return_type()->id() == arrow::Type::DECIMAL) {
prepare_ss << "bool overflow = false;" << std::endl;
}
prepare_ss << codes_str_ << " = " << fix_ss.str() << ");" << std::endl;
if (node.return_type()->id() == arrow::Type::DECIMAL) {
prepare_ss << "if (overflow) {\n" << validity << " = false;}" << std::endl;
}
prepare_ss << "}" << std::endl;

prepare_str_ += prepare_ss.str();
Expand All @@ -524,14 +538,19 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
} else if (func_name.compare("abs") == 0) {
codes_str_ = "abs_" + std::to_string(cur_func_id);
auto validity = codes_str_ + "_validity";
std::stringstream fix_ss;
if (node.return_type()->id() != arrow::Type::DECIMAL) {
fix_ss << "abs(" << child_visitor_list[0]->GetResult() << ")";
} else {
fix_ss << child_visitor_list[0]->GetResult() << ".Abs()";
}
std::stringstream prepare_ss;
prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";"
<< std::endl;
prepare_ss << "bool " << validity << " = " << child_visitor_list[0]->GetPreCheck()
<< ";" << std::endl;
prepare_ss << "if (" << validity << ") {" << std::endl;
prepare_ss << codes_str_ << " = abs(" << child_visitor_list[0]->GetResult() << ");"
<< std::endl;
prepare_ss << codes_str_ << " = " << fix_ss.str() << ";" << std::endl;
prepare_ss << "}" << std::endl;

for (int i = 0; i < 1; i++) {
Expand Down
15 changes: 15 additions & 0 deletions native-sql-engine/cpp/src/precompile/gandiva.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,18 @@ bool equal_with_nan(double left, double right) {
}
return left == right;
}

arrow::Decimal128 round(arrow::Decimal128 in,
int32_t original_precision,
int32_t original_scale,
bool* overflow_,
int32_t res_scale = 2) {
bool overflow = false;
gandiva::BasicDecimalScalar128 val(in, original_precision, original_scale);
auto out = gandiva::decimalops::Round(val, original_precision, res_scale,
res_scale, &overflow);
if (overflow) {
*overflow_ = true;
}
return arrow::Decimal128(out);
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ TEST(TestArrowCompute, ArithmeticDecimalTest) {
res = divide(left, left_precision, left_scale, right, right_precision, right_scale,
out_precision, out_scale, &overflow);
ASSERT_EQ(res, arrow::Decimal128("13780.2495094037"));
res = round(left, left_precision, left_scale, &overflow, 4);
ASSERT_EQ(res, arrow::Decimal128("32342423.0129"));
res = arrow::Decimal128("-32342423.012875").Abs();
ASSERT_EQ(res, left);
}

TEST(TestArrowCompute, ArithmeticComparisonTest) {
Expand Down