Skip to content

Commit

Permalink
updated predict and split threshold logic
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <manish9ue@gmail.com>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent b09dc98 commit c0e522b
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ object DecisionTree extends Serializable with Logging {
val lowThreshold = bin.lowSplit.threshold
val highThreshold = bin.highSplit.threshold
val features = labeledPoint.features
if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) {
if ((lowThreshold < features(featureIndex)) & (highThreshold >= features(featureIndex))) {
return binIndex
}
}
Expand Down Expand Up @@ -400,7 +400,8 @@ object DecisionTree extends Serializable with Logging {
}
}

val predict = leftCount / (leftCount + rightCount)
//val predict = leftCount / (leftCount + rightCount)
val predict = (left1Count + right1Count) / (leftCount + rightCount)

new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict)
}
Expand Down Expand Up @@ -672,8 +673,8 @@ object DecisionTree extends Serializable with Logging {

//Find all bins
for (featureIndex <- 0 until numFeatures){
val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinous) { //bins for categorical variables are already assigned
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) { //bins for categorical variables are already assigned
bins(featureIndex)(0)
= new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0),Continuous,Double.MinValue)
for (index <- 1 until numBins - 1){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ object DecisionTreeRunner extends Logging {
//TODO: Make these generic MLTable metrics
def meanSquaredError(tree : DecisionTreeModel, data : RDD[LabeledPoint]) : Double = {
val meanSumOfSquares = data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean()
println("meanSumOfSquares = " + meanSumOfSquares)
meanSumOfSquares
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializabl
def predict(features : Array[Double]) = {
algo match {
case Classification => {
if (topNode.predictIfLeaf(features) >= 0.5) 0.0 else 1.0
if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0
}
case Regression => {
topNode.predictIfLeaf(features)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ class InformationGainStats(val gain : Double,
//val rightSamples : Long
val predict : Double) extends Serializable {

override def toString =
"gain = " + gain + ", impurity = " + impurity + ", left impurity = "
+ leftImpurity + ", right impurity = " + rightImpurity + ", predict = " + predict
override def toString = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
.format(gain, impurity, leftImpurity, rightImpurity, predict)
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Node ( val id : Int,
def build(nodes : Array[Node]) : Unit = {

logDebug("building node " + id + " at level " + (scala.math.log(id + 1)/scala.math.log(2)).toInt )
logDebug("id = " + id + ", split = " + split)
logDebug("stats = " + stats)
logDebug("predict = " + predict)
if (!isLeaf) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(0==bestSplits(0)._2.gain)
assert(0==bestSplits(0)._2.leftImpurity)
assert(0==bestSplits(0)._2.rightImpurity)
assert(0.01==bestSplits(0)._2.predict)
println(bestSplits(0)._2.predict)
}

test("stump with fixed label 1 for Gini"){
Expand All @@ -181,7 +181,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(0==bestSplits(0)._2.gain)
assert(0==bestSplits(0)._2.leftImpurity)
assert(0==bestSplits(0)._2.rightImpurity)
assert(0.01==bestSplits(0)._2.predict)
assert(1==bestSplits(0)._2.predict)

}

Expand All @@ -207,7 +207,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(0==bestSplits(0)._2.gain)
assert(0==bestSplits(0)._2.leftImpurity)
assert(0==bestSplits(0)._2.rightImpurity)
assert(0.01==bestSplits(0)._2.predict)
assert(0==bestSplits(0)._2.predict)
}

test("stump with fixed label 1 for Entropy"){
Expand All @@ -231,7 +231,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(0==bestSplits(0)._2.gain)
assert(0==bestSplits(0)._2.leftImpurity)
assert(0==bestSplits(0)._2.rightImpurity)
assert(0.01==bestSplits(0)._2.predict)
assert(1==bestSplits(0)._2.predict)
}


Expand Down

0 comments on commit c0e522b

Please sign in to comment.