From 479d56be0a530d0d7fb13196b5120f0202cfc704 Mon Sep 17 00:00:00 2001 From: LantaoJin Date: Tue, 7 Jul 2020 15:00:44 +0800 Subject: [PATCH] [SPARK-32201][SQL] More general skew join pattern matching --- .../adaptive/OptimizeSkewedJoin.scala | 222 +++++++++++------- .../adaptive/AdaptiveQueryExecSuite.scala | 78 ++++-- 2 files changed, 196 insertions(+), 104 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 396c9c9d6b4e5..1bf72c22b3b59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -144,6 +144,21 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { sizes.sum / sizes.length } + private def findShuffleStage(plan: SparkPlan): Option[ShuffleStageInfo] = { + plan collectFirst { + case _ @ ShuffleStage(shuffleStageInfo) => + shuffleStageInfo + } + } + + private def replaceSkewedShufleReader( + smj: SparkPlan, newCtm: CustomShuffleReaderExec): SparkPlan = { + smj transformUp { + case _ @ CustomShuffleReaderExec(child, _) if child.sameResult(newCtm.child) => + newCtm + } + } + /* * This method aim to optimize the skewed join with the following steps: * 1. Check whether the shuffle partition is skewed based on the median size @@ -158,95 +173,107 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { */ def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp { case smj @ SortMergeJoinExec(_, _, joinType, _, - s1 @ SortExec(_, _, ShuffleStage(left: ShuffleStageInfo), _), - s2 @ SortExec(_, _, ShuffleStage(right: ShuffleStageInfo), _), _) + s1 @ SortExec(_, _, _, _), + s2 @ SortExec(_, _, _, _), _) if supportedJoinTypes.contains(joinType) => - assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length) - val numPartitions = left.partitionsWithSizes.length - // We use the median size of the original shuffle partitions to detect skewed partitions. - val leftMedSize = medianSize(left.mapStats) - val rightMedSize = medianSize(right.mapStats) - logDebug( - s""" - |Optimizing skewed join. - |Left side partitions size info: - |${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)} - |Right side partitions size info: - |${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)} - """.stripMargin) - val canSplitLeft = canSplitLeftSide(joinType) - val canSplitRight = canSplitRightSide(joinType) - // We use the actual partition sizes (may be coalesced) to calculate target size, so that - // the final data distribution is even (coalesced partitions + split partitions). - val leftActualSizes = left.partitionsWithSizes.map(_._2) - val rightActualSizes = right.partitionsWithSizes.map(_._2) - val leftTargetSize = targetSize(leftActualSizes, leftMedSize) - val rightTargetSize = targetSize(rightActualSizes, rightMedSize) - - val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] - val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] - var numSkewedLeft = 0 - var numSkewedRight = 0 - for (partitionIndex <- 0 until numPartitions) { - val leftActualSize = leftActualSizes(partitionIndex) - val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft - val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1 - val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex - - val rightActualSize = rightActualSizes(partitionIndex) - val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight - val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1 - val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex - - // A skewed partition should never be coalesced, but skip it here just to be safe. - val leftParts = if (isLeftSkew && !isLeftCoalesced) { - val reducerId = leftPartSpec.startReducerIndex - val skewSpecs = createSkewPartitionSpecs( - left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize) - if (skewSpecs.isDefined) { - logDebug(s"Left side partition $partitionIndex " + - s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " + - s"split it into ${skewSpecs.get.length} parts.") - numSkewedLeft += 1 + // find the shuffleStage from the plan tree + val leftOpt = findShuffleStage(s1) + val rightOpt = findShuffleStage(s2) + if (leftOpt.isEmpty || rightOpt.isEmpty) { + smj + } else { + val left = leftOpt.get + val right = rightOpt.get + assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length) + val numPartitions = left.partitionsWithSizes.length + // We use the median size of the original shuffle partitions to detect skewed partitions. + val leftMedSize = medianSize(left.mapStats) + val rightMedSize = medianSize(right.mapStats) + logDebug( + s""" + |Optimizing skewed join. + |Left side partitions size info: + |${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)} + + |Right side partitio + + |${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)} + """.stripMargin) + val canSplitLeft = canSplitLeftSide(joinType) + val canSplitRight = canSplitRightSide(joinType) + // We use the actual partition sizes (may be coalesced) to calculate target size, so that + // the final data distribution is even (coalesced partitions + split partitions). + val leftActualSizes = left.partitionsWithSizes.map(_._2) + val rightActualSizes = right.partitionsWithSizes.map(_._2) + val leftTargetSize = targetSize(leftActualSizes, leftMedSize) + val rightTargetSize = targetSize(rightActualSizes, rightMedSize) + + val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] + val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] + var numSkewedLeft = 0 + var numSkewedRight = 0 + for (partitionIndex <- 0 until numPartitions) { + val leftActualSize = leftActualSizes(partitionIndex) + val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft + val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1 + val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex + + val rightActualSize = rightActualSizes(partitionIndex) + val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight + val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1 + val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex + + // A skewed partition should never be coalesced, but skip it here just to be safe. + val leftParts = if (isLeftSkew && !isLeftCoalesced) { + val reducerId = leftPartSpec.startReducerIndex + val skewSpecs = createSkewPartitionSpecs( + left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize) + if (skewSpecs.isDefined) { + logDebug(s"Left side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " + + s"split it into ${skewSpecs.get.length} parts.") + numSkewedLeft += 1 + } + skewSpecs.getOrElse(Seq(leftPartSpec)) + } else { + Seq(leftPartSpec) } - skewSpecs.getOrElse(Seq(leftPartSpec)) - } else { - Seq(leftPartSpec) - } - // A skewed partition should never be coalesced, but skip it here just to be safe. - val rightParts = if (isRightSkew && !isRightCoalesced) { - val reducerId = rightPartSpec.startReducerIndex - val skewSpecs = createSkewPartitionSpecs( - right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize) - if (skewSpecs.isDefined) { - logDebug(s"Right side partition $partitionIndex " + - s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " + - s"split it into ${skewSpecs.get.length} parts.") - numSkewedRight += 1 + // A skewed partition should never be coalesced, but skip it here just to be safe. + val rightParts = if (isRightSkew && !isRightCoalesced) { + val reducerId = rightPartSpec.startReducerIndex + val skewSpecs = createSkewPartitionSpecs( + right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize) + if (skewSpecs.isDefined) { + logDebug(s"Right side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " + + s"split it into ${skewSpecs.get.length} parts.") + numSkewedRight += 1 + } + skewSpecs.getOrElse(Seq(rightPartSpec)) + } else { + Seq(rightPartSpec) } - skewSpecs.getOrElse(Seq(rightPartSpec)) - } else { - Seq(rightPartSpec) - } - for { - leftSidePartition <- leftParts - rightSidePartition <- rightParts - } { - leftSidePartitions += leftSidePartition - rightSidePartitions += rightSidePartition + for { + leftSidePartition <- leftParts + rightSidePartition <- rightParts + } { + leftSidePartitions += leftSidePartition + rightSidePartitions += rightSidePartition + } } - } - logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight") - if (numSkewedLeft > 0 || numSkewedRight > 0) { - val newLeft = CustomShuffleReaderExec(left.shuffleStage, leftSidePartitions) - val newRight = CustomShuffleReaderExec(right.shuffleStage, rightSidePartitions) - smj.copy( - left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true) - } else { - smj + logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight") + if (numSkewedLeft > 0 || numSkewedRight > 0) { + val newLeft = CustomShuffleReaderExec(left.shuffleStage, leftSidePartitions) + val newRight = CustomShuffleReaderExec(right.shuffleStage, rightSidePartitions) + val newSmj = replaceSkewedShufleReader( + replaceSkewedShufleReader(smj, newLeft), newRight).asInstanceOf[SortMergeJoinExec] + newSmj.copy(isSkewJoin = true) + } else { + smj + } } } @@ -263,15 +290,19 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { val shuffleStages = collectShuffleStages(plan) if (shuffleStages.length == 2) { - // When multi table join, there will be too many complex combination to consider. - // Currently we only handle 2 table join like following use case. + // SPARK-32201. Skew join supports below pattern, ".." may contain any number of nodes, + // includes such as BroadcastHashJoinExec. So it can handle more than two tables join. // SMJ // Sort - // Shuffle + // .. + // Shuffle // Sort - // Shuffle + // .. + // Shuffle val optimizePlan = optimizeSkewJoin(plan) - val numShuffles = ensureRequirements.apply(optimizePlan).collect { + val ensuredPlan = ensureRequirements.apply(optimizePlan) + println(ensuredPlan) + val numShuffles = ensuredPlan.collect { case e: ShuffleExchangeExec => e }.length @@ -316,6 +347,23 @@ private object ShuffleStage { } Some(ShuffleStageInfo(s, mapStats, partitions)) + case _: LeafExecNode => None + + case _ @ UnaryExecNode((_, ShuffleStage(ss: ShuffleStageInfo))) => + Some(ss) + + case b: BinaryExecNode => + b.left match { + case _ @ ShuffleStage(ss: ShuffleStageInfo) => + Some(ss) + case _ => + b.right match { + case _ @ ShuffleStage(ss: ShuffleStageInfo) => + Some(ss) + case _ => None + } + } + case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index c696d3f648ed1..7865ee8bb67a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -692,23 +692,6 @@ class AdaptiveQueryExecSuite 'id as "value2") .createOrReplaceTempView("skewData2") - def checkSkewJoin( - joins: Seq[SortMergeJoinExec], - leftSkewNum: Int, - rightSkewNum: Int): Unit = { - assert(joins.size == 1 && joins.head.isSkewJoin) - assert(joins.head.left.collect { - case r: CustomShuffleReaderExec => r - }.head.partitionSpecs.collect { - case p: PartialReducerPartitionSpec => p.reducerIndex - }.distinct.length == leftSkewNum) - assert(joins.head.right.collect { - case r: CustomShuffleReaderExec => r - }.head.partitionSpecs.collect { - case p: PartialReducerPartitionSpec => p.reducerIndex - }.distinct.length == rightSkewNum) - } - // skewed inner join optimization val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM skewData1 join skewData2 ON key1 = key2") @@ -730,6 +713,67 @@ class AdaptiveQueryExecSuite } } + private def checkSkewJoin( + joins: Seq[SortMergeJoinExec], + leftSkewNum: Int, + rightSkewNum: Int): Unit = { + assert(joins.size == 1 && joins.head.isSkewJoin) + assert(joins.head.left.collect { + case r: CustomShuffleReaderExec => r + }.head.partitionSpecs.collect { + case p: PartialReducerPartitionSpec => p.reducerIndex + }.distinct.length == leftSkewNum) + assert(joins.head.right.collect { + case r: CustomShuffleReaderExec => r + }.head.partitionSpecs.collect { + case p: PartialReducerPartitionSpec => p.reducerIndex + }.distinct.length == rightSkewNum) + } + + test("SPARK-32201: handle general skew join pattern") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") { + withTempView("skewData1", "skewData2") { + spark + .range(0, 1000, 1, 10) + .select( + when('id < 250, 249) + .when('id >= 750, 1000) + .otherwise('id).as("key1"), + 'id as "value1") + .createOrReplaceTempView("skewData1") + + spark + .range(0, 1000, 1, 10) + .select( + when('id < 250, 249) + .otherwise('id).as("key2"), + 'id as "value2") + .createOrReplaceTempView("skewData2") + val sqlText = + """ + |SELECT * FROM + | skewData1 AS data1 + | INNER JOIN + | ( + | SELECT skewData2.key2, sum(skewData2.value2) AS sum2 + | FROM skewData2 GROUP BY skewData2.key2 + | ) AS data2 + |ON data1.key1 = data2.key2 + |""".stripMargin + + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(sqlText) + val innerSmj = findTopLevelSortMergeJoin(adaptivePlan) + checkSkewJoin(innerSmj, 2, 0) + } + } + } + test("SPARK-30291: AQE should catch the exceptions when doing materialize") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {