diff --git a/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala b/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala deleted file mode 100644 index 00988bc480dc8..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib - -import org.apache.spark.SparkContext - -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.rdd.RDD - -/** - * Provides methods related to machine learning on top of [[org.apache.spark.SparkContext]]. - * - * @param sparkContext a [[org.apache.spark.SparkContext]] instance - */ -class MLContext(val sparkContext: SparkContext) { - /** - * Reads labeled data in the LIBSVM format into an RDD[LabeledPoint]. - * The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR. - * Each line represents a labeled sparse feature vector using the following format: - * {{{label index1:value1 index2:value2 ...}}} - * where the indices are one-based and in ascending order. - * This method parses each line into a [[org.apache.spark.mllib.regression.LabeledPoint]], - * where the feature indices are converted to zero-based. - * - * @param path file or directory path in any Hadoop-supported file system URI - * @param minSplits min number of partitions, default: sparkContext.defaultMinSplits - * @param numFeatures number of features, which will be determined from the input data if a - * non-positive value is given. The default value is 0. - * @param labelParser parser for labels, default: _.toDouble - * @return labeled data stored as an RDD[LabeledPoint] - */ - def libSVMFile( - path: String, - minSplits: Int = sparkContext.defaultMinSplits, - numFeatures: Int = 0, - labelParser: String => Double = _.toDouble): RDD[LabeledPoint] = { - val parsed = sparkContext.textFile(path, minSplits) - .map(_.trim) - .filter(!_.isEmpty) - .map(_.split(' ')) - // Determine number of features. - val d = if (numFeatures > 0) { - numFeatures - } else { - parsed.map { items => - if (items.length > 1) { - items.last.split(':')(0).toInt - } else { - 0 - } - }.reduce(math.max) - } - parsed.map { items => - val label = labelParser(items.head) - val (indices, values) = items.tail.map { item => - val indexAndValue = item.split(':') - val index = indexAndValue(0).toInt - 1 - val value = indexAndValue(1).toDouble - (index, value) - }.unzip - LabeledPoint(label, Vectors.sparse(d, indices.toArray, values.toArray)) - } - } -} - -object MLContext { - /** - * Creates an [[org.apache.spark.mllib.MLContext]] instance from - * an [[org.apache.spark.SparkContext]] instance. - */ - def apply(sc: SparkContext): MLContext = new MLContext(sc) -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 4c47eb5c7506c..cb85e433bfc73 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -38,6 +38,107 @@ object MLUtils { eps } + /** + * Multiclass label parser, which parses a string into double. + */ + val multiclassLabelParser: String => Double = _.toDouble + + /** + * Binary label parser, which outputs 1.0 (positive) if the value is greater than 0.5, + * or 0.0 (negative) otherwise. + */ + val binaryLabelParser: String => Double = label => if (label.toDouble > 0.5) 1.0 else 0.0 + + /** + * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint]. + * The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR. + * Each line represents a labeled sparse feature vector using the following format: + * {{{label index1:value1 index2:value2 ...}}} + * where the indices are one-based and in ascending order. + * This method parses each line into a [[org.apache.spark.mllib.regression.LabeledPoint]], + * where the feature indices are converted to zero-based. + * + * @param sc Spark context + * @param path file or directory path in any Hadoop-supported file system URI + * @param labelParser parser for labels, default: 1.0 if label > 0.5 or 0.0 otherwise + * @param numFeatures number of features, which will be determined from the input data if a + * negative value is given. The default value is -1. + * @param minSplits min number of partitions, default: sc.defaultMinSplits + * @return labeled data stored as an RDD[LabeledPoint] + */ + def loadLibSVMData( + sc: SparkContext, + path: String, + labelParser: String => Double, + numFeatures: Int, + minSplits: Int): RDD[LabeledPoint] = { + val parsed = sc.textFile(path, minSplits) + .map(_.trim) + .filter(!_.isEmpty) + .map(_.split(' ')) + // Determine number of features. + val d = if (numFeatures >= 0) { + numFeatures + } else { + parsed.map { items => + if (items.length > 1) { + items.last.split(':')(0).toInt + } else { + 0 + } + }.reduce(math.max) + } + parsed.map { items => + val label = labelParser(items.head) + val (indices, values) = items.tail.map { item => + val indexAndValue = item.split(':') + val index = indexAndValue(0).toInt - 1 + val value = indexAndValue(1).toDouble + (index, value) + }.unzip + LabeledPoint(label, Vectors.sparse(d, indices.toArray, values.toArray)) + } + } + + // Convenient methods for calling from Java. + + /** + * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], + * with number of features determined automatically and the default number of partitions. + */ + def loadLibSVMData(sc: SparkContext, path: String): RDD[LabeledPoint] = + loadLibSVMData(sc, path, binaryLabelParser, -1, sc.defaultMinSplits) + + /** + * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], + * with number of features specified explicitly and the default number of partitions. + */ + def loadLibSVMData(sc: SparkContext, path: String, numFeatures: Int): RDD[LabeledPoint] = + loadLibSVMData(sc, path, binaryLabelParser, numFeatures, sc.defaultMinSplits) + + /** + * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], + * with the given label parser, number of features determined automatically, + * and the default number of partitions. + */ + def loadLibSVMData( + sc: SparkContext, + path: String, + labelParser: String => Double): RDD[LabeledPoint] = + loadLibSVMData(sc, path, labelParser, -1, sc.defaultMinSplits) + + /** + * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], + * with the given label parser, number of features specified explicitly, + * and the default number of partitions. + */ + def loadLibSVMData( + sc: SparkContext, + path: String, + labelParser: String => Double, + numFeatures: Int): RDD[LabeledPoint] = + loadLibSVMData(sc, path, labelParser, numFeatures, sc.defaultMinSplits) + /** * Load labeled data from a file. The data format used here is * , ... diff --git a/mllib/src/test/scala/org/apache/spark/mllib/MLContextSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/MLContextSuite.scala deleted file mode 100644 index 6313978d546b9..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/MLContextSuite.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib - -import java.io.File - -import org.scalatest.FunSuite - -import com.google.common.base.Charsets -import com.google.common.io.Files - -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.LocalSparkContext - -class MLContextSuite extends FunSuite with LocalSparkContext { - test("libSVMFile") { - val lines = - """ - |1 1:1.0 3:2.0 5:3.0 - |0 - |0 2:4.0 4:5.0 6:6.0 - """.stripMargin - val tempDir = Files.createTempDir() - val file = new File(tempDir.getPath, "part-00000") - Files.write(lines, file, Charsets.US_ASCII) - - val mlc = MLContext(sc) - - val pointsWithNumFeatures = mlc.libSVMFile(tempDir.toURI.toString, numFeatures = 6).collect() - val pointsWithoutNumFeatures = mlc.libSVMFile(tempDir.toURI.toString).collect() - - for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) { - assert(points.length === 3) - assert(points(0).label === 1.0) - assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) - assert(points(1).label == 0.0) - assert(points(1).features == Vectors.sparse(6, Seq())) - assert(points(2).label === 0.0) - assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0)))) - } - - try { - file.delete() - tempDir.delete() - } catch { - case t: Throwable => - } - } -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 2081fe46b17ef..27d41c7869aa0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -17,14 +17,18 @@ package org.apache.spark.mllib.util +import java.io.File + import org.scalatest.FunSuite import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm, squaredDistance => breezeSquaredDistance} +import com.google.common.base.Charsets +import com.google.common.io.Files -import org.apache.spark.mllib.util.MLUtils._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils._ class MLUtilsSuite extends FunSuite with LocalSparkContext { @@ -63,4 +67,43 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { assert(mean === Vectors.dense(2.0, 3.0, 4.0)) assert(std === Vectors.dense(1.0, 1.0, 1.0)) } + + test("loadLibSVMData") { + val lines = + """ + |+1 1:1.0 3:2.0 5:3.0 + |-1 + |-1 2:4.0 4:5.0 6:6.0 + """.stripMargin + val tempDir = Files.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, 6).collect() + val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect() + + for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) { + assert(points.length === 3) + assert(points(0).label === 1.0) + assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + assert(points(1).label == 0.0) + assert(points(1).features == Vectors.sparse(6, Seq())) + assert(points(2).label === 0.0) + assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0)))) + } + + val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MLUtils.multiclassLabelParser).collect() + assert(multiclassPoints.length === 3) + assert(multiclassPoints(0).label === 1.0) + assert(multiclassPoints(1).label === -1.0) + assert(multiclassPoints(2).label === -1.0) + + try { + file.delete() + tempDir.delete() + } catch { + case t: Throwable => + } + } }