Skip to content

Commit

Permalink
[GLUTEN-7364] Simplify the RuleInjector
Browse files Browse the repository at this point in the history
  • Loading branch information
beliefer committed Sep 26, 2024
1 parent 0983ec2 commit 5897e55
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,22 @@ class CHRuleApi extends RuleApi {

private object CHRuleApi {
def injectSpark(injector: SparkInjector): Unit = {
// Regular Spark rules.
injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply)
injector.injectQueryStagePrepRule(spark => CHAQEPropagateEmptyRelation(spark))
injector.injectParser(
// Inject the regular Spark rules directly.
injector.extensions.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply)
injector.extensions.injectQueryStagePrepRule(spark => CHAQEPropagateEmptyRelation(spark))
injector.extensions.injectParser(
(spark, parserInterface) => new GlutenCacheFilesSqlParser(spark, parserInterface))
injector.injectParser(
injector.extensions.injectParser(
(spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface))
injector.injectResolutionRule(
injector.extensions.injectResolutionRule(
spark => new RewriteToDateExpresstionRule(spark, spark.sessionState.conf))
injector.injectResolutionRule(
injector.extensions.injectResolutionRule(
spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf))
injector.injectOptimizerRule(
injector.extensions.injectOptimizerRule(
spark => new CommonSubexpressionEliminateRule(spark, spark.sessionState.conf))
injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark))
injector.injectOptimizerRule(_ => CountDistinctWithoutExpand)
injector.injectOptimizerRule(_ => EqualToRewrite)
injector.extensions.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark))
injector.extensions.injectOptimizerRule(_ => CountDistinctWithoutExpand)
injector.extensions.injectOptimizerRule(_ => EqualToRewrite)
}

def injectLegacy(injector: LegacyInjector): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ class VeloxRuleApi extends RuleApi {

private object VeloxRuleApi {
def injectSpark(injector: SparkInjector): Unit = {
// Regular Spark rules.
injector.injectOptimizerRule(CollectRewriteRule.apply)
injector.injectOptimizerRule(HLLRewriteRule.apply)
injector.injectPostHocResolutionRule(ArrowConvertorRule.apply)
// Inject the regular Spark rules directly.
injector.extensions.injectOptimizerRule(CollectRewriteRule.apply)
injector.extensions.injectOptimizerRule(HLLRewriteRule.apply)
injector.extensions.injectPostHocResolutionRule(ArrowConvertorRule.apply)
}

def injectLegacy(injector: LegacyInjector): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSessionExtensions

private[gluten] class GlutenSessionExtensions extends (SparkSessionExtensions => Unit) {
override def apply(exts: SparkSessionExtensions): Unit = {
val injector = new RuleInjector()
val injector = new RuleInjector(exts)
Backend.get().injectRules(injector)
injector.inject(exts)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ class GlutenInjector private[injector] {
val ras: RasInjector = new RasInjector()

private[injector] def inject(extensions: SparkSessionExtensions): Unit = {
val ruleBuilder = (session: SparkSession) => new GlutenColumnarRule(session, applier)
extensions.injectColumnar(session => ruleBuilder(session))
extensions.injectColumnar(session => new GlutenColumnarRule(session, applier))
}

private def applier(session: SparkSession): ColumnarRuleApplier = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package org.apache.gluten.extension.injector
import org.apache.spark.sql.SparkSessionExtensions

/** Injector used to inject query planner rules into Spark and Gluten. */
class RuleInjector {
val spark: SparkInjector = new SparkInjector()
class RuleInjector(extensions: SparkSessionExtensions) {
val spark: SparkInjector = new SparkInjector(extensions)
val gluten: GlutenInjector = new GlutenInjector()

private[extension] def inject(extensions: SparkSessionExtensions): Unit = {
spark.inject(extensions)
// The regular Spark rules already injected with the `injectRules` of `RuleApi` directly.
// Only inject the Spark columnar rule here.
gluten.inject(extensions)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,68 +16,8 @@
*/
package org.apache.gluten.extension.injector

import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan

import scala.collection.mutable
import org.apache.spark.sql.SparkSessionExtensions

/** Injector used to inject query planner rules into Spark. */
class SparkInjector private[injector] {
private type RuleBuilder = SparkSession => Rule[LogicalPlan]
private type StrategyBuilder = SparkSession => Strategy
private type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
private type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder)
private type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan]

private val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder]
private val parserBuilders = mutable.Buffer.empty[ParserBuilder]
private val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
private val optimizerRules = mutable.Buffer.empty[RuleBuilder]
private val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder]
private val injectedFunctions = mutable.Buffer.empty[FunctionDescription]
private val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]

def injectQueryStagePrepRule(builder: QueryStagePrepRuleBuilder): Unit = {
queryStagePrepRuleBuilders += builder
}

def injectParser(builder: ParserBuilder): Unit = {
parserBuilders += builder
}

def injectResolutionRule(builder: RuleBuilder): Unit = {
resolutionRuleBuilders += builder
}

def injectOptimizerRule(builder: RuleBuilder): Unit = {
optimizerRules += builder
}

def injectPlannerStrategy(builder: StrategyBuilder): Unit = {
plannerStrategyBuilders += builder
}

def injectFunction(functionDescription: FunctionDescription): Unit = {
injectedFunctions += functionDescription
}

def injectPostHocResolutionRule(builder: RuleBuilder): Unit = {
postHocResolutionRuleBuilders += builder
}

private[injector] def inject(extensions: SparkSessionExtensions): Unit = {
queryStagePrepRuleBuilders.foreach(extensions.injectQueryStagePrepRule)
parserBuilders.foreach(extensions.injectParser)
resolutionRuleBuilders.foreach(extensions.injectResolutionRule)
optimizerRules.foreach(extensions.injectOptimizerRule)
plannerStrategyBuilders.foreach(extensions.injectPlannerStrategy)
injectedFunctions.foreach(extensions.injectFunction)
postHocResolutionRuleBuilders.foreach(extensions.injectPostHocResolutionRule)
}
}
@deprecated("This class is deprecated and will be removed in future versions.", since = "1.3.0")
class SparkInjector private[injector] (val extensions: SparkSessionExtensions) {}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ package org.apache.gluten.backendsapi
import org.apache.gluten.extension.injector.RuleInjector

trait RuleApi {
// Injects all Gluten / Spark query planner rules used by the backend.
// Injects all Spark query planner rules used by the Gluten backend.
def injectRules(injector: RuleInjector): Unit
}

0 comments on commit 5897e55

Please sign in to comment.