diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index c467d5ba65d94..fb752b06380ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -545,17 +545,24 @@ object DecisionTree extends Serializable with Logging { -1 } + /** + * Sequential search helper method to find bin for categorical feature in multiclass + * classification. Dummy value of 0 used since it is not used in future calculation + */ + def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = 0 + /** * Sequential search helper method to find bin for categorical feature. */ - def sequentialBinSearchForCategoricalFeature(): Int = { - val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) + def sequentialBinSearchForCategoricalFeatureInMultiClassClassification(): Int = { + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { val bin = bins(featureIndex)(binIndex) - val category = bin.category + val categories = bin.highSplit.categories val features = labeledPoint.features - if (category == features(featureIndex)) { + if (categories.contains(features(featureIndex))) { return binIndex } binIndex += 1 @@ -572,7 +579,14 @@ object DecisionTree extends Serializable with Logging { binIndex } else { // Perform sequential search to find bin for categorical features. - val binIndex = sequentialBinSearchForCategoricalFeature() + val binIndex = { + if (strategy.isMultiClassification) { + sequentialBinSearchForCategoricalFeatureInBinaryClassification() + } + else { + sequentialBinSearchForCategoricalFeatureInMultiClassClassification() + } + } if (binIndex == -1){ throw new UnknownError("no bin was found for categorical variable.") } @@ -584,7 +598,8 @@ object DecisionTree extends Serializable with Logging { * Finds bins for all nodes (and all features) at a given level. * For l nodes, k features the storage is as follows: * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, - * where b_ij is an integer between 0 and numBins - 1. + * where b_ij is an integer between 0 and numBins - 1 for regressions and binary + * classification and an invalid value for categorical feature in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. */ def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = { @@ -646,7 +661,22 @@ object DecisionTree extends Serializable with Logging { = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses label.toInt match { case n: Int => - agg(aggIndex + n) = agg(aggIndex + n) + 1 * labelWeights.getOrElse(n, 1) + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous && strategy.isMultiClassification) { + // Find all matching bins and increment their values + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + var binIndex = 0 + while (binIndex < numCategoricalBins) { + if (bins(featureIndex)(binIndex).highSplit.categories.contains(n)){ + agg(aggIndex + binIndex) + = agg(aggIndex + binIndex) + labelWeights.getOrElse(binIndex, 1) + } + binIndex += 1 + } + } else { + agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1) + } } featureIndex += 1 } @@ -705,6 +735,7 @@ object DecisionTree extends Serializable with Logging { agg } + // TODO: Double-check this // Calculate bin aggregate length for classification or regression. val binAggregateLength = strategy.algo match { case Classification => numClasses * numBins * numFeatures * numNodes @@ -785,10 +816,10 @@ object DecisionTree extends Serializable with Logging { } if (leftTotalCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1) + return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1) } if (rightTotalCount == 0) { - return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0) + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1) } val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount) @@ -812,16 +843,16 @@ object DecisionTree extends Serializable with Logging { = leftCounts.zip(rightCounts) .map{case (leftCount, rightCount) => leftCount + rightCount} - def indexOfLargest(array: Seq[Double]): Int = { + def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1,Double.MinValue,0) { case ((maxIndex, maxValue, currentIndex), currentValue) => if(currentValue > maxValue) (currentIndex,currentValue,currentIndex+1) else (maxIndex,maxValue,currentIndex+1) } - if (result._1 < 0) result._1 else 0 + if (result._1 < 0) 0 else result._1 } - val predict = indexOfLargest(leftRightCounts) + val predict = indexOfLargestArrayElement(leftRightCounts) val prob = leftRightCounts(predict) / totalCount new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) @@ -1051,8 +1082,20 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 - // TODO: Modify this for categorical variables to go over only valid splits - while (splitIndex < numBins - 1) { + val maxSplitIndex : Double = { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + numBins - 1 + } else { // Categorical feature + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + if (strategy.isMultiClassification) { + math.pow(2.0, featureCategories - 1).toInt - 1 + } else { // Binary classification + featureCategories + } + } + } + while (splitIndex < maxSplitIndex) { val gainStats = gains(featureIndex)(splitIndex) if (gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats @@ -1176,24 +1219,29 @@ object DecisionTree extends Serializable with Logging { splits(featureIndex)(index) = split } } else { // Categorical feature - val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) // Use different bin/split calculation strategy for multiclass classification if (strategy.isMultiClassification) { - // Iterate from 0 to 2^maxFeatureValue - 1 leading to 2^(maxFeatureValue- 1) - 1 - // combinations. + // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 - while (index < math.pow(2.0, maxFeatureValue).toInt - 1) { + while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { val categories: List[Double] - = extractMultiClassCategories(index + 1, maxFeatureValue) + = extractMultiClassCategories(index + 1, featureCategories) splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, categories) bins(featureIndex)(index) = { if (index == 0) { - new Bin(new DummyCategoricalSplit(featureIndex, Categorical), - splits(featureIndex)(0), Categorical, Double.MinValue) + new Bin( + new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), + Categorical, + Double.MinValue) } else { - new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Categorical, + new Bin( + splits(featureIndex)(index - 1), + splits(featureIndex)(index), + Categorical, Double.MinValue) } } @@ -1210,7 +1258,7 @@ object DecisionTree extends Serializable with Logging { // Check for missing categorical variables and putting them last in the sorted list. val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() - for (i <- 0 until maxFeatureValue) { + for (i <- 0 until featureCategories) { if (centroidForCategories.contains(i)) { fullCentroidForCategories(i) = centroidForCategories(i) } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 6366960f39b02..ead76d64b6383 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -42,8 +42,11 @@ object Entropy extends Impurity { var impurity = 0.0 var classIndex = 0 while (classIndex < numClasses) { - val freq = counts(classIndex) / totalCount - impurity -= freq * log2(freq) + val classCount = counts(classIndex) + if (classCount != 0) { + val freq = classCount / totalCount + impurity -= freq * log2(freq) + } classIndex += 1 } impurity diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 2d71e1e366069..c89c1e371a40e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ * @param highSplit signifying the upper threshold for the continuous feature to be * accepted in the bin * @param featureType type of feature -- categorical or continuous - * @param category categorical label value accepted in the bin + * @param category categorical label value accepted in the bin for binary classification */ private[tree] case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index b5259d8e18822..e7a55d52e7367 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -35,7 +35,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 100) + val strategy = new Strategy(Classification, Gini, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(bins.length === 2) @@ -51,6 +51,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Classification, Gini, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) @@ -130,6 +131,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Classification, Gini, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) @@ -237,7 +239,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq) } - test("split and bin calculations for categorical variables wiht multiclass classification") { + test("split and bin calculations for categorical variables with multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -245,12 +247,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Classification, Gini, maxDepth = 3, + numClassesForClassification = 100, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1-> 2), - numClassesForClassification = 3) + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - // Expecting 2^3 - 1 = 7 bins/splits + // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) assert(splits(0)(0).threshold === Double.MinValue) assert(splits(0)(0).featureType === Categorical) @@ -287,6 +289,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(1)(2).categories.contains(1.0)) assert(splits(0)(3) === null) + assert(splits(1)(3) === null) // Check bins. @@ -329,22 +332,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } - test("split and bin calculations for categorical variables with no sample for one category " + - "for multiclass classification") { - val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 3, - maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3), - numClassesForClassification = 3) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - - } - test("classification stump with all categorical variables") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) @@ -352,6 +339,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, + numClassesForClassification = 2, maxDepth = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) @@ -367,8 +355,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) - assert(stats.predict > 0.4) - assert(stats.predict < 0.5) + assert(stats.predict === 0) + assert(stats.prob > 0.5) + assert(stats.prob < 0.6) assert(stats.impurity > 0.2) } @@ -403,7 +392,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 100) + val strategy = new Strategy(Classification, Gini, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -426,7 +415,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 100) + val strategy = new Strategy(Classification, Gini, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -450,7 +439,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 100) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -474,7 +463,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 100) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -498,7 +487,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 100) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99)