diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala index c328718bef947..f7966d3ebb613 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala @@ -27,7 +27,10 @@ trait LabelParser extends Serializable { * Label parser for binary labels, which outputs 1.0 (positive) if the value is greater than 0.5, * or 0.0 (negative) otherwise. So it works with +1/-1 labeling and +1/0 labeling. */ -class BinaryLabelParser extends LabelParser { +object BinaryLabelParser extends LabelParser { + /** Gets the default instance of BinaryLabelParser. */ + def getInstance(): LabelParser = this + /** * Parses the input label into positive (1.0) if the value is greater than 0.5, * or negative (0.0) otherwise. @@ -35,19 +38,12 @@ class BinaryLabelParser extends LabelParser { override def parse(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0 } -object BinaryLabelParser extends BinaryLabelParser { - /** Gets the default instance of BinaryLabelParser. */ - def getInstance(): BinaryLabelParser = this -} - /** * Label parser for multiclass labels, which converts the input label to double. */ -class MulticlassLabelParser extends LabelParser { - override def parse(labelString: String): Double = labelString.toDouble -} - -object MulticlassLabelParser extends MulticlassLabelParser { +object MulticlassLabelParser extends LabelParser { /** Gets the default instance of MulticlassLabelParser. */ - def getInstance(): MulticlassLabelParser = this + def getInstance(): LabelParser = this + + override def parse(labelString: String): Double = labelString.toDouble }