diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala index 748e6deb3..f1025d4fd 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala @@ -209,8 +209,15 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { right) case plan: BroadcastQueryStageExec => logDebug( - s"Columnar Processing for ${plan.getClass} is currently supported, actual plan is ${plan.plan.getClass}.") - plan + s"Columnar Processing for ${plan.getClass} is currently supported, actual plan is ${plan.plan}.") + plan.plan match { + case ReusedExchangeExec(_, originalBroadcastPlan: ColumnarBroadcastExchangeAdaptor) => + val newBroadcast = BroadcastExchangeExec( + originalBroadcastPlan.mode, + DataToArrowColumnarExec(plan.plan, 1)) + SparkShimLoader.getSparkShims.newBroadcastQueryStageExec(plan.id, newBroadcast) + case other => plan + } case plan: BroadcastExchangeExec => val child = replaceWithColumnarPlan(plan.child) logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") @@ -368,6 +375,17 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { var isSupportAdaptive: Boolean = true def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match { + // To get ColumnarBroadcastExchangeExec back from the fallback that for DPP reuse. + case RowToColumnarExec(broadcastQueryStageExec: BroadcastQueryStageExec) + if (broadcastQueryStageExec.plan match { + case BroadcastExchangeExec(_, _: DataToArrowColumnarExec) => true + case _ => false + }) => + logDebug(s"Due to a fallback of BHJ inserted into plan." + + s" See above override in BroadcastQueryStageExec") + val localBroadcastXchg = broadcastQueryStageExec.plan.asInstanceOf[BroadcastExchangeExec] + val dataToArrowColumnar = localBroadcastXchg.child.asInstanceOf[DataToArrowColumnarExec] + ColumnarBroadcastExchangeExec(localBroadcastXchg.mode, dataToArrowColumnar) case plan: RowToColumnarExec => val child = replaceWithColumnarPlan(plan.child) if (columnarConf.enableArrowRowToColumnar) { diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index d0839d484..d1f90d580 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -26,11 +26,11 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListe import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeAdaptor, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, REPARTITION, REPARTITION_WITH_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} +import org.apache.spark.sql.execution.exchange.{Exchange, REPARTITION, REPARTITION_WITH_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate @@ -537,7 +537,7 @@ class AdaptiveQueryExecSuite // Even with local shuffle reader, the query stage reuse can also work. val ex = findReusedExchange(adaptivePlan) assert(ex.nonEmpty) - assert(ex.head.child.isInstanceOf[BroadcastExchangeExec]) + assert(ex.head.child.isInstanceOf[ColumnarBroadcastExchangeAdaptor]) val sub = findReusedSubquery(adaptivePlan) assert(sub.isEmpty) }