Skip to content

Commit

Permalink
[SPARK-48065][SQL] SPJ: allowJoinKeysSubsetOfPartitionKeys is too strict
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
If spark.sql.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled is true, change KeyGroupedPartitioning.satisfies0(distribution) check from all clustering keys (here, join keys)  being in partition keys, to the two sets overlapping.

  ### Why are the changes needed?
If spark.sql.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled is true, then SPJ no longer triggers if there are more join keys than partition keys. But SPJ is supported in this case if flag is false.

  ### Does this PR introduce _any_ user-facing change?
No

  ### How was this patch tested?
Added tests in KeyGroupedPartitioningSuite

 ### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#46325 from szehon-ho/fix_spj_less_join_key.

Authored-by: Szehon Ho <szehon.apache@gmail.com>
Signed-off-by: Chao Sun <chao@openai.com>
  • Loading branch information
szehon-ho authored and JacobZheng0927 committed May 11, 2024
1 parent db4cf8c commit 7b433c5
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,9 @@ case class KeyGroupedPartitioning(
val attributes = expressions.flatMap(_.collectLeaves())

if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
// check that all join keys (required clustering keys) contained in partitioning
requiredClustering.forall(x => attributes.exists(_.semanticEquals(x))) &&
// check that join keys (required clustering keys)
// overlap with partition keys (KeyGroupedPartitioning attributes)
requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
expressions.forall(_.collectLeaves().size == 1)
} else {
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,66 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}
}

test("SPARK-48065: SPJ: allowJoinKeysSubsetOfPartitionKeys is too strict") {
val table1 = "tab1e1"
val table2 = "table2"
val partition = Array(identity("id"))
createTable(table1, columns, partition)
sql(s"INSERT INTO testcat.ns.$table1 VALUES " +
"(1, 'aa', cast('2020-01-01' as timestamp)), " +
"(2, 'bb', cast('2020-01-01' as timestamp)), " +
"(2, 'cc', cast('2020-01-01' as timestamp)), " +
"(3, 'dd', cast('2020-01-01' as timestamp)), " +
"(3, 'dd', cast('2020-01-01' as timestamp)), " +
"(3, 'ee', cast('2020-01-01' as timestamp)), " +
"(3, 'ee', cast('2020-01-01' as timestamp))")

createTable(table2, columns, partition)
sql(s"INSERT INTO testcat.ns.$table2 VALUES " +
"(4, 'zz', cast('2020-01-01' as timestamp)), " +
"(4, 'zz', cast('2020-01-01' as timestamp)), " +
"(3, 'dd', cast('2020-01-01' as timestamp)), " +
"(3, 'dd', cast('2020-01-01' as timestamp)), " +
"(3, 'xx', cast('2020-01-01' as timestamp)), " +
"(3, 'xx', cast('2020-01-01' as timestamp)), " +
"(2, 'ww', cast('2020-01-01' as timestamp))")

Seq(true, false).foreach { pushDownValues =>
Seq(true, false).foreach { partiallyClustered =>
withSQLConf(
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
partiallyClustered.toString,
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {
val df = sql(
s"""
|${selectWithMergeJoinHint("t1", "t2")}
|t1.id AS id, t1.data AS t1data, t2.data AS t2data
|FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2
|ON t1.id = t2.id AND t1.data = t2.data ORDER BY t1.id, t1data, t2data
|""".stripMargin)
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.isEmpty, "SPJ should be triggered")

val scans = collectScans(df.queryExecution.executedPlan)
.map(_.inputRDD.partitions.length)
if (partiallyClustered) {
assert(scans == Seq(8, 8))
} else {
assert(scans == Seq(4, 4))
}
checkAnswer(df, Seq(
Row(3, "dd", "dd"),
Row(3, "dd", "dd"),
Row(3, "dd", "dd"),
Row(3, "dd", "dd")
))
}
}
}
}

test("SPARK-44647: test join key is subset of cluster key " +
"with push values and partially-clustered") {
val table1 = "tab1e1"
Expand Down

0 comments on commit 7b433c5

Please sign in to comment.