Skip to content

Commit

Permalink
[SPARK-44647][SQL] Support SPJ where join keys are less than cluster …
Browse files Browse the repository at this point in the history
…keys

- Add new conf spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled
- Change key compatibility checks in EnsureRequirements.  Remove checks where all partition keys must be in join keys to allow isKeyCompatible = true in this case (if this flag is enabled)
- Change BatchScanExec/DataSourceV2Relation to group splits by join keys if they differ from partition keys (previously grouped only by partition values).  Do same for all auxiliary data structure, like commonPartValues.
- Implement partiallyClustered skew-handling.
  - Group only the replicate side (now by join key as well), replicate by the total size of other-side partitions that share the join key.
  - add an additional sort for partitions based on join key, as when we group the replicate side, partition ordering becomes out of order from the non-replicate side.

- Support Storage Partition Join in cases where the join condition does not contain all the partition keys, but just some of them

No

-Added tests in KeyGroupedPartitioningSuite
-Found two existing problems, will address in separate PR:
- Because of apache#37886   we have to select all join keys to trigger SPJ in this case, otherwise DSV2 scan does not report KeyGroupedPartitioning and SPJ does not get triggered.  Need to see how to relax this.
- https://issues.apache.org/jira/browse/SPARK-44641 was found when testing this change.  This pr refactors some of those code to add group-by-join-key, but doesnt change the underlying logic, so issue continues to exist.  Hopefully this will also get fixed in another way.

Closes apache#42306 from szehon-ho/spj_attempt_master.

Authored-by: Szehon Ho <szehon.apache@gmail.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
szehon-ho committed Feb 7, 2024
1 parent f0691c4 commit ce439be
Show file tree
Hide file tree
Showing 5 changed files with 389 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,14 @@ case class KeyGroupedPartitioning(
} else {
// We'll need to find leaf attributes from the partition expressions first.
val attributes = expressions.flatMap(_.collectLeaves())
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))

if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
// check that all join keys (required clustering keys) contained in partitioning
requiredClustering.forall(x => attributes.exists(_.semanticEquals(x))) &&
expressions.forall(_.collectLeaves().size == 1)
} else {
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
}

case _ =>
Expand All @@ -389,8 +396,20 @@ case class KeyGroupedPartitioning(
}
}

override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
KeyGroupedShuffleSpec(this, distribution)
override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = {
val result = KeyGroupedShuffleSpec(this, distribution)
if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
// If allowing join keys to be subset of clustering keys, we should create a new
// `KeyGroupedPartitioning` here that is grouped on the join keys instead, and use that as
// the returned shuffle spec.
val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2)
val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions,
partitionValues, originalPartitionValues)
result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions))
} else {
result
}
}

lazy val uniquePartitionValues: Seq[InternalRow] = {
partitionValues
Expand All @@ -403,8 +422,25 @@ case class KeyGroupedPartitioning(
object KeyGroupedPartitioning {
def apply(
expressions: Seq[Expression],
partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues, partitionValues)
projectionPositions: Seq[Int],
partitionValues: Seq[InternalRow],
originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
val projectedExpressions = projectionPositions.map(expressions(_))
val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _))
val projectedOriginalPartitionValues =
originalPartitionValues.map(project(expressions, projectionPositions, _))

KeyGroupedPartitioning(projectedExpressions, projectedPartitionValues.length,
projectedPartitionValues, projectedOriginalPartitionValues)
}

def project(
expressions: Seq[Expression],
positions: Seq[Int],
input: InternalRow): InternalRow = {
val projectedValues: Array[Any] = positions.map(i => input.get(i, expressions(i).dataType))
.toArray
new GenericInternalRow(projectedValues)
}

