diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index b176598ed8c2c..3641654b89b76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -135,9 +135,14 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { leftKeys: IndexedSeq[Expression], rightKeys: IndexedSeq[Expression], expectedOrderOfKeys: Seq[Expression], - currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + currentOrderOfKeys: Seq[Expression]): Option[(Seq[Expression], Seq[Expression])] = { if (expectedOrderOfKeys.size != currentOrderOfKeys.size) { - return (leftKeys, rightKeys) + return None + } + + // Check if the current order already satisfies the expected order. + if (expectedOrderOfKeys.zip(currentOrderOfKeys).forall(p => p._1.semanticEquals(p._2))) { + return Some(leftKeys, rightKeys) } // Build a lookup between an expression and the positions its holds in the current key seq. @@ -164,10 +169,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { rightKeysBuffer += rightKeys(index) case _ => // The expression cannot be found, or we have exhausted all indices for that expression. - return (leftKeys, rightKeys) + return None } } - (leftKeysBuffer.toSeq, rightKeysBuffer.toSeq) + Some(leftKeysBuffer.toSeq, rightKeysBuffer.toSeq) } private def reorderJoinKeys( @@ -176,19 +181,48 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { leftPartitioning: Partitioning, rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { - (leftPartitioning, rightPartitioning) match { - case (HashPartitioning(leftExpressions, _), _) => - reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys) - case (_, HashPartitioning(rightExpressions, _)) => - reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) - case _ => - (leftKeys, rightKeys) - } + reorderJoinKeysRecursively( + leftKeys, + rightKeys, + Some(leftPartitioning), + Some(rightPartitioning)) + .getOrElse((leftKeys, rightKeys)) } else { (leftKeys, rightKeys) } } + /** + * Recursively reorders the join keys based on partitioning. It starts reordering the + * join keys to match HashPartitioning on either side, followed by PartitioningCollection. + */ + private def reorderJoinKeysRecursively( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftPartitioning: Option[Partitioning], + rightPartitioning: Option[Partitioning]): Option[(Seq[Expression], Seq[Expression])] = { + (leftPartitioning, rightPartitioning) match { + case (Some(HashPartitioning(leftExpressions, _)), _) => + reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys) + .orElse(reorderJoinKeysRecursively( + leftKeys, rightKeys, None, rightPartitioning)) + case (_, Some(HashPartitioning(rightExpressions, _))) => + reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) + .orElse(reorderJoinKeysRecursively( + leftKeys, rightKeys, leftPartitioning, None)) + case (Some(PartitioningCollection(partitionings)), _) => + partitionings.foldLeft(Option.empty[(Seq[Expression], Seq[Expression])]) { (res, p) => + res.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, Some(p), rightPartitioning)) + }.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, None, rightPartitioning)) + case (_, Some(PartitioningCollection(partitionings))) => + partitionings.foldLeft(Option.empty[(Seq[Expression], Seq[Expression])]) { (res, p) => + res.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, Some(p))) + }.orElse(None) + case _ => + None + } + } + /** * When the physical operators are created for JOIN, the ordering of join keys is based on order * in which the join keys appear in the user query. That might not match with the output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala new file mode 100644 index 0000000000000..38e68cd2512e7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection} +import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} +import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.test.SharedSparkSession + +class EnsureRequirementsSuite extends SharedSparkSession { + private val exprA = Literal(1) + private val exprB = Literal(2) + private val exprC = Literal(3) + + test("reorder should handle PartitioningCollection") { + val plan1 = DummySparkPlan( + outputPartitioning = PartitioningCollection(Seq( + HashPartitioning(exprA :: exprB :: Nil, 5), + HashPartitioning(exprA :: Nil, 5)))) + val plan2 = DummySparkPlan() + + // Test PartitioningCollection on the left side of join. + val smjExec1 = SortMergeJoinExec( + exprB :: exprA :: Nil, exprA :: exprB :: Nil, Inner, None, plan1, plan2) + EnsureRequirements(spark.sessionState.conf).apply(smjExec1) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) => + assert(leftKeys === Seq(exprA, exprB)) + assert(rightKeys === Seq(exprB, exprA)) + case other => fail(other.toString) + } + + // Test PartitioningCollection on the right side of join. + val smjExec2 = SortMergeJoinExec( + exprA :: exprB :: Nil, exprB :: exprA :: Nil, Inner, None, plan2, plan1) + EnsureRequirements(spark.sessionState.conf).apply(smjExec2) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) => + assert(leftKeys === Seq(exprB, exprA)) + assert(rightKeys === Seq(exprA, exprB)) + case other => fail(other.toString) + } + + // Both sides are PartitioningCollection, but left side cannot be reorderd to match + // and it should fall back to the right side. + val smjExec3 = SortMergeJoinExec( + exprA :: exprC :: Nil, exprB :: exprA :: Nil, Inner, None, plan1, plan1) + EnsureRequirements(spark.sessionState.conf).apply(smjExec3) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) => + assert(leftKeys === Seq(exprC, exprA)) + assert(rightKeys === Seq(exprA, exprB)) + case other => fail(other.toString) + } + } + + test("reorder should fallback to the other side partitioning") { + val plan1 = DummySparkPlan( + outputPartitioning = HashPartitioning(exprA :: exprB :: exprC :: Nil, 5)) + val plan2 = DummySparkPlan( + outputPartitioning = HashPartitioning(exprB :: exprC :: Nil, 5)) + + // Test fallback to the right side, which has HashPartitioning. + val smjExec1 = SortMergeJoinExec( + exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan2) + EnsureRequirements(spark.sessionState.conf).apply(smjExec1) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), _) => + assert(leftKeys === Seq(exprB, exprA)) + assert(rightKeys === Seq(exprB, exprC)) + case other => fail(other.toString) + } + + // Test fallback to the right side, which has PartitioningCollection. + val plan3 = DummySparkPlan( + outputPartitioning = PartitioningCollection(Seq(HashPartitioning(exprB :: exprC :: Nil, 5)))) + val smjExec2 = SortMergeJoinExec( + exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan3) + EnsureRequirements(spark.sessionState.conf).apply(smjExec2) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) => + assert(leftKeys === Seq(exprB, exprA)) + assert(rightKeys === Seq(exprB, exprC)) + case other => fail(other.toString) + } + + // The right side has HashPartitioning, so it is matched first, but no reordering match is + // found, and it should fall back to the left side, which has a PartitioningCollection. + val smjExec3 = SortMergeJoinExec( + exprC :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, plan3, plan1) + EnsureRequirements(spark.sessionState.conf).apply(smjExec3) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) => + assert(leftKeys === Seq(exprB, exprC)) + assert(rightKeys === Seq(exprB, exprA)) + case other => fail(other.toString) + } + } +}