From abc5a23bf80d792a345d723b44bff3ee217cd5ac Mon Sep 17 00:00:00 2001 From: Evan Sparks Date: Mon, 21 Apr 2014 18:41:36 -0700 Subject: [PATCH] Parameterizing max memory. --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 8 ++++++-- .../apache/spark/mllib/tree/configuration/Strategy.scala | 3 ++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ad901d4f67398..ffee3fd848955 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -31,6 +31,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.util.Utils.memoryStringToMb import org.apache.spark.mllib.linalg.{Vector, Vectors} /** @@ -79,7 +80,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Calculate level for single group construction // Max memory usage for aggregates - val maxMemoryUsage = scala.math.pow(2, 27).toInt //128MB + val maxMemoryUsage = strategy.maxMemory * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage) val numElementsPerNode = { strategy.algo match { @@ -1158,10 +1159,13 @@ object DecisionTree extends Serializable with Logging { val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt val maxBins = options.getOrElse('maxBins, "100").toString.toInt + val maxMemUsage = memoryStringToMb(options.getOrElse('maxMemory, "128m").toString) - val strategy = new Strategy(algo, impurity, maxDepth, maxBins) + val strategy = new Strategy(algo, impurity, maxDepth, maxBins, maxMemory=maxMemUsage) val model = DecisionTree.train(trainData, strategy) + + // Load test data. val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 8767aca47cd5a..fd7a9ed1514c9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -43,4 +43,5 @@ class Strategy ( val maxDepth: Int, val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, - val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) extends Serializable + val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + val maxMemory: Int = 128) extends Serializable