Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-942] Force to use hash aggregate for string type input (#941)
Browse files Browse the repository at this point in the history
* Initial commit

* Revise unit tests

* Change NativeDataFrameAggregateSuite

* Replace SortAggregate at columnar override

* Remove SortExec if sort agg is replaced by columnar hash agg

* Fix issues reported by UT
  • Loading branch information
PHILO-HE authored Jun 8, 2022
1 parent e179f98 commit a7ff199
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ class GazellePluginConfig(conf: SQLConf) extends Logging {
val enableColumnarHashAgg: Boolean =
conf.getConfString("spark.oap.sql.columnar.hashagg", "true").toBoolean && enableCpu

val ENABLE_HASH_AGG_FOR_STRING_TYPE_KEY = "spark.oap.sql.columnar.hashagg.support.string"
// To control whether hash agg is used for string type input, instead of sort agg.
val enableHashAggForStringType: Boolean =
conf.getConfString(
ENABLE_HASH_AGG_FOR_STRING_TYPE_KEY, "true").toBoolean && enableCpu

// enable or disable columnar project and filter
val enableColumnarProjFilter: Boolean =
conf.getConfString("spark.oap.sql.columnar.projfilter", "true").toBoolean && enableCpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,14 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.ShufflePartitionSpec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{ShuffleStageInfo, _}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, ColumnarArrowEvalPythonExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf

import org.apache.spark.util.ShufflePartitionUtils

import scala.collection.mutable
Expand Down Expand Up @@ -122,6 +121,33 @@ case class ColumnarPreOverrides(session: SparkSession) extends Rule[SparkPlan] {
plan.initialInputBufferOffset,
plan.resultExpressions,
child)
case plan: SortAggregateExec if (columnarConf.enableHashAggForStringType) =>
try {
val child = replaceWithColumnarPlan(plan.child)
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
ColumnarHashAggregateExec(
plan.requiredChildDistributionExpressions,
plan.groupingExpressions,
plan.aggregateExpressions,
plan.aggregateAttributes,
plan.initialInputBufferOffset,
plan.resultExpressions,
// If SortAggregateExec is forcibly replaced by ColumnarHashAggregateExec,
// Sort operator is useless. So just use its child to initialize.
child match {
case sort: ColumnarSortExec =>
sort.child
case sort: SortExec =>
sort.child
case other =>
other
})
} catch {
case _: Throwable =>
logInfo("Fallback to SortAggregateExec instead of forcibly" +
" using ColumnarHashAggregateExec!")
plan
}
case plan: UnionExec =>
val children = plan.children.map(replaceWithColumnarPlan)
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.FullOuter
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive._
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.exchange._
Expand Down Expand Up @@ -108,6 +108,15 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] {
plan.initialInputBufferOffset,
plan.resultExpressions,
plan.child)
case plan: SortAggregateExec if (columnarConf.enableHashAggForStringType) =>
ColumnarHashAggregateExec(
plan.requiredChildDistributionExpressions,
plan.groupingExpressions,
plan.aggregateExpressions,
plan.aggregateAttributes,
plan.initialInputBufferOffset,
plan.resultExpressions,
plan.child)
case plan: UnionExec =>
if (!enableColumnarUnion) return false
new ColumnarUnionExec(plan.children)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import com.intel.oap.GazellePluginConfig
import com.intel.oap.execution.ColumnarHashAggregateExec

import scala.util.Random
Expand Down Expand Up @@ -1042,6 +1043,7 @@ class DataFrameAggregateSuite extends QueryTest
Seq(true, false).foreach { value =>
test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") {
withSQLConf(
GazellePluginConfig.getSessionConf.ENABLE_HASH_AGG_FOR_STRING_TYPE_KEY -> "false",
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) {
withTempView("t1", "t2") {
sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)")
Expand Down Expand Up @@ -1078,6 +1080,27 @@ class DataFrameAggregateSuite extends QueryTest
}
}

Seq(true, false).foreach { value =>
test(s"Force to use hash agg for string type with (whole-stage-codegen = $value)") {
withSQLConf(
GazellePluginConfig.getSessionConf.ENABLE_HASH_AGG_FOR_STRING_TYPE_KEY -> "true",
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) {
withTempView("t1") {
sql("create temporary view t1 as select * from values('A'), ('B'), ('C') as t1(col1)")
// test hashAggregateExec
var df = sql("select max(col1) from t1")
assert(find(df.queryExecution.executedPlan)(
_.isInstanceOf[ColumnarHashAggregateExec]).isDefined)
checkAnswer(df, Row("C") :: Nil)
df = sql("select first(col1) from t1")
assert(find(df.queryExecution.executedPlan)(
_.isInstanceOf[ColumnarHashAggregateExec]).isDefined)
checkAnswer(df, Row("A") :: Nil)
}
}
}
}

test("SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate") {
withTempView("view") {
val nan1 = java.lang.Float.intBitsToFloat(0x7f800001)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package org.apache.spark.sql.nativesql

import com.intel.oap.GazellePluginConfig
import com.intel.oap.execution.ColumnarHashAggregateExec

import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row}

import scala.util.Random
Expand Down Expand Up @@ -1043,6 +1043,7 @@ class NativeDataFrameAggregateSuite extends QueryTest
Seq(true, false).foreach { value =>
test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") {
withSQLConf(
GazellePluginConfig.getSessionConf.ENABLE_HASH_AGG_FOR_STRING_TYPE_KEY -> "false",
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) {
withTempView("t1", "t2") {
sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)")
Expand Down Expand Up @@ -1079,6 +1080,27 @@ class NativeDataFrameAggregateSuite extends QueryTest
}
}

Seq(true, false).foreach { value =>
test(s"Force to use hash agg for string type with (whole-stage-codegen = $value)") {
withSQLConf(
GazellePluginConfig.getSessionConf.ENABLE_HASH_AGG_FOR_STRING_TYPE_KEY -> "true",
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) {
withTempView("t1") {
sql("create temporary view t1 as select * from values('A'), ('B'), ('C') as t1(col1)")
// test hashAggregateExec
var df = sql("select max(col1) from t1")
assert(find(df.queryExecution.executedPlan)(
_.isInstanceOf[ColumnarHashAggregateExec]).isDefined)
checkAnswer(df, Row("C") :: Nil)
df = sql("select first(col1) from t1")
assert(find(df.queryExecution.executedPlan)(
_.isInstanceOf[ColumnarHashAggregateExec]).isDefined)
checkAnswer(df, Row("A") :: Nil)
}
}
}
}

test("SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate") {
withTempView("view") {
val nan1 = java.lang.Float.intBitsToFloat(0x7f800001)
Expand Down

0 comments on commit a7ff199

Please sign in to comment.