Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow returning an EmptyHashedRelation when a broadcast result is empty [databricks] #4256

Merged
merged 27 commits into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
26a8d37
Allow returning an EmptyHashedRelation when a broadcast result is empty
abellina Nov 17, 2021
6508550
Address review comments
abellina Dec 3, 2021
3f31a67
Revert change in integration tests
abellina Dec 6, 2021
b9684c2
Remove unwanted change
abellina Dec 6, 2021
0dffe9f
Minor cleanup
abellina Dec 6, 2021
5df6437
identity.length == 0 to identity.isEmpty
abellina Dec 6, 2021
d048334
Apply more suggested cleanup
abellina Dec 6, 2021
e021a00
Remove incRefCount
abellina Dec 6, 2021
01bd86f
EmptyHashedRelation was introduced in 3.1.x, so this fixes the shims
abellina Dec 7, 2021
282f402
Cache build schema outside of mapPartitions
abellina Dec 7, 2021
30beaf1
Fix bug with the broadcast helper + make a new test in join_test
abellina Dec 7, 2021
6c0d593
Adds a test that forces a broadcast for the EmptyHashedRelation scenario
abellina Dec 7, 2021
a3fc8ab
Fix typo
abellina Dec 7, 2021
4bba02f
Upmerge to 22.02
abellina Dec 7, 2021
2e83fa8
Fix typo
abellina Dec 7, 2021
f62167c
Adding isFoldableNonLitAllowed to UnaryExprMeta
abellina Dec 7, 2021
ee80225
Fix Spark 3.0.x build
abellina Dec 8, 2021
71b91f3
Also need to fix 30Xdb
abellina Dec 8, 2021
5ac0143
Move isEmptyRelation override to Spark31xdb
abellina Dec 8, 2021
012bdf0
Upmerge
abellina Dec 8, 2021
a41f97d
Disable in databricks and do some clenaup
abellina Dec 8, 2021
7fd4a20
Parametrize databricks so we dont request AQE when that is an invalid…
abellina Dec 8, 2021
6f9b75f
Cleanup
abellina Dec 8, 2021
a0704e2
Extra comment in RapidsMeta, take care of other review comments
abellina Dec 9, 2021
5c527eb
Fix import spacing
abellina Dec 9, 2021
438ff4c
Call the non-capturing assert
abellina Dec 9, 2021
34f2b59
Apply suggestion in GpuBroadcastHelper
abellina Dec 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExec, GpuHashJoin, JoinTypeChecks, SerializeConcatHostBuffersDeserializeBatch}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExec, GpuBroadcastHelper, GpuHashJoin, JoinTypeChecks}
import org.apache.spark.sql.vectorized.ColumnarBatch

