diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 21989d9a025a0..2cea58cd3fd22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -151,6 +151,8 @@ class DenseVector(val values: Array[Double]) extends Vector { override def toArray: Array[Double] = values private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) + + override def apply(i: Int) = values(i) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 33205b919db8f..dee9594a9dd79 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * A class that implements a decision tree algorithm for classification and regression. It @@ -295,7 +296,7 @@ object DecisionTree extends Serializable with Logging { val numNodes = scala.math.pow(2, level).toInt logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. - val numFeatures = input.first().features.length + val numFeatures = input.first().features.size logDebug("numFeatures = " + numFeatures) val numBins = bins(0).length logDebug("numBins = " + numBins) @@ -902,7 +903,7 @@ object DecisionTree extends Serializable with Logging { val count = input.count() // Find the number of features by looking at the first sample - val numFeatures = input.take(1)(0).features.length + val numFeatures = input.take(1)(0).features.size val maxBins = strategy.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt @@ -1116,7 +1117,7 @@ object DecisionTree extends Serializable with Logging { sc.textFile(dir).map { line => val parts = line.trim().split(",") val label = parts(0).toDouble - val features = parts.slice(1,parts.length).map(_.toDouble) + val features = Vectors.dense(parts.slice(1,parts.length).map(_.toDouble)) LabeledPoint(label, features) } } @@ -1127,7 +1128,7 @@ object DecisionTree extends Serializable with Logging { */ private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint], threshold: Double = 0.5): Double = { - def predictedValue(features: Array[Double]) = { + def predictedValue(features: Vector) = { if (model.predict(features) < threshold) 0.0 else 1.0 } val correctCount = data.filter(y => predictedValue(y.features) == y.label).count() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index a8bbf21daec01..a6dca84a2ce09 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector /** * Model to store the decision tree parameters @@ -33,7 +34,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @param features array representing a single data point * @return Double prediction from the trained model */ - def predict(features: Array[Double]): Double = { + def predict(features: Vector): Double = { topNode.predictIfLeaf(features) } @@ -43,7 +44,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @param features RDD representing data points to be predicted * @return RDD[Int] where each entry contains the corresponding prediction */ - def predict(features: RDD[Array[Double]]): RDD[Double] = { + def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index ea4693c5c2f4e..aac3f9ce308f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.Logging import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.linalg.Vector /** * Node in a decision tree @@ -54,8 +55,8 @@ class Node ( logDebug("stats = " + stats) logDebug("predict = " + predict) if (!isLeaf) { - val leftNodeIndex = id*2 + 1 - val rightNodeIndex = id*2 + 2 + val leftNodeIndex = id * 2 + 1 + val rightNodeIndex = id * 2 + 2 leftNode = Some(nodes(leftNodeIndex)) rightNode = Some(nodes(rightNodeIndex)) leftNode.get.build(nodes) @@ -68,7 +69,7 @@ class Node ( * @param feature feature value * @return predicted value */ - def predictIfLeaf(feature: Array[Double]) : Double = { + def predictIfLeaf(feature: Vector) : Double = { if (isLeaf) { predict } else{ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 4349c7000a0ae..350130c914f26 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.linalg.Vectors class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { @@ -396,7 +397,7 @@ object DecisionTreeSuite { def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000){ - val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i)) + val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } arr @@ -405,7 +406,7 @@ object DecisionTreeSuite { def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000){ - val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i)) + val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i)) arr(i) = lp } arr @@ -415,9 +416,9 @@ object DecisionTreeSuite { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000){ if (i < 600){ - arr(i) = new LabeledPoint(1.0,Array(0.0,1.0)) + arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) } else { - arr(i) = new LabeledPoint(0.0,Array(1.0,0.0)) + arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0)) } } arr