From 4d5f70c4688c1183b754f2133a4d5a11d862070a Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 5 May 2014 22:52:08 -0700 Subject: [PATCH] added multiclass support for find splits bins --- .../spark/mllib/tree/DecisionTree.scala | 117 ++++++++++++------ .../mllib/tree/configuration/Strategy.scala | 10 +- 2 files changed, 90 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 7c18834ef3468..1c2f4cd704741 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1006,15 +1006,20 @@ object DecisionTree extends Serializable with Logging { // TODO: Multiclass modification here /* - * TODO: Add a require statement ensuring #bins is always greater than the categories. + * Ensure #bins is always greater than the categories. For multiclass classification, + * #bins should be greater than 2^(maxCategories - 1) - 1. * It's a limitation of the current implementation but a reasonable trade-off since features * with large number of categories get favored over continuous features. */ if (strategy.categoricalFeaturesInfo.size > 0) { val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 require(numBins >= maxCategoriesForFeatures) + if (strategy.isMultiClassification) { + require(numBins > math.pow(2, maxCategoriesForFeatures.toInt) - 1) + } } + // Calculate the number of sample for approximate quantile calculation. val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 @@ -1048,49 +1053,69 @@ object DecisionTree extends Serializable with Logging { val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) splits(featureIndex)(index) = split } - } else { - // TODO: Multiclass modification here + } else { // Categorical feature val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) - require(maxFeatureValue < numBins, "number of categories should be less than number " + - "of bins") - - // For categorical variables, each bin is a category. The bins are sorted and they - // are ordered by calculating the centroid of their corresponding labels. - val centroidForCategories = - sampledInput.map(lp => (lp.features(featureIndex),lp.label)) - .groupBy(_._1) - .mapValues(x => x.map(_._2).sum / x.map(_._1).length) - - // Check for missing categorical variables and putting them last in the sorted list. - val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() - for (i <- 0 until maxFeatureValue) { - if (centroidForCategories.contains(i)) { - fullCentroidForCategories(i) = centroidForCategories(i) - } else { - fullCentroidForCategories(i) = Double.MaxValue - } - } - // bins sorted by centroids - val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - - logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) - - var categoriesForSplit = List[Double]() - categoriesSortedByCentroid.iterator.zipWithIndex.foreach { - case ((key, value), index) => - categoriesForSplit = key :: categoriesForSplit - splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, - categoriesForSplit) + // Use different bin/split calculation strategy for multiclass classification + if (strategy.isMultiClassification) { + // Iterate from 1 to 2^maxFeatureValue leading to 2^(maxFeatureValue- 1) - 1 + // combinations. + var index = 1 + while (index < math.pow(2.0, maxFeatureValue).toInt) { + val categories: List[Double] = extractMultiClassCategories(index, maxFeatureValue) + splits(featureIndex)(index) + = new Split(featureIndex, Double.MinValue, Categorical, categories) bins(featureIndex)(index) = { if (index == 0) { new Bin(new DummyCategoricalSplit(featureIndex, Categorical), - splits(featureIndex)(0), Categorical, key) + splits(featureIndex)(0), Categorical, index) } else { - new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), - Categorical, key) + new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Categorical, + Double.MinValue) } } + index += 1 + } + } else { // regression or binary classification + + // For categorical variables, each bin is a category. The bins are sorted and they + // are ordered by calculating the centroid of their corresponding labels. + val centroidForCategories = + sampledInput.map(lp => (lp.features(featureIndex),lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + + // Check for missing categorical variables and putting them last in the sorted list. + val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() + for (i <- 0 until maxFeatureValue) { + if (centroidForCategories.contains(i)) { + fullCentroidForCategories(i) = centroidForCategories(i) + } else { + fullCentroidForCategories(i) = Double.MaxValue + } + } + + // bins sorted by centroids + val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) + + logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) + + var categoriesForSplit = List[Double]() + categoriesSortedByCentroid.iterator.zipWithIndex.foreach { + case ((key, value), index) => + categoriesForSplit = key :: categoriesForSplit + splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, + categoriesForSplit) + bins(featureIndex)(index) = { + if (index == 0) { + new Bin(new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), Categorical, key) + } else { + new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Categorical, key) + } + } + } } } featureIndex += 1 @@ -1120,4 +1145,24 @@ object DecisionTree extends Serializable with Logging { throw new UnsupportedOperationException("approximate histogram not supported yet.") } } + + /** + * Nested method to extract list of eligible categories given an index + */ + private def extractMultiClassCategories(i: Int, maxFeatureValue: Double): List[Double] = { + // TODO: Test this + var categories = List[Double]() + var j = 0 + while (j < maxFeatureValue) { + var copy = i + if (copy % 2 != 0) { + // updating the list of categories. + categories = j.toDouble :: categories + } + copy = copy >> 1 + j += 1 + } + categories + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 1b505fd76eb75..3aa2d5382cb83 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -37,6 +37,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * zero-indexed. * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. + * @param numClassesForClassification number of classes for classification. Default value is 2 + * leads to binary classification * */ @Experimental @@ -47,4 +49,10 @@ class Strategy ( val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val maxMemoryInMB: Int = 128) extends Serializable + val maxMemoryInMB: Int = 128, + val numClassesForClassification: Int = 2) extends Serializable { + + require(numClassesForClassification >= 2) + val isMultiClassification = numClassesForClassification > 2 + +}