diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index dd766c12d28a4..2070fa7efd664 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -738,12 +738,15 @@ object DecisionTree extends Serializable with Logging { val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count - val totalCount = leftCount + rightCount - if (totalCount == 0) { - // Return arbitrary prediction. - return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + // If left child or right child doesn't satisfy minimum instances per node, + // then this split is invalid, return invalid information gain stats + if ((leftCount < metadata.minInstancesPerNode) || + (rightCount < metadata.minInstancesPerNode)) { + return InformationGainStats.invalidInformationGainStats } + val totalCount = leftCount + rightCount + val parentNodeAgg = leftImpurityCalculator.copy parentNodeAgg.add(rightImpurityCalculator) // impurity of parent node @@ -763,6 +766,9 @@ object DecisionTree extends Serializable with Logging { val rightWeight = rightCount / totalCount.toDouble val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + if (gain < metadata.minInfoGain) { + return InformationGainStats.invalidInformationGainStats + } new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) } @@ -807,6 +813,9 @@ object DecisionTree extends Serializable with Logging { calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) (splitIdx, gainStats) }.maxBy(_._2.gain) + if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) { + (Split.noSplit, InformationGainStats.invalidInformationGainStats) + } (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (metadata.isUnordered(featureIndex)) { // Unordered categorical feature @@ -820,6 +829,9 @@ object DecisionTree extends Serializable with Logging { calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) + if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) { + (Split.noSplit, InformationGainStats.invalidInformationGainStats) + } (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { // Ordered categorical feature @@ -891,6 +903,9 @@ object DecisionTree extends Serializable with Logging { calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) + if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) { + (Split.noSplit, InformationGainStats.invalidInformationGainStats) + } val categoriesForSplit = categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) val bestFeatureSplit = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index cfc8192a85abd..48b958fb95975 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -61,6 +61,8 @@ class Strategy ( val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + val minInstancesPerNode: Int = 0, + val minInfoGain: Double = 0.0, val maxMemoryInMB: Int = 128) extends Serializable { if (algo == Classification) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index e95add7558bcf..5ceaa8154d11a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -45,7 +45,9 @@ private[tree] class DecisionTreeMetadata( val unorderedFeatures: Set[Int], val numBins: Array[Int], val impurity: Impurity, - val quantileStrategy: QuantileStrategy) extends Serializable { + val quantileStrategy: QuantileStrategy, + val minInstancesPerNode: Int, + val minInfoGain: Double) extends Serializable { def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) @@ -127,7 +129,8 @@ private[tree] object DecisionTreeMetadata { new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, - strategy.impurity, strategy.quantileCalculationStrategy) + strategy.impurity, strategy.quantileCalculationStrategy, + strategy.minInstancesPerNode, strategy.minInfoGain) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index fb12298e0f5d3..dce9d2ec8a5f2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -43,3 +43,8 @@ class InformationGainStats( .format(gain, impurity, leftImpurity, rightImpurity, predict, prob) } } + + +private[tree] object InformationGainStats { + val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, 0.0) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 50fb48b40de3d..da1aefb01c6ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType +import org.apache.spark.mllib.tree.configuration.FeatureType +import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType /** * :: DeveloperApi :: @@ -66,3 +68,7 @@ private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType) private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType) extends Split(feature, Double.MaxValue, featureType, List()) + +private[tree] object Split { + val noSplit = new Split(-1, Double.MinValue, FeatureType.Continuous, List()) +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 8e556c917b2e7..bb3d1e03c69a4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node} +import org.apache.spark.mllib.tree.model.{Split, DecisionTreeModel, Node} import org.apache.spark.mllib.util.LocalSparkContext @@ -684,6 +684,45 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { validateClassifier(model, arr, 0.6) } + test("split must satisfy min instances per node requirements") { + val arr = new Array[LabeledPoint](3) + arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) + arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, minInstancesPerNode = 4) + + val model = DecisionTree.train(input, strategy) + assert(model.topNode.isLeaf) + assert(model.topNode.predict == 0.0) + assert(model.topNode.split.get == Split.noSplit) + val predicts = input.map(p => model.predict(p.features)).collect() + predicts.foreach { predict => + assert(predict == 0.0) + } + } + + test("split must satisfy min info gain requirements") { + val arr = new Array[LabeledPoint](3) + arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) + arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, minInfoGain = 1.0) + + val model = DecisionTree.train(input, strategy) + assert(model.topNode.isLeaf) + assert(model.topNode.predict == 0.0) + assert(model.topNode.split.get == Split.noSplit) + val predicts = input.map(p => model.predict(p.features)).collect() + predicts.foreach { predict => + assert(predict == 0.0) + } + } } object DecisionTreeSuite {