Skip to content

Commit

Permalink
[MLlib] SPARK-1536: multiclass classification support for decision tree
Browse files Browse the repository at this point in the history
The ability to perform multiclass classification is a big advantage for using decision trees and was a highly requested feature for mllib. This pull request adds multiclass classification support to the MLlib decision tree. It also adds sample weights support using WeightedLabeledPoint class for handling unbalanced datasets during classification. It will also support algorithms such as AdaBoost which requires instances to be weighted.

It handles the special case where the categorical variables cannot be ordered for multiclass classification and thus the optimizations used for speeding up binary classification cannot be directly used for multiclass classification with categorical variables. More specifically, for m categories in a categorical feature, it analyses all the ```2^(m-1) - 1``` categorical splits provided that #splits are less than the maxBins provided in the input. This condition will not be met for features with large number of categories -- using decision trees is not recommended for such datasets in general since the categorical features are favored over continuous features. Moreover, the user can use a combination of tricks (increasing bin size of the tree algorithms, use binary encoding for categorical features or use one-vs-all classification strategy) to avoid these constraints.

The new code is accompanied by unit tests and has also been tested on the iris and covtype datasets.

cc: mengxr, etrain, hirakendu, atalwalkar, srowen

Author: Manish Amde <manish9ue@gmail.com>
Author: manishamde <manish9ue@gmail.com>
Author: Evan Sparks <sparks@cs.berkeley.edu>

Closes apache#886 from manishamde/multiclass and squashes the following commits:

26f8acc [Manish Amde] another attempt at fixing mima
c5b2d04 [Manish Amde] more MIMA fixes
1ce7212 [Manish Amde] change problem filter for mima
10fdd82 [Manish Amde] fixing MIMA excludes
e1c970d [Manish Amde] merged master
abf2901 [Manish Amde] adding classes to MimaExcludes.scala
45e767a [Manish Amde] adding developer api annotation for overriden methods
c8428c4 [Manish Amde] fixing weird multiline bug
afced16 [Manish Amde] removed label weights support
2d85a48 [Manish Amde] minor: fixed scalastyle issues reprise
4e85f2c [Manish Amde] minor: fixed scalastyle issues
b2ae41f [Manish Amde] minor: scalastyle
e4c1321 [Manish Amde] using while loop for regression histograms
d75ac32 [Manish Amde] removed WeightedLabeledPoint from this PR
0fecd38 [Manish Amde] minor: add newline to EOF
2061cf5 [Manish Amde] merged from master
06b1690 [Manish Amde] fixed off-by-one error in bin to split conversion
9cc3e31 [Manish Amde] added implicit conversion import
5c1b2ca [Manish Amde] doc for PointConverter class
485eaae [Manish Amde] implicit conversion from LabeledPoint to WeightedLabeledPoint
3d7f911 [Manish Amde] updated doc
8e44ab8 [Manish Amde] updated doc
adc7315 [Manish Amde] support ordered categorical splits for multiclass classification
e3e8843 [Manish Amde] minor code formatting
23d4268 [Manish Amde] minor: another minor code style
34ee7b9 [Manish Amde] minor: code style
237762d [Manish Amde] renaming functions
12e6d0a [Manish Amde] minor: removing line in doc
9a90c93 [Manish Amde] Merge branch 'master' into multiclass
1892a2c [Manish Amde] tests and use multiclass binaggregate length when atleast one categorical feature is present
f5f6b83 [Manish Amde] multiclass for continous variables
8cfd3b6 [Manish Amde] working for categorical multiclass classification
828ff16 [Manish Amde] added categorical variable test
bce835f [Manish Amde] code cleanup
7e5f08c [Manish Amde] minor doc
1dd2735 [Manish Amde] bin search logic for multiclass
f16a9bb [Manish Amde] fixing while loop
d811425 [Manish Amde] multiclass bin aggregate logic
ab5cb21 [Manish Amde] multiclass logic
d8e4a11 [Manish Amde] sample weights
ed5a2df [Manish Amde] fixed classification requirements
d012be7 [Manish Amde] fixed while loop
18d2835 [Manish Amde] changing default values for num classes
6b912dc [Manish Amde] added numclasses to tree runner, predict logic for multiclass, add multiclass option to train
75f2bfc [Manish Amde] minor code style fix
e547151 [Manish Amde] minor modifications
34549d0 [Manish Amde] fixing error during merge
098e8c5 [Manish Amde] merged master
e006f9d [Manish Amde] changing variable names
5c78e1a [Manish Amde] added multiclass support
6c7af22 [Manish Amde] prepared for multiclass without breaking binary classification
46e06ee [Manish Amde] minor mods
3f85a17 [Manish Amde] tests for multiclass classification
4d5f70c [Manish Amde] added multiclass support for find splits bins
46f909c [Manish Amde] todo for multiclass support
455bea9 [Manish Amde] fixed tests
14aea48 [Manish Amde] changing instance format to weighted labeled point
a1a6e09 [Manish Amde] added weighted point class
968ca9d [Manish Amde] merged master
7fc9545 [Manish Amde] added docs
ce004a1 [Manish Amde] minor formatting
b27ad2c [Manish Amde] formatting
426bb28 [Manish Amde] programming guide blurb
8053fed [Manish Amde] more formatting
5eca9e4 [Manish Amde] grammar
4731cda [Manish Amde] formatting
5e82202 [Manish Amde] added documentation, fixed off by 1 error in max level calculation
cbd9f14 [Manish Amde] modified scala.math to math
dad9652 [Manish Amde] removed unused imports
e0426ee [Manish Amde] renamed parameter
718506b [Manish Amde] added unit test
1517155 [Manish Amde] updated documentation
9dbdabe [Manish Amde] merge from master
719d009 [Manish Amde] updating user documentation
fecf89a [manishamde] Merge pull request alteryx#6 from etrain/deep_tree
0287772 [Evan Sparks] Fixing scalastyle issue.
2f1e093 [Manish Amde] minor: added doc for maxMemory parameter
2f6072c [manishamde] Merge pull request alteryx#5 from etrain/deep_tree
abc5a23 [Evan Sparks] Parameterizing max memory.
50b143a [Manish Amde] adding support for very deep trees
  • Loading branch information
