diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 049ba326b6921..d12dd8a894bb2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -523,7 +523,7 @@ class SparkContext(config: SparkConf) extends Logging { val job = new NewHadoopJob(hadoopConfiguration) NewFileInputFormat.addInputPath(job, new Path(path)) val updateConf = job.getConfiguration - new RawFileRDD( + new BinaryFileRDD( this, classOf[ByteInputFormat], classOf[String], @@ -548,7 +548,7 @@ class SparkContext(config: SparkConf) extends Logging { val job = new NewHadoopJob(hadoopConfiguration) NewFileInputFormat.addInputPath(job, new Path(path)) val updateConf = job.getConfiguration - new RawFileRDD( + new BinaryFileRDD( this, classOf[StreamInputFormat], classOf[String], @@ -565,9 +565,9 @@ class SparkContext(config: SparkConf) extends Logging { * @param path Directory to the input data files * @return An RDD of data with values, RDD[(Array[Byte])] */ - def fixedLengthBinaryFiles(path: String): RDD[Array[Byte]] = { - val lines = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path) - val data = lines.map{ case (k, v) => v.getBytes} + def binaryRecords(path: String): RDD[Array[Byte]] = { + val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path) + val data = br.map{ case (k, v) => v.getBytes} data } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index d366befa6240c..a407263aa8dd0 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -289,7 +289,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * @param minPartitions A suggestion value of the minimal splitting number for input data. */ def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): - JavaPairRDD[String,Array[Byte]] = new JavaPairRDD(sc.binaryFiles(path,minPartitions)) + JavaPairRDD[String, Array[Byte]] = new JavaPairRDD(sc.binaryFiles(path,minPartitions)) /** * Load data from a flat binary file, assuming each record is a set of numbers @@ -299,8 +299,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * @param path Directory to the input data files * @return An RDD of data with values, JavaRDD[(Array[Byte])] */ - def fixedLengthBinaryFiles(path: String): JavaRDD[Array[Byte]] = { - new JavaRDD(sc.fixedLengthBinaryFiles(path)) + def binaryRecords(path: String): JavaRDD[Array[Byte]] = { + new JavaRDD(sc.binaryRecords(path)) } /** Get an RDD for a Hadoop SequenceFile with given key and value types. diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala index a59317afc0e0f..309fd578b532f 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAt * a parameter recordLength in the Hadoop configuration. */ -object FixedLengthBinaryInputFormat { +private[spark] object FixedLengthBinaryInputFormat { /** * This function retrieves the recordLength by checking the configuration parameter @@ -42,7 +42,7 @@ object FixedLengthBinaryInputFormat { } -class FixedLengthBinaryInputFormat extends FileInputFormat[LongWritable, BytesWritable] { +private[spark] class FixedLengthBinaryInputFormat extends FileInputFormat[LongWritable, BytesWritable] { /** diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala index e7ea91809aac5..a292a1e41d912 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala @@ -37,7 +37,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileSplit * VALUE = the record itself (BytesWritable) * */ -class FixedLengthBinaryRecordReader extends RecordReader[LongWritable, BytesWritable] { +private[spark] class FixedLengthBinaryRecordReader extends RecordReader[LongWritable, BytesWritable] { override def close() { if (fileInputStream != null) { diff --git a/core/src/main/scala/org/apache/spark/input/RawFileInput.scala b/core/src/main/scala/org/apache/spark/input/RawFileInput.scala index eca82184f587f..fa79ed7af3513 100644 --- a/core/src/main/scala/org/apache/spark/input/RawFileInput.scala +++ b/core/src/main/scala/org/apache/spark/input/RawFileInput.scala @@ -73,8 +73,18 @@ abstract class StreamBasedRecordReader[T]( private val key = path.toString private var value: T = null.asInstanceOf[T] + // the file to be read when nextkeyvalue is called + private lazy val fileIn: FSDataInputStream = fs.open(path) + override def initialize(split: InputSplit, context: TaskAttemptContext) = {} - override def close() = {} + override def close() = { + // make sure the file is closed + try { + fileIn.close() + } catch { + case ioe: java.io.IOException => // do nothing + } + } override def getProgress = if (processed) 1.0f else 0.0f @@ -82,9 +92,10 @@ abstract class StreamBasedRecordReader[T]( override def getCurrentValue = value + override def nextKeyValue = { if (!processed) { - val fileIn: FSDataInputStream = fs.open(path) + value = parseStream(fileIn) processed = true true @@ -104,7 +115,7 @@ abstract class StreamBasedRecordReader[T]( /** * Reads the record in directly as a stream for other objects to manipulate and handle */ -class StreamRecordReader( +private[spark] class StreamRecordReader( split: CombineFileSplit, context: TaskAttemptContext, index: Integer) @@ -117,7 +128,7 @@ class StreamRecordReader( * A class for extracting the information from the file using the * BinaryRecordReader (as Byte array) */ -class StreamInputFormat extends StreamFileInputFormat[DataInputStream] { +private[spark] class StreamInputFormat extends StreamFileInputFormat[DataInputStream] { override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext)= { new CombineFileRecordReader[String,DataInputStream]( @@ -146,7 +157,7 @@ abstract class BinaryRecordReader[T]( } -class ByteRecordReader( +private[spark] class ByteRecordReader( split: CombineFileSplit, context: TaskAttemptContext, index: Integer) @@ -158,7 +169,7 @@ class ByteRecordReader( /** * A class for reading the file using the BinaryRecordReader (as Byte array) */ -class ByteInputFormat extends StreamFileInputFormat[Array[Byte]] { +private[spark] class ByteInputFormat extends StreamFileInputFormat[Array[Byte]] { override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext)= { new CombineFileRecordReader[String,Array[Byte]]( diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 7c31e2b50ab75..c7dc50820d59b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -23,10 +23,10 @@ package org.apache.spark.rdd import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import org.apache.spark.{Partition, SparkContext} +import org.apache.spark.{InterruptibleIterator, TaskContext, Partition, SparkContext} import org.apache.spark.input.StreamFileInputFormat -private[spark] class RawFileRDD[T]( +private[spark] class BinaryFileRDD[T]( sc : SparkContext, inputFormatClass: Class[_ <: StreamFileInputFormat[T]], keyClass: Class[String], @@ -35,6 +35,7 @@ private[spark] class RawFileRDD[T]( minPartitions: Int) extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) { + override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance inputFormat match { diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e8bd65f8e4507..5bf2dbccf28f5 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -836,6 +836,28 @@ public Tuple2 call(Tuple2 pair) { Assert.assertEquals(pairs, readRDD.collect()); } + @Test + public void binaryFiles() throws Exception { + // Reusing the wholeText files example + byte[] content1 = "spark is easy to use.\n".getBytes("utf-8"); + byte[] content2 = "spark is also easy to use.\n".getBytes("utf-8"); + + String tempDirName = tempDir.getAbsolutePath(); + File file1 = new File(tempDirName + "/part-00000"); + Files.write(content1, file1); + File file2 = new File(tempDirName + "/part-00001"); + Files.write(content2, file2); + + JavaPairRDD readRDD = sc.binaryFiles(tempDirName,3); + List> result = readRDD.collect(); + for (Tuple2 res : result) { + if (res._1()==file1.toString()) + Assert.assertArrayEquals(content1,res._2()); + else + Assert.assertArrayEquals(content2,res._2()); + } + } + @SuppressWarnings("unchecked") @Test public void writeWithNewAPIHadoopFile() { diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index c70e22cf09433..37db288430692 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -224,6 +224,60 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) } + test("byte stream input") { + sc = new SparkContext("local", "test") + val outputDir = new File(tempDir, "output").getAbsolutePath + val outFile = new File(outputDir, "part-00000.bin") + val outFileName = outFile.toPath().toString() + + // create file + val testOutput = Array[Byte](1,2,3,4,5,6) + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + // write data to file + val file = new java.io.FileOutputStream(outFile) + val channel = file.getChannel + channel.write(bbuf) + channel.close() + file.close() + + val inRdd = sc.binaryFiles(outFileName) + val (infile: String, indata: Array[Byte]) = inRdd.first + + // Try reading the output back as an object file + assert(infile === outFileName) + assert(indata === testOutput) + } + + test("fixed length byte stream input") { + // a fixed length of 6 bytes + + sc = new SparkContext("local", "test") + + val outputDir = new File(tempDir, "output").getAbsolutePath + val outFile = new File(outputDir, "part-00000.bin") + val outFileName = outFile.toPath().toString() + + // create file + val testOutput = Array[Byte](1,2,3,4,5,6) + val testOutputCopies = 10 + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + // write data to file + val file = new java.io.FileOutputStream(outFile) + val channel = file.getChannel + for(i <- 1 to testOutputCopies) channel.write(bbuf) + channel.close() + file.close() + sc.hadoopConfiguration.setInt("recordLength",testOutput.length) + + val inRdd = sc.binaryRecords(outFileName) + // make sure there are enough elements + assert(inRdd.count== testOutputCopies) + + // now just compare the first one + val indata: Array[Byte] = inRdd.first + assert(indata === testOutput) + } + test("file caching") { sc = new SparkContext("local", "test") val out = new FileWriter(tempDir + "/input")