Skip to content

Commit

Permalink
[NSE-329] fix out partitioning in BHJ and SHJ (oap-project#335)
Browse files Browse the repository at this point in the history
* fix out partitioning for SHJ and BHJ

* refine
  • Loading branch information
rui-mo authored and zhouyuan committed May 20, 2021
1 parent 66a8ccb commit 4251ff2
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,19 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils
import org.apache.spark.sql.execution.joins.{HashJoin,ShuffledJoin,BaseJoinExec}
import org.apache.spark.sql.execution.joins.{BaseJoinExec, HashJoin, ShuffledJoin}
import org.apache.spark.sql.execution.joins.HashedRelationInfo
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.apache.spark.util.{ExecutorManager, UserAddedJarUtils}
import org.apache.spark.sql.types.DecimalType

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

/**
Expand Down Expand Up @@ -138,6 +140,74 @@ case class ColumnarBroadcastHashJoinExec(
s"ColumnarBroadcastHashJoinExec doesn't support doExecute")
}

val broadcastHashJoinOutputPartitioningExpandLimit: Int = sqlContext.getConf(
"spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit").trim().toInt

override lazy val outputPartitioning: Partitioning = {
joinType match {
case _: InnerLike if broadcastHashJoinOutputPartitioningExpandLimit > 0 =>
streamedPlan.outputPartitioning match {
case h: HashPartitioning => expandOutputPartitioning(h)
case c: PartitioningCollection => expandOutputPartitioning(c)
case other => other
}
case _ => streamedPlan.outputPartitioning
}
}

// An one-to-many mapping from a streamed key to build keys.
private lazy val streamedKeyToBuildKeyMapping = {
val mapping = mutable.Map.empty[Expression, Seq[Expression]]
streamedKeyExprs.zip(buildKeyExprs).foreach {
case (streamedKey, buildKey) =>
val key = streamedKey.canonicalized
mapping.get(key) match {
case Some(v) => mapping.put(key, v :+ buildKey)
case None => mapping.put(key, Seq(buildKey))
}
}
mapping.toMap
}

// Expands the given partitioning collection recursively.
private def expandOutputPartitioning(partitioning: PartitioningCollection): PartitioningCollection = {
PartitioningCollection(partitioning.partitionings.flatMap {
case h: HashPartitioning => expandOutputPartitioning(h).partitionings
case c: PartitioningCollection => Seq(expandOutputPartitioning(c))
case other => Seq(other)
})
}

// Expands the given hash partitioning by substituting streamed keys with build keys.
// For example, if the expressions for the given partitioning are Seq("a", "b", "c")
// where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"),
// the expanded partitioning will have the following expressions:
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
// The expanded expressions are returned as PartitioningCollection.
private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = {
val maxNumCombinations = broadcastHashJoinOutputPartitioningExpandLimit
var currentNumCombinations = 0

def generateExprCombinations(current: Seq[Expression],
accumulated: Seq[Expression]): Seq[Seq[Expression]] = {
if (currentNumCombinations >= maxNumCombinations) {
Nil
} else if (current.isEmpty) {
currentNumCombinations += 1
Seq(accumulated)
} else {
val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized)
generateExprCombinations(current.tail, accumulated :+ current.head) ++
buildKeysOpt.map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b)))
.getOrElse(Nil)
}
}

PartitioningCollection(
generateExprCombinations(partitioning.expressions, Nil)
.map(HashPartitioning(_, partitioning.numPartitions)))
}

override def inputRDDs(): Seq[RDD[ColumnarBatch]] = streamedPlan match {
case c: ColumnarCodegenSupport if c.supportColumnarCodegen == true =>
c.inputRDDs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,25 @@ case class ColumnarShuffledHashJoinExec(
throw new UnsupportedOperationException(
s"ColumnarShuffledHashJoinExec doesn't support doExecute")
}

override def outputPartitioning: Partitioning = buildSide match {
case BuildLeft =>
joinType match {
case _: InnerLike | RightOuter => right.outputPartitioning
case x =>
throw new IllegalArgumentException(
s"HashJoin should not take $x as the JoinType with building left side")
}
case BuildRight =>
joinType match {
case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin =>
left.outputPartitioning
case x =>
throw new IllegalArgumentException(
s"HashJoin should not take $x as the JoinType with building right side")
}
}

override def supportsColumnar = true

override def inputRDDs(): Seq[RDD[ColumnarBatch]] = streamedPlan match {
Expand Down

0 comments on commit 4251ff2

Please sign in to comment.