diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 059a9336b5f9e..085832cf12070 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree +import scala.util.control.Breaks._ + import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.tree.model._ @@ -28,7 +30,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} -import scala.util.control.Breaks._ +import java.util.Random +import org.apache.spark.util.random.XORShiftRandom /** * A class that implements a decision tree algorithm for classification and regression. It @@ -48,32 +51,32 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { - //Cache input RDD for speedup during multiple passes + // Cache input RDD for speedup during multiple passes input.cache() logDebug("algo = " + strategy.algo) - //Finding the splits and the corresponding bins (interval between the splits) using a sample + // Finding the splits and the corresponding bins (interval between the splits) using a sample // of the input data. val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) logDebug("numSplits = " + bins(0).length) - //Noting numBins for the input data + // Noting numBins for the input data strategy.numBins = bins(0).length - //The depth of the decision tree + // The depth of the decision tree val maxDepth = strategy.maxDepth - //The max number of nodes possible given the depth of the tree + // The max number of nodes possible given the depth of the tree val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 - //Initalizing an array to hold filters applied to points for each node + // Initalizing an array to hold filters applied to points for each node val filters = new Array[List[Filter]](maxNumNodes) - //The filter at the top node is an empty list + // The filter at the top node is an empty list filters(0) = List() - //Initializing an array to hold parent impurity calculations for each node + // Initializing an array to hold parent impurity calculations for each node val parentImpurities = new Array[Double](maxNumNodes) - //Dummy value for top node (updated during first split calculation) + // Dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) - //The main-idea here is to perform level-wise training of the decision tree nodes thus + // The main-idea here is to perform level-wise training of the decision tree nodes thus // reducing the passes over the data from l to log2(l) where l is the total number of nodes. // Each data sample is checked for validity w.r.t to each node at a given level -- i.e., // the sample is only used for the split calculation at the node if the sampled would have @@ -85,21 +88,21 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log logDebug("level = " + level) logDebug("#####################################") - //Find best split for all nodes at a level + // Find best split for all nodes at a level val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters, splits, bins) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - //Extract info for nodes at the current level + // Extract info for nodes at the current level extractNodeInfo(nodeSplitStats, level, index, nodes) - //Extract info for nodes at the next lower level + // Extract info for nodes at the next lower level extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) logDebug("final best split = " + nodeSplitStats._1) } require(scala.math.pow(2, level) == splitsStatsForLevel.length) - //Check whether all the nodes at the current level at leaves + // Check whether all the nodes at the current level at leaves val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) if (allLeaf) break //no more tree construction @@ -107,12 +110,12 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log } } - //Initialize the top or root node of the tree + // Initialize the top or root node of the tree val topNode = nodes(0) - //Build the full tree using the node info calculated in the level-wise best split calculations + // Build the full tree using the node info calculated in the level-wise best split calculations topNode.build(nodes) - //Return a decision tree model + // Return a decision tree model return new DecisionTreeModel(topNode, strategy.algo) } @@ -149,7 +152,7 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log // 0 corresponds to the left child node and 1 corresponds to the right child node. for (i <- 0 to 1) { - //Calculating the index of the node from the node level and the index at the current level + // Calculating the index of the node from the node level and the index at the current level val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i if (level < maxDepth - 1) { val impurity = if (i == 0) { @@ -158,9 +161,9 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log nodeSplitStats._2.rightImpurity } logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) - //noting the parent impurities + // noting the parent impurities parentImpurities(nodeIndex) = impurity - //noting the parents filters for the child nodes + // noting the parents filters for the child nodes val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) for (filter <- filters(nodeIndex)) { @@ -236,6 +239,8 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(input: RDD[LabeledPoint]) } + val InvalidBinIndex = -1 + /** * Returns an array of optimal splits for all nodes at a given level * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data @@ -259,16 +264,16 @@ object DecisionTree extends Serializable with Logging { bins: Array[Array[Bin]]) : Array[(Split, InformationGainStats)] = { - //Common calculations for multiple nested methods + // Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt logDebug("numNodes = " + numNodes) - //Find the number of features by looking at the first sample - val numFeatures = input.take(1)(0).features.length + // Find the number of features by looking at the first sample + val numFeatures = input.first().features.length logDebug("numFeatures = " + numFeatures) val numBins = strategy.numBins logDebug("numBins = " + numBins) - /*Find the filters used before reaching the current code*/ + /** Find the filters used before reaching the current code */ def findParentFilters(nodeIndex: Int): List[Filter] = { if (level == 0) { List[Filter]() @@ -284,7 +289,7 @@ object DecisionTree extends Serializable with Logging { */ def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { - //Leaf + // Leaf if ((level > 0) & (parentFilters.length == 0) ){ return false } @@ -360,59 +365,52 @@ object DecisionTree extends Serializable with Logging { */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { - // calculating bin index and label per feature per node val arr = new Array[Double](1 + (numFeatures * numNodes)) arr(0) = labeledPoint.label for (nodeIndex <- 0 until numNodes) { val parentFilters = findParentFilters(nodeIndex) - //Find out whether the sample qualifies for the particular node + // Find out whether the sample qualifies for the particular node val sampleValid = isSampleValid(parentFilters, labeledPoint) val shift = 1 + numFeatures * nodeIndex if (!sampleValid) { - //Add to invalid bin index -1 - breakable { - for (featureIndex <- 0 until numFeatures) { - arr(shift + featureIndex) = -1 - //Breaking since marking one bin is sufficient - break() - } - } + // marking one bin as -1 is sufficient + arr(shift) = InvalidBinIndex } else { - for (featureIndex <- 0 until numFeatures) { - //logDebug("shift+featureIndex =" + (shift+featureIndex)) - val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinous) + var featureIndex = 0 + while (featureIndex < numFeatures){ + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous) + featureIndex += 1 } } - } arr } /** - Performs a sequential aggregation over a partition for classification. - - for p bins, k features, l nodes (level = log2(l)) storage is of the form: - b111_left_count,b111_right_count, .... , .. - .. bpk1_left_count, bpk1_right_count, .... , .. - .. bpkl_left_count, bpkl_right_count - - @param agg Array[Double] storing aggregate calculation of size - 2*numSplits*numFeatures*numNodes for classification - @param arr Array[Double] of size 1+(numFeatures*numNodes) - @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes - for classification - */ + * Performs a sequential aggregation over a partition for classification. + * + * for p bins, k features, l nodes (level = log2(l)) storage is of the form: + * b111_left_count,b111_right_count, .... , .. + * .. bpk1_left_count, bpk1_right_count, .... , .. + * .. bpkl_left_count, bpkl_right_count + * + * @param agg Array[Double] storing aggregate calculation of size + * 2*numSplits*numFeatures*numNodes for classification + * @param arr Array[Double] of size 1+(numFeatures*numNodes) + * @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes + * for classification + */ def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { - for (node <- 0 until numNodes) { - val validSignalIndex = 1 + numFeatures * node - val isSampleValidForNode = if (arr(validSignalIndex) != -1) true else false + for (nodeIndex <- 0 until numNodes) { + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { val label = arr(0) for (featureIndex <- 0 until numFeatures) { - val arrShift = 1 + numFeatures * node - val aggShift = 2 * numBins * numFeatures * node + val arrShift = 1 + numFeatures * nodeIndex + val aggShift = 2 * numBins * numFeatures * nodeIndex val arrIndex = arrShift + featureIndex val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 label match { @@ -425,28 +423,28 @@ object DecisionTree extends Serializable with Logging { } /** - Performs a sequential aggregation over a partition for regression. - - for p bins, k features, l nodes (level = log2(l)) storage is of the form: - b111_count,b111_sum, b111_sum_squares .... , .. - .. bpk1_count, bpk1_sum, bpk1_sum_squares, .... , .. - .. bpkl_count, bpkl_sum, bpkl_sum_squares - - @param agg Array[Double] storing aggregate calculation of size - 3*numSplits*numFeatures*numNodes for classification - @param arr Array[Double] of size 1+(numFeatures*numNodes) - @return Array[Double] storing aggregate calculation of size 3*numSplits*numFeatures*numNodes - for regression - */ + * Performs a sequential aggregation over a partition for regression. + * + * for p bins, k features, l nodes (level = log2(l)) storage is of the form: + * b111_count,b111_sum, b111_sum_squares .... , .. + * .. bpk1_count, bpk1_sum, bpk1_sum_squares, .... , .. + * .. bpkl_count, bpkl_sum, bpkl_sum_squares + * + * @param agg Array[Double] storing aggregate calculation of size + * 3*numSplits*numFeatures*numNodes for classification + * @param arr Array[Double] of size 1+(numFeatures*numNodes) + * @return Array[Double] storing aggregate calculation of size + * 3*numSplits*numFeatures*numNodes for regression + */ def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { - for (node <- 0 until numNodes) { - val validSignalIndex = 1 + numFeatures * node - val isSampleValidForNode = if (arr(validSignalIndex) != -1) true else false + for (nodeIndex <- 0 until numNodes) { + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { val label = arr(0) for (feature <- 0 until numFeatures) { - val arrShift = 1 + numFeatures * node - val aggShift = 3 * numBins * numFeatures * node + val arrShift = 1 + numFeatures * nodeIndex + val aggShift = 3 * numBins * numFeatures * nodeIndex val arrIndex = arrShift + feature val aggIndex = aggShift + 3 * feature * numBins + arr(arrIndex).toInt * 3 //count, sum, sum^2 @@ -513,13 +511,12 @@ object DecisionTree extends Serializable with Logging { logDebug("input = " + input.count) val binMappedRDD = input.map(x => findBinsForLevel(x)) logDebug("binMappedRDD.count = " + binMappedRDD.count) - //calculate bin aggregates + // calculate bin aggregates val binAggregates = { binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) } logDebug("binAggregates.length = " + binAggregates.length) - //binAggregates.foreach(x => logDebug(x)) /** * Calculates the information gain for all splits @@ -578,7 +575,6 @@ object DecisionTree extends Serializable with Logging { } } - //val predict = leftCount / (leftCount + rightCount) val predict = (left1Count + right1Count) / (leftCount + rightCount) new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) @@ -640,7 +636,8 @@ object DecisionTree extends Serializable with Logging { * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) */ def extractLeftRightNodeAggregates( - binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { + binData: Array[Double]) + : (Array[Array[Double]], Array[Array[Double]]) = { strategy.algo match { case Classification => { @@ -747,7 +744,7 @@ object DecisionTree extends Serializable with Logging { val (bestFeatureIndex,bestSplitIndex, gainStats) = { var bestFeatureIndex = 0 var bestSplitIndex = 0 - //Initialization with infeasible values + // Initialization with infeasible values var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,-1.0,-1) for (featureIndex <- 0 until numFeatures) { for (splitIndex <- 0 until numBins - 1){ @@ -767,7 +764,7 @@ object DecisionTree extends Serializable with Logging { (splits(bestFeatureIndex)(bestSplitIndex),gainStats) } - //Calculate best splits for all nodes at a given level + // Calculate best splits for all nodes at a given level val bestSplits = new Array[(Split, InformationGainStats)](numNodes) def getBinDataForNode(node: Int): Array[Double] = { strategy.algo match { @@ -814,28 +811,29 @@ object DecisionTree extends Serializable with Logging { val count = input.count() - //Find the number of features by looking at the first sample + // Find the number of features by looking at the first sample val numFeatures = input.take(1)(0).features.length val maxBins = strategy.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt - logDebug("maxBins = " + numBins) - //Calculate the number of sample for approximate quantile calculation + + // Calculate the number of sample for approximate quantile calculation val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 logDebug("fraction of data used for calculating quantiles = " + fraction) - //sampled input for RDD calculation - val sampledInput = input.sample(false, fraction, 42).collect() + + // sampled input for RDD calculation + val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect() val numSamples = sampledInput.length - val stride: Double = numSamples.toDouble/numBins + val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) strategy.quantileCalculationStrategy match { case Sort => { - val splits = Array.ofDim[Split](numFeatures,numBins-1) - val bins = Array.ofDim[Bin](numFeatures,numBins) + val splits = Array.ofDim[Split](numFeatures, numBins-1) + val bins = Array.ofDim[Bin](numFeatures, numBins) //Find all splits for (featureIndex <- 0 until numFeatures){ @@ -860,10 +858,10 @@ object DecisionTree extends Serializable with Logging { = sampledInput.map(lp => (lp.features(featureIndex),lp.label)) .groupBy(_._1).mapValues(x => x.map(_._2).sum / x.map(_._1).length) - //Checking for missing categorical variables + // Checking for missing categorical variables val fullCentriodForCategories = scala.collection.mutable.Map[Double,Double]() - for (i <- 0 until maxFeatureValue){ - if (centriodForCategories.contains(i)){ + for (i <- 0 until maxFeatureValue) { + if (centriodForCategories.contains(i)) { fullCentriodForCategories(i) = centriodForCategories(i) } else { fullCentriodForCategories(i) = Double.MaxValue @@ -871,14 +869,14 @@ object DecisionTree extends Serializable with Logging { } val categoriesSortedByCentriod - = fullCentriodForCategories.toList sortBy {_._2} + = fullCentriodForCategories.toList.sortBy{_._2} logDebug("centriod for categorical variable = " + categoriesSortedByCentriod) var categoriesForSplit = List[Double]() categoriesSortedByCentriod.iterator.zipWithIndex foreach { case((key, value), index) => { - categoriesForSplit = key:: categoriesForSplit + categoriesForSplit = key :: categoriesForSplit splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) bins(featureIndex)(index) = { @@ -896,10 +894,10 @@ object DecisionTree extends Serializable with Logging { } } - //Find all bins + // Find all bins for (featureIndex <- 0 until numFeatures){ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { //bins for categorical variables are already assigned + if (isFeatureContinuous) { // bins for categorical variables are already assigned bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0), Continuous,Double.MinValue) @@ -933,7 +931,7 @@ object DecisionTree extends Serializable with Logging { val usage = """ - Usage: DecisionTreeRunner[slices] --algo [slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] """ @@ -948,12 +946,10 @@ object DecisionTree extends Serializable with Logging { val sc = new SparkContext(args(0), "DecisionTree") - - val arglist = args.toList.drop(1) + val argList = args.toList.drop(1) type OptionMap = Map[Symbol, Any] def nextOption(map : OptionMap, list: List[String]): OptionMap = { - def isSwitch(s : String) = (s(0) == '-') list match { case Nil => map case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail) @@ -969,20 +965,20 @@ object DecisionTree extends Serializable with Logging { sys.exit(1) } } - val options = nextOption(Map(),arglist) + val options = nextOption(Map(),argList) logDebug(options.toString()) - //Load training data + // Load training data val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) - //Identify the type of algorithm + // Identify the type of algorithm val algoStr = options.get('algo).get.toString val algo = algoStr match { case "Classification" => Classification case "Regression" => Regression } - //Identify the type of impurity + // Identify the type of impurity val impurityStr = options.getOrElse('impurity, if (algo == Classification) "Gini" else "Variance").toString val impurity = impurityStr match { @@ -994,19 +990,22 @@ object DecisionTree extends Serializable with Logging { val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt val maxBins = options.getOrElse('maxBins,"100").toString.toInt - val strategy = new Strategy(algo = algo, impurity = impurity, maxDepth = maxDepth, - maxBins = maxBins) - val model = DecisionTree.train(trainData,strategy) + val strategy = new Strategy(algo, impurity, maxDepth, maxBins) + val model = DecisionTree.train(trainData, strategy) - //Load test data + // Load test data val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) - //Measure algorithm accuracy - val accuracy = accuracyScore(model, testData) - logDebug("accuracy = " + accuracy) + // Measure algorithm accuracy + if (algo == Classification){ + val accuracy = accuracyScore(model, testData) + logDebug("accuracy = " + accuracy) + } - val mse = meanSquaredError(model,testData) - logDebug("mean square error = " + mse) + if (algo == Regression){ + val mse = meanSquaredError(model, testData) + logDebug("mean square error = " + mse) + } sc.stop() } @@ -1030,7 +1029,7 @@ object DecisionTree extends Serializable with Logging { } } - //TODO: Port this method to a generic metrics package + // TODO: Port this method to a generic metrics package /** * Calculates the classifier accuracy. */ @@ -1042,14 +1041,11 @@ object DecisionTree extends Serializable with Logging { correctCount.toDouble / count } - //TODO: Port this method to a generic metrics package + // TODO: Port this method to a generic metrics package /** * Calculates the mean squared error for regression */ def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - val meanSumOfSquares = - data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)) - .mean() - meanSumOfSquares + data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean() } }