Skip to content

Commit

Permalink
tests for multiclass classification
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed May 12, 2014
1 parent 4d5f70c commit 3f85a17
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 25 deletions.
30 changes: 17 additions & 13 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1003,8 +1003,6 @@ object DecisionTree extends Serializable with Logging {
val numBins = if (maxBins <= count) maxBins else count.toInt
logDebug("numBins = " + numBins)

// TODO: Multiclass modification here

/*
* Ensure #bins is always greater than the categories. For multiclass classification,
* #bins should be greater than 2^(maxCategories - 1) - 1.
Expand Down Expand Up @@ -1058,17 +1056,18 @@ object DecisionTree extends Serializable with Logging {

// Use different bin/split calculation strategy for multiclass classification
if (strategy.isMultiClassification) {
// Iterate from 1 to 2^maxFeatureValue leading to 2^(maxFeatureValue- 1) - 1
// Iterate from 0 to 2^maxFeatureValue - 1 leading to 2^(maxFeatureValue- 1) - 1
// combinations.
var index = 1
while (index < math.pow(2.0, maxFeatureValue).toInt) {
val categories: List[Double] = extractMultiClassCategories(index, maxFeatureValue)
var index = 0
while (index < math.pow(2.0, maxFeatureValue).toInt - 1) {
val categories: List[Double]
= extractMultiClassCategories(index + 1, maxFeatureValue)
splits(featureIndex)(index)
= new Split(featureIndex, Double.MinValue, Categorical, categories)
bins(featureIndex)(index) = {
if (index == 0) {
new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
splits(featureIndex)(0), Categorical, index)
splits(featureIndex)(0), Categorical, Double.MinValue)
} else {
new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Categorical,
Double.MinValue)
Expand Down Expand Up @@ -1147,19 +1146,24 @@ object DecisionTree extends Serializable with Logging {
}

/**
* Nested method to extract list of eligible categories given an index
* Nested method to extract list of eligible categories given an index. It extracts the
* position of ones in a binary representation of the input. If binary
* representation of an number is 01101 (13), the output list should (3.0, 2.0,
* 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
*/
private def extractMultiClassCategories(i: Int, maxFeatureValue: Double): List[Double] = {
// TODO: Test this
private[tree] def extractMultiClassCategories(
input: Int,
maxFeatureValue: Int): List[Double] = {
var categories = List[Double]()
var j = 0
var bitShiftedInput = input
while (j < maxFeatureValue) {
var copy = i
if (copy % 2 != 0) {
if (bitShiftedInput % 2 != 0) {
// updating the list of categories.
categories = j.toDouble :: categories
}
copy = copy >> 1
//Right shift by one
bitShiftedInput = bitShiftedInput >> 1
j += 1
}
categories
Expand Down
138 changes: 126 additions & 12 deletions mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,120 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(1)(3) === null)
}

test("extract categories from a number for multiclass classification") {
val l = DecisionTree.extractMultiClassCategories(13, 10)
assert(l.length === 3)
assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq)
}

test("split and bin calculations for categorical variables wiht multiclass classification") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(
Classification,
Gini,
maxDepth = 3,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2),
numClassesForClassification = 3)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)

// Expecting 2^3 - 1 = 7 bins/splits
assert(splits(0)(0).feature === 0)
assert(splits(0)(0).threshold === Double.MinValue)
assert(splits(0)(0).featureType === Categorical)
assert(splits(0)(0).categories.length === 1)
assert(splits(0)(0).categories.contains(0.0))
assert(splits(1)(0).feature === 1)
assert(splits(1)(0).threshold === Double.MinValue)
assert(splits(1)(0).featureType === Categorical)
assert(splits(1)(0).categories.length === 1)
assert(splits(1)(0).categories.contains(0.0))

assert(splits(0)(1).feature === 0)
assert(splits(0)(1).threshold === Double.MinValue)
assert(splits(0)(1).featureType === Categorical)
assert(splits(0)(1).categories.length === 1)
assert(splits(0)(1).categories.contains(1.0))
assert(splits(1)(1).feature === 1)
assert(splits(1)(1).threshold === Double.MinValue)
assert(splits(1)(1).featureType === Categorical)
assert(splits(1)(1).categories.length === 1)
assert(splits(1)(1).categories.contains(1.0))

