From b0e3e76c47b1b449c91832aee2a6e94cee0a7c6b Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 12 Jan 2014 11:45:47 -0800 Subject: [PATCH] adding enum for feature type Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 43 ++++++++++--------- .../mllib/tree/configuration/Strategy.scala | 3 +- .../apache/spark/mllib/tree/model/Bin.scala | 4 +- .../apache/spark/mllib/tree/model/Split.scala | 10 +++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 11 ++--- 5 files changed, 40 insertions(+), 31 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 9cd1597e6fa18..1665d0ee1ffb9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -26,6 +26,7 @@ import org.apache.spark.mllib.tree.model.Split import scala.util.control.Breaks._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.FeatureType._ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { @@ -353,21 +354,13 @@ object DecisionTree extends Serializable with Logging { def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) - //logDebug("binData.length = " + binData.length) - //logDebug("binData.sum = " + binData.sum) for (featureIndex <- 0 until numFeatures) { - //logDebug("featureIndex = " + featureIndex) val shift = 2*featureIndex*numSplits leftNodeAgg(featureIndex)(0) = binData(shift + 0) - //logDebug("binData(shift + 0) = " + binData(shift + 0)) leftNodeAgg(featureIndex)(1) = binData(shift + 1) - //logDebug("binData(shift + 1) = " + binData(shift + 1)) rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1))) - //logDebug(binData(shift + (2 * (numSplits - 1)))) rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1) - //logDebug(binData(shift + (2 * (numSplits - 1)) + 1)) for (splitIndex <- 1 until numSplits - 1) { - //logDebug("splitIndex = " + splitIndex) leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) leftNodeAgg(featureIndex)(2 * splitIndex + 1) @@ -479,33 +472,43 @@ object DecisionTree extends Serializable with Logging { //Find all splits for (featureIndex <- 0 until numFeatures){ - val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - - val stride : Double = numSamples.toDouble/numBins - logDebug("stride = " + stride) - for (index <- 0 until numBins-1) { - val sampleIndex = (index+1)*stride.toInt - val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous") - splits(featureIndex)(index) = split + val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinous) { + val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted + + val stride : Double = numSamples.toDouble/numBins + logDebug("stride = " + stride) + for (index <- 0 until numBins-1) { + val sampleIndex = (index+1)*stride.toInt + val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous) + splits(featureIndex)(index) = split + } + } else { + val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) + for (index <- 0 until maxFeatureValue){ + //TODO: Sort by centriod + val split = new Split(featureIndex,index,Categorical) + splits(featureIndex)(index) = split + } } } //Find all bins for (featureIndex <- 0 until numFeatures){ bins(featureIndex)(0) - = new Bin(new DummyLowSplit("continuous"),splits(featureIndex)(0),"continuous") + = new Bin(new DummyLowSplit(Continuous),splits(featureIndex)(0),Continuous) for (index <- 1 until numBins - 1){ - val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),"continuous") + val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Continuous) bins(featureIndex)(index) = bin } bins(featureIndex)(numBins-1) - = new Bin(splits(featureIndex)(numBins-3),new DummyHighSplit("continuous"),"continuous") + = new Bin(splits(featureIndex)(numBins-3),new DummyHighSplit(Continuous),Continuous) } (splits,bins) } case MinMax => { - (Array.ofDim[Split](numFeatures,numBins),Array.ofDim[Bin](numFeatures,numBins+2)) + throw new UnsupportedOperationException("minmax not supported yet.") } case ApproxHist => { throw new UnsupportedOperationException("approximate histogram not supported yet.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 3c759bbc1c805..281dabd3364d8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -25,7 +25,8 @@ class Strategy ( val impurity : Impurity, val maxDepth : Int, val maxBins : Int, - val quantileCalculationStrategy : QuantileStrategy = Sort) extends Serializable { + val quantileCalculationStrategy : QuantileStrategy = Sort, + val categoricalFeaturesInfo : Map[Int,Int] = Map[Int,Int]()) extends Serializable { var numBins : Int = Int.MinValue diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 25d16a9a2fc2f..13191851956ad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.mllib.tree.model -case class Bin(lowSplit : Split, highSplit : Split, kind : String) { +import org.apache.spark.mllib.tree.configuration.FeatureType._ + +case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType) { } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 1b39154d42e47..01aa349115302 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -16,11 +16,13 @@ */ package org.apache.spark.mllib.tree.model -case class Split(feature: Int, threshold : Double, kind : String){ - override def toString = "Feature = " + feature + ", threshold = " + threshold + ", kind = " + kind +import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType + +case class Split(feature: Int, threshold : Double, featureType : FeatureType){ + override def toString = "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType } -class DummyLowSplit(kind : String) extends Split(Int.MinValue, Double.MinValue, kind) +class DummyLowSplit(kind : FeatureType) extends Split(Int.MinValue, Double.MinValue, kind) -class DummyHighSplit(kind : String) extends Split(Int.MaxValue, Double.MaxValue, kind) +class DummyHighSplit(kind : FeatureType) extends Split(Int.MaxValue, Double.MaxValue, kind) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 86cd1f432d162..6097c6d5ac985 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity.{Entropy, Gini} import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.Algo._ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { @@ -48,7 +49,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy("regression",Gini,3,100,"sort") + val strategy = new Strategy(Regression,Gini,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(bins.length==2) @@ -61,7 +62,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy("regression",Gini,3,100,"sort") + val strategy = new Strategy(Regression,Gini,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) @@ -87,7 +88,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy("regression",Gini,3,100,"sort") + val strategy = new Strategy(Regression,Gini,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) @@ -113,7 +114,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy("regression",Entropy,3,100,"sort") + val strategy = new Strategy(Regression,Entropy,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) @@ -138,7 +139,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy("regression",Entropy,3,100,"sort") + val strategy = new Strategy(Regression,Entropy,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99)