diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 8f7ac330cba5..1d204f73e1a2 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -22,8 +22,8 @@ import org.apache.gluten.extension.columnar._ import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides} import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} -import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector} import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector, RasInjector} +import org.apache.gluten.extension.injector.RuleInjector import org.apache.gluten.parser.{GlutenCacheFilesSqlParser, GlutenClickhouseSqlParser} import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.utils.PhysicalPlanSelector @@ -37,30 +37,30 @@ class CHRuleApi extends RuleApi { override def injectRules(injector: RuleInjector): Unit = { injector.gluten.skipOn(PhysicalPlanSelector.skipCond) - injectSpark(injector.spark) + injectSpark(injector) injectLegacy(injector.gluten.legacy) injectRas(injector.gluten.ras) } } private object CHRuleApi { - def injectSpark(injector: SparkInjector): Unit = { - // Regular Spark rules. - injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply) - injector.injectQueryStagePrepRule(spark => CHAQEPropagateEmptyRelation(spark)) - injector.injectParser( + def injectSpark(injector: RuleInjector): Unit = { + // Inject the regular Spark rules directly. + injector.extensions.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply) + injector.extensions.injectQueryStagePrepRule(spark => CHAQEPropagateEmptyRelation(spark)) + injector.extensions.injectParser( (spark, parserInterface) => new GlutenCacheFilesSqlParser(spark, parserInterface)) - injector.injectParser( + injector.extensions.injectParser( (spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface)) - injector.injectResolutionRule( + injector.extensions.injectResolutionRule( spark => new RewriteToDateExpresstionRule(spark, spark.sessionState.conf)) - injector.injectResolutionRule( + injector.extensions.injectResolutionRule( spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf)) - injector.injectOptimizerRule( + injector.extensions.injectOptimizerRule( spark => new CommonSubexpressionEliminateRule(spark, spark.sessionState.conf)) - injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark)) - injector.injectOptimizerRule(_ => CountDistinctWithoutExpand) - injector.injectOptimizerRule(_ => EqualToRewrite) + injector.extensions.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark)) + injector.extensions.injectOptimizerRule(_ => CountDistinctWithoutExpand) + injector.extensions.injectOptimizerRule(_ => EqualToRewrite) } def injectLegacy(injector: LegacyInjector): Unit = { diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index ffbb393bef17..e5181ec82a48 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -25,8 +25,8 @@ import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTable import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} -import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector} import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector, RasInjector} +import org.apache.gluten.extension.injector.RuleInjector import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.utils.PhysicalPlanSelector @@ -38,18 +38,18 @@ class VeloxRuleApi extends RuleApi { override def injectRules(injector: RuleInjector): Unit = { injector.gluten.skipOn(PhysicalPlanSelector.skipCond) - injectSpark(injector.spark) + injectSpark(injector) injectLegacy(injector.gluten.legacy) injectRas(injector.gluten.ras) } } private object VeloxRuleApi { - def injectSpark(injector: SparkInjector): Unit = { - // Regular Spark rules. - injector.injectOptimizerRule(CollectRewriteRule.apply) - injector.injectOptimizerRule(HLLRewriteRule.apply) - injector.injectPostHocResolutionRule(ArrowConvertorRule.apply) + def injectSpark(injector: RuleInjector): Unit = { + // Inject the regular Spark rules directly. + injector.extensions.injectOptimizerRule(CollectRewriteRule.apply) + injector.extensions.injectOptimizerRule(HLLRewriteRule.apply) + injector.extensions.injectPostHocResolutionRule(ArrowConvertorRule.apply) } def injectLegacy(injector: LegacyInjector): Unit = { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala index 710d96c54e25..697b41da9edc 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.SparkSessionExtensions private[gluten] class GlutenSessionExtensions extends (SparkSessionExtensions => Unit) { override def apply(exts: SparkSessionExtensions): Unit = { - val injector = new RuleInjector() + val injector = new RuleInjector(exts) Backend.get().injectRules(injector) - injector.inject(exts) + injector.inject() } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala index ca76e61b7bb0..db3310151fa8 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala @@ -35,8 +35,7 @@ class GlutenInjector private[injector] { val ras: RasInjector = new RasInjector() private[injector] def inject(extensions: SparkSessionExtensions): Unit = { - val ruleBuilder = (session: SparkSession) => new GlutenColumnarRule(session, applier) - extensions.injectColumnar(session => ruleBuilder(session)) + extensions.injectColumnar(session => new GlutenColumnarRule(session, applier)) } private def applier(session: SparkSession): ColumnarRuleApplier = { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala index bccbd38b26d5..f1f5d25e2838 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala @@ -19,12 +19,12 @@ package org.apache.gluten.extension.injector import org.apache.spark.sql.SparkSessionExtensions /** Injector used to inject query planner rules into Spark and Gluten. */ -class RuleInjector { - val spark: SparkInjector = new SparkInjector() +class RuleInjector(val extensions: SparkSessionExtensions) { val gluten: GlutenInjector = new GlutenInjector() - private[extension] def inject(extensions: SparkSessionExtensions): Unit = { - spark.inject(extensions) + private[extension] def inject(): Unit = { + // The regular Spark rules already injected with the `injectRules` of `RuleApi` directly. + // Only inject the Spark columnar rule here. gluten.inject(extensions) } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala deleted file mode 100644 index 6935e61bdd5b..000000000000 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.extension.injector - -import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy} -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.ExpressionInfo -import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.SparkPlan - -import scala.collection.mutable - -/** Injector used to inject query planner rules into Spark. */ -class SparkInjector private[injector] { - private type RuleBuilder = SparkSession => Rule[LogicalPlan] - private type StrategyBuilder = SparkSession => Strategy - private type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface - private type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) - private type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan] - - private val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder] - private val parserBuilders = mutable.Buffer.empty[ParserBuilder] - private val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] - private val optimizerRules = mutable.Buffer.empty[RuleBuilder] - private val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] - private val injectedFunctions = mutable.Buffer.empty[FunctionDescription] - private val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] - - def injectQueryStagePrepRule(builder: QueryStagePrepRuleBuilder): Unit = { - queryStagePrepRuleBuilders += builder - } - - def injectParser(builder: ParserBuilder): Unit = { - parserBuilders += builder - } - - def injectResolutionRule(builder: RuleBuilder): Unit = { - resolutionRuleBuilders += builder - } - - def injectOptimizerRule(builder: RuleBuilder): Unit = { - optimizerRules += builder - } - - def injectPlannerStrategy(builder: StrategyBuilder): Unit = { - plannerStrategyBuilders += builder - } - - def injectFunction(functionDescription: FunctionDescription): Unit = { - injectedFunctions += functionDescription - } - - def injectPostHocResolutionRule(builder: RuleBuilder): Unit = { - postHocResolutionRuleBuilders += builder - } - - private[injector] def inject(extensions: SparkSessionExtensions): Unit = { - queryStagePrepRuleBuilders.foreach(extensions.injectQueryStagePrepRule) - parserBuilders.foreach(extensions.injectParser) - resolutionRuleBuilders.foreach(extensions.injectResolutionRule) - optimizerRules.foreach(extensions.injectOptimizerRule) - plannerStrategyBuilders.foreach(extensions.injectPlannerStrategy) - injectedFunctions.foreach(extensions.injectFunction) - postHocResolutionRuleBuilders.foreach(extensions.injectPostHocResolutionRule) - } -} diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala index f8669a6fe049..7c4c8577f421 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala @@ -19,6 +19,6 @@ package org.apache.gluten.backendsapi import org.apache.gluten.extension.injector.RuleInjector trait RuleApi { - // Injects all Gluten / Spark query planner rules used by the backend. + // Injects all Spark query planner rules used by the Gluten backend. def injectRules(injector: RuleInjector): Unit }