diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 4fd030e3a3c05..a2a3dba213e7f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -26,7 +26,7 @@ import org.apache.spark.mllib.tree.model.Split import org.apache.spark.mllib.tree.impurity.Gini -class DecisionTree(val strategy : Strategy) { +class DecisionTree(val strategy : Strategy) extends Logging { def train(input : RDD[LabeledPoint]) : DecisionTreeModel = { @@ -42,20 +42,43 @@ class DecisionTree(val strategy : Strategy) { val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1 val filters = new Array[List[Filter]](maxNumNodes) + filters(0) = List() + val parentImpurities = new Array[Double](maxNumNodes) + //Dummy value for top node (calculate from scratch during first split calculation) + parentImpurities(0) = Double.MinValue for (level <- 0 until maxDepth){ + + println("#####################################") + println("level = " + level) + println("#####################################") + //Find best split for all nodes at a level val numNodes= scala.math.pow(2,level).toInt - //TODO: Change the input parent impurities values - val splits_stats_for_level = DecisionTree.findBestSplits(input, Array(2.0), strategy, level, filters,splits,bins) - for (tmp <- splits_stats_for_level){ - println("final best split = " + tmp._1) + val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters,splits,bins) + for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){ + for (i <- 0 to 1){ + val nodeIndex = (scala.math.pow(2,level+1)).toInt - 1 + 2*index + i + if(level < maxDepth - 1){ + val impurity = if (i == 0) nodeSplitStats._2.leftImpurity else nodeSplitStats._2.rightImpurity + println("nodeIndex = " + nodeIndex + ", impurity = " + impurity) + parentImpurities(nodeIndex) = impurity + println("updating nodeIndex = " + nodeIndex) + filters(nodeIndex) = new Filter(nodeSplitStats._1, if(i == 0) - 1 else 1) :: filters((nodeIndex-1)/2) + for (filter <- filters(nodeIndex)){ + println(filter) + } + } + } + println("final best split = " + nodeSplitStats._1) } - //TODO: update filters and decision tree model - require(scala.math.pow(2,level)==splits_stats_for_level.length) + require(scala.math.pow(2,level)==splitsStatsForLevel.length) + } + //TODO: Extract decision tree model + return new DecisionTreeModel() } @@ -99,7 +122,7 @@ object DecisionTree extends Serializable { if (level == 0) { List[Filter]() } else { - val nodeFilterIndex = scala.math.pow(2, level).toInt + nodeIndex + val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex //val parentFilterIndex = nodeFilterIndex / 2 //TODO: Check left or right filter filters(nodeFilterIndex) @@ -155,11 +178,11 @@ object DecisionTree extends Serializable { // calculating bin index and label per feature per node val arr = new Array[Double](1+(numFeatures * numNodes)) arr(0) = labeledPoint.label - for (nodeIndex <- 0 until numNodes) { - val parentFilters = findParentFilters(nodeIndex) + for (index <- 0 until numNodes) { + val parentFilters = findParentFilters(index) //Find out whether the sample qualifies for the particular node val sampleValid = isSampleValid(parentFilters, labeledPoint) - val shift = 1 + numFeatures * nodeIndex + val shift = 1 + numFeatures * index if (!sampleValid) { //Add to invalid bin index -1 for (featureIndex <- 0 until numFeatures) { @@ -251,22 +274,26 @@ object DecisionTree extends Serializable { val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) val rightCount = right0Count + right1Count + val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) + if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong) if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0) - //println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount) val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) - - - //println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount) val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) val leftWeight = leftCount.toDouble / (leftCount + rightCount) val rightWeight = rightCount.toDouble / (leftCount + rightCount) - val gain = topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } - new InformationGainStats(gain,topImpurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) + new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) } @@ -339,7 +366,7 @@ object DecisionTree extends Serializable { var bestFeatureIndex = 0 var bestSplitIndex = 0 //Initialization with infeasible values - var bestGainStats = new InformationGainStats(-1.0,-1.0,-1.0,0,-1.0,0) + var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,0,-1.0,0) // var maxGain = Double.MinValue // var leftSamples = Long.MinValue // var rightSamples = Long.MinValue @@ -351,8 +378,8 @@ object DecisionTree extends Serializable { bestGainStats = gainStats bestFeatureIndex = featureIndex bestSplitIndex = splitIndex - println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex - + ", gain stats = " + bestGainStats) + //println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex) + //println( "gain stats = " + bestGainStats) } } } @@ -365,9 +392,12 @@ object DecisionTree extends Serializable { //Calculate best splits for all nodes at a given level val bestSplits = new Array[(Split, InformationGainStats)](numNodes) for (node <- 0 until numNodes){ + val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node val shift = 2*node*numSplits*numFeatures val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures) - val parentNodeImpurity = parentImpurities(node/2) + println("nodeImpurityIndex = " + nodeImpurityIndex) + val parentNodeImpurity = parentImpurities(nodeImpurityIndex) + println("node impurity = " + parentNodeImpurity) bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) } @@ -456,8 +486,9 @@ object DecisionTree extends Serializable { val sc = new SparkContext(args(0), "DecisionTree") val data = loadLabeledData(sc, args(1)) + val maxDepth = args(2).toInt - val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = 2, numSplits = 569) + val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = maxDepth, numSplits = 569) val model = new DecisionTree(strategy).train(data) sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index a95f0431c6e8f..3396a015e7858 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -18,11 +18,15 @@ package org.apache.spark.mllib.tree.impurity object Gini extends Impurity { - def calculate(c0 : Double, c1 : Double): Double = { - val total = c0 + c1 - val f0 = c0 / total - val f1 = c1 / total - 1 - f0*f0 - f1*f1 - } + def calculate(c0 : Double, c1 : Double): Double = { + if (c0 == 0 || c1 == 0) { + 0 + } else { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + 1 - f0*f0 - f1*f1 + } + } }