Skip to content

Commit

Permalink
[GLUTEN-4882] ColumnarBroadcastExchangeExec should set/cancel with jo…
Browse files Browse the repository at this point in the history
…b tag for Spark3.5 (apache#4882)
  • Loading branch information
ulysses-you authored and taiyang-li committed Oct 9, 2024
1 parent 8bcd2c8 commit 6222d29
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.extension.{GlutenPlan, ValidationResult}
import io.glutenproject.metrics.GlutenTimeMetric
import io.glutenproject.sql.shims.SparkShimLoader

import org.apache.spark.{broadcast, SparkException}
import org.apache.spark.launcher.SparkLauncher
Expand Down Expand Up @@ -59,11 +60,7 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
session,
BroadcastExchangeExec.executionContext) {
try {
// Setup a job group here so later it may get cancelled by groupId if necessary.
sparkContext.setJobGroup(
runId.toString,
s"broadcast exchange (runId $runId)",
interruptOnCancel = true)
SparkShimLoader.getSparkShims.setJobDescriptionOrTagForBroadcastExchange(sparkContext, this)
val relation = GlutenTimeMetric.millis(longMetric("collectTime")) {
_ =>
// this created relation ignore HashedRelationBroadcastMode isNullAware, because we
Expand Down Expand Up @@ -169,7 +166,7 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
case ex: TimeoutException =>
logError(s"Could not execute broadcast in $timeout secs.", ex)
if (!relationFuture.isDone) {
sparkContext.cancelJobGroup(runId.toString)
SparkShimLoader.getSparkShims.cancelJobGroupForBroadcastExchange(sparkContext, this)
relationFuture.cancel(true)
}
throw new SparkException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package io.glutenproject.sql.shims

import io.glutenproject.expression.Sig

import org.apache.spark.TaskContext
import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.scheduler.TaskInfo
import org.apache.spark.shuffle.{ShuffleHandle, ShuffleReader}
Expand All @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.{FileSourceScanExec, GlobalLimitExec, Spar
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionDirectory, PartitionedFile, PartitioningAwareFileIndex, WriteJobDescription, WriteTaskResult}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.{BlockId, BlockManagerId}
Expand Down Expand Up @@ -117,6 +117,15 @@ trait SparkShims {

def createTestTaskContext(): TaskContext

// To be compatible with Spark-3.5 and later
// See https://github.com/apache/spark/pull/41440
def setJobDescriptionOrTagForBroadcastExchange(
sc: SparkContext,
broadcastExchange: BroadcastExchangeLike): Unit
def cancelJobGroupForBroadcastExchange(
sc: SparkContext,
broadcastExchange: BroadcastExchangeLike): Unit

def getShuffleReaderParam[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import io.glutenproject.execution.datasource.GlutenParquetWriterInjects
import io.glutenproject.expression.{ExpressionNames, Sig}
import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims}

import org.apache.spark.{ShuffleUtils, TaskContext, TaskContextUtils}
import org.apache.spark.{ShuffleUtils, SparkContext, TaskContext, TaskContextUtils}
import org.apache.spark.scheduler.TaskInfo
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.sql.{AnalysisException, SparkSession}
Expand All @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.{BlockId, BlockManagerId}
Expand Down Expand Up @@ -124,6 +125,22 @@ class Spark32Shims extends SparkShims {
TaskContextUtils.createTestTaskContext()
}

def setJobDescriptionOrTagForBroadcastExchange(
sc: SparkContext,
broadcastExchange: BroadcastExchangeLike): Unit = {
// Setup a job group here so later it may get cancelled by groupId if necessary.
sc.setJobGroup(
broadcastExchange.runId.toString,
s"broadcast exchange (runId ${broadcastExchange.runId})",
interruptOnCancel = true)
}

def cancelJobGroupForBroadcastExchange(
sc: SparkContext,
broadcastExchange: BroadcastExchangeLike): Unit = {
sc.cancelJobGroup(broadcastExchange.runId.toString)
}

override def getShuffleReaderParam[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.glutenproject.execution.datasource.GlutenParquetWriterInjects
import io.glutenproject.expression.{ExpressionNames, Sig}
import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims}

import org.apache.spark.{ShuffleDependency, ShuffleUtils, SparkEnv, SparkException, TaskContext, TaskContextUtils}
import org.apache.spark.{ShuffleDependency, ShuffleUtils, SparkContext, SparkEnv, SparkException, TaskContext, TaskContextUtils}
import org.apache.spark.scheduler.TaskInfo
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.ShuffleHandle
Expand All @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.{BlockId, BlockManagerId}
Expand Down Expand Up @@ -167,6 +168,22 @@ class Spark33Shims extends SparkShims {
TaskContextUtils.createTestTaskContext()
}

def setJobDescriptionOrTagForBroadcastExchange(
sc: SparkContext,
broadcastExchange: BroadcastExchangeLike): Unit = {
// Setup a job group here so later it may get cancelled by groupId if necessary.
sc.setJobGroup(
broadcastExchange.runId.toString,
s"broadcast exchange (runId ${broadcastExchange.runId})",
interruptOnCancel = true)
}

def cancelJobGroupForBroadcastExchange(
sc: SparkContext,
broadcastExchange: BroadcastExchangeLike): Unit = {
sc.cancelJobGroup(broadcastExchange.runId.toString)
}

override def getShuffleReaderParam[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import io.glutenproject.GlutenConfig
import io.glutenproject.expression.{ExpressionNames, Sig}
import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims}

import org.apache.spark.{ShuffleUtils, SparkException, TaskContext, TaskContextUtils}
import org.apache.spark.{ShuffleUtils, SparkContext, SparkException, TaskContext, TaskContextUtils}
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.paths.SparkPath
import org.apache.spark.scheduler.TaskInfo
Expand All @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.{BlockId, BlockManagerId}
Expand Down Expand Up @@ -205,6 +206,22 @@ class Spark34Shims extends SparkShims {
TaskContextUtils.createTestTaskContext()
}

def setJobDescriptionOrTagForBroadcastExchange(
sc: SparkContext,
broadcastExchange: BroadcastExchangeLike): Unit = {
// Setup a job group here so later it may get cancelled by groupId if necessary.
sc.setJobGroup(
broadcastExchange.runId.toString,
s"broadcast exchange (runId ${broadcastExchange.runId})",
interruptOnCancel = true)
}

def cancelJobGroupForBroadcastExchange(
sc: SparkContext,
broadcastExchange: BroadcastExchangeLike): Unit = {
sc.cancelJobGroup(broadcastExchange.runId.toString)
}

override def getShuffleReaderParam[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import io.glutenproject.GlutenConfig
import io.glutenproject.expression.{ExpressionNames, Sig}
import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims}

import org.apache.spark.{ShuffleUtils, SparkException, TaskContext, TaskContextUtils}
import org.apache.spark.{ShuffleUtils, SparkContext, SparkException, TaskContext, TaskContextUtils}
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.paths.SparkPath
import org.apache.spark.scheduler.TaskInfo
Expand All @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.{BlockId, BlockManagerId}
Expand Down Expand Up @@ -205,6 +205,20 @@ class Spark35Shims extends SparkShims {
TaskContextUtils.createTestTaskContext()
}

override def setJobDescriptionOrTagForBroadcastExchange(
sc: SparkContext,
broadcastExchange: BroadcastExchangeLike): Unit = {
// Setup a job tag here so later it may get cancelled by tag if necessary.
sc.addJobTag(broadcastExchange.jobTag)
sc.setInterruptOnCancel(true)
}

override def cancelJobGroupForBroadcastExchange(
sc: SparkContext,
broadcastExchange: BroadcastExchangeLike): Unit = {
sc.cancelJobsWithTag(broadcastExchange.jobTag)
}

override def getShuffleReaderParam[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
Expand Down

0 comments on commit 6222d29

Please sign in to comment.