Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-130] support decimal in project #131

Merged
merged 7 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.util.ExecutorManager
import org.apache.spark.sql.util.StructTypeFWD
import org.apache.spark.{SparkConf, TaskContext}
Expand Down Expand Up @@ -70,8 +70,10 @@ case class ColumnarConditionProjectExec(
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarConditionProjector.")
if (!attr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarConditionProjector.")
}
}
})
// check expr
Expand All @@ -80,8 +82,10 @@ case class ColumnarConditionProjectExec(
ConverterUtils.checkIfTypeSupported(condExpr.dataType)
} catch {
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${condExpr.dataType} is not supported in ColumnarConditionProjector.")
if (!condExpr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${condExpr.dataType} is not supported in ColumnarConditionProjector.")
}
}
ColumnarExpressionConverter.replaceWithColumnarExpression(condExpr)
}
Expand All @@ -91,8 +95,10 @@ case class ColumnarConditionProjectExec(
ConverterUtils.checkIfTypeSupported(expr.dataType)
} catch {
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${expr.dataType} is not supported in ColumnarConditionProjector.")
if (!expr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${expr.dataType} is not supported in ColumnarConditionProjector.")
}
}
ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
}
Expand Down
162 changes: 100 additions & 62 deletions core/src/main/scala/com/intel/oap/expression/ColumnarArithmetic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,20 @@
package com.intel.oap.expression

import com.google.common.collect.Lists

import org.apache.arrow.gandiva.evaluator._
import org.apache.arrow.gandiva.exceptions.GandivaException
import org.apache.arrow.gandiva.expression._
import org.apache.arrow.vector.types.FloatingPointPrecision
import org.apache.arrow.vector.types.pojo.ArrowType
import org.apache.arrow.vector.types.pojo.Field

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

import scala.collection.mutable.ListBuffer

import org.apache.arrow.gandiva.evaluator.DecimalTypeUtil

/**
* A version of add that supports columnar processing for longs.
*/
Expand All @@ -44,22 +45,30 @@ class ColumnarAdd(left: Expression, right: Expression, original: Expression)
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.ADD, l, r)
val addNode = TreeBuilder.makeFunction(
"add", Lists.newArrayList(left_node, right_node), resultType)
(addNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType)
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType)
}
//logInfo(s"(TreeBuilder.makeFunction(add, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
val funcNode = TreeBuilder.makeFunction(
"add", Lists.newArrayList(left_node, right_node), resultType)
(funcNode, resultType)
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
}

//logInfo(s"(TreeBuilder.makeFunction(add, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
(
TreeBuilder.makeFunction("add", Lists.newArrayList(left_node, right_node), resultType),
resultType)
}
}

Expand All @@ -73,21 +82,30 @@ class ColumnarSubtract(left: Expression, right: Expression, original: Expression
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.SUBTRACT, l, r)
val subNode = TreeBuilder.makeFunction(
"subtract", Lists.newArrayList(left_node, right_node), resultType)
(subNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType)
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType)
}
//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
val funcNode = TreeBuilder.makeFunction(
"subtract", Lists.newArrayList(left_node, right_node), resultType)
(funcNode, resultType)
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
}
//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
(
TreeBuilder.makeFunction("subtract", Lists.newArrayList(left_node, right_node), resultType),
resultType)
}
}

Expand All @@ -101,22 +119,30 @@ class ColumnarMultiply(left: Expression, right: Expression, original: Expression
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.MULTIPLY, l, r)
val mulNode = TreeBuilder.makeFunction(
"multiply", Lists.newArrayList(left_node, right_node), resultType)
(mulNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
}
//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
val funcNode = TreeBuilder.makeFunction(
"multiply", Lists.newArrayList(left_node, right_node), resultType)
(funcNode, resultType)
}

//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
(
TreeBuilder.makeFunction("multiply", Lists.newArrayList(left_node, right_node), resultType),
resultType)
}
}

Expand All @@ -130,21 +156,30 @@ class ColumnarDivide(left: Expression, right: Expression, original: Expression)
var (right_node, right_type): (TreeNode, ArrowType) =
right.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
(left_type, right_type) match {
case (l: ArrowType.Decimal, r: ArrowType.Decimal) =>
val resultType = DecimalTypeUtil.getResultTypeForOperation(
DecimalTypeUtil.OperationType.DIVIDE, l, r)
val divNode = TreeBuilder.makeFunction(
"divide", Lists.newArrayList(left_node, right_node), resultType)
(divNode, resultType)
case _ =>
val resultType = CodeGeneration.getResultType(left_type, right_type)
if (!left_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
left_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(left_node), resultType),
}
if (!right_type.equals(resultType)) {
val func_name = CodeGeneration.getCastFuncName(resultType)
right_node =
TreeBuilder.makeFunction(func_name, Lists.newArrayList(right_node), resultType),
}
//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
val funcNode = TreeBuilder.makeFunction(
"divide", Lists.newArrayList(left_node, right_node), resultType)
(funcNode, resultType)
}
//logInfo(s"(TreeBuilder.makeFunction(multiply, Lists.newArrayList($left_node, $right_node), $resultType), $resultType)")
(
TreeBuilder.makeFunction("divide", Lists.newArrayList(left_node, right_node), resultType),
resultType)
}
}

Expand Down Expand Up @@ -238,8 +273,11 @@ object ColumnarBinaryArithmetic {
ConverterUtils.checkIfTypeSupported(right.dataType)
} catch {
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${left.dataType} or ${right.dataType} is not supported in ColumnarBinaryArithmetic")
if (!left.dataType.isInstanceOf[DecimalType] ||
!right.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${left.dataType} or ${right.dataType} is not supported in ColumnarBinaryArithmetic")
}
}
}
}
Loading