diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 2b3675d0a9bfb..bc90a869fd9b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -547,7 +547,7 @@ case class EnsureRequirements( private def createKeyGroupedShuffleSpec( partitioning: Partitioning, distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = { - def tryCreate(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = { + def check(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = { val attributes = partitioning.expressions.flatMap(_.collectLeaves()) val clustering = distribution.clustering @@ -567,10 +567,11 @@ case class EnsureRequirements( } partitioning match { - case p: KeyGroupedPartitioning => tryCreate(p) + case p: KeyGroupedPartitioning => check(p) case PartitioningCollection(partitionings) => val specs = partitionings.map(p => createKeyGroupedShuffleSpec(p, distribution)) - specs.filter(_.isDefined).map(_.get).headOption + assert(specs.forall(_.isEmpty) || specs.forall(_.isDefined)) + specs.head case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 06485f1b4ce40..cf76f6ca32cad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -328,28 +328,6 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { .add("price", FloatType) .add("time", TimestampType) - test("SPARK-49179: Fix v2 multi bucketed inner joins throw AssertionError") { - val cols = Array( - Column.create("id", LongType), - Column.create("name", StringType)) - val buckets = Array(bucket(8, "id")) - - withTable("t1", "t2", "t3") { - Seq("t1", "t2", "t3").foreach { t => - createTable(t, cols, buckets) - sql(s"INSERT INTO testcat.ns.$t VALUES (1, 'aa'), (2, 'bb'), (3, 'cc')") - } - val df = sql( - """ - |SELECT t1.id, t2.id, t3.name FROM testcat.ns.t1 - |JOIN testcat.ns.t2 ON t1.id = t2.id - |JOIN testcat.ns.t3 ON t1.id = t3.id - |""".stripMargin) - checkAnswer(df, Seq(Row(1, 1, "aa"), Row(2, 2, "bb"), Row(3, 3, "cc"))) - assert(collectShuffles(df.queryExecution.executedPlan).isEmpty) - } - } - test("partitioned join: join with two partition keys and matching & sorted partitions") { val items_partitions = Array(bucket(8, "id"), days("arrive_time")) createTable(items, items_schema, items_partitions)