diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index bc924e6978ddc..e61952e9b206b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -167,13 +167,18 @@ case class AdaptiveSparkPlanExec( // Start materialization of all new stages and fail fast if any stages failed eagerly result.newStages.foreach { stage => try { - stage.materialize().onComplete { res => + val stageEval = stage.materialize() + stageEval.map(_.onComplete { res => if (res.isSuccess) { events.offer(StageSuccess(stage, res.get)) } else { events.offer(StageFailure(stage, res.failed.get)) } - }(AdaptiveSparkPlanExec.executionContext) + }(AdaptiveSparkPlanExec.executionContext)) + // This stage was materialized. Just put materialized result. + if (stageEval.isEmpty) { + events.offer(StageSuccess(stage, stage.resultOption.get().get)) + } } catch { case e: Throwable => cleanUpAndThrowException(Seq(e), Some(stage.id)) @@ -329,13 +334,13 @@ case class AdaptiveSparkPlanExec( CreateStageResult( newPlan = stage, allChildStagesMaterialized = isMaterialized, - newStages = if (isMaterialized) Seq.empty else Seq(stage)) + newStages = Seq(stage)) case _ => val result = createQueryStages(e.child) val newPlan = e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange] - // Create a query stage only when all the child query stages are ready. - if (result.allChildStagesMaterialized) { + // Create a query stage only when no query stages are created in the child. + if (result.newStages.isEmpty) { var newStage = newQueryStage(newPlan) if (conf.exchangeReuseEnabled) { // Check the `stageCache` again for reuse. If a match is found, ditch the new stage @@ -350,7 +355,7 @@ case class AdaptiveSparkPlanExec( CreateStageResult( newPlan = newStage, allChildStagesMaterialized = isMaterialized, - newStages = if (isMaterialized) Seq.empty else Seq(newStage)) + newStages = Seq(newStage)) } else { CreateStageResult(newPlan = newPlan, allChildStagesMaterialized = false, newStages = result.newStages) @@ -365,11 +370,25 @@ case class AdaptiveSparkPlanExec( if (plan.children.isEmpty) { CreateStageResult(newPlan = plan, allChildStagesMaterialized = true, newStages = Seq.empty) } else { - val results = plan.children.map(createQueryStages) + var foundExchange = false + val (newPlans, materializedStatuses, newStages) = plan.children.map { child => + if (!foundExchange) { + val stage = createQueryStages(child) + if (stage.newStages.nonEmpty) { + // Once we created a query stage, stopping creating query stages in next calls. + foundExchange = true + (stage.newPlan, stage.allChildStagesMaterialized, stage.newStages) + } else { + (child, true, Seq.empty) + } + } else { + (child, true, Seq.empty) + } + }.unzip3 CreateStageResult( - newPlan = plan.withNewChildren(results.map(_.newPlan)), - allChildStagesMaterialized = results.forall(_.allChildStagesMaterialized), - newStages = results.flatMap(_.newStages)) + newPlan = plan.withNewChildren(newPlans), + allChildStagesMaterialized = materializedStatuses.forall(x => x), + newStages = newStages.flatten) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index 4e83b4344fbf0..97f645b7eb749 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -75,8 +75,12 @@ abstract class QueryStageExec extends LeafExecNode { * broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this * stage is ready. */ - final def materialize(): Future[Any] = executeQuery { - doMaterialize() + final def materialize(): Option[Future[Any]] = executeQuery { + if (_resultOption.get.isEmpty) { + Option(doMaterialize()) + } else { + None + } } def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 9fa97bffa8910..080f787a704ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -141,6 +141,10 @@ class AdaptiveQueryExecSuite val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) checkNumLocalShuffleReaders(adaptivePlan) + val localReaders = collect(adaptivePlan) { + case reader: CustomShuffleReaderExec if reader.isLocalReader => reader + } + assert(localReaders.length == 1) } } @@ -158,21 +162,14 @@ class AdaptiveQueryExecSuite val localReaders = collect(adaptivePlan) { case reader: CustomShuffleReaderExec if reader.isLocalReader => reader } - assert(localReaders.length == 2) + assert(localReaders.length == 1) val localShuffleRDD0 = localReaders(0).execute().asInstanceOf[ShuffledRowRDD] - val localShuffleRDD1 = localReaders(1).execute().asInstanceOf[ShuffledRowRDD] // The pre-shuffle partition size is [0, 0, 0, 72, 0] // We exclude the 0-size partitions, so only one partition, advisoryParallelism = 1 // the final parallelism is // math.max(1, advisoryParallelism / numMappers): math.max(1, 1/2) = 1 // and the partitions length is 1 * numMappers = 2 assert(localShuffleRDD0.getPartitions.length == 2) - // The pre-shuffle partition size is [0, 72, 0, 72, 126] - // We exclude the 0-size partitions, so only 3 partition, advisoryParallelism = 3 - // the final parallelism is - // math.max(1, advisoryParallelism / numMappers): math.max(1, 3/2) = 1 - // and the partitions length is 1 * numMappers = 2 - assert(localShuffleRDD1.getPartitions.length == 2) } } @@ -190,15 +187,11 @@ class AdaptiveQueryExecSuite val localReaders = collect(adaptivePlan) { case reader: CustomShuffleReaderExec if reader.isLocalReader => reader } - assert(localReaders.length == 2) + assert(localReaders.length == 1) val localShuffleRDD0 = localReaders(0).execute().asInstanceOf[ShuffledRowRDD] - val localShuffleRDD1 = localReaders(1).execute().asInstanceOf[ShuffledRowRDD] // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 // and the partitions length is 2 * numMappers = 4 assert(localShuffleRDD0.getPartitions.length == 4) - // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 - // and the partitions length is 2 * numMappers = 4 - assert(localShuffleRDD1.getPartitions.length == 4) } } @@ -432,9 +425,10 @@ class AdaptiveQueryExecSuite val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) checkNumLocalShuffleReaders(adaptivePlan) - // Even with local shuffle reader, the query stage reuse can also work. - val ex = findReusedExchange(adaptivePlan) - assert(ex.size == 1) + val localReaders = collect(adaptivePlan) { + case reader: CustomShuffleReaderExec if reader.isLocalReader => reader + } + assert(localReaders.length == 1) } } @@ -859,11 +853,7 @@ class AdaptiveQueryExecSuite val readers = collect(join.right) { case r: CustomShuffleReaderExec => r } - assert(readers.length == 1) - val reader = readers.head - assert(reader.isLocalReader) - assert(reader.metrics.keys.toSeq == Seq("numPartitions")) - assert(reader.metrics("numPartitions").value == reader.partitionSpecs.length) + assert(readers.length == 0) } withSQLConf(