Skip to content

Commit

Permalink
basic stump working
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <manish9ue@gmail.com>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent 8bca1e2 commit 0012a77
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 19 deletions.
170 changes: 152 additions & 18 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class DecisionTree(val strategy : Strategy) {
for (level <- 0 until maxDepth){
//Find best split for all nodes at a level
val numNodes= scala.math.pow(2,level).toInt
val bestSplits = DecisionTree.findBestSplits(input, strategy, level, filters,splits,bins)
//TODO: Change the input parent impurities values
val bestSplits = DecisionTree.findBestSplits(input, Array(0.0), strategy, level, filters,splits,bins)
//TODO: update filters and decision tree model
}

Expand All @@ -60,6 +61,7 @@ object DecisionTree extends Serializable {
Returns an Array[Split] of optimal splits for all nodes at a given level
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree
@param parentImpurities Impurities for all parent nodes for the current level
@param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree
@param level Level of the tree
@param filters Filter for all nodes at a given level
Expand All @@ -70,13 +72,14 @@ object DecisionTree extends Serializable {
*/
def findBestSplits(
input : RDD[LabeledPoint],
parentImpurities : Array[Double],
strategy: Strategy,
level: Int,
filters : Array[List[Filter]],
splits : Array[Array[Split]],
bins : Array[Array[Bin]]) : Array[Split] = {

//TODO: Move these calculations outside
//Common calculations for multiple nested methods
val numNodes = scala.math.pow(2, level).toInt
println("numNodes = " + numNodes)
//Find the number of features by looking at the first sample
Expand Down Expand Up @@ -118,6 +121,7 @@ object DecisionTree extends Serializable {

/*Finds the right bin for the given feature*/
def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = {
println("finding bin for labeled point " + labeledPoint.features(featureIndex))
//TODO: Do binary search
for (binIndex <- 0 until strategy.numSplits) {
val bin = bins(featureIndex)(binIndex)
Expand All @@ -134,7 +138,7 @@ object DecisionTree extends Serializable {
}

/*Finds bins for all nodes (and all features) at a given level
k features, l nodes
k features, l nodes (level = log2(l))
Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk
Denotes invalid sample for tree by noting bin for feature 1 as -1
*/
Expand Down Expand Up @@ -167,49 +171,179 @@ object DecisionTree extends Serializable {
}

/*
Performs a sequential aggreation over a partition
Performs a sequential aggregation over a partition.
@param agg Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification
and 3*numSplits*numFeatures*numNodes for regression
for p bins, k features, l nodes (level = log2(l)) storage is of the form:
b111_left_count,b111_right_count, .... , bpk1_left_count, bpk1_right_count, .... , bpkl_left_count, bpkl_right_count
@param agg Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification
@param arr Array[Double] of size 1+(numFeatures*numNodes)
@return Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification
and 3*numSplits*numFeatures*numNodes for regression
@return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification
*/
def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = {
//TODO: Requires logic for regressions
for (node <- 0 until numNodes) {
val validSignalIndex = 1+numFeatures*node
val isSampleValidForNode = if(arr(validSignalIndex) != -1) true else false
if(isSampleValidForNode) {
if(isSampleValidForNode){
val label = arr(0)
for (feature <- 0 until numFeatures){
val arrShift = 1 + numFeatures*node
val aggShift = numSplits*numFeatures*node
val aggShift = 2*numSplits*numFeatures*node
val arrIndex = arrShift + feature
val aggIndex = aggShift + feature*numSplits + arr(arrIndex).toInt
agg(aggIndex) = agg(aggIndex) + 1
val aggIndex = aggShift + 2*feature*numSplits + arr(arrIndex).toInt*2
label match {
case(0.0) => agg(aggIndex) = agg(aggIndex) + 1
case(1.0) => agg(aggIndex+1) = agg(aggIndex+1) + 1
}
}
}
}
agg
}

def binCombOp(par1 : Array[Double], par2: Array[Double]) : Array[Double] = {
par1
//TODO: This length if different for regression
val binAggregateLength = 2*numSplits * numFeatures * numNodes
println("binAggregageLength = " + binAggregateLength)

/*Combines the aggregates from partitions
@param agg1 Array containing aggregates from one or more partitions
@param agg2 Array contianing aggregates from one or more partitions
@return Combined aggregate from agg1 and agg2
*/
def binCombOp(agg1 : Array[Double], agg2: Array[Double]) : Array[Double] = {
val combinedAggregate = new Array[Double](binAggregateLength)
for (index <- 0 until binAggregateLength){
combinedAggregate(index) = agg1(index) + agg2(index)
}
combinedAggregate
}

println("input = " + input.count)
val binMappedRDD = input.map(x => findBinsForLevel(x))
println("binMappedRDD.count = " + binMappedRDD.count)
//calculate bin aggregates

val binAggregates = binMappedRDD.aggregate(Array.fill[Double](numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp)

//find best split
val binAggregates = binMappedRDD.aggregate(Array.fill[Double](2*numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp)
println("binAggregates.length = " + binAggregates.length)
binAggregates.foreach(x => println(x))


def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], featureIndex: Int, index: Int, rightNodeAgg: Array[Array[Double]], topImpurity: Double): Double = {

val left0Count = leftNodeAgg(featureIndex)(2 * index)
val left1Count = leftNodeAgg(featureIndex)(2 * index + 1)
val leftCount = left0Count + left1Count
println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)

val right0Count = rightNodeAgg(featureIndex)(2 * index)
val right1Count = rightNodeAgg(featureIndex)(2 * index + 1)
val rightCount = right0Count + right1Count
println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)

val leftWeight = leftCount.toDouble / (leftCount + rightCount)
val rightWeight = rightCount.toDouble / (leftCount + rightCount)

topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity

}

/*
Extracts left and right split aggregates
@param binData Array[Double] of size 2*numFeatures*numSplits
@return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], Array[Double]) where
each array is of size(numFeature,2*(numSplits-1))
*/
def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = {
val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1))
val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1))
println("binData.length = " + binData.length)
println("binData.sum = " + binData.sum)
for (featureIndex <- 0 until numFeatures) {
println("featureIndex = " + featureIndex)
val shift = 2*featureIndex*numSplits
leftNodeAgg(featureIndex)(0) = binData(shift + 0)
println("binData(shift + 0) = " + binData(shift + 0))
leftNodeAgg(featureIndex)(1) = binData(shift + 1)
println("binData(shift + 1) = " + binData(shift + 1))
rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1)))
println(binData(shift + (2 * (numSplits - 1))))
rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1)
println(binData(shift + (2 * (numSplits - 1)) + 1))
for (splitIndex <- 1 until numSplits - 1) {
println("splitIndex = " + splitIndex)
leftNodeAgg(featureIndex)(2 * splitIndex)
= binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2)
leftNodeAgg(featureIndex)(2 * splitIndex + 1)
= binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1)
rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex))
= binData(shift + (2 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex))
rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex) + 1)
= binData(shift + (2 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex) + 1)
}
}
(leftNodeAgg, rightNodeAgg)
}

