diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 575eb4e8d825f..883ddcf74999e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -37,7 +37,6 @@ class DecisionTree(val strategy : Strategy) { val (splits, bins) = DecisionTree.find_splits_bins(input, strategy) //TODO: Level-wise training of tree and obtain Decision Tree model - val maxDepth = strategy.maxDepth val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1 @@ -55,8 +54,20 @@ class DecisionTree(val strategy : Strategy) { } -object DecisionTree extends Logging { +object DecisionTree extends Serializable { + + /* + Returns an Array[Split] of optimal splits for all nodes at a given level + + @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree + @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree + @param level Level of the tree + @param filters Filter for all nodes at a given level + @param splits possible splits for all features + @param bins possible bins for all features + @return Array[Split] instance for best splits for all nodes at a given level. + */ def findBestSplits( input : RDD[LabeledPoint], strategy: Strategy, @@ -65,6 +76,16 @@ object DecisionTree extends Logging { splits : Array[Array[Split]], bins : Array[Array[Bin]]) : Array[Split] = { + //TODO: Move these calculations outside + val numNodes = scala.math.pow(2, level).toInt + println("numNodes = " + numNodes) + //Find the number of features by looking at the first sample + val numFeatures = input.take(1)(0).features.length + println("numFeatures = " + numFeatures) + val numSplits = strategy.numSplits + println("numSplits = " + numSplits) + + /*Find the filters used before reaching the current code*/ def findParentFilters(nodeIndex: Int): List[Filter] = { if (level == 0) { List[Filter]() @@ -75,6 +96,10 @@ object DecisionTree extends Logging { } } + /*Find whether the sample is valid input for the current node. + + In other words, does it pass through all the filters for the current node. + */ def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { for (filter <- parentFilters) { @@ -91,48 +116,49 @@ object DecisionTree extends Logging { true } + /*Finds the right bin for the given feature*/ def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = { - //TODO: Do binary search for (binIndex <- 0 until strategy.numSplits) { val bin = bins(featureIndex)(binIndex) - //TODO: Remove this requirement post basic functional testing - require(bin.lowSplit.feature == featureIndex) - require(bin.highSplit.feature == featureIndex) + //TODO: Remove this requirement post basic functional val lowThreshold = bin.lowSplit.threshold val highThreshold = bin.highSplit.threshold val features = labeledPoint.features - if ((lowThreshold < features(featureIndex)) & (highThreshold < features(featureIndex))) { + if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) { return binIndex } } throw new UnknownError("no bin was found.") } - def findBinsForLevel: Array[Double] = { - val numNodes = scala.math.pow(2, level).toInt - //Find the number of features by looking at the first sample - val numFeatures = input.take(1)(0).features.length + /*Finds bins for all nodes (and all features) at a given level + k features, l nodes + Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk + Denotes invalid sample for tree by noting bin for feature 1 as -1 + */ + def findBinsForLevel(labeledPoint : LabeledPoint) : Array[Double] = { + - //TODO: Bit pack more by removing redundant label storage // calculating bin index and label per feature per node - val arr = new Array[Double](2 * numFeatures * numNodes) + val arr = new Array[Double](1+(numFeatures * numNodes)) + arr(0) = labeledPoint.label for (nodeIndex <- 0 until numNodes) { val parentFilters = findParentFilters(nodeIndex) //Find out whether the sample qualifies for the particular node val sampleValid = isSampleValid(parentFilters, labeledPoint) - val shift = 2 * numFeatures * nodeIndex - if (sampleValid) { + val shift = 1 + numFeatures * nodeIndex + if (!sampleValid) { //Add to invalid bin index -1 - for (featureIndex <- shift until (shift + numFeatures) by 2) { - arr(featureIndex + 1) = -1 - arr(featureIndex + 2) = labeledPoint.label + for (featureIndex <- 0 until numFeatures) { + arr(shift+featureIndex) = -1 + //TODO: Break since marking one bin is sufficient } } else { for (featureIndex <- 0 until numFeatures) { - arr(shift + (featureIndex * 2) + 1) = findBin(featureIndex, labeledPoint) - arr(shift + (featureIndex * 2) + 2) = labeledPoint.label + //println("shift+featureIndex =" + (shift+featureIndex)) + arr(shift + featureIndex) = findBin(featureIndex, labeledPoint) } } @@ -140,30 +166,80 @@ object DecisionTree extends Logging { arr } - val binMappedRDD = input.map(labeledPoint => findBinsForLevel) + /* + Performs a sequential aggreation over a partition + + @param agg Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification + and 3*numSplits*numFeatures*numNodes for regression + @param arr Array[Double] of size 1+(numFeatures*numNodes) + @return Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification + and 3*numSplits*numFeatures*numNodes for regression + */ + def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { + for (node <- 0 until numNodes) { + val validSignalIndex = 1+numFeatures*node + val isSampleValidForNode = if(arr(validSignalIndex) != -1) true else false + if(isSampleValidForNode) { + for (feature <- 0 until numFeatures){ + val arrShift = 1 + numFeatures*node + val aggShift = numSplits*numFeatures*node + val arrIndex = arrShift + feature + val aggIndex = aggShift + feature*numSplits + arr(arrIndex).toInt + agg(aggIndex) = agg(aggIndex) + 1 + } + } + } + agg + } + + def binCombOp(par1 : Array[Double], par2: Array[Double]) : Array[Double] = { + par1 + } + + println("input = " + input.count) + val binMappedRDD = input.map(x => findBinsForLevel(x)) + println("binMappedRDD.count = " + binMappedRDD.count) //calculate bin aggregates + + val binAggregates = binMappedRDD.aggregate(Array.fill[Double](numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp) + //find best split + println("binAggregates.length = " + binAggregates.length) - Array[Split]() + val bestSplits = new Array[Split](numNodes) + for (node <- 0 until numNodes){ + val binsForNode = binAggregates.slice(node,numSplits*node) + } + + bestSplits } + /* + Returns split and bins for decision tree calculation. + + @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree + @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree + @return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an + Array[Array[Bin]] of size (numFeatures,numSplits1) + */ def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { val numSplits = strategy.numSplits - logDebug("numSplits = " + numSplits) + println("numSplits = " + numSplits) //Calculate the number of sample for approximate quantile calculation //TODO: Justify this calculation val requiredSamples = numSplits*numSplits val count = input.count() val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 - logDebug("fraction of data used for calculating quantiles = " + fraction) + println("fraction of data used for calculating quantiles = " + fraction) //sampled input for RDD calculation val sampledInput = input.sample(false, fraction, 42).collect() val numSamples = sampledInput.length + //TODO: Remove this requirement require(numSamples > numSplits, "length of input samples should be greater than numSplits") //Find the number of features by looking at the first sample diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala index a7077f0914033..7f88053043e0a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/Strategy.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.mllib.tree.impurity.Impurity -class Strategy ( +case class Strategy ( val kind : String, val impurity : Impurity, val maxDepth : Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index fadebb0c203eb..4b6e679820f59 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.mllib.tree.impurity -trait Impurity { +trait Impurity extends Serializable { def calculate(c0 : Double, c1 : Double): Double diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 22c6b6eca1876..0e8c9ba850e4f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -28,6 +28,7 @@ import org.jblas._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity.Gini +import org.apache.spark.mllib.tree.model.Filter class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { @@ -54,6 +55,23 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) println(splits(1)(98)) } + + test("stump"){ + val arr = DecisionTreeSuite.generateReverseOrderedLabeledPoints() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy("regression",Gini,3,100,"sort") + val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + assert(splits.length==2) + assert(splits(0).length==99) + assert(bins.length==2) + assert(bins(0).length==100) + assert(splits(0).length==99) + assert(bins(0).length==100) + println(splits(1)(98)) + DecisionTree.findBestSplits(rdd,strategy,0,Array[List[Filter]](),splits,bins) + } + } object DecisionTreeSuite {