Skip to content

Commit

Permalink
SPARK-1601 & SPARK-1602: two bug fixes related to cancellation
Browse files Browse the repository at this point in the history
This should go into 1.0 since it would return wrong data when the bug happens (which is pretty likely if cancellation is used). Test case attached.

1. Do not put partially executed partitions into cache (in task killing).

2. Iterator returned by CacheManager#getOrCompute was not an InterruptibleIterator, and was thus leading to uninterruptible jobs.

Thanks @aarondav and @ahirreddy for reporting and helping debug.

Author: Reynold Xin <rxin@apache.org>

Closes apache#521 from rxin/kill and squashes the following commits:

401033f [Reynold Xin] Merge branch 'master' of https://git-wip-us.apache.org/repos/asf/spark into kill
7a7bdd2 [Reynold Xin] Add a new line in the end of JobCancellationSuite.scala.
35cd9f7 [Reynold Xin] Fixed a bug that partially executed partitions can be put into cache (in task killing).
  • Loading branch information
rxin authored and pdeyhim committed Jun 25, 2014
1 parent 539b2ff commit f957cb5
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 15 deletions.
15 changes: 11 additions & 4 deletions core/src/main/scala/org/apache/spark/CacheManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
if (loading.contains(key)) {
logInfo("Another thread is loading %s, waiting for it to finish...".format(key))
while (loading.contains(key)) {
try {loading.wait()} catch {case _ : Throwable =>}
try {
loading.wait()
} catch {
case e: Exception =>
logWarning(s"Got an exception while waiting for another thread to load $key", e)
}
}
logInfo("Finished waiting for %s".format(key))
/* See whether someone else has successfully loaded it. The main way this would fail
Expand All @@ -72,7 +77,9 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
val computedValues = rdd.computeOrReadCheckpoint(split, context)

// Persist the result, so long as the task is not running locally
if (context.runningLocally) { return computedValues }
if (context.runningLocally) {
return computedValues
}

// Keep track of blocks with updated statuses
var updatedBlocks = Seq[(BlockId, BlockStatus)]()
Expand All @@ -88,7 +95,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
updatedBlocks = blockManager.put(key, computedValues, storageLevel, tellMaster = true)
blockManager.get(key) match {
case Some(values) =>
new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
values.asInstanceOf[Iterator[T]]
case None =>
logInfo("Failure to store %s".format(key))
throw new Exception("Block manager failed to return persisted valued")
Expand All @@ -107,7 +114,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
val metrics = context.taskMetrics
metrics.updatedBlocks = Some(updatedBlocks)

returnValue
new InterruptibleIterator(context, returnValue)

} finally {
loading.synchronized {
Expand Down
12 changes: 11 additions & 1 deletion core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,17 @@ package org.apache.spark
private[spark] class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
extends Iterator[T] {

def hasNext: Boolean = !context.interrupted && delegate.hasNext
def hasNext: Boolean = {
// TODO(aarondav/rxin): Check Thread.interrupted instead of context.interrupted if interrupt
// is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
// (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
// introduces an expensive read fence.
if (context.interrupted) {
throw new TaskKilledException
} else {
delegate.hasNext
}
}

def next(): T = delegate.next()
}
23 changes: 23 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskKilledException.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

/**
* Exception for a task getting killed.
*/
private[spark] class TaskKilledException extends RuntimeException
8 changes: 3 additions & 5 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,6 @@ private[spark] class Executor(
class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
extends Runnable {

object TaskKilledException extends Exception

@volatile private var killed = false
@volatile private var task: Task[Any] = _

Expand Down Expand Up @@ -200,7 +198,7 @@ private[spark] class Executor(
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
throw TaskKilledException
throw new TaskKilledException
}

attemptedTask = Some(task)
Expand All @@ -214,7 +212,7 @@ private[spark] class Executor(

// If the task has been killed, let's fail it.
if (task.killed) {
throw TaskKilledException
throw new TaskKilledException
}

val resultSer = SparkEnv.get.serializer.newInstance()
Expand Down Expand Up @@ -257,7 +255,7 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}

case TaskKilledException | _: InterruptedException if task.killed => {
case _: TaskKilledException | _: InterruptedException if task.killed => {
logInfo("Executor killed task " + taskId)
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
}
Expand Down
43 changes: 38 additions & 5 deletions core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,35 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
assert(sc.parallelize(1 to 10, 2).count === 10)
}

test("do not put partially executed partitions into cache") {
// In this test case, we create a scenario in which a partition is only partially executed,
// and make sure CacheManager does not put that partially executed partition into the
// BlockManager.
import JobCancellationSuite._
sc = new SparkContext("local", "test")

// Run from 1 to 10, and then block and wait for the task to be killed.
val rdd = sc.parallelize(1 to 1000, 2).map { x =>
if (x > 10) {
taskStartedSemaphore.release()
taskCancelledSemaphore.acquire()
}
x
}.cache()

val rdd1 = rdd.map(x => x)

future {
taskStartedSemaphore.acquire()
sc.cancelAllJobs()
taskCancelledSemaphore.release(100000)
}

intercept[SparkException] { rdd1.count() }
// If the partial block is put into cache, rdd.count() would return a number less than 1000.
assert(rdd.count() === 1000)
}

test("job group") {
sc = new SparkContext("local[2]", "test")

Expand Down Expand Up @@ -114,7 +143,6 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
assert(jobB.get() === 100)
}


test("job group with interruption") {
sc = new SparkContext("local[2]", "test")

Expand Down Expand Up @@ -145,15 +173,14 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
assert(jobB.get() === 100)
}

/*
test("two jobs sharing the same stage") {
ignore("two jobs sharing the same stage") {
// sem1: make sure cancel is issued after some tasks are launched
// sem2: make sure the first stage is not finished until cancel is issued
val sem1 = new Semaphore(0)
val sem2 = new Semaphore(0)

sc = new SparkContext("local[2]", "test")
sc.dagScheduler.addSparkListener(new SparkListener {
sc.addSparkListener(new SparkListener {
override def onTaskStart(taskStart: SparkListenerTaskStart) {
sem1.release()
}
Expand All @@ -179,7 +206,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
intercept[SparkException] { f1.get() }
intercept[SparkException] { f2.get() }
}
*/

def testCount() {
// Cancel before launching any tasks
{
Expand Down Expand Up @@ -238,3 +265,9 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
}
}
}


object JobCancellationSuite {
val taskStartedSemaphore = new Semaphore(0)
val taskCancelledSemaphore = new Semaphore(0)
}

0 comments on commit f957cb5

Please sign in to comment.