Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
fix on decimal dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Jan 20, 2021
1 parent 957ba76 commit 92715ee
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1668,14 +1668,12 @@ final void setNull(int rowId) {

@Override
final void setInt(int rowId, int value) {
BigDecimal v = new BigDecimal(value);
writer.setSafe(rowId, v.setScale(writer.getScale()));
writer.setSafe(rowId, value);
}

@Override
final void setLong(int rowId, long value) {
BigDecimal v = new BigDecimal(value);
writer.setSafe(rowId, v.setScale(writer.getScale()));
writer.setSafe(rowId, value);
}

@Override
Expand Down
30 changes: 30 additions & 0 deletions core/src/main/scala/com/intel/oap/ColumnarGuardRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,36 @@ case class ColumnarGuardRule(conf: SparkConf) extends Rule[SparkPlan] {
case plan: BroadcastExchangeExec =>
ColumnarBroadcastExchangeExec(plan.mode, plan.child)
case plan: BroadcastHashJoinExec =>
// We need to check if BroadcastExchangeExec can be converted to columnar-based.
// If not, BHJ should also be row-based.
val left = plan.left
left match {
case exec: BroadcastExchangeExec =>
new ColumnarBroadcastExchangeExec(exec.mode, exec.child)
case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec) =>
new ColumnarBroadcastExchangeExec(plan.mode, plan.child)
case BroadcastQueryStageExec(_, plan: ReusedExchangeExec) =>
plan match {
case ReusedExchangeExec(_, b: BroadcastExchangeExec) =>
new ColumnarBroadcastExchangeExec(b.mode, b.child)
case _ =>
}
case _ =>
}
val right = plan.right
right match {
case exec: BroadcastExchangeExec =>
new ColumnarBroadcastExchangeExec(exec.mode, exec.child)
case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec) =>
new ColumnarBroadcastExchangeExec(plan.mode, plan.child)
case BroadcastQueryStageExec(_, plan: ReusedExchangeExec) =>
plan match {
case ReusedExchangeExec(_, b: BroadcastExchangeExec) =>
new ColumnarBroadcastExchangeExec(b.mode, b.child)
case _ =>
}
case _ =>
}
ColumnarBroadcastHashJoinExec(
plan.leftKeys,
plan.rightKeys,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.apache.spark.sql.types.{MapType, StructType}
import org.apache.spark.sql.types.{DecimalType, MapType, StructType}
import org.apache.spark.util.ExecutorManager
import org.apache.spark.sql.util.StructTypeFWD
import org.apache.spark.{SparkConf, TaskContext}
Expand Down Expand Up @@ -237,7 +237,8 @@ case class ColumnarUnionExec(children: Seq[SparkPlan]) extends SparkPlan {
def buildCheck(): Unit = {
for (child <- children) {
for (schema <- child.schema) {
if (schema.dataType.isInstanceOf[MapType]) {
if (schema.dataType.isInstanceOf[MapType] ||
schema.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${schema.dataType} is not supported in ColumnarUnionExec")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import io.netty.buffer.ByteBuf
import com.google.common.collect.Lists
import com.intel.oap.expression._
import com.intel.oap.vectorized.ExpressionEvaluator
import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
Expand Down Expand Up @@ -100,17 +101,6 @@ case class ColumnarBroadcastHashJoinExec(
buildCheck()

def buildCheck(): Unit = {
// build check for BroadcastExchange
left match {
case exec: BroadcastExchangeExec =>
new ColumnarBroadcastExchangeExec(exec.mode, exec.child)
case _ =>
}
right match {
case exec: BroadcastExchangeExec =>
new ColumnarBroadcastExchangeExec(exec.mode, exec.child)
case _ =>
}
// build check for condition
val conditionExpr: Expression = condition.orNull
if (conditionExpr != null) {
Expand Down Expand Up @@ -558,4 +548,4 @@ case class ColumnarBroadcastHashJoinExec(
}

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import com.intel.oap.vectorized._
import com.intel.oap.ColumnarPluginConfig
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{UserAddedJarUtils, Utils, ExecutorManager}
import org.apache.spark.util.{ExecutorManager, UserAddedJarUtils, Utils}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
Expand All @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
import scala.collection.JavaConverters._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}

import scala.collection.mutable.ListBuffer
import org.apache.arrow.vector.ipc.message.ArrowFieldNode
Expand All @@ -52,6 +52,7 @@ import com.intel.oap.vectorized.ExpressionEvaluator
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashJoin}
import org.apache.spark.sql.types.DecimalType

/**
* Performs a hash join of two child relations by first shuffling the data using the join keys.
Expand Down Expand Up @@ -94,6 +95,17 @@ case class ColumnarShuffledHashJoinExec(

def buildCheck(): Unit = {
for (attr <- buildPlan.output) {
if (attr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarShuffledHashJoin.")
}
CodeGeneration.getResultType(attr.dataType)
}
for (attr <- streamedPlan.output) {
if (attr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarShuffledHashJoin.")
}
CodeGeneration.getResultType(attr.dataType)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ case class ColumnarSortExec(
val numOutputRows = longMetric("numOutputRows")
val numOutputBatches = longMetric("numOutputBatches")

ColumnarSorter.buildCheck(sortOrder)
ColumnarSorter.buildCheck(output)

/***************** WSCG related function ******************/
override def inputRDDs(): Seq[RDD[ColumnarBatch]] = child match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ object ColumnarProjection extends Logging {
for (expr <- exprs) {
val func = ColumnarExpressionConverter
.replaceWithColumnarExpression(expr, originalInputAttributes)
val unsupportedTypes = List(NullType, TimestampType, BinaryType)
val datatype = func.dataType
if (unsupportedTypes.indexOf(datatype) != -1 || datatype.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${datatype} is not supported in ColumnarProjection.")
}
CodeGeneration.getResultType(func.dataType)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,13 @@ object ColumnarSorter extends Logging {
(TreeBuilder.makeExpression(sort_node, retType), new Schema(outputFieldList.asJava))
}

def buildCheck(sortOrder: Seq[SortOrder]): Unit = synchronized {
def buildCheck(outputAttributes: Seq[Attribute]): Unit = synchronized {
val unsupportedTypes = List(NullType, TimestampType, BinaryType)
for (sort <- sortOrder) {
val keyType = ConverterUtils.getAttrFromExpr(sort.child).dataType
if (unsupportedTypes.indexOf(keyType) != -1 || keyType.isInstanceOf[DecimalType]) {
for (attr <- outputAttributes) {
val datatype = attr.dataType
if (unsupportedTypes.indexOf(datatype) != -1 || datatype.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${keyType} is not supported in ColumnarSorter.")
s"${datatype} is not supported in ColumnarSorter.")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,20 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
}
}

def buildCheck(): Unit = {
for (expr <- buildKeyExprs) {
ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
}
val unsupportedTypes = List(NullType, TimestampType, BinaryType, ByteType)
output.toList.foreach(attr => {
if (unsupportedTypes.indexOf(attr.dataType) != -1 ||
attr.dataType.isInstanceOf[DecimalType])
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarBroadcastExchangeExec.")
CodeGeneration.getResultType(attr.dataType)
})
}

override def doPrepare(): Unit = {
// Materialize the future.
relationFuture
Expand Down Expand Up @@ -286,22 +300,6 @@ class ColumnarBroadcastExchangeAdaptor(mode: BroadcastMode, child: SparkPlan)

override def canEqual(other: Any): Boolean = other.isInstanceOf[ColumnarShuffleExchangeAdaptor]

def buildCheck(): Unit = {
for (expr <- buildKeyExprs) {
ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
}
val unsupportedTypes = List(NullType, TimestampType, BinaryType, ByteType)
output.toList.foreach(attr => {
if (unsupportedTypes.indexOf(attr.dataType) != -1 ||
attr.dataType.isInstanceOf[DecimalType])
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarBroadcastExchangeExec.")
CodeGeneration.getResultType(attr.dataType)
})
}

override def canEqual(other: Any): Boolean = other.isInstanceOf[ColumnarBroadcastExchangeExec]

override def equals(other: Any): Boolean = other match {
case that: ColumnarShuffleExchangeAdaptor =>
(that canEqual this) && super.equals(that)
Expand Down

0 comments on commit 92715ee

Please sign in to comment.