Skip to content

Commit

Permalink
rename loadLibSVMData to loadLibSVMFile; hide LabelParser from user APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed May 5, 2014
1 parent 54b812c commit 649fcf0
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ object BinaryClassification {

Logger.getRootLogger.setLevel(Level.WARN)

val examples = MLUtils.loadLibSVMData(sc, params.input).cache()
val examples = MLUtils.loadLibSVMFile(sc, params.input).cache()

val splits = examples.randomSplit(Array(0.8, 0.2))
val training = splits(0).cache()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scopt.OptionParser

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.mllib.util.{MulticlassLabelParser, MLUtils}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.optimization.{SimpleUpdater, SquaredL2Updater, L1Updater}

/**
Expand Down Expand Up @@ -82,7 +82,7 @@ object LinearRegression extends App {

Logger.getRootLogger.setLevel(Level.WARN)

val examples = MLUtils.loadLibSVMData(sc, params.input, MulticlassLabelParser).cache()
val examples = MLUtils.loadLibSVMFile(sc, params.input, multiclass = true).cache()

val splits = examples.randomSplit(Array(0.8, 0.2))
val training = splits(0).cache()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ object SparseNaiveBayes {
val minPartitions =
if (params.minPartitions > 0) params.minPartitions else sc.defaultMinPartitions

val examples = MLUtils.loadLibSVMData(sc, params.input, MulticlassLabelParser,
params.numFeatures, minPartitions)
val examples =
MLUtils.loadLibSVMFile(sc, params.input, multiclass = true, params.numFeatures, minPartitions)
// Cache examples because it will be used in both training and evaluation.
examples.cache()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,23 @@
package org.apache.spark.mllib.util

/** Trait for label parsers. */
trait LabelParser extends Serializable {
private trait LabelParser extends Serializable {
/** Parses a string label into a double label. */
def parse(labelString: String): Double
}

/** Factory methods for label parsers. */
private object LabelParser {
def getInstance(multiclass: Boolean): LabelParser = {
if (multiclass) MulticlassLabelParser else BinaryLabelParser
}
}

/**
* Label parser for binary labels, which outputs 1.0 (positive) if the value is greater than 0.5,
* or 0.0 (negative) otherwise. So it works with +1/-1 labeling and +1/0 labeling.
*/
object BinaryLabelParser extends LabelParser {
private object BinaryLabelParser extends LabelParser {
/** Gets the default instance of BinaryLabelParser. */
def getInstance(): LabelParser = this

Expand All @@ -41,7 +48,7 @@ object BinaryLabelParser extends LabelParser {
/**
* Label parser for multiclass labels, which converts the input label to double.
*/
object MulticlassLabelParser extends LabelParser {
private object MulticlassLabelParser extends LabelParser {
/** Gets the default instance of MulticlassLabelParser. */
def getInstance(): LabelParser = this

Expand Down
119 changes: 74 additions & 45 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.rdd.PartitionwiseSampledRDD
import org.apache.spark.util.random.BernoulliSampler
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.storage.StorageLevel

/**
* Helper methods to load, save and pre-process data used in ML Lib.
Expand All @@ -44,7 +45,6 @@ object MLUtils {
}

/**
* :: Experimental ::
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint].
* The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR.
* Each line represents a labeled sparse feature vector using the following format:
Expand All @@ -55,84 +55,113 @@ object MLUtils {
*
* @param sc Spark context
* @param path file or directory path in any Hadoop-supported file system URI
* @param labelParser parser for labels, default: 1.0 if label > 0.5 or 0.0 otherwise
* @param labelParser parser for labels
* @param numFeatures number of features, which will be determined from the input data if a
* negative value is given. The default value is -1.
* @param minPartitions min number of partitions, default: sc.defaultMinPartitions
* nonpositive value is given. This is useful when the dataset is already split
* into multiple files and you want to load them separately, because some
* features may not present in certain files, which leads to inconsistent
* feature dimensions.
* @param minPartitions min number of partitions
* @return labeled data stored as an RDD[LabeledPoint]
*/
@Experimental
def loadLibSVMData(
private def loadLibSVMFile(
sc: SparkContext,
path: String,
labelParser: LabelParser,
numFeatures: Int,
minPartitions: Int): RDD[LabeledPoint] = {
val parsed = sc.textFile(path, minPartitions)
.map(_.trim)
.filter(!_.isEmpty)
.map(_.split(' '))
.filter(line => !(line.isEmpty || line.startsWith("#")))
.map { line =>
val items = line.split(' ')
val label = labelParser.parse(items.head)
val (indices, values) = items.tail.map { item =>
val indexAndValue = item.split(':')
val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.
val value = indexAndValue(1).toDouble
(index, value)
}.unzip
(label, indices.toArray, values.toArray)
}

// Determine number of features.
val d = if (numFeatures >= 0) {
val d = if (numFeatures > 0) {
numFeatures
} else {
parsed.map { items =>
if (items.length > 1) {
items.last.split(':')(0).toInt
} else {
0
}
}.reduce(math.max)
parsed.persist(StorageLevel.MEMORY_ONLY)
parsed.map { case (label, indices, values) =>
indices.lastOption.getOrElse(0)
}.reduce(math.max) + 1
}
parsed.map { items =>
val label = labelParser.parse(items.head)
val (indices, values) = items.tail.map { item =>
val indexAndValue = item.split(':')
val index = indexAndValue(0).toInt - 1
val value = indexAndValue(1).toDouble
(index, value)
}.unzip
LabeledPoint(label, Vectors.sparse(d, indices.toArray, values.toArray))

parsed.map { case (label, indices, values) =>
LabeledPoint(label, Vectors.sparse(d, indices, values))
}
}

// Convenient methods for calling from Java.
// Convenient methods for `loadLibSVMFile`.

/**
* :: Experimental ::
* Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint],
* with number of features determined automatically and the default number of partitions.
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint].
* The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR.
* Each line represents a labeled sparse feature vector using the following format:
* {{{label index1:value1 index2:value2 ...}}}
* where the indices are one-based and in ascending order.
* This method parses each line into a [[org.apache.spark.mllib.regression.LabeledPoint]],
* where the feature indices are converted to zero-based.
*
* @param sc Spark context
* @param path file or directory path in any Hadoop-supported file system URI
* @param multiclass whether the input labels contain more than two classes. If false, any label
* with value greater than 0.5 will be mapped to 1.0, or 0.0 otherwise. So it
* works for both +1/-1 and 1/0 cases. If true, the double value parsed directly
* from the label string will be used as the label value.
* @param numFeatures number of features, which will be determined from the input data if a
* nonpositive value is given. This is useful when the dataset is already split
* into multiple files and you want to load them separately, because some
* features may not present in certain files, which leads to inconsistent
* feature dimensions.
* @param minPartitions min number of partitions
* @return labeled data stored as an RDD[LabeledPoint]
*/
@Experimental
def loadLibSVMData(sc: SparkContext, path: String): RDD[LabeledPoint] =
loadLibSVMData(sc, path, BinaryLabelParser, -1, sc.defaultMinPartitions)
def loadLibSVMFile(
sc: SparkContext,
path: String,
multiclass: Boolean,
numFeatures: Int,
minPartitions: Int): RDD[LabeledPoint] =
loadLibSVMFile(sc, path, LabelParser.getInstance(multiclass), numFeatures, minPartitions)

/**
* :: Experimental ::
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint],
* with the given label parser, number of features determined automatically,
* with the given label parser, number of features specified explicitly,
* and the default number of partitions.
*/
@Experimental
def loadLibSVMData(
def loadLibSVMFile(
sc: SparkContext,
path: String,
labelParser: LabelParser): RDD[LabeledPoint] =
loadLibSVMData(sc, path, labelParser, -1, sc.defaultMinPartitions)
multiclass: Boolean,
numFeatures: Int): RDD[LabeledPoint] =
loadLibSVMFile(sc, path, multiclass, numFeatures, sc.defaultMinPartitions)

/**
* :: Experimental ::
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint],
* with the given label parser, number of features specified explicitly,
* with the given label parser, number of features determined automatically,
* and the default number of partitions.
*/
@Experimental
def loadLibSVMData(
def loadLibSVMFile(
sc: SparkContext,
path: String,
labelParser: LabelParser,
numFeatures: Int): RDD[LabeledPoint] =
loadLibSVMData(sc, path, labelParser, numFeatures, sc.defaultMinPartitions)
multiclass: Boolean): RDD[LabeledPoint] =
loadLibSVMFile(sc, path, multiclass, -1, sc.defaultMinPartitions)

/**
* Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint],
* with number of features determined automatically and the default number of partitions.
*/
def loadLibSVMFile(sc: SparkContext, path: String): RDD[LabeledPoint] =
loadLibSVMFile(sc, path, multiclass = false, -1, sc.defaultMinPartitions)

/**
* Save labeled data in LIBSVM format.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
}
}

test("loadLibSVMData") {
test("loadLibSVMFile") {
val lines =
"""
|+1 1:1.0 3:2.0 5:3.0
Expand All @@ -71,8 +71,8 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
Files.write(lines, file, Charsets.US_ASCII)
val path = tempDir.toURI.toString

val pointsWithNumFeatures = loadLibSVMData(sc, path, BinaryLabelParser, 6).collect()
val pointsWithoutNumFeatures = loadLibSVMData(sc, path).collect()
val pointsWithNumFeatures = loadLibSVMFile(sc, path, multiclass = false, 6).collect()
val pointsWithoutNumFeatures = loadLibSVMFile(sc, path).collect()

for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
assert(points.length === 3)
Expand All @@ -84,7 +84,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
}

val multiclassPoints = loadLibSVMData(sc, path, MulticlassLabelParser).collect()
val multiclassPoints = loadLibSVMFile(sc, path, multiclass = true).collect()
assert(multiclassPoints.length === 3)
assert(multiclassPoints(0).label === 1.0)
assert(multiclassPoints(1).label === -1.0)
Expand Down

0 comments on commit 649fcf0

Please sign in to comment.