def supportsExpressions(expressions: Seq[Expression]): Boolean = {
Expand Down Expand Up @@ -717,9 +753,18 @@ case class CoalescedHashShuffleSpec(
override def numPartitions: Int = partitions.length
}

/**
* [[ShuffleSpec]] created by [[KeyGroupedPartitioning]].
*
* @param partitioning key grouped partitioning
* @param distribution distribution
* @param joinKeyPosition position of join keys among cluster keys.
* This is set if joining on a subset of cluster keys is allowed.
*/
case class KeyGroupedShuffleSpec(
partitioning: KeyGroupedPartitioning,
distribution: ClusteredDistribution) extends ShuffleSpec {
distribution: ClusteredDistribution,
joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec {

/**
* A sequence where each element is a set of positions of the partition expression to the cluster
Expand Down Expand Up @@ -754,7 +799,7 @@ case class KeyGroupedShuffleSpec(
// 3.3 each pair of partition expressions at the same index must share compatible
// transform functions.
// 4. the partition values from both sides are following the same order.
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) =>
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) =>
distribution.clustering.length == otherDistribution.clustering.length &&
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1492,6 +1492,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS =
buildConf("spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled")
.doc("Whether to allow storage-partition join in the case where join keys are" +
"a subset of the partition keys of the source tables. At planning time, " +
"Spark will group the partitions by only those keys that are in the join keys." +
s"This is currently enabled only if ${REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key} " +
"is false."
)
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
Expand Down Expand Up @@ -4783,6 +4795,9 @@ class SQLConf extends Serializable with Logging {
def v2BucketingShuffleEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED)

def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean =
getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)

def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ case class BatchScanExec(
val newPartValues = spjParams.commonPartitionValues.get.flatMap {
case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
}
k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues)
val expressions = spjParams.joinKeyPositions match {
case Some(projectionPositions) => projectionPositions.map(i => k.expressions(i))
case _ => k.expressions
}
k.copy(expressions = expressions, numPartitions = newPartValues.length,
partitionValues = newPartValues)
case p => p
}
}
Expand All @@ -132,14 +137,29 @@ case class BatchScanExec(
// return an empty RDD with 1 partition if dynamic filtering removed the only split
sparkContext.parallelize(Array.empty[InternalRow], 1)
} else {
var finalPartitions = filteredPartitions

outputPartitioning match {
val finalPartitions = outputPartitioning match {
case p: KeyGroupedPartitioning =>
val groupedPartitions = filteredPartitions.map(splits => {
assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey])
(splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits)
})
assert(spjParams.keyGroupedPartitioning.isDefined)
val expressions = spjParams.keyGroupedPartitioning.get

// Re-group the input partitions if we are projecting on a subset of join keys
val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match {
case Some(projectPositions) =>
val projectedExpressions = projectPositions.map(i => expressions(i))
val parts = filteredPartitions.flatten.groupBy(part => {
val row = part.asInstanceOf[HasPartitionKey].partitionKey()
val projectedRow = KeyGroupedPartitioning.project(
expressions, projectPositions, row)
InternalRowComparableWrapper(projectedRow, projectedExpressions)
}).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq
(parts, projectedExpressions)
case _ =>
val groupedParts = filteredPartitions.map(splits => {
assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey])
(splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits)
})
(groupedParts, expressions)
}

// When partially clustered, the input partitions are not grouped by partition
// values. Here we'll need to check `commonPartitionValues` and decide how to group
Expand All @@ -149,12 +169,12 @@ case class BatchScanExec(
// should contain.
val commonPartValuesMap = spjParams.commonPartitionValues
.get
.map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2))
.map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2))
.toMap
val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) =>
// `commonPartValuesMap` should contain the part value since it's the super set.
val numSplits = commonPartValuesMap
.get(InternalRowComparableWrapper(partValue, p.expressions))
.get(InternalRowComparableWrapper(partValue, partExpressions))
assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
"common partition values from Spark plan")

Expand All @@ -169,37 +189,37 @@ case class BatchScanExec(
// sides of a join will have the same number of partitions & splits.
splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
}
(InternalRowComparableWrapper(partValue, p.expressions), newSplits)
(InternalRowComparableWrapper(partValue, partExpressions), newSplits)
}

// Now fill missing partition keys with empty partitions
val partitionMapping = nestGroupedPartitions.toMap
finalPartitions = spjParams.commonPartitionValues.get.flatMap {
spjParams.commonPartitionValues.get.flatMap {
case (partValue, numSplits) =>
// Use empty partition for those partition values that are not present.
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions),
InternalRowComparableWrapper(partValue, partExpressions),
Seq.fill(numSplits)(Seq.empty))
}
} else {
// either `commonPartitionValues` is not defined, or it is defined but
// `applyPartialClustering` is false.
val partitionMapping = groupedPartitions.map { case (partValue, splits) =>
InternalRowComparableWrapper(partValue, p.expressions) -> splits
InternalRowComparableWrapper(partValue, partExpressions) -> splits
}.toMap

// In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there
// could exist duplicated partition values, as partition grouping is not done
// at the beginning and postponed to this method. It is important to use unique
// partition values here so that grouped partitions won't get duplicated.
finalPartitions = p.uniquePartitionValues.map { partValue =>
p.uniquePartitionValues.map { partValue =>
// Use empty partition for those partition values that are not present
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
InternalRowComparableWrapper(partValue, partExpressions), Seq.empty)
}
}

case _ =>
case _ => filteredPartitions
}

new DataSourceRDD(
Expand Down Expand Up @@ -234,6 +254,7 @@ case class BatchScanExec(

case class StoragePartitionJoinParams(
keyGroupedPartitioning: Option[Seq[Expression]] = None,
joinKeyPositions: Option[Seq[Int]] = None,
commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
applyPartialClustering: Boolean = false,
replicatePartitions: Boolean = false) {
Expand All @@ -247,6 +268,7 @@ case class StoragePartitionJoinParams(
}

override def hashCode(): Int = Objects.hashCode(
joinKeyPositions: Option[Seq[Int]],
commonPartitionValues: Option[Seq[(InternalRow, Int)]],
applyPartialClustering: java.lang.Boolean,
replicatePartitions: java.lang.Boolean)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ case class EnsureRequirements(
val rightSpec = specs(1)

var isCompatible = false
if (!conf.v2BucketingPushPartValuesEnabled) {
if (!conf.v2BucketingPushPartValuesEnabled &&
!conf.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
isCompatible = leftSpec.isCompatibleWith(rightSpec)
} else {
logInfo("Pushing common partition values for storage-partitioned join")
Expand Down Expand Up @@ -505,10 +506,10 @@ case class EnsureRequirements(
}

// Now we need to push-down the common partition key to the scan in each child
newLeft = populatePartitionValues(
left, mergedPartValues, applyPartialClustering, replicateLeftSide)
newRight = populatePartitionValues(
right, mergedPartValues, applyPartialClustering, replicateRightSide)
newLeft = populatePartitionValues(left, mergedPartValues, leftSpec.joinKeyPositions,
applyPartialClustering, replicateLeftSide)
newRight = populatePartitionValues(right, mergedPartValues, rightSpec.joinKeyPositions,
applyPartialClustering, replicateRightSide)
}
}

Expand All @@ -530,19 +531,21 @@ case class EnsureRequirements(
private def populatePartitionValues(
plan: SparkPlan,
values: Seq[(InternalRow, Int)],
joinKeyPositions: Option[Seq[Int]],
applyPartialClustering: Boolean,
replicatePartitions: Boolean): SparkPlan = plan match {
case scan: BatchScanExec =>
scan.copy(
spjParams = scan.spjParams.copy(
commonPartitionValues = Some(values),
joinKeyPositions = joinKeyPositions,
applyPartialClustering = applyPartialClustering,
replicatePartitions = replicatePartitions
)
)
case node =>
node.mapChildren(child => populatePartitionValues(
child, values, applyPartialClustering, replicatePartitions))
child, values, joinKeyPositions, applyPartialClustering, replicatePartitions))
}

/**
Expand Down
Loading

0 comments on commit ce439be

Please sign in to comment.