Skip to content

Commit

Permalink
minor improvements to docs and style
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed Mar 10, 2014
1 parent eb8fcbe commit 794ff4d
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.mllib.tree

import scala.util.control.Breaks._
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.tree.model._
Expand All @@ -29,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity}
import scala.util.control.Breaks._

/**
* A class that implements a decision tree algorithm for classification and regression. It
Expand Down Expand Up @@ -181,8 +181,8 @@ object DecisionTree extends Serializable with Logging {
input: RDD[LabeledPoint],
algo: Algo,
impurity: Impurity,
maxDepth: Int
): DecisionTreeModel = {
maxDepth: Int)
: DecisionTreeModel = {
val strategy = new Strategy(algo,impurity,maxDepth)
new DecisionTree(strategy).train(input: RDD[LabeledPoint])
}
Expand Down Expand Up @@ -211,8 +211,8 @@ object DecisionTree extends Serializable with Logging {
maxDepth: Int,
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,
categoricalFeaturesInfo: Map[Int,Int]
): DecisionTreeModel = {
categoricalFeaturesInfo: Map[Int,Int])
: DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy,
categoricalFeaturesInfo)
new DecisionTree(strategy).train(input: RDD[LabeledPoint])
Expand All @@ -238,7 +238,8 @@ object DecisionTree extends Serializable with Logging {
level: Int,
filters: Array[List[Filter]],
splits: Array[Array[Split]],
bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = {
bins: Array[Array[Bin]])
: Array[(Split, InformationGainStats)] = {

//Common calculations for multiple nested methods
val numNodes = scala.math.pow(2, level).toInt
Expand Down Expand Up @@ -301,7 +302,8 @@ object DecisionTree extends Serializable with Logging {
def findBin(
featureIndex: Int,
labeledPoint: LabeledPoint,
isFeatureContinuous: Boolean): Int = {
isFeatureContinuous: Boolean)
: Int = {

if (isFeatureContinuous){
for (binIndex <- 0 until strategy.numBins) {
Expand Down Expand Up @@ -515,7 +517,8 @@ object DecisionTree extends Serializable with Logging {
featureIndex: Int,
splitIndex: Int,
rightNodeAgg: Array[Array[Double]],
topImpurity: Double): InformationGainStats = {
topImpurity: Double)
: InformationGainStats = {

strategy.algo match {
case Classification => {
Expand Down Expand Up @@ -694,7 +697,8 @@ object DecisionTree extends Serializable with Logging {
def calculateGainsForAllNodeSplits(
leftNodeAgg: Array[Array[Double]],
rightNodeAgg: Array[Array[Double]],
nodeImpurity: Double): Array[Array[InformationGainStats]] = {
nodeImpurity: Double)
: Array[Array[InformationGainStats]] = {

val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)

Expand All @@ -715,7 +719,8 @@ object DecisionTree extends Serializable with Logging {
*/
def binsToBestSplit(
binData: Array[Double],
nodeImpurity: Double): (Split, InformationGainStats) = {
nodeImpurity: Double)
: (Split, InformationGainStats) = {

logDebug("node impurity = " + nodeImpurity)
val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
Expand Down Expand Up @@ -786,7 +791,8 @@ object DecisionTree extends Serializable with Logging {
*/
def findSplitsBins(
input: RDD[LabeledPoint],
strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {
strategy: Strategy)
: (Array[Array[Split]], Array[Array[Bin]]) = {

val count = input.count()

Expand Down Expand Up @@ -947,12 +953,11 @@ object DecisionTree extends Serializable with Logging {
}
val options = nextOption(Map(),arglist)
logDebug(options.toString())
//TODO: Add validation for input parameters

//Load training data
val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString)

//Figure out the type of algorithm
//Identify the type of algorithm
val algoStr = options.get('algo).get.toString
val algo = algoStr match {
case "Classification" => Classification
Expand Down Expand Up @@ -1007,7 +1012,10 @@ object DecisionTree extends Serializable with Logging {
}
}

//TODO: Port them to a metrics package
//TODO: Port this method to a generic metrics package
/**
* Calculates the classifier accuracy.
*/
def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
val count = data.count()
Expand All @@ -1016,7 +1024,10 @@ object DecisionTree extends Serializable with Logging {
correctCount.toDouble / count
}

//TODO: Make these generic MLTable metrics
//TODO: Port this method to a generic metrics package
/**
* Calculates the mean squared error for regression
*/
def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
val meanSumOfSquares =
data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label))
Expand Down

0 comments on commit 794ff4d

Please sign in to comment.