assert(splits(0)(2).feature === 0)
assert(splits(0)(2).threshold === Double.MinValue)
assert(splits(0)(2).featureType === Categorical)
assert(splits(0)(2).categories.length === 2)
assert(splits(0)(2).categories.contains(0.0))
assert(splits(0)(2).categories.contains(1.0))
assert(splits(1)(2).feature === 1)
assert(splits(1)(2).threshold === Double.MinValue)
assert(splits(1)(2).featureType === Categorical)
assert(splits(1)(2).categories.length === 2)
assert(splits(1)(2).categories.contains(0.0))
assert(splits(1)(2).categories.contains(1.0))

assert(splits(0)(3) === null)


// Check bins.

assert(bins(0)(0).category === Double.MinValue)
assert(bins(0)(0).lowSplit.categories.length === 0)
assert(bins(0)(0).highSplit.categories.length === 1)
assert(bins(0)(0).highSplit.categories.contains(0.0))
assert(bins(1)(0).category === Double.MinValue)
assert(bins(1)(0).lowSplit.categories.length === 0)
assert(bins(1)(0).highSplit.categories.length === 1)
assert(bins(1)(0).highSplit.categories.contains(0.0))

assert(bins(0)(1).category === Double.MinValue)
assert(bins(0)(1).lowSplit.categories.length === 1)
assert(bins(0)(1).lowSplit.categories.contains(0.0))
assert(bins(0)(1).highSplit.categories.length === 1)
assert(bins(0)(1).highSplit.categories.contains(1.0))
assert(bins(1)(1).category === Double.MinValue)
assert(bins(1)(1).lowSplit.categories.length === 1)
assert(bins(1)(1).lowSplit.categories.contains(0.0))
assert(bins(1)(1).highSplit.categories.length === 1)
assert(bins(1)(1).highSplit.categories.contains(1.0))

assert(bins(0)(2).category === Double.MinValue)
assert(bins(0)(2).lowSplit.categories.length === 1)
assert(bins(0)(2).lowSplit.categories.contains(1.0))
assert(bins(0)(2).highSplit.categories.length === 2)
assert(bins(0)(2).highSplit.categories.contains(1.0))
assert(bins(0)(2).highSplit.categories.contains(0.0))
assert(bins(1)(2).category === Double.MinValue)
assert(bins(1)(2).lowSplit.categories.length === 1)
assert(bins(1)(2).lowSplit.categories.contains(1.0))
assert(bins(1)(2).highSplit.categories.length === 2)
assert(bins(1)(2).highSplit.categories.contains(1.0))
assert(bins(1)(2).highSplit.categories.contains(0.0))

assert(bins(0)(3) === null)
assert(bins(1)(3) === null)

}

test("split and bin calculations for categorical variables with no sample for one category " +
"for multiclass classification") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(
Classification,
Gini,
maxDepth = 3,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3),
numClassesForClassification = 3)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)

}

test("classification stump with all categorical variables") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
Expand Down Expand Up @@ -430,29 +544,29 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {

object DecisionTreeSuite {

def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
def generateOrderedLabeledPointsWithLabel0(): Array[WeightedLabeledPoint] = {
val arr = new Array[WeightedLabeledPoint](1000)
for (i <- 0 until 1000) {
val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
val lp = new WeightedLabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
arr(i) = lp
}
arr
}

def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
def generateOrderedLabeledPointsWithLabel1(): Array[WeightedLabeledPoint] = {
val arr = new Array[WeightedLabeledPoint](1000)
for (i <- 0 until 1000) {
val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i))
val lp = new WeightedLabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i))
arr(i) = lp
}
arr
}

def generateOrderedLabeledPoints(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
def generateOrderedLabeledPoints(): Array[WeightedLabeledPoint] = {
val arr = new Array[WeightedLabeledPoint](1000)
for (i <- 0 until 1000) {
if (i < 600) {
val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
val lp = new WeightedLabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
arr(i) = lp
} else {
val lp = new WeightedLabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i))
Expand All @@ -462,11 +576,11 @@ object DecisionTreeSuite {
arr
}

def generateCategoricalDataPoints(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
def generateCategoricalDataPoints(): Array[WeightedLabeledPoint] = {
val arr = new Array[WeightedLabeledPoint](1000)
for (i <- 0 until 1000) {
if (i < 600) {
arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0))
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(0.0, 1.0))
} else {
arr(i) = new WeightedLabeledPoint(0.0, Vectors.dense(1.0, 0.0))
}
Expand Down

0 comments on commit 3f85a17

Please sign in to comment.