diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 72c3ab475b61f..4683e6eb966be 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -55,6 +55,8 @@ object DecisionTreeRunner { maxDepth: Int = 5, impurity: ImpurityType = Gini, maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, fracTest: Double = 0.2) def main(args: Array[String]) { @@ -75,6 +77,13 @@ object DecisionTreeRunner { opt[Int]("maxBins") .text(s"max number of bins, default: ${defaultParams.maxBins}") .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) opt[Double]("fracTest") .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) @@ -179,7 +188,9 @@ object DecisionTreeRunner { impurity = impurityCalculator, maxDepth = params.maxDepth, maxBins = params.maxBins, - numClassesForClassification = numClasses) + numClassesForClassification = numClasses, + minInstancesPerNode = params.minInstancesPerNode, + minInfoGain = params.minInfoGain) val model = DecisionTree.train(training, strategy) println(model) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 4343124f102a0..fa0fa69f38634 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -303,7 +303,9 @@ class PythonMLLibAPI extends Serializable { categoricalFeaturesInfoJMap: java.util.Map[Int, Int], impurityStr: String, maxDepth: Int, - maxBins: Int): DecisionTreeModel = { + maxBins: Int, + minInstancesPerNode: Int, + minInfoGain: Double): DecisionTreeModel = { val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) @@ -316,7 +318,9 @@ class PythonMLLibAPI extends Serializable { maxDepth = maxDepth, numClassesForClassification = numClasses, maxBins = maxBins, - categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap) + categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap, + minInstancesPerNode = minInstancesPerNode, + minInfoGain = minInfoGain) DecisionTree.train(data, strategy) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 56bb8812100a7..c7f2576c822b1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -389,7 +389,7 @@ object DecisionTree extends Serializable with Logging { var groupIndex = 0 var doneTraining = true while (groupIndex < numGroups) { - val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level, + val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, timer, numGroups, groupIndex) doneTraining = doneTraining && doneTrainingGroup groupIndex += 1 @@ -898,7 +898,7 @@ object DecisionTree extends Serializable with Logging { } }.maxBy(_._2.gain) - require(predict.isDefined, "must calculate predict for each node") + assert(predict.isDefined, "must calculate predict for each node") (bestSplit, bestSplitStats, predict.get) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 31d1e8ac30eea..caaccbfb8ad16 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -77,6 +77,8 @@ class Strategy ( } require(minInstancesPerNode >= 1, s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") + require(maxMemoryInMB <= 10240, + s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") val isMulticlassClassification = algo == Classification && numClassesForClassification > 2 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index 6fac2be2797bc..d8476b5cd7bc7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -17,18 +17,14 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi - /** - * :: DeveloperApi :: * Predicted value for a node * @param predict predicted value * @param prob probability of the label (classification only) */ -@DeveloperApi private[tree] class Predict( val predict: Double, - val prob: Double = 0.0) extends Serializable{ + val prob: Double = 0.0) extends Serializable { override def toString = { "predict = %f, prob = %f".format(predict, prob) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 1bd7ea05c46c8..2b2e579b992f6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -714,8 +714,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(gain == InformationGainStats.invalidInformationGainStats) } - test("don't choose split that doesn't satisfy min instance per node requirements") { - // if a split doesn't satisfy min instances per node requirements, + test("do not choose split that does not satisfy min instance per node requirements") { + // if a split does not satisfy min instances per node requirements, // this split is invalid, even though the information gain of split is large. val arr = new Array[LabeledPoint](4) arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0)) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index ccc000ac70ba6..5b13ab682bbfc 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -138,7 +138,8 @@ class DecisionTree(object): @staticmethod def trainClassifier(data, numClasses, categoricalFeaturesInfo, - impurity="gini", maxDepth=5, maxBins=32): + impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, + minInfoGain=0.0): """ Train a DecisionTreeModel for classification. @@ -154,6 +155,9 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, E.g., depth 0 means 1 leaf node. Depth 1 means 1 internal node + 2 leaf nodes. :param maxBins: Number of bins used for finding splits at each node. + :param minInstancesPerNode: Min number of instances required at child nodes to create + the parent split + :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel """ sc = data.context @@ -164,13 +168,14 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( dataBytes._jrdd, "classification", numClasses, categoricalFeaturesInfoJMap, - impurity, maxDepth, maxBins) + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) dataBytes.unpersist() return DecisionTreeModel(sc, model) @staticmethod def trainRegressor(data, categoricalFeaturesInfo, - impurity="variance", maxDepth=5, maxBins=32): + impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, + minInfoGain=0.0): """ Train a DecisionTreeModel for regression. @@ -185,6 +190,9 @@ def trainRegressor(data, categoricalFeaturesInfo, E.g., depth 0 means 1 leaf node. Depth 1 means 1 internal node + 2 leaf nodes. :param maxBins: Number of bins used for finding splits at each node. + :param minInstancesPerNode: Min number of instances required at child nodes to create + the parent split + :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel """ sc = data.context @@ -195,7 +203,7 @@ def trainRegressor(data, categoricalFeaturesInfo, model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( dataBytes._jrdd, "regression", 0, categoricalFeaturesInfoJMap, - impurity, maxDepth, maxBins) + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) dataBytes.unpersist() return DecisionTreeModel(sc, model)