Skip to content

Commit

Permalink
Added more test cases covering cleanup when fault happens in ShuffleB…
Browse files Browse the repository at this point in the history
…lockFetcherIteratorSuite
  • Loading branch information
rxin authored and aarondav committed Oct 10, 2014
1 parent 1be4e8e commit cb589ec
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,17 @@ final class ShuffleBlockFetcherIterator(
private[this] def cleanup() {
isZombie = true
// Release the current buffer if necessary
if (currentResult != null && currentResult.buf != null) {
if (currentResult != null && !currentResult.failed) {
currentResult.buf.release()
}

// Release buffers in the results queue
val iter = results.iterator()
while (iter.hasNext) {
val result = iter.next()
result.buf.release()
if (!result.failed) {
result.buf.release()
}
}
}

Expand Down Expand Up @@ -313,7 +315,7 @@ object ShuffleBlockFetcherIterator {
* @param blocks Sequence of tuple, where the first element is the block id,
* and the second element is the estimated size, used to calculate bytesInFlight.
*/
class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) {
val size = blocks.map(_._2).sum
}

Expand All @@ -324,7 +326,8 @@ object ShuffleBlockFetcherIterator {
* Note that this is NOT the exact bytes. -1 if failure is present.
* @param buf [[ManagedBuffer]] for the content. null is error.
*/
class FetchResult(val blockId: BlockId, val size: Long, val buf: ManagedBuffer) {
case class FetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) {
def failed: Boolean = size == -1
if (failed) assert(buf == null) else assert(buf != null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@

package org.apache.spark.storage

import java.util.concurrent.Semaphore

import scala.concurrent.future
import scala.concurrent.ExecutionContext.Implicits.global

import org.mockito.Mockito._
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.invocation.InvocationOnMock
Expand All @@ -30,80 +35,200 @@ import org.apache.spark.serializer.TestSerializer


class ShuffleBlockFetcherIteratorSuite extends FunSuite {
// Some of the tests are quite tricky because we are testing the cleanup behavior
// in the presence of faults.

val conf = new SparkConf
/** Creates a mock [[BlockTransferService]] that returns data from the given map. */
private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = {
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val blocks = invocation.getArguments()(2).asInstanceOf[Seq[String]]
val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]

test("handle successful local reads") {
val buf = mock(classOf[ManagedBuffer])
val blockManager = mock(classOf[BlockManager])
doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId
for (blockId <- blocks) {
if (data.contains(BlockId(blockId))) {
listener.onBlockFetchSuccess(blockId, data(BlockId(blockId)))
} else {
listener.onBlockFetchFailure(new BlockNotFoundException(blockId))
}
}
}
})
transfer
}

val blockIds = Array[BlockId](
ShuffleBlockId(0, 0, 0),
ShuffleBlockId(0, 1, 0),
ShuffleBlockId(0, 2, 0),
ShuffleBlockId(0, 3, 0),
ShuffleBlockId(0, 4, 0))
private val conf = new SparkConf

// All blocks should be fetched successfully
blockIds.foreach { blockId =>
test("successful 3 local reads + 2 remote reads") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
doReturn(localBmId).when(blockManager).blockManagerId

// Make sure blockManager.getBlockData would return the blocks
val localBlocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]))
localBlocks.foreach { case (blockId, buf) =>
doReturn(buf).when(blockManager).getBlockData(meq(blockId.toString))
}

val bmId = BlockManagerId("test-client", "test-client", 1)
// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val remoteBlocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]),
ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer])
)

val transfer = createMockTransfer(remoteBlocks)

val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(bmId, blockIds.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)
(localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq),
(remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)
)

val iterator = new ShuffleBlockFetcherIterator(
new TaskContext(0, 0, 0),
mock(classOf[BlockTransferService]),
transfer,
blockManager,
blocksByAddress,
new TestSerializer,
48 * 1024 * 1024)

// Local blocks are fetched immediately.
verify(blockManager, times(5)).getBlockData(any())
// 3 local blocks fetched in initialization
verify(blockManager, times(3)).getBlockData(any())

for (i <- 0 until 5) {
assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements")
assert(iterator.next()._2.isDefined,
val (blockId, subIterator) = iterator.next()
assert(subIterator.isDefined,
s"iterator should have 5 elements defined but actually has $i elements")

// Make sure we release the buffer once the iterator is exhausted.
val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
verify(mockBuf, times(0)).release()
subIterator.get.foreach(_ => Unit) // exhaust the iterator
verify(mockBuf, times(1)).release()
}
// No more fetching of local blocks.
verify(blockManager, times(5)).getBlockData(any())

// 3 local blocks, and 2 remote blocks
// (but from the same block manager so one call to fetchBlocks)
verify(blockManager, times(3)).getBlockData(any())
verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any())
}

test("handle remote fetch failures in BlockTransferService") {
test("release current unexhausted buffer in case the task completes early") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
doReturn(localBmId).when(blockManager).blockManagerId

// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val blocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])
)

// Semaphore to coordinate event sequence in two different threads.
val sem = new Semaphore(0)

val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
listener.onBlockFetchFailure(new Exception("blah"))
future {
// Return the first two blocks, and wait till task completion before returning the 3rd one
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
sem.acquire()
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
}
}
})

val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))

val taskContext = new TaskContext(0, 0, 0)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
blockManager,
blocksByAddress,
new TestSerializer,
48 * 1024 * 1024)

// Exhaust the first block, and then it should be released.
iterator.next()._2.get.foreach(_ => Unit)
verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release()

// Get the 2nd block but do not exhaust the iterator
val subIter = iterator.next()._2.get

// Complete the task; then the 2nd block buffer should be exhausted
verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release()
taskContext.markTaskCompleted()
verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release()

// The 3rd block should not be retained because the iterator is already in zombie state
sem.release()
verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).retain()
verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release()
}

test("fail all blocks if any of the remote request fails") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
doReturn(localBmId).when(blockManager).blockManagerId

// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val blocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])
)

when(blockManager.blockManagerId).thenReturn(BlockManagerId("test-client", "test-client", 1))
// Semaphore to coordinate event sequence in two different threads.
val sem = new Semaphore(0)

val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
future {
// Return the first block, and then fail.
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
listener.onBlockFetchFailure(new BlockNotFoundException("blah"))
sem.release()
}
}
})

val blId1 = ShuffleBlockId(0, 0, 0)
val blId2 = ShuffleBlockId(0, 1, 0)
val bmId = BlockManagerId("test-server", "test-server", 1)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(bmId, Seq((blId1, 1L), (blId2, 1L))))
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))

val taskContext = new TaskContext(0, 0, 0)
val iterator = new ShuffleBlockFetcherIterator(
new TaskContext(0, 0, 0),
taskContext,
transfer,
blockManager,
blocksByAddress,
new TestSerializer,
48 * 1024 * 1024)

iterator.foreach { case (_, iterOption) =>
assert(!iterOption.isDefined)
}
// Continue only after the mock calls onBlockFetchFailure
sem.acquire()

// The first block should be defined, and the last two are not defined (due to failure)
assert(iterator.next()._2.isDefined === true)
assert(iterator.next()._2.isDefined === false)
assert(iterator.next()._2.isDefined === false)
}
}

0 comments on commit cb589ec

Please sign in to comment.