diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index f89c53a7ad70d..ed0cf825b1d50 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -204,15 +204,12 @@ object DecisionTree extends Serializable with Logging { } /*Finds the right bin for the given feature*/ - def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = { - //logDebug("finding bin for labeled point " + labeledPoint.features(featureIndex)) + def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinous : Boolean) : Int = { - val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinous){ //TODO: Do binary search for (binIndex <- 0 until strategy.numBins) { val bin = bins(featureIndex)(binIndex) - //TODO: Remove this requirement post basic functional val lowThreshold = bin.lowSplit.threshold val highThreshold = bin.highSplit.threshold val features = labeledPoint.features @@ -222,9 +219,9 @@ object DecisionTree extends Serializable with Logging { } throw new UnknownError("no bin was found for continuous variable.") } else { + for (binIndex <- 0 until strategy.numBins) { val bin = bins(featureIndex)(binIndex) - //TODO: Remove this requirement post basic functional val category = bin.category val features = labeledPoint.features if (category == features(featureIndex)) { @@ -262,7 +259,8 @@ object DecisionTree extends Serializable with Logging { } else { for (featureIndex <- 0 until numFeatures) { //logDebug("shift+featureIndex =" + (shift+featureIndex)) - arr(shift + featureIndex) = findBin(featureIndex, labeledPoint) + val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinous) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 4e68611d2be9e..40bb94e6794d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -75,6 +75,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { println(splits(1)(0)) println(splits(1)(1)) println(bins(1)(0)) + //TODO: Add asserts + } test("split and bin calculations for categorical variables with no sample for one category"){ @@ -100,12 +102,28 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { println(bins(1)(1)) println(bins(0)(2)) println(bins(0)(3)) + //TODO: Add asserts + } //TODO: Test max feature value > num bins - test("stump with fixed label 0 for Gini"){ + test("stump with all categorical variables"){ + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) + strategy.numBins = 100 + val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) + println(bestSplits(0)._1) + println(bestSplits(0)._2) + //TODO: Add asserts + } + + + test("stump with fixed label 0 for Gini"){ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr)