From 4251ff27b116d46419817d6b6350571f3d2b5753 Mon Sep 17 00:00:00 2001 From: Rui Mo Date: Thu, 20 May 2021 19:48:20 +0800 Subject: [PATCH] [NSE-329] fix out partitioning in BHJ and SHJ (#335) * fix out partitioning for SHJ and BHJ * refine --- .../ColumnarBroadcastHashJoinExec.scala | 72 ++++++++++++++++++- .../ColumnarShuffledHashJoinExec.scala | 19 +++++ 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarBroadcastHashJoinExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarBroadcastHashJoinExec.scala index a2fb26206..c3c18e5a1 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarBroadcastHashJoinExec.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarBroadcastHashJoinExec.scala @@ -29,10 +29,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils -import org.apache.spark.sql.execution.joins.{HashJoin,ShuffledJoin,BaseJoinExec} +import org.apache.spark.sql.execution.joins.{BaseJoinExec, HashJoin, ShuffledJoin} import org.apache.spark.sql.execution.joins.HashedRelationInfo import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} @@ -40,6 +41,7 @@ import org.apache.spark.util.{ExecutorManager, UserAddedJarUtils} import org.apache.spark.sql.types.DecimalType import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.ListBuffer /** @@ -138,6 +140,74 @@ case class ColumnarBroadcastHashJoinExec( s"ColumnarBroadcastHashJoinExec doesn't support doExecute") } + val broadcastHashJoinOutputPartitioningExpandLimit: Int = sqlContext.getConf( + "spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit").trim().toInt + + override lazy val outputPartitioning: Partitioning = { + joinType match { + case _: InnerLike if broadcastHashJoinOutputPartitioningExpandLimit > 0 => + streamedPlan.outputPartitioning match { + case h: HashPartitioning => expandOutputPartitioning(h) + case c: PartitioningCollection => expandOutputPartitioning(c) + case other => other + } + case _ => streamedPlan.outputPartitioning + } + } + + // An one-to-many mapping from a streamed key to build keys. + private lazy val streamedKeyToBuildKeyMapping = { + val mapping = mutable.Map.empty[Expression, Seq[Expression]] + streamedKeyExprs.zip(buildKeyExprs).foreach { + case (streamedKey, buildKey) => + val key = streamedKey.canonicalized + mapping.get(key) match { + case Some(v) => mapping.put(key, v :+ buildKey) + case None => mapping.put(key, Seq(buildKey)) + } + } + mapping.toMap + } + + // Expands the given partitioning collection recursively. + private def expandOutputPartitioning(partitioning: PartitioningCollection): PartitioningCollection = { + PartitioningCollection(partitioning.partitionings.flatMap { + case h: HashPartitioning => expandOutputPartitioning(h).partitionings + case c: PartitioningCollection => Seq(expandOutputPartitioning(c)) + case other => Seq(other) + }) + } + + // Expands the given hash partitioning by substituting streamed keys with build keys. + // For example, if the expressions for the given partitioning are Seq("a", "b", "c") + // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"), + // the expanded partitioning will have the following expressions: + // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). + // The expanded expressions are returned as PartitioningCollection. + private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = { + val maxNumCombinations = broadcastHashJoinOutputPartitioningExpandLimit + var currentNumCombinations = 0 + + def generateExprCombinations(current: Seq[Expression], + accumulated: Seq[Expression]): Seq[Seq[Expression]] = { + if (currentNumCombinations >= maxNumCombinations) { + Nil + } else if (current.isEmpty) { + currentNumCombinations += 1 + Seq(accumulated) + } else { + val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) + generateExprCombinations(current.tail, accumulated :+ current.head) ++ + buildKeysOpt.map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b))) + .getOrElse(Nil) + } + } + + PartitioningCollection( + generateExprCombinations(partitioning.expressions, Nil) + .map(HashPartitioning(_, partitioning.numPartitions))) + } + override def inputRDDs(): Seq[RDD[ColumnarBatch]] = streamedPlan match { case c: ColumnarCodegenSupport if c.supportColumnarCodegen == true => c.inputRDDs diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarShuffledHashJoinExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarShuffledHashJoinExec.scala index 31935dd97..a610d925a 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarShuffledHashJoinExec.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarShuffledHashJoinExec.scala @@ -170,6 +170,25 @@ case class ColumnarShuffledHashJoinExec( throw new UnsupportedOperationException( s"ColumnarShuffledHashJoinExec doesn't support doExecute") } + + override def outputPartitioning: Partitioning = buildSide match { + case BuildLeft => + joinType match { + case _: InnerLike | RightOuter => right.outputPartitioning + case x => + throw new IllegalArgumentException( + s"HashJoin should not take $x as the JoinType with building left side") + } + case BuildRight => + joinType match { + case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => + left.outputPartitioning + case x => + throw new IllegalArgumentException( + s"HashJoin should not take $x as the JoinType with building right side") + } + } + override def supportsColumnar = true override def inputRDDs(): Seq[RDD[ColumnarBatch]] = streamedPlan match {