From 5e2ca15e331b3f3c5e75a09371d199174d3a7c67 Mon Sep 17 00:00:00 2001 From: Rui Mo Date: Mon, 15 Mar 2021 05:55:33 +0000 Subject: [PATCH] fix precision loss in divide --- .../oap/expression/ColumnarArithmetic.scala | 22 +++++++++++-- .../ColumnarExpressionConverter.scala | 32 +++++++++++++++---- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala index 0fbd8b297..510cb2efe 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala @@ -224,7 +224,8 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression } } -class ColumnarDivide(left: Expression, right: Expression, original: Expression) +class ColumnarDivide(left: Expression, right: Expression, + original: Expression, resType: DecimalType = null) extends Divide(left: Expression, right: Expression) with ColumnarExpression with Logging { @@ -237,8 +238,12 @@ class ColumnarDivide(left: Expression, right: Expression, original: Expression) (left_type, right_type) match { case (l: ArrowType.Decimal, r: ArrowType.Decimal) => - var resultType = DecimalTypeUtil.getResultTypeForOperation( - DecimalTypeUtil.OperationType.DIVIDE, l, r) + var resultType = if (resType != null) { + CodeGeneration.getResultType(resType) + } else { + DecimalTypeUtil.getResultTypeForOperation( + DecimalTypeUtil.OperationType.DIVIDE, l, r) + } val newLeftNode = left match { case literal: ColumnarLiteral => val leftStr = literal.value.asInstanceOf[Decimal].toDouble.toString @@ -374,6 +379,17 @@ object ColumnarBinaryArithmetic { } } + def createDivide(left: Expression, right: Expression, + original: Expression, resType: DecimalType): Expression = { + buildCheck(left, right) + original match { + case d: Divide => + new ColumnarDivide(left, right, d, resType) + case other => + throw new UnsupportedOperationException(s"not currently supported: $other.") + } + } + def buildCheck(left: Expression, right: Expression): Unit = { try { ConverterUtils.checkIfTypeSupported(left.dataType) diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala index 053d836a6..20896bee6 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences +import org.apache.spark.sql.types.DecimalType object ColumnarExpressionConverter extends Logging { var check_if_no_calculation = true @@ -245,12 +246,31 @@ object ColumnarExpressionConverter extends Logging { expr) case u: UnaryExpression => logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.") - ColumnarUnaryOperator.create( - replaceWithColumnarExpression( - u.child, - attributeSeq, - convertBoundRefToAttrRef = convertBoundRefToAttrRef), - expr) + if (!u.isInstanceOf[CheckOverflow] || !u.child.isInstanceOf[Divide]) { + ColumnarUnaryOperator.create( + replaceWithColumnarExpression( + u.child, + attributeSeq, + convertBoundRefToAttrRef = convertBoundRefToAttrRef), + expr) + } else { + // CheckOverflow[Divide]: pass resType to Divide to avoid precision loss + val divide = u.child.asInstanceOf[Divide] + val columnarDivide = ColumnarBinaryArithmetic.createDivide( + replaceWithColumnarExpression( + divide.left, + attributeSeq, + convertBoundRefToAttrRef = convertBoundRefToAttrRef), + replaceWithColumnarExpression( + divide.right, + attributeSeq, + convertBoundRefToAttrRef = convertBoundRefToAttrRef), + divide, + u.dataType.asInstanceOf[DecimalType]) + ColumnarUnaryOperator.create( + columnarDivide, + expr) + } case oaps: com.intel.oap.expression.ColumnarScalarSubquery => oaps case s: org.apache.spark.sql.execution.ScalarSubquery =>