Skip to content

Commit

Permalink
add whole text files reader
Browse files Browse the repository at this point in the history
  • Loading branch information
yinxusen committed Mar 27, 2014
1 parent d679843 commit 28cb0fe
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 0 deletions.
68 changes: 68 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib

import org.apache.spark.mllib.input.WholeTextFileInputFormat
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext

/**
* Extra functions available on SparkContext of mllib through an implicit conversion. Import
* `org.apache.spark.mllib.MLContext._` at the top of your program to use these functions.
*/
class MLContext(self: SparkContext) {

/**
* Read a directory of text files from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI. Each file is read as a single record and returned in a
* key-value pair, where the key is the path of each file, the value is the content of each file.
*
* <p> For example, if you have the following files:
* {{{
* hdfs://a-hdfs-path/part-00000
* hdfs://a-hdfs-path/part-00001
* ...
* hdfs://a-hdfs-path/part-nnnnn
* }}}
*
* Do `val rdd = mlContext.wholeTextFile("hdfs://a-hdfs-path")`,
*
* <p> then `rdd` contains
* {{{
* (a-hdfs-path/part-00000, its content)
* (a-hdfs-path/part-00001, its content)
* ...
* (a-hdfs-path/part-nnnnn, its content)
* }}}
*/
def wholeTextFile(path: String): RDD[(String, String)] = {
self.newAPIHadoopFile(
path,
classOf[WholeTextFileInputFormat],
classOf[String],
classOf[String])
}
}

/**
* The MLContext object contains a number of implicit conversions and parameters for use with
* various mllib features.
*/
object MLContext {
implicit def sparkContextToMLContext(sc: SparkContext) = new MLContext(sc)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.input

import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.InputSplit
import org.apache.hadoop.mapreduce.JobContext
import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat
import org.apache.hadoop.mapreduce.RecordReader
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.hadoop.mapreduce.lib.input.CombineFileRecordReader
import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit

/**
* A [[org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat CombineFileInputFormat]] for
* reading whole text files. Each file is read as key-value pair, where the key is the file path and
* the value is the entire content of file.
*/

private[mllib] class WholeTextFileInputFormat extends CombineFileInputFormat[String, String] {
override protected def isSplitable(context: JobContext, file: Path): Boolean = false

override def createRecordReader(
split: InputSplit,
context: TaskAttemptContext): RecordReader[String, String] = {

new CombineFileRecordReader[String, String](
split.asInstanceOf[CombineFileSplit],
context,
classOf[WholeTextFileRecordReader])
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.input

import com.google.common.io.{ByteStreams, Closeables}

import org.apache.hadoop.io.Text
import org.apache.hadoop.mapreduce.InputSplit
import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
import org.apache.hadoop.mapreduce.RecordReader
import org.apache.hadoop.mapreduce.TaskAttemptContext

/**
* A [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] for reading a single whole text file
* out in a key-value pair, where the key is the file path and the value is the entire content of
* the file.
*/
private[mllib] class WholeTextFileRecordReader(
split: CombineFileSplit,
context: TaskAttemptContext,
index: Integer)
extends RecordReader[String, String] {

private val path = split.getPath(index)
private val fs = path.getFileSystem(context.getConfiguration)

// True means the current file has been processed, then skip it.
private var processed = false

private val key = path.toString
private var value: String = null

override def initialize(split: InputSplit, context: TaskAttemptContext) = {}

override def close() = {}

override def getProgress = if (processed) 1.0f else 0.0f

override def getCurrentKey = key

override def getCurrentValue = value

override def nextKeyValue = {
if (!processed) {
val fileIn = fs.open(path)
val innerBuffer = ByteStreams.toByteArray(fileIn)

value = new Text(innerBuffer).toString
Closeables.close(fileIn, false)

processed = true
true
} else {
false
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.input

import java.io.DataOutputStream
import java.io.File
import java.io.FileOutputStream

import scala.collection.immutable.IndexedSeq

import com.google.common.io.Files

import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite

import org.apache.hadoop.io.Text

import org.apache.spark.SparkContext
import org.apache.spark.mllib.MLContext._

/**
* Tests the correctness of
* [[org.apache.spark.mllib.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary
* directory is created as fake input. Temporal storage would be deleted in the end.
*/
class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll {
private var sc: SparkContext = _

override def beforeAll() {
sc = new SparkContext("local", "test")
}

override def afterAll() {
sc.stop()
}

private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte]) = {
val out = new DataOutputStream(new FileOutputStream(s"${inputDir.toString}/$fileName"))
out.write(contents, 0, contents.length)
out.close()
}

/**
* This code will test the behaviors of WholeTextFileRecordReader based on local disk. There are
* three aspects to check:
* 1) Whether all files are read;
* 2) Whether paths are read correctly;
* 3) Does the contents be the same.
*/
test("Correctness of WholeTextFileRecordReader.") {

val dir = Files.createTempDir()
println(s"Local disk address is ${dir.toString}.")

WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) =>
createNativeFile(dir, filename, contents)
}

val res = sc.wholeTextFile(dir.toString).collect()

assert(res.size === WholeTextFileRecordReaderSuite.fileNames.size,
"Number of files read out does not fit with the actual value.")

for ((filename, contents) <- res) {
val shortName = filename.split('/').last
assert(WholeTextFileRecordReaderSuite.fileNames.contains(shortName),
s"Missing file name $filename.")
assert(contents === new Text(WholeTextFileRecordReaderSuite.files(shortName)).toString,
s"file $filename contents can not match.")
}

dir.delete()
}
}

/**
* Files to be tested are defined here.
*/
object WholeTextFileRecordReaderSuite {
private val testWords: IndexedSeq[Byte] = "Spark is easy to use.\n".map(_.toByte)

private val fileNames = Array("part-00000", "part-00001", "part-00002")
private val fileLengths = Array(10, 100, 1000)

private val files = fileLengths.zip(fileNames).map { case (upperBound, filename) =>
filename -> Stream.continually(testWords.toList.toStream).flatten.take(upperBound).toArray
}.toMap
}

0 comments on commit 28cb0fe

Please sign in to comment.