From ff363a7353b28e9bcf16944deb376e075555dfd1 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 16 Mar 2014 23:28:45 -0700 Subject: [PATCH] binary search for bins and while loop for categorical feature bins --- .../spark/mllib/tree/DecisionTree.scala | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index a16bff2b5f4d7..b7492038445cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -319,6 +319,7 @@ object DecisionTree extends Serializable with Logging { true } + // TODO: Unit test this /** * Finds the right bin for the given feature */ @@ -328,26 +329,47 @@ object DecisionTree extends Serializable with Logging { isFeatureContinuous: Boolean) : Int = { - if (isFeatureContinuous){ - for (binIndex <- 0 until strategy.numBins) { - val bin = bins(featureIndex)(binIndex) + val binForFeatures = bins(featureIndex) + val feature = labeledPoint.features(featureIndex) + + def binarySearchForBins(): Int = { + var left = 0 + var right = binForFeatures.length-1 + while (left <= right) { + val mid = left + (right - left) / 2 + val bin = binForFeatures(mid) val lowThreshold = bin.lowSplit.threshold val highThreshold = bin.highSplit.threshold - val features = labeledPoint.features - if ((lowThreshold < features(featureIndex)) & (highThreshold >= features(featureIndex))) { - return binIndex + if ((lowThreshold < feature) & (highThreshold >= feature)){ + return mid + } + else if ((lowThreshold >= feature)){ + right = mid - 1 } + else { + left = mid + 1 + } + } + -1 + } + + if (isFeatureContinuous){ + val binIndex = binarySearchForBins() + if (binIndex == -1){ + throw new UnknownError("no bin was found for continuous variable.") } - throw new UnknownError("no bin was found for continuous variable.") + binIndex } else { val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) - for (binIndex <- 0 until numCategoricalBins) { + var binIndex = 0 + while (binIndex < numCategoricalBins) { val bin = bins(featureIndex)(binIndex) val category = bin.category val features = labeledPoint.features if (category == features(featureIndex)) { return binIndex } + binIndex += 1 } throw new UnknownError("no bin was found for categorical variable.")