Skip to content

Commit

Permalink
[Fix #79] Replace Breakable For Loops By While Loops
Browse files Browse the repository at this point in the history
Author: Sandeep <sandeep@techaddict.me>

Closes #503 from techaddict/fix-79 and squashes the following commits:

e3f6746 [Sandeep] Style changes
07a4f6b [Sandeep] for loop to While loop
0a6d8e9 [Sandeep] Breakable for loop to While loop
  • Loading branch information
techaddict authored and rxin committed Apr 24, 2014
1 parent 6ab7578 commit bb68f47
Showing 1 changed file with 31 additions and 29 deletions.
60 changes: 31 additions & 29 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.mllib.tree

import scala.util.control.Breaks._

import org.apache.spark.annotation.Experimental
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.SparkContext._
Expand Down Expand Up @@ -82,31 +80,34 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* still survived the filters of the parent nodes.
*/

// TODO: Convert for loop to while loop
breakable {
for (level <- 0 until maxDepth) {

logDebug("#####################################")
logDebug("level = " + level)
logDebug("#####################################")

// Find best split for all nodes at a level.
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
level, filters, splits, bins)

for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
// Extract info for nodes at the current level.
extractNodeInfo(nodeSplitStats, level, index, nodes)
// Extract info for nodes at the next lower level.
extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
filters)
logDebug("final best split = " + nodeSplitStats._1)
}
require(scala.math.pow(2, level) == splitsStatsForLevel.length)
// Check whether all the nodes at the current level at leaves.
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
logDebug("all leaf = " + allLeaf)
if (allLeaf) break // no more tree construction
var level = 0
var break = false
while (level < maxDepth && !break) {

logDebug("#####################################")
logDebug("level = " + level)
logDebug("#####################################")

// Find best split for all nodes at a level.
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
level, filters, splits, bins)

for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
// Extract info for nodes at the current level.
extractNodeInfo(nodeSplitStats, level, index, nodes)
// Extract info for nodes at the next lower level.
extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
filters)
logDebug("final best split = " + nodeSplitStats._1)
}
require(scala.math.pow(2, level) == splitsStatsForLevel.length)
// Check whether all the nodes at the current level at leaves.
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
logDebug("all leaf = " + allLeaf)
if (allLeaf) {
break = true // no more tree construction
} else {
level += 1
}
}

Expand Down Expand Up @@ -146,8 +147,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
parentImpurities: Array[Double],
filters: Array[List[Filter]]): Unit = {
// 0 corresponds to the left child node and 1 corresponds to the right child node.
// TODO: Convert to while loop
for (i <- 0 to 1) {
var i = 0
while (i <= 1) {
// Calculate the index of the node from the node level and the index at the current level.
val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i
if (level < maxDepth - 1) {
Expand All @@ -166,6 +167,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("Filter = " + filter)
}
}
i += 1
}
}
}
Expand Down

0 comments on commit bb68f47

Please sign in to comment.