Skip to content

Commit

Permalink
[GLUTEN-7364] Simplify the RuleInjector
Browse files Browse the repository at this point in the history
  • Loading branch information
beliefer committed Sep 28, 2024
1 parent b9fbb47 commit 10ae59f
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import org.apache.gluten.extension.columnar._
import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides}
import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions}
import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector}
import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector, RasInjector}
import org.apache.gluten.extension.injector.RuleInjector
import org.apache.gluten.parser.{GlutenCacheFilesSqlParser, GlutenClickhouseSqlParser}
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.PhysicalPlanSelector
Expand All @@ -37,30 +37,30 @@ class CHRuleApi extends RuleApi {
override def injectRules(injector: RuleInjector): Unit = {
injector.gluten.skipOn(PhysicalPlanSelector.skipCond)

injectSpark(injector.spark)
injectSpark(injector)
injectLegacy(injector.gluten.legacy)
injectRas(injector.gluten.ras)
}
}

private object CHRuleApi {
def injectSpark(injector: SparkInjector): Unit = {
// Regular Spark rules.
injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply)
injector.injectQueryStagePrepRule(spark => CHAQEPropagateEmptyRelation(spark))
injector.injectParser(
def injectSpark(injector: RuleInjector): Unit = {
// Inject the regular Spark rules directly.
injector.extensions.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply)
injector.extensions.injectQueryStagePrepRule(spark => CHAQEPropagateEmptyRelation(spark))
injector.extensions.injectParser(
(spark, parserInterface) => new GlutenCacheFilesSqlParser(spark, parserInterface))
injector.injectParser(
injector.extensions.injectParser(
(spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface))
injector.injectResolutionRule(
injector.extensions.injectResolutionRule(
spark => new RewriteToDateExpresstionRule(spark, spark.sessionState.conf))
injector.injectResolutionRule(
injector.extensions.injectResolutionRule(
spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf))
injector.injectOptimizerRule(
injector.extensions.injectOptimizerRule(
spark => new CommonSubexpressionEliminateRule(spark, spark.sessionState.conf))
injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark))
injector.injectOptimizerRule(_ => CountDistinctWithoutExpand)
injector.injectOptimizerRule(_ => EqualToRewrite)
injector.extensions.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark))
injector.extensions.injectOptimizerRule(_ => CountDistinctWithoutExpand)
injector.extensions.injectOptimizerRule(_ => EqualToRewrite)
}

def injectLegacy(injector: LegacyInjector): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTable
import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform
import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions}
import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector}
import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector, RasInjector}
import org.apache.gluten.extension.injector.RuleInjector
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.PhysicalPlanSelector

Expand All @@ -38,18 +38,18 @@ class VeloxRuleApi extends RuleApi {
override def injectRules(injector: RuleInjector): Unit = {
injector.gluten.skipOn(PhysicalPlanSelector.skipCond)

injectSpark(injector.spark)
injectSpark(injector)
injectLegacy(injector.gluten.legacy)
injectRas(injector.gluten.ras)
}
}

private object VeloxRuleApi {
def injectSpark(injector: SparkInjector): Unit = {
// Regular Spark rules.
injector.injectOptimizerRule(CollectRewriteRule.apply)
injector.injectOptimizerRule(HLLRewriteRule.apply)
injector.injectPostHocResolutionRule(ArrowConvertorRule.apply)
def injectSpark(injector: RuleInjector): Unit = {
// Inject the regular Spark rules directly.
injector.extensions.injectOptimizerRule(CollectRewriteRule.apply)
injector.extensions.injectOptimizerRule(HLLRewriteRule.apply)
injector.extensions.injectPostHocResolutionRule(ArrowConvertorRule.apply)
}

def injectLegacy(injector: LegacyInjector): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import org.apache.spark.sql.SparkSessionExtensions

private[gluten] class GlutenSessionExtensions extends (SparkSessionExtensions => Unit) {
override def apply(exts: SparkSessionExtensions): Unit = {
val injector = new RuleInjector()
val injector = new RuleInjector(exts)
Backend.get().injectRules(injector)
injector.inject(exts)
injector.inject()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ class GlutenInjector private[injector] {
val ras: RasInjector = new RasInjector()

private[injector] def inject(extensions: SparkSessionExtensions): Unit = {
val ruleBuilder = (session: SparkSession) => new GlutenColumnarRule(session, applier)
extensions.injectColumnar(session => ruleBuilder(session))
extensions.injectColumnar(session => new GlutenColumnarRule(session, applier))
}

private def applier(session: SparkSession): ColumnarRuleApplier = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ package org.apache.gluten.extension.injector
import org.apache.spark.sql.SparkSessionExtensions

/** Injector used to inject query planner rules into Spark and Gluten. */
class RuleInjector {
val spark: SparkInjector = new SparkInjector()
class RuleInjector(val extensions: SparkSessionExtensions) {
val gluten: GlutenInjector = new GlutenInjector()

private[extension] def inject(extensions: SparkSessionExtensions): Unit = {
spark.inject(extensions)
private[extension] def inject(): Unit = {
// The regular Spark rules already injected with the `injectRules` of `RuleApi` directly.
// Only inject the Spark columnar rule here.
gluten.inject(extensions)
}
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ package org.apache.gluten.backendsapi
import org.apache.gluten.extension.injector.RuleInjector

trait RuleApi {
// Injects all Gluten / Spark query planner rules used by the backend.
// Injects all Spark query planner rules used by the Gluten backend.
def injectRules(injector: RuleInjector): Unit
}

0 comments on commit 10ae59f

Please sign in to comment.