diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java new file mode 100644 index 0000000000000..0b9ed7fb681ea --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java @@ -0,0 +1,16 @@ +package org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.annotation.Evolving; + +/** + * A 'reducer' for output of user-defined functions. + * + * A user_defined function f_source(x) is 'reducible' on another user_defined function f_target(x), + * if there exists a 'reducer' r(x) such that r(f_source(x)) = f_target(x) for all input x. + * @param function output type + * @since 4.0.0 + */ +@Evolving +public interface Reducer { + T reduce(T arg1); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java new file mode 100644 index 0000000000000..39103d063f351 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -0,0 +1,25 @@ +package org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.annotation.Evolving; +import scala.Option; + +/** + * Base class for user-defined functions that can be 'reduced' on another function. + * + * A function f_source(x) is 'reducible' on another function f_target(x) if + * there exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x. + * + * @since 4.0.0 + */ +@Evolving +public interface ReducibleFunction extends ScalarFunction { + + /** + * If this function is 'reducible' on another function, return the {@link Reducer} function. + * @param other other function + * @param thisArgument argument for this function instance + * @param otherArgument argument for other function instance + * @return a reduction function if it is reducible, none if not + */ + Option> reducer(ReducibleFunction other, Option thisArgument, Option otherArgument); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index 8412de554b711..cc5810993a9fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.connector.catalog.functions.BoundFunction +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ReducibleFunction} import org.apache.spark.sql.types.DataType /** @@ -54,6 +54,31 @@ case class TransformExpression( false } + /** + * Whether this [[TransformExpression]]'s function is compatible with the `other` + * [[TransformExpression]]'s function. + * + * This is true if both are instances of [[ReducibleFunction]] and there exists a [[Reducer]] r(x) + * such that r(t1(x)) = t2(x), or r(t2(x)) = t1(x), for all input x. + * + * @param other the transform expression to compare to + * @return true if compatible, false if not + */ + def isCompatible(other: TransformExpression): Boolean = { + if (isSameFunction(other)) { + true + } else { + (function, other.function) match { + case (f: ReducibleFunction[Any, Any] @unchecked, + o: ReducibleFunction[Any, Any] @unchecked) => + val reducer = f.reducer(o, numBucketsOpt, other.numBucketsOpt) + val otherReducer = o.reducer(f, other.numBucketsOpt, numBucketsOpt) + reducer.isDefined || otherReducer.isDefined + case _ => false + } + } + } + override def dataType: DataType = function.resultType() override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index c98a2a92a3abb..3c1cc5e2e9e92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -24,6 +24,7 @@ import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper +import org.apache.spark.sql.connector.catalog.functions.{Reducer, ReducibleFunction} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, IntegerType} @@ -635,6 +636,22 @@ trait ShuffleSpec { */ def createPartitioning(clustering: Seq[Expression]): Partitioning = throw SparkUnsupportedOperationException() + + /** + * Return a set of [[Reducer]] for the partition expressions of this shuffle spec, + * on the partition expressions of another shuffle spec. + *

+ * A [[Reducer]] exists for a partition expression function of this shuffle spec if it is + * 'reducible' on the corresponding partition expression function of the other shuffle spec. + *

+ * If a value is returned, there must be one Option[[Reducer]] per partition expression. + * A None value in the set indicates that the particular partition expression is not reducible + * on the corresponding expression on the other shuffle spec. + *

+ * Returning none also indicates that none of the partition expressions can be reduced on the + * corresponding expression on the other shuffle spec. + */ + def reducers(spec: ShuffleSpec): Option[Seq[Option[Reducer[Any]]]] = None } case object SinglePartitionShuffleSpec extends ShuffleSpec { @@ -829,20 +846,60 @@ case class KeyGroupedShuffleSpec( } } + override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && + // Only support partition expressions are AttributeReference for now + partitioning.expressions.forall(_.isInstanceOf[AttributeReference]) + + override def createPartitioning(clustering: Seq[Expression]): Partitioning = { + KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues) + } + + override def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[Any]]]] = { + other match { + case otherSpec: KeyGroupedShuffleSpec => + val results = partitioning.expressions.zip(otherSpec.partitioning.expressions).map { + case (e1: TransformExpression, e2: TransformExpression) + if e1.function.isInstanceOf[ReducibleFunction[Any, Any]@unchecked] + && e2.function.isInstanceOf[ReducibleFunction[Any, Any]@unchecked] => + e1.function.asInstanceOf[ReducibleFunction[Any, Any]].reducer( + e2.function.asInstanceOf[ReducibleFunction[Any, Any]], + e1.numBucketsOpt.map(a => a.asInstanceOf[Any]), + e2.numBucketsOpt.map(a => a.asInstanceOf[Any])) + case (_, _) => None + } + + // optimize to not return a value, if none of the partition expressions need reducing + if (results.forall(p => p.isEmpty)) None else Some(results) + case _ => None + } + } + private def isExpressionCompatible(left: Expression, right: Expression): Boolean = (left, right) match { case (_: LeafExpression, _: LeafExpression) => true case (left: TransformExpression, right: TransformExpression) => - left.isSameFunction(right) + if (SQLConf.get.v2BucketingPushPartValuesEnabled && + !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled && + SQLConf.get.v2BucketingAllowCompatibleTransforms) { + left.isCompatible(right) + } else { + left.isSameFunction(right) + } case _ => false } +} - override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && - // Only support partition expressions are AttributeReference for now - partitioning.expressions.forall(_.isInstanceOf[AttributeReference]) - - override def createPartitioning(clustering: Seq[Expression]): Partitioning = { - KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues) +object KeyGroupedShuffleSpec { + def reducePartitionValue(row: InternalRow, + expressions: Seq[Expression], + reducers: Seq[Option[Reducer[Any]]]): + InternalRowComparableWrapper = { + val partitionVals = row.toSeq(expressions.map(_.dataType)) + val reducedRow = partitionVals.zip(reducers).map{ + case (v, Some(reducer)) => reducer.reduce(v) + case (v, _) => v + }.toArray + InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 04b392d0c44f4..447ccd326f6d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1537,6 +1537,18 @@ object SQLConf { .booleanConf .createWithDefault(false) + val V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS = + buildConf("spark.sql.sources.v2.bucketing.allow.enabled") + .doc("Whether to allow storage-partition join in the case where the partition transforms" + + "are compatible but not identical. This config requires both " + + s"${V2_BUCKETING_ENABLED.key} and ${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " + + s"enabled and ${V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + + "to be disabled." + ) + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") .version("2.4.0") @@ -5201,6 +5213,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean = getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS) + def v2BucketingAllowCompatibleTransforms: Boolean = + getConf(SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 7cce599040189..3772d1f9f8847 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -24,9 +24,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition} +import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.connector.read._ import org.apache.spark.util.ArrayImplicits._ @@ -164,6 +165,18 @@ case class BatchScanExec( (groupedParts, expressions) } + // Also re-group the partitions if we are reducing compatible partition expressions + val finalGroupedPartitions = spjParams.reducers match { + case Some(reducers) => + val result = groupedPartitions.groupBy { case (row, _) => + KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers) + }.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq + val rowOrdering = RowOrdering.createNaturalAscendingOrdering( + expressions.map(_.dataType)) + result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + case _ => groupedPartitions + } + // When partially clustered, the input partitions are not grouped by partition // values. Here we'll need to check `commonPartitionValues` and decide how to group // and replicate splits within a partition. @@ -174,7 +187,7 @@ case class BatchScanExec( .get .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) .toMap - val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) => + val nestGroupedPartitions = finalGroupedPartitions.map { case (partValue, splits) => // `commonPartValuesMap` should contain the part value since it's the super set. val numSplits = commonPartValuesMap .get(InternalRowComparableWrapper(partValue, partExpressions)) @@ -207,7 +220,7 @@ case class BatchScanExec( } else { // either `commonPartitionValues` is not defined, or it is defined but // `applyPartialClustering` is false. - val partitionMapping = groupedPartitions.map { case (partValue, splits) => + val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) => InternalRowComparableWrapper(partValue, partExpressions) -> splits }.toMap @@ -224,7 +237,6 @@ case class BatchScanExec( case _ => filteredPartitions } - new DataSourceRDD( sparkContext, finalPartitions, readerFactory, supportsColumnar, customMetrics) } @@ -259,6 +271,7 @@ case class StoragePartitionJoinParams( keyGroupedPartitioning: Option[Seq[Expression]] = None, joinKeyPositions: Option[Seq[Int]] = None, commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, + reducers: Option[Seq[Option[Reducer[Any]]]] = None, applyPartialClustering: Boolean = false, replicatePartitions: Boolean = false) { override def equals(other: Any): Boolean = other match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 2a7c1206bb410..1c4c797861b2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper +import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} @@ -505,11 +506,28 @@ case class EnsureRequirements( } } - // Now we need to push-down the common partition key to the scan in each child - newLeft = populatePartitionValues(left, mergedPartValues, leftSpec.joinKeyPositions, - applyPartialClustering, replicateLeftSide) - newRight = populatePartitionValues(right, mergedPartValues, rightSpec.joinKeyPositions, - applyPartialClustering, replicateRightSide) + // in case of compatible but not identical partition expressions, we apply 'reduce' + // transforms to group one side's partitions as well as the common partition values + val leftReducers = leftSpec.reducers(rightSpec) + val rightReducers = rightSpec.reducers(leftSpec) + + if (leftReducers.isDefined || rightReducers.isDefined) { + mergedPartValues = reduceCommonPartValues(mergedPartValues, + leftSpec.partitioning.expressions, + leftReducers) + mergedPartValues = reduceCommonPartValues(mergedPartValues, + rightSpec.partitioning.expressions, + rightReducers) + val rowOrdering = RowOrdering + .createNaturalAscendingOrdering(partitionExprs.map(_.dataType)) + mergedPartValues = mergedPartValues.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + } + + // Now we need to push-down the common partition information to the scan in each child + newLeft = populateCommonPartitionInfo(left, mergedPartValues, leftSpec.joinKeyPositions, + leftReducers, applyPartialClustering, replicateLeftSide) + newRight = populateCommonPartitionInfo(right, mergedPartValues, rightSpec.joinKeyPositions, + rightReducers, applyPartialClustering, replicateRightSide) } } @@ -527,11 +545,12 @@ case class EnsureRequirements( joinType == LeftAnti || joinType == LeftOuter } - // Populate the common partition values down to the scan nodes - private def populatePartitionValues( + // Populate the common partition information down to the scan nodes + private def populateCommonPartitionInfo( plan: SparkPlan, values: Seq[(InternalRow, Int)], joinKeyPositions: Option[Seq[Int]], + reducers: Option[Seq[Option[Reducer[Any]]]], applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => @@ -539,13 +558,25 @@ case class EnsureRequirements( spjParams = scan.spjParams.copy( commonPartitionValues = Some(values), joinKeyPositions = joinKeyPositions, + reducers = reducers, applyPartialClustering = applyPartialClustering, replicatePartitions = replicatePartitions ) ) case node => - node.mapChildren(child => populatePartitionValues( - child, values, joinKeyPositions, applyPartialClustering, replicatePartitions)) + node.mapChildren(child => populateCommonPartitionInfo( + child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions)) + } + + private def reduceCommonPartValues(commonPartValues: Seq[(InternalRow, Int)], + expressions: Seq[Expression], + reducers: Option[Seq[Option[Reducer[Any]]]]) = { + reducers match { + case Some(reducers) => commonPartValues.groupBy { case (row, _) => + KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers) + }.map{ case(wrapper, splits) => (wrapper.row, splits.map(_._2).sum) }.toSeq + case _ => commonPartValues + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index e6448d4d80fda..2a49679719d92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -68,6 +68,10 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { .add("id", IntegerType) .add("data", StringType) .add("ts", TimestampType) + private val schema2 = new StructType() + .add("store_id", IntegerType) + .add("dept_id", IntegerType) + .add("data", StringType) test("clustered distribution: output partitioning should be KeyGroupedPartitioning") { val partitions: Array[Transform] = Array(Expressions.years("ts")) @@ -1310,6 +1314,312 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-47094: Support compatible buckets") { + val table1 = "tab1e1" + val table2 = "table2" + + Seq( + ((2, 4), (4, 2)), + ((4, 2), (2, 4)), + ((2, 2), (4, 6)), + ((6, 2), (2, 2))).foreach { + case ((table1buckets1, table1buckets2), (table2buckets1, table2buckets2)) => + catalog.clearTables() + + val partition1 = Array(bucket(table1buckets1, "store_id"), + bucket(table1buckets2, "dept_id")) + val partition2 = Array(bucket(table2buckets1, "store_id"), + bucket(table2buckets2, "dept_id")) + + Seq((table1, partition1), (table2, partition2)).foreach { case (tab, part) => + createTable(tab, schema2, part) + val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " + + "(0, 0, 'aa'), " + + "(0, 0, 'ab'), " + // duplicate partition key + "(0, 1, 'ac'), " + + "(0, 2, 'ad'), " + + "(0, 3, 'ae'), " + + "(0, 4, 'af'), " + + "(0, 5, 'ag'), " + + "(1, 0, 'ah'), " + + "(1, 0, 'ai'), " + // duplicate partition key + "(1, 1, 'aj'), " + + "(1, 2, 'ak'), " + + "(1, 3, 'al'), " + + "(1, 4, 'am'), " + + "(1, 5, 'an'), " + + "(2, 0, 'ao'), " + + "(2, 0, 'ap'), " + // duplicate partition key + "(2, 1, 'aq'), " + + "(2, 2, 'ar'), " + + "(2, 3, 'as'), " + + "(2, 4, 'at'), " + + "(2, 5, 'au'), " + + "(3, 0, 'av'), " + + "(3, 0, 'aw'), " + // duplicate partition key + "(3, 1, 'ax'), " + + "(3, 2, 'ay'), " + + "(3, 3, 'az'), " + + "(3, 4, 'ba'), " + + "(3, 5, 'bb'), " + + "(4, 0, 'bc'), " + + "(4, 0, 'bd'), " + // duplicate partition key + "(4, 1, 'be'), " + + "(4, 2, 'bf'), " + + "(4, 3, 'bg'), " + + "(4, 4, 'bh'), " + + "(4, 5, 'bi'), " + + "(5, 0, 'bj'), " + + "(5, 0, 'bk'), " + // duplicate partition key + "(5, 1, 'bl'), " + + "(5, 2, 'bm'), " + + "(5, 3, 'bn'), " + + "(5, 4, 'bo'), " + + "(5, 5, 'bp')" + + // additional unmatched partitions to test push down + val finalStr = if (tab == table1) { + insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')" + } else { + insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')" + } + + sql(finalStr) + } + + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString, + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.store_id, t1.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + val expectedBuckets = Math.min(table1buckets1, table2buckets1) * + Math.min(table1buckets2, table2buckets2) + assert(scans == Seq(expectedBuckets, expectedBuckets)) + + checkAnswer(df, Seq( + Row(0, 0, "aa", "aa"), + Row(0, 0, "aa", "ab"), + Row(0, 0, "ab", "aa"), + Row(0, 0, "ab", "ab"), + Row(0, 1, "ac", "ac"), + Row(0, 2, "ad", "ad"), + Row(0, 3, "ae", "ae"), + Row(0, 4, "af", "af"), + Row(0, 5, "ag", "ag"), + Row(1, 0, "ah", "ah"), + Row(1, 0, "ah", "ai"), + Row(1, 0, "ai", "ah"), + Row(1, 0, "ai", "ai"), + Row(1, 1, "aj", "aj"), + Row(1, 2, "ak", "ak"), + Row(1, 3, "al", "al"), + Row(1, 4, "am", "am"), + Row(1, 5, "an", "an"), + Row(2, 0, "ao", "ao"), + Row(2, 0, "ao", "ap"), + Row(2, 0, "ap", "ao"), + Row(2, 0, "ap", "ap"), + Row(2, 1, "aq", "aq"), + Row(2, 2, "ar", "ar"), + Row(2, 3, "as", "as"), + Row(2, 4, "at", "at"), + Row(2, 5, "au", "au"), + Row(3, 0, "av", "av"), + Row(3, 0, "av", "aw"), + Row(3, 0, "aw", "av"), + Row(3, 0, "aw", "aw"), + Row(3, 1, "ax", "ax"), + Row(3, 2, "ay", "ay"), + Row(3, 3, "az", "az"), + Row(3, 4, "ba", "ba"), + Row(3, 5, "bb", "bb"), + Row(4, 0, "bc", "bc"), + Row(4, 0, "bc", "bd"), + Row(4, 0, "bd", "bc"), + Row(4, 0, "bd", "bd"), + Row(4, 1, "be", "be"), + Row(4, 2, "bf", "bf"), + Row(4, 3, "bg", "bg"), + Row(4, 4, "bh", "bh"), + Row(4, 5, "bi", "bi"), + Row(5, 0, "bj", "bj"), + Row(5, 0, "bj", "bk"), + Row(5, 0, "bk", "bj"), + Row(5, 0, "bk", "bk"), + Row(5, 1, "bl", "bl"), + Row(5, 2, "bm", "bm"), + Row(5, 3, "bn", "bn"), + Row(5, 4, "bo", "bo"), + Row(5, 5, "bp", "bp"), + )) + } + } + } + } + + test("SPARK-47094: Support compatible buckets with less join keys than partition keys") { + val table1 = "tab1e1" + val table2 = "table2" + + Seq((2, 4), (4, 2), (2, 6), (6, 2)).foreach { + case (table1buckets, table2buckets) => + catalog.clearTables() + + val partition1 = Array(bucket(3, "store_id"), + bucket(table1buckets, "dept_id")) + val partition2 = Array(bucket(3, "store_id"), + bucket(table2buckets, "dept_id")) + + createTable(table1, schema2, partition1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 0, 'aa'), " + + "(1, 0, 'ab'), " + + "(2, 1, 'ac'), " + + "(3, 2, 'ad'), " + + "(4, 3, 'ae'), " + + "(5, 4, 'af'), " + + "(6, 5, 'ag'), " + + + // value without other side match + "(6, 6, 'xx')" + ) + + createTable(table2, schema2, partition2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(6, 0, '01'), " + + "(5, 1, '02'), " + // duplicate partition key + "(5, 1, '03'), " + + "(4, 2, '04'), " + + "(3, 3, '05'), " + + "(2, 4, '06'), " + + "(1, 5, '07'), " + + + // value without other side match + "(7, 7, '99')" + ) + + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.store_id, t2.store_id, t1.dept_id, t2.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + + val expectedBuckets = Math.min(table1buckets, table2buckets) + + assert(scans == Seq(expectedBuckets, expectedBuckets)) + + checkAnswer(df, Seq( + Row(0, 6, 0, 0, "aa", "01"), + Row(1, 6, 0, 0, "ab", "01"), + Row(2, 5, 1, 1, "ac", "02"), + Row(2, 5, 1, 1, "ac", "03"), + Row(3, 4, 2, 2, "ad", "04"), + Row(4, 3, 3, 3, "ae", "05"), + Row(5, 2, 4, 4, "af", "06"), + Row(6, 1, 5, 5, "ag", "07"), + )) + } + } + } + + test("SPARK-47094: Compatible buckets does not support SPJ with " + + "push-down values or partially-clustered") { + val table1 = "tab1e1" + val table2 = "table2" + + val partition1 = Array(bucket(4, "store_id"), + bucket(2, "dept_id")) + val partition2 = Array(bucket(2, "store_id"), + bucket(2, "dept_id")) + + createTable(table1, schema2, partition1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 0, 'aa'), " + + "(1, 1, 'bb'), " + + "(2, 2, 'cc')" + ) + + createTable(table2, schema2, partition2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(0, 0, 'aa'), " + + "(1, 1, 'bb'), " + + "(2, 2, 'cc')" + ) + + Seq(true, false).foreach{ allowPushDown => + Seq(true, false).foreach{ partiallyClustered => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> allowPushDown.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.store_id, t1.store_id, t1.dept_id, t2.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + + (allowPushDown, partiallyClustered) match { + case (true, false) => + assert(shuffles.isEmpty, "SPJ should be triggered") + assert(scans == Seq(2, 2)) + case (_, _) => + assert(shuffles.nonEmpty, "SPJ should not be triggered") + assert(scans == Seq(3, 2)) + } + + checkAnswer(df, Seq( + Row(0, 0, 0, 0, "aa", "aa"), + Row(1, 1, 1, 1, "bb", "bb"), + Row(2, 2, 2, 2, "cc", "cc") + )) + } + } + } + } + test("SPARK-44647: test join key is the second cluster key") { val table1 = "tab1e1" val table2 = "table2" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 61895d49c4a2a..67da85480ef92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -76,7 +76,7 @@ object UnboundBucketFunction extends UnboundFunction { override def name(): String = "bucket" } -object BucketFunction extends ScalarFunction[Int] { +object BucketFunction extends ReducibleFunction[Int, Int] { override def inputTypes(): Array[DataType] = Array(IntegerType, LongType) override def resultType(): DataType = IntegerType override def name(): String = "bucket" @@ -85,6 +85,23 @@ object BucketFunction extends ScalarFunction[Int] { override def produceResult(input: InternalRow): Int = { (input.getLong(1) % input.getInt(0)).toInt } + + override def reducer(func: ReducibleFunction[_, _], + thisNumBuckets: Option[_], + otherNumBuckets: Option[_]): Option[Reducer[Int]] = { + (thisNumBuckets, otherNumBuckets) match { + case (Some(thisNumBucketsVal: Int), Some(otherNumBucketsVal: Int)) + if func.isInstanceOf[ReducibleFunction[_, _]] && + ((thisNumBucketsVal > otherNumBucketsVal) && + (thisNumBucketsVal % otherNumBucketsVal == 0)) => + Some(BucketReducer(thisNumBucketsVal, otherNumBucketsVal)) + case _ => None + } + } +} + +case class BucketReducer(thisNumBuckets: Int, otherNumBuckets: Int) extends Reducer[Int] { + override def reduce(bucket: Int): Int = bucket % otherNumBuckets } object UnboundStringSelfFunction extends UnboundFunction {