Skip to content

Commit

Permalink
minor cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <manish9ue@gmail.com>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent c0e522b commit f067d68
Showing 1 changed file with 7 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
logDebug("numSplits = " + bins(0).length)
strategy.numBins = bins(0).length

//TODO: Level-wise training of tree and obtain Decision Tree model
val maxDepth = strategy.maxDepth

val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1
Expand All @@ -62,7 +61,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
logDebug("#####################################")

//Find best split for all nodes at a level
val numNodes= scala.math.pow(2,level).toInt
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters,splits,bins)

for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){
Expand Down Expand Up @@ -105,7 +103,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
private def extractInfoForLowerLevels(level: Int, index: Int, maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), parentImpurities: Array[Double], filters: Array[List[Filter]]) {
for (i <- 0 to 1) {

val nodeIndex = (scala.math.pow(2, level + 1)).toInt - 1 + 2 * index + i
val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i

if (level < maxDepth - 1) {

Expand Down Expand Up @@ -205,7 +203,6 @@ object DecisionTree extends Serializable with Logging {
def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous : Boolean) : Int = {

if (isFeatureContinuous){
//TODO: Do binary search
for (binIndex <- 0 until strategy.numBins) {
val bin = bins(featureIndex)(binIndex)
val lowThreshold = bin.lowSplit.threshold
Expand Down Expand Up @@ -250,9 +247,12 @@ object DecisionTree extends Serializable with Logging {
val shift = 1 + numFeatures * nodeIndex
if (!sampleValid) {
//Add to invalid bin index -1
for (featureIndex <- 0 until numFeatures) {
arr(shift+featureIndex) = -1
//TODO: Break since marking one bin is sufficient
breakable {
for (featureIndex <- 0 until numFeatures) {
arr(shift+featureIndex) = -1
//Breaking since marking one bin is sufficient
break()
}
}
} else {
for (featureIndex <- 0 until numFeatures) {
Expand Down Expand Up @@ -318,7 +318,6 @@ object DecisionTree extends Serializable with Logging {
def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = {
strategy.algo match {
case Classification => classificationBinSeqOp(arr, agg)
//TODO: Implement this
case Regression => regressionBinSeqOp(arr, agg)
}
agg
Expand Down Expand Up @@ -599,7 +598,6 @@ object DecisionTree extends Serializable with Logging {

logDebug("maxBins = " + numBins)
//Calculate the number of sample for approximate quantile calculation
//TODO: Justify this calculation
val requiredSamples = numBins*numBins
val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
logDebug("fraction of data used for calculating quantiles = " + fraction)
Expand All @@ -624,7 +622,6 @@ object DecisionTree extends Serializable with Logging {
val stride : Double = numSamples.toDouble/numBins
logDebug("stride = " + stride)
for (index <- 0 until numBins-1) {
//TODO: Investigate this
val sampleIndex = (index+1)*stride.toInt
val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous, List())
splits(featureIndex)(index) = split
Expand Down

0 comments on commit f067d68

Please sign in to comment.