diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index e7a55d52e7367..664abf742d4a1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -133,7 +133,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 3, numClassesForClassification = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) // Check splits. @@ -483,7 +483,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.predict === 1) } - test("test second level node building with/without groups") { + test("second level node building with/without groups") { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -529,6 +529,33 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } + test("stump with continuous variables for multiclass classification") { + assert(true==true) + } + + test("stump with categorical variables for multiclass classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + assert(bestSplit.feature === 0) + assert(bestSplit.categories.length === 1) + assert(bestSplit.categories.contains(0)) + assert(bestSplit.featureType === Categorical) + println(bestSplit) + } + + test("stump with continuous + categorical variables for multiclass classification") { + assert(true==true) + } + } object DecisionTreeSuite { @@ -576,4 +603,22 @@ object DecisionTreeSuite { } arr } + + def generateCategoricalDataPointsForMulticlass(): Array[WeightedLabeledPoint] = { + val arr = new Array[WeightedLabeledPoint](3000) + for (i <- 0 until 3000) { + if (i < 1000) { + arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + } else if (i < 2000) { + arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0)) + } else { + arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + } + } + println(arr(0)) + println(arr(1000)) + println(arr(2000)) + arr + } + }