diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index aaa5a4fef6697..1c813244e5630 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.impurity.Impurity /** A class that implements a decision tree algorithm for classification and regression. @@ -38,7 +39,7 @@ algorithm (classification, regression, etc.), feature type (continuous, categorical), depth of the tree, quantile calculation strategy, etc. */ -class DecisionTree(val strategy : Strategy) extends Serializable with Logging { +class DecisionTree private (val strategy : Strategy) extends Serializable with Logging { /** Method to train a decision tree model over an RDD @@ -157,6 +158,70 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { object DecisionTree extends Serializable with Logging { + /** + Method to train a decision tree model over an RDD + + @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + for DecisionTree + @param strategy The configuration parameters for the tree algorithm which specify the type of algorithm + (classification, regression, etc.), feature type (continuous, categorical), + depth of the tree, quantile calculation strategy, etc. + @return a DecisionTreeModel that can be used for prediction + */ + def train(input : RDD[LabeledPoint], strategy : Strategy) : DecisionTreeModel = { + new DecisionTree(strategy).train(input : RDD[LabeledPoint]) + } + + /** + Method to train a decision tree model over an RDD + + @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + for DecisionTree + @param algo classification or regression + @param impurity criterion used for information gain calculation + @param maxDepth maximum depth of the tree + @return a DecisionTreeModel that can be used for prediction + */ + def train( + input : RDD[LabeledPoint], + algo : Algo, + impurity : Impurity, + maxDepth : Int + ) : DecisionTreeModel = { + val strategy = new Strategy(algo,impurity,maxDepth) + new DecisionTree(strategy).train(input : RDD[LabeledPoint]) + } + + + /** + Method to train a decision tree model over an RDD + + @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + for DecisionTree + @param algo classification or regression + @param impurity criterion used for information gain calculation + @param maxDepth maximum depth of the tree + @param maxBins maximum number of bins used for splitting features + @param quantileCalculationStrategy algorithm for calculating quantiles + @param categoricalFeaturesInfo A map storing information about the categorical variables and the number of discrete + values they take. For example, an entry (n -> k) implies the feature n is + categorical with k categories 0, 1, 2, ... , k-1. It's important to note that + features are zero-indexed. + @return a DecisionTreeModel that can be used for prediction + */ + def train( + input : RDD[LabeledPoint], + algo : Algo, + impurity : Impurity, + maxDepth : Int, + maxBins : Int, + quantileCalculationStrategy : QuantileStrategy, + categoricalFeaturesInfo : Map[Int,Int] + ) : DecisionTreeModel = { + val strategy = new Strategy(algo,impurity,maxDepth,maxBins,quantileCalculationStrategy,categoricalFeaturesInfo) + new DecisionTree(strategy).train(input : RDD[LabeledPoint]) + } + /** Returns an Array[Split] of optimal splits for all nodes at a given level @@ -717,13 +782,13 @@ object DecisionTree extends Serializable with Logging { for DecisionTree @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing parameters for construction the DecisionTree - @return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures, - numSplits-1) and bins is an - Array[Array[Bin]] of size (numFeatures,numSplits1) + @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree.model.Split] of + size (numFeatures,numSplits-1) and bins is an Array of [org.apache.spark.mllib.tree.model.Bin] of + size (numFeatures,numSplits1) */ def findSplitsBins( input : RDD[LabeledPoint], - strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { + strategy : Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala index 05d000f3a3ddc..d93633d26228d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala @@ -87,7 +87,7 @@ object DecisionTreeRunner extends Logging { val maxBins = options.getOrElse('maxBins,"100").toString.toInt val strategy = new Strategy(algo = algo, impurity = impurity, maxDepth = maxDepth, maxBins = maxBins) - val model = new DecisionTree(strategy).train(trainData) + val model = DecisionTree.train(trainData,strategy) //Load test data val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 973aaee49e5fb..88dfa76fc284f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -37,7 +37,7 @@ class Strategy ( val algo : Algo, val impurity : Impurity, val maxDepth : Int, - val maxBins : Int, + val maxBins : Int = 100, val quantileCalculationStrategy : QuantileStrategy = Sort, val categoricalFeaturesInfo : Map[Int,Int] = Map[Int,Int]()) extends Serializable {