diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1665d0ee1ffb9..1ff8c05bcb790 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -27,6 +27,7 @@ import scala.util.control.Breaks._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.configuration.Algo._ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { @@ -51,6 +52,9 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { //parentImpurities(0) = Double.MinValue val nodes = new Array[Node](maxNumNodes) + logDebug("algo = " + strategy.algo) + + breakable { for (level <- 0 until maxDepth){ @@ -247,8 +251,47 @@ object DecisionTree extends Serializable with Logging { arr } - /* - Performs a sequential aggregation over a partition. + def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { + for (node <- 0 until numNodes) { + val validSignalIndex = 1 + numFeatures * node + val isSampleValidForNode = if (arr(validSignalIndex) != -1) true else false + if (isSampleValidForNode) { + val label = arr(0) + for (feature <- 0 until numFeatures) { + val arrShift = 1 + numFeatures * node + val aggShift = 2 * numSplits * numFeatures * node + val arrIndex = arrShift + feature + val aggIndex = aggShift + 2 * feature * numSplits + arr(arrIndex).toInt * 2 + label match { + case (0.0) => agg(aggIndex) = agg(aggIndex) + 1 + case (1.0) => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 + } + } + } + } + } + + def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { + for (node <- 0 until numNodes) { + val validSignalIndex = 1 + numFeatures * node + val isSampleValidForNode = if (arr(validSignalIndex) != -1) true else false + if (isSampleValidForNode) { + val label = arr(0) + for (feature <- 0 until numFeatures) { + val arrShift = 1 + numFeatures * node + val aggShift = 3 * numSplits * numFeatures * node + val arrIndex = arrShift + feature + val aggIndex = aggShift + 3 * feature * numSplits + arr(arrIndex).toInt * 3 + //count, sum, sum^2 + agg(aggIndex) = agg(aggIndex) + 1 + agg(aggIndex + 1) = agg(aggIndex + 1) + label + agg(aggIndex + 2) = agg(aggIndex + 2) + label*label + } + } + } + } + + /*Performs a sequential aggregation over a partition. for p bins, k features, l nodes (level = log2(l)) storage is of the form: b111_left_count,b111_right_count, .... , bpk1_left_count, bpk1_right_count, .... , bpkl_left_count, bpkl_right_count @@ -256,32 +299,23 @@ object DecisionTree extends Serializable with Logging { @param agg Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification @param arr Array[Double] of size 1+(numFeatures*numNodes) @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification - */ + */ def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { //TODO: Requires logic for regressions - for (node <- 0 until numNodes) { - val validSignalIndex = 1+numFeatures*node - val isSampleValidForNode = if(arr(validSignalIndex) != -1) true else false - if(isSampleValidForNode){ - val label = arr(0) - for (feature <- 0 until numFeatures){ - val arrShift = 1 + numFeatures*node - val aggShift = 2*numSplits*numFeatures*node - val arrIndex = arrShift + feature - val aggIndex = aggShift + 2*feature*numSplits + arr(arrIndex).toInt*2 - label match { - case(0.0) => agg(aggIndex) = agg(aggIndex) + 1 - case(1.0) => agg(aggIndex+1) = agg(aggIndex+1) + 1 - } - } - } + strategy.algo match { + case Classification => classificationBinSeqOp(arr, agg) + //TODO: Implement this + case Regression => regressionBinSeqOp(arr, agg) } agg } - //TODO: This length if different for regression - val binAggregateLength = 2*numSplits * numFeatures * numNodes - logDebug("binAggregageLength = " + binAggregateLength) + //TODO: This length is different for regression + val binAggregateLength = strategy.algo match { + case Classification => 2*numSplits * numFeatures * numNodes + case Regression => 3*numSplits * numFeatures * numNodes + } + logDebug("binAggregateLength = " + binAggregateLength) /*Combines the aggregates from partitions @param agg1 Array containing aggregates from one or more partitions @@ -290,11 +324,22 @@ object DecisionTree extends Serializable with Logging { @return Combined aggregate from agg1 and agg2 */ def binCombOp(agg1 : Array[Double], agg2: Array[Double]) : Array[Double] = { - val combinedAggregate = new Array[Double](binAggregateLength) - for (index <- 0 until binAggregateLength){ - combinedAggregate(index) = agg1(index) + agg2(index) + strategy.algo match { + case Classification => { + val combinedAggregate = new Array[Double](binAggregateLength) + for (index <- 0 until binAggregateLength){ + combinedAggregate(index) = agg1(index) + agg2(index) + } + combinedAggregate + } + case Regression => { + val combinedAggregate = new Array[Double](binAggregateLength) + for (index <- 0 until binAggregateLength){ + combinedAggregate(index) = agg1(index) + agg2(index) + } + combinedAggregate + } } - combinedAggregate } logDebug("input = " + input.count) @@ -302,7 +347,7 @@ object DecisionTree extends Serializable with Logging { logDebug("binMappedRDD.count = " + binMappedRDD.count) //calculate bin aggregates - val binAggregates = binMappedRDD.aggregate(Array.fill[Double](2*numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp) + val binAggregates = binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) logDebug("binAggregates.length = " + binAggregates.length) //binAggregates.foreach(x => logDebug(x)) @@ -312,36 +357,70 @@ object DecisionTree extends Serializable with Logging { index: Int, rightNodeAgg: Array[Array[Double]], topImpurity: Double) : InformationGainStats = { + strategy.algo match { + case Classification => { - val left0Count = leftNodeAgg(featureIndex)(2 * index) - val left1Count = leftNodeAgg(featureIndex)(2 * index + 1) - val leftCount = left0Count + left1Count + val left0Count = leftNodeAgg(featureIndex)(2 * index) + val left1Count = leftNodeAgg(featureIndex)(2 * index + 1) + val leftCount = left0Count + left1Count - val right0Count = rightNodeAgg(featureIndex)(2 * index) - val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) - val rightCount = right0Count + right1Count + val right0Count = rightNodeAgg(featureIndex)(2 * index) + val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) + val rightCount = right0Count + right1Count - val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) + val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) - if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong) - if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0) + if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong) + if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0) - val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) - val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) + val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) + val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } + + new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) } - } + case Regression => { + val leftCount = leftNodeAgg(featureIndex)(3 * index) + val leftSum = leftNodeAgg(featureIndex)(3 * index + 1) + val leftSumSquares = leftNodeAgg(featureIndex)(3 * index + 2) + + val rightCount = rightNodeAgg(featureIndex)(3 * index) + val rightSum = rightNodeAgg(featureIndex)(3 * index + 1) + val rightSumSquares = rightNodeAgg(featureIndex)(3 * index + 2) + + val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(leftCount + rightCount, leftSum + rightSum, leftSumSquares + rightSumSquares) + + if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong) + if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0) + + val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) + val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) + + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) - new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } + new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) + + } + } } /* @@ -352,26 +431,60 @@ object DecisionTree extends Serializable with Logging { each array is of size(numFeature,2*(numSplits-1)) */ def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { - val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) - for (featureIndex <- 0 until numFeatures) { - val shift = 2*featureIndex*numSplits - leftNodeAgg(featureIndex)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(1) = binData(shift + 1) - rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1))) - rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1) - for (splitIndex <- 1 until numSplits - 1) { - leftNodeAgg(featureIndex)(2 * splitIndex) - = binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) - leftNodeAgg(featureIndex)(2 * splitIndex + 1) - = binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) - rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex)) - = binData(shift + (2 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex)) - rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex) + 1) - = binData(shift + (2 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex) + 1) + strategy.algo match { + case Classification => { + + val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) + for (featureIndex <- 0 until numFeatures) { + val shift = 2*featureIndex*numSplits + leftNodeAgg(featureIndex)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(1) = binData(shift + 1) + rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1))) + rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1) + for (splitIndex <- 1 until numSplits - 1) { + leftNodeAgg(featureIndex)(2 * splitIndex) + = binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) + leftNodeAgg(featureIndex)(2 * splitIndex + 1) + = binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) + rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex)) + = binData(shift + (2 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex)) + rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex) + 1) + = binData(shift + (2 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex) + 1) + } + } + (leftNodeAgg, rightNodeAgg) + } + case Regression => { + + val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numSplits - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numSplits - 1)) + for (featureIndex <- 0 until numFeatures) { + val shift = 3*featureIndex*numSplits + leftNodeAgg(featureIndex)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(1) = binData(shift + 1) + leftNodeAgg(featureIndex)(2) = binData(shift + 2) + rightNodeAgg(featureIndex)(3 * (numSplits - 2)) = binData(shift + (3 * (numSplits - 1))) + rightNodeAgg(featureIndex)(3 * (numSplits - 2) + 1) = binData(shift + (3 * (numSplits - 1)) + 1) + rightNodeAgg(featureIndex)(3 * (numSplits - 2) + 2) = binData(shift + (3 * (numSplits - 1)) + 2) + for (splitIndex <- 1 until numSplits - 1) { + leftNodeAgg(featureIndex)(3 * splitIndex) + = binData(shift + 3*splitIndex) + leftNodeAgg(featureIndex)(3 * splitIndex - 3) + leftNodeAgg(featureIndex)(3 * splitIndex + 1) + = binData(shift + 3*splitIndex + 1) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) + leftNodeAgg(featureIndex)(3 * splitIndex + 2) + = binData(shift + 3*splitIndex + 2) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) + rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex)) + = binData(shift + (3 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex)) + rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 1) + = binData(shift + (3 * (numSplits - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 1) + rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 2) + = binData(shift + (3 * (numSplits - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 2) + } + } + (leftNodeAgg, rightNodeAgg) } } - (leftNodeAgg, rightNodeAgg) } def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double) @@ -421,10 +534,24 @@ object DecisionTree extends Serializable with Logging { //Calculate best splits for all nodes at a given level val bestSplits = new Array[(Split, InformationGainStats)](numNodes) + def getBinDataForNode(node: Int): Array[Double] = { + strategy.algo match { + case Classification => { + val shift = 2 * node * numSplits * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 2 * numSplits * numFeatures) + binsForNode + } + case Regression => { + val shift = 3 * node * numSplits * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 3 * numSplits * numFeatures) + binsForNode + } + } + } + for (node <- 0 until numNodes){ val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node - val shift = 2*node*numSplits*numFeatures - val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures) + val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) logDebug("node impurity = " + parentNodeImpurity) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala index 65b5ab1162597..ae18cb0aaa4e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.mllib.tree +import org.apache.spark.SparkContext._ import org.apache.spark.{Logging, SparkContext} import org.apache.spark.mllib.tree.impurity.{Gini,Entropy,Variance} import org.apache.spark.rdd.RDD @@ -95,6 +96,9 @@ object DecisionTreeRunner extends Logging { val accuracy = accuracyScore(model, testData) logDebug("accuracy = " + accuracy) + val mse = meanSquaredError(model,testData) + logDebug("mean square error = " + mse) + sc.stop() } @@ -126,6 +130,14 @@ object DecisionTreeRunner extends Logging { correctCount.toDouble / count } + //TODO: Make these generic MLTable metrics + def meanSquaredError(tree : DecisionTreeModel, data : RDD[LabeledPoint]) : Double = { + val meanSumOfSquares = data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean() + println("meanSumOfSquares = " + meanSumOfSquares) + meanSumOfSquares + } + + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 00feb25e25322..350627e9de1dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.mllib.tree.impurity +import javax.naming.OperationNotSupportedException + object Entropy extends Impurity { def log2(x: Double) = scala.math.log(x) / scala.math.log(2) @@ -31,4 +33,6 @@ object Entropy extends Impurity { } } - } + def calculate(count: Double, sum: Double, sumSquares: Double): Double = + throw new OperationNotSupportedException("Entropy.calculate") +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 3396a015e7858..8befeb5a475f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.mllib.tree.impurity +import javax.naming.OperationNotSupportedException + object Gini extends Impurity { def calculate(c0 : Double, c1 : Double): Double = { @@ -29,4 +31,5 @@ object Gini extends Impurity { } } - } + def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new OperationNotSupportedException("Gini.calculate") +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 4b6e679820f59..cda534b462234 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -20,4 +20,6 @@ trait Impurity extends Serializable { def calculate(c0 : Double, c1 : Double): Double + def calculate(count : Double, sum : Double, sumSquares : Double) : Double + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 98f332122785e..65f5b3702779a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -17,7 +17,14 @@ package org.apache.spark.mllib.tree.impurity import javax.naming.OperationNotSupportedException +import org.apache.spark.Logging -object Variance extends Impurity { +object Variance extends Impurity with Logging { def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate") - } + + def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + val squaredLoss = sumSquares - (sum*sum)/count + squaredLoss/count + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 6097c6d5ac985..5f9aad0de2f65 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -49,7 +49,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Regression,Gini,3,100) + val strategy = new Strategy(Classification,Gini,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(bins.length==2) @@ -62,7 +62,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Regression,Gini,3,100) + val strategy = new Strategy(Classification,Gini,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) @@ -88,7 +88,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Regression,Gini,3,100) + val strategy = new Strategy(Classification,Gini,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) @@ -114,7 +114,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Regression,Entropy,3,100) + val strategy = new Strategy(Classification,Entropy,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) @@ -139,7 +139,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Regression,Entropy,3,100) + val strategy = new Strategy(Classification,Entropy,3,100) val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99)