manishamde authored and mengxr committed Jul 18, 2014
1 parent 586e716 commit d88f6be
Show file tree
Hide file tree
Showing 12 changed files with 926 additions and 258 deletions.
8 changes: 5 additions & 3 deletions docs/mllib-decision-tree.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,17 @@ bins if the condition is not satisfied.

**Categorical features**

For `$M$` categorical features, one could come up with `$2^M-1$` split candidates. However, for
binary classification, the number of split candidates can be reduced to `$M-1$` by ordering the
For `$M$` categorical feature values, one could come up with `$2^(M-1)-1$` split candidates. For
binary classification, we can reduce the number of split candidates to `$M-1$` by ordering the
categorical feature values by the proportion of labels falling in one of the two classes (see
Section 9.2.4 in
[Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for
details). For example, for a binary classification problem with one categorical feature with three
categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical
features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B
and A , B \| C where \| denotes the split.
and A , B \| C where \| denotes the split. A similar heuristic is used for multiclass classification
when `$2^(M-1)-1$` is greater than the number of bins -- the impurity for each categorical feature value
is used for ordering.

### Stopping rule

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ object DecisionTreeRunner {
case class Params(
input: String = null,
algo: Algo = Classification,
numClassesForClassification: Int = 2,
maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 100)
Expand All @@ -68,6 +69,10 @@ object DecisionTreeRunner {
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
.action((x, c) => c.copy(maxDepth = x))
opt[Int]("numClassesForClassification")
.text(s"number of classes for classification, "
+ s"default: ${defaultParams.numClassesForClassification}")
.action((x, c) => c.copy(numClassesForClassification = x))
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
Expand Down Expand Up @@ -118,7 +123,13 @@ object DecisionTreeRunner {
case Variance => impurity.Variance
}

val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins)
val strategy
= new Strategy(
algo = params.algo,
impurity = impurityCalculator,
maxDepth = params.maxDepth,
maxBins = params.maxBins,
numClassesForClassification = params.numClassesForClassification)
val model = DecisionTree.train(training, strategy)

if (params.algo == Classification) {
Expand All @@ -139,12 +150,8 @@ object DecisionTreeRunner {
*/
private def accuracyScore(
model: DecisionTreeModel,
data: RDD[LabeledPoint],
threshold: Double = 0.5): Double = {
def predictedValue(features: Vector): Double = {
if (model.predict(features) < threshold) 0.0 else 1.0
}
val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
data: RDD[LabeledPoint]): Double = {
val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
val count = data.count()
correctCount.toDouble / count
}
Expand Down
732 changes: 532 additions & 200 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
* @param maxDepth maximum depth of the tree
* @param numClassesForClassification number of classes for classification. Default value is 2
* leads to binary classification
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
Expand All @@ -44,7 +46,15 @@ class Strategy (
val algo: Algo,
val impurity: Impurity,
val maxDepth: Int,
val numClassesForClassification: Int = 2,
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
val maxMemoryInMB: Int = 128) extends Serializable
val maxMemoryInMB: Int = 128) extends Serializable {

require(numClassesForClassification >= 2)
val isMulticlassClassification = numClassesForClassification > 2
val isMulticlassWithCategoricalFeatures
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,35 @@ object Entropy extends Impurity {

/**
* :: DeveloperApi ::
* entropy calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return entropy value
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
* @return information value
*/
@DeveloperApi
override def calculate(c0: Double, c1: Double): Double = {
if (c0 == 0 || c1 == 0) {
0
} else {
val total = c0 + c1
val f0 = c0 / total
val f1 = c1 / total
-(f0 * log2(f0)) - (f1 * log2(f1))
override def calculate(counts: Array[Double], totalCount: Double): Double = {
val numClasses = counts.length
var impurity = 0.0
var classIndex = 0
while (classIndex < numClasses) {
val classCount = counts(classIndex)
if (classCount != 0) {
val freq = classCount / totalCount
impurity -= freq * log2(freq)
}
classIndex += 1
}
impurity
}

/**
* :: DeveloperApi ::
* variance calculation
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
*/
@DeveloperApi
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Entropy.calculate")
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,32 @@ object Gini extends Impurity {

/**
* :: DeveloperApi ::
* Gini coefficient calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return Gini coefficient value
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
* @return information value
*/
@DeveloperApi
override def calculate(c0: Double, c1: Double): Double = {
if (c0 == 0 || c1 == 0) {
0
} else {
val total = c0 + c1
val f0 = c0 / total
val f1 = c1 / total
1 - f0 * f0 - f1 * f1
override def calculate(counts: Array[Double], totalCount: Double): Double = {
val numClasses = counts.length
var impurity = 1.0
var classIndex = 0
while (classIndex < numClasses) {
val freq = counts(classIndex) / totalCount
impurity -= freq * freq
classIndex += 1
}
impurity
}

/**
* :: DeveloperApi ::
* variance calculation
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
*/
@DeveloperApi
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Gini.calculate")
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ trait Impurity extends Serializable {

/**
* :: DeveloperApi ::
* information calculation for binary classification
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
* @return information value
*/
@DeveloperApi
def calculate(c0 : Double, c1 : Double): Double
def calculate(counts: Array[Double], totalCount: Double): Double

/**
* :: DeveloperApi ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,16 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
*/
@Experimental
object Variance extends Impurity {
override def calculate(c0: Double, c1: Double): Double =

/**
* :: DeveloperApi ::
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
* @return information value
*/
@DeveloperApi
override def calculate(counts: Array[Double], totalCount: Double): Double =
throw new UnsupportedOperationException("Variance.calculate")

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
* @param highSplit signifying the upper threshold for the continuous feature to be
* accepted in the bin
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin
* @param category categorical label value accepted in the bin for binary classification
*/
private[tree]
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,19 @@ import org.apache.spark.annotation.DeveloperApi
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
@DeveloperApi
class InformationGainStats(
val gain: Double,
val impurity: Double,
val leftImpurity: Double,
val rightImpurity: Double,
val predict: Double) extends Serializable {
val predict: Double,
val prob: Double = 0.0) extends Serializable {

override def toString = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
.format(gain, impurity, leftImpurity, rightImpurity, predict)
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
.format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
}
}
Loading

0 comments on commit d88f6be

Please sign in to comment.