From 98ec8d57a0a0897b093ced7e3284228ee21ce5f4 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 21 Dec 2013 22:39:29 -0800 Subject: [PATCH] tree building and prediction logic Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 248 +++++++++--------- .../spark/mllib/tree/DecisionTreeRunner.scala | 74 ++++++ .../apache/spark/mllib/tree/Strategy.scala | 8 +- .../mllib/tree/model/DecisionTreeModel.scala | 6 +- .../tree/model/InformationGainStats.scala | 2 +- .../apache/spark/mllib/tree/model/Node.scala | 60 +++++ 6 files changed, 272 insertions(+), 126 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 7749bdd687d1f..d8ffa12030f8d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -24,9 +24,10 @@ import org.apache.spark.{SparkContext, Logging} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.Split import org.apache.spark.mllib.tree.impurity.Gini +import scala.util.control.Breaks._ -class DecisionTree(val strategy : Strategy) extends Logging { +class DecisionTree(val strategy : Strategy) extends Serializable with Logging { def train(input : RDD[LabeledPoint]) : DecisionTreeModel = { @@ -36,6 +37,8 @@ class DecisionTree(val strategy : Strategy) extends Logging { //TODO: Find all splits and bins using quantiles including support for categorical features, single-pass //TODO: Think about broadcasting this val (splits, bins) = DecisionTree.find_splits_bins(input, strategy) + logDebug("numSplits = " + bins(0).length) + strategy.numBins = bins(0).length //TODO: Level-wise training of tree and obtain Decision Tree model val maxDepth = strategy.maxDepth @@ -44,47 +47,86 @@ class DecisionTree(val strategy : Strategy) extends Logging { val filters = new Array[List[Filter]](maxNumNodes) filters(0) = List() val parentImpurities = new Array[Double](maxNumNodes) - //Dummy value for top node (calculate from scratch during first split calculation) - parentImpurities(0) = Double.MinValue - - for (level <- 0 until maxDepth){ - - println("#####################################") - println("level = " + level) - println("#####################################") - - //Find best split for all nodes at a level - val numNodes= scala.math.pow(2,level).toInt - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters,splits,bins) - for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){ - for (i <- 0 to 1){ - val nodeIndex = (scala.math.pow(2,level+1)).toInt - 1 + 2*index + i - if(level < maxDepth - 1){ - val impurity = if (i == 0) nodeSplitStats._2.leftImpurity else nodeSplitStats._2.rightImpurity - println("nodeIndex = " + nodeIndex + ", impurity = " + impurity) - parentImpurities(nodeIndex) = impurity - println("updating nodeIndex = " + nodeIndex) - filters(nodeIndex) = new Filter(nodeSplitStats._1, if(i == 0) - 1 else 1) :: filters((nodeIndex-1)/2) - for (filter <- filters(nodeIndex)){ - println(filter) - } - } + //Dummy value for top node (updated during first split calculation) + //parentImpurities(0) = Double.MinValue + val nodes = new Array[Node](maxNumNodes) + + + breakable { + for (level <- 0 until maxDepth){ + + logDebug("#####################################") + logDebug("level = " + level) + logDebug("#####################################") + + //Find best split for all nodes at a level + val numNodes= scala.math.pow(2,level).toInt + val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters,splits,bins) + + for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){ + + extractNodeInfo(nodeSplitStats, level, index, nodes) + extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) + logDebug("final best split = " + nodeSplitStats._1) + } - println("final best split = " + nodeSplitStats._1) + require(scala.math.pow(2,level)==splitsStatsForLevel.length) + + val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0 ) + logDebug("all leaf = " + allLeaf) + if (allLeaf) break + } - require(scala.math.pow(2,level)==splitsStatsForLevel.length) + } + val topNode = nodes(0) + topNode.build(nodes) + val decisionTreeModel = { + return new DecisionTreeModel(topNode) } - //TODO: Extract decision tree model + return decisionTreeModel + } + - return new DecisionTreeModel() + private def extractNodeInfo(nodeSplitStats: (Split, InformationGainStats), level: Int, index: Int, nodes: Array[Node]) { + val split = nodeSplitStats._1 + val stats = nodeSplitStats._2 + val nodeIndex = scala.math.pow(2, level).toInt - 1 + index + val predict = { + val leftSamples = nodeSplitStats._2.leftSamples.toDouble + val rightSamples = nodeSplitStats._2.rightSamples.toDouble + val totalSamples = leftSamples + rightSamples + leftSamples / totalSamples + } + val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) + val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats)) + logDebug("Node = " + node) + nodes(nodeIndex) = node } + private def extractInfoForLowerLevels(level: Int, index: Int, maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), parentImpurities: Array[Double], filters: Array[List[Filter]]) { + for (i <- 0 to 1) { + + val nodeIndex = (scala.math.pow(2, level + 1)).toInt - 1 + 2 * index + i + + if (level < maxDepth - 1) { + + val impurity = if (i == 0) nodeSplitStats._2.leftImpurity else nodeSplitStats._2.rightImpurity + logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) + parentImpurities(nodeIndex) = impurity + filters(nodeIndex) = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) :: filters((nodeIndex - 1) / 2) + for (filter <- filters(nodeIndex)) { + logDebug("Filter = " + filter) + } + + } + } + } } -object DecisionTree extends Serializable { +object DecisionTree extends Serializable with Logging { /* Returns an Array[Split] of optimal splits for all nodes at a given level @@ -110,12 +152,12 @@ object DecisionTree extends Serializable { //Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt - println("numNodes = " + numNodes) + logDebug("numNodes = " + numNodes) //Find the number of features by looking at the first sample val numFeatures = input.take(1)(0).features.length - println("numFeatures = " + numFeatures) - val numSplits = strategy.numSplits - println("numSplits = " + numSplits) + logDebug("numFeatures = " + numFeatures) + val numSplits = strategy.numBins + logDebug("numSplits = " + numSplits) /*Find the filters used before reaching the current code*/ def findParentFilters(nodeIndex: Int): List[Filter] = { @@ -136,7 +178,7 @@ object DecisionTree extends Serializable { def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { //Leaf - if (parentFilters.length == 0 ){ + if ((level > 0) & (parentFilters.length == 0) ){ return false } @@ -156,9 +198,9 @@ object DecisionTree extends Serializable { /*Finds the right bin for the given feature*/ def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = { - //println("finding bin for labeled point " + labeledPoint.features(featureIndex)) + //logDebug("finding bin for labeled point " + labeledPoint.features(featureIndex)) //TODO: Do binary search - for (binIndex <- 0 until strategy.numSplits) { + for (binIndex <- 0 until strategy.numBins) { val bin = bins(featureIndex)(binIndex) //TODO: Remove this requirement post basic functional val lowThreshold = bin.lowSplit.threshold @@ -196,7 +238,7 @@ object DecisionTree extends Serializable { } } else { for (featureIndex <- 0 until numFeatures) { - //println("shift+featureIndex =" + (shift+featureIndex)) + //logDebug("shift+featureIndex =" + (shift+featureIndex)) arr(shift + featureIndex) = findBin(featureIndex, labeledPoint) } } @@ -239,7 +281,7 @@ object DecisionTree extends Serializable { //TODO: This length if different for regression val binAggregateLength = 2*numSplits * numFeatures * numNodes - println("binAggregageLength = " + binAggregateLength) + logDebug("binAggregageLength = " + binAggregateLength) /*Combines the aggregates from partitions @param agg1 Array containing aggregates from one or more partitions @@ -255,14 +297,14 @@ object DecisionTree extends Serializable { combinedAggregate } - println("input = " + input.count) + logDebug("input = " + input.count) val binMappedRDD = input.map(x => findBinsForLevel(x)) - println("binMappedRDD.count = " + binMappedRDD.count) + logDebug("binMappedRDD.count = " + binMappedRDD.count) //calculate bin aggregates val binAggregates = binMappedRDD.aggregate(Array.fill[Double](2*numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp) - println("binAggregates.length = " + binAggregates.length) - //binAggregates.foreach(x => println(x)) + logDebug("binAggregates.length = " + binAggregates.length) + //binAggregates.foreach(x => logDebug(x)) def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], @@ -312,21 +354,21 @@ object DecisionTree extends Serializable { def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) - //println("binData.length = " + binData.length) - //println("binData.sum = " + binData.sum) + //logDebug("binData.length = " + binData.length) + //logDebug("binData.sum = " + binData.sum) for (featureIndex <- 0 until numFeatures) { - //println("featureIndex = " + featureIndex) + //logDebug("featureIndex = " + featureIndex) val shift = 2*featureIndex*numSplits leftNodeAgg(featureIndex)(0) = binData(shift + 0) - //println("binData(shift + 0) = " + binData(shift + 0)) + //logDebug("binData(shift + 0) = " + binData(shift + 0)) leftNodeAgg(featureIndex)(1) = binData(shift + 1) - //println("binData(shift + 1) = " + binData(shift + 1)) + //logDebug("binData(shift + 1) = " + binData(shift + 1)) rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1))) - //println(binData(shift + (2 * (numSplits - 1)))) + //logDebug(binData(shift + (2 * (numSplits - 1)))) rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1) - //println(binData(shift + (2 * (numSplits - 1)) + 1)) + //logDebug(binData(shift + (2 * (numSplits - 1)) + 1)) for (splitIndex <- 1 until numSplits - 1) { - //println("splitIndex = " + splitIndex) + //logDebug("splitIndex = " + splitIndex) leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) leftNodeAgg(featureIndex)(2 * splitIndex + 1) @@ -347,7 +389,7 @@ object DecisionTree extends Serializable { for (featureIndex <- 0 until numFeatures) { for (index <- 0 until numSplits -1) { - //println("splitIndex = " + index) + //logDebug("splitIndex = " + index) gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity) } } @@ -360,12 +402,12 @@ object DecisionTree extends Serializable { @param binData Array[Double] of size 2*numSplits*numFeatures */ def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, InformationGainStats) = { - println("node impurity = " + nodeImpurity) + logDebug("node impurity = " + nodeImpurity) val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - //println("gains.size = " + gains.size) - //println("gains(0).size = " + gains(0).size) + //logDebug("gains.size = " + gains.size) + //logDebug("gains(0).size = " + gains(0).size) val (bestFeatureIndex,bestSplitIndex, gainStats) = { var bestFeatureIndex = 0 @@ -378,13 +420,13 @@ object DecisionTree extends Serializable { for (featureIndex <- 0 until numFeatures) { for (splitIndex <- 0 until numSplits - 1){ val gainStats = gains(featureIndex)(splitIndex) - //println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) + //logDebug("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) if(gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats bestFeatureIndex = featureIndex bestSplitIndex = splitIndex - //println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex) - //println( "gain stats = " + bestGainStats) + //logDebug("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex) + //logDebug( "gain stats = " + bestGainStats) } } } @@ -400,9 +442,9 @@ object DecisionTree extends Serializable { val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node val shift = 2*node*numSplits*numFeatures val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures) - println("nodeImpurityIndex = " + nodeImpurityIndex) + logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) - println("node impurity = " + parentNodeImpurity) + logDebug("node impurity = " + parentNodeImpurity) bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) } @@ -419,47 +461,42 @@ object DecisionTree extends Serializable { */ def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { - val numSplits = strategy.numSplits - println("numSplits = " + numSplits) + val count = input.count() + + //Find the number of features by looking at the first sample + val numFeatures = input.take(1)(0).features.length + + val maxBins = strategy.maxBins + val numBins = if (maxBins <= count) maxBins else count.toInt + logDebug("maxBins = " + numBins) //Calculate the number of sample for approximate quantile calculation //TODO: Justify this calculation - val requiredSamples = numSplits*numSplits - val count = input.count() + val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 - println("fraction of data used for calculating quantiles = " + fraction) - + logDebug("fraction of data used for calculating quantiles = " + fraction) //sampled input for RDD calculation val sampledInput = input.sample(false, fraction, 42).collect() val numSamples = sampledInput.length - //Find the number of features by looking at the first sample - val numFeatures = input.take(1)(0).features.length + val stride : Double = numSamples.toDouble/numBins + logDebug("stride = " + stride) strategy.quantileCalculationStrategy match { case "sort" => { - val splits = Array.ofDim[Split](numFeatures,numSplits-1) - val bins = Array.ofDim[Bin](numFeatures,numSplits) + val splits = Array.ofDim[Split](numFeatures,numBins-1) + val bins = Array.ofDim[Bin](numFeatures,numBins) //Find all splits for (featureIndex <- 0 until numFeatures){ val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - if (numSamples < numSplits) { - //TODO: Test this - println("numSamples = " + numSamples + ", less than numSplits = " + numSplits) - for (index <- 0 until numSplits-1) { - val split = new Split(featureIndex,featureSamples(index),"continuous") - splits(featureIndex)(index) = split - } - } else { - val stride : Double = numSamples.toDouble/numSplits - println("stride = " + stride) - for (index <- 0 until numSplits-1) { - val sampleIndex = (index+1)*stride.toInt - val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous") - splits(featureIndex)(index) = split - } + val stride : Double = numSamples.toDouble/numBins + logDebug("stride = " + stride) + for (index <- 0 until numBins-1) { + val sampleIndex = (index+1)*stride.toInt + val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous") + splits(featureIndex)(index) = split } } @@ -467,18 +504,18 @@ object DecisionTree extends Serializable { for (featureIndex <- 0 until numFeatures){ bins(featureIndex)(0) = new Bin(new DummyLowSplit("continuous"),splits(featureIndex)(0),"continuous") - for (index <- 1 until numSplits - 1){ + for (index <- 1 until numBins - 1){ val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),"continuous") bins(featureIndex)(index) = bin } - bins(featureIndex)(numSplits-1) - = new Bin(splits(featureIndex)(numSplits-3),new DummyHighSplit("continuous"),"continuous") + bins(featureIndex)(numBins-1) + = new Bin(splits(featureIndex)(numBins-3),new DummyHighSplit("continuous"),"continuous") } (splits,bins) } case "minMax" => { - (Array.ofDim[Split](numFeatures,numSplits),Array.ofDim[Bin](numFeatures,numSplits+2)) + (Array.ofDim[Split](numFeatures,numBins),Array.ofDim[Bin](numFeatures,numBins+2)) } case "approximateHistogram" => { throw new UnsupportedOperationException("approximate histogram not supported yet.") @@ -487,37 +524,6 @@ object DecisionTree extends Serializable { } } - def main(args: Array[String]) { - - val sc = new SparkContext(args(0), "DecisionTree") - val data = loadLabeledData(sc, args(1)) - val maxDepth = args(2).toInt - - val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = maxDepth, numSplits = 569) - val model = new DecisionTree(strategy).train(data) - - sc.stop() - } - - /** - * Load labeled data from a file. The data format used here is - * , ... - * where , are feature values in Double and is the corresponding label as Double. - * - * @param sc SparkContext - * @param dir Directory to the input data files. - * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is - * the label, and the second element represents the feature values (an array of Double). - */ - def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { - sc.textFile(dir).map { line => - val parts = line.trim().split(",") - val label = parts(0).toDouble - val features = parts.slice(1,parts.length).map(_.toDouble) - LabeledPoint(label, features) - } - } - } \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala new file mode 100644 index 0000000000000..542a3d9c3b33d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.mllib.tree.impurity.Gini +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.DecisionTreeModel + +object DecisionTreeRunner extends Logging { + + + def main(args: Array[String]) { + + val sc = new SparkContext(args(0), "DecisionTree") + val data = loadLabeledData(sc, args(1)) + val maxDepth = args(2).toInt + val maxBins = args(3).toInt + + val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = maxDepth, maxBins = maxBins) + val model = new DecisionTree(strategy).train(data) + + val accuracy = accuracyScore(model, data) + logDebug("accuracy = " + accuracy) + + sc.stop() + } + + /** + * Load labeled data from a file. The data format used here is + * , ... + * where , are feature values in Double and is the corresponding label as Double. + * + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is + * the label, and the second element represents the feature values (an array of Double). + */ + def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { + sc.textFile(dir).map { line => + val parts = line.trim().split(",") + val label = parts(0).toDouble + val features = parts.slice(1,parts.length).map(_.toDouble) + LabeledPoint(label, features) + } + } + + //TODO: Port them to a metrics package + def accuracyScore(model : DecisionTreeModel, data : RDD[LabeledPoint]) : Double = { + val correctCount = data.filter(y => model.predict(y.features) == y.label).count() + val count = data.count() + logDebug("correct prediction count = " + correctCount) + logDebug("data count = " + count) + correctCount.toDouble / count + } + + + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala index 7f88053043e0a..c688a478ce0d2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala @@ -18,11 +18,13 @@ package org.apache.spark.mllib.tree import org.apache.spark.mllib.tree.impurity.Impurity -case class Strategy ( +class Strategy ( val kind : String, val impurity : Impurity, val maxDepth : Int, - val numSplits : Int, - val quantileCalculationStrategy : String = "sort") { + val maxBins : Int, + val quantileCalculationStrategy : String = "sort") extends Serializable { + + var numBins : Int = Int.MinValue } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index d0465d8c6fb6f..1d7c03289c407 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -16,6 +16,10 @@ */ package org.apache.spark.mllib.tree.model -class DecisionTreeModel { +import org.apache.spark.mllib.regression.LabeledPoint + +class DecisionTreeModel(val topNode : Node) extends Serializable { + + def predict(features : Array[Double]) = if (topNode.predictIfLeaf(features) >= 0.5) 0.0 else 1.0 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 4ca02beec03c0..60a4f99b7f806 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -21,7 +21,7 @@ class InformationGainStats(val gain : Double, val leftImpurity : Double, val leftSamples : Long, val rightImpurity : Double, - val rightSamples : Long) { + val rightSamples : Long) extends Serializable { override def toString = "gain = " + gain + ", impurity = " + impurity + ", left impurity = " diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala new file mode 100644 index 0000000000000..a9210e10ae48b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.model + +import org.apache.spark.Logging +import org.apache.spark.mllib.regression.LabeledPoint + +class Node ( val id : Int, + val predict : Double, + val isLeaf : Boolean, + val split : Option[Split], + var leftNode : Option[Node], + var rightNode : Option[Node], + val stats : Option[InformationGainStats] + ) extends Serializable with Logging{ + + override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", split = " + split + ", stats = " + stats + + def build(nodes : Array[Node]) : Unit = { + + logDebug("building node " + id + " at level " + (scala.math.log(id + 1)/scala.math.log(2)).toInt ) + logDebug("stats = " + stats) + logDebug("predict = " + predict) + if (!isLeaf) { + val leftNodeIndex = id*2 + 1 + val rightNodeIndex = id*2 + 2 + leftNode = Some(nodes(leftNodeIndex)) + rightNode = Some(nodes(rightNodeIndex)) + leftNode.get.build(nodes) + rightNode.get.build(nodes) + } + } + + def predictIfLeaf(feature : Array[Double]) : Double = { + if (isLeaf) { + predict + } else{ + if (feature(split.get.feature) <= split.get.threshold) { + leftNode.get.predictIfLeaf(feature) + } else { + rightNode.get.predictIfLeaf(feature) + } + } + } + +}