From 2d7c030065d206418c08d278f0808c3ef3e7cde4 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 17 Oct 2015 00:07:27 +0800 Subject: [PATCH] Add a unit test --- .../spark/scheduler/TaskSetManager.scala | 2 +- .../scheduler/ReceiverTrackerSuite.scala | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index c02597c4365c9..947feda3b34aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -388,7 +388,7 @@ private[spark] class TaskSetManager( if (TaskLocality.isAllowed(maxLocality, TaskLocality.NO_PREF)) { // Look for noPref tasks after NODE_LOCAL for minimize cross-rack traffic for (index <- dequeueTaskFromList(execId, pendingTasksWithNoPrefs)) { - return Some((index, TaskLocality.PROCESS_LOCAL, false)) + return Some((index, TaskLocality.NO_PREF, false)) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index 45138b748ecab..fda86aef457d4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart, TaskLocality} +import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -80,6 +82,28 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } } + + test("SPARK-11063: TaskSetManager should use Receiver RDD's preferredLocations") { + // Use ManualClock to prevent from starting batches so that we can make sure the only task is + // for starting the Receiver + val _conf = conf.clone.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") + withStreamingContext(new StreamingContext(_conf, Milliseconds(100))) { ssc => + @volatile var receiverTaskLocality: TaskLocality = null + ssc.sparkContext.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + receiverTaskLocality = taskStart.taskInfo.taskLocality + } + }) + val input = ssc.receiverStream(new TestReceiver) + val output = new TestOutputStream(input) + output.register() + ssc.start() + eventually(timeout(10 seconds), interval(10 millis)) { + // If preferredLocations is set correctly, receiverTaskLocality should be NODE_LOCAL + assert(receiverTaskLocality === TaskLocality.NODE_LOCAL) + } + } + } } /** An input DStream with for testing rate controlling */