Skip to content

Commit

Permalink
moved RF code to use WeightedEnsembleModel class
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed Oct 26, 2014
1 parent fee06d3 commit d971f73
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -317,7 +317,7 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
*/
private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = {
private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = tree.predict(y.features) - y.label
err * err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
val rfModel = rf.train(input)
rfModel.trees(0)
rfModel.baseLearners(0)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum

/**
* :: Experimental ::
Expand Down Expand Up @@ -458,7 +459,7 @@ object GradientBoosting extends Logging {


// 3. Output classifier
new WeightedEnsembleModel(baseLearners, baseLearnerWeights, strategy.algo)
new WeightedEnsembleModel(baseLearners, baseLearnerWeights, strategy.algo, Sum)

}

Expand Down
28 changes: 15 additions & 13 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker}
import org.apache.spark.mllib.tree.impurity.Impurities
Expand Down Expand Up @@ -78,9 +79,9 @@ private class RandomForest (
/**
* Method to train a decision tree model over an RDD
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @return RandomForestModel that can be used for prediction
* @return WeightedEnsembleModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint]): RandomForestModel = {
def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {

val timer = new TimeTracker()

Expand Down Expand Up @@ -193,7 +194,8 @@ private class RandomForest (
logInfo(s"$timer")

val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
RandomForestModel.build(trees)
val treeWeights = Array.fill[Double](numTrees)(1.0)
new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average)
}

}
Expand All @@ -214,14 +216,14 @@ object RandomForest extends Serializable with Logging {
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* @param seed Random seed for bootstrapping and choosing feature subsets.
* @return RandomForestModel that can be used for prediction
* @return WeightedEnsembleModel that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String,
seed: Int): RandomForestModel = {
seed: Int): WeightedEnsembleModel = {
require(strategy.algo == Classification,
s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
Expand Down Expand Up @@ -252,7 +254,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
* @return RandomForestModel that can be used for prediction
* @return WeightedEnsembleModel that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
Expand All @@ -263,7 +265,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
seed: Int = Utils.random.nextInt()): RandomForestModel = {
seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Classification, impurityType, maxDepth,
numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo)
Expand All @@ -282,7 +284,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
seed: Int): RandomForestModel = {
seed: Int): WeightedEnsembleModel = {
trainClassifier(input.rdd, numClassesForClassification,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
Expand All @@ -302,14 +304,14 @@ object RandomForest extends Serializable with Logging {
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* @param seed Random seed for bootstrapping and choosing feature subsets.
* @return RandomForestModel that can be used for prediction
* @return WeightedEnsembleModel that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String,
seed: Int): RandomForestModel = {
seed: Int): WeightedEnsembleModel = {
require(strategy.algo == Regression,
s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
Expand Down Expand Up @@ -339,7 +341,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
* @return RandomForestModel that can be used for prediction
* @return WeightedEnsembleModel that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
Expand All @@ -349,7 +351,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
seed: Int = Utils.random.nextInt()): RandomForestModel = {
seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Regression, impurityType, maxDepth,
0, maxBins, Sort, categoricalFeaturesInfo)
Expand All @@ -367,7 +369,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
seed: Int): RandomForestModel = {
seed: Int): WeightedEnsembleModel = {
trainRegressor(input.rdd,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.configuration

import org.apache.spark.annotation.DeveloperApi

/**
* :: Experimental ::
* Enum to select ensemble combining strategy for base learners
*/
@DeveloperApi
object EnsembleCombiningStrategy extends Enumeration {
type EnsembleCombiningStrategy = Value
val Sum, Average = Value
}

This file was deleted.

Loading

0 comments on commit d971f73

Please sign in to comment.