diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index f4be03c90be75..7c03bad90ebbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal -import org.apache.spark.{ErrorMessageFormat, JobArtifactSet, SparkEnv, SparkException, SparkThrowable, SparkThrowableHelper} +import org.apache.spark.{ErrorMessageFormat, JobArtifactSet, SparkContext, SparkEnv, SparkException, SparkThrowable, SparkThrowableHelper} import org.apache.spark.SparkContext.{SPARK_JOB_DESCRIPTION, SPARK_JOB_INTERRUPT_ON_CANCEL} import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{SPARK_DRIVER_PREFIX, SPARK_EXECUTOR_PREFIX} @@ -128,7 +128,8 @@ object SQLExecution extends Logging { sparkPlanInfo = SparkPlanInfo.EMPTY, time = System.currentTimeMillis(), modifiedConfigs = redactedConfigs, - jobTags = sc.getJobTags() + jobTags = sc.getJobTags(), + jobGroupId = Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)) ) try { body match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index bf33ba2c96f19..dcbf328c71e33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -343,7 +343,7 @@ class SQLAppStatusListener( private def onExecutionStart(event: SparkListenerSQLExecutionStart): Unit = { val SparkListenerSQLExecutionStart(executionId, rootExecutionId, description, details, - physicalPlanDescription, sparkPlanInfo, time, modifiedConfigs, _) = event + physicalPlanDescription, sparkPlanInfo, time, modifiedConfigs, _, _) = event val planGraph = SparkPlanGraph(sparkPlanInfo) val sqlPlanMetrics = planGraph.allNodes.flatMap { node => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 3a22dd23548fc..416b9547b0462 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -54,7 +54,8 @@ case class SparkListenerSQLExecutionStart( sparkPlanInfo: SparkPlanInfo, time: Long, modifiedConfigs: Map[String, String] = Map.empty, - jobTags: Set[String] = Set.empty) + jobTags: Set[String] = Set.empty, + jobGroupId: Option[String] = None) extends SparkListenerEvent @DeveloperApi diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index b8a109919f8f6..94d33731b6de5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -227,6 +227,7 @@ class SQLExecutionSuite extends SparkFunSuite with SQLConfHelper { spark.range(1).collect() + spark.sparkContext.listenerBus.waitUntilEmpty() assert(jobTags.contains(jobTag)) assert(sqlJobTags.contains(jobTag)) } finally { @@ -234,6 +235,38 @@ class SQLExecutionSuite extends SparkFunSuite with SQLConfHelper { spark.stop() } } + + test("jobGroupId property") { + val spark = SparkSession.builder().master("local[*]").appName("test").getOrCreate() + val JobGroupId = "test-JobGroupId" + try { + spark.sparkContext.setJobGroup(JobGroupId, "job Group id") + + var jobGroupIdOpt: Option[String] = None + var sqlJobGroupIdOpt: Option[String] = None + spark.sparkContext.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobGroupIdOpt = Some(jobStart.properties.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case e: SparkListenerSQLExecutionStart => + sqlJobGroupIdOpt = e.jobGroupId + } + } + }) + + spark.range(1).collect() + + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(jobGroupIdOpt.contains(JobGroupId)) + assert(sqlJobGroupIdOpt.contains(JobGroupId)) + } finally { + spark.sparkContext.clearJobGroup() + spark.stop() + } + } } object SQLExecutionSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 17e77cf8d8fb3..e63ff019a2b6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -344,7 +344,7 @@ abstract class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTes val listener = new SparkListener { override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { - case SparkListenerSQLExecutionStart(_, _, _, _, planDescription, _, _, _, _) => + case SparkListenerSQLExecutionStart(_, _, _, _, planDescription, _, _, _, _, _) => assert(expected.forall(planDescription.contains)) checkDone = true case _ => // ignore other events