diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 883ddcf74999e..ddb78d3903049 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -45,7 +45,8 @@ class DecisionTree(val strategy : Strategy) { for (level <- 0 until maxDepth){ //Find best split for all nodes at a level val numNodes= scala.math.pow(2,level).toInt - val bestSplits = DecisionTree.findBestSplits(input, strategy, level, filters,splits,bins) + //TODO: Change the input parent impurities values + val bestSplits = DecisionTree.findBestSplits(input, Array(0.0), strategy, level, filters,splits,bins) //TODO: update filters and decision tree model } @@ -60,6 +61,7 @@ object DecisionTree extends Serializable { Returns an Array[Split] of optimal splits for all nodes at a given level @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree + @param parentImpurities Impurities for all parent nodes for the current level @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree @param level Level of the tree @param filters Filter for all nodes at a given level @@ -70,13 +72,14 @@ object DecisionTree extends Serializable { */ def findBestSplits( input : RDD[LabeledPoint], + parentImpurities : Array[Double], strategy: Strategy, level: Int, filters : Array[List[Filter]], splits : Array[Array[Split]], bins : Array[Array[Bin]]) : Array[Split] = { - //TODO: Move these calculations outside + //Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt println("numNodes = " + numNodes) //Find the number of features by looking at the first sample @@ -118,6 +121,7 @@ object DecisionTree extends Serializable { /*Finds the right bin for the given feature*/ def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = { + println("finding bin for labeled point " + labeledPoint.features(featureIndex)) //TODO: Do binary search for (binIndex <- 0 until strategy.numSplits) { val bin = bins(featureIndex)(binIndex) @@ -134,7 +138,7 @@ object DecisionTree extends Serializable { } /*Finds bins for all nodes (and all features) at a given level - k features, l nodes + k features, l nodes (level = log2(l)) Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk Denotes invalid sample for tree by noting bin for feature 1 as -1 */ @@ -167,33 +171,53 @@ object DecisionTree extends Serializable { } /* - Performs a sequential aggreation over a partition + Performs a sequential aggregation over a partition. - @param agg Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification - and 3*numSplits*numFeatures*numNodes for regression + for p bins, k features, l nodes (level = log2(l)) storage is of the form: + b111_left_count,b111_right_count, .... , bpk1_left_count, bpk1_right_count, .... , bpkl_left_count, bpkl_right_count + + @param agg Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification @param arr Array[Double] of size 1+(numFeatures*numNodes) - @return Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification - and 3*numSplits*numFeatures*numNodes for regression + @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification */ def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { + //TODO: Requires logic for regressions for (node <- 0 until numNodes) { val validSignalIndex = 1+numFeatures*node val isSampleValidForNode = if(arr(validSignalIndex) != -1) true else false - if(isSampleValidForNode) { + if(isSampleValidForNode){ + val label = arr(0) for (feature <- 0 until numFeatures){ val arrShift = 1 + numFeatures*node - val aggShift = numSplits*numFeatures*node + val aggShift = 2*numSplits*numFeatures*node val arrIndex = arrShift + feature - val aggIndex = aggShift + feature*numSplits + arr(arrIndex).toInt - agg(aggIndex) = agg(aggIndex) + 1 + val aggIndex = aggShift + 2*feature*numSplits + arr(arrIndex).toInt*2 + label match { + case(0.0) => agg(aggIndex) = agg(aggIndex) + 1 + case(1.0) => agg(aggIndex+1) = agg(aggIndex+1) + 1 + } } } } agg } - def binCombOp(par1 : Array[Double], par2: Array[Double]) : Array[Double] = { - par1 + //TODO: This length if different for regression + val binAggregateLength = 2*numSplits * numFeatures * numNodes + println("binAggregageLength = " + binAggregateLength) + + /*Combines the aggregates from partitions + @param agg1 Array containing aggregates from one or more partitions + @param agg2 Array contianing aggregates from one or more partitions + + @return Combined aggregate from agg1 and agg2 + */ + def binCombOp(agg1 : Array[Double], agg2: Array[Double]) : Array[Double] = { + val combinedAggregate = new Array[Double](binAggregateLength) + for (index <- 0 until binAggregateLength){ + combinedAggregate(index) = agg1(index) + agg2(index) + } + combinedAggregate } println("input = " + input.count) @@ -201,15 +225,125 @@ object DecisionTree extends Serializable { println("binMappedRDD.count = " + binMappedRDD.count) //calculate bin aggregates - val binAggregates = binMappedRDD.aggregate(Array.fill[Double](numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp) - - //find best split + val binAggregates = binMappedRDD.aggregate(Array.fill[Double](2*numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp) println("binAggregates.length = " + binAggregates.length) + binAggregates.foreach(x => println(x)) + + + def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], featureIndex: Int, index: Int, rightNodeAgg: Array[Array[Double]], topImpurity: Double): Double = { + + val left0Count = leftNodeAgg(featureIndex)(2 * index) + val left1Count = leftNodeAgg(featureIndex)(2 * index + 1) + val leftCount = left0Count + left1Count + println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount) + val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) + val right0Count = rightNodeAgg(featureIndex)(2 * index) + val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) + val rightCount = right0Count + right1Count + println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount) + val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) + + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) + + topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity + + } + + /* + Extracts left and right split aggregates + + @param binData Array[Double] of size 2*numFeatures*numSplits + @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], Array[Double]) where + each array is of size(numFeature,2*(numSplits-1)) + */ + def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { + val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) + println("binData.length = " + binData.length) + println("binData.sum = " + binData.sum) + for (featureIndex <- 0 until numFeatures) { + println("featureIndex = " + featureIndex) + val shift = 2*featureIndex*numSplits + leftNodeAgg(featureIndex)(0) = binData(shift + 0) + println("binData(shift + 0) = " + binData(shift + 0)) + leftNodeAgg(featureIndex)(1) = binData(shift + 1) + println("binData(shift + 1) = " + binData(shift + 1)) + rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1))) + println(binData(shift + (2 * (numSplits - 1)))) + rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1) + println(binData(shift + (2 * (numSplits - 1)) + 1)) + for (splitIndex <- 1 until numSplits - 1) { + println("splitIndex = " + splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex) + = binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) + leftNodeAgg(featureIndex)(2 * splitIndex + 1) + = binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) + rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex)) + = binData(shift + (2 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex)) + rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex) + 1) + = binData(shift + (2 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex) + 1) + } + } + (leftNodeAgg, rightNodeAgg) + } + + def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double): Array[Array[Double]] = { + + val gains = Array.ofDim[Double](numFeatures, numSplits - 1) + + for (featureIndex <- 0 until numFeatures) { + for (index <- 0 until numSplits -1) { + println("splitIndex = " + index) + gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity) + } + } + gains + } + + /* + Find the best split for a node given bin aggregate data + + @param binData Array[Double] of size 2*numSplits*numFeatures + */ + def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : Split = { + println("node impurity = " + nodeImpurity) + val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) + val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) + + println("gains.size = " + gains.size) + println("gains(0).size = " + gains(0).size) + + val (bestFeatureIndex,bestSplitIndex) = { + var bestFeatureIndex = 0 + var bestSplitIndex = 0 + var maxGain = Double.MinValue + for (featureIndex <- 0 until numFeatures) { + for (splitIndex <- 0 until numSplits - 1){ + val gain = gains(featureIndex)(splitIndex) + println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) + if(gain > maxGain) { + maxGain = gain + bestFeatureIndex = featureIndex + bestSplitIndex = splitIndex + println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex + ", maxGain = " + maxGain) + } + } + } + (bestFeatureIndex,bestSplitIndex) + } + + splits(bestFeatureIndex)(bestSplitIndex) + } + //Calculate best splits for all nodes at a given level val bestSplits = new Array[Split](numNodes) for (node <- 0 until numNodes){ - val binsForNode = binAggregates.slice(node,numSplits*node) + val shift = 2*node*numSplits*numFeatures + val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures) + val parentNodeImpurity = parentImpurities(node/2) + bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) } bestSplits diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 0e8c9ba850e4f..e886c40901b45 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -69,7 +69,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(splits(0).length==99) assert(bins(0).length==100) println(splits(1)(98)) - DecisionTree.findBestSplits(rdd,strategy,0,Array[List[Filter]](),splits,bins) + DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) } }