Skip to content

Commit

Permalink
[SPARK-3516] [mllib] DecisionTree: Add minInstancesPerNode, minInfoGa…
Browse files Browse the repository at this point in the history
…in params to example and Python API

Added minInstancesPerNode, minInfoGain params to:
* DecisionTreeRunner.scala example
* Python API (tree.py)

Also:
* Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements"
* small style fixes

CC: mengxr

Author: qiping.lqp <qiping.lqp@alibaba-inc.com>
Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>
Author: chouqin <liqiping1991@gmail.com>

Closes #2349 from jkbradley/chouqin-dt-preprune and squashes the following commits:

61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy.
a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune
95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes
e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune
19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune
f1d11d1 [chouqin] fix typo
c7ebaf1 [chouqin] fix typo
39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test
c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py
0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree
d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1
efcc736 [qiping.lqp] fix bug
10b8012 [qiping.lqp] fix style
6728fad [qiping.lqp] minor fix: remove empty lines
bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune
cadd569 [qiping.lqp] add api docs
46b891f [qiping.lqp] fix bug
e72c7e4 [qiping.lqp] add comments
845c6fa [qiping.lqp] fix style
f195e83 [qiping.lqp] fix style
987cbf4 [qiping.lqp] fix bug
ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain
ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree
  • Loading branch information
qiping.lqp authored and mengxr committed Sep 16, 2014
1 parent 983d6a9 commit fdb302f
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ object DecisionTreeRunner {
maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 32,
minInstancesPerNode: Int = 1,
minInfoGain: Double = 0.0,
fracTest: Double = 0.2)

def main(args: Array[String]) {
Expand All @@ -75,6 +77,13 @@ object DecisionTreeRunner {
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
opt[Int]("minInstancesPerNode")
.text(s"min number of instances required at child nodes to create the parent split," +
s" default: ${defaultParams.minInstancesPerNode}")
.action((x, c) => c.copy(minInstancesPerNode = x))
opt[Double]("minInfoGain")
.text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
.action((x, c) => c.copy(minInfoGain = x))
opt[Double]("fracTest")
.text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
Expand Down Expand Up @@ -179,7 +188,9 @@ object DecisionTreeRunner {
impurity = impurityCalculator,
maxDepth = params.maxDepth,
maxBins = params.maxBins,
numClassesForClassification = numClasses)
numClassesForClassification = numClasses,
minInstancesPerNode = params.minInstancesPerNode,
minInfoGain = params.minInfoGain)
val model = DecisionTree.train(training, strategy)

println(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ class PythonMLLibAPI extends Serializable {
categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
impurityStr: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
maxBins: Int,
minInstancesPerNode: Int,
minInfoGain: Double): DecisionTreeModel = {

val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)

Expand All @@ -316,7 +318,9 @@ class PythonMLLibAPI extends Serializable {
maxDepth = maxDepth,
numClassesForClassification = numClasses,
maxBins = maxBins,
categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap)
categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap,
minInstancesPerNode = minInstancesPerNode,
minInfoGain = minInfoGain)

DecisionTree.train(data, strategy)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ object DecisionTree extends Serializable with Logging {
var groupIndex = 0
var doneTraining = true
while (groupIndex < numGroups) {
val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
topNode, splits, bins, timer, numGroups, groupIndex)
doneTraining = doneTraining && doneTrainingGroup
groupIndex += 1
Expand Down Expand Up @@ -898,7 +898,7 @@ object DecisionTree extends Serializable with Logging {
}
}.maxBy(_._2.gain)

require(predict.isDefined, "must calculate predict for each node")
assert(predict.isDefined, "must calculate predict for each node")

(bestSplit, bestSplitStats, predict.get)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class Strategy (
}
require(minInstancesPerNode >= 1,
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
require(maxMemoryInMB <= 10240,
s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")

val isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@

package org.apache.spark.mllib.tree.model

import org.apache.spark.annotation.DeveloperApi

/**
* :: DeveloperApi ::
* Predicted value for a node
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
@DeveloperApi
private[tree] class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable{
val prob: Double = 0.0) extends Serializable {

override def toString = {
"predict = %f, prob = %f".format(predict, prob)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(gain == InformationGainStats.invalidInformationGainStats)
}

test("don't choose split that doesn't satisfy min instance per node requirements") {
// if a split doesn't satisfy min instances per node requirements,
test("do not choose split that does not satisfy min instance per node requirements") {
// if a split does not satisfy min instances per node requirements,
// this split is invalid, even though the information gain of split is large.
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))
Expand Down
16 changes: 12 additions & 4 deletions python/pyspark/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ class DecisionTree(object):

@staticmethod
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
impurity="gini", maxDepth=5, maxBins=32):
impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0):
"""
Train a DecisionTreeModel for classification.
Expand All @@ -154,6 +155,9 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
:param minInstancesPerNode: Min number of instances required at child nodes to create
the parent split
:param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
"""
sc = data.context
Expand All @@ -164,13 +168,14 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, "classification",
numClasses, categoricalFeaturesInfoJMap,
impurity, maxDepth, maxBins)
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)

@staticmethod
def trainRegressor(data, categoricalFeaturesInfo,
impurity="variance", maxDepth=5, maxBins=32):
impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0):
"""
Train a DecisionTreeModel for regression.
Expand All @@ -185,6 +190,9 @@ def trainRegressor(data, categoricalFeaturesInfo,
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
:param minInstancesPerNode: Min number of instances required at child nodes to create
the parent split
:param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
"""
sc = data.context
Expand All @@ -195,7 +203,7 @@ def trainRegressor(data, categoricalFeaturesInfo,
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, "regression",
0, categoricalFeaturesInfoJMap,
impurity, maxDepth, maxBins)
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)

Expand Down

0 comments on commit fdb302f

Please sign in to comment.