From 6a250e7910c691e7b74989e0b6ef9418a6a7c60b Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Tue, 13 Aug 2024 13:07:51 +0800 Subject: [PATCH] [SPARK-49179][SQL] Fix v2 multi bucketed inner joins throw AssertionError ### What changes were proposed in this pull request? For SMJ with inner join, it just wraps left and right output partitioning to `PartitioningCollection` so it may not satisfy the target required clustering. ### Why are the changes needed? Fix exception if the query contains multi bucketed inner joins ```sql SELECT * FROM testcat.ns.t1 JOIN testcat.ns.t2 ON t1.id = t2.id JOIN testcat.ns.t3 ON t1.id = t3.id ``` ``` Cause: java.lang.AssertionError: assertion failed at scala.Predef$.assert(Predef.scala:264) at org.apache.spark.sql.execution.exchange.EnsureRequirements.createKeyGroupedShuffleSpec(EnsureRequirements.scala:642) at org.apache.spark.sql.execution.exchange.EnsureRequirements.$anonfun$checkKeyGroupCompatible$1(EnsureRequirements.scala:385) at scala.collection.immutable.List.map(List.scala:247) at scala.collection.immutable.List.map(List.scala:79) at org.apache.spark.sql.execution.exchange.EnsureRequirements.checkKeyGroupCompatible(EnsureRequirements.scala:382) at org.apache.spark.sql.execution.exchange.EnsureRequirements.checkKeyGroupCompatible(EnsureRequirements.scala:364) at org.apache.spark.sql.execution.exchange.EnsureRequirements.org$apache$spark$sql$execution$exchange$EnsureRequirements$$ensureDistributionAndOrdering(EnsureRequirements.scala:166) at org.apache.spark.sql.execution.exchange.EnsureRequirements$$anonfun$1.applyOrElse(EnsureRequirements.scala:714) at org.apache.spark.sql.execution.exchange.EnsureRequirements$$anonfun$1.applyOrElse(EnsureRequirements.scala:689) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUpWithPruning$4(TreeNode.scala:528) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(origin.scala:84) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUpWithPruning(TreeNode.scala:528) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:497) at org.apache.spark.sql.execution.exchange.EnsureRequirements.apply(EnsureRequirements.scala:689) at org.apache.spark.sql.execution.exchange.EnsureRequirements.apply(EnsureRequirements.scala:51) at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec$.$anonfun$applyPhysicalRules$2(AdaptiveSparkPlanExec.scala:882) ``` ### Does this PR introduce _any_ user-facing change? yes, it's a bug fix ### How was this patch tested? add test ### Was this patch authored or co-authored using generative AI tooling? no Closes #47683 from ulysses-you/SPARK-49179. Authored-by: ulysses-you Signed-off-by: youxiduo (cherry picked from commit 8133294d6c2c925b97b3dbfcd3aa5e0762882d5f) Signed-off-by: youxiduo --- .../exchange/EnsureRequirements.scala | 7 +++--- .../KeyGroupedPartitioningSuite.scala | 22 +++++++++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 42c880e7c6262..ee0ea11816f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -550,7 +550,7 @@ case class EnsureRequirements( private def createKeyGroupedShuffleSpec( partitioning: Partitioning, distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = { - def check(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = { + def tryCreate(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = { val attributes = partitioning.expressions.flatMap(_.collectLeaves()) val clustering = distribution.clustering @@ -570,11 +570,10 @@ case class EnsureRequirements( } partitioning match { - case p: KeyGroupedPartitioning => check(p) + case p: KeyGroupedPartitioning => tryCreate(p) case PartitioningCollection(partitionings) => val specs = partitionings.map(p => createKeyGroupedShuffleSpec(p, distribution)) - assert(specs.forall(_.isEmpty) || specs.forall(_.isDefined)) - specs.head + specs.filter(_.isDefined).map(_.get).headOption case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 6b07c77aefb60..0718f090cff46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -330,6 +330,28 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { .add("price", FloatType) .add("time", TimestampType) + test("SPARK-49179: Fix v2 multi bucketed inner joins throw AssertionError") { + val cols = Array( + Column.create("id", LongType), + Column.create("name", StringType)) + val buckets = Array(bucket(8, "id")) + + withTable("t1", "t2", "t3") { + Seq("t1", "t2", "t3").foreach { t => + createTable(t, cols, buckets) + sql(s"INSERT INTO testcat.ns.$t VALUES (1, 'aa'), (2, 'bb'), (3, 'cc')") + } + val df = sql( + """ + |SELECT t1.id, t2.id, t3.name FROM testcat.ns.t1 + |JOIN testcat.ns.t2 ON t1.id = t2.id + |JOIN testcat.ns.t3 ON t1.id = t3.id + |""".stripMargin) + checkAnswer(df, Seq(Row(1, 1, "aa"), Row(2, 2, "bb"), Row(3, 3, "cc"))) + assert(collectShuffles(df.queryExecution.executedPlan).isEmpty) + } + } + test("partitioned join: join with two partition keys and matching & sorted partitions") { val items_partitions = Array(bucket(8, "id"), days("arrive_time")) createTable(items, items_schema, items_partitions)