diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 31d984aaaed06..a070c843411ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -380,7 +380,14 @@ case class KeyGroupedPartitioning( } else { // We'll need to find leaf attributes from the partition expressions first. val attributes = expressions.flatMap(_.collectLeaves()) - attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + // check that all join keys (required clustering keys) contained in partitioning + requiredClustering.forall(x => attributes.exists(_.semanticEquals(x))) && + expressions.forall(_.collectLeaves().size == 1) + } else { + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } } case _ => @@ -389,8 +396,20 @@ case class KeyGroupedPartitioning( } } - override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = - KeyGroupedShuffleSpec(this, distribution) + override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = { + val result = KeyGroupedShuffleSpec(this, distribution) + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + // If allowing join keys to be subset of clustering keys, we should create a new + // `KeyGroupedPartitioning` here that is grouped on the join keys instead, and use that as + // the returned shuffle spec. + val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) + val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions, + partitionValues, originalPartitionValues) + result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) + } else { + result + } + } lazy val uniquePartitionValues: Seq[InternalRow] = { partitionValues @@ -403,8 +422,25 @@ case class KeyGroupedPartitioning( object KeyGroupedPartitioning { def apply( expressions: Seq[Expression], - partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { - KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues, partitionValues) + projectionPositions: Seq[Int], + partitionValues: Seq[InternalRow], + originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { + val projectedExpressions = projectionPositions.map(expressions(_)) + val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _)) + val projectedOriginalPartitionValues = + originalPartitionValues.map(project(expressions, projectionPositions, _)) + + KeyGroupedPartitioning(projectedExpressions, projectedPartitionValues.length, + projectedPartitionValues, projectedOriginalPartitionValues) + } + + def project( + expressions: Seq[Expression], + positions: Seq[Int], + input: InternalRow): InternalRow = { + val projectedValues: Array[Any] = positions.map(i => input.get(i, expressions(i).dataType)) + .toArray + new GenericInternalRow(projectedValues) } def supportsExpressions(expressions: Seq[Expression]): Boolean = { @@ -717,9 +753,18 @@ case class CoalescedHashShuffleSpec( override def numPartitions: Int = partitions.length } +/** + * [[ShuffleSpec]] created by [[KeyGroupedPartitioning]]. + * + * @param partitioning key grouped partitioning + * @param distribution distribution + * @param joinKeyPosition position of join keys among cluster keys. + * This is set if joining on a subset of cluster keys is allowed. + */ case class KeyGroupedShuffleSpec( partitioning: KeyGroupedPartitioning, - distribution: ClusteredDistribution) extends ShuffleSpec { + distribution: ClusteredDistribution, + joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec { /** * A sequence where each element is a set of positions of the partition expression to the cluster @@ -754,7 +799,7 @@ case class KeyGroupedShuffleSpec( // 3.3 each pair of partition expressions at the same index must share compatible // transform functions. // 4. the partition values from both sides are following the same order. - case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) => + case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6ece724e6adac..ad6c7ad1701b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1492,6 +1492,18 @@ object SQLConf { .booleanConf .createWithDefault(false) + val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS = + buildConf("spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled") + .doc("Whether to allow storage-partition join in the case where join keys are" + + "a subset of the partition keys of the source tables. At planning time, " + + "Spark will group the partitions by only those keys that are in the join keys." + + s"This is currently enabled only if ${REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key} " + + "is false." + ) + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") .version("2.4.0") @@ -4783,6 +4795,9 @@ class SQLConf extends Serializable with Logging { def v2BucketingShuffleEnabled: Boolean = getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED) + def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean = + getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 3ffcccdc410e9..afcc762e636a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -120,7 +120,12 @@ case class BatchScanExec( val newPartValues = spjParams.commonPartitionValues.get.flatMap { case (partValue, numSplits) => Seq.fill(numSplits)(partValue) } - k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues) + val expressions = spjParams.joinKeyPositions match { + case Some(projectionPositions) => projectionPositions.map(i => k.expressions(i)) + case _ => k.expressions + } + k.copy(expressions = expressions, numPartitions = newPartValues.length, + partitionValues = newPartValues) case p => p } } @@ -132,14 +137,29 @@ case class BatchScanExec( // return an empty RDD with 1 partition if dynamic filtering removed the only split sparkContext.parallelize(Array.empty[InternalRow], 1) } else { - var finalPartitions = filteredPartitions - - outputPartitioning match { + val finalPartitions = outputPartitioning match { case p: KeyGroupedPartitioning => - val groupedPartitions = filteredPartitions.map(splits => { - assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) - (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) - }) + assert(spjParams.keyGroupedPartitioning.isDefined) + val expressions = spjParams.keyGroupedPartitioning.get + + // Re-group the input partitions if we are projecting on a subset of join keys + val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match { + case Some(projectPositions) => + val projectedExpressions = projectPositions.map(i => expressions(i)) + val parts = filteredPartitions.flatten.groupBy(part => { + val row = part.asInstanceOf[HasPartitionKey].partitionKey() + val projectedRow = KeyGroupedPartitioning.project( + expressions, projectPositions, row) + InternalRowComparableWrapper(projectedRow, projectedExpressions) + }).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq + (parts, projectedExpressions) + case _ => + val groupedParts = filteredPartitions.map(splits => { + assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) + (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) + }) + (groupedParts, expressions) + } // When partially clustered, the input partitions are not grouped by partition // values. Here we'll need to check `commonPartitionValues` and decide how to group @@ -149,12 +169,12 @@ case class BatchScanExec( // should contain. val commonPartValuesMap = spjParams.commonPartitionValues .get - .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) + .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) .toMap val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) => // `commonPartValuesMap` should contain the part value since it's the super set. val numSplits = commonPartValuesMap - .get(InternalRowComparableWrapper(partValue, p.expressions)) + .get(InternalRowComparableWrapper(partValue, partExpressions)) assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + "common partition values from Spark plan") @@ -169,37 +189,37 @@ case class BatchScanExec( // sides of a join will have the same number of partitions & splits. splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) } - (InternalRowComparableWrapper(partValue, p.expressions), newSplits) + (InternalRowComparableWrapper(partValue, partExpressions), newSplits) } // Now fill missing partition keys with empty partitions val partitionMapping = nestGroupedPartitions.toMap - finalPartitions = spjParams.commonPartitionValues.get.flatMap { + spjParams.commonPartitionValues.get.flatMap { case (partValue, numSplits) => // Use empty partition for those partition values that are not present. partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), + InternalRowComparableWrapper(partValue, partExpressions), Seq.fill(numSplits)(Seq.empty)) } } else { // either `commonPartitionValues` is not defined, or it is defined but // `applyPartialClustering` is false. val partitionMapping = groupedPartitions.map { case (partValue, splits) => - InternalRowComparableWrapper(partValue, p.expressions) -> splits + InternalRowComparableWrapper(partValue, partExpressions) -> splits }.toMap // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there // could exist duplicated partition values, as partition grouping is not done // at the beginning and postponed to this method. It is important to use unique // partition values here so that grouped partitions won't get duplicated. - finalPartitions = p.uniquePartitionValues.map { partValue => + p.uniquePartitionValues.map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) + InternalRowComparableWrapper(partValue, partExpressions), Seq.empty) } } - case _ => + case _ => filteredPartitions } new DataSourceRDD( @@ -234,6 +254,7 @@ case class BatchScanExec( case class StoragePartitionJoinParams( keyGroupedPartitioning: Option[Seq[Expression]] = None, + joinKeyPositions: Option[Seq[Int]] = None, commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, applyPartialClustering: Boolean = false, replicatePartitions: Boolean = false) { @@ -247,6 +268,7 @@ case class StoragePartitionJoinParams( } override def hashCode(): Int = Objects.hashCode( + joinKeyPositions: Option[Seq[Int]], commonPartitionValues: Option[Seq[(InternalRow, Int)]], applyPartialClustering: java.lang.Boolean, replicatePartitions: java.lang.Boolean) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index f8e6fd1d0167f..8552c950f6776 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -380,7 +380,8 @@ case class EnsureRequirements( val rightSpec = specs(1) var isCompatible = false - if (!conf.v2BucketingPushPartValuesEnabled) { + if (!conf.v2BucketingPushPartValuesEnabled && + !conf.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { isCompatible = leftSpec.isCompatibleWith(rightSpec) } else { logInfo("Pushing common partition values for storage-partitioned join") @@ -505,10 +506,10 @@ case class EnsureRequirements( } // Now we need to push-down the common partition key to the scan in each child - newLeft = populatePartitionValues( - left, mergedPartValues, applyPartialClustering, replicateLeftSide) - newRight = populatePartitionValues( - right, mergedPartValues, applyPartialClustering, replicateRightSide) + newLeft = populatePartitionValues(left, mergedPartValues, leftSpec.joinKeyPositions, + applyPartialClustering, replicateLeftSide) + newRight = populatePartitionValues(right, mergedPartValues, rightSpec.joinKeyPositions, + applyPartialClustering, replicateRightSide) } } @@ -530,19 +531,21 @@ case class EnsureRequirements( private def populatePartitionValues( plan: SparkPlan, values: Seq[(InternalRow, Int)], + joinKeyPositions: Option[Seq[Int]], applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => scan.copy( spjParams = scan.spjParams.copy( commonPartitionValues = Some(values), + joinKeyPositions = joinKeyPositions, applyPartialClustering = applyPartialClustering, replicatePartitions = replicatePartitions ) ) case node => node.mapChildren(child => populatePartitionValues( - child, values, applyPartialClustering, replicatePartitions)) + child, values, joinKeyPositions, applyPartialClustering, replicatePartitions)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 0853387084e97..cec9f5556d3f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -98,14 +98,17 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val catalystDistribution = physical.ClusteredDistribution( Seq(TransformExpression(YearsFunction, Seq(attr("ts"))))) val partitionValues = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) + val projectedPositions = catalystDistribution.clustering.indices checkQueryPlan(df, catalystDistribution, - physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues)) + physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, + partitionValues, partitionValues)) // multiple group keys should work too as long as partition keys are subset of them df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts") checkQueryPlan(df, catalystDistribution, - physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues)) + physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, + partitionValues, partitionValues)) } test("non-clustered distribution: no partition") { @@ -1288,35 +1291,292 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val items_partitions = Array(identity("id")) createTable(items, items_schema, items_partitions) sql(s"INSERT INTO testcat.ns.$items VALUES " + - s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + - s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + - s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + - s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " + - s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") val purchases_partitions = Array(identity("item_id")) createTable(purchases, purchases_schema, purchases_partitions) sql(s"INSERT INTO testcat.ns.$purchases VALUES " + - s"(1, 42.0, cast('2020-01-01' as timestamp)), " + - s"(1, 44.0, cast('2020-01-15' as timestamp)), " + - s"(1, 45.0, cast('2020-01-15' as timestamp)), " + - s"(2, 11.0, cast('2020-01-01' as timestamp)), " + - s"(3, 19.5, cast('2020-02-01' as timestamp))") + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(1, 44.0, cast('2020-01-15' as timestamp)), " + + s"(1, 45.0, cast('2020-01-15' as timestamp)), " + + s"(2, 11.0, cast('2020-01-01' as timestamp)), " + + s"(3, 19.5, cast('2020-02-01' as timestamp))") Seq(true, false).foreach { pushDownValues => Seq(true, false).foreach { partiallyClustered => { withSQLConf( SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> - partiallyClustered.toString, + partiallyClustered.toString, SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString) { // The dynamic filtering effectively filtered out all the partitions val df = sql(s"SELECT p.price from testcat.ns.$items i, testcat.ns.$purchases p " + - "WHERE i.id = p.item_id AND i.price > 50.0") + "WHERE i.id = p.item_id AND i.price > 50.0") checkAnswer(df, Seq.empty) } } } } } + + test("SPARK-44647: test join key is subset of cluster key " + + "with push values and partially-clustered") { + val table1 = "tab1e1" + val table2 = "table2" + val partition = Array(identity("id"), identity("data")) + createTable(table1, schema, partition) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(1, 'aa', cast('2020-01-01' as timestamp)), " + + "(2, 'bb', cast('2020-01-01' as timestamp)), " + + "(2, 'cc', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp))") + + createTable(table2, schema, partition) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(3, 'yy', cast('2020-01-01' as timestamp)), " + + "(3, 'yy', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(2, 'ww', cast('2020-01-01' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + + val df = sql("SELECT t1.id AS id, t1.data AS t1data, t2.data AS t2data " + + s"FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 " + + "ON t1.id = t2.id ORDER BY t1.id, t1data, t2data") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + + (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { + // SPJ and partially-clustered + case (true, true) => assert(scans == Seq(8, 8)) + // SPJ and not partially-clustered + case (true, false) => assert(scans == Seq(4, 4)) + // No SPJ + case _ => assert(scans == Seq(5, 4)) + } + + checkAnswer(df, Seq( + Row(2, "bb", "ww"), + Row(2, "cc", "ww"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy") + )) + } + } + } + } + } + + test("SPARK-44647: test join key is the second cluster key") { + val table1 = "tab1e1" + val table2 = "table2" + val partition = Array(identity("id"), identity("data")) + createTable(table1, schema, partition) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(1, 'aa', cast('2020-01-01' as timestamp)), " + + "(2, 'bb', cast('2020-01-02' as timestamp)), " + + "(3, 'cc', cast('2020-01-03' as timestamp))") + + createTable(table2, schema, partition) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(4, 'aa', cast('2020-01-01' as timestamp)), " + + "(5, 'bb', cast('2020-01-02' as timestamp)), " + + "(6, 'cc', cast('2020-01-03' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> + pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + + val df = sql("SELECT t1.id AS t1id, t2.id as t2id, t1.data AS data " + + s"FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 " + + "ON t1.data = t2.data ORDER BY t1id, t1id, data") + + checkAnswer(df, Seq(Row(1, 4, "aa"), Row(2, 5, "bb"), Row(3, 6, "cc"))) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + (pushDownValues, allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { + // SPJ and partially-clustered + case (true, true, true) => assert(scans == Seq(3, 3)) + // non-SPJ or SPJ/partially-clustered + case _ => assert(scans == Seq(3, 3)) + } + } + } + } + } + } + + test("SPARK-44647: test join key is the second partition key and a transform") { + val items_partitions = Array(bucket(8, "id"), days("arrive_time")) + createTable(items, items_schema, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id"), days("time")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(1, 44.0, cast('2020-01-15' as timestamp)), " + + s"(1, 45.0, cast('2020-01-15' as timestamp)), " + + s"(2, 11.0, cast('2020-01-01' as timestamp)), " + + s"(3, 19.5, cast('2020-02-01' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + val df = sql("SELECT id, name, i.price as purchase_price, " + + "p.item_id, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.arrive_time = p.time " + + "ORDER BY id, purchase_price, p.item_id, sale_price") + + // Currently SPJ for case where join key not same as partition key + // only supported when push-part-values enabled + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { + // SPJ and partially-clustered + case (true, true) => assert(scans == Seq(5, 5)) + // SPJ and not partially-clustered + case (true, false) => assert(scans == Seq(3, 3)) + // No SPJ + case _ => assert(scans == Seq(4, 4)) + } + + checkAnswer(df, + Seq( + Row(1, "aa", 40.0, 1, 42.0), + Row(1, "aa", 40.0, 2, 11.0), + Row(1, "aa", 41.0, 1, 44.0), + Row(1, "aa", 41.0, 1, 45.0), + Row(2, "bb", 10.0, 1, 42.0), + Row(2, "bb", 10.0, 2, 11.0), + Row(2, "bb", 10.5, 1, 42.0), + Row(2, "bb", 10.5, 2, 11.0), + Row(3, "cc", 15.5, 3, 19.5) + ) + ) + } + } + } + } + } + + test("SPARK-44647: shuffle one side and join keys are less than partition keys") { + val items_partitions = Array(identity("id"), identity("name")) + createTable(items, items_schema, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchases_schema, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 89.0, cast('2020-01-03' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp)), " + + "(5, 26.0, cast('2023-01-01' as timestamp)), " + + "(6, 50.0, cast('2023-02-01' as timestamp))") + + Seq(true, false).foreach { pushdownValues => + Seq(true, false).foreach { partiallyClustered => + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key + -> partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "SPJ should be triggered") + checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), + Row(1, "aa", 30.0, 89.0), + Row(1, "aa", 40.0, 42.0), + Row(1, "aa", 40.0, 89.0), + Row(3, "bb", 10.0, 19.5))) + } + } + } } }