Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
Following NSE-153, optimize fallback conditions for columnar window (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Mar 24, 2021
1 parent 6fbf0fc commit 53993e3
Showing 1 changed file with 56 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.arrow.gandiva.expression.TreeBuilder
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, Descending, Expression, MakeDecimal, NamedExpression, Rank, SortOrder, UnscaledValue, WindowExpression, WindowFunction}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, Descending, Expression, MakeDecimal, NamedExpression, Rank, SortOrder, UnscaledValue, WindowExpression, WindowFunction, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, Sum}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.execution.SparkPlan
Expand Down Expand Up @@ -71,45 +71,65 @@ class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
val sparkConf = sparkContext.getConf
val numaBindingInfo = ColumnarPluginConfig.getConf.numaBindingInfo

def checkAggFunctionSpec(windowSpec: WindowSpecDefinition): Unit = {
if (windowSpec.orderSpec.nonEmpty) {
throw new UnsupportedOperationException("unsupported operation for " +
"aggregation window function: " + windowSpec)
}
}

def checkRankSpec(windowSpec: WindowSpecDefinition): Unit = {
// leave it empty for now
}

val windowFunctions: Seq[(String, Expression)] = windowExpression
.map(e => e.asInstanceOf[Alias])
.map(a => a.child.asInstanceOf[WindowExpression])
.map(w => w.windowFunction)
.map(w => (w, w.windowFunction))
.map {
case a: AggregateExpression => a.aggregateFunction
case b: WindowFunction => b
case f =>
throw new UnsupportedOperationException("unsupported window function type: " +
f)
case (expr, func) =>
(expr, func match {
case a: AggregateExpression => a.aggregateFunction
case b: WindowFunction => b
case f =>
throw new UnsupportedOperationException("unsupported window function type: " +
f)
})
}
.map { f =>
val name = f match {
case _: Sum => "sum"
case _: Average => "avg"
case _: Rank =>
val desc: Option[Boolean] = orderSpec.foldLeft[Option[Boolean]](None) {
(desc, s) =>
val currentDesc = s.direction match {
case Ascending => false
case Descending => true
case _ => throw new IllegalStateException
}
if (desc.isEmpty) {
Some(currentDesc)
} else if (currentDesc == desc.get) {
Some(currentDesc)
} else {
throw new UnsupportedOperationException("Rank: clashed rank order found")
}
}
desc match {
case Some(true) => "rank_desc"
case Some(false) => "rank_asc"
case None => "rank_asc"
}
case f => throw new UnsupportedOperationException("unsupported window function: " + f)
}
(name, f)
.map {
case (expr, func) =>
val name = func match {
case _: Sum =>
checkAggFunctionSpec(expr.windowSpec)
"sum"
case _: Average =>
checkAggFunctionSpec(expr.windowSpec)
"avg"
case _: Rank =>
checkRankSpec(expr.windowSpec)
val desc: Option[Boolean] = orderSpec.foldLeft[Option[Boolean]](None) {
(desc, s) =>
val currentDesc = s.direction match {
case Ascending => false
case Descending => true
case _ => throw new IllegalStateException
}
if (desc.isEmpty) {
Some(currentDesc)
} else if (currentDesc == desc.get) {
Some(currentDesc)
} else {
throw new UnsupportedOperationException("Rank: clashed rank order found")
}
}
desc match {
case Some(true) => "rank_desc"
case Some(false) => "rank_asc"
case None => "rank_asc"
}
case f => throw new UnsupportedOperationException("unsupported window function: " + f)
}
(name, func)
}

if (windowFunctions.isEmpty) {
Expand Down Expand Up @@ -349,11 +369,6 @@ object ColumnarWindowExec {
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan): SparkPlan = {
//TODO(): this is a quick fix on non-avg window issue
if (!windowExpression.toString.contains("avg")) {
new ColumnarWindowExec(windowExpression, partitionSpec, orderSpec, child)
} else {
createWithProjection(windowExpression, partitionSpec, orderSpec, child)
}
createWithProjection(windowExpression, partitionSpec, orderSpec, child)
}
}

0 comments on commit 53993e3

Please sign in to comment.