forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MLLIB] SPARK-1547: Add Gradient Boosting to MLlib
Given the popular demand for gradient boosting and AdaBoost in MLlib, I am creating a WIP branch for early feedback on gradient boosting with AdaBoost to follow soon after this PR is accepted. This is based on work done along with hirakendu that was pending due to decision tree optimizations and random forests work. Ideally, boosting algorithms should work with any base learners. This will soon be possible once the MLlib API is finalized -- we want to ensure we use a consistent interface for the underlying base learners. In the meantime, this PR uses decision trees as base learners for the gradient boosting algorithm. The current PR allows "pluggable" loss functions and provides least squares error and least absolute error by default. Here is the task list: - [x] Gradient boosting support - [x] Pluggable loss functions - [x] Stochastic gradient boosting support – Re-use the BaggedPoint approach used for RandomForest. - [x] Binary classification support - [x] Support configurable checkpointing – This approach will avoid long lineage chains. - [x] Create classification and regression APIs - [x] Weighted Ensemble Model -- created a WeightedEnsembleModel class that can be used by ensemble algorithms such as random forests and boosting. - [x] Unit Tests Future work: + Multi-class classification is currently not supported by this PR since it requires discussion on the best way to support "deviance" as a loss function. + BaggedRDD caching -- Avoid repeating feature to bin mapping for each tree estimator after standard API work is completed. cc: jkbradley hirakendu mengxr etrain atalwalkar chouqin Author: Manish Amde <manish9ue@gmail.com> Author: manishamde <manish9ue@gmail.com> Closes apache#2607 from manishamde/gbt and squashes the following commits: 991c7b5 [Manish Amde] public api ff2a796 [Manish Amde] addressing comments b4c1318 [Manish Amde] removing spaces 8476b6b [Manish Amde] fixing line length 0183cb9 [Manish Amde] fixed naming and formatting issues 1c40c33 [Manish Amde] add newline, removed spaces e33ab61 [Manish Amde] minor comment eadbf09 [Manish Amde] parameter renaming 035a2ed [Manish Amde] jkbradley formatting suggestions 9f7359d [Manish Amde] simplified gbt logic and added more tests 49ba107 [Manish Amde] merged from master eff21fe [Manish Amde] Added gradient boosting tests 3fd0528 [Manish Amde] moved helper methods to new class a32a5ab [Manish Amde] added test for subsampling without replacement 781542a [Manish Amde] added support for fractional subsampling with replacement 3a18cc1 [Manish Amde] cleaned up api for conversion to bagged point and moved tests to it's own test suite 0e81906 [Manish Amde] improving caching unpersisting logic d971f73 [Manish Amde] moved RF code to use WeightedEnsembleModel class fee06d3 [Manish Amde] added weighted ensemble model 1b01943 [Manish Amde] add weights for base learners 9bc6e74 [Manish Amde] adding random seed as parameter d2c8323 [Manish Amde] Merge branch 'master' into gbt 2ae97b7 [Manish Amde] added documentation for the loss classes 9366b8f [Manish Amde] minor: using numTrees instead of trees.size 3b43896 [Manish Amde] added learning rate for prediction 9b2e35e [Manish Amde] Merge branch 'master' into gbt 6a11c02 [manishamde] fixing formatting 823691b [Manish Amde] fixing RF test 1f47941 [Manish Amde] changing access modifier 5b67102 [Manish Amde] shortened parameter list 5ab3796 [Manish Amde] minor reformatting 9155a9d [Manish Amde] consolidated boosting configuration and added public API 631baea [Manish Amde] Merge branch 'master' into gbt 2cb1258 [Manish Amde] public API support 3b8ffc0 [Manish Amde] added documentation 8e10c63 [Manish Amde] modified unpersist strategy f62bc48 [Manish Amde] added unpersist bdca43a [Manish Amde] added timing parameters 2fbc9c7 [Manish Amde] fixing binomial classification prediction 6dd4dd8 [Manish Amde] added support for log loss 9af0231 [Manish Amde] classification attempt 62cc000 [Manish Amde] basic checkpointing 4784091 [Manish Amde] formatting 78ed452 [Manish Amde] added newline and fixed if statement 3973dd1 [Manish Amde] minor indicating subsample is double during comparison aa8fae7 [Manish Amde] minor refactoring 1a8031c [Manish Amde] sampling with replacement f1c9ef7 [Manish Amde] Merge branch 'master' into gbt cdceeef [Manish Amde] added documentation 6251fd5 [Manish Amde] modified method name 5538521 [Manish Amde] disable checkpointing for now 0ae1c0a [Manish Amde] basic gradient boosting code from earlier branches
- Loading branch information
1 parent
e07fb6a
commit 8602195
Showing
20 changed files
with
1,331 additions
and
267 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
314 changes: 314 additions & 0 deletions
314
mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,314 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree | ||
|
||
import scala.collection.JavaConverters._ | ||
|
||
import org.apache.spark.annotation.Experimental | ||
import org.apache.spark.api.java.JavaRDD | ||
import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy} | ||
import org.apache.spark.Logging | ||
import org.apache.spark.mllib.tree.impl.TimeTracker | ||
import org.apache.spark.mllib.tree.loss.Losses | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.mllib.regression.LabeledPoint | ||
import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} | ||
import org.apache.spark.mllib.tree.configuration.Algo._ | ||
import org.apache.spark.storage.StorageLevel | ||
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum | ||
|
||
/** | ||
* :: Experimental :: | ||
* A class that implements gradient boosting for regression and binary classification problems. | ||
* @param boostingStrategy Parameters for the gradient boosting algorithm | ||
*/ | ||
@Experimental | ||
class GradientBoosting ( | ||
private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { | ||
|
||
/** | ||
* Method to train a gradient boosting model | ||
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. | ||
* @return WeightedEnsembleModel that can be used for prediction | ||
*/ | ||
def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = { | ||
val algo = boostingStrategy.algo | ||
algo match { | ||
case Regression => GradientBoosting.boost(input, boostingStrategy) | ||
case Classification => | ||
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) | ||
GradientBoosting.boost(remappedInput, boostingStrategy) | ||
case _ => | ||
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") | ||
} | ||
} | ||
|
||
} | ||
|
||
|
||
object GradientBoosting extends Logging { | ||
|
||
/** | ||
* Method to train a gradient boosting model. | ||
* | ||
* Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] | ||
* is recommended to clearly specify regression. | ||
* Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] | ||
* is recommended to clearly specify regression. | ||
* | ||
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. | ||
* For classification, labels should take values {0, 1, ..., numClasses-1}. | ||
* For regression, labels are real numbers. | ||
* @param boostingStrategy Configuration options for the boosting algorithm. | ||
* @return WeightedEnsembleModel that can be used for prediction | ||
*/ | ||
def train( | ||
input: RDD[LabeledPoint], | ||
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { | ||
new GradientBoosting(boostingStrategy).train(input) | ||
} | ||
|
||
/** | ||
* Method to train a gradient boosting classification model. | ||
* | ||
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. | ||
* For classification, labels should take values {0, 1, ..., numClasses-1}. | ||
* For regression, labels are real numbers. | ||
* @param boostingStrategy Configuration options for the boosting algorithm. | ||
* @return WeightedEnsembleModel that can be used for prediction | ||
*/ | ||
def trainClassifier( | ||
input: RDD[LabeledPoint], | ||
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { | ||
val algo = boostingStrategy.algo | ||
require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.") | ||
new GradientBoosting(boostingStrategy).train(input) | ||
} | ||
|
||
/** | ||
* Method to train a gradient boosting regression model. | ||
* | ||
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. | ||
* For classification, labels should take values {0, 1, ..., numClasses-1}. | ||
* For regression, labels are real numbers. | ||
* @param boostingStrategy Configuration options for the boosting algorithm. | ||
* @return WeightedEnsembleModel that can be used for prediction | ||
*/ | ||
def trainRegressor( | ||
input: RDD[LabeledPoint], | ||
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { | ||
val algo = boostingStrategy.algo | ||
require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.") | ||
new GradientBoosting(boostingStrategy).train(input) | ||
} | ||
|
||
/** | ||
* Method to train a gradient boosting binary classification model. | ||
* | ||
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. | ||
* For classification, labels should take values {0, 1, ..., numClasses-1}. | ||
* For regression, labels are real numbers. | ||
* @param numEstimators Number of estimators used in boosting stages. In other words, | ||
* number of boosting iterations performed. | ||
* @param loss Loss function used for minimization during gradient boosting. | ||
* @param learningRate Learning rate for shrinking the contribution of each estimator. The | ||
* learning rate should be between in the interval (0, 1] | ||
* @param subsamplingRate Fraction of the training data used for learning the decision tree. | ||
* @param numClassesForClassification Number of classes for classification. | ||
* (Ignored for regression.) | ||
* @param categoricalFeaturesInfo A map storing information about the categorical variables and | ||
* the number of discrete values they take. For example, | ||
* an entry (n -> k) implies the feature n is categorical with k | ||
* categories 0, 1, 2, ... , k-1. It's important to note that | ||
* features are zero-indexed. | ||
* @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is | ||
* supported.) | ||
* @return WeightedEnsembleModel that can be used for prediction | ||
*/ | ||
def trainClassifier( | ||
input: RDD[LabeledPoint], | ||
numEstimators: Int, | ||
loss: String, | ||
learningRate: Double, | ||
subsamplingRate: Double, | ||
numClassesForClassification: Int, | ||
categoricalFeaturesInfo: Map[Int, Int], | ||
weakLearnerParams: Strategy): WeightedEnsembleModel = { | ||
val lossType = Losses.fromString(loss) | ||
val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType, | ||
learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo, | ||
weakLearnerParams) | ||
new GradientBoosting(boostingStrategy).train(input) | ||
} | ||
|
||
/** | ||
* Method to train a gradient boosting regression model. | ||
* | ||
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. | ||
* For classification, labels should take values {0, 1, ..., numClasses-1}. | ||
* For regression, labels are real numbers. | ||
* @param numEstimators Number of estimators used in boosting stages. In other words, | ||
* number of boosting iterations performed. | ||
* @param loss Loss function used for minimization during gradient boosting. | ||
* @param learningRate Learning rate for shrinking the contribution of each estimator. The | ||
* learning rate should be between in the interval (0, 1] | ||
* @param subsamplingRate Fraction of the training data used for learning the decision tree. | ||
* @param numClassesForClassification Number of classes for classification. | ||
* (Ignored for regression.) | ||
* @param categoricalFeaturesInfo A map storing information about the categorical variables and | ||
* the number of discrete values they take. For example, | ||
* an entry (n -> k) implies the feature n is categorical with k | ||
* categories 0, 1, 2, ... , k-1. It's important to note that | ||
* features are zero-indexed. | ||
* @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is | ||
* supported.) | ||
* @return WeightedEnsembleModel that can be used for prediction | ||
*/ | ||
def trainRegressor( | ||
input: RDD[LabeledPoint], | ||
numEstimators: Int, | ||
loss: String, | ||
learningRate: Double, | ||
subsamplingRate: Double, | ||
numClassesForClassification: Int, | ||
categoricalFeaturesInfo: Map[Int, Int], | ||
weakLearnerParams: Strategy): WeightedEnsembleModel = { | ||
val lossType = Losses.fromString(loss) | ||
val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType, | ||
learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo, | ||
weakLearnerParams) | ||
new GradientBoosting(boostingStrategy).train(input) | ||
} | ||
|
||
/** | ||
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] | ||
*/ | ||
def trainClassifier( | ||
input: RDD[LabeledPoint], | ||
numEstimators: Int, | ||
loss: String, | ||
learningRate: Double, | ||
subsamplingRate: Double, | ||
numClassesForClassification: Int, | ||
categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer], | ||
weakLearnerParams: Strategy): WeightedEnsembleModel = { | ||
trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate, | ||
numClassesForClassification, | ||
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, | ||
weakLearnerParams) | ||
} | ||
|
||
/** | ||
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] | ||
*/ | ||
def trainRegressor( | ||
input: RDD[LabeledPoint], | ||
numEstimators: Int, | ||
loss: String, | ||
learningRate: Double, | ||
subsamplingRate: Double, | ||
numClassesForClassification: Int, | ||
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], | ||
weakLearnerParams: Strategy): WeightedEnsembleModel = { | ||
trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate, | ||
numClassesForClassification, | ||
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, | ||
weakLearnerParams) | ||
} | ||
|
||
|
||
/** | ||
* Internal method for performing regression using trees as base learners. | ||
* @param input training dataset | ||
* @param boostingStrategy boosting parameters | ||
* @return | ||
*/ | ||
private def boost( | ||
input: RDD[LabeledPoint], | ||
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { | ||
|
||
val timer = new TimeTracker() | ||
timer.start("total") | ||
timer.start("init") | ||
|
||
// Initialize gradient boosting parameters | ||
val numEstimators = boostingStrategy.numEstimators | ||
val baseLearners = new Array[DecisionTreeModel](numEstimators) | ||
val baseLearnerWeights = new Array[Double](numEstimators) | ||
val loss = boostingStrategy.loss | ||
val learningRate = boostingStrategy.learningRate | ||
val strategy = boostingStrategy.weakLearnerParams | ||
|
||
// Cache input | ||
input.persist(StorageLevel.MEMORY_AND_DISK) | ||
|
||
timer.stop("init") | ||
|
||
logDebug("##########") | ||
logDebug("Building tree 0") | ||
logDebug("##########") | ||
var data = input | ||
|
||
// 1. Initialize tree | ||
timer.start("building tree 0") | ||
val firstTreeModel = new DecisionTree(strategy).train(data) | ||
baseLearners(0) = firstTreeModel | ||
baseLearnerWeights(0) = 1.0 | ||
val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression, | ||
Sum) | ||
logDebug("error of gbt = " + loss.computeError(startingModel, input)) | ||
// Note: A model of type regression is used since we require raw prediction | ||
timer.stop("building tree 0") | ||
|
||
// psuedo-residual for second iteration | ||
data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), | ||
point.features)) | ||
|
||
var m = 1 | ||
while (m < numEstimators) { | ||
timer.start(s"building tree $m") | ||
logDebug("###################################################") | ||
logDebug("Gradient boosting tree iteration " + m) | ||
logDebug("###################################################") | ||
val model = new DecisionTree(strategy).train(data) | ||
timer.stop(s"building tree $m") | ||
// Create partial model | ||
baseLearners(m) = model | ||
baseLearnerWeights(m) = learningRate | ||
// Note: A model of type regression is used since we require raw prediction | ||
val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1), | ||
baseLearnerWeights.slice(0, m + 1), Regression, Sum) | ||
logDebug("error of gbt = " + loss.computeError(partialModel, input)) | ||
// Update data with pseudo-residuals | ||
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), | ||
point.features)) | ||
m += 1 | ||
} | ||
|
||
timer.stop("total") | ||
|
||
logInfo("Internal timing for DecisionTree:") | ||
logInfo(s"$timer") | ||
|
||
|
||
// 3. Output classifier | ||
new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum) | ||
|
||
} | ||
|
||
} |
Oops, something went wrong.