Skip to content

Commit

Permalink
revert spark 21052 in spark 2.3 branch
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Dec 14, 2018
1 parent 96a5a12 commit ad3bdad
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ case class BroadcastHashJoinExec(
extends BinaryExecNode with HashJoin with CodegenSupport {

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

override def requiredChildDistribution: Seq[Distribution] = {
val mode = HashedRelationBroadcastMode(buildKeys)
Expand All @@ -62,13 +61,12 @@ case class BroadcastHashJoinExec(

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val avgHashProbe = longMetric("avgHashProbe")

val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
streamedPlan.execute().mapPartitions { streamedIter =>
val hashed = broadcastRelation.value.asReadOnlyCopy()
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
join(streamedIter, hashed, numOutputRows, avgHashProbe)
join(streamedIter, hashed, numOutputRows)
}
}

Expand Down Expand Up @@ -110,23 +108,6 @@ case class BroadcastHashJoinExec(
}
}

/**
* Returns the codes used to add a task completion listener to update avg hash probe
* at the end of the task.
*/
private def genTaskListener(avgHashProbe: String, relationTerm: String): String = {
val listenerClass = classOf[TaskCompletionListener].getName
val taskContextClass = classOf[TaskContext].getName
s"""
| $taskContextClass$$.MODULE$$.get().addTaskCompletionListener(new $listenerClass() {
| @Override
| public void onTaskCompletion($taskContextClass context) {
| $avgHashProbe.set($relationTerm.getAverageProbesPerLookup());
| }
| });
""".stripMargin
}

/**
* Returns a tuple of Broadcast of HashedRelation and the variable name for it.
*/
Expand All @@ -144,7 +125,6 @@ case class BroadcastHashJoinExec(
v => s"""
| $v = (($clsName) $broadcast.value()).asReadOnlyCopy();
| incPeakExecutionMemory($v.estimatedSize());
| ${genTaskListener(avgHashProbe, v)}
""".stripMargin, forceInline = true)
(broadcastRelation, relationTerm)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,7 @@ trait HashJoin {
protected def join(
streamedIter: Iterator[InternalRow],
hashed: HashedRelation,
numOutputRows: SQLMetric,
avgHashProbe: SQLMetric): Iterator[InternalRow] = {
numOutputRows: SQLMetric): Iterator[InternalRow] = {

val joinedIter = joinType match {
case _: InnerLike =>
Expand All @@ -213,10 +212,6 @@ trait HashJoin {
s"BroadcastHashJoin should not take $x as the JoinType")
}

// At the end of the task, we update the avg hash probe.
TaskContext.get().addTaskCompletionListener(_ =>
avgHashProbe.set(hashed.getAverageProbesPerLookup))

val resultProj = createResultProjection
joinedIter.map { r =>
numOutputRows += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,6 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation {
* Release any used resources.
*/
def close(): Unit

/**
* Returns the average number of probes per key lookup.
*/
def getAverageProbesPerLookup: Double
}

private[execution] object HashedRelation {
Expand Down Expand Up @@ -280,8 +275,6 @@ private[joins] class UnsafeHashedRelation(
override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
read(() => in.readInt(), () => in.readLong(), in.readBytes)
}

override def getAverageProbesPerLookup: Double = binaryMap.getAverageProbesPerLookup
}

private[joins] object UnsafeHashedRelation {
Expand Down Expand Up @@ -395,10 +388,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
// The number of unique keys.
private var numKeys = 0L

// Tracking average number of probes per key lookup.
private var numKeyLookups = 0L
private var numProbes = 0L

// needed by serializer
def this() = {
this(
Expand Down Expand Up @@ -483,8 +472,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
if (isDense) {
numKeyLookups += 1
numProbes += 1
if (key >= minKey && key <= maxKey) {
val value = array((key - minKey).toInt)
if (value > 0) {
Expand All @@ -493,14 +480,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
}
} else {
var pos = firstSlot(key)
numKeyLookups += 1
numProbes += 1
while (array(pos + 1) != 0) {
if (array(pos) == key) {
return getRow(array(pos + 1), resultRow)
}
pos = nextSlot(pos)
numProbes += 1
}
}
null
Expand Down Expand Up @@ -528,8 +512,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
if (isDense) {
numKeyLookups += 1
numProbes += 1
if (key >= minKey && key <= maxKey) {
val value = array((key - minKey).toInt)
if (value > 0) {
Expand All @@ -538,14 +520,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
}
} else {
var pos = firstSlot(key)
numKeyLookups += 1
numProbes += 1
while (array(pos + 1) != 0) {
if (array(pos) == key) {
return valueIter(array(pos + 1), resultRow)
}
pos = nextSlot(pos)
numProbes += 1
}
}
null
Expand Down Expand Up @@ -585,11 +564,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
private def updateIndex(key: Long, address: Long): Unit = {
var pos = firstSlot(key)
assert(numKeys < array.length / 2)
numKeyLookups += 1
numProbes += 1
while (array(pos) != key && array(pos + 1) != 0) {
pos = nextSlot(pos)
numProbes += 1
}
if (array(pos + 1) == 0) {
// this is the first value for this key, put the address in array.
Expand Down Expand Up @@ -721,8 +697,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
writeLong(maxKey)
writeLong(numKeys)
writeLong(numValues)
writeLong(numKeyLookups)
writeLong(numProbes)

writeLong(array.length)
writeLongArray(writeBuffer, array, array.length)
Expand Down Expand Up @@ -764,8 +738,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
maxKey = readLong()
numKeys = readLong()
numValues = readLong()
numKeyLookups = readLong()
numProbes = readLong()

val length = readLong().toInt
mask = length - 2
Expand All @@ -783,11 +755,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
override def read(kryo: Kryo, in: Input): Unit = {
read(() => in.readBoolean(), () => in.readLong(), in.readBytes)
}

/**
* Returns the average number of probes per key lookup.
*/
def getAverageProbesPerLookup: Double = numProbes.toDouble / numKeyLookups
}

private[joins] class LongHashedRelation(
Expand Down Expand Up @@ -839,8 +806,6 @@ private[joins] class LongHashedRelation(
resultRow = new UnsafeRow(nFields)
map = in.readObject().asInstanceOf[LongToUnsafeRowMap]
}

override def getAverageProbesPerLookup: Double = map.getAverageProbesPerLookup
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ case class ShuffledHashJoinExec(
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"),
"avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))

override def requiredChildDistribution: Seq[Distribution] =
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
Expand All @@ -63,10 +62,9 @@ case class ShuffledHashJoinExec(

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val avgHashProbe = longMetric("avgHashProbe")
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
val hashed = buildHashedRelation(buildIter)
join(streamIter, hashed, numOutputRows, avgHashProbe)
join(streamIter, hashed, numOutputRows)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -231,50 +231,6 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
)
}

test("BroadcastHashJoin metrics: track avg probe") {
// The executed plan looks like:
// Project [a#210, b#211, b#221]
// +- BroadcastHashJoin [a#210], [a#220], Inner, BuildRight
// :- Project [_1#207 AS a#210, _2#208 AS b#211]
// : +- Filter isnotnull(_1#207)
// : +- LocalTableScan [_1#207, _2#208]
// +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, binary, true]))
// +- Project [_1#217 AS a#220, _2#218 AS b#221]
// +- Filter isnotnull(_1#217)
// +- LocalTableScan [_1#217, _2#218]
//
// Assume the execution plan with node id is
// WholeStageCodegen disabled:
// Project(nodeId = 0)
// BroadcastHashJoin(nodeId = 1)
// ...(ignored)
//
// WholeStageCodegen enabled:
// WholeStageCodegen(nodeId = 0)
// Project(nodeId = 1)
// BroadcastHashJoin(nodeId = 2)
// Project(nodeId = 3)
// Filter(nodeId = 4)
// ...(ignored)
Seq(true, false).foreach { enableWholeStage =>
val df1 = generateRandomBytesDF()
val df2 = generateRandomBytesDF()
val df = df1.join(broadcast(df2), "a")
val nodeIds = if (enableWholeStage) {
Set(2L)
} else {
Set(1L)
}
val metrics = getSparkPlanMetrics(df, 2, nodeIds, enableWholeStage).get
nodeIds.foreach { nodeId =>
val probes = metrics(nodeId)._2("avg hash probe (min, med, max)")
probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe =>
assert(probe.toDouble > 1.0)
}
}
}
}

test("ShuffledHashJoin metrics") {
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "40",
"spark.sql.shuffle.partitions" -> "2",
Expand All @@ -287,59 +243,11 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
val metrics = getSparkPlanMetrics(df, 1, Set(1L))
testSparkPlanMetrics(df, 1, Map(
1L -> (("ShuffledHashJoin", Map(
"number of output rows" -> 2L,
"avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))))
"number of output rows" -> 2L))))
)
}
}

test("ShuffledHashJoin metrics: track avg probe") {
// The executed plan looks like:
// Project [a#308, b#309, b#319]
// +- ShuffledHashJoin [a#308], [a#318], Inner, BuildRight
// :- Exchange hashpartitioning(a#308, 2)
// : +- Project [_1#305 AS a#308, _2#306 AS b#309]
// : +- Filter isnotnull(_1#305)
// : +- LocalTableScan [_1#305, _2#306]
// +- Exchange hashpartitioning(a#318, 2)
// +- Project [_1#315 AS a#318, _2#316 AS b#319]
// +- Filter isnotnull(_1#315)
// +- LocalTableScan [_1#315, _2#316]
//
// Assume the execution plan with node id is
// WholeStageCodegen disabled:
// Project(nodeId = 0)
// ShuffledHashJoin(nodeId = 1)
// ...(ignored)
//
// WholeStageCodegen enabled:
// WholeStageCodegen(nodeId = 0)
// Project(nodeId = 1)
// ShuffledHashJoin(nodeId = 2)
// ...(ignored)
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "5000000",
"spark.sql.shuffle.partitions" -> "2",
"spark.sql.join.preferSortMergeJoin" -> "false") {
Seq(true, false).foreach { enableWholeStage =>
val df1 = generateRandomBytesDF(65535 * 5)
val df2 = generateRandomBytesDF(65535)
val df = df1.join(df2, "a")
val nodeIds = if (enableWholeStage) {
Set(2L)
} else {
Set(1L)
}
val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get
nodeIds.foreach { nodeId =>
val probes = metrics(nodeId)._2("avg hash probe (min, med, max)")
probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe =>
assert(probe.toDouble > 1.0)
}
}
}
}
}

test("BroadcastHashJoin(outer) metrics") {
val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value")
val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value")
Expand Down

0 comments on commit ad3bdad

Please sign in to comment.