Skip to content

Commit

Permalink
[MLLIB] SPARK-1547: Add Gradient Boosting to MLlib
Browse files Browse the repository at this point in the history
Given the popular demand for gradient boosting and AdaBoost in MLlib, I am creating a WIP branch for early feedback on gradient boosting with AdaBoost to follow soon after this PR is accepted. This is based on work done along with hirakendu that was pending due to decision tree optimizations and random forests work.

Ideally, boosting algorithms should work with any base learners.  This will soon be possible once the MLlib API is finalized -- we want to ensure we use a consistent interface for the underlying base learners. In the meantime, this PR uses decision trees as base learners for the gradient boosting algorithm. The current PR allows "pluggable" loss functions and provides least squares error and least absolute error by default.

Here is the task list:
- [x] Gradient boosting support
- [x] Pluggable loss functions
- [x] Stochastic gradient boosting support – Re-use the BaggedPoint approach used for RandomForest.
- [x] Binary classification support
- [x] Support configurable checkpointing – This approach will avoid long lineage chains.
- [x] Create classification and regression APIs
- [x] Weighted Ensemble Model -- created a WeightedEnsembleModel class that can be used by ensemble algorithms such as random forests and boosting.
- [x] Unit Tests

Future work:
+ Multi-class classification is currently not supported by this PR since it requires discussion on the best way to support "deviance" as a loss function.
+ BaggedRDD caching -- Avoid repeating feature to bin mapping for each tree estimator after standard API work is completed.

cc: jkbradley hirakendu mengxr etrain atalwalkar chouqin

Author: Manish Amde <manish9ue@gmail.com>
Author: manishamde <manish9ue@gmail.com>

Closes apache#2607 from manishamde/gbt and squashes the following commits:

991c7b5 [Manish Amde] public api
ff2a796 [Manish Amde] addressing comments
b4c1318 [Manish Amde] removing spaces
8476b6b [Manish Amde] fixing line length
0183cb9 [Manish Amde] fixed naming and formatting issues
1c40c33 [Manish Amde] add newline, removed spaces
e33ab61 [Manish Amde] minor comment
eadbf09 [Manish Amde] parameter renaming
035a2ed [Manish Amde] jkbradley formatting suggestions
9f7359d [Manish Amde] simplified gbt logic and added more tests
49ba107 [Manish Amde] merged from master
eff21fe [Manish Amde] Added gradient boosting tests
3fd0528 [Manish Amde] moved helper methods to new class
a32a5ab [Manish Amde] added test for subsampling without replacement
781542a [Manish Amde] added support for fractional subsampling with replacement
3a18cc1 [Manish Amde] cleaned up api for conversion to bagged point and moved tests to it's own test suite
0e81906 [Manish Amde] improving caching unpersisting logic
d971f73 [Manish Amde] moved RF code to use WeightedEnsembleModel class
fee06d3 [Manish Amde] added weighted ensemble model
1b01943 [Manish Amde] add weights for base learners
9bc6e74 [Manish Amde] adding random seed as parameter
d2c8323 [Manish Amde] Merge branch 'master' into gbt
2ae97b7 [Manish Amde] added documentation for the loss classes
9366b8f [Manish Amde] minor: using numTrees instead of trees.size
3b43896 [Manish Amde] added learning rate for prediction
9b2e35e [Manish Amde] Merge branch 'master' into gbt
6a11c02 [manishamde] fixing formatting
823691b [Manish Amde] fixing RF test
1f47941 [Manish Amde] changing access modifier
5b67102 [Manish Amde] shortened parameter list
5ab3796 [Manish Amde] minor reformatting
9155a9d [Manish Amde] consolidated boosting configuration and added public API
631baea [Manish Amde] Merge branch 'master' into gbt
2cb1258 [Manish Amde] public API support
3b8ffc0 [Manish Amde] added documentation
8e10c63 [Manish Amde] modified unpersist strategy
f62bc48 [Manish Amde] added unpersist
bdca43a [Manish Amde] added timing parameters
2fbc9c7 [Manish Amde] fixing binomial classification prediction
6dd4dd8 [Manish Amde] added support for log loss
9af0231 [Manish Amde] classification attempt
62cc000 [Manish Amde] basic checkpointing
4784091 [Manish Amde] formatting
78ed452 [Manish Amde] added newline and fixed if statement
3973dd1 [Manish Amde] minor indicating subsample is double during comparison
aa8fae7 [Manish Amde] minor refactoring
1a8031c [Manish Amde] sampling with replacement
f1c9ef7 [Manish Amde] Merge branch 'master' into gbt
cdceeef [Manish Amde] added documentation
6251fd5 [Manish Amde] modified method name
5538521 [Manish Amde] disable checkpointing for now
0ae1c0a [Manish Amde] basic gradient boosting code from earlier branches
  • Loading branch information
manishamde authored and mengxr committed Nov 1, 2014
1 parent e07fb6a commit 8602195
Show file tree
Hide file tree
Showing 20 changed files with 1,331 additions and 267 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -317,7 +317,7 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
*/
private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = {
private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = tree.predict(y.features) - y.label
err * err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
val rfModel = rf.train(input)
rfModel.trees(0)
rfModel.weakHypotheses(0)
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree

import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy}
import org.apache.spark.Logging
import org.apache.spark.mllib.tree.impl.TimeTracker
import org.apache.spark.mllib.tree.loss.Losses
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum

