Skip to content

Commit

Permalink
[SPARK-32201][SQL] More general skew join pattern matching
Browse files Browse the repository at this point in the history
  • Loading branch information
LantaoJin committed Jul 7, 2020
1 parent 5d296ed commit 479d56b
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,21 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
sizes.sum / sizes.length
}

private def findShuffleStage(plan: SparkPlan): Option[ShuffleStageInfo] = {
plan collectFirst {
case _ @ ShuffleStage(shuffleStageInfo) =>
shuffleStageInfo
}
}

private def replaceSkewedShufleReader(
smj: SparkPlan, newCtm: CustomShuffleReaderExec): SparkPlan = {
smj transformUp {
case _ @ CustomShuffleReaderExec(child, _) if child.sameResult(newCtm.child) =>
newCtm
}
}

/*
* This method aim to optimize the skewed join with the following steps:
* 1. Check whether the shuffle partition is skewed based on the median size
Expand All @@ -158,95 +173,107 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
*/
def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
case smj @ SortMergeJoinExec(_, _, joinType, _,
s1 @ SortExec(_, _, ShuffleStage(left: ShuffleStageInfo), _),
s2 @ SortExec(_, _, ShuffleStage(right: ShuffleStageInfo), _), _)
s1 @ SortExec(_, _, _, _),
s2 @ SortExec(_, _, _, _), _)
if supportedJoinTypes.contains(joinType) =>
assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
val numPartitions = left.partitionsWithSizes.length
// We use the median size of the original shuffle partitions to detect skewed partitions.
val leftMedSize = medianSize(left.mapStats)
val rightMedSize = medianSize(right.mapStats)
logDebug(
s"""
|Optimizing skewed join.
|Left side partitions size info:
|${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}
|Right side partitions size info:
|${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
""".stripMargin)
val canSplitLeft = canSplitLeftSide(joinType)
val canSplitRight = canSplitRightSide(joinType)
// We use the actual partition sizes (may be coalesced) to calculate target size, so that
// the final data distribution is even (coalesced partitions + split partitions).
val leftActualSizes = left.partitionsWithSizes.map(_._2)
val rightActualSizes = right.partitionsWithSizes.map(_._2)
val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
val rightTargetSize = targetSize(rightActualSizes, rightMedSize)

val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
var numSkewedLeft = 0
var numSkewedRight = 0
for (partitionIndex <- 0 until numPartitions) {
val leftActualSize = leftActualSizes(partitionIndex)
val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft
val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex

val rightActualSize = rightActualSizes(partitionIndex)
val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight
val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex

// A skewed partition should never be coalesced, but skip it here just to be safe.
val leftParts = if (isLeftSkew && !isLeftCoalesced) {
val reducerId = leftPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Left side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " +
s"split it into ${skewSpecs.get.length} parts.")
numSkewedLeft += 1
// find the shuffleStage from the plan tree
val leftOpt = findShuffleStage(s1)
val rightOpt = findShuffleStage(s2)
if (leftOpt.isEmpty || rightOpt.isEmpty) {
smj
} else {
val left = leftOpt.get
val right = rightOpt.get
assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
val numPartitions = left.partitionsWithSizes.length
// We use the median size of the original shuffle partitions to detect skewed partitions.
val leftMedSize = medianSize(left.mapStats)
val rightMedSize = medianSize(right.mapStats)
logDebug(
s"""
|Optimizing skewed join.
|Left side partitions size info:
|${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}

|Right side partitio

|${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
""".stripMargin)
val canSplitLeft = canSplitLeftSide(joinType)
val canSplitRight = canSplitRightSide(joinType)
// We use the actual partition sizes (may be coalesced) to calculate target size, so that
// the final data distribution is even (coalesced partitions + split partitions).
val leftActualSizes = left.partitionsWithSizes.map(_._2)
val rightActualSizes = right.partitionsWithSizes.map(_._2)
val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
val rightTargetSize = targetSize(rightActualSizes, rightMedSize)

val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
var numSkewedLeft = 0
var numSkewedRight = 0
for (partitionIndex <- 0 until numPartitions) {
val leftActualSize = leftActualSizes(partitionIndex)
val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft
val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex

val rightActualSize = rightActualSizes(partitionIndex)
val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight
val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex

// A skewed partition should never be coalesced, but skip it here just to be safe.
val leftParts = if (isLeftSkew && !isLeftCoalesced) {
val reducerId = leftPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Left side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " +
s"split it into ${skewSpecs.get.length} parts.")
numSkewedLeft += 1
}
skewSpecs.getOrElse(Seq(leftPartSpec))
} else {
Seq(leftPartSpec)
}
skewSpecs.getOrElse(Seq(leftPartSpec))
} else {
Seq(leftPartSpec)
}

// A skewed partition should never be coalesced, but skip it here just to be safe.
val rightParts = if (isRightSkew && !isRightCoalesced) {
val reducerId = rightPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Right side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " +
s"split it into ${skewSpecs.get.length} parts.")
numSkewedRight += 1
// A skewed partition should never be coalesced, but skip it here just to be safe.
val rightParts = if (isRightSkew && !isRightCoalesced) {
val reducerId = rightPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Right side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " +
s"split it into ${skewSpecs.get.length} parts.")
numSkewedRight += 1
}
skewSpecs.getOrElse(Seq(rightPartSpec))
} else {
Seq(rightPartSpec)
}
skewSpecs.getOrElse(Seq(rightPartSpec))
} else {
Seq(rightPartSpec)
}

