Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPY-317] Backports of SPARK-1582, SPARK-1601 and SPARK-1602 #6

Merged
merged 3 commits into from
Apr 24, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions core/src/main/scala/org/apache/spark/CacheManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
if (loading.contains(key)) {
logInfo("Another thread is loading %s, waiting for it to finish...".format(key))
while (loading.contains(key)) {
try {loading.wait()} catch {case _ : Throwable =>}
try {
loading.wait()
} catch {
case e: Exception =>
logWarning("Got an exception while waiting for another thread to load " + key, e)
}
}
logInfo("Finished waiting for %s".format(key))
// See whether someone else has successfully loaded it. The main way this would fail
Expand Down Expand Up @@ -74,7 +79,10 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
val elements = new ArrayBuffer[Any]
elements ++= computedValues
blockManager.put(key, elements, storageLevel, tellMaster = true)
return elements.iterator.asInstanceOf[Iterator[T]]
val returnValue: Iterator[T] = elements.iterator.asInstanceOf[Iterator[T]]

new InterruptibleIterator(context, returnValue)

} finally {
loading.synchronized {
loading.remove(key)
Expand Down
12 changes: 11 additions & 1 deletion core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,17 @@ package org.apache.spark
class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
extends Iterator[T] {

def hasNext: Boolean = !context.interrupted && delegate.hasNext
def hasNext: Boolean = {
// TODO(aarondav/rxin): Check Thread.interrupted instead of context.interrupted if interrupt
// is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
// (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
// introduces an expensive read fence.
if (context.interrupted) {
throw new TaskKilledException
} else {
delegate.hasNext
}
}

def next(): T = delegate.next()
}
15 changes: 14 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -355,16 +355,27 @@ class SparkContext(
* // In a separate thread:
* sc.cancelJobGroup("some_job_to_cancel")
* }}}
*
* If interruptOnCancel is set to true for the job group, then job cancellation will result
* in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure
* that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208,
* where HDFS may respond to Thread.interrupt() by marking nodes as dead.
*/
def setJobGroup(groupId: String, description: String) {
def setJobGroup(groupId: String, description: String, interruptOnCancel: Boolean = false) {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description)
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
// Note: Specifying interruptOnCancel in setJobGroup (rather than cancelJobGroup) avoids
// changing several public APIs and allows Spark cancellations outside of the cancelJobGroup
// APIs to also take advantage of this property (e.g., internal job failures or canceling from
// JobProgressTab UI) on a per-job basis.
setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, interruptOnCancel.toString)
}

/** Clear the job group id and its description. */
def clearJobGroup() {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null)
}

