Skip to content

Commit

Permalink
change labelParser from annoymous function to trait
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 7, 2014
1 parent 87d0928 commit 7f8eb36
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.util

/** Trait for label parsers. */
trait LabelParser extends Serializable {
/** Parses a string label into a double label. */
def apply(labelString: String): Double
}

/**
* Label parser for binary labels, which outputs 1.0 (positive) if the value is greater than 0.5,
* or 0.0 (negative) otherwise. So it works with +1/-1 labeling and +1/0 labeling.
*/
class BinaryLabelParser extends LabelParser {
/**
* Parses the input label into positive (1.0) if the value is greater than 0.5,
* or negative (0.0) otherwise.
*/
override def apply(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0
}

object BinaryLabelParser {
private lazy val instance = new BinaryLabelParser()
/** Gets the default instance of BinaryLabelParser. */
def apply() = instance
}

/**
* Label parser for multiclass labels, which converts the input label to double.
*/
class MulticlassLabelParser extends LabelParser {
override def apply(labelString: String): Double = labelString.toDouble
}

object MulticlassLabelParser {
private lazy val instance = new MulticlassLabelParser()
/** Gets the default instance of MulticlassLabelParser. */
def apply() = instance
}
26 changes: 4 additions & 22 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,6 @@ object MLUtils {
eps
}

/**
* Multiclass label parser, which parses a string into double.
*/
val multiclassLabelParser: String => Double = _.toDouble

/**
* Binary label parser, which outputs 1.0 (positive) if the value is greater than 0.5,
* or 0.0 (negative) otherwise.
*/
val binaryLabelParser: String => Double = label => if (label.toDouble > 0.5) 1.0 else 0.0

/**
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint].
* The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR.
Expand All @@ -69,7 +58,7 @@ object MLUtils {
def loadLibSVMData(
sc: SparkContext,
path: String,
labelParser: String => Double,
labelParser: LabelParser,
numFeatures: Int,
minSplits: Int): RDD[LabeledPoint] = {
val parsed = sc.textFile(path, minSplits)
Expand Down Expand Up @@ -107,14 +96,7 @@ object MLUtils {
* with number of features determined automatically and the default number of partitions.
*/
def loadLibSVMData(sc: SparkContext, path: String): RDD[LabeledPoint] =
loadLibSVMData(sc, path, binaryLabelParser, -1, sc.defaultMinSplits)

/**
* Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint],
* with number of features specified explicitly and the default number of partitions.
*/
def loadLibSVMData(sc: SparkContext, path: String, numFeatures: Int): RDD[LabeledPoint] =
loadLibSVMData(sc, path, binaryLabelParser, numFeatures, sc.defaultMinSplits)
loadLibSVMData(sc, path, BinaryLabelParser(), -1, sc.defaultMinSplits)

/**
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint],
Expand All @@ -124,7 +106,7 @@ object MLUtils {
def loadLibSVMData(
sc: SparkContext,
path: String,
labelParser: String => Double): RDD[LabeledPoint] =
labelParser: LabelParser): RDD[LabeledPoint] =
loadLibSVMData(sc, path, labelParser, -1, sc.defaultMinSplits)

/**
Expand All @@ -135,7 +117,7 @@ object MLUtils {
def loadLibSVMData(
sc: SparkContext,
path: String,
labelParser: String => Double,
labelParser: LabelParser,
numFeatures: Int): RDD[LabeledPoint] =
loadLibSVMData(sc, path, labelParser, numFeatures, sc.defaultMinSplits)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
Files.write(lines, file, Charsets.US_ASCII)
val path = tempDir.toURI.toString

val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, 6).collect()
val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, BinaryLabelParser(), 6).collect()
val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect()

for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
Expand All @@ -93,7 +93,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
}

val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MLUtils.multiclassLabelParser).collect()
val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MulticlassLabelParser()).collect()
assert(multiclassPoints.length === 3)
assert(multiclassPoints(0).label === 1.0)
assert(multiclassPoints(1).label === -1.0)
Expand Down

0 comments on commit 7f8eb36

Please sign in to comment.