diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 0635b98742096..aefb2f5685537 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -26,7 +26,8 @@ import org.apache.spark._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -private[hash] object BlockStoreShuffleFetcher extends Logging { +private[hash] class BlockStoreShuffleFetcher extends Logging { + def fetchBlockStreams( shuffleId: Int, reduceId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index ca6eddf8d5c12..b868f32f5cce1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,6 +17,7 @@ package org.apache.spark.shuffle.hash +import org.apache.spark.storage.BlockManager import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} @@ -27,18 +28,19 @@ private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, - context: TaskContext) + context: TaskContext, + blockManager: BlockManager = SparkEnv.get.blockManager, + blockStoreShuffleFetcher: BlockStoreShuffleFetcher = new BlockStoreShuffleFetcher) extends ShuffleReader[K, C] { require(endPartition == startPartition + 1, "Hash shuffle currently only supports fetching one partition") private val dep = handle.dependency - private val blockManager = SparkEnv.get.blockManager /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( + val blockStreams = blockStoreShuffleFetcher.fetchBlockStreams( handle.shuffleId, startPartition, context) // Wrap the streams for compression based on configuration diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 491dc3659e184..53b2b89a5e641 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -17,16 +17,22 @@ package org.apache.spark.shuffle.hash -import java.io.{File, FileWriter} +import java.io._ +import java.nio.ByteBuffer import scala.language.reflectiveCalls -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} -import org.apache.spark.executor.ShuffleWriteMetrics +import org.mockito.Matchers.any +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark._ +import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics, ShuffleWriteMetrics} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.FileShuffleBlockResolver -import org.apache.spark.storage.{ShuffleBlockId, FileSegment} +import org.apache.spark.serializer._ +import org.apache.spark.shuffle.{BaseShuffleHandle, FileShuffleBlockResolver} +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId, FileSegment} class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { private val testConf = new SparkConf(false) @@ -107,4 +113,100 @@ class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { for (i <- 0 until numBytes) writer.write(i) writer.close() } + + test("HashShuffleReader.read() releases resources and tracks metrics") { + val shuffleId = 1 + val numMaps = 2 + val numKeyValuePairs = 10 + + val mockContext = mock(classOf[TaskContext]) + + val mockTaskMetrics = mock(classOf[TaskMetrics]) + val mockReadMetrics = mock(classOf[ShuffleReadMetrics]) + when(mockTaskMetrics.createShuffleReadMetricsForDependency()).thenReturn(mockReadMetrics) + when(mockContext.taskMetrics()).thenReturn(mockTaskMetrics) + + val mockShuffleFetcher = mock(classOf[BlockStoreShuffleFetcher]) + + val mockDep = mock(classOf[ShuffleDependency[_, _, _]]) + when(mockDep.keyOrdering).thenReturn(None) + when(mockDep.aggregator).thenReturn(None) + when(mockDep.serializer).thenReturn(Some(new Serializer { + override def newInstance(): SerializerInstance = new SerializerInstance { + + override def deserializeStream(s: InputStream): DeserializationStream = + new DeserializationStream { + override def readObject[T: ClassManifest](): T = null.asInstanceOf[T] + + override def close(): Unit = s.close() + + private val values = { + for (i <- 0 to numKeyValuePairs * 2) yield i + }.iterator + + private def getValueOrEOF(): Int = { + if (values.hasNext) { + values.next() + } else { + throw new EOFException("End of the file: mock deserializeStream") + } + } + + // NOTE: the readKey and readValue methods are called by asKeyValueIterator() + // which is wrapped in a NextIterator + override def readKey[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] + + override def readValue[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] + } + + override def deserialize[T: ClassManifest](bytes: ByteBuffer, loader: ClassLoader): T = + null.asInstanceOf[T] + + override def serialize[T: ClassManifest](t: T): ByteBuffer = ByteBuffer.allocate(0) + + override def serializeStream(s: OutputStream): SerializationStream = + null.asInstanceOf[SerializationStream] + + override def deserialize[T: ClassManifest](bytes: ByteBuffer): T = null.asInstanceOf[T] + } + })) + + val mockBlockManager = { + // Create a block manager that isn't configured for compression, just returns input stream + val blockManager = mock(classOf[BlockManager]) + when(blockManager.wrapForCompression(any[BlockId](), any[InputStream]())) + .thenAnswer(new Answer[InputStream] { + override def answer(invocation: InvocationOnMock): InputStream = { + val blockId = invocation.getArguments()(0).asInstanceOf[BlockId] + val inputStream = invocation.getArguments()(1).asInstanceOf[InputStream] + inputStream + } + }) + blockManager + } + + val mockInputStream = mock(classOf[InputStream]) + when(mockShuffleFetcher.fetchBlockStreams(any[Int](), any[Int](), any[TaskContext]())) + .thenReturn(Iterator.single((mock(classOf[BlockId]), mockInputStream))) + + val shuffleHandle = new BaseShuffleHandle(shuffleId, numMaps, mockDep) + + val reader = new HashShuffleReader(shuffleHandle, 0, 1, + mockContext, mockBlockManager, mockShuffleFetcher) + + val values = reader.read() + // Verify that we're reading the correct values + var numValuesRead = 0 + for (((key: Int, value: Int), i) <- values.zipWithIndex) { + assert(key == i * 2) + assert(value == i * 2 + 1) + numValuesRead += 1 + } + // Verify that we read the correct number of values + assert(numKeyValuePairs == numValuesRead) + // Verify that our input stream was closed + verify(mockInputStream, times(1)).close() + // Verify that we collected metrics for each key/value pair + verify(mockReadMetrics, times(numKeyValuePairs)).incRecordsRead(1) + } }