def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double): Array[Array[Double]] = {

val gains = Array.ofDim[Double](numFeatures, numSplits - 1)

for (featureIndex <- 0 until numFeatures) {
for (index <- 0 until numSplits -1) {
println("splitIndex = " + index)
gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity)
}
}
gains
}

/*
Find the best split for a node given bin aggregate data
@param binData Array[Double] of size 2*numSplits*numFeatures
*/
def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : Split = {
println("node impurity = " + nodeImpurity)
val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)

println("gains.size = " + gains.size)
println("gains(0).size = " + gains(0).size)

val (bestFeatureIndex,bestSplitIndex) = {
var bestFeatureIndex = 0
var bestSplitIndex = 0
var maxGain = Double.MinValue
for (featureIndex <- 0 until numFeatures) {
for (splitIndex <- 0 until numSplits - 1){
val gain = gains(featureIndex)(splitIndex)
println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain)
if(gain > maxGain) {
maxGain = gain
bestFeatureIndex = featureIndex
bestSplitIndex = splitIndex
println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex + ", maxGain = " + maxGain)
}
}
}
(bestFeatureIndex,bestSplitIndex)
}

splits(bestFeatureIndex)(bestSplitIndex)
}

//Calculate best splits for all nodes at a given level
val bestSplits = new Array[Split](numNodes)
for (node <- 0 until numNodes){
val binsForNode = binAggregates.slice(node,numSplits*node)
val shift = 2*node*numSplits*numFeatures
val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures)
val parentNodeImpurity = parentImpurities(node/2)
bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
}

bestSplits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(splits(0).length==99)
assert(bins(0).length==100)
println(splits(1)(98))
DecisionTree.findBestSplits(rdd,strategy,0,Array[List[Filter]](),splits,bins)
DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins)
}

}
Expand Down

0 comments on commit 0012a77

Please sign in to comment.