diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index bba67bb5894a9..059a9336b5f9e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -50,23 +50,34 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log //Cache input RDD for speedup during multiple passes input.cache() + logDebug("algo = " + strategy.algo) + //Finding the splits and the corresponding bins (interval between the splits) using a sample + // of the input data. val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) logDebug("numSplits = " + bins(0).length) + + //Noting numBins for the input data strategy.numBins = bins(0).length + //The depth of the decision tree val maxDepth = strategy.maxDepth - + //The max number of nodes possible given the depth of the tree val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 + //Initalizing an array to hold filters applied to points for each node val filters = new Array[List[Filter]](maxNumNodes) + //The filter at the top node is an empty list filters(0) = List() + //Initializing an array to hold parent impurity calculations for each node val parentImpurities = new Array[Double](maxNumNodes) //Dummy value for top node (updated during first split calculation) - //parentImpurities(0) = Double.MinValue val nodes = new Array[Node](maxNumNodes) - logDebug("algo = " + strategy.algo) - + //The main-idea here is to perform level-wise training of the decision tree nodes thus + // reducing the passes over the data from l to log2(l) where l is the total number of nodes. + // Each data sample is checked for validity w.r.t to each node at a given level -- i.e., + // the sample is only used for the split calculation at the node if the sampled would have + // still survived the filters of the parent nodes. breakable { for (level <- 0 until maxDepth) { @@ -79,36 +90,41 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log level, filters, splits, bins) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - + //Extract info for nodes at the current level extractNodeInfo(nodeSplitStats, level, index, nodes) + //Extract info for nodes at the next lower level extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) logDebug("final best split = " + nodeSplitStats._1) } require(scala.math.pow(2, level) == splitsStatsForLevel.length) - + //Check whether all the nodes at the current level at leaves val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) - if (allLeaf) break + if (allLeaf) break //no more tree construction } } + //Initialize the top or root node of the tree val topNode = nodes(0) + //Build the full tree using the node info calculated in the level-wise best split calculations topNode.build(nodes) - val decisionTreeModel = { - return new DecisionTreeModel(topNode, strategy.algo) - } - return decisionTreeModel + //Return a decision tree model + return new DecisionTreeModel(topNode, strategy.algo) } - + /** + * Extract the decision tree node information for th given tree level and node index + */ private def extractNodeInfo( nodeSplitStats: (Split, InformationGainStats), - level: Int, index: Int, - nodes: Array[Node]) { + level: Int, + index: Int, + nodes: Array[Node]) + : Unit = { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 @@ -119,35 +135,37 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log nodes(nodeIndex) = node } + /** + * Extract the decision tree node information for the children of the node + */ private def extractInfoForLowerLevels( level: Int, index: Int, maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), parentImpurities: Array[Double], - filters: Array[List[Filter]]) { + filters: Array[List[Filter]]) + : Unit = { + // 0 corresponds to the left child node and 1 corresponds to the right child node. for (i <- 0 to 1) { - + //Calculating the index of the node from the node level and the index at the current level val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i - if (level < maxDepth - 1) { - val impurity = if (i == 0) { nodeSplitStats._2.leftImpurity } else { nodeSplitStats._2.rightImpurity } - logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) + //noting the parent impurities parentImpurities(nodeIndex) = impurity + //noting the parents filters for the child nodes val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) - for (filter <- filters(nodeIndex)) { logDebug("Filter = " + filter) } - } } }