diff --git a/core/src/main/scala/com/intel/oap/execution/ColumnarBasicPhysicalOperators.scala b/core/src/main/scala/com/intel/oap/execution/ColumnarBasicPhysicalOperators.scala index fbff56e57..d39101abc 100644 --- a/core/src/main/scala/com/intel/oap/execution/ColumnarBasicPhysicalOperators.scala +++ b/core/src/main/scala/com/intel/oap/execution/ColumnarBasicPhysicalOperators.scala @@ -27,8 +27,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} +import org.apache.spark.sql.types.{DecimalType, StructType} import org.apache.spark.util.ExecutorManager import org.apache.spark.sql.util.StructTypeFWD import org.apache.spark.{SparkConf, TaskContext} @@ -70,8 +70,10 @@ case class ColumnarConditionProjectExec( ConverterUtils.checkIfTypeSupported(attr.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${attr.dataType} is not supported in ColumnarConditionProjector.") + if (!attr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${attr.dataType} is not supported in ColumnarConditionProjector.") + } } }) // check expr @@ -80,8 +82,10 @@ case class ColumnarConditionProjectExec( ConverterUtils.checkIfTypeSupported(condExpr.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${condExpr.dataType} is not supported in ColumnarConditionProjector.") + if (!condExpr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${condExpr.dataType} is not supported in ColumnarConditionProjector.") + } } ColumnarExpressionConverter.replaceWithColumnarExpression(condExpr) } @@ -91,8 +95,10 @@ case class ColumnarConditionProjectExec( ConverterUtils.checkIfTypeSupported(expr.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${expr.dataType} is not supported in ColumnarConditionProjector.") + if (!expr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${expr.dataType} is not supported in ColumnarConditionProjector.") + } } ColumnarExpressionConverter.replaceWithColumnarExpression(expr) } diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala index e2319c910..1dd5e3c06 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala @@ -18,19 +18,20 @@ package com.intel.oap.expression import com.google.common.collect.Lists - import org.apache.arrow.gandiva.evaluator._ import org.apache.arrow.gandiva.exceptions.GandivaException import org.apache.arrow.gandiva.expression._ +import org.apache.arrow.vector.types.FloatingPointPrecision import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.arrow.vector.types.pojo.Field - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import scala.collection.mutable.ListBuffer +import org.apache.arrow.gandiva.evaluator.DecimalTypeUtil + /** * A version of add that supports columnar processing for longs. */ @@ -44,22 +45,30 @@ class ColumnarAdd(left: Expression, right: Expression, original: Expression) var (right_node, right_type): (TreeNode, ArrowType) = right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) - val resultType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType), + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + val resultType = DecimalTypeUtil.getResultTypeForOperation( + DecimalTypeUtil.OperationType.ADD, l, r) + val addNode = TreeBuilder.makeFunction( + "add", Lists.newArrayList(left_node, right_node), resultType) + (addNode, resultType) + case _ => + val resultType = CodeGeneration.getResultType(left_type, right_type) + if (!left_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType) + } + if (!right_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType) + } + //logInfo(s"(TreeBuilder.makeFunction(add, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") + val funcNode = TreeBuilder.makeFunction( + "add", Lists.newArrayList(left_node, right_node), resultType) + (funcNode, resultType) } - if (!right_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType), - } - - //logInfo(s"(TreeBuilder.makeFunction(add, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") - ( - TreeBuilder.makeFunction("add", Lists.newArrayList(left_node, right_node), resultType), - resultType) } } @@ -73,21 +82,30 @@ class ColumnarSubtract(left: Expression, right: Expression, original: Expression var (right_node, right_type): (TreeNode, ArrowType) = right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) - val resultType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType), + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + val resultType = DecimalTypeUtil.getResultTypeForOperation( + DecimalTypeUtil.OperationType.SUBTRACT, l, r) + val subNode = TreeBuilder.makeFunction( + "subtract", Lists.newArrayList(left_node, right_node), resultType) + (subNode, resultType) + case _ => + val resultType = CodeGeneration.getResultType(left_type, right_type) + if (!left_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType) + } + if (!right_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType) + } + //logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") + val funcNode = TreeBuilder.makeFunction( + "subtract", Lists.newArrayList(left_node, right_node), resultType) + (funcNode, resultType) } - if (!right_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType), - } - //logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") - ( - TreeBuilder.makeFunction("subtract", Lists.newArrayList(left_node, right_node), resultType), - resultType) } } @@ -101,22 +119,30 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression var (right_node, right_type): (TreeNode, ArrowType) = right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) - val resultType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType), - } - if (!right_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType), + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + val resultType = DecimalTypeUtil.getResultTypeForOperation( + DecimalTypeUtil.OperationType.MULTIPLY, l, r) + val mulNode = TreeBuilder.makeFunction( + "multiply", Lists.newArrayList(left_node, right_node), resultType) + (mulNode, resultType) + case _ => + val resultType = CodeGeneration.getResultType(left_type, right_type) + if (!left_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType), + } + if (!right_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType), + } + //logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") + val funcNode = TreeBuilder.makeFunction( + "multiply", Lists.newArrayList(left_node, right_node), resultType) + (funcNode, resultType) } - - //logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") - ( - TreeBuilder.makeFunction("multiply", Lists.newArrayList(left_node, right_node), resultType), - resultType) } } @@ -130,21 +156,30 @@ class ColumnarDivide(left: Expression, right: Expression, original: Expression) var (right_node, right_type): (TreeNode, ArrowType) = right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) - val resultType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType), - } - if (!right_type.equals(resultType)) { - val func_name = CodeGeneration.getCastFuncName(resultType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType), + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + val resultType = DecimalTypeUtil.getResultTypeForOperation( + DecimalTypeUtil.OperationType.DIVIDE, l, r) + val divNode = TreeBuilder.makeFunction( + "divide", Lists.newArrayList(left_node, right_node), resultType) + (divNode, resultType) + case _ => + val resultType = CodeGeneration.getResultType(left_type, right_type) + if (!left_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType), + } + if (!right_type.equals(resultType)) { + val func_name = CodeGeneration.getCastFuncName(resultType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType), + } + //logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") + val funcNode = TreeBuilder.makeFunction( + "divide", Lists.newArrayList(left_node, right_node), resultType) + (funcNode, resultType) } - //logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)") - ( - TreeBuilder.makeFunction("divide", Lists.newArrayList(left_node, right_node), resultType), - resultType) } } @@ -238,8 +273,11 @@ object ColumnarBinaryArithmetic { ConverterUtils.checkIfTypeSupported(right.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${left.dataType} or ${right.dataType} is not supported in ColumnarBinaryArithmetic") + if (!left.dataType.isInstanceOf[DecimalType] || + !right.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${left.dataType} or ${right.dataType} is not supported in ColumnarBinaryArithmetic") + } } } } diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarBinaryOperator.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarBinaryOperator.scala index f39b68b0b..1991f2e5a 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarBinaryOperator.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarBinaryOperator.scala @@ -146,16 +146,21 @@ class ColumnarEqualTo(left: Expression, right: Expression, original: Expression) right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) val unifiedType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) - } - if (!right_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + case _ => + if (!left_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) + } + if (!right_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + } } + var function = "equal" val nanCheck = ColumnarPluginConfig.getConf.enableColumnarNaNCheck if (nanCheck) { @@ -183,16 +188,21 @@ class ColumnarEqualNull(left: Expression, right: Expression, original: Expressio right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) val unifiedType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) - } - if (!right_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + case _ => + if (!left_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) + } + if (!right_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + } } + var function = "equal" val nanCheck = ColumnarPluginConfig.getConf.enableColumnarNaNCheck if (nanCheck) { @@ -220,16 +230,21 @@ class ColumnarLessThan(left: Expression, right: Expression, original: Expression right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) val unifiedType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) - } - if (!right_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + case _ => + if (!left_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) + } + if (!right_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + } } + var function = "less_than" val nanCheck = ColumnarPluginConfig.getConf.enableColumnarNaNCheck if (nanCheck) { @@ -257,16 +272,21 @@ class ColumnarLessThanOrEqual(left: Expression, right: Expression, original: Exp right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) val unifiedType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) - } - if (!right_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + case _ => + if (!left_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) + } + if (!right_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + } } + var function = "less_than_or_equal_to" val nanCheck = ColumnarPluginConfig.getConf.enableColumnarNaNCheck if (nanCheck) { @@ -296,15 +316,19 @@ class ColumnarGreaterThan(left: Expression, right: Expression, original: Express right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) val unifiedType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) - } - if (!right_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + case _ => + if (!left_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) + } + if (!right_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + } } var function = "greater_than" @@ -336,16 +360,21 @@ class ColumnarGreaterThanOrEqual(left: Expression, right: Expression, original: right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) val unifiedType = CodeGeneration.getResultType(left_type, right_type) - if (!left_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - left_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) - } - if (!right_type.equals(unifiedType)) { - val func_name = CodeGeneration.getCastFuncName(unifiedType) - right_node = - TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + (left_type, right_type) match { + case (l: ArrowType.Decimal, r: ArrowType.Decimal) => + case _ => + if (!left_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + left_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), unifiedType) + } + if (!right_type.equals(unifiedType)) { + val func_name = CodeGeneration.getCastFuncName(unifiedType) + right_node = + TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), unifiedType) + } } + var function = "greater_than_or_equal_to" val nanCheck = ColumnarPluginConfig.getConf.enableColumnarNaNCheck if (nanCheck) { @@ -452,8 +481,11 @@ object ColumnarBinaryOperator { ConverterUtils.checkIfTypeSupported(right.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${left.dataType} or ${right.dataType} is not supported in ColumnarBinaryOperator") + if (!left.dataType.isInstanceOf[DecimalType] || + !right.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${left.dataType} or ${right.dataType} is not supported in ColumnarBinaryOperator") + } } } } diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarBoundAttribute.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarBoundAttribute.scala index 64d784030..0dd8ba5cb 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarBoundAttribute.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarBoundAttribute.scala @@ -42,8 +42,10 @@ class ColumnarBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean ConverterUtils.checkIfTypeSupported(dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${dataType} is not supported in ColumnarBoundReference.") + if (!dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${dataType} is not supported in ColumnarBoundReference.") + } } } override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = { diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarCaseWhenOperator.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarCaseWhenOperator.scala index 0bc161d1f..6f5d9849c 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarCaseWhenOperator.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarCaseWhenOperator.scala @@ -52,8 +52,10 @@ class ColumnarCaseWhen( ConverterUtils.checkIfTypeSupported(expr.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${dataType} is not supported in ColumnarCaseWhen") + if (!expr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${dataType} is not supported in ColumnarCaseWhen") + } }) } diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarCoalesceOperator.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarCoalesceOperator.scala index 00b422364..f7907225a 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarCoalesceOperator.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarCoalesceOperator.scala @@ -53,8 +53,10 @@ class ColumnarCoalesce(exps: Seq[Expression], original: Expression) ConverterUtils.checkIfTypeSupported(expr.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${expr.dataType} is not supported in ColumnarCoalesce") + if (!expr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${expr.dataType} is not supported in ColumnarCoalesce") + } } ) } diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarIfOperator.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarIfOperator.scala index 7fbf22772..0c2700cbb 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarIfOperator.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarIfOperator.scala @@ -43,9 +43,13 @@ class ColumnarIf(predicate: Expression, trueValue: Expression, ConverterUtils.checkIfTypeSupported(falseValue.dataType) } catch { case e : UnsupportedOperationException => - throw new UnsupportedOperationException( - s"${predicate.dataType} or ${trueValue.dataType} or ${falseValue.dataType} " + - s"is not supported in ColumnarIf") + if (!predicate.dataType.isInstanceOf[DecimalType] || + !trueValue.dataType.isInstanceOf[DecimalType] || + !falseValue.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${predicate.dataType} or ${trueValue.dataType} or ${falseValue.dataType} " + + s"is not supported in ColumnarIf") + } } } diff --git a/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala b/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala index 6daa85c73..98d49d7c9 100644 --- a/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala +++ b/core/src/main/scala/com/intel/oap/expression/ColumnarUnaryOperator.scala @@ -249,11 +249,7 @@ class ColumnarCheckOverflow(child: Expression, original: CheckOverflow) val (child_node, childType): (TreeNode, ArrowType) = child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) // since spark will call toPrecision in checkOverFlow and rescale from zero, we need to re-calculate result dataType here - val childScale: Int = childType match { - case d: ArrowType.Decimal => d.getScale - case _ => 0 - } - val newDataType = DecimalType(dataType.precision, dataType.scale + childScale) + val newDataType = DecimalType(dataType.precision, dataType.scale) val resType = CodeGeneration.getResultType(newDataType) val funcNode = TreeBuilder.makeFunction( "castDECIMAL", diff --git a/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc b/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc index 87d59b29f..750741162 100644 --- a/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc +++ b/cpp/src/codegen/arrow_compute/ext/expression_codegen_visitor.cc @@ -277,20 +277,33 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) func_name.compare("castDECIMAL") != 0) { codes_str_ = func_name + "_" + std::to_string(cur_func_id); auto validity = codes_str_ + "_validity"; - std::stringstream fix_ss; - if (node.return_type()->id() == arrow::Type::DOUBLE || - node.return_type()->id() == arrow::Type::FLOAT) { - fix_ss << " * 1.0 "; - } std::stringstream prepare_ss; prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" << std::endl; prepare_ss << "bool " << validity << " = " << child_visitor_list[0]->GetPreCheck() << ";" << std::endl; prepare_ss << "if (" << validity << ") {" << std::endl; - prepare_ss << codes_str_ << " = static_cast<" << GetCTypeString(node.return_type()) - << ">(" << child_visitor_list[0]->GetResult() << fix_ss.str() << ");" - << std::endl; + + auto childNode = node.children().at(0); + if (childNode->return_type()->id() != arrow::Type::DECIMAL) { + // if not casting form Decimal + std::stringstream fix_ss; + if (node.return_type()->id() == arrow::Type::DOUBLE || + node.return_type()->id() == arrow::Type::FLOAT) { + fix_ss << " * 1.0 "; + } + prepare_ss << codes_str_ << " = static_cast<" << GetCTypeString(node.return_type()) + << ">(" << child_visitor_list[0]->GetResult() << fix_ss.str() << ");" + << std::endl; + } else { + // if casting From Decimal + auto decimal_type = + std::dynamic_pointer_cast(childNode->return_type()); + prepare_ss << codes_str_ << " = static_cast<" << GetCTypeString(node.return_type()) + << ">(castFloatFromDecimal(" << child_visitor_list[0]->GetResult() + << ", " << decimal_type->scale() << "));" << std::endl; + header_list_.push_back(R"(#include "precompile/gandiva.h")"); + } prepare_ss << "}" << std::endl; for (int i = 0; i < 1; i++) { @@ -354,11 +367,16 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) std::stringstream fix_ss; auto decimal_type = std::dynamic_pointer_cast(node.return_type()); - if (!child_visitor_list[0]->decimal_scale_.empty()) { - fix_ss << ", " << child_visitor_list[0]->decimal_scale_ << ", " - << decimal_type->scale(); - } else { + auto childNode = node.children().at(0); + if (childNode->return_type()->id() != arrow::Type::DECIMAL) { + // if not casting from Decimal fix_ss << ", " << decimal_type->precision() << ", " << decimal_type->scale(); + } else { + // if casting from Decimal + auto childType = + std::dynamic_pointer_cast(childNode->return_type()); + fix_ss << ", " << childType->precision() << ", " << childType->scale() << ", " + << decimal_type->precision() << ", " << decimal_type->scale(); } std::stringstream prepare_ss; prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" @@ -382,11 +400,11 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) std::stringstream fix_ss; auto decimal_type = std::dynamic_pointer_cast(node.return_type()); - if (!child_visitor_list[0]->decimal_scale_.empty()) { - fix_ss << ", 0, " << decimal_type->scale(); - } else { - fix_ss << ", " << decimal_type->precision() << ", " << decimal_type->scale(); - } + auto childNode = node.children().at(0); + auto childType = + std::dynamic_pointer_cast(childNode->return_type()); + fix_ss << ", " << childType->scale() << ", " << decimal_type->precision() << ", " + << decimal_type->scale(); std::stringstream prepare_ss; prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" << std::endl; @@ -466,6 +484,9 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } else if (func_name.compare("add") == 0) { codes_str_ = "add_" + std::to_string(cur_func_id); auto validity = codes_str_ + "_validity"; + std::stringstream fix_ss; + fix_ss << child_visitor_list[0]->GetResult() << " + " + << child_visitor_list[1]->GetResult(); std::stringstream prepare_ss; prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" << std::endl; @@ -474,8 +495,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) child_visitor_list[1]->GetPreCheck()}) << ");" << std::endl; prepare_ss << "if (" << validity << ") {" << std::endl; - prepare_ss << codes_str_ << " = " << child_visitor_list[0]->GetResult() << " + " - << child_visitor_list[1]->GetResult() << ";" << std::endl; + prepare_ss << codes_str_ << " = " << fix_ss.str() << ";" << std::endl; prepare_ss << "}" << std::endl; for (int i = 0; i < 2; i++) { @@ -486,6 +506,9 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } else if (func_name.compare("subtract") == 0) { codes_str_ = "subtract_" + std::to_string(cur_func_id); auto validity = codes_str_ + "_validity"; + std::stringstream fix_ss; + fix_ss << child_visitor_list[0]->GetResult() << " - " + << child_visitor_list[1]->GetResult(); std::stringstream prepare_ss; prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" << std::endl; @@ -494,8 +517,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) child_visitor_list[1]->GetPreCheck()}) << ");" << std::endl; prepare_ss << "if (" << validity << ") {" << std::endl; - prepare_ss << codes_str_ << " = " << child_visitor_list[0]->GetResult() << " - " - << child_visitor_list[1]->GetResult() << ";" << std::endl; + prepare_ss << codes_str_ << " = " << fix_ss.str() << ";" << std::endl; prepare_ss << "}" << std::endl; for (int i = 0; i < 2; i++) { @@ -506,6 +528,9 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } else if (func_name.compare("multiply") == 0) { codes_str_ = "multiply_" + std::to_string(cur_func_id); auto validity = codes_str_ + "_validity"; + std::stringstream fix_ss; + fix_ss << child_visitor_list[0]->GetResult() << " * " + << child_visitor_list[1]->GetResult(); std::stringstream prepare_ss; prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" << std::endl; @@ -514,8 +539,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) child_visitor_list[1]->GetPreCheck()}) << ");" << std::endl; prepare_ss << "if (" << validity << ") {" << std::endl; - prepare_ss << codes_str_ << " = " << child_visitor_list[0]->GetResult() << " * " - << child_visitor_list[1]->GetResult() << ";" << std::endl; + prepare_ss << codes_str_ << " = " << fix_ss.str() << ";" << std::endl; prepare_ss << "}" << std::endl; for (int i = 0; i < 2; i++) { @@ -526,6 +550,25 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) } else if (func_name.compare("divide") == 0) { codes_str_ = "divide_" + std::to_string(cur_func_id); auto validity = codes_str_ + "_validity"; + std::stringstream fix_ss; + if (node.return_type()->id() != arrow::Type::DECIMAL) { + fix_ss << child_visitor_list[0]->GetResult() << " / " + << child_visitor_list[1]->GetResult(); + } else { + auto leftNode = node.children().at(0); + auto rightNode = node.children().at(1); + auto leftType = + std::dynamic_pointer_cast(leftNode->return_type()); + auto rightType = + std::dynamic_pointer_cast(rightNode->return_type()); + auto resType = + std::dynamic_pointer_cast(node.return_type()); + fix_ss << "divide(" << child_visitor_list[0]->GetResult() << ", " + << leftType->precision() << ", " << leftType->scale() << ", " + << child_visitor_list[1]->GetResult() << ", " + << rightType->precision() << ", " << rightType->scale() << ", " + << resType->precision() << ", " << resType->scale() << ")"; + } std::stringstream prepare_ss; prepare_ss << GetCTypeString(node.return_type()) << " " << codes_str_ << ";" << std::endl; @@ -534,8 +577,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node) child_visitor_list[1]->GetPreCheck()}) << ");" << std::endl; prepare_ss << "if (" << validity << ") {" << std::endl; - prepare_ss << codes_str_ << " = " << child_visitor_list[0]->GetResult() << " / " - << child_visitor_list[1]->GetResult() << ";" << std::endl; + prepare_ss << codes_str_ << " = " << fix_ss.str() << ";" << std::endl; prepare_ss << "}" << std::endl; for (int i = 0; i < 2; i++) { diff --git a/cpp/src/precompile/array.h b/cpp/src/precompile/array.h index 1641ba9f7..e5fa14e88 100644 --- a/cpp/src/precompile/array.h +++ b/cpp/src/precompile/array.h @@ -1,8 +1,8 @@ #pragma once #include - -#include "arrow/util/string_view.h" // IWYU pragma: export +#include +#include // IWYU pragma: export namespace sparkcolumnarplugin { namespace precompile { @@ -130,9 +130,12 @@ class FixedSizeBinaryArray { public: FixedSizeBinaryArray(const std::shared_ptr&); arrow::util::string_view GetView(int64_t i) const { - return arrow::util::string_view(reinterpret_cast(raw_value_[i]), + return arrow::util::string_view(reinterpret_cast(GetValue(i)), byte_width_); } + const uint8_t* GetValue(int64_t i) const { + return raw_value_ + (i + offset_) * byte_width_; + } bool IsNull(int64_t i) const { i += offset_; return null_bitmap_data_ != NULLPTR && @@ -156,6 +159,10 @@ class FixedSizeBinaryArray { class Decimal128Array : public FixedSizeBinaryArray { public: Decimal128Array(const std::shared_ptr& in) : FixedSizeBinaryArray(in) {} + arrow::Decimal128 GetView(int64_t i) const { + const arrow::Decimal128 value(GetValue(i)); + return value; + } }; arrow::Status MakeFixedSizeBinaryArray(const std::shared_ptr&, diff --git a/cpp/src/precompile/gandiva.h b/cpp/src/precompile/gandiva.h index 0c445736f..c7fb8d6b4 100644 --- a/cpp/src/precompile/gandiva.h +++ b/cpp/src/precompile/gandiva.h @@ -7,6 +7,7 @@ #include #include "third_party/gandiva/types.h" +#include "third_party/gandiva/decimal_ops.h" int32_t castDATE32(int32_t in) { return castDATE_int32(in); } int64_t castDATE64(int32_t in) { return castDATE_date32(in); } @@ -26,10 +27,37 @@ arrow::Decimal128 castDECIMAL(double val, int32_t precision, int32_t scale) { snprintf(buffer, charsNeeded, "%.*f", (int)scale, nextafter(val, val + 0.5)); auto decimal_str = std::string(buffer); free(buffer); - return arrow::Decimal128(decimal_str); + return arrow::Decimal128::FromString(decimal_str).ValueOrDie(); } -arrow::Decimal128 castDECIMAL(arrow::Decimal128 in, int32_t original_scale, +double castFloatFromDecimal(arrow::Decimal128 val, int32_t scale) { + std::string str = val.ToString(scale); + return atof(str.c_str()); +} + +arrow::Decimal128 castDECIMAL(arrow::Decimal128 in, int32_t original_precision, + int32_t original_scale, int32_t new_precision, int32_t new_scale) { - return in.Rescale(original_scale, new_scale).ValueOrDie(); -} \ No newline at end of file + bool overflow = false; + gandiva::BasicDecimalScalar128 val(in, original_precision, original_scale); + auto out = gandiva::decimalops::Convert(val, new_precision, new_scale, &overflow); + if (overflow) { + throw std::overflow_error("castDECIMAL overflowed!"); + } + return arrow::Decimal128(out); +} + +arrow::Decimal128 divide(arrow::Decimal128 left, int32_t left_precision, + int32_t left_scale, arrow::Decimal128 right, + int32_t right_precision, int32_t right_scale, + int32_t out_precision, int32_t out_scale) { + gandiva::BasicDecimalScalar128 x(left, left_precision, left_scale); + gandiva::BasicDecimalScalar128 y(right, right_precision, right_scale); + bool overflow = false; + arrow::BasicDecimal128 out = + gandiva::decimalops::Divide(0, x, y, out_precision, out_scale, &overflow); + if (overflow) { + throw std::overflow_error("Decimal divide overflowed!"); + } + return arrow::Decimal128(out); +} diff --git a/cpp/src/tests/arrow_compute_test_precompile.cc b/cpp/src/tests/arrow_compute_test_precompile.cc index 308892771..2e70bf91b 100644 --- a/cpp/src/tests/arrow_compute_test_precompile.cc +++ b/cpp/src/tests/arrow_compute_test_precompile.cc @@ -21,6 +21,7 @@ #include "precompile/array.h" #include "tests/test_utils.h" +#include "precompile/gandiva.h" namespace sparkcolumnarplugin { namespace codegen { @@ -41,5 +42,22 @@ TEST(TestArrowCompute, BooleanArrayTest) { } } } + +TEST(TestArrowCompute, ArithmeticDecimalTest) { + auto left = arrow::Decimal128("32342423.012875"); + auto right = arrow::Decimal128("2347.012874535"); + int32_t left_scale = 6; + int32_t right_scale = 9; + int32_t left_precision = 14; + int32_t right_precision = 13; + int32_t out_precision = 22; + int32_t out_scale = 10; + auto res = castDECIMAL(left, left_precision, left_scale, out_precision, out_scale); + ASSERT_EQ(res, arrow::Decimal128("32342423.0128750000")); + res = divide(left, left_precision, left_scale, right, right_precision, right_scale, + out_precision, out_scale); + ASSERT_EQ(res, arrow::Decimal128("13780.2495094037")); +} + } // namespace codegen } // namespace sparkcolumnarplugin \ No newline at end of file diff --git a/cpp/src/third_party/gandiva/CMakeLists.txt b/cpp/src/third_party/gandiva/CMakeLists.txt index e265709a3..40f6f56d1 100644 --- a/cpp/src/third_party/gandiva/CMakeLists.txt +++ b/cpp/src/third_party/gandiva/CMakeLists.txt @@ -16,4 +16,5 @@ # under the License. set(THIRDPARTY_GANDIVA_SRCS + third_party/gandiva/decimal_ops.cc third_party/gandiva/time.cc PARENT_SCOPE) \ No newline at end of file diff --git a/cpp/src/third_party/gandiva/decimal_ops.cc b/cpp/src/third_party/gandiva/decimal_ops.cc new file mode 100644 index 000000000..e03c4d630 --- /dev/null +++ b/cpp/src/third_party/gandiva/decimal_ops.cc @@ -0,0 +1,724 @@ +// This File is copied from Gandiva in Arrow-3.0. +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Algorithms adapted from Apache Impala + +#include "decimal_ops.h" + +#include +#include +#include + +#include "arrow/util/logging.h" +#include "gandiva/decimal_type_util.h" +#include "gandiva/decimal_xlarge.h" +#include "gandiva/gdv_function_stubs.h" + +// Several operations (multiply, divide, mod, ..) require converting to 256-bit, and we +// use the boost library for doing 256-bit operations. To avoid references to boost from +// the precompiled-to-ir code (this causes issues with symbol resolution at runtime), we +// use a wrapper exported from the CPP code. The wrapper functions are named gdv_xlarge_xx + +namespace gandiva { +namespace decimalops { + +using arrow::BasicDecimal128; + +static BasicDecimal128 CheckAndIncreaseScale(const BasicDecimal128& in, int32_t delta) { + return (delta <= 0) ? in : in.IncreaseScaleBy(delta); +} + +static BasicDecimal128 CheckAndReduceScale(const BasicDecimal128& in, int32_t delta) { + return (delta <= 0) ? in : in.ReduceScaleBy(delta); +} + +/// Adjust x and y to the same scale, and add them. +static BasicDecimal128 AddFastPath(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_scale) { + auto higher_scale = std::max(x.scale(), y.scale()); + + auto x_scaled = CheckAndIncreaseScale(x.value(), higher_scale - x.scale()); + auto y_scaled = CheckAndIncreaseScale(y.value(), higher_scale - y.scale()); + return x_scaled + y_scaled; +} + +/// Add x and y, caller has ensured there can be no overflow. +static BasicDecimal128 AddNoOverflow(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_scale) { + auto higher_scale = std::max(x.scale(), y.scale()); + auto sum = AddFastPath(x, y, out_scale); + return CheckAndReduceScale(sum, higher_scale - out_scale); +} + +/// Both x_value and y_value must be >= 0 +static BasicDecimal128 AddLargePositive(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, + int32_t out_scale) { + DCHECK_GE(x.value(), 0); + DCHECK_GE(y.value(), 0); + + // separate out whole/fractions. + BasicDecimal128 x_left, x_right, y_left, y_right; + x.value().GetWholeAndFraction(x.scale(), &x_left, &x_right); + y.value().GetWholeAndFraction(y.scale(), &y_left, &y_right); + + // Adjust fractional parts to higher scale. + auto higher_scale = std::max(x.scale(), y.scale()); + auto x_right_scaled = CheckAndIncreaseScale(x_right, higher_scale - x.scale()); + auto y_right_scaled = CheckAndIncreaseScale(y_right, higher_scale - y.scale()); + + BasicDecimal128 right; + BasicDecimal128 carry_to_left; + auto multiplier = BasicDecimal128::GetScaleMultiplier(higher_scale); + if (x_right_scaled >= multiplier - y_right_scaled) { + right = x_right_scaled - (multiplier - y_right_scaled); + carry_to_left = 1; + } else { + right = x_right_scaled + y_right_scaled; + carry_to_left = 0; + } + right = CheckAndReduceScale(right, higher_scale - out_scale); + + auto left = x_left + y_left + carry_to_left; + return (left * BasicDecimal128::GetScaleMultiplier(out_scale)) + right; +} + +/// x_value and y_value cannot be 0, and one must be positive and the other negative. +static BasicDecimal128 AddLargeNegative(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, + int32_t out_scale) { + DCHECK_NE(x.value(), 0); + DCHECK_NE(y.value(), 0); + DCHECK((x.value() < 0 && y.value() > 0) || (x.value() > 0 && y.value() < 0)); + + // separate out whole/fractions. + BasicDecimal128 x_left, x_right, y_left, y_right; + x.value().GetWholeAndFraction(x.scale(), &x_left, &x_right); + y.value().GetWholeAndFraction(y.scale(), &y_left, &y_right); + + // Adjust fractional parts to higher scale. + auto higher_scale = std::max(x.scale(), y.scale()); + x_right = CheckAndIncreaseScale(x_right, higher_scale - x.scale()); + y_right = CheckAndIncreaseScale(y_right, higher_scale - y.scale()); + + // Overflow not possible because one is +ve and the other is -ve. + auto left = x_left + y_left; + auto right = x_right + y_right; + + // If the whole and fractional parts have different signs, then we need to make the + // fractional part have the same sign as the whole part. If either left or right is + // zero, then nothing needs to be done. + if (left < 0 && right > 0) { + left += 1; + right -= BasicDecimal128::GetScaleMultiplier(higher_scale); + } else if (left > 0 && right < 0) { + left -= 1; + right += BasicDecimal128::GetScaleMultiplier(higher_scale); + } + right = CheckAndReduceScale(right, higher_scale - out_scale); + return (left * BasicDecimal128::GetScaleMultiplier(out_scale)) + right; +} + +static BasicDecimal128 AddLarge(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_scale) { + if (x.value() >= 0 && y.value() >= 0) { + // both positive or 0 + return AddLargePositive(x, y, out_scale); + } else if (x.value() <= 0 && y.value() <= 0) { + // both negative or 0 + BasicDecimalScalar128 x_neg(-x.value(), x.precision(), x.scale()); + BasicDecimalScalar128 y_neg(-y.value(), y.precision(), y.scale()); + return -AddLargePositive(x_neg, y_neg, out_scale); + } else { + // one positive and the other negative + return AddLargeNegative(x, y, out_scale); + } +} + +// Suppose we have a number that requires x bits to be represented and we scale it up by +// 10^scale_by. Let's say now y bits are required to represent it. This function returns +// the maximum possible y - x for a given 'scale_by'. +inline int32_t MaxBitsRequiredIncreaseAfterScaling(int32_t scale_by) { + // We rely on the following formula: + // bits_required(x * 10^y) <= bits_required(x) + floor(log2(10^y)) + 1 + // We precompute floor(log2(10^x)) + 1 for x = 0, 1, 2...75, 76 + DCHECK_GE(scale_by, 0); + DCHECK_LE(scale_by, 76); + static const int32_t floor_log2_plus_one[] = { + 0, 4, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40, 44, 47, 50, + 54, 57, 60, 64, 67, 70, 74, 77, 80, 84, 87, 90, 94, 97, 100, 103, + 107, 110, 113, 117, 120, 123, 127, 130, 133, 137, 140, 143, 147, 150, 153, 157, + 160, 163, 167, 170, 173, 177, 180, 183, 187, 190, 193, 196, 200, 203, 206, 210, + 213, 216, 220, 223, 226, 230, 233, 236, 240, 243, 246, 250, 253}; + return floor_log2_plus_one[scale_by]; +} + +// If we have a number with 'num_lz' leading zeros, and we scale it up by 10^scale_by, +// this function returns the minimum number of leading zeros the result can have. +inline int32_t MinLeadingZerosAfterScaling(int32_t num_lz, int32_t scale_by) { + DCHECK_GE(scale_by, 0); + DCHECK_LE(scale_by, 76); + int32_t result = num_lz - MaxBitsRequiredIncreaseAfterScaling(scale_by); + return result; +} + +// Returns the maximum possible number of bits required to represent num * 10^scale_by. +inline int32_t MaxBitsRequiredAfterScaling(const BasicDecimalScalar128& num, + int32_t scale_by) { + auto value = num.value(); + auto value_abs = value.Abs(); + + int32_t num_occupied = 128 - value_abs.CountLeadingBinaryZeros(); + DCHECK_GE(scale_by, 0); + DCHECK_LE(scale_by, 76); + return num_occupied + MaxBitsRequiredIncreaseAfterScaling(scale_by); +} + +// Returns the minimum number of leading zero x or y would have after one of them gets +// scaled up to match the scale of the other one. +inline int32_t MinLeadingZeros(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y) { + auto x_value = x.value(); + auto x_value_abs = x_value.Abs(); + + auto y_value = y.value(); + auto y_value_abs = y_value.Abs(); + + int32_t x_lz = x_value_abs.CountLeadingBinaryZeros(); + int32_t y_lz = y_value_abs.CountLeadingBinaryZeros(); + if (x.scale() < y.scale()) { + x_lz = MinLeadingZerosAfterScaling(x_lz, y.scale() - x.scale()); + } else if (x.scale() > y.scale()) { + y_lz = MinLeadingZerosAfterScaling(y_lz, x.scale() - y.scale()); + } + return std::min(x_lz, y_lz); +} + +BasicDecimal128 Add(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y, + int32_t out_precision, int32_t out_scale) { + if (out_precision < DecimalTypeUtil::kMaxPrecision) { + // fast-path add + return AddFastPath(x, y, out_scale); + } else { + int32_t min_lz = MinLeadingZeros(x, y); + if (min_lz >= 3) { + // If both numbers have at least MIN_LZ leading zeros, we can add them directly + // without the risk of overflow. + // We want the result to have at least 2 leading zeros, which ensures that it fits + // into the maximum decimal because 2^126 - 1 < 10^38 - 1. If both x and y have at + // least 3 leading zeros, then we are guaranteed that the result will have at lest 2 + // leading zeros. + return AddNoOverflow(x, y, out_scale); + } else { + // slower-version : add whole/fraction parts separately, and then, combine. + return AddLarge(x, y, out_scale); + } + } +} + +BasicDecimal128 Subtract(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y, + int32_t out_precision, int32_t out_scale) { + return Add(x, {-y.value(), y.precision(), y.scale()}, out_precision, out_scale); +} + +// Multiply when the out_precision is 38, and there is no trimming of the scale i.e +// the intermediate value is the same as the final value. +static BasicDecimal128 MultiplyMaxPrecisionNoScaleDown(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, + int32_t out_scale, + bool* overflow) { + DCHECK_EQ(x.scale() + y.scale(), out_scale); + + BasicDecimal128 result; + auto x_abs = BasicDecimal128::Abs(x.value()); + auto y_abs = BasicDecimal128::Abs(y.value()); + + if (x_abs > BasicDecimal128::GetMaxValue() / y_abs) { + *overflow = true; + } else { + // We've verified that the result will fit into 128 bits. + *overflow = false; + result = x.value() * y.value(); + } + return result; +} + +// Multiply when the out_precision is 38, and there is trimming of the scale i.e +// the intermediate value could be larger than the final value. +static BasicDecimal128 MultiplyMaxPrecisionAndScaleDown(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, + int32_t out_scale, + bool* overflow) { + auto delta_scale = x.scale() + y.scale() - out_scale; + DCHECK_GT(delta_scale, 0); + + *overflow = false; + BasicDecimal128 result; + auto x_abs = BasicDecimal128::Abs(x.value()); + auto y_abs = BasicDecimal128::Abs(y.value()); + + // It's possible that the intermediate value does not fit in 128-bits, but the + // final value will (after scaling down). + bool needs_int256 = false; + int32_t total_leading_zeros = + x_abs.CountLeadingBinaryZeros() + y_abs.CountLeadingBinaryZeros(); + // This check is quick, but conservative. In some cases it will indicate that + // converting to 256 bits is necessary, when it's not actually the case. + needs_int256 = total_leading_zeros <= 128; + if (ARROW_PREDICT_FALSE(needs_int256)) { + int64_t result_high; + uint64_t result_low; + + // This requires converting to 256-bit, and we use the boost library for that. To + // avoid references to boost from the precompiled-to-ir code (this causes issues + // with symbol resolution at runtime), we use a wrapper exported from the CPP code. + gdv_xlarge_multiply_and_scale_down(x.value().high_bits(), x.value().low_bits(), + y.value().high_bits(), y.value().low_bits(), + delta_scale, &result_high, &result_low, overflow); + result = BasicDecimal128(result_high, result_low); + } else { + if (ARROW_PREDICT_TRUE(delta_scale <= 38)) { + // The largest value that result can have here is (2^64 - 1) * (2^63 - 1), which is + // greater than BasicDecimal128::kMaxValue. + result = x.value() * y.value(); + // Since delta_scale is greater than zero, result can now be at most + // ((2^64 - 1) * (2^63 - 1)) / 10, which is less than BasicDecimal128::kMaxValue, so + // there cannot be any overflow. + result = result.ReduceScaleBy(delta_scale); + } else { + // We are multiplying decimal(38, 38) by decimal(38, 38). The result should be a + // decimal(38, 37), so delta scale = 38 + 38 - 37 = 39. Since we are not in the + // 256 bit intermediate value case and we are scaling down by 39, then we are + // guaranteed that the result is 0 (even if we try to round). The largest possible + // intermediate result is 38 "9"s. If we scale down by 39, the leftmost 9 is now + // two digits to the right of the rightmost "visible" one. The reason why we have + // to handle this case separately is because a scale multiplier with a delta_scale + // 39 does not fit into 128 bit. + DCHECK_EQ(delta_scale, 39); + result = 0; + } + } + return result; +} + +// Multiply when the out_precision is 38. +static BasicDecimal128 MultiplyMaxPrecision(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, + int32_t out_scale, bool* overflow) { + auto delta_scale = x.scale() + y.scale() - out_scale; + DCHECK_GE(delta_scale, 0); + if (delta_scale == 0) { + return MultiplyMaxPrecisionNoScaleDown(x, y, out_scale, overflow); + } else { + return MultiplyMaxPrecisionAndScaleDown(x, y, out_scale, overflow); + } +} + +BasicDecimal128 Multiply(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y, + int32_t out_precision, int32_t out_scale, bool* overflow) { + BasicDecimal128 result; + *overflow = false; + if (out_precision < DecimalTypeUtil::kMaxPrecision) { + // fast-path multiply + result = x.value() * y.value(); + DCHECK_EQ(x.scale() + y.scale(), out_scale); + DCHECK_LE(BasicDecimal128::Abs(result), BasicDecimal128::GetMaxValue()); + } else if (x.value() == 0 || y.value() == 0) { + // Handle this separately to avoid divide-by-zero errors. + result = BasicDecimal128(0, 0); + } else { + result = MultiplyMaxPrecision(x, y, out_scale, overflow); + } + DCHECK(*overflow || BasicDecimal128::Abs(result) <= BasicDecimal128::GetMaxValue()); + return result; +} + +BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_precision, + int32_t out_scale, bool* overflow) { + if (y.value() == 0) { + char const* err_msg = "divide by zero error"; + gdv_fn_context_set_error_msg(context, err_msg); + return 0; + } + + // scale up to the output scale, and do an integer division. + int32_t delta_scale = out_scale + y.scale() - x.scale(); + DCHECK_GE(delta_scale, 0); + + BasicDecimal128 result; + auto num_bits_required_after_scaling = MaxBitsRequiredAfterScaling(x, delta_scale); + if (num_bits_required_after_scaling <= 127) { + // fast-path. The dividend fits in 128-bit after scaling too. + *overflow = false; + + // do the division. + auto x_scaled = CheckAndIncreaseScale(x.value(), delta_scale); + BasicDecimal128 remainder; + auto status = x_scaled.Divide(y.value(), &result, &remainder); + DCHECK_EQ(status, arrow::DecimalStatus::kSuccess); + + // round-up + if (BasicDecimal128::Abs(2 * remainder) >= BasicDecimal128::Abs(y.value())) { + result += (x.value().Sign() ^ y.value().Sign()) + 1; + } + } else { + // convert to 256-bit and do the divide. + *overflow = delta_scale > 38 && num_bits_required_after_scaling > 255; + if (!*overflow) { + int64_t result_high; + uint64_t result_low; + + gdv_xlarge_scale_up_and_divide(x.value().high_bits(), x.value().low_bits(), + y.value().high_bits(), y.value().low_bits(), + delta_scale, &result_high, &result_low, overflow); + result = BasicDecimal128(result_high, result_low); + } + } + return result; +} + +BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_precision, + int32_t out_scale, bool* overflow) { + if (y.value() == 0) { + char const* err_msg = "divide by zero error"; + gdv_fn_context_set_error_msg(context, err_msg); + return 0; + } + + // Adsjust x and y to the same scale (higher one), and then, do a integer mod. + *overflow = false; + BasicDecimal128 result; + int32_t min_lz = MinLeadingZeros(x, y); + if (min_lz >= 2) { + auto higher_scale = std::max(x.scale(), y.scale()); + auto x_scaled = CheckAndIncreaseScale(x.value(), higher_scale - x.scale()); + auto y_scaled = CheckAndIncreaseScale(y.value(), higher_scale - y.scale()); + result = x_scaled % y_scaled; + DCHECK_LE(BasicDecimal128::Abs(result), BasicDecimal128::GetMaxValue()); + } else { + int64_t result_high; + uint64_t result_low; + + gdv_xlarge_mod(x.value().high_bits(), x.value().low_bits(), x.scale(), + y.value().high_bits(), y.value().low_bits(), y.scale(), &result_high, + &result_low); + result = BasicDecimal128(result_high, result_low); + } + DCHECK(BasicDecimal128::Abs(result) <= BasicDecimal128::Abs(x.value()) || + BasicDecimal128::Abs(result) <= BasicDecimal128::Abs(y.value())); + return result; +} + +int32_t CompareSameScale(const BasicDecimal128& x, const BasicDecimal128& y) { + if (x == y) { + return 0; + } else if (x < y) { + return -1; + } else { + return 1; + } +} + +int32_t Compare(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y) { + int32_t delta_scale = x.scale() - y.scale(); + + // fast-path : both are of the same scale. + if (delta_scale == 0) { + return CompareSameScale(x.value(), y.value()); + } + + // Check if we'll need more than 256-bits after adjusting the scale. + bool need256 = + (delta_scale < 0 && x.precision() - delta_scale > DecimalTypeUtil::kMaxPrecision) || + (y.precision() + delta_scale > DecimalTypeUtil::kMaxPrecision); + if (need256) { + return gdv_xlarge_compare(x.value().high_bits(), x.value().low_bits(), x.scale(), + y.value().high_bits(), y.value().low_bits(), y.scale()); + } else { + BasicDecimal128 x_scaled; + BasicDecimal128 y_scaled; + + if (delta_scale < 0) { + x_scaled = x.value().IncreaseScaleBy(-delta_scale); + y_scaled = y.value(); + } else { + x_scaled = x.value(); + y_scaled = y.value().IncreaseScaleBy(delta_scale); + } + return CompareSameScale(x_scaled, y_scaled); + } +} + +#define DECIMAL_OVERFLOW_IF(condition, overflow) \ + do { \ + if (*overflow || (condition)) { \ + *overflow = true; \ + return 0; \ + } \ + } while (0) + +static BasicDecimal128 GetMaxValue(int32_t precision) { + return BasicDecimal128::GetScaleMultiplier(precision) - 1; +} + +// Compute the double scale multipliers once. +static std::array kDoubleScaleMultipliers = + ([]() -> std::array { + std::array values; + values[0] = 1.0; + for (int32_t idx = 1; idx <= DecimalTypeUtil::kMaxPrecision; idx++) { + values[idx] = values[idx - 1] * 10; + } + return values; + })(); + +BasicDecimal128 FromDouble(double in, int32_t precision, int32_t scale, bool* overflow) { + // Multiply decimal with the scale + auto unscaled = in * kDoubleScaleMultipliers[scale]; + DECIMAL_OVERFLOW_IF(std::isnan(unscaled), overflow); + + unscaled = std::round(unscaled); + + // convert scaled double to int128 + int32_t sign = unscaled < 0 ? -1 : 1; + auto unscaled_abs = std::abs(unscaled); + + // overflow if > 2^127 - 1 + DECIMAL_OVERFLOW_IF(unscaled_abs > std::ldexp(static_cast(1), 127) - 1, + overflow); + + uint64_t high_bits = static_cast(std::ldexp(unscaled_abs, -64)); + uint64_t low_bits = static_cast( + unscaled_abs - std::ldexp(static_cast(high_bits), 64)); + + auto result = BasicDecimal128(static_cast(high_bits), low_bits); + + // overflow if > max value based on precision + DECIMAL_OVERFLOW_IF(result > GetMaxValue(precision), overflow); + return result * sign; +} + +double ToDouble(const BasicDecimalScalar128& in, bool* overflow) { + // convert int128 to double + int64_t sign = in.value().Sign(); + auto value_abs = BasicDecimal128::Abs(in.value()); + double unscaled = static_cast(value_abs.low_bits()) + + std::ldexp(static_cast(value_abs.high_bits()), 64); + + // scale double. + return (unscaled * sign) / kDoubleScaleMultipliers[in.scale()]; +} + +BasicDecimal128 FromInt64(int64_t in, int32_t precision, int32_t scale, bool* overflow) { + // check if multiplying by scale will cause an overflow. + DECIMAL_OVERFLOW_IF(std::abs(in) > GetMaxValue(precision - scale), overflow); + return in * BasicDecimal128::GetScaleMultiplier(scale); +} + +// Helper function to modify the scale and/or precision of a decimal value. +static BasicDecimal128 ModifyScaleAndPrecision(const BasicDecimalScalar128& x, + int32_t out_precision, int32_t out_scale, + bool* overflow) { + int32_t delta_scale = out_scale - x.scale(); + if (delta_scale >= 0) { + // check if multiplying by delta_scale will cause an overflow. + DECIMAL_OVERFLOW_IF( + BasicDecimal128::Abs(x.value()) > GetMaxValue(out_precision - delta_scale), + overflow); + return x.value().IncreaseScaleBy(delta_scale); + } else { + // Do not do any rounding, that is handled by the caller. + auto result = x.value().ReduceScaleBy(-delta_scale, false); + DECIMAL_OVERFLOW_IF(BasicDecimal128::Abs(result) > GetMaxValue(out_precision), + overflow); + return result; + } +} + +enum RoundType { + kRoundTypeCeil, // +1 if +ve and trailing value is > 0, else no rounding. + kRoundTypeFloor, // -1 if -ve and trailing value is < 0, else no rounding. + kRoundTypeTrunc, // no rounding, truncate the trailing digits. + kRoundTypeHalfRoundUp, // if +ve and trailing value is >= half of base, +1. + // else if -ve and trailing value is >= half of base, -1. +}; + +// Compute the rounding delta for the givven rounding type. +static int32_t ComputeRoundingDelta(const BasicDecimal128& x, int32_t x_scale, + int32_t out_scale, RoundType type) { + if (type == kRoundTypeTrunc || // no rounding for this type. + out_scale >= x_scale) { // no digits dropped, so no rounding. + return 0; + } + + int32_t result = 0; + switch (type) { + case kRoundTypeHalfRoundUp: { + auto base = BasicDecimal128::GetScaleMultiplier(x_scale - out_scale); + auto trailing = x % base; + if (trailing == 0) { + result = 0; + } else if (trailing.Abs() < base / 2) { + result = 0; + } else { + result = (x < 0) ? -1 : 1; + } + break; + } + + case kRoundTypeCeil: + if (x < 0) { + // no rounding for -ve + result = 0; + } else { + auto base = BasicDecimal128::GetScaleMultiplier(x_scale - out_scale); + auto trailing = x % base; + result = (trailing == 0) ? 0 : 1; + } + break; + + case kRoundTypeFloor: + if (x > 0) { + // no rounding for +ve + result = 0; + } else { + auto base = BasicDecimal128::GetScaleMultiplier(x_scale - out_scale); + auto trailing = x % base; + result = (trailing == 0) ? 0 : -1; + } + break; + + case kRoundTypeTrunc: + break; + } + return result; +} + +// Modify the scale and round. +static BasicDecimal128 RoundWithPositiveScale(const BasicDecimalScalar128& x, + int32_t out_precision, int32_t out_scale, + RoundType round_type, bool* overflow) { + DCHECK_GE(out_scale, 0); + + auto scaled = ModifyScaleAndPrecision(x, out_precision, out_scale, overflow); + if (*overflow) { + return 0; + } + + auto delta = ComputeRoundingDelta(x.value(), x.scale(), out_scale, round_type); + if (delta == 0) { + return scaled; + } + + // If there is a rounding delta, the output scale must be less than the input scale. + // That means at least one digit is dropped after the decimal. The delta add can add + // utmost one digit before the decimal. So, overflow will occur only if the output + // precision has changed. + DCHECK_GT(x.scale(), out_scale); + auto result = scaled + delta; + DECIMAL_OVERFLOW_IF(out_precision < x.precision() && + BasicDecimal128::Abs(result) > GetMaxValue(out_precision), + overflow); + return result; +} + +// Modify scale to drop all digits to the right of the decimal and round. +// Then, zero out 'rounding_scale' number of digits to the left of the decimal point. +static BasicDecimal128 RoundWithNegativeScale(const BasicDecimalScalar128& x, + int32_t out_precision, + int32_t rounding_scale, + RoundType round_type, bool* overflow) { + DCHECK_LT(rounding_scale, 0); + + // get rid of the fractional part. + auto scaled = ModifyScaleAndPrecision(x, out_precision, 0, overflow); + auto rounding_delta = ComputeRoundingDelta(scaled, 0, -rounding_scale, round_type); + + auto base = BasicDecimal128::GetScaleMultiplier(-rounding_scale); + auto delta = rounding_delta * base - (scaled % base); + DECIMAL_OVERFLOW_IF(BasicDecimal128::Abs(scaled) > + GetMaxValue(out_precision) - BasicDecimal128::Abs(delta), + overflow); + return scaled + delta; +} + +BasicDecimal128 Round(const BasicDecimalScalar128& x, int32_t out_precision, + int32_t out_scale, int32_t rounding_scale, bool* overflow) { + // no-op if target scale is same as arg scale + if (x.scale() == out_scale && rounding_scale >= 0) { + return x.value(); + } + + if (rounding_scale < 0) { + return RoundWithNegativeScale(x, out_precision, rounding_scale, + RoundType::kRoundTypeHalfRoundUp, overflow); + } else { + return RoundWithPositiveScale(x, out_precision, rounding_scale, + RoundType::kRoundTypeHalfRoundUp, overflow); + } +} + +BasicDecimal128 Truncate(const BasicDecimalScalar128& x, int32_t out_precision, + int32_t out_scale, int32_t rounding_scale, bool* overflow) { + // no-op if target scale is same as arg scale + if (x.scale() == out_scale && rounding_scale >= 0) { + return x.value(); + } + + if (rounding_scale < 0) { + return RoundWithNegativeScale(x, out_precision, rounding_scale, + RoundType::kRoundTypeTrunc, overflow); + } else { + return RoundWithPositiveScale(x, out_precision, rounding_scale, + RoundType::kRoundTypeTrunc, overflow); + } +} + +BasicDecimal128 Ceil(const BasicDecimalScalar128& x, bool* overflow) { + return RoundWithPositiveScale(x, x.precision(), 0, RoundType::kRoundTypeCeil, overflow); +} + +BasicDecimal128 Floor(const BasicDecimalScalar128& x, bool* overflow) { + return RoundWithPositiveScale(x, x.precision(), 0, RoundType::kRoundTypeFloor, + overflow); +} + +BasicDecimal128 Convert(const BasicDecimalScalar128& x, int32_t out_precision, + int32_t out_scale, bool* overflow) { + DCHECK_GE(out_scale, 0); + DCHECK_LE(out_scale, DecimalTypeUtil::kMaxScale); + DCHECK_GT(out_precision, 0); + DCHECK_LE(out_precision, DecimalTypeUtil::kMaxScale); + + return RoundWithPositiveScale(x, out_precision, out_scale, + RoundType::kRoundTypeHalfRoundUp, overflow); +} + +int64_t ToInt64(const BasicDecimalScalar128& in, bool* overflow) { + auto rounded = RoundWithPositiveScale(in, in.precision(), 0 /*scale*/, + RoundType::kRoundTypeHalfRoundUp, overflow); + DECIMAL_OVERFLOW_IF((rounded > std::numeric_limits::max()) || + (rounded < std::numeric_limits::min()), + overflow); + return static_cast(rounded.low_bits()); +} + +} // namespace decimalops +} // namespace gandiva diff --git a/cpp/src/third_party/gandiva/decimal_ops.h b/cpp/src/third_party/gandiva/decimal_ops.h new file mode 100644 index 000000000..5a4e50bba --- /dev/null +++ b/cpp/src/third_party/gandiva/decimal_ops.h @@ -0,0 +1,91 @@ +// This File is copied from Gandiva in Arrow-3.0. +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include "gandiva/basic_decimal_scalar.h" + +namespace gandiva { +namespace decimalops { + +/// Return the sum of 'x' and 'y'. +/// out_precision and out_scale are passed along for efficiency, they must match +/// the rules in DecimalTypeSql::GetResultType. +arrow::BasicDecimal128 Add(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y, + int32_t out_precision, int32_t out_scale); + +/// Subtract 'y' from 'x', and return the result. +arrow::BasicDecimal128 Subtract(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_precision, + int32_t out_scale); + +/// Multiply 'x' from 'y', and return the result. +arrow::BasicDecimal128 Multiply(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_precision, + int32_t out_scale, bool* overflow); + +/// Divide 'x' by 'y', and return the result. +arrow::BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_precision, + int32_t out_scale, bool* overflow); + +/// Divide 'x' by 'y', and return the remainder. +arrow::BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_precision, + int32_t out_scale, bool* overflow); + +/// Compare two decimals. Returns : +/// 0 if x == y +/// 1 if x > y +/// -1 if x < y +int32_t Compare(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y); + +/// Convert to decimal from double. +BasicDecimal128 FromDouble(double in, int32_t precision, int32_t scale, bool* overflow); + +/// Convert from decimal to double. +double ToDouble(const BasicDecimalScalar128& in, bool* overflow); + +/// Convert to decimal from gdv_int64. +BasicDecimal128 FromInt64(int64_t in, int32_t precision, int32_t scale, bool* overflow); + +/// Convert from decimal to gdv_int64 +int64_t ToInt64(const BasicDecimalScalar128& in, bool* overflow); + +/// Convert from one decimal scale/precision to another. +BasicDecimal128 Convert(const BasicDecimalScalar128& x, int32_t out_precision, + int32_t out_scale, bool* overflow); + +/// round decimal. +BasicDecimal128 Round(const BasicDecimalScalar128& x, int32_t out_precision, + int32_t out_scale, int32_t rounding_scale, bool* overflow); + +/// truncate decimal. +BasicDecimal128 Truncate(const BasicDecimalScalar128& x, int32_t out_precision, + int32_t out_scale, int32_t rounding_scale, bool* overflow); + +/// ceil decimal +BasicDecimal128 Ceil(const BasicDecimalScalar128& x, bool* overflow); + +/// floor decimal +BasicDecimal128 Floor(const BasicDecimalScalar128& x, bool* overflow); + +} // namespace decimalops +} // namespace gandiva