Skip to content

Commit

Permalink
[SPARK-47094][SQL] SPJ : Dynamically rebalance number of buckets when…
Browse files Browse the repository at this point in the history
… they are not equal

  ### What changes were proposed in this pull request?
-- Allow SPJ between 'compatible' bucket funtions
-- Add a mechanism to define 'reducible' functions, one function whose output can be 'reduced' to another for all inputs.

  ### Why are the changes needed?
-- SPJ currently applies only if the partition transform expressions on both sides are identifical.

  ### Does this PR introduce _any_ user-facing change?
No

  ### How was this patch tested?
Added new tests in KeyGroupedPartitioningSuite

  ### Was this patch authored or co-authored using generative AI tooling?
No
  • Loading branch information
szehon-ho committed Feb 27, 2024
1 parent 76c4fd5 commit c81c039
Show file tree
Hide file tree
Showing 9 changed files with 531 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.apache.spark.sql.connector.catalog.functions;

import org.apache.spark.annotation.Evolving;

/**
* A 'reducer' for output of user-defined functions.
*
* A user_defined function f_source(x) is 'reducible' on another user_defined function f_target(x),
* if there exists a 'reducer' r(x) such that r(f_source(x)) = f_target(x) for all input x.
* @param <T> function output type
* @since 4.0.0
*/
@Evolving
public interface Reducer<T> {
T reduce(T arg1);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.apache.spark.sql.connector.catalog.functions;

import org.apache.spark.annotation.Evolving;
import scala.Option;

/**
* Base class for user-defined functions that can be 'reduced' on another function.
*
* A function f_source(x) is 'reducible' on another function f_target(x) if
* there exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x.
*
* @since 4.0.0
*/
@Evolving
public interface ReducibleFunction<T, A> extends ScalarFunction<T> {

/**
* If this function is 'reducible' on another function, return the {@link Reducer} function.
* @param other other function
* @param thisArgument argument for this function instance
* @param otherArgument argument for other function instance
* @return a reduction function if it is reducible, none if not
*/
Option<Reducer<A>> reducer(ReducibleFunction<?, ?> other, Option<?> thisArgument, Option<?> otherArgument);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.connector.catalog.functions.BoundFunction
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ReducibleFunction}
import org.apache.spark.sql.types.DataType

/**
Expand Down Expand Up @@ -54,6 +54,31 @@ case class TransformExpression(
false
}

/**
* Whether this [[TransformExpression]]'s function is compatible with the `other`
* [[TransformExpression]]'s function.
*
* This is true if both are instances of [[ReducibleFunction]] and there exists a [[Reducer]] r(x)
* such that r(t1(x)) = t2(x), or r(t2(x)) = t1(x), for all input x.
*
* @param other the transform expression to compare to
* @return true if compatible, false if not
*/
def isCompatible(other: TransformExpression): Boolean = {
if (isSameFunction(other)) {
true
} else {
(function, other.function) match {
case (f: ReducibleFunction[Any, Any] @unchecked,
o: ReducibleFunction[Any, Any] @unchecked) =>
val reducer = f.reducer(o, numBucketsOpt, other.numBucketsOpt)
val otherReducer = o.reducer(f, other.numBucketsOpt, numBucketsOpt)
reducer.isDefined || otherReducer.isDefined
case _ => false
}
}
}

override def dataType: DataType = function.resultType()

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
import org.apache.spark.sql.connector.catalog.functions.{Reducer, ReducibleFunction}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, IntegerType}

Expand Down Expand Up @@ -635,6 +636,22 @@ trait ShuffleSpec {
*/
def createPartitioning(clustering: Seq[Expression]): Partitioning =
throw SparkUnsupportedOperationException()

/**
* Return a set of [[Reducer]] for the partition expressions of this shuffle spec,
* on the partition expressions of another shuffle spec.
* <p>
* A [[Reducer]] exists for a partition expression function of this shuffle spec if it is
* 'reducible' on the corresponding partition expression function of the other shuffle spec.
* <p>
* If a value is returned, there must be one Option[[Reducer]] per partition expression.
* A None value in the set indicates that the particular partition expression is not reducible
* on the corresponding expression on the other shuffle spec.
* <p>
* Returning none also indicates that none of the partition expressions can be reduced on the
* corresponding expression on the other shuffle spec.
*/
def reducers(spec: ShuffleSpec): Option[Seq[Option[Reducer[Any]]]] = None
}

