diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cae22d733d2f8..efb5f9d501e48 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1571,6 +1571,15 @@ class SparkContext(config: SparkConf) extends Logging { listenerBus.addListener(listener) } + /** + * :: DeveloperApi :: + * Deregister the listener from Spark's listener bus. + */ + @DeveloperApi + def removeSparkListener(listener: SparkListenerInterface): Unit = { + listenerBus.removeListener(listener) + } + private[spark] def getExecutorIds(): Seq[String] = { schedulerBackend match { case b: CoarseGrainedSchedulerBackend => diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index c451c596b069a..8fba82de5462d 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.Matchers._ +import org.apache.spark.scheduler.SparkListener import org.apache.spark.util.Utils class SparkContextSuite extends SparkFunSuite with LocalSparkContext { @@ -451,4 +452,19 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } } + + test("register and deregister Spark listener from SparkContext") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + try { + val sparkListener1 = new SparkListener { } + val sparkListener2 = new SparkListener { } + sc.addSparkListener(sparkListener1) + sc.addSparkListener(sparkListener2) + assert(sc.listenerBus.listeners.contains(sparkListener1)) + assert(sc.listenerBus.listeners.contains(sparkListener2)) + sc.removeSparkListener(sparkListener1) + assert(!sc.listenerBus.listeners.contains(sparkListener1)) + assert(sc.listenerBus.listeners.contains(sparkListener2)) + } + } }