Skip to content

Commit

Permalink
update libSVMFile to determine number of features automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 31, 2014
1 parent 3432e84 commit 6f59eed
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 13 deletions.
27 changes: 20 additions & 7 deletions mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,37 @@ class MLContext(self: SparkContext) {
* where the feature indices are converted to zero-based.
*
* @param path file or directory path in any Hadoop-supported file system URI
* @param numFeatures number of features
* @param labelParser parser for labels, default: _.toDouble
* @param numFeatures number of features, it will be determined from input
* if a non-positive value is given
*@param labelParser parser for labels, default: _.toDouble
* @return labeled data stored as an RDD[LabeledPoint]
*/
def libSVMFile(
path: String,
numFeatures: Int,
labelParser: String => Double = _.toDouble): RDD[LabeledPoint] = {
self.textFile(path).map(_.trim).filter(!_.isEmpty).map { line =>
val items = line.split(' ')
val parsed = self.textFile(path).map(_.trim).filter(!_.isEmpty).map(_.split(' '))
// Determine number of features.
val d = if (numFeatures > 0) {
numFeatures
} else {
parsed.map { items =>
if (items.length > 1) {
items.last.split(':')(0).toInt
} else {
0
}
}.reduce(math.max)
}
parsed.map { items =>
val label = labelParser(items.head)
val features = Vectors.sparse(numFeatures, items.tail.map { item =>
val (indices, values) = items.tail.map { item =>
val indexAndValue = item.split(':')
val index = indexAndValue(0).toInt - 1
val value = indexAndValue(1).toDouble
(index, value)
})
LabeledPoint(label, features)
}.unzip
LabeledPoint(label, Vectors.sparse(d, indices.toArray, values.toArray))
}
}
}
Expand Down
21 changes: 15 additions & 6 deletions mllib/src/test/scala/org/apache/spark/mllib/MLContextSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,26 @@ class MLContextSuite extends FunSuite with LocalSparkContext {
val lines =
"""
|1 1:1.0 3:2.0 5:3.0
|0
|0 2:4.0 4:5.0 6:6.0
""".stripMargin
val tempDir = Files.createTempDir()
val file = new File(tempDir.getPath, "part-00000")
Files.write(lines, file, Charsets.US_ASCII)
val points = sc.libSVMFile(tempDir.toURI.toString, 6).collect()
assert(points.length === 2)
assert(points(0).label === 1.0)
assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
assert(points(1).label === 0.0)
assert(points(1).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))

val pointsWithNumFeatures = sc.libSVMFile(tempDir.toURI.toString, 6).collect()
val pointsWithoutNumFeatures = sc.libSVMFile(tempDir.toURI.toString, 0).collect()

for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
assert(points.length === 3)
assert(points(0).label === 1.0)
assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
assert(points(1).label == 0.0)
assert(points(1).features == Vectors.sparse(6, Seq()))
assert(points(2).label === 0.0)
assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
}

try {
file.delete()
tempDir.delete()
Expand Down

0 comments on commit 6f59eed

Please sign in to comment.