// Post init
Expand Down Expand Up @@ -1022,6 +1033,8 @@ object SparkContext {

private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"

private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel"

private[spark] val SPARK_UNKNOWN_USER = "<unknown>"

implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
Expand Down
23 changes: 23 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskKilledException.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

/**
* Exception for a task getting killed.
*/
private[spark] class TaskKilledException extends RuntimeException
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ private[spark] class CoarseGrainedExecutorBackend(
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
}

case KillTask(taskId, _) =>
case KillTask(taskId, _, interruptThread) =>
if (executor == null) {
logError("Received KillTask command but executor was null")
System.exit(1)
} else {
executor.killTask(taskId)
executor.killTask(taskId, interruptThread)
}

case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
Expand Down
16 changes: 7 additions & 9 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ private[spark] class Executor(
threadPool.execute(tr)
}

def killTask(taskId: Long) {
def killTask(taskId: Long, interruptThread: Boolean) {
val tr = runningTasks.get(taskId)
if (tr != null) {
tr.kill()
tr.kill(interruptThread)
}
}

Expand All @@ -163,16 +163,14 @@ private[spark] class Executor(
class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
extends Runnable {

object TaskKilledException extends Exception

@volatile private var killed = false
@volatile private var task: Task[Any] = _

def kill() {
def kill(interruptThread: Boolean) {
logInfo("Executor is trying to kill task " + taskId)
killed = true
if (task != null) {
task.kill()
task.kill(interruptThread)
}
}

Expand Down Expand Up @@ -202,7 +200,7 @@ private[spark] class Executor(
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
throw TaskKilledException
throw new TaskKilledException
}

attemptedTask = Some(task)
Expand All @@ -216,7 +214,7 @@ private[spark] class Executor(

// If the task has been killed, let's fail it.
if (task.killed) {
throw TaskKilledException
throw new TaskKilledException
}

for (m <- task.metrics) {
Expand Down Expand Up @@ -254,7 +252,7 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}

case TaskKilledException => {
case _: TaskKilledException | _: InterruptedException if task.killed => {
logInfo("Executor killed task " + taskId)
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ private[spark] class MesosExecutorBackend
if (executor == null) {
logError("Received KillTask but executor was null")
} else {
executor.killTask(t.getValue.toLong)
// TODO: Determine the 'interruptOnCancel' property set for the given job.
executor.killTask(t.getValue.toLong, interruptThread = false)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -962,10 +962,13 @@ class DAGScheduler(
if (!jobIdToStageIds.contains(jobId)) {
logDebug("Trying to cancel unregistered job " + jobId)
} else {
val job = idToActiveJob(jobId)
val independentStages = removeJobAndIndependentStages(jobId)
independentStages.foreach { taskSched.cancelTasks }
val shouldInterruptThread =
if (job.properties == null) false
else job.properties.getProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false").toBoolean
independentStages.foreach { taskSched.cancelTasks(_, shouldInterruptThread) }
val error = new SparkException("Job %d cancelled".format(jobId))
val job = idToActiveJob(jobId)
job.listener.jobFailed(error)
jobIdToStageIds -= jobId
activeJobs -= job
Expand Down
12 changes: 10 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex

final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
taskThread = Thread.currentThread()
if (_killed) {
kill()
kill(interruptThread = false)
}
runTask(context)
}
Expand All @@ -65,6 +66,9 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
// Task context, to be initialized in run().
@transient protected var context: TaskContext = _

// The actual Thread on which the task is running, if any. Initialized in run().
@volatile @transient private var taskThread: Thread = _

// A flag to indicate whether the task is killed. This is used in case context is not yet
// initialized when kill() is invoked.
@volatile @transient private var _killed = false
Expand All @@ -78,12 +82,16 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
* Kills a task by setting the interrupted flag to true. This relies on the upper level Spark
* code and user code to properly handle the flag. This function should be idempotent so it can
* be called multiple times.
* If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread.
*/
def kill() {
def kill(interruptThread: Boolean) {
_killed = true
if (context != null) {
context.interrupted = true
}
if (interruptThread && taskThread != null) {
taskThread.interrupt()
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private[spark] trait TaskScheduler {
def submitTasks(taskSet: TaskSet): Unit

// Cancel a stage.
def cancelTasks(stageId: Int)
def cancelTasks(stageId: Int, interruptThread: Boolean)

// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ private[spark] class TaskSet(
val properties: Properties) {
val id: String = stageId + "." + attempt

def kill() {
tasks.foreach(_.kill())
def kill(interruptThread: Boolean) {
tasks.foreach(_.kill(interruptThread))
}

override def toString: String = "TaskSet " + id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
backend.reviveOffers()
}

override def cancelTasks(stageId: Int): Unit = synchronized {
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
// There are two possible cases here:
Expand All @@ -178,7 +178,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (taskIds.size > 0) {
taskIds.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId)
backend.killTask(tid, execId, interruptThread)
}
}
logInfo("Stage %d was cancelled".format(stageId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ private[spark] object CoarseGrainedClusterMessages {
// Driver to executors
case class LaunchTask(task: TaskDescription) extends CoarseGrainedClusterMessage

case class KillTask(taskId: Long, executor: String) extends CoarseGrainedClusterMessage
case class KillTask(taskId: Long, executor: String, interruptThread: Boolean)
extends CoarseGrainedClusterMessage

case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
extends CoarseGrainedClusterMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
case ReviveOffers =>
makeOffers()

case KillTask(taskId, executorId) =>
executorActor(executorId) ! KillTask(taskId, executorId)
case KillTask(taskId, executorId, interruptThread) =>
executorActor(executorId) ! KillTask(taskId, executorId, interruptThread)

case StopDriver =>
sender ! true
Expand Down Expand Up @@ -215,8 +215,8 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
driverActor ! ReviveOffers
}

override def killTask(taskId: Long, executorId: String) {
driverActor ! KillTask(taskId, executorId)
override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
driverActor ! KillTask(taskId, executorId, interruptThread)
}

override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int

def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException
def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit =
throw new UnsupportedOperationException

// Memory used by each executor (in megabytes)
protected val executorMemory: Int = SparkContext.executorMemoryRequested
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ private[local]
case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)

private[local]
case class KillTask(taskId: Long)
case class KillTask(taskId: Long, interruptThread: Boolean)

private[spark]
class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int)
Expand All @@ -62,8 +62,8 @@ class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int)
launchTask(localScheduler.resourceOffer(freeCores))
}

case KillTask(taskId) =>
executor.killTask(taskId)
case KillTask(taskId, interruptThread) =>
executor.killTask(taskId, interruptThread)
}

private def launchTask(tasks: Seq[TaskDescription]) {
Expand Down Expand Up @@ -128,7 +128,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
}

override def cancelTasks(stageId: Int): Unit = synchronized {
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId))
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
Expand All @@ -141,7 +141,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
val taskIds = taskSetTaskIds(tsm.taskSet.id)
if (taskIds.size > 0) {
taskIds.foreach { tid =>
localActor ! KillTask(tid)
localActor ! KillTask(tid, interruptThread)
}
}
logInfo("Stage %d was cancelled".format(stageId))
Expand Down
Loading