Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-130] support decimal in project (#131)
Browse files Browse the repository at this point in the history
* support decimal in project

* support more functions for decimal

* fix cast decimal

* use gandiva divide functions

* use convert decimal api for rescale

* use decimal ops from arrow-3.0

* refine
  • Loading branch information
rui-mo authored Mar 5, 2021
1 parent 74c35b2 commit c38528a
Show file tree
Hide file tree
Showing 15 changed files with 1,165 additions and 172 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.util.ExecutorManager
import org.apache.spark.sql.util.StructTypeFWD
import org.apache.spark.{SparkConf, TaskContext}
Expand Down Expand Up @@ -70,8 +70,10 @@ case class ColumnarConditionProjectExec(
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarConditionProjector.")
if (!attr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarConditionProjector.")
}
}
})
// check expr
Expand All @@ -80,8 +82,10 @@ case class ColumnarConditionProjectExec(
ConverterUtils.checkIfTypeSupported(condExpr.dataType)
} catch {
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${condExpr.dataType} is not supported in ColumnarConditionProjector.")
if (!condExpr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${condExpr.dataType} is not supported in ColumnarConditionProjector.")
}
}
ColumnarExpressionConverter.replaceWithColumnarExpression(condExpr)
}
Expand All @@ -91,8 +95,10 @@ case class ColumnarConditionProjectExec(
ConverterUtils.checkIfTypeSupported(expr.dataType)
} catch {
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${expr.dataType} is not supported in ColumnarConditionProjector.")
if (!expr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${expr.dataType} is not supported in ColumnarConditionProjector.")
}
}
ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
}
Expand Down
162 changes: 100 additions & 62 deletions core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,20 @@
package com.intel.oap.expression

import com.google.common.collect.Lists

import org.apache.arrow.gandiva.evaluator._
import org.apache.arrow.gandiva.exceptions.GandivaException
import org.apache.arrow.gandiva.expression._
import org.apache.arrow.vector.types.FloatingPointPrecision
import org.apache.arrow.vector.types.pojo.ArrowType
import org.apache.arrow.vector.types.pojo.Field

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

import scala.collection.mutable.ListBuffer

import org.apache.arrow.gandiva.evaluator.DecimalTypeUtil

/**
* A version of add that supports columnar processing for longs.
*/
Expand All @@ -44,22 +45,30 @@ class ColumnarAdd(left: Expression, right: Expression, original: Expression)
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.ADD, l, r)
val addNode = TreeBuilder.makeFunction(
"add", Lists.newArrayList(left_node, right_node), resultType)
(addNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType)
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType)
}
//logInfo(s"(TreeBuilder.makeFunction(add, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
val funcNode = TreeBuilder.makeFunction(
"add", Lists.newArrayList(left_node, right_node), resultType)
(funcNode, resultType)
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
}

//logInfo(s"(TreeBuilder.makeFunction(add, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
(
TreeBuilder.makeFunction("add", Lists.newArrayList(left_node, right_node), resultType),
resultType)
}
}

Expand All @@ -73,21 +82,30 @@ class ColumnarSubtract(left: Expression, right: Expression, original: Expression
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.SUBTRACT, l, r)
val subNode = TreeBuilder.makeFunction(
"subtract", Lists.newArrayList(left_node, right_node), resultType)
(subNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType)
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType)
}
//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
val funcNode = TreeBuilder.makeFunction(
"subtract", Lists.newArrayList(left_node, right_node), resultType)
(funcNode, resultType)
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
}
//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
(
TreeBuilder.makeFunction("subtract", Lists.newArrayList(left_node, right_node), resultType),
resultType)
}
}

Expand All @@ -101,22 +119,30 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.MULTIPLY, l, r)
val mulNode = TreeBuilder.makeFunction(
"multiply", Lists.newArrayList(left_node, right_node), resultType)
(mulNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
}
//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
val funcNode = TreeBuilder.makeFunction(
"multiply", Lists.newArrayList(left_node, right_node), resultType)
(funcNode, resultType)
}

//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
(
TreeBuilder.makeFunction("multiply", Lists.newArrayList(left_node, right_node), resultType),
resultType)
}
}

Expand All @@ -130,21 +156,30 @@ class ColumnarDivide(left: Expression, right: Expression, original: Expression)
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.DIVIDE, l, r)
val divNode = TreeBuilder.makeFunction(
"divide", Lists.newArrayList(left_node, right_node), resultType)
(divNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
}
//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
val funcNode = TreeBuilder.makeFunction(
"divide", Lists.newArrayList(left_node, right_node), resultType)
(funcNode, resultType)
}
//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
(
TreeBuilder.makeFunction("divide", Lists.newArrayList(left_node, right_node), resultType),
resultType)
}
}

Expand Down Expand Up @@ -238,8 +273,11 @@ object ColumnarBinaryArithmetic {
ConverterUtils.checkIfTypeSupported(right.dataType)
} catch {
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${left.dataType} or ${right.dataType} is not supported in ColumnarBinaryArithmetic")
if (!left.dataType.isInstanceOf[DecimalType] ||
!right.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${left.dataType} or ${right.dataType} is not supported in ColumnarBinaryArithmetic")
}
}
}
}
Loading

0 comments on commit c38528a

Please sign in to comment.