Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-7364][CORE] Simplify the RuleInjector #7365

Merged
merged 1 commit into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class CHRuleApi extends RuleApi {

private object CHRuleApi {
def injectSpark(injector: SparkInjector): Unit = {
// Regular Spark rules.
// Inject the regular Spark rules directly.
injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply)
injector.injectQueryStagePrepRule(spark => CHAQEPropagateEmptyRelation(spark))
injector.injectParser(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class VeloxRuleApi extends RuleApi {

private object VeloxRuleApi {
def injectSpark(injector: SparkInjector): Unit = {
// Regular Spark rules.
// Inject the regular Spark rules directly.
injector.injectOptimizerRule(CollectRewriteRule.apply)
injector.injectOptimizerRule(HLLRewriteRule.apply)
injector.injectPostHocResolutionRule(ArrowConvertorRule.apply)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import org.apache.spark.sql.SparkSessionExtensions

private[gluten] class GlutenSessionExtensions extends (SparkSessionExtensions => Unit) {
override def apply(exts: SparkSessionExtensions): Unit = {
val injector = new RuleInjector()
val injector = new RuleInjector(exts)
Backend.get().injectRules(injector)
injector.inject(exts)
injector.inject()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ class GlutenInjector private[injector] {
val ras: RasInjector = new RasInjector()

private[injector] def inject(extensions: SparkSessionExtensions): Unit = {
val ruleBuilder = (session: SparkSession) => new GlutenColumnarRule(session, applier)
extensions.injectColumnar(session => ruleBuilder(session))
extensions.injectColumnar(session => new GlutenColumnarRule(session, applier))
}

private def applier(session: SparkSession): ColumnarRuleApplier = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package org.apache.gluten.extension.injector
import org.apache.spark.sql.SparkSessionExtensions

/** Injector used to inject query planner rules into Spark and Gluten. */
class RuleInjector {
val spark: SparkInjector = new SparkInjector()
class RuleInjector(extensions: SparkSessionExtensions) {
val spark: SparkInjector = new SparkInjector(extensions)
val gluten: GlutenInjector = new GlutenInjector()

private[extension] def inject(extensions: SparkSessionExtensions): Unit = {
spark.inject(extensions)
private[extension] def inject(): Unit = {
// The regular Spark rules already injected with the `injectRules` of `RuleApi` directly.
// Only inject the Spark columnar rule here.
gluten.inject(extensions)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,59 +25,35 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan

import scala.collection.mutable

/** Injector used to inject query planner rules into Spark. */
class SparkInjector private[injector] {
private type RuleBuilder = SparkSession => Rule[LogicalPlan]
private type StrategyBuilder = SparkSession => Strategy
private type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
private type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder)
private type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan]

private val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder]
private val parserBuilders = mutable.Buffer.empty[ParserBuilder]
private val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
private val optimizerRules = mutable.Buffer.empty[RuleBuilder]
private val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder]
private val injectedFunctions = mutable.Buffer.empty[FunctionDescription]
private val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]

def injectQueryStagePrepRule(builder: QueryStagePrepRuleBuilder): Unit = {
queryStagePrepRuleBuilders += builder
}
class SparkInjector private[injector] (extensions: SparkSessionExtensions) {

def injectParser(builder: ParserBuilder): Unit = {
parserBuilders += builder
def injectQueryStagePrepRule(builder: SparkSession => Rule[SparkPlan]): Unit = {
extensions.injectQueryStagePrepRule(builder)
}

def injectResolutionRule(builder: RuleBuilder): Unit = {
resolutionRuleBuilders += builder
def injectResolutionRule(builder: SparkSession => Rule[LogicalPlan]): Unit = {
extensions.injectResolutionRule(builder)
}

def injectOptimizerRule(builder: RuleBuilder): Unit = {
optimizerRules += builder
def injectPostHocResolutionRule(builder: SparkSession => Rule[LogicalPlan]): Unit = {
extensions.injectPostHocResolutionRule(builder)
}

def injectPlannerStrategy(builder: StrategyBuilder): Unit = {
plannerStrategyBuilders += builder
def injectOptimizerRule(builder: SparkSession => Rule[LogicalPlan]): Unit = {
extensions.injectOptimizerRule(builder)
}

def injectFunction(functionDescription: FunctionDescription): Unit = {
injectedFunctions += functionDescription
def injectPlannerStrategy(builder: SparkSession => Strategy): Unit = {
extensions.injectPlannerStrategy(builder)
}

def injectPostHocResolutionRule(builder: RuleBuilder): Unit = {
postHocResolutionRuleBuilders += builder
def injectParser(builder: (SparkSession, ParserInterface) => ParserInterface): Unit = {
extensions.injectParser(builder)
}

private[injector] def inject(extensions: SparkSessionExtensions): Unit = {
queryStagePrepRuleBuilders.foreach(extensions.injectQueryStagePrepRule)
parserBuilders.foreach(extensions.injectParser)
resolutionRuleBuilders.foreach(extensions.injectResolutionRule)
optimizerRules.foreach(extensions.injectOptimizerRule)
plannerStrategyBuilders.foreach(extensions.injectPlannerStrategy)
injectedFunctions.foreach(extensions.injectFunction)
postHocResolutionRuleBuilders.foreach(extensions.injectPostHocResolutionRule)
def injectFunction(
functionDescription: (FunctionIdentifier, ExpressionInfo, FunctionBuilder)): Unit = {
extensions.injectFunction(functionDescription)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ package org.apache.gluten.backendsapi
import org.apache.gluten.extension.injector.RuleInjector

trait RuleApi {
// Injects all Gluten / Spark query planner rules used by the backend.
// Injects all Spark query planner rules used by the Gluten backend.
def injectRules(injector: RuleInjector): Unit
}
Loading