Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32012][SQL] Incrementally create and materialize query stage to avoid unnecessary local shuffle #28846

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ case class AdaptiveSparkPlanExec(
val events = new LinkedBlockingQueue[StageMaterializationEvent]()
val errors = new mutable.ArrayBuffer[Throwable]()
var stagesToReplace = Seq.empty[QueryStageExec]
while (!result.allChildStagesMaterialized) {
while (!result.newStages.isEmpty) {
currentPhysicalPlan = result.newPlan
if (result.newStages.nonEmpty) {
stagesToReplace = result.newStages ++ stagesToReplace
Expand All @@ -172,13 +172,18 @@ case class AdaptiveSparkPlanExec(
// Start materialization of all new stages and fail fast if any stages failed eagerly
result.newStages.foreach { stage =>
try {
stage.materialize().onComplete { res =>
val stageEval = stage.materialize()
stageEval.map(_.onComplete { res =>
if (res.isSuccess) {
events.offer(StageSuccess(stage, res.get))
} else {
events.offer(StageFailure(stage, res.failed.get))
}
}(AdaptiveSparkPlanExec.executionContext)
}(AdaptiveSparkPlanExec.executionContext))
// This stage was materialized. Just put materialized result.
if (stageEval.isEmpty) {
events.offer(StageSuccess(stage, stage.resultOption.get().get))
}
} catch {
case e: Throwable =>
cleanUpAndThrowException(Seq(e), Some(stage.id))
Expand Down Expand Up @@ -335,13 +340,13 @@ case class AdaptiveSparkPlanExec(
CreateStageResult(
newPlan = stage,
allChildStagesMaterialized = isMaterialized,
newStages = if (isMaterialized) Seq.empty else Seq(stage))
newStages = Seq(stage))

case _ =>
val result = createQueryStages(e.child)
val newPlan = e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange]
// Create a query stage only when all the child query stages are ready.
if (result.allChildStagesMaterialized) {
// Create a query stage only when no query stages are created in the child.
if (result.newStages.isEmpty) {
var newStage = newQueryStage(newPlan)
if (conf.exchangeReuseEnabled) {
// Check the `stageCache` again for reuse. If a match is found, ditch the new stage
Expand All @@ -356,7 +361,7 @@ case class AdaptiveSparkPlanExec(
CreateStageResult(
newPlan = newStage,
allChildStagesMaterialized = isMaterialized,
newStages = if (isMaterialized) Seq.empty else Seq(newStage))
newStages = Seq(newStage))
} else {
CreateStageResult(newPlan = newPlan,
allChildStagesMaterialized = false, newStages = result.newStages)
Expand All @@ -371,11 +376,25 @@ case class AdaptiveSparkPlanExec(
if (plan.children.isEmpty) {
CreateStageResult(newPlan = plan, allChildStagesMaterialized = true, newStages = Seq.empty)
} else {
val results = plan.children.map(createQueryStages)
var foundExchange = false
val (newPlans, materializedStatuses, newStages) = plan.children.map { child =>
if (!foundExchange) {
val stage = createQueryStages(child)
if (stage.newStages.nonEmpty) {
// Once we created a query stage, stopping creating query stages in next calls.
foundExchange = true
(stage.newPlan, stage.allChildStagesMaterialized, stage.newStages)
} else {
(child, true, Seq.empty)
}
} else {
(child, true, Seq.empty)
}
}.unzip3
CreateStageResult(
newPlan = plan.withNewChildren(results.map(_.newPlan)),
allChildStagesMaterialized = results.forall(_.allChildStagesMaterialized),
newStages = results.flatMap(_.newStages))
newPlan = plan.withNewChildren(newPlans),
allChildStagesMaterialized = materializedStatuses.forall(x => x),
newStages = newStages.flatten)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,12 @@ abstract class QueryStageExec extends LeafExecNode {
* broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this
* stage is ready.
*/
final def materialize(): Future[Any] = executeQuery {
doMaterialize()
final def materialize(): Option[Future[Any]] = executeQuery {
if (_resultOption.get.isEmpty) {
Option(doMaterialize())
} else {
None
}
}

def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ class AdaptiveQueryExecSuite
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan)
val localReaders = collect(adaptivePlan) {
case reader: CustomShuffleReaderExec if reader.isLocalReader => reader
}
assert(localReaders.length == 1)
}
}

Expand All @@ -169,21 +173,14 @@ class AdaptiveQueryExecSuite
val localReaders = collect(adaptivePlan) {
case reader: CustomShuffleReaderExec if reader.isLocalReader => reader
}
assert(localReaders.length == 2)
assert(localReaders.length == 1)
val localShuffleRDD0 = localReaders(0).execute().asInstanceOf[ShuffledRowRDD]
val localShuffleRDD1 = localReaders(1).execute().asInstanceOf[ShuffledRowRDD]
// The pre-shuffle partition size is [0, 0, 0, 72, 0]
// We exclude the 0-size partitions, so only one partition, advisoryParallelism = 1
// the final parallelism is
// math.max(1, advisoryParallelism / numMappers): math.max(1, 1/2) = 1
// and the partitions length is 1 * numMappers = 2
assert(localShuffleRDD0.getPartitions.length == 2)
// The pre-shuffle partition size is [0, 72, 0, 72, 126]
// We exclude the 0-size partitions, so only 3 partition, advisoryParallelism = 3
// the final parallelism is
// math.max(1, advisoryParallelism / numMappers): math.max(1, 3/2) = 1
// and the partitions length is 1 * numMappers = 2
assert(localShuffleRDD1.getPartitions.length == 2)
}
}

Expand All @@ -201,15 +198,11 @@ class AdaptiveQueryExecSuite
val localReaders = collect(adaptivePlan) {
case reader: CustomShuffleReaderExec if reader.isLocalReader => reader
}
assert(localReaders.length == 2)
assert(localReaders.length == 1)
val localShuffleRDD0 = localReaders(0).execute().asInstanceOf[ShuffledRowRDD]
val localShuffleRDD1 = localReaders(1).execute().asInstanceOf[ShuffledRowRDD]
// the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2
// and the partitions length is 2 * numMappers = 4
assert(localShuffleRDD0.getPartitions.length == 4)
// the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2
// and the partitions length is 2 * numMappers = 4
assert(localShuffleRDD1.getPartitions.length == 4)
}
}

Expand Down Expand Up @@ -443,9 +436,10 @@ class AdaptiveQueryExecSuite
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan)
// Even with local shuffle reader, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.size == 1)
val localReaders = collect(adaptivePlan) {
case reader: CustomShuffleReaderExec if reader.isLocalReader => reader
}
assert(localReaders.length == 1)
}
}

Expand Down Expand Up @@ -871,11 +865,7 @@ class AdaptiveQueryExecSuite
val readers = collect(join.right) {
case r: CustomShuffleReaderExec => r
}
assert(readers.length == 1)
val reader = readers.head
assert(reader.isLocalReader)
assert(reader.metrics.keys.toSeq == Seq("numPartitions"))
assert(reader.metrics("numPartitions").value == reader.partitionSpecs.length)
assert(readers.length == 0)
}

withSQLConf(
Expand Down