class GpuBroadcastHashJoinMeta(
Expand Down Expand Up @@ -148,16 +148,16 @@ case class GpuBroadcastHashJoinExec(

val targetSize = RapidsConf.GPU_BATCH_SIZE_BYTES.get(conf)

val broadcastRelation = broadcastExchange
.executeColumnarBroadcast[SerializeConcatHostBuffersDeserializeBatch]()
val broadcastRelation = broadcastExchange.executeColumnarBroadcast[Any]()

val rdd = streamedPlan.executeColumnar()
rdd.mapPartitions { it =>
val stIt = new CollectTimeIterator("broadcast join stream", it, streamTime)
val builtBatch = broadcastRelation.value.batch
GpuColumnVector.extractBases(builtBatch).foreach(_.noWarnLeakExpected())
doJoin(builtBatch, stIt, targetSize, spillCallback,
numOutputRows, joinOutputRows, numOutputBatches, opTime, joinTime)
withResource(
GpuBroadcastHelper.getBroadcastBatch(broadcastRelation, buildPlan)) { builtBatch =>
doJoin(builtBatch, stIt, targetSize, spillCallback,
numOutputRows, joinOutputRows, numOutputBatches, opTime, joinTime)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExec, GpuHashJoin, JoinTypeChecks, SerializeConcatHostBuffersDeserializeBatch}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExec, GpuBroadcastHelper, GpuHashJoin, JoinTypeChecks}
import org.apache.spark.sql.vectorized.ColumnarBatch

class GpuBroadcastHashJoinMeta(
Expand Down Expand Up @@ -147,16 +147,16 @@ case class GpuBroadcastHashJoinExec(

val targetSize = RapidsConf.GPU_BATCH_SIZE_BYTES.get(conf)

val broadcastRelation = broadcastExchange
.executeColumnarBroadcast[SerializeConcatHostBuffersDeserializeBatch]()
val broadcastRelation = broadcastExchange.executeColumnarBroadcast[Any]()

val rdd = streamedPlan.executeColumnar()
rdd.mapPartitions { it =>
val stIt = new CollectTimeIterator("broadcast join stream", it, streamTime)
val builtBatch = broadcastRelation.value.batch
GpuColumnVector.extractBases(builtBatch).foreach(_.noWarnLeakExpected())
doJoin(builtBatch, stIt, targetSize, spillCallback,
numOutputRows, joinOutputRows, numOutputBatches, opTime, joinTime)
withResource(
GpuBroadcastHelper.getBroadcastBatch(broadcastRelation, buildPlan)) { builtBatch =>
doJoin(builtBatch, stIt, targetSize, spillCallback,
numOutputRows, joinOutputRows, numOutputBatches, opTime, joinTime)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExec, GpuHashJoin, JoinTypeChecks, SerializeConcatHostBuffersDeserializeBatch}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExec, GpuBroadcastHelper, GpuHashJoin, JoinTypeChecks}
import org.apache.spark.sql.vectorized.ColumnarBatch

class GpuBroadcastHashJoinMeta(
Expand Down Expand Up @@ -149,16 +149,16 @@ case class GpuBroadcastHashJoinExec(

val targetSize = RapidsConf.GPU_BATCH_SIZE_BYTES.get(conf)

val broadcastRelation = broadcastExchange
.executeColumnarBroadcast[SerializeConcatHostBuffersDeserializeBatch]()
val broadcastRelation = broadcastExchange.executeColumnarBroadcast[Any]()

val rdd = streamedPlan.executeColumnar()
rdd.mapPartitions { it =>
val stIt = new CollectTimeIterator("broadcast join stream", it, streamTime)
val builtBatch = broadcastRelation.value.batch
GpuColumnVector.extractBases(builtBatch).foreach(_.noWarnLeakExpected())
doJoin(builtBatch, stIt, targetSize, spillCallback,
numOutputRows, joinOutputRows, numOutputBatches, opTime, joinTime)
withResource(
GpuBroadcastHelper.getBroadcastBatch(broadcastRelation, buildPlan)) { builtBatch =>
doJoin(builtBatch, stIt, targetSize, spillCallback,
numOutputRows, joinOutputRows, numOutputBatches, opTime, joinTime)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.nvidia.spark.rapids.shims.v2

import com.nvidia.spark.rapids._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
Expand All @@ -28,7 +27,7 @@ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExec, GpuHashJoin, JoinTypeChecks, SerializeConcatHostBuffersDeserializeBatch}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExec, GpuBroadcastHelper, GpuHashJoin, JoinTypeChecks}
import org.apache.spark.sql.vectorized.ColumnarBatch

class GpuBroadcastHashJoinMeta(
Expand Down Expand Up @@ -148,16 +147,16 @@ case class GpuBroadcastHashJoinExec(

val targetSize = RapidsConf.GPU_BATCH_SIZE_BYTES.get(conf)

val broadcastRelation = broadcastExchange
.executeColumnarBroadcast[SerializeConcatHostBuffersDeserializeBatch]()
val broadcastRelation = broadcastExchange.executeColumnarBroadcast[Any]()

val rdd = streamedPlan.executeColumnar()
rdd.mapPartitions { it =>
val stIt = new CollectTimeIterator("broadcast join stream", it, streamTime)
val builtBatch = broadcastRelation.value.batch
GpuColumnVector.extractBases(builtBatch).foreach(_.noWarnLeakExpected())
doJoin(builtBatch, stIt, targetSize, spillCallback,
numOutputRows, joinOutputRows, numOutputBatches, opTime, joinTime)
withResource(
GpuBroadcastHelper.getBroadcastBatch(broadcastRelation, buildPlan)) { builtBatch =>
doJoin(builtBatch, stIt, targetSize, spillCallback,
numOutputRows, joinOutputRows, numOutputBatches, opTime, joinTime)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -303,28 +303,36 @@ abstract class GpuBroadcastExchangeExecBase(
// Setup a job group here so later it may get cancelled by groupId if necessary.
sparkContext.setJobGroup(_runId.toString, s"broadcast exchange (runId ${_runId})",
interruptOnCancel = true)
val batch = withResource(new NvtxWithMetrics("broadcast collect", NvtxColor.GREEN,
collectTime)) { _ =>
val data = child.executeColumnar().map(cb => try {
new SerializeBatchDeserializeHostBuffer(cb)
} finally {
cb.close()
})
val d = data.collect()
new SerializeConcatHostBuffersDeserializeBatch(d, output)
}

val numRows = batch.numRows
checkRowLimit(numRows)
numOutputBatches += 1
numOutputRows += numRows
var dataSize = 0L
val broadcastResult =
withResource(new NvtxWithMetrics("broadcast collect", NvtxColor.GREEN,
collectTime)) { _ =>
val childRdd = child.executeColumnar()
val data = childRdd.map(cb => try {
new SerializeBatchDeserializeHostBuffer(cb)
} finally {
cb.close()
})
val d = data.collect()
if (d.length == 0) {
// This call for `HashedRelationBroadcastMode` produces
// `EmptyHashedRelation` allowing the AQE rule `EliminateJoinToEmptyRelation` to
// optimize out our parent join given that this is a empty broadcast result.
mode.transform(Iterator.empty, None)
} else {
val batch = new SerializeConcatHostBuffersDeserializeBatch(d, output)
val numRows = batch.numRows
checkRowLimit(numRows)
numOutputBatches += 1
numOutputRows += numRows
batch
}
}

withResource(new NvtxWithMetrics("broadcast build", NvtxColor.DARK_GREEN,
buildTime)) { _ =>
// we only support hashjoin so this is a noop
// val relation = mode.transform(input, Some(numRows))
val dataSize = batch.dataSize

gpuLongMetric("dataSize") += dataSize
if (dataSize >= MAX_BROADCAST_TABLE_BYTES) {
throw new SparkException(
Expand All @@ -335,7 +343,7 @@ abstract class GpuBroadcastExchangeExecBase(
val broadcasted = withResource(new NvtxWithMetrics("broadcast", NvtxColor.CYAN,
broadcastTime)) { _ =>
// Broadcast the relation
sparkContext.broadcast(batch.asInstanceOf[Any])
sparkContext.broadcast(broadcastResult.asInstanceOf[Any])
}

SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.rapids.execution

import com.nvidia.spark.rapids.GpuColumnVector

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.EmptyHashedRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

object GpuBroadcastHelper {
/**
* Given a broadcast relation get a ColumnarBatch that can be used on the GPU.
*
* The broadcast relation may or may not contain any data, so we special case
* the empty relation case (hash or identity depending on the type of join).
*
* If a broadcast result is unexpected we throw, but at this moment other
* cases are not known, so this is a defensive measure.
*
* @param broadcastRelation - the broadcast as produced by a broadcast exchange
* @param broadcastPlan - the SparkPlan to use to obtain the schema for the broadcast
* batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing an entire plan just to get the schema is very heavyweight. This should simply take a schema parameter.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed now.

* @return a `ColumnarBatch` or throw if the broadcast can't be handled
*/
def getBroadcastBatch(broadcastRelation: Broadcast[Any],
broadcastPlan: SparkPlan): ColumnarBatch = {
val broadcastRelationValue = broadcastRelation.value
broadcastRelationValue match {
case broadcastBatch: SerializeConcatHostBuffersDeserializeBatch =>
val builtBatch = broadcastBatch.batch
GpuColumnVector.incRefCounts(builtBatch)
builtBatch
case EmptyHashedRelation =>
GpuColumnVector.emptyBatch(broadcastPlan.schema)
case identity: Array[Any] if identity.length == 0 =>
// A broadcast nested loop join uses `IdentityBroadcastMode` which when
// transformed can produce an Array[InternalRow].
// In this case we handle the scenario where this is an empty result,
// so we return the empty batch, other results are expected to be
// `SerializeConcatHostBuffersDeserializeBatch`.
GpuColumnVector.emptyBatch(broadcastPlan.schema)
case t =>
throw new IllegalStateException(s"Invalid broadcast batch received $t")
}
}

/**
* Given a broadcast relation get the number of rows that the received batch
* contains
*
* The broadcast relation may or may not contain any data, so we special case
* the empty relation case (hash or identity depending on the type of join).
*
* If a broadcast result is unexpected we throw, but at this moment other
* cases are not known, so this is a defensive measure.
*
* @param broadcastRelation - the broadcast as produced by a broadcast exchange
* @return number of rows for a batch received, or 0 if it's an empty relation
*/
def getBroadcastBatchNumRows(broadcastRelation: Broadcast[Any]): Int = {
val broadcastRelationValue = broadcastRelation.value
broadcastRelationValue match {
case broadcastBatch: SerializeConcatHostBuffersDeserializeBatch =>
broadcastBatch.batch.numRows()
case EmptyHashedRelation => 0
case identity: Array[Any] if identity.length == 0 => 0
case t =>
throw new IllegalStateException(s"Invalid broadcast batch received $t")
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -435,23 +435,26 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
}

private[this] def makeBuiltBatch(
broadcastRelation: Broadcast[SerializeConcatHostBuffersDeserializeBatch],
broadcastRelation: Broadcast[Any],
buildTime: GpuMetric,
buildDataSize: GpuMetric): ColumnarBatch = {
withResource(new NvtxWithMetrics("build join table", NvtxColor.GREEN, buildTime)) { _ =>
val ret = broadcastRelation.value.batch
buildDataSize += GpuColumnVector.getTotalDeviceMemoryUsed(ret)
GpuColumnVector.incRefCounts(ret)
withResource(GpuBroadcastHelper.getBroadcastBatch(
broadcastRelation, broadcast)) { builtBatch =>
GpuColumnVector.incRefCounts(builtBatch)
buildDataSize += GpuColumnVector.getTotalDeviceMemoryUsed(builtBatch)
builtBatch
}
}
}

private[this] def computeBuildRowCount(
broadcastRelation: Broadcast[SerializeConcatHostBuffersDeserializeBatch],
broadcastRelation: Broadcast[Any],
buildTime: GpuMetric,
buildDataSize: GpuMetric): Int = {
withResource(new NvtxWithMetrics("build join table", NvtxColor.GREEN, buildTime)) { _ =>
buildDataSize += 0
broadcastRelation.value.batch.numRows()
GpuBroadcastHelper.getBroadcastBatchNumRows(broadcastRelation)
}
}

Expand All @@ -468,7 +471,7 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
}

val broadcastRelation =
broadcastExchange.executeColumnarBroadcast[SerializeConcatHostBuffersDeserializeBatch]()
broadcastExchange.executeColumnarBroadcast[Any]()

val joinCondition = boundCondition.orElse {
// For outer joins use a true condition if there are any columns in the build side
Expand All @@ -489,7 +492,7 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
}

private def leftExistenceJoin(
broadcastRelation: Broadcast[SerializeConcatHostBuffersDeserializeBatch],
broadcastRelation: Broadcast[Any],
exists: Boolean,
buildTime: GpuMetric,
buildDataSize: GpuMetric): RDD[ColumnarBatch] = {
Expand All @@ -504,9 +507,7 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
}
}

private def doUnconditionalJoin(
broadcastRelation: Broadcast[SerializeConcatHostBuffersDeserializeBatch]
): RDD[ColumnarBatch] = {
private def doUnconditionalJoin(broadcastRelation: Broadcast[Any]): RDD[ColumnarBatch] = {
if (output.isEmpty) {
doUnconditionalJoinRowCount(broadcastRelation)
} else {
Expand Down Expand Up @@ -565,9 +566,7 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
}

/** Special-case handling of an unconditional join that just needs to output a row count. */
private def doUnconditionalJoinRowCount(
broadcastRelation: Broadcast[SerializeConcatHostBuffersDeserializeBatch]
): RDD[ColumnarBatch] = {
private def doUnconditionalJoinRowCount(broadcastRelation: Broadcast[Any]): RDD[ColumnarBatch] = {
if (joinType == LeftAnti) {
// degenerate case, no rows are returned.
left.executeColumnar().mapPartitions { _ =>
Expand Down Expand Up @@ -604,13 +603,13 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
}

private def doConditionalJoin(
broadcastRelation: Broadcast[SerializeConcatHostBuffersDeserializeBatch],
broadcastRelation: Broadcast[Any],
boundCondition: Option[GpuExpression],
numFirstTableColumns: Int): RDD[ColumnarBatch] = {
val buildTime = gpuLongMetric(BUILD_TIME)
val buildDataSize = gpuLongMetric(BUILD_DATA_SIZE)
val spillCallback = GpuMetric.makeSpillCallback(allMetrics)
lazy val builtBatch = makeBuiltBatch(broadcastRelation, buildTime, buildDataSize)

val streamAttributes = streamed.output
val numOutputRows = gpuLongMetric(NUM_OUTPUT_ROWS)
val numOutputBatches = gpuLongMetric(NUM_OUTPUT_BATCHES)
Expand All @@ -619,7 +618,6 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
val joinOutputRows = gpuLongMetric(JOIN_OUTPUT_ROWS)
val nestedLoopJoinType = joinType
val buildSide = getGpuBuildSide
val spillCallback = GpuMetric.makeSpillCallback(allMetrics)
streamed.executeColumnar().mapPartitions { streamedIter =>
val lazyStream = streamedIter.map { cb =>
withResource(cb) { cb =>
Expand Down