case object SinglePartitionShuffleSpec extends ShuffleSpec {
Expand Down Expand Up @@ -829,20 +846,60 @@ case class KeyGroupedShuffleSpec(
}
}

override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
// Only support partition expressions are AttributeReference for now
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
}

override def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[Any]]]] = {
other match {
case otherSpec: KeyGroupedShuffleSpec =>
val results = partitioning.expressions.zip(otherSpec.partitioning.expressions).map {
case (e1: TransformExpression, e2: TransformExpression)
if e1.function.isInstanceOf[ReducibleFunction[Any, Any]@unchecked]
&& e2.function.isInstanceOf[ReducibleFunction[Any, Any]@unchecked] =>
e1.function.asInstanceOf[ReducibleFunction[Any, Any]].reducer(
e2.function.asInstanceOf[ReducibleFunction[Any, Any]],
e1.numBucketsOpt.map(a => a.asInstanceOf[Any]),
e2.numBucketsOpt.map(a => a.asInstanceOf[Any]))
case (_, _) => None
}

// optimize to not return a value, if none of the partition expressions need reducing
if (results.forall(p => p.isEmpty)) None else Some(results)
case _ => None
}
}

private def isExpressionCompatible(left: Expression, right: Expression): Boolean =
(left, right) match {
case (_: LeafExpression, _: LeafExpression) => true
case (left: TransformExpression, right: TransformExpression) =>
left.isSameFunction(right)
if (SQLConf.get.v2BucketingPushPartValuesEnabled &&
!SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
SQLConf.get.v2BucketingAllowCompatibleTransforms) {
left.isCompatible(right)
} else {
left.isSameFunction(right)
}
case _ => false
}
}

override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
// Only support partition expressions are AttributeReference for now
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
object KeyGroupedShuffleSpec {
def reducePartitionValue(row: InternalRow,
expressions: Seq[Expression],
reducers: Seq[Option[Reducer[Any]]]):
InternalRowComparableWrapper = {
val partitionVals = row.toSeq(expressions.map(_.dataType))
val reducedRow = partitionVals.zip(reducers).map{
case (v, Some(reducer)) => reducer.reduce(v)
case (v, _) => v
}.toArray
InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS =
buildConf("spark.sql.sources.v2.bucketing.allow.enabled")
.doc("Whether to allow storage-partition join in the case where the partition transforms" +
"are compatible but not identical. This config requires both " +
s"${V2_BUCKETING_ENABLED.key} and ${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " +
s"enabled and ${V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
"to be disabled."
)
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
Expand Down Expand Up @@ -5201,6 +5213,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean =
getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)

def v2BucketingAllowCompatibleTransforms: Boolean =
getConf(SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS)

def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.catalog.functions.Reducer
import org.apache.spark.sql.connector.read._
import org.apache.spark.util.ArrayImplicits._

Expand Down Expand Up @@ -164,6 +165,18 @@ case class BatchScanExec(
(groupedParts, expressions)
}

// Also re-group the partitions if we are reducing compatible partition expressions
val finalGroupedPartitions = spjParams.reducers match {
case Some(reducers) =>
val result = groupedPartitions.groupBy { case (row, _) =>
KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers)
}.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq
val rowOrdering = RowOrdering.createNaturalAscendingOrdering(
expressions.map(_.dataType))
result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1))
case _ => groupedPartitions
}

