From 46e06ee0ceb223aee50fa811a35d25090a5c4d42 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 6 May 2014 18:05:58 -0700 Subject: [PATCH] minor mods --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 6c7097cdb5a95..49b821d589071 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -562,6 +562,7 @@ object DecisionTree extends Serializable with Logging { val label = arr(0) // Iterate over all features. var featureIndex = 0 + // TODO: Multiclass modification here while (featureIndex < numFeatures) { // Find the bin index for this feature. val arrShift = 1 + numFeatures * nodeIndex @@ -569,10 +570,8 @@ object DecisionTree extends Serializable with Logging { // Update the left or right count for one bin. val aggShift = 2 * numBins * numFeatures * nodeIndex val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 - // TODO: Multiclass modification here label match { - case 0.0 => agg(aggIndex) = agg(aggIndex) + 1 - case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 + case n: Double => agg(aggIndex) = agg(aggIndex + n.toInt) + 1 } featureIndex += 1 }