Skip to content

Commit

Permalink
move libSVMFile to MLUtils and rename to loadLibSVMData
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 2, 2014
1 parent c26c4fc commit eb6e793
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 153 deletions.
88 changes: 0 additions & 88 deletions mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala

This file was deleted.

101 changes: 101 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,107 @@ object MLUtils {
eps
}

/**
* Multiclass label parser, which parses a string into double.
*/
val multiclassLabelParser: String => Double = _.toDouble

/**
* Binary label parser, which outputs 1.0 (positive) if the value is greater than 0.5,
* or 0.0 (negative) otherwise.
*/
val binaryLabelParser: String => Double = label => if (label.toDouble > 0.5) 1.0 else 0.0

/**
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint].
* The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR.
* Each line represents a labeled sparse feature vector using the following format:
* {{{label index1:value1 index2:value2 ...}}}
* where the indices are one-based and in ascending order.
* This method parses each line into a [[org.apache.spark.mllib.regression.LabeledPoint]],
* where the feature indices are converted to zero-based.
*
* @param sc Spark context
* @param path file or directory path in any Hadoop-supported file system URI
* @param labelParser parser for labels, default: 1.0 if label > 0.5 or 0.0 otherwise
* @param numFeatures number of features, which will be determined from the input data if a
* negative value is given. The default value is -1.
* @param minSplits min number of partitions, default: sc.defaultMinSplits
* @return labeled data stored as an RDD[LabeledPoint]
*/
def loadLibSVMData(
sc: SparkContext,
path: String,
labelParser: String => Double,
numFeatures: Int,
minSplits: Int): RDD[LabeledPoint] = {
val parsed = sc.textFile(path, minSplits)
.map(_.trim)
.filter(!_.isEmpty)
.map(_.split(' '))
// Determine number of features.
val d = if (numFeatures >= 0) {
numFeatures
} else {
parsed.map { items =>
if (items.length > 1) {
items.last.split(':')(0).toInt
} else {
0
}
}.reduce(math.max)
}
parsed.map { items =>
val label = labelParser(items.head)
val (indices, values) = items.tail.map { item =>
val indexAndValue = item.split(':')
val index = indexAndValue(0).toInt - 1
val value = indexAndValue(1).toDouble
(index, value)
}.unzip
LabeledPoint(label, Vectors.sparse(d, indices.toArray, values.toArray))
}
}

// Convenient methods for calling from Java.

/**
* Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint],
* with number of features determined automatically and the default number of partitions.
*/
def loadLibSVMData(sc: SparkContext, path: String): RDD[LabeledPoint] =
loadLibSVMData(sc, path, binaryLabelParser, -1, sc.defaultMinSplits)

/**
* Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint],
* with number of features specified explicitly and the default number of partitions.
*/
def loadLibSVMData(sc: SparkContext, path: String, numFeatures: Int): RDD[LabeledPoint] =
loadLibSVMData(sc, path, binaryLabelParser, numFeatures, sc.defaultMinSplits)

/**
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint],
* with the given label parser, number of features determined automatically,
* and the default number of partitions.
*/
def loadLibSVMData(
sc: SparkContext,
path: String,
labelParser: String => Double): RDD[LabeledPoint] =
loadLibSVMData(sc, path, labelParser, -1, sc.defaultMinSplits)

/**
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint],
* with the given label parser, number of features specified explicitly,
* and the default number of partitions.
*/
def loadLibSVMData(
sc: SparkContext,
path: String,
labelParser: String => Double,
numFeatures: Int): RDD[LabeledPoint] =
loadLibSVMData(sc, path, labelParser, numFeatures, sc.defaultMinSplits)

/**
* Load labeled data from a file. The data format used here is
* <L>, <f1> <f2> ...
Expand Down
64 changes: 0 additions & 64 deletions mllib/src/test/scala/org/apache/spark/mllib/MLContextSuite.scala

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@

package org.apache.spark.mllib.util

import java.io.File

import org.scalatest.FunSuite

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm,
squaredDistance => breezeSquaredDistance}
import com.google.common.base.Charsets
import com.google.common.io.Files

import org.apache.spark.mllib.util.MLUtils._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._

class MLUtilsSuite extends FunSuite with LocalSparkContext {

Expand Down Expand Up @@ -63,4 +67,43 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
assert(mean === Vectors.dense(2.0, 3.0, 4.0))
assert(std === Vectors.dense(1.0, 1.0, 1.0))
}

test("loadLibSVMData") {
val lines =
"""
|+1 1:1.0 3:2.0 5:3.0
|-1
|-1 2:4.0 4:5.0 6:6.0
""".stripMargin
val tempDir = Files.createTempDir()
val file = new File(tempDir.getPath, "part-00000")
Files.write(lines, file, Charsets.US_ASCII)
val path = tempDir.toURI.toString

val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, 6).collect()
val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect()

for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
assert(points.length === 3)
assert(points(0).label === 1.0)
assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
assert(points(1).label == 0.0)
assert(points(1).features == Vectors.sparse(6, Seq()))
assert(points(2).label === 0.0)
assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
}

val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MLUtils.multiclassLabelParser).collect()
assert(multiclassPoints.length === 3)
assert(multiclassPoints(0).label === 1.0)
assert(multiclassPoints(1).label === -1.0)
assert(multiclassPoints(2).label === -1.0)

try {
file.delete()
tempDir.delete()
} catch {
case t: Throwable =>
}
}
}

0 comments on commit eb6e793

Please sign in to comment.