Skip to content

Commit

Permalink
more documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed Mar 10, 2014
1 parent 794ff4d commit d1ef4f6
Showing 1 changed file with 39 additions and 21 deletions.
60 changes: 39 additions & 21 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,34 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log

//Cache input RDD for speedup during multiple passes
input.cache()
logDebug("algo = " + strategy.algo)

//Finding the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
logDebug("numSplits = " + bins(0).length)

//Noting numBins for the input data
strategy.numBins = bins(0).length

//The depth of the decision tree
val maxDepth = strategy.maxDepth

//The max number of nodes possible given the depth of the tree
val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1
//Initalizing an array to hold filters applied to points for each node
val filters = new Array[List[Filter]](maxNumNodes)
//The filter at the top node is an empty list
filters(0) = List()
//Initializing an array to hold parent impurity calculations for each node
val parentImpurities = new Array[Double](maxNumNodes)
//Dummy value for top node (updated during first split calculation)
//parentImpurities(0) = Double.MinValue
val nodes = new Array[Node](maxNumNodes)

logDebug("algo = " + strategy.algo)

//The main-idea here is to perform level-wise training of the decision tree nodes thus
// reducing the passes over the data from l to log2(l) where l is the total number of nodes.
// Each data sample is checked for validity w.r.t to each node at a given level -- i.e.,
// the sample is only used for the split calculation at the node if the sampled would have
// still survived the filters of the parent nodes.
breakable {
for (level <- 0 until maxDepth) {

Expand All @@ -79,36 +90,41 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log
level, filters, splits, bins)

for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {

//Extract info for nodes at the current level
extractNodeInfo(nodeSplitStats, level, index, nodes)
//Extract info for nodes at the next lower level
extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
filters)
logDebug("final best split = " + nodeSplitStats._1)

}
require(scala.math.pow(2, level) == splitsStatsForLevel.length)

//Check whether all the nodes at the current level at leaves
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
logDebug("all leaf = " + allLeaf)
if (allLeaf) break
if (allLeaf) break //no more tree construction

}
}

//Initialize the top or root node of the tree
val topNode = nodes(0)
//Build the full tree using the node info calculated in the level-wise best split calculations
topNode.build(nodes)

val decisionTreeModel = {
return new DecisionTreeModel(topNode, strategy.algo)
}
return decisionTreeModel
//Return a decision tree model
return new DecisionTreeModel(topNode, strategy.algo)
}


/**
* Extract the decision tree node information for th given tree level and node index
*/
private def extractNodeInfo(
nodeSplitStats: (Split, InformationGainStats),
level: Int, index: Int,
nodes: Array[Node]) {
level: Int,
index: Int,
nodes: Array[Node])
: Unit = {

val split = nodeSplitStats._1
val stats = nodeSplitStats._2
Expand All @@ -119,35 +135,37 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log
nodes(nodeIndex) = node
}

/**
* Extract the decision tree node information for the children of the node
*/
private def extractInfoForLowerLevels(
level: Int,
index: Int,
maxDepth: Int,
nodeSplitStats: (Split, InformationGainStats),
parentImpurities: Array[Double],
filters: Array[List[Filter]]) {
filters: Array[List[Filter]])
: Unit = {

// 0 corresponds to the left child node and 1 corresponds to the right child node.
for (i <- 0 to 1) {

//Calculating the index of the node from the node level and the index at the current level
val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i

if (level < maxDepth - 1) {

val impurity = if (i == 0) {
nodeSplitStats._2.leftImpurity
} else {
nodeSplitStats._2.rightImpurity
}

logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity)
//noting the parent impurities
parentImpurities(nodeIndex) = impurity
//noting the parents filters for the child nodes
val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1)
filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2)

for (filter <- filters(nodeIndex)) {
logDebug("Filter = " + filter)
}

}
}
}
Expand Down

0 comments on commit d1ef4f6

Please sign in to comment.