diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala deleted file mode 100644 index a6f27e7fd1111..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationTree.scala +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.mllib.classification - -class ClassificationTree { - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index d8ffa12030f8d..b8cfe03aa151b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -285,7 +285,7 @@ object DecisionTree extends Serializable with Logging { /*Combines the aggregates from partitions @param agg1 Array containing aggregates from one or more partitions - @param agg2 Array contianing aggregates from one or more partitions + @param agg2 Array containing aggregates from one or more partitions @return Combined aggregate from agg1 and agg2 */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala index 542a3d9c3b33d..d46733336d558 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala @@ -17,25 +17,69 @@ package org.apache.spark.mllib.tree import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.mllib.tree.impurity.Gini +import org.apache.spark.mllib.tree.impurity.{Gini,Entropy,Variance} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.DecisionTreeModel object DecisionTreeRunner extends Logging { + val usage = """ + Usage: DecisionTreeRunner [slices] --kind --trainDataDir path --testDataDir path [--maxDepth num] [--impurity ] [--maxBins num] + """ + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println(usage) + System.exit(1) + } + val sc = new SparkContext(args(0), "DecisionTree") - val data = loadLabeledData(sc, args(1)) - val maxDepth = args(2).toInt - val maxBins = args(3).toInt - val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = maxDepth, maxBins = maxBins) - val model = new DecisionTree(strategy).train(data) - val accuracy = accuracyScore(model, data) + val arglist = args.toList.drop(1) + type OptionMap = Map[Symbol, Any] + + def nextOption(map : OptionMap, list: List[String]) : OptionMap = { + def isSwitch(s : String) = (s(0) == '-') + list match { + case Nil => map + case "--kind" :: string :: tail => nextOption(map ++ Map('kind -> string), tail) + case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) + case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) + case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail) + case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string), tail) + case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), tail) + case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) + case option :: tail => println("Unknown option "+option) + exit(1) + } + } + val options = nextOption(Map(),arglist) + logDebug(options.toString()) + + val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) + + val typeStr = options.get('type).toString + //TODO: Create enum + val impurityStr = options.getOrElse('impurity,if (typeStr == "classification") "Gini" else "Variance").toString + val impurity = { + impurityStr match { + case "Gini" => Gini + case "Entropy" => Entropy + case "Variance" => Variance + } + } + val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt + val maxBins = options.getOrElse('maxBins,"100").toString.toInt + + val strategy = new Strategy(kind = typeStr, impurity = Gini, maxDepth = maxDepth, maxBins = maxBins) + val model = new DecisionTree(strategy).train(trainData) + + val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) + val accuracy = accuracyScore(model, testData) logDebug("accuracy = " + accuracy) sc.stop()