From 8940f6e48f056e65fa86a7a7d5bf17f3f3d2bb14 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Fri, 21 Jun 2024 17:20:54 -0700 Subject: [PATCH] [SPARK-48613][SQL] SPJ: Support auto-shuffle one side + less join keys than partition keys ### What changes were proposed in this pull request? SPJ: Support auto-shuffle one side + less join keys than partition keys ### Why are the changes needed? This is the last planned scenario for SPJ not yet supported. ### How was this patch tested? Update existing unit test in KeyGroupedPartitionSuite ### Was this patch authored or co-authored using generative AI tooling? No. --- .../plans/physical/partitioning.scala | 25 ++++++++----------- .../datasources/v2/BatchScanExec.scala | 12 ++++++++- .../exchange/EnsureRequirements.scala | 20 ++++++++++++++- .../KeyGroupedPartitioningSuite.scala | 3 +-- 4 files changed, 42 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 19595eef10b34..89cf70eab3417 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -434,8 +434,13 @@ object KeyGroupedPartitioning { val projectedOriginalPartitionValues = originalPartitionValues.map(project(expressions, projectionPositions, _)) - KeyGroupedPartitioning(projectedExpressions, projectedPartitionValues.length, - projectedPartitionValues, projectedOriginalPartitionValues) + val finalPartitionValues = projectedPartitionValues + .map(InternalRowComparableWrapper(_, projectedExpressions)) + .distinct + .map(_.row) + + KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length, + finalPartitionValues, projectedOriginalPartitionValues) } def project( @@ -871,20 +876,12 @@ case class KeyGroupedShuffleSpec( if (results.forall(p => p.isEmpty)) None else Some(results) } - override def canCreatePartitioning: Boolean = { - // Allow one side shuffle for SPJ for now only if partially-clustered is not enabled - // and for join keys less than partition keys only if transforms are not enabled. - val checkExprType = if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - e: Expression => e.isInstanceOf[AttributeReference] - } else { - e: Expression => e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression] - } + override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled && - partitioning.expressions.forall(checkExprType) - } - - + partitioning.expressions.forall{e => + e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression] + } override def createPartitioning(clustering: Seq[Expression]): Partitioning = { val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index f949dbf71a371..d5f5c4332555f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -130,6 +130,15 @@ case class BatchScanExec( } k.copy(expressions = expressions, numPartitions = newPartValues.length, partitionValues = newPartValues) + case k: KeyGroupedPartitioning if spjParams.joinKeyPositions.isDefined => + val expressions = spjParams.joinKeyPositions.get.map(i => k.expressions(i)) + val newPartValues = k.partitionValues.map{r => + val projectedRow = KeyGroupedPartitioning.project(expressions, + spjParams.joinKeyPositions.get, r) + InternalRowComparableWrapper(projectedRow, expressions) + }.distinct.map(_.row) + k.copy(expressions = expressions, numPartitions = newPartValues.length, + partitionValues = newPartValues) case p => p } } @@ -279,7 +288,8 @@ case class StoragePartitionJoinParams( case other: StoragePartitionJoinParams => this.commonPartitionValues == other.commonPartitionValues && this.replicatePartitions == other.replicatePartitions && - this.applyPartialClustering == other.applyPartialClustering + this.applyPartialClustering == other.applyPartialClustering && + this.joinKeyPositions == other.joinKeyPositions case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 67d879bdd8bf4..0a1075289424d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -175,7 +175,11 @@ case class EnsureRequirements( child case ((child, dist), idx) => if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) { - child + bestSpecOpt match { + case Some(KeyGroupedShuffleSpec(_, _, Some(joinKeyPositions))) => + populateJoinKeyPositions(child, Some(joinKeyPositions)) + case _ => child + } } else { val newPartitioning = bestSpecOpt.map { bestSpec => // Use the best spec to create a new partitioning to re-shuffle this child @@ -578,6 +582,20 @@ case class EnsureRequirements( child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions)) } + + private def populateJoinKeyPositions(plan: SparkPlan, + joinKeyPositions: Option[Seq[Int]]): SparkPlan = plan match { + case scan: BatchScanExec => + scan.copy( + spjParams = scan.spjParams.copy( + joinKeyPositions = joinKeyPositions + ) + ) + case node => + node.mapChildren(child => populateJoinKeyPositions( + child, joinKeyPositions)) + } + private def reduceCommonPartValues( commonPartValues: Seq[(InternalRow, Int)], expressions: Seq[Expression], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index a5de5bc1913b9..155e6ded43bc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -2145,8 +2145,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { val df = createJoinTestDF(Seq("id" -> "item_id")) val shuffles = collectShuffles(df.queryExecution.executedPlan) - assert(shuffles.size == 2, "SPJ should not be triggered for transform expression with" + - "less join keys than partition keys for now.") + assert(shuffles.size == 1, "SPJ should be triggered") checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), Row(1, "aa", 30.0, 89.0), Row(1, "aa", 40.0, 42.0),