Skip to content

Commit

Permalink
[GLUTEN-7364] Simplify the RuleInjector
Browse files Browse the repository at this point in the history
  • Loading branch information
beliefer committed Sep 29, 2024
1 parent b9fbb47 commit 4ff7473
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class CHRuleApi extends RuleApi {

private object CHRuleApi {
def injectSpark(injector: SparkInjector): Unit = {
// Regular Spark rules.
// Inject the regular Spark rules directly.
injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply)
injector.injectQueryStagePrepRule(spark => CHAQEPropagateEmptyRelation(spark))
injector.injectParser(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class VeloxRuleApi extends RuleApi {

private object VeloxRuleApi {
def injectSpark(injector: SparkInjector): Unit = {
// Regular Spark rules.
// Inject the regular Spark rules directly.
injector.injectOptimizerRule(CollectRewriteRule.apply)
injector.injectOptimizerRule(HLLRewriteRule.apply)
injector.injectPostHocResolutionRule(ArrowConvertorRule.apply)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import org.apache.spark.sql.SparkSessionExtensions

private[gluten] class GlutenSessionExtensions extends (SparkSessionExtensions => Unit) {
override def apply(exts: SparkSessionExtensions): Unit = {
val injector = new RuleInjector()
val injector = new RuleInjector(exts)
Backend.get().injectRules(injector)
injector.inject(exts)
injector.inject()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ class GlutenInjector private[injector] {
val ras: RasInjector = new RasInjector()

private[injector] def inject(extensions: SparkSessionExtensions): Unit = {
val ruleBuilder = (session: SparkSession) => new GlutenColumnarRule(session, applier)
extensions.injectColumnar(session => ruleBuilder(session))
extensions.injectColumnar(session => new GlutenColumnarRule(session, applier))
}

private def applier(session: SparkSession): ColumnarRuleApplier = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package org.apache.gluten.extension.injector
import org.apache.spark.sql.SparkSessionExtensions

/** Injector used to inject query planner rules into Spark and Gluten. */
class RuleInjector {
val spark: SparkInjector = new SparkInjector()
class RuleInjector(extensions: SparkSessionExtensions) {
val spark: SparkInjector = new SparkInjector(extensions)
val gluten: GlutenInjector = new GlutenInjector()

private[extension] def inject(extensions: SparkSessionExtensions): Unit = {
spark.inject(extensions)
private[extension] def inject(): Unit = {
// The regular Spark rules already injected with the `injectRules` of `RuleApi` directly.
// Only inject the Spark columnar rule here.
gluten.inject(extensions)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,59 +25,35 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan

import scala.collection.mutable

/** Injector used to inject query planner rules into Spark. */
class SparkInjector private[injector] {
private type RuleBuilder = SparkSession => Rule[LogicalPlan]
private type StrategyBuilder = SparkSession => Strategy
private type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
private type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder)
private type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan]

private val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder]
private val parserBuilders = mutable.Buffer.empty[ParserBuilder]
private val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
private val optimizerRules = mutable.Buffer.empty[RuleBuilder]
private val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder]
private val injectedFunctions = mutable.Buffer.empty[FunctionDescription]
private val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]

def injectQueryStagePrepRule(builder: QueryStagePrepRuleBuilder): Unit = {
queryStagePrepRuleBuilders += builder
}
class SparkInjector private[injector] (extensions: SparkSessionExtensions) {

def injectParser(builder: ParserBuilder): Unit = {
parserBuilders += builder
def injectQueryStagePrepRule(builder: SparkSession => Rule[SparkPlan]): Unit = {
extensions.injectQueryStagePrepRule(builder)
}

def injectResolutionRule(builder: RuleBuilder): Unit = {
resolutionRuleBuilders += builder
def injectResolutionRule(builder: SparkSession => Rule[LogicalPlan]): Unit = {
extensions.injectResolutionRule(builder)
}

def injectOptimizerRule(builder: RuleBuilder): Unit = {
optimizerRules += builder
def injectPostHocResolutionRule(builder: SparkSession => Rule[LogicalPlan]): Unit = {
extensions.injectPostHocResolutionRule(builder)
}

def injectPlannerStrategy(builder: StrategyBuilder): Unit = {
plannerStrategyBuilders += builder
def injectOptimizerRule(builder: SparkSession => Rule[LogicalPlan]): Unit = {
extensions.injectOptimizerRule(builder)
}

def injectFunction(functionDescription: FunctionDescription): Unit = {
injectedFunctions += functionDescription
def injectPlannerStrategy(builder: SparkSession => Strategy): Unit = {
extensions.injectPlannerStrategy(builder)
}

def injectPostHocResolutionRule(builder: RuleBuilder): Unit = {
postHocResolutionRuleBuilders += builder
def injectParser(builder: (SparkSession, ParserInterface) => ParserInterface): Unit = {
extensions.injectParser(builder)
}

private[injector] def inject(extensions: SparkSessionExtensions): Unit = {
queryStagePrepRuleBuilders.foreach(extensions.injectQueryStagePrepRule)
parserBuilders.foreach(extensions.injectParser)
resolutionRuleBuilders.foreach(extensions.injectResolutionRule)
optimizerRules.foreach(extensions.injectOptimizerRule)
plannerStrategyBuilders.foreach(extensions.injectPlannerStrategy)
injectedFunctions.foreach(extensions.injectFunction)
postHocResolutionRuleBuilders.foreach(extensions.injectPostHocResolutionRule)
def injectFunction(
functionDescription: (FunctionIdentifier, ExpressionInfo, FunctionBuilder)): Unit = {
extensions.injectFunction(functionDescription)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ package org.apache.gluten.backendsapi
import org.apache.gluten.extension.injector.RuleInjector

trait RuleApi {
// Injects all Gluten / Spark query planner rules used by the backend.
// Injects all Spark query planner rules used by the Gluten backend.
def injectRules(injector: RuleInjector): Unit
}

0 comments on commit 4ff7473

Please sign in to comment.