Skip to content

Commit

Permalink
add libSVMFile to MLContext
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 27, 2014
1 parent f0fe616 commit 78c4671
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 0 deletions.
61 changes: 61 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/MLContext.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib

import org.apache.spark.SparkContext

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD

class MLContext(self: SparkContext) {
/**
* Reads labeled data in the LIBSVM format into an RDD[LabeledPoint].
* The LIBSVM format is a text-based format used by LIBSVM (http://www.csie.ntu.edu.tw/~cjlin/libsvm/).
* Each line represents a labeled sparse feature vector using the following format:
* {{{label index1:value1 index2:value2 ...}}}
* where the indices are one-based and in ascending order.
* This method parses each line into a [[org.apache.spark.mllib.regression.LabeledPoint]] instance,
* where the feature indices are converted to zero-based.
*
* @param path file or directory path in any Hadoop-supported file system URI
* @param numFeatures number of features
* @param labelParser parser for labels, default: _.toDouble
* @return labeled data stored as an RDD[LabeledPoint]
*/
def libSVMFile(
path: String,
numFeatures: Int,
labelParser: String => Double = _.toDouble): RDD[LabeledPoint] = {
self.textFile(path).map(_.trim).filter(!_.isEmpty).map { line =>
val items = line.split(' ')
val label = labelParser(items.head)
val features = Vectors.sparse(numFeatures, items.tail.map { item =>
val indexAndValue = item.split(':')
val index = indexAndValue(0).toInt - 1
val value = indexAndValue(1).toDouble
(index, value)
})
LabeledPoint(label, features)
}
}
}

object MLContext {
implicit def sparkContextToMLContext(sc: SparkContext): MLContext = new MLContext(sc)
}
51 changes: 51 additions & 0 deletions mllib/src/test/scala/org/apache/spark/mllib/MLContextSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib

import org.apache.spark.mllib.MLContext._
import org.apache.spark.mllib.util.LocalSparkContext
import org.scalatest.FunSuite
import com.google.common.io.Files
import java.io.File
import com.google.common.base.Charsets
import org.apache.spark.mllib.linalg.Vectors

class MLContextSuite extends FunSuite with LocalSparkContext {
test("libSVMFile") {
val lines =
"""
|1 1:1.0 3:2.0 5:3.0
|0 2:4.0 4:5.0 6:6.0
""".stripMargin
val tempDir = Files.createTempDir()
val file = new File(tempDir.getPath, "part-00000")
Files.write(lines, file, Charsets.US_ASCII)
val points = sc.libSVMFile(tempDir.toURI.toString, 6).collect()
assert(points.length === 2)
assert(points(0).label === 1.0)
assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
assert(points(1).label === 0.0)
assert(points(1).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
try {
file.delete()
tempDir.delete()
} catch {
case t: Throwable =>
}
}
}

0 comments on commit 78c4671

Please sign in to comment.