Skip to content

Commit

Permalink
added multiclass support for find splits bins
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed May 12, 2014
1 parent 46f909c commit 4d5f70c
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 37 deletions.
117 changes: 81 additions & 36 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1006,15 +1006,20 @@ object DecisionTree extends Serializable with Logging {
// TODO: Multiclass modification here

/*
* TODO: Add a require statement ensuring #bins is always greater than the categories.
* Ensure #bins is always greater than the categories. For multiclass classification,
* #bins should be greater than 2^(maxCategories - 1) - 1.
* It's a limitation of the current implementation but a reasonable trade-off since features
* with large number of categories get favored over continuous features.
*/
if (strategy.categoricalFeaturesInfo.size > 0) {
val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
require(numBins >= maxCategoriesForFeatures)
if (strategy.isMultiClassification) {
require(numBins > math.pow(2, maxCategoriesForFeatures.toInt) - 1)
}
}


// Calculate the number of sample for approximate quantile calculation.
val requiredSamples = numBins*numBins
val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
Expand Down Expand Up @@ -1048,49 +1053,69 @@ object DecisionTree extends Serializable with Logging {
val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List())
splits(featureIndex)(index) = split
}
} else {
// TODO: Multiclass modification here
} else { // Categorical feature
val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
require(maxFeatureValue < numBins, "number of categories should be less than number " +
"of bins")

// For categorical variables, each bin is a category. The bins are sorted and they
// are ordered by calculating the centroid of their corresponding labels.
val centroidForCategories =
sampledInput.map(lp => (lp.features(featureIndex),lp.label))
.groupBy(_._1)
.mapValues(x => x.map(_._2).sum / x.map(_._1).length)

// Check for missing categorical variables and putting them last in the sorted list.
val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
for (i <- 0 until maxFeatureValue) {
if (centroidForCategories.contains(i)) {
fullCentroidForCategories(i) = centroidForCategories(i)
} else {
fullCentroidForCategories(i) = Double.MaxValue
}
}

// bins sorted by centroids
val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2)

logDebug("centriod for categorical variable = " + categoriesSortedByCentroid)

var categoriesForSplit = List[Double]()
categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
case ((key, value), index) =>
categoriesForSplit = key :: categoriesForSplit
splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical,
categoriesForSplit)
// Use different bin/split calculation strategy for multiclass classification
if (strategy.isMultiClassification) {
// Iterate from 1 to 2^maxFeatureValue leading to 2^(maxFeatureValue- 1) - 1
// combinations.
var index = 1
while (index < math.pow(2.0, maxFeatureValue).toInt) {
val categories: List[Double] = extractMultiClassCategories(index, maxFeatureValue)
splits(featureIndex)(index)
= new Split(featureIndex, Double.MinValue, Categorical, categories)
bins(featureIndex)(index) = {
if (index == 0) {
new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
splits(featureIndex)(0), Categorical, key)
splits(featureIndex)(0), Categorical, index)
} else {
new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
Categorical, key)
new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Categorical,
Double.MinValue)
}
}
index += 1
}
} else { // regression or binary classification

// For categorical variables, each bin is a category. The bins are sorted and they
// are ordered by calculating the centroid of their corresponding labels.
val centroidForCategories =
sampledInput.map(lp => (lp.features(featureIndex),lp.label))
.groupBy(_._1)
.mapValues(x => x.map(_._2).sum / x.map(_._1).length)

// Check for missing categorical variables and putting them last in the sorted list.
val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
for (i <- 0 until maxFeatureValue) {
if (centroidForCategories.contains(i)) {
fullCentroidForCategories(i) = centroidForCategories(i)
} else {
fullCentroidForCategories(i) = Double.MaxValue
}
}

// bins sorted by centroids
val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2)

logDebug("centriod for categorical variable = " + categoriesSortedByCentroid)

var categoriesForSplit = List[Double]()
categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
case ((key, value), index) =>
categoriesForSplit = key :: categoriesForSplit
splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical,
categoriesForSplit)
bins(featureIndex)(index) = {
if (index == 0) {
new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
splits(featureIndex)(0), Categorical, key)
} else {
new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
Categorical, key)
}
}
}
}
}
featureIndex += 1
Expand Down Expand Up @@ -1120,4 +1145,24 @@ object DecisionTree extends Serializable with Logging {
throw new UnsupportedOperationException("approximate histogram not supported yet.")
}
}

/**
* Nested method to extract list of eligible categories given an index
*/
private def extractMultiClassCategories(i: Int, maxFeatureValue: Double): List[Double] = {
// TODO: Test this
var categories = List[Double]()
var j = 0
while (j < maxFeatureValue) {
var copy = i
if (copy % 2 != 0) {
// updating the list of categories.
categories = j.toDouble :: categories
}
copy = copy >> 1
j += 1
}
categories
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* zero-indexed.
* @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is
* 128 MB.
* @param numClassesForClassification number of classes for classification. Default value is 2
* leads to binary classification
*
*/
@Experimental
Expand All @@ -47,4 +49,10 @@ class Strategy (
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
val maxMemoryInMB: Int = 128) extends Serializable
val maxMemoryInMB: Int = 128,
val numClassesForClassification: Int = 2) extends Serializable {

require(numClassesForClassification >= 2)
val isMultiClassification = numClassesForClassification > 2

}

0 comments on commit 4d5f70c

Please sign in to comment.