// When partially clustered, the input partitions are not grouped by partition
// values. Here we'll need to check `commonPartitionValues` and decide how to group
// and replicate splits within a partition.
Expand All @@ -174,7 +187,7 @@ case class BatchScanExec(
.get
.map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2))
.toMap
val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) =>
val nestGroupedPartitions = finalGroupedPartitions.map { case (partValue, splits) =>
// `commonPartValuesMap` should contain the part value since it's the super set.
val numSplits = commonPartValuesMap
.get(InternalRowComparableWrapper(partValue, partExpressions))
Expand Down Expand Up @@ -207,7 +220,7 @@ case class BatchScanExec(
} else {
// either `commonPartitionValues` is not defined, or it is defined but
// `applyPartialClustering` is false.
val partitionMapping = groupedPartitions.map { case (partValue, splits) =>
val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) =>
InternalRowComparableWrapper(partValue, partExpressions) -> splits
}.toMap

Expand All @@ -224,7 +237,6 @@ case class BatchScanExec(

case _ => filteredPartitions
}

new DataSourceRDD(
sparkContext, finalPartitions, readerFactory, supportsColumnar, customMetrics)
}
Expand Down Expand Up @@ -259,6 +271,7 @@ case class StoragePartitionJoinParams(
keyGroupedPartitioning: Option[Seq[Expression]] = None,
joinKeyPositions: Option[Seq[Int]] = None,
commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
reducers: Option[Seq[Option[Reducer[Any]]]] = None,
applyPartialClustering: Boolean = false,
replicatePartitions: Boolean = false) {
override def equals(other: Any): Boolean = other match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
import org.apache.spark.sql.connector.catalog.functions.Reducer
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
Expand Down Expand Up @@ -505,11 +506,28 @@ case class EnsureRequirements(
}
}

// Now we need to push-down the common partition key to the scan in each child
newLeft = populatePartitionValues(left, mergedPartValues, leftSpec.joinKeyPositions,
applyPartialClustering, replicateLeftSide)
newRight = populatePartitionValues(right, mergedPartValues, rightSpec.joinKeyPositions,
applyPartialClustering, replicateRightSide)
// in case of compatible but not identical partition expressions, we apply 'reduce'
// transforms to group one side's partitions as well as the common partition values
val leftReducers = leftSpec.reducers(rightSpec)
val rightReducers = rightSpec.reducers(leftSpec)

if (leftReducers.isDefined || rightReducers.isDefined) {
mergedPartValues = reduceCommonPartValues(mergedPartValues,
leftSpec.partitioning.expressions,
leftReducers)
mergedPartValues = reduceCommonPartValues(mergedPartValues,
rightSpec.partitioning.expressions,
rightReducers)
val rowOrdering = RowOrdering
.createNaturalAscendingOrdering(partitionExprs.map(_.dataType))
mergedPartValues = mergedPartValues.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1))
}

// Now we need to push-down the common partition information to the scan in each child
newLeft = populateCommonPartitionInfo(left, mergedPartValues, leftSpec.joinKeyPositions,
leftReducers, applyPartialClustering, replicateLeftSide)
newRight = populateCommonPartitionInfo(right, mergedPartValues, rightSpec.joinKeyPositions,
rightReducers, applyPartialClustering, replicateRightSide)
}
}

Expand All @@ -527,25 +545,38 @@ case class EnsureRequirements(
joinType == LeftAnti || joinType == LeftOuter
}

// Populate the common partition values down to the scan nodes
private def populatePartitionValues(
// Populate the common partition information down to the scan nodes
private def populateCommonPartitionInfo(
plan: SparkPlan,
values: Seq[(InternalRow, Int)],
joinKeyPositions: Option[Seq[Int]],
reducers: Option[Seq[Option[Reducer[Any]]]],
applyPartialClustering: Boolean,
replicatePartitions: Boolean): SparkPlan = plan match {
case scan: BatchScanExec =>
scan.copy(
spjParams = scan.spjParams.copy(
commonPartitionValues = Some(values),
joinKeyPositions = joinKeyPositions,
reducers = reducers,
applyPartialClustering = applyPartialClustering,
replicatePartitions = replicatePartitions
)
)
case node =>
node.mapChildren(child => populatePartitionValues(
child, values, joinKeyPositions, applyPartialClustering, replicatePartitions))
node.mapChildren(child => populateCommonPartitionInfo(
child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions))
}

private def reduceCommonPartValues(commonPartValues: Seq[(InternalRow, Int)],
expressions: Seq[Expression],
reducers: Option[Seq[Option[Reducer[Any]]]]) = {
reducers match {
case Some(reducers) => commonPartValues.groupBy { case (row, _) =>
KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers)
}.map{ case(wrapper, splits) => (wrapper.row, splits.map(_._2).sum) }.toSeq
case _ => commonPartValues
}
}

/**
Expand Down
Loading

0 comments on commit c81c039

Please sign in to comment.