for {
leftSidePartition <- leftParts
rightSidePartition <- rightParts
} {
leftSidePartitions += leftSidePartition
rightSidePartitions += rightSidePartition
for {
leftSidePartition <- leftParts
rightSidePartition <- rightParts
} {
leftSidePartitions += leftSidePartition
rightSidePartitions += rightSidePartition
}
}
}

logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight")
if (numSkewedLeft > 0 || numSkewedRight > 0) {
val newLeft = CustomShuffleReaderExec(left.shuffleStage, leftSidePartitions)
val newRight = CustomShuffleReaderExec(right.shuffleStage, rightSidePartitions)
smj.copy(
left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true)
} else {
smj
logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight")
if (numSkewedLeft > 0 || numSkewedRight > 0) {
val newLeft = CustomShuffleReaderExec(left.shuffleStage, leftSidePartitions)
val newRight = CustomShuffleReaderExec(right.shuffleStage, rightSidePartitions)
val newSmj = replaceSkewedShufleReader(
replaceSkewedShufleReader(smj, newLeft), newRight).asInstanceOf[SortMergeJoinExec]
newSmj.copy(isSkewJoin = true)
} else {
smj
}
}
}

Expand All @@ -263,15 +290,19 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
val shuffleStages = collectShuffleStages(plan)

if (shuffleStages.length == 2) {
// When multi table join, there will be too many complex combination to consider.
// Currently we only handle 2 table join like following use case.
// SPARK-32201. Skew join supports below pattern, ".." may contain any number of nodes,
// includes such as BroadcastHashJoinExec. So it can handle more than two tables join.
// SMJ
// Sort
// Shuffle
// ..
// Shuffle
// Sort
// Shuffle
// ..
// Shuffle
val optimizePlan = optimizeSkewJoin(plan)
val numShuffles = ensureRequirements.apply(optimizePlan).collect {
val ensuredPlan = ensureRequirements.apply(optimizePlan)
println(ensuredPlan)
val numShuffles = ensuredPlan.collect {
case e: ShuffleExchangeExec => e
}.length

Expand Down Expand Up @@ -316,6 +347,23 @@ private object ShuffleStage {
}
Some(ShuffleStageInfo(s, mapStats, partitions))

case _: LeafExecNode => None

case _ @ UnaryExecNode((_, ShuffleStage(ss: ShuffleStageInfo))) =>
Some(ss)

case b: BinaryExecNode =>
b.left match {
case _ @ ShuffleStage(ss: ShuffleStageInfo) =>
Some(ss)
case _ =>
b.right match {
case _ @ ShuffleStage(ss: ShuffleStageInfo) =>
Some(ss)
case _ => None
}
}

case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,23 +692,6 @@ class AdaptiveQueryExecSuite
'id as "value2")
.createOrReplaceTempView("skewData2")

def checkSkewJoin(
joins: Seq[SortMergeJoinExec],
leftSkewNum: Int,
rightSkewNum: Int): Unit = {
assert(joins.size == 1 && joins.head.isSkewJoin)
assert(joins.head.left.collect {
case r: CustomShuffleReaderExec => r
}.head.partitionSpecs.collect {
case p: PartialReducerPartitionSpec => p.reducerIndex
}.distinct.length == leftSkewNum)
assert(joins.head.right.collect {
case r: CustomShuffleReaderExec => r
}.head.partitionSpecs.collect {
case p: PartialReducerPartitionSpec => p.reducerIndex
}.distinct.length == rightSkewNum)
}

// skewed inner join optimization
val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT * FROM skewData1 join skewData2 ON key1 = key2")
Expand All @@ -730,6 +713,67 @@ class AdaptiveQueryExecSuite
}
}

private def checkSkewJoin(
joins: Seq[SortMergeJoinExec],
leftSkewNum: Int,
rightSkewNum: Int): Unit = {
assert(joins.size == 1 && joins.head.isSkewJoin)
assert(joins.head.left.collect {
case r: CustomShuffleReaderExec => r
}.head.partitionSpecs.collect {
case p: PartialReducerPartitionSpec => p.reducerIndex
}.distinct.length == leftSkewNum)
assert(joins.head.right.collect {
case r: CustomShuffleReaderExec => r
}.head.partitionSpecs.collect {
case p: PartialReducerPartitionSpec => p.reducerIndex
}.distinct.length == rightSkewNum)
}

test("SPARK-32201: handle general skew join pattern") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
SQLConf.SHUFFLE_PARTITIONS.key -> "100",
SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800",
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") {
withTempView("skewData1", "skewData2") {
spark
.range(0, 1000, 1, 10)
.select(
when('id < 250, 249)
.when('id >= 750, 1000)
.otherwise('id).as("key1"),
'id as "value1")
.createOrReplaceTempView("skewData1")

spark
.range(0, 1000, 1, 10)
.select(
when('id < 250, 249)
.otherwise('id).as("key2"),
'id as "value2")
.createOrReplaceTempView("skewData2")
val sqlText =
"""
|SELECT * FROM
| skewData1 AS data1
| INNER JOIN
| (
| SELECT skewData2.key2, sum(skewData2.value2) AS sum2
| FROM skewData2 GROUP BY skewData2.key2
| ) AS data2
|ON data1.key1 = data2.key2
|""".stripMargin

val (_, adaptivePlan) = runAdaptiveAndVerifyResult(sqlText)
val innerSmj = findTopLevelSortMergeJoin(adaptivePlan)
checkSkewJoin(innerSmj, 2, 0)
}
}
}

test("SPARK-30291: AQE should catch the exceptions when doing materialize") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
Expand Down

0 comments on commit 479d56b

Please sign in to comment.