Skip to content

Commit

Permalink
update DecisionTree to use RDD[Vector]
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 2, 2014
1 parent 11999c7 commit c26c4fc
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class DenseVector(val values: Array[Double]) extends Vector {
override def toArray: Array[Double] = values

private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values)

override def apply(i: Int) = values(i)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.mllib.linalg.{Vector, Vectors}

/**
* A class that implements a decision tree algorithm for classification and regression. It
Expand Down Expand Up @@ -295,7 +296,7 @@ object DecisionTree extends Serializable with Logging {
val numNodes = scala.math.pow(2, level).toInt
logDebug("numNodes = " + numNodes)
// Find the number of features by looking at the first sample.
val numFeatures = input.first().features.length
val numFeatures = input.first().features.size
logDebug("numFeatures = " + numFeatures)
val numBins = bins(0).length
logDebug("numBins = " + numBins)
Expand Down Expand Up @@ -902,7 +903,7 @@ object DecisionTree extends Serializable with Logging {
val count = input.count()

// Find the number of features by looking at the first sample
val numFeatures = input.take(1)(0).features.length
val numFeatures = input.take(1)(0).features.size

val maxBins = strategy.maxBins
val numBins = if (maxBins <= count) maxBins else count.toInt
Expand Down Expand Up @@ -1116,7 +1117,7 @@ object DecisionTree extends Serializable with Logging {
sc.textFile(dir).map { line =>
val parts = line.trim().split(",")
val label = parts(0).toDouble
val features = parts.slice(1,parts.length).map(_.toDouble)
val features = Vectors.dense(parts.slice(1,parts.length).map(_.toDouble))
LabeledPoint(label, features)
}
}
Expand All @@ -1127,7 +1128,7 @@ object DecisionTree extends Serializable with Logging {
*/
private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint],
threshold: Double = 0.5): Double = {
def predictedValue(features: Array[Double]) = {
def predictedValue(features: Vector) = {
if (model.predict(features) < threshold) 0.0 else 1.0
}
val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.model

import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vector

/**
* Model to store the decision tree parameters
Expand All @@ -33,7 +34,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
* @param features array representing a single data point
* @return Double prediction from the trained model
*/
def predict(features: Array[Double]): Double = {
def predict(features: Vector): Double = {
topNode.predictIfLeaf(features)
}

Expand All @@ -43,7 +44,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
* @param features RDD representing data points to be predicted
* @return RDD[Int] where each entry contains the corresponding prediction
*/
def predict(features: RDD[Array[Double]]): RDD[Double] = {
def predict(features: RDD[Vector]): RDD[Double] = {
features.map(x => predict(x))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.model

import org.apache.spark.Logging
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.linalg.Vector

/**
* Node in a decision tree
Expand Down Expand Up @@ -54,8 +55,8 @@ class Node (
logDebug("stats = " + stats)
logDebug("predict = " + predict)
if (!isLeaf) {
val leftNodeIndex = id*2 + 1
val rightNodeIndex = id*2 + 2
val leftNodeIndex = id * 2 + 1
val rightNodeIndex = id * 2 + 2
leftNode = Some(nodes(leftNodeIndex))
rightNode = Some(nodes(rightNodeIndex))
leftNode.get.build(nodes)
Expand All @@ -68,7 +69,7 @@ class Node (
* @param feature feature value
* @return predicted value
*/
def predictIfLeaf(feature: Array[Double]) : Double = {
def predictIfLeaf(feature: Vector) : Double = {
if (isLeaf) {
predict
} else{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.mllib.tree.model.Filter
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.linalg.Vectors

class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {

Expand Down Expand Up @@ -396,7 +397,7 @@ object DecisionTreeSuite {
def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000){
val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i))
val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
arr(i) = lp
}
arr
Expand All @@ -405,7 +406,7 @@ object DecisionTreeSuite {
def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000){
val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i))
val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i))
arr(i) = lp
}
arr
Expand All @@ -415,9 +416,9 @@ object DecisionTreeSuite {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000){
if (i < 600){
arr(i) = new LabeledPoint(1.0,Array(0.0,1.0))
arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0))
} else {
arr(i) = new LabeledPoint(0.0,Array(1.0,0.0))
arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0))
}
}
arr
Expand Down

0 comments on commit c26c4fc

Please sign in to comment.