/**
* :: Experimental ::
* A class that implements gradient boosting for regression and binary classification problems.
* @param boostingStrategy Parameters for the gradient boosting algorithm
*/
@Experimental
class GradientBoosting (
private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {

/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return WeightedEnsembleModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
val algo = boostingStrategy.algo
algo match {
case Regression => GradientBoosting.boost(input, boostingStrategy)
case Classification =>
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoosting.boost(remappedInput, boostingStrategy)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
}

}


object GradientBoosting extends Logging {

/**
* Method to train a gradient boosting model.
*
* Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
* is recommended to clearly specify regression.
* Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
* is recommended to clearly specify regression.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param boostingStrategy Configuration options for the boosting algorithm.
* @return WeightedEnsembleModel that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
new GradientBoosting(boostingStrategy).train(input)
}

/**
* Method to train a gradient boosting classification model.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param boostingStrategy Configuration options for the boosting algorithm.
* @return WeightedEnsembleModel that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
val algo = boostingStrategy.algo
require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.")
new GradientBoosting(boostingStrategy).train(input)
}

/**
* Method to train a gradient boosting regression model.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param boostingStrategy Configuration options for the boosting algorithm.
* @return WeightedEnsembleModel that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
val algo = boostingStrategy.algo
require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.")
new GradientBoosting(boostingStrategy).train(input)
}

/**
* Method to train a gradient boosting binary classification model.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param numEstimators Number of estimators used in boosting stages. In other words,
* number of boosting iterations performed.
* @param loss Loss function used for minimization during gradient boosting.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1]
* @param subsamplingRate Fraction of the training data used for learning the decision tree.
* @param numClassesForClassification Number of classes for classification.
* (Ignored for regression.)
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
* the number of discrete values they take. For example,
* an entry (n -> k) implies the feature n is categorical with k
* categories 0, 1, 2, ... , k-1. It's important to note that
* features are zero-indexed.
* @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
* supported.)
* @return WeightedEnsembleModel that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
numEstimators: Int,
loss: String,
learningRate: Double,
subsamplingRate: Double,
numClassesForClassification: Int,
categoricalFeaturesInfo: Map[Int, Int],
weakLearnerParams: Strategy): WeightedEnsembleModel = {
val lossType = Losses.fromString(loss)
val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType,
learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
weakLearnerParams)
new GradientBoosting(boostingStrategy).train(input)
}

/**
* Method to train a gradient boosting regression model.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
* @param numEstimators Number of estimators used in boosting stages. In other words,
* number of boosting iterations performed.
* @param loss Loss function used for minimization during gradient boosting.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1]
* @param subsamplingRate Fraction of the training data used for learning the decision tree.
* @param numClassesForClassification Number of classes for classification.
* (Ignored for regression.)
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
* the number of discrete values they take. For example,
* an entry (n -> k) implies the feature n is categorical with k
* categories 0, 1, 2, ... , k-1. It's important to note that
* features are zero-indexed.
* @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
* supported.)
* @return WeightedEnsembleModel that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
numEstimators: Int,
loss: String,
learningRate: Double,
subsamplingRate: Double,
numClassesForClassification: Int,
categoricalFeaturesInfo: Map[Int, Int],
weakLearnerParams: Strategy): WeightedEnsembleModel = {
val lossType = Losses.fromString(loss)
val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType,
learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
weakLearnerParams)
new GradientBoosting(boostingStrategy).train(input)
}

/**
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
*/
def trainClassifier(
input: RDD[LabeledPoint],
numEstimators: Int,
loss: String,
learningRate: Double,
subsamplingRate: Double,
numClassesForClassification: Int,
categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer],
weakLearnerParams: Strategy): WeightedEnsembleModel = {
trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate,
numClassesForClassification,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
weakLearnerParams)
}

/**
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
*/
def trainRegressor(
input: RDD[LabeledPoint],
numEstimators: Int,
loss: String,
learningRate: Double,
subsamplingRate: Double,
numClassesForClassification: Int,
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
weakLearnerParams: Strategy): WeightedEnsembleModel = {
trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate,
numClassesForClassification,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
weakLearnerParams)
}


/**
* Internal method for performing regression using trees as base learners.
* @param input training dataset
* @param boostingStrategy boosting parameters
* @return
*/
private def boost(
input: RDD[LabeledPoint],
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {

val timer = new TimeTracker()
timer.start("total")
timer.start("init")

// Initialize gradient boosting parameters
val numEstimators = boostingStrategy.numEstimators
val baseLearners = new Array[DecisionTreeModel](numEstimators)
val baseLearnerWeights = new Array[Double](numEstimators)
val loss = boostingStrategy.loss
val learningRate = boostingStrategy.learningRate
val strategy = boostingStrategy.weakLearnerParams

// Cache input
input.persist(StorageLevel.MEMORY_AND_DISK)

timer.stop("init")

logDebug("##########")
logDebug("Building tree 0")
logDebug("##########")
var data = input

// 1. Initialize tree
timer.start("building tree 0")
val firstTreeModel = new DecisionTree(strategy).train(data)
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = 1.0
val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression,
Sum)
logDebug("error of gbt = " + loss.computeError(startingModel, input))
// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")

// psuedo-residual for second iteration
data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
point.features))

var m = 1
while (m < numEstimators) {
timer.start(s"building tree $m")
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
val model = new DecisionTree(strategy).train(data)
timer.stop(s"building tree $m")
// Create partial model
baseLearners(m) = model
baseLearnerWeights(m) = learningRate
// Note: A model of type regression is used since we require raw prediction
val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1),
baseLearnerWeights.slice(0, m + 1), Regression, Sum)
logDebug("error of gbt = " + loss.computeError(partialModel, input))
// Update data with pseudo-residuals
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
point.features))
m += 1
}

timer.stop("total")

logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")


// 3. Output classifier
new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum)

}

}
Loading

0 comments on commit 8602195

Please sign in to comment.