Skip to content

Commit

Permalink
### What changes were proposed in this pull request?
Browse files Browse the repository at this point in the history
SPJ: Support auto-shuffle one side + less join keys than partition keys

  ### Why are the changes needed?

This is the last planned scenario for SPJ not yet supported.

  ### How was this patch tested?
Update existing unit test in KeyGroupedPartitionSuite

  ### Was this patch authored or co-authored using generative AI tooling?
 No.
  • Loading branch information
szehon-ho committed Jun 22, 2024
1 parent 80bba44 commit a2c1305
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,13 @@ object KeyGroupedPartitioning {
val projectedOriginalPartitionValues =
originalPartitionValues.map(project(expressions, projectionPositions, _))

KeyGroupedPartitioning(projectedExpressions, projectedPartitionValues.length,
projectedPartitionValues, projectedOriginalPartitionValues)
val finalPartitionValues = projectedPartitionValues
.map(InternalRowComparableWrapper(_, projectedExpressions))
.distinct
.map(_.row)

KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length,
finalPartitionValues, projectedOriginalPartitionValues)
}

def project(
Expand Down Expand Up @@ -871,20 +876,12 @@ case class KeyGroupedShuffleSpec(
if (results.forall(p => p.isEmpty)) None else Some(results)
}

override def canCreatePartitioning: Boolean = {
// Allow one side shuffle for SPJ for now only if partially-clustered is not enabled
// and for join keys less than partition keys only if transforms are not enabled.
val checkExprType = if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
e: Expression => e.isInstanceOf[AttributeReference]
} else {
e: Expression => e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression]
}
override def canCreatePartitioning: Boolean =
SQLConf.get.v2BucketingShuffleEnabled &&
!SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
partitioning.expressions.forall(checkExprType)
}


partitioning.expressions.forall{e =>
e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression]
}

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ case class BatchScanExec(
}
k.copy(expressions = expressions, numPartitions = newPartValues.length,
partitionValues = newPartValues)
case k: KeyGroupedPartitioning if spjParams.joinKeyPositions.isDefined =>
val expressions = spjParams.joinKeyPositions.get.map(i => k.expressions(i))
val newPartValues = k.partitionValues.map{r =>
val projectedRow = KeyGroupedPartitioning.project(expressions,
spjParams.joinKeyPositions.get, r)
InternalRowComparableWrapper(projectedRow, expressions)
}.distinct.map(_.row)
k.copy(expressions = expressions, numPartitions = newPartValues.length,
partitionValues = newPartValues)
case p => p
}
}
Expand Down Expand Up @@ -279,7 +288,8 @@ case class StoragePartitionJoinParams(
case other: StoragePartitionJoinParams =>
this.commonPartitionValues == other.commonPartitionValues &&
this.replicatePartitions == other.replicatePartitions &&
this.applyPartialClustering == other.applyPartialClustering
this.applyPartialClustering == other.applyPartialClustering &&
this.joinKeyPositions == other.joinKeyPositions
case _ =>
false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,11 @@ case class EnsureRequirements(
child
case ((child, dist), idx) =>
if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) {
child
bestSpecOpt match {
case Some(KeyGroupedShuffleSpec(partitioning, _, Some(joinKeyPositions))) =>
populateJoinKeyPositions(child, Some(joinKeyPositions))
case _ => child
}
} else {
val newPartitioning = bestSpecOpt.map { bestSpec =>
// Use the best spec to create a new partitioning to re-shuffle this child
Expand Down Expand Up @@ -578,6 +582,20 @@ case class EnsureRequirements(
child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions))
}


private def populateJoinKeyPositions(plan: SparkPlan,
joinKeyPositions: Option[Seq[Int]]): SparkPlan = plan match {
case scan: BatchScanExec =>
scan.copy(
spjParams = scan.spjParams.copy(
joinKeyPositions = joinKeyPositions
)
)
case node =>
node.mapChildren(child => populateJoinKeyPositions(
child, joinKeyPositions))
}

private def reduceCommonPartValues(
commonPartValues: Seq[(InternalRow, Int)],
expressions: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2145,8 +2145,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {
val df = createJoinTestDF(Seq("id" -> "item_id"))
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 2, "SPJ should not be triggered for transform expression with" +
"less join keys than partition keys for now.")
assert(shuffles.size == 1, "SPJ should be triggered")
checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
Row(1, "aa", 30.0, 89.0),
Row(1, "aa", 40.0, 42.0),
Expand Down

0 comments on commit a2c1305

Please sign in to comment.