Skip to content

Commit

Permalink
adding enum for feature type
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <manish9ue@gmail.com>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent 154aa77 commit b0e3e76
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 31 deletions.
43 changes: 23 additions & 20 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.mllib.tree.model.Split
import scala.util.control.Breaks._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.FeatureType._


class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
Expand Down Expand Up @@ -353,21 +354,13 @@ object DecisionTree extends Serializable with Logging {
def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = {
val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1))
val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1))
//logDebug("binData.length = " + binData.length)
//logDebug("binData.sum = " + binData.sum)
for (featureIndex <- 0 until numFeatures) {
//logDebug("featureIndex = " + featureIndex)
val shift = 2*featureIndex*numSplits
leftNodeAgg(featureIndex)(0) = binData(shift + 0)
//logDebug("binData(shift + 0) = " + binData(shift + 0))
leftNodeAgg(featureIndex)(1) = binData(shift + 1)
//logDebug("binData(shift + 1) = " + binData(shift + 1))
rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1)))
//logDebug(binData(shift + (2 * (numSplits - 1))))
rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1)
//logDebug(binData(shift + (2 * (numSplits - 1)) + 1))
for (splitIndex <- 1 until numSplits - 1) {
//logDebug("splitIndex = " + splitIndex)
leftNodeAgg(featureIndex)(2 * splitIndex)
= binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2)
leftNodeAgg(featureIndex)(2 * splitIndex + 1)
Expand Down Expand Up @@ -479,33 +472,43 @@ object DecisionTree extends Serializable with Logging {

//Find all splits
for (featureIndex <- 0 until numFeatures){
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted

val stride : Double = numSamples.toDouble/numBins
logDebug("stride = " + stride)
for (index <- 0 until numBins-1) {
val sampleIndex = (index+1)*stride.toInt
val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous")
splits(featureIndex)(index) = split
val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinous) {
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted

val stride : Double = numSamples.toDouble/numBins
logDebug("stride = " + stride)
for (index <- 0 until numBins-1) {
val sampleIndex = (index+1)*stride.toInt
val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous)
splits(featureIndex)(index) = split
}
} else {
val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
for (index <- 0 until maxFeatureValue){
//TODO: Sort by centriod
val split = new Split(featureIndex,index,Categorical)
splits(featureIndex)(index) = split
}
}
}

//Find all bins
for (featureIndex <- 0 until numFeatures){
bins(featureIndex)(0)
= new Bin(new DummyLowSplit("continuous"),splits(featureIndex)(0),"continuous")
= new Bin(new DummyLowSplit(Continuous),splits(featureIndex)(0),Continuous)
for (index <- 1 until numBins - 1){
val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),"continuous")
val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Continuous)
bins(featureIndex)(index) = bin
}
bins(featureIndex)(numBins-1)
= new Bin(splits(featureIndex)(numBins-3),new DummyHighSplit("continuous"),"continuous")
= new Bin(splits(featureIndex)(numBins-3),new DummyHighSplit(Continuous),Continuous)
}

(splits,bins)
}
case MinMax => {
(Array.ofDim[Split](numFeatures,numBins),Array.ofDim[Bin](numFeatures,numBins+2))
throw new UnsupportedOperationException("minmax not supported yet.")
}
case ApproxHist => {
throw new UnsupportedOperationException("approximate histogram not supported yet.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class Strategy (
val impurity : Impurity,
val maxDepth : Int,
val maxBins : Int,
val quantileCalculationStrategy : QuantileStrategy = Sort) extends Serializable {
val quantileCalculationStrategy : QuantileStrategy = Sort,
val categoricalFeaturesInfo : Map[Int,Int] = Map[Int,Int]()) extends Serializable {

var numBins : Int = Int.MinValue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.mllib.tree.model

case class Bin(lowSplit : Split, highSplit : Split, kind : String) {
import org.apache.spark.mllib.tree.configuration.FeatureType._

case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType) {

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
*/
package org.apache.spark.mllib.tree.model

case class Split(feature: Int, threshold : Double, kind : String){
override def toString = "Feature = " + feature + ", threshold = " + threshold + ", kind = " + kind
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType

case class Split(feature: Int, threshold : Double, featureType : FeatureType){
override def toString = "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType
}

class DummyLowSplit(kind : String) extends Split(Int.MinValue, Double.MinValue, kind)
class DummyLowSplit(kind : FeatureType) extends Split(Int.MinValue, Double.MinValue, kind)

class DummyHighSplit(kind : String) extends Split(Int.MaxValue, Double.MaxValue, kind)
class DummyHighSplit(kind : FeatureType) extends Split(Int.MaxValue, Double.MaxValue, kind)

Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini}
import org.apache.spark.mllib.tree.model.Filter
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._

class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {

Expand All @@ -48,7 +49,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy("regression",Gini,3,100,"sort")
val strategy = new Strategy(Regression,Gini,3,100)
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
assert(splits.length==2)
assert(bins.length==2)
Expand All @@ -61,7 +62,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy("regression",Gini,3,100,"sort")
val strategy = new Strategy(Regression,Gini,3,100)
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
assert(splits.length==2)
assert(splits(0).length==99)
Expand All @@ -87,7 +88,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy("regression",Gini,3,100,"sort")
val strategy = new Strategy(Regression,Gini,3,100)
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
assert(splits.length==2)
assert(splits(0).length==99)
Expand All @@ -113,7 +114,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy("regression",Entropy,3,100,"sort")
val strategy = new Strategy(Regression,Entropy,3,100)
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
assert(splits.length==2)
assert(splits(0).length==99)
Expand All @@ -138,7 +139,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy("regression",Entropy,3,100,"sort")
val strategy = new Strategy(Regression,Entropy,3,100)
val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy)
assert(splits.length==2)
assert(splits(0).length==99)
Expand Down

0 comments on commit b0e3e76

Please sign in to comment.