Skip to content

Commit

Permalink
TreePoint
Browse files Browse the repository at this point in the history
* Updated doc
* Made some methods private

Changed timer to report time in seconds.
  • Loading branch information
jkbradley committed Aug 14, 2014
1 parent 8464a6e commit e66f1b1
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity}
import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
Expand Down Expand Up @@ -59,8 +59,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo

timer.start("total")

timer.start("init")
// Cache input RDD for speedup during multiple passes.
timer.start("init")
val retaggedInput = input.retag(classOf[LabeledPoint])
logDebug("algo = " + strategy.algo)
timer.stop("init")
Expand All @@ -74,7 +74,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("numBins = " + numBins)

timer.start("init")
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins)
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins).cache()
timer.stop("init")

// depth of the decision tree
Expand All @@ -90,7 +90,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// dummy value for top node (updated during first split calculation)
val nodes = new Array[Node](maxNumNodes)
// num features
val numFeatures = retaggedInput.take(1)(0).features.size
val numFeatures = treeInput.take(1)(0).features.size

// Calculate level for single group construction

Expand Down Expand Up @@ -118,10 +118,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* still survived the filters of the parent nodes.
*/

var findBestSplitsTime: Long = 0
var extractNodeInfoTime: Long = 0
var extractInfoForLowerLevelsTime: Long = 0

var level = 0
var break = false
while (level <= maxDepth && !break) {
Expand Down Expand Up @@ -618,8 +614,6 @@ object DecisionTree extends Serializable with Logging {
true
}

// TODO: REMOVED findBin()

/**
* Finds bins for all nodes (and all features) at a given level.
* For l nodes, k features the storage is as follows:
Expand Down Expand Up @@ -664,11 +658,9 @@ object DecisionTree extends Serializable with Logging {
arr
}

timer.start("findBinsForLevel")

// Find feature bins for all nodes at a level.
timer.start("findBinsForLevel")
val binMappedRDD = input.map(x => findBinsForLevel(x))

timer.stop("findBinsForLevel")

/**
Expand Down Expand Up @@ -1126,7 +1118,6 @@ object DecisionTree extends Serializable with Logging {

val rightChildShift = numClasses * numBins * numFeatures
var splitIndex = 0
var TMPDEBUG = 0.0
while (splitIndex < numBins - 1) {
var classIndex = 0
while (classIndex < numClasses) {
Expand All @@ -1136,7 +1127,6 @@ object DecisionTree extends Serializable with Logging {
val rightBinValue = binData(rightChildShift + shift + classIndex)
leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue
rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue
TMPDEBUG += leftBinValue + rightBinValue
classIndex += 1
}
splitIndex += 1
Expand Down Expand Up @@ -1344,9 +1334,8 @@ object DecisionTree extends Serializable with Logging {
}
}

timer.start("chooseSplits")

// Calculate best splits for all nodes at a given level
timer.start("chooseSplits")
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
// Iterating over all nodes at this level
var node = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ class TimeTracker extends Serializable {
}

/**
* Stops a timer and returns the elapsed time in nanoseconds.
* Stops a timer and returns the elapsed time in seconds.
*/
def stop(timerLabel: String): Long = {
def stop(timerLabel: String): Double = {
val tmpTime = System.nanoTime()
if (!starts.contains(timerLabel)) {
throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" +
Expand All @@ -60,16 +60,16 @@ class TimeTracker extends Serializable {
} else {
totals(timerLabel) = elapsed
}
elapsed
elapsed / 1e9
}

/**
* Print all timing results.
* Print all timing results in seconds.
*/
override def toString: String = {
s"Timing\n" +
totals.map { case (label, elapsed) =>
s" $label: $elapsed"
s" $label: ${elapsed / 1e9}"
}.mkString("\n")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,36 @@ import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.model.Bin
import org.apache.spark.rdd.RDD


/**
* Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
* of size (numFeatures, numBins).
* TODO: ADD DOC
* Internal representation of LabeledPoint for DecisionTree.
* This bins feature values based on a subsampled of data as follows:
* (a) Continuous features are binned into ranges.
* (b) Unordered categorical features are binned based on subsets of feature values.
* "Unordered categorical features" are categorical features with low arity used in
* multiclass classification.
* (c) Ordered categorical features are binned based on feature values.
* "Ordered categorical features" are categorical features with high arity,
* or any categorical feature used in regression or binary classification.
*
* @param label Label from LabeledPoint
* @param features Binned feature values.
* Same length as LabeledPoint.features, but values are bin indices.
*/
private[tree] class TreePoint(val label: Double, val features: Array[Int]) extends Serializable {
}


private[tree] object TreePoint {

/**
* Convert an input dataset into its TreePoint representation,
* binning feature values in preparation for DecisionTree training.
* @param input Input dataset.
* @param strategy DecisionTree training info, used for dataset metadata.
* @param bins Bins for features, of size (numFeatures, numBins).
* @return TreePoint dataset representation
*/
def convertToTreeRDD(
input: RDD[LabeledPoint],
strategy: Strategy,
Expand All @@ -42,7 +62,12 @@ private[tree] object TreePoint {
}
}

def labeledPointToTreePoint(
/**
* Convert one LabeledPoint into its TreePoint representation.
* @param bins Bins for features, of size (numFeatures, numBins).
* @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity
*/
private def labeledPointToTreePoint(
labeledPoint: LabeledPoint,
isMulticlassClassification: Boolean,
bins: Array[Array[Bin]],
Expand Down Expand Up @@ -77,16 +102,11 @@ private[tree] object TreePoint {
/**
* Find bin for one (labeledPoint, feature).
*
* @param featureIndex
* @param labeledPoint
* @param isFeatureContinuous
* @param isUnorderedFeature (only applies if feature is categorical)
* @param bins Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
* of size (numFeatures, numBins).
* @param categoricalFeaturesInfo
* @return
* @param bins Bins for features, of size (numFeatures, numBins).
* @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity
*/
def findBin(
private def findBin(
featureIndex: Int,
labeledPoint: LabeledPoint,
isFeatureContinuous: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification)

val model = DecisionTree.train(input, strategy)
validateClassifier(model, arr, 1.0)
assert(model.numNodes === 3)
assert(model.depth === 1)

val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
Expand All @@ -710,11 +715,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val gain = bestSplits(0)._2
assert(gain.leftImpurity === 0)
assert(gain.rightImpurity === 0)

val model = DecisionTree.train(input, strategy)
validateClassifier(model, arr, 1.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
}

test("stump with continuous variables for multiclass classification") {
Expand Down

0 comments on commit e66f1b1

Please sign in to comment.