diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 9e461cfdbbd08..7c9b4796ed62b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -34,14 +34,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. */ -class Strategy ( - val algo: Algo, - val impurity: Impurity, - val maxDepth: Int, - val maxBins: Int = 100, - val quantileCalculationStrategy: QuantileStrategy = Sort, - val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable { +class Strategy ( + val algo: Algo, + val impurity: Impurity, + val maxDepth: Int, + val maxBins: Int = 100, + val quantileCalculationStrategy: QuantileStrategy = Sort, + val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable { var numBins: Int = Int.MinValue - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 8832d7a6929a9..b93995fcf9441 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.tree.impurity -import java.lang.UnsupportedOperationException - /** * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during * binary classification. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 3f043125a6aba..c0407554a91b3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -17,32 +17,30 @@ package org.apache.spark.mllib.tree.impurity -import java.lang.UnsupportedOperationException - /** - * Class for calculating the [[http://en.wikipedia - * .org/wiki/Decision_tree_learning#Gini_impurity]] during binary classification + * Class for calculating the + * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]] + * during binary classification. */ object Gini extends Impurity { /** - * gini coefficient calculation + * Gini coefficient calculation * @param c0 count of instances with label 0 * @param c1 count of instances with label 1 - * @return gini coefficient value + * @return Gini coefficient value */ - def calculate(c0 : Double, c1 : Double): Double = { + override def calculate(c0: Double, c1: Double): Double = { if (c0 == 0 || c1 == 0) { 0 } else { val total = c0 + c1 val f0 = c0 / total val f1 = c1 / total - 1 - f0*f0 - f1*f1 + 1 - f0 * f0 - f1 * f1 } } def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Gini.calculate") - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 97092c85aea61..a4069063af2ad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.tree.impurity /** - * Trail for calculating information gain + * Trait for calculating information gain. */ trait Impurity extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 35b1c4e5c3727..b74577dcec167 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -17,25 +17,21 @@ package org.apache.spark.mllib.tree.impurity -import java.lang.UnsupportedOperationException - /** * Class for calculating variance during regression */ object Variance extends Impurity { - def calculate(c0: Double, c1: Double): Double - = throw new UnsupportedOperationException("Variance.calculate") + override def calculate(c0: Double, c1: Double): Double = + throw new UnsupportedOperationException("Variance.calculate") /** * variance calculation * @param count number of instances * @param sum sum of labels * @param sumSquares summation of squares of the labels - * @return */ - def calculate(count: Double, sum: Double, sumSquares: Double): Double = { - val squaredLoss = sumSquares - (sum*sum)/count - squaredLoss/count + override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + val squaredLoss = sumSquares - (sum * sum) / count + squaredLoss / count } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 47afe3aed2b1b..a57faa13745f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -30,6 +30,4 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ * @param featureType type of feature -- categorical or continuous * @param category categorical label value accepted in the bin */ -case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) { - -} +case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index a056da77641ee..a8bbf21daec01 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -46,6 +46,4 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable def predict(features: RDD[Array[Double]]): RDD[Double] = { features.map(x => predict(x)) } - - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 64ff826486f5b..99bf79cf12e45 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -36,6 +36,4 @@ class InformationGainStats( "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f" .format(gain, impurity, leftImpurity, rightImpurity, predict) } - - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index c3e5c00c8d53c..ea4693c5c2f4e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -37,7 +37,7 @@ class Node ( val split: Option[Split], var leftNode: Option[Node], var rightNode: Option[Node], - val stats: Option[InformationGainStats]) extends Serializable with Logging{ + val stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + "split = " + split + ", stats = " + stats @@ -46,7 +46,7 @@ class Node ( * build the left node and right nodes if not leaf * @param nodes array of nodes */ - def build(nodes : Array[Node]): Unit = { + def build(nodes: Array[Node]): Unit = { logDebug("building node " + id + " at level " + (scala.math.log(id + 1)/scala.math.log(2)).toInt ) @@ -68,7 +68,7 @@ class Node ( * @param feature feature value * @return predicted value */ - def predictIfLeaf(feature : Array[Double]) : Double = { + def predictIfLeaf(feature: Array[Double]) : Double = { if (isLeaf) { predict } else{ @@ -87,5 +87,4 @@ class Node ( } } } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index fffd68d7a64b5..4e64a81dda74e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -42,7 +42,7 @@ case class Split( * @param feature feature index * @param featureType type of feature -- categorical or continuous */ -class DummyLowSplit(feature: Int, featureType : FeatureType) +class DummyLowSplit(feature: Int, featureType: FeatureType) extends Split(feature, Double.MinValue, featureType, List()) /** @@ -50,7 +50,7 @@ class DummyLowSplit(feature: Int, featureType : FeatureType) * @param feature feature index * @param featureType type of feature -- categorical or continuous */ -class DummyHighSplit(feature: Int, featureType : FeatureType) +class DummyHighSplit(feature: Int, featureType: FeatureType) extends Split(feature, Double.MaxValue, featureType, List()) /** @@ -59,6 +59,6 @@ class DummyHighSplit(feature: Int, featureType : FeatureType) * @param feature feature index * @param featureType type of feature -- categorical or continuous */ -class DummyCategoricalSplit(feature: Int, featureType : FeatureType) +class DummyCategoricalSplit(feature: Int, featureType: FeatureType) extends Split(feature, Double.MaxValue, featureType, List()) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 2dfcdd857b504..a359bf3a76ce1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -41,246 +41,254 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { System.clearProperty("spark.driver.port") } - test("split and bin calculation"){ + test("split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - assert(splits.length==2) - assert(bins.length==2) - assert(splits(0).length==99) - assert(bins(0).length==100) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 99) + assert(bins(0).length === 100) } - test("split and bin calculation for categorical variables"){ + test("split and bin calculation for categorical variables") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 2, - 1-> 2)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - assert(splits.length==2) - assert(bins.length==2) - assert(splits(0).length==99) - assert(bins(0).length==100) - - //Checking splits - - assert(splits(0)(0).feature == 0) - assert(splits(0)(0).threshold == Double.MinValue) - assert(splits(0)(0).featureType == Categorical) - assert(splits(0)(0).categories.length == 1) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + // Check splits. + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) assert(splits(0)(0).categories.contains(1.0)) - - assert(splits(0)(1).feature == 0) - assert(splits(0)(1).threshold == Double.MinValue) - assert(splits(0)(1).featureType == Categorical) - assert(splits(0)(1).categories.length == 2) + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) assert(splits(0)(1).categories.contains(1.0)) assert(splits(0)(1).categories.contains(0.0)) - assert(splits(0)(2) == null) + assert(splits(0)(2) === null) - assert(splits(1)(0).feature == 1) - assert(splits(1)(0).threshold == Double.MinValue) - assert(splits(1)(0).featureType == Categorical) - assert(splits(1)(0).categories.length == 1) + assert(splits(1)(0).feature === 1) + assert(splits(1)(0).threshold === Double.MinValue) + assert(splits(1)(0).featureType === Categorical) + assert(splits(1)(0).categories.length === 1) assert(splits(1)(0).categories.contains(0.0)) - - assert(splits(1)(1).feature == 1) - assert(splits(1)(1).threshold == Double.MinValue) - assert(splits(1)(1).featureType == Categorical) - assert(splits(1)(1).categories.length == 2) + assert(splits(1)(1).feature === 1) + assert(splits(1)(1).threshold === Double.MinValue) + assert(splits(1)(1).featureType === Categorical) + assert(splits(1)(1).categories.length === 2) assert(splits(1)(1).categories.contains(1.0)) assert(splits(1)(1).categories.contains(0.0)) - assert(splits(1)(2) == null) - + assert(splits(1)(2) === null) - // Checks bins + // Check bins. - assert(bins(0)(0).category == 1.0) - assert(bins(0)(0).lowSplit.categories.length == 0) - assert(bins(0)(0).highSplit.categories.length == 1) + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) assert(bins(0)(0).highSplit.categories.contains(1.0)) - assert(bins(0)(1).category == 0.0) - assert(bins(0)(1).lowSplit.categories.length == 1) + assert(bins(0)(1).category === 0.0) + assert(bins(0)(1).lowSplit.categories.length === 1) assert(bins(0)(1).lowSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.length == 2) + assert(bins(0)(1).highSplit.categories.length === 2) assert(bins(0)(1).highSplit.categories.contains(1.0)) assert(bins(0)(1).highSplit.categories.contains(0.0)) - assert(bins(0)(2) == null) + assert(bins(0)(2) === null) - assert(bins(1)(0).category == 0.0) - assert(bins(1)(0).lowSplit.categories.length == 0) - assert(bins(1)(0).highSplit.categories.length == 1) + assert(bins(1)(0).category === 0.0) + assert(bins(1)(0).lowSplit.categories.length === 0) + assert(bins(1)(0).highSplit.categories.length === 1) assert(bins(1)(0).highSplit.categories.contains(0.0)) - assert(bins(1)(1).category == 1.0) - assert(bins(1)(1).lowSplit.categories.length == 1) + assert(bins(1)(1).category === 1.0) + assert(bins(1)(1).lowSplit.categories.length === 1) assert(bins(1)(1).lowSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.length == 2) + assert(bins(1)(1).highSplit.categories.length === 2) assert(bins(1)(1).highSplit.categories.contains(0.0)) assert(bins(1)(1).highSplit.categories.contains(1.0)) - assert(bins(1)(2) == null) - + assert(bins(1)(2) === null) } - test("split and bin calculations for categorical variables with no sample for one category"){ + test("split and bin calculations for categorical variables with no sample for one category") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, - 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - - //Checking splits - - assert(splits(0)(0).feature == 0) - assert(splits(0)(0).threshold == Double.MinValue) - assert(splits(0)(0).featureType == Categorical) - assert(splits(0)(0).categories.length == 1) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + // Check splits. + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) assert(splits(0)(0).categories.contains(1.0)) - assert(splits(0)(1).feature == 0) - assert(splits(0)(1).threshold == Double.MinValue) - assert(splits(0)(1).featureType == Categorical) - assert(splits(0)(1).categories.length == 2) + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) assert(splits(0)(1).categories.contains(1.0)) assert(splits(0)(1).categories.contains(0.0)) - assert(splits(0)(2).feature == 0) - assert(splits(0)(2).threshold == Double.MinValue) - assert(splits(0)(2).featureType == Categorical) - assert(splits(0)(2).categories.length == 3) + assert(splits(0)(2).feature === 0) + assert(splits(0)(2).threshold === Double.MinValue) + assert(splits(0)(2).featureType === Categorical) + assert(splits(0)(2).categories.length === 3) assert(splits(0)(2).categories.contains(1.0)) assert(splits(0)(2).categories.contains(0.0)) assert(splits(0)(2).categories.contains(2.0)) - assert(splits(0)(3) == null) + assert(splits(0)(3) === null) - assert(splits(1)(0).feature == 1) - assert(splits(1)(0).threshold == Double.MinValue) - assert(splits(1)(0).featureType == Categorical) - assert(splits(1)(0).categories.length == 1) + assert(splits(1)(0).feature === 1) + assert(splits(1)(0).threshold === Double.MinValue) + assert(splits(1)(0).featureType === Categorical) + assert(splits(1)(0).categories.length === 1) assert(splits(1)(0).categories.contains(0.0)) - assert(splits(1)(1).feature == 1) - assert(splits(1)(1).threshold == Double.MinValue) - assert(splits(1)(1).featureType == Categorical) - assert(splits(1)(1).categories.length == 2) + assert(splits(1)(1).feature === 1) + assert(splits(1)(1).threshold === Double.MinValue) + assert(splits(1)(1).featureType === Categorical) + assert(splits(1)(1).categories.length === 2) assert(splits(1)(1).categories.contains(1.0)) assert(splits(1)(1).categories.contains(0.0)) - assert(splits(1)(2).feature == 1) - assert(splits(1)(2).threshold == Double.MinValue) - assert(splits(1)(2).featureType == Categorical) - assert(splits(1)(2).categories.length == 3) + assert(splits(1)(2).feature === 1) + assert(splits(1)(2).threshold === Double.MinValue) + assert(splits(1)(2).featureType === Categorical) + assert(splits(1)(2).categories.length === 3) assert(splits(1)(2).categories.contains(1.0)) assert(splits(1)(2).categories.contains(0.0)) assert(splits(1)(2).categories.contains(2.0)) - assert(splits(1)(3) == null) + assert(splits(1)(3) === null) + // Check bins. - // Checks bins - - assert(bins(0)(0).category == 1.0) - assert(bins(0)(0).lowSplit.categories.length == 0) - assert(bins(0)(0).highSplit.categories.length == 1) + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) assert(bins(0)(0).highSplit.categories.contains(1.0)) - assert(bins(0)(1).category == 0.0) - assert(bins(0)(1).lowSplit.categories.length == 1) + assert(bins(0)(1).category === 0.0) + assert(bins(0)(1).lowSplit.categories.length === 1) assert(bins(0)(1).lowSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.length == 2) + assert(bins(0)(1).highSplit.categories.length === 2) assert(bins(0)(1).highSplit.categories.contains(1.0)) assert(bins(0)(1).highSplit.categories.contains(0.0)) - assert(bins(0)(2).category == 2.0) - assert(bins(0)(2).lowSplit.categories.length == 2) + assert(bins(0)(2).category === 2.0) + assert(bins(0)(2).lowSplit.categories.length === 2) assert(bins(0)(2).lowSplit.categories.contains(1.0)) assert(bins(0)(2).lowSplit.categories.contains(0.0)) - assert(bins(0)(2).highSplit.categories.length == 3) + assert(bins(0)(2).highSplit.categories.length === 3) assert(bins(0)(2).highSplit.categories.contains(1.0)) assert(bins(0)(2).highSplit.categories.contains(0.0)) assert(bins(0)(2).highSplit.categories.contains(2.0)) - assert(bins(0)(3) == null) + assert(bins(0)(3) === null) - assert(bins(1)(0).category == 0.0) - assert(bins(1)(0).lowSplit.categories.length == 0) - assert(bins(1)(0).highSplit.categories.length == 1) + assert(bins(1)(0).category === 0.0) + assert(bins(1)(0).lowSplit.categories.length === 0) + assert(bins(1)(0).highSplit.categories.length === 1) assert(bins(1)(0).highSplit.categories.contains(0.0)) - assert(bins(1)(1).category == 1.0) - assert(bins(1)(1).lowSplit.categories.length == 1) + assert(bins(1)(1).category === 1.0) + assert(bins(1)(1).lowSplit.categories.length === 1) assert(bins(1)(1).lowSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.length == 2) + assert(bins(1)(1).highSplit.categories.length === 2) assert(bins(1)(1).highSplit.categories.contains(0.0)) assert(bins(1)(1).highSplit.categories.contains(1.0)) - assert(bins(1)(2).category == 2.0) - assert(bins(1)(2).lowSplit.categories.length == 2) + assert(bins(1)(2).category === 2.0) + assert(bins(1)(2).lowSplit.categories.length === 2) assert(bins(1)(2).lowSplit.categories.contains(0.0)) assert(bins(1)(2).lowSplit.categories.contains(1.0)) - assert(bins(1)(2).highSplit.categories.length == 3) + assert(bins(1)(2).highSplit.categories.length === 3) assert(bins(1)(2).highSplit.categories.contains(0.0)) assert(bins(1)(2).highSplit.categories.contains(1.0)) assert(bins(1)(2).highSplit.categories.contains(2.0)) - assert(bins(1)(3) == null) - - + assert(bins(1)(3) === null) } - test("classification stump with all categorical variables"){ + test("classification stump with all categorical variables") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, - 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins) val split = bestSplits(0)._1 - assert(split.categories.length == 1) + assert(split.categories.length === 1) assert(split.categories.contains(1.0)) - assert(split.featureType == Categorical) - assert(split.threshold == Double.MinValue) + assert(split.featureType === Categorical) + assert(split.threshold === Double.MinValue) val stats = bestSplits(0)._2 assert(stats.gain > 0) assert(stats.predict > 0.4) assert(stats.predict < 0.5) assert(stats.impurity > 0.2) - } - test("regression stump with all categorical variables"){ + test("regression stump with all categorical variables") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Regression,Variance,3,100,categoricalFeaturesInfo = Map(0 -> 3, - 1-> 3)) + val strategy = new Strategy( + Regression, + Variance, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins) val split = bestSplits(0)._1 - assert(split.categories.length == 1) + assert(split.categories.length === 1) assert(split.categories.contains(1.0)) - assert(split.featureType == Categorical) - assert(split.threshold == Double.MinValue) + assert(split.featureType === Categorical) + assert(split.threshold === Double.MinValue) val stats = bestSplits(0)._2 assert(stats.gain > 0) @@ -289,110 +297,104 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(stats.impurity > 0.2) } - - test("stump with fixed label 0 for Gini"){ + test("stump with fixed label 0 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - assert(splits.length==2) - assert(splits(0).length==99) - assert(bins.length==2) - assert(bins(0).length==100) - assert(splits(0).length==99) - assert(bins(0).length==100) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins) - assert(bestSplits.length == 1) - assert(0==bestSplits(0)._1.feature) - assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2.gain) - assert(0==bestSplits(0)._2.leftImpurity) - assert(0==bestSplits(0)._2.rightImpurity) - + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) } - test("stump with fixed label 1 for Gini"){ + test("stump with fixed label 1 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - assert(splits.length==2) - assert(splits(0).length==99) - assert(bins.length==2) - assert(bins(0).length==100) - assert(splits(0).length==99) - assert(bins(0).length==100) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins) - assert(bestSplits.length == 1) - assert(0==bestSplits(0)._1.feature) - assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2.gain) - assert(0==bestSplits(0)._2.leftImpurity) - assert(0==bestSplits(0)._2.rightImpurity) - assert(1==bestSplits(0)._2.predict) - + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 1) } - - test("stump with fixed label 0 for Entropy"){ + test("stump with fixed label 0 for Entropy") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Entropy,3,100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - assert(splits.length==2) - assert(splits(0).length==99) - assert(bins.length==2) - assert(bins(0).length==100) - assert(splits(0).length==99) - assert(bins(0).length==100) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins) - assert(bestSplits.length == 1) - assert(0==bestSplits(0)._1.feature) - assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2.gain) - assert(0==bestSplits(0)._2.leftImpurity) - assert(0==bestSplits(0)._2.rightImpurity) - assert(0==bestSplits(0)._2.predict) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 0) } - test("stump with fixed label 1 for Entropy"){ + test("stump with fixed label 1 for Entropy") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() - assert(arr.length == 1000) + assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Entropy,3,100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - assert(splits.length==2) - assert(splits(0).length==99) - assert(bins.length==2) - assert(bins(0).length==100) - assert(splits(0).length==99) - assert(bins(0).length==100) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins) - assert(bestSplits.length == 1) - assert(0==bestSplits(0)._1.feature) - assert(10==bestSplits(0)._1.threshold) - assert(0==bestSplits(0)._2.gain) - assert(0==bestSplits(0)._2.leftImpurity) - assert(0==bestSplits(0)._2.rightImpurity) - assert(1==bestSplits(0)._2.predict) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 1) } - - } object DecisionTreeSuite { @@ -406,7 +408,6 @@ object DecisionTreeSuite { arr } - def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000){ @@ -427,5 +428,4 @@ object DecisionTreeSuite { } arr } - }