diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 9cc7c494f9d64..bba67bb5894a9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,7 +17,6 @@ package org.apache.spark.mllib.tree -import scala.util.control.Breaks._ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.tree.model._ @@ -29,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} +import scala.util.control.Breaks._ /** * A class that implements a decision tree algorithm for classification and regression. It @@ -181,8 +181,8 @@ object DecisionTree extends Serializable with Logging { input: RDD[LabeledPoint], algo: Algo, impurity: Impurity, - maxDepth: Int - ): DecisionTreeModel = { + maxDepth: Int) + : DecisionTreeModel = { val strategy = new Strategy(algo,impurity,maxDepth) new DecisionTree(strategy).train(input: RDD[LabeledPoint]) } @@ -211,8 +211,8 @@ object DecisionTree extends Serializable with Logging { maxDepth: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, - categoricalFeaturesInfo: Map[Int,Int] - ): DecisionTreeModel = { + categoricalFeaturesInfo: Map[Int,Int]) + : DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) new DecisionTree(strategy).train(input: RDD[LabeledPoint]) @@ -238,7 +238,8 @@ object DecisionTree extends Serializable with Logging { level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], - bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = { + bins: Array[Array[Bin]]) + : Array[(Split, InformationGainStats)] = { //Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt @@ -301,7 +302,8 @@ object DecisionTree extends Serializable with Logging { def findBin( featureIndex: Int, labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean): Int = { + isFeatureContinuous: Boolean) + : Int = { if (isFeatureContinuous){ for (binIndex <- 0 until strategy.numBins) { @@ -515,7 +517,8 @@ object DecisionTree extends Serializable with Logging { featureIndex: Int, splitIndex: Int, rightNodeAgg: Array[Array[Double]], - topImpurity: Double): InformationGainStats = { + topImpurity: Double) + : InformationGainStats = { strategy.algo match { case Classification => { @@ -694,7 +697,8 @@ object DecisionTree extends Serializable with Logging { def calculateGainsForAllNodeSplits( leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], - nodeImpurity: Double): Array[Array[InformationGainStats]] = { + nodeImpurity: Double) + : Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) @@ -715,7 +719,8 @@ object DecisionTree extends Serializable with Logging { */ def binsToBestSplit( binData: Array[Double], - nodeImpurity: Double): (Split, InformationGainStats) = { + nodeImpurity: Double) + : (Split, InformationGainStats) = { logDebug("node impurity = " + nodeImpurity) val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) @@ -786,7 +791,8 @@ object DecisionTree extends Serializable with Logging { */ def findSplitsBins( input: RDD[LabeledPoint], - strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + strategy: Strategy) + : (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() @@ -947,12 +953,11 @@ object DecisionTree extends Serializable with Logging { } val options = nextOption(Map(),arglist) logDebug(options.toString()) - //TODO: Add validation for input parameters //Load training data val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) - //Figure out the type of algorithm + //Identify the type of algorithm val algoStr = options.get('algo).get.toString val algo = algoStr match { case "Classification" => Classification @@ -1007,7 +1012,10 @@ object DecisionTree extends Serializable with Logging { } } - //TODO: Port them to a metrics package + //TODO: Port this method to a generic metrics package + /** + * Calculates the classifier accuracy. + */ def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { val correctCount = data.filter(y => model.predict(y.features) == y.label).count() val count = data.count() @@ -1016,7 +1024,10 @@ object DecisionTree extends Serializable with Logging { correctCount.toDouble / count } - //TODO: Make these generic MLTable metrics + //TODO: Port this method to a generic metrics package + /** + * Calculates the mean squared error for regression + */ def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { val meanSumOfSquares = data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label))