diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 8f7ac330cba5..9b24fd6275e4 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -45,22 +45,22 @@ class CHRuleApi extends RuleApi { private object CHRuleApi { def injectSpark(injector: SparkInjector): Unit = { - // Regular Spark rules. - injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply) - injector.injectQueryStagePrepRule(spark => CHAQEPropagateEmptyRelation(spark)) - injector.injectParser( + // Inject the regular Spark rules directly. + injector.extensions.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply) + injector.extensions.injectQueryStagePrepRule(spark => CHAQEPropagateEmptyRelation(spark)) + injector.extensions.injectParser( (spark, parserInterface) => new GlutenCacheFilesSqlParser(spark, parserInterface)) - injector.injectParser( + injector.extensions.injectParser( (spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface)) - injector.injectResolutionRule( + injector.extensions.injectResolutionRule( spark => new RewriteToDateExpresstionRule(spark, spark.sessionState.conf)) - injector.injectResolutionRule( + injector.extensions.injectResolutionRule( spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf)) - injector.injectOptimizerRule( + injector.extensions.injectOptimizerRule( spark => new CommonSubexpressionEliminateRule(spark, spark.sessionState.conf)) - injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark)) - injector.injectOptimizerRule(_ => CountDistinctWithoutExpand) - injector.injectOptimizerRule(_ => EqualToRewrite) + injector.extensions.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark)) + injector.extensions.injectOptimizerRule(_ => CountDistinctWithoutExpand) + injector.extensions.injectOptimizerRule(_ => EqualToRewrite) } def injectLegacy(injector: LegacyInjector): Unit = { diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index ffbb393bef17..5a99f07c4ea1 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -46,10 +46,10 @@ class VeloxRuleApi extends RuleApi { private object VeloxRuleApi { def injectSpark(injector: SparkInjector): Unit = { - // Regular Spark rules. - injector.injectOptimizerRule(CollectRewriteRule.apply) - injector.injectOptimizerRule(HLLRewriteRule.apply) - injector.injectPostHocResolutionRule(ArrowConvertorRule.apply) + // Inject the regular Spark rules directly. + injector.extensions.injectOptimizerRule(CollectRewriteRule.apply) + injector.extensions.injectOptimizerRule(HLLRewriteRule.apply) + injector.extensions.injectPostHocResolutionRule(ArrowConvertorRule.apply) } def injectLegacy(injector: LegacyInjector): Unit = { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala index 710d96c54e25..a1091e73f8b2 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSessionExtensions private[gluten] class GlutenSessionExtensions extends (SparkSessionExtensions => Unit) { override def apply(exts: SparkSessionExtensions): Unit = { - val injector = new RuleInjector() + val injector = new RuleInjector(exts) Backend.get().injectRules(injector) injector.inject(exts) } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala index ca76e61b7bb0..db3310151fa8 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala @@ -35,8 +35,7 @@ class GlutenInjector private[injector] { val ras: RasInjector = new RasInjector() private[injector] def inject(extensions: SparkSessionExtensions): Unit = { - val ruleBuilder = (session: SparkSession) => new GlutenColumnarRule(session, applier) - extensions.injectColumnar(session => ruleBuilder(session)) + extensions.injectColumnar(session => new GlutenColumnarRule(session, applier)) } private def applier(session: SparkSession): ColumnarRuleApplier = { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala index bccbd38b26d5..061c338f1a6b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala @@ -19,12 +19,13 @@ package org.apache.gluten.extension.injector import org.apache.spark.sql.SparkSessionExtensions /** Injector used to inject query planner rules into Spark and Gluten. */ -class RuleInjector { - val spark: SparkInjector = new SparkInjector() +class RuleInjector(extensions: SparkSessionExtensions) { + val spark: SparkInjector = new SparkInjector(extensions) val gluten: GlutenInjector = new GlutenInjector() private[extension] def inject(extensions: SparkSessionExtensions): Unit = { - spark.inject(extensions) + // The regular Spark rules already injected with the `injectRules` of `RuleApi` directly. + // Only inject the Spark columnar rule here. gluten.inject(extensions) } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala index 6935e61bdd5b..2a08eb38671e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala @@ -16,68 +16,8 @@ */ package org.apache.gluten.extension.injector -import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy} -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.ExpressionInfo -import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.SparkPlan - -import scala.collection.mutable +import org.apache.spark.sql.SparkSessionExtensions /** Injector used to inject query planner rules into Spark. */ -class SparkInjector private[injector] { - private type RuleBuilder = SparkSession => Rule[LogicalPlan] - private type StrategyBuilder = SparkSession => Strategy - private type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface - private type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) - private type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan] - - private val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder] - private val parserBuilders = mutable.Buffer.empty[ParserBuilder] - private val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] - private val optimizerRules = mutable.Buffer.empty[RuleBuilder] - private val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] - private val injectedFunctions = mutable.Buffer.empty[FunctionDescription] - private val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] - - def injectQueryStagePrepRule(builder: QueryStagePrepRuleBuilder): Unit = { - queryStagePrepRuleBuilders += builder - } - - def injectParser(builder: ParserBuilder): Unit = { - parserBuilders += builder - } - - def injectResolutionRule(builder: RuleBuilder): Unit = { - resolutionRuleBuilders += builder - } - - def injectOptimizerRule(builder: RuleBuilder): Unit = { - optimizerRules += builder - } - - def injectPlannerStrategy(builder: StrategyBuilder): Unit = { - plannerStrategyBuilders += builder - } - - def injectFunction(functionDescription: FunctionDescription): Unit = { - injectedFunctions += functionDescription - } - - def injectPostHocResolutionRule(builder: RuleBuilder): Unit = { - postHocResolutionRuleBuilders += builder - } - - private[injector] def inject(extensions: SparkSessionExtensions): Unit = { - queryStagePrepRuleBuilders.foreach(extensions.injectQueryStagePrepRule) - parserBuilders.foreach(extensions.injectParser) - resolutionRuleBuilders.foreach(extensions.injectResolutionRule) - optimizerRules.foreach(extensions.injectOptimizerRule) - plannerStrategyBuilders.foreach(extensions.injectPlannerStrategy) - injectedFunctions.foreach(extensions.injectFunction) - postHocResolutionRuleBuilders.foreach(extensions.injectPostHocResolutionRule) - } -} +@deprecated("This class is deprecated and will be removed in future versions.", since = "1.3.0") +class SparkInjector private[injector] (val extensions: SparkSessionExtensions) {} diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala index f8669a6fe049..7c4c8577f421 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala @@ -19,6 +19,6 @@ package org.apache.gluten.backendsapi import org.apache.gluten.extension.injector.RuleInjector trait RuleApi { - // Injects all Gluten / Spark query planner rules used by the backend. + // Injects all Spark query planner rules used by the Gluten backend. def injectRules(injector: RuleInjector): Unit }