Skip to content

Commit

Permalink
sampling with replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed Oct 6, 2014
1 parent f1c9ef7 commit 1a8031c
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ object GradientBoosting extends Logging {
val loss = boostingStrategy.loss
val learningRate = boostingStrategy.learningRate
// TODO: Implement Stochastic gradient boosting using BaggedPoint
val subSample = boostingStrategy.subSample
val subSample = boostingStrategy.subsample

// Cache input
input.cache()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,16 @@ private class RandomForest (
// Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
val baggedInput = if (numTrees > 1) {
BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed)
// TODO: Have a stricter check in the strategy
val isRandomForest = numTrees > 1
val baggedInput = if (isRandomForest) {
val subsample = 1.0
val withReplacement = true
BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement)
} else {
BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
val subsample = strategy.subsample
val withReplacement = false
BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement)
}.persist(StorageLevel.MEMORY_AND_DISK)

// depth of the decision tree
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.mllib.tree.loss.{LeastSquaresError, Loss}
case class BoostingStrategy(
numEstimators: Int = 100,
learningRate: Double = 0.1,
subSample: Double = 1,
subsample: Double = 1,
loss: Loss = LeastSquaresError,
checkpointPeriod: Int = 20
)
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* this split will not be considered as a valid split.
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 256 MB.
* @param subsample Fraction of the training data used for learning decision tree.
*/
@Experimental
class Strategy (
Expand All @@ -70,7 +71,8 @@ class Strategy (
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
val minInstancesPerNode: Int = 1,
val minInfoGain: Double = 0.0,
val maxMemoryInMB: Int = 256) extends Serializable {
val maxMemoryInMB: Int = 256,
val subsample: Double = 1) extends Serializable {

if (algo == Classification) {
require(numClassesForClassification >= 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import cern.jet.random.engine.DRand

import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom

/**
* Internal representation of a datapoint which belongs to several subsamples of the same dataset,
Expand All @@ -47,20 +48,61 @@ private[tree] object BaggedPoint {
* Convert an input dataset into its BaggedPoint representation,
* choosing subsample counts for each instance.
* Each subsample has the same number of instances as the original dataset,
* and is created by subsampling with replacement.
* and is created by subsampling without replacement.
* @param input Input dataset.
* @param subsample Fraction of the training data used for learning decision tree.
* @param numSubsamples Number of subsamples of this RDD to take.
* @param seed Random seed.
* @param withReplacement Sampling with/without replacement.
* @return BaggedPoint dataset representation
*/
def convertToBaggedRDD[Datum](
input: RDD[Datum],
subsample: Double,
numSubsamples: Int,
seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = {
withReplacement: Boolean): RDD[BaggedPoint[Datum]] = {
if (withReplacement) {
convertToBaggedRDDSamplingWithReplacement(input, subsample, numSubsamples)
} else {
if (numSubsamples == 1 && subsample == 1) {
convertToBaggedRDDSamplingWithoutReplacement(input, subsample, numSubsamples)
} else {
convertToBaggedRDDWithoutSampling(input)
}
}
}

private def convertToBaggedRDDSamplingWithoutReplacement[Datum](
input: RDD[Datum],
subsample: Double,
numSubsamples: Int): RDD[BaggedPoint[Datum]] = {
val seed = Utils.random.nextLong()
input.mapPartitionsWithIndex { (partitionIndex, instances) =>
// Use random seed = seed + partitionIndex + 1 to make generation reproducible.
val rng = new XORShiftRandom
rng.setSeed(seed + partitionIndex + 1)
instances.map { instance =>
val subsampleWeights = new Array[Double](numSubsamples)
var subsampleIndex = 0
while (subsampleIndex < numSubsamples) {
val x = rng.nextDouble()
subsampleWeights(subsampleIndex) = {
if (x < subsample) 1.0 else 0.0
}
subsampleIndex += 1
}
new BaggedPoint(instance, subsampleWeights)
}
}
}

private def convertToBaggedRDDSamplingWithReplacement[Datum](
input: RDD[Datum],
subsample: Double,
numSubsamples: Int): RDD[BaggedPoint[Datum]] = {
val seed = Utils.random.nextInt()
input.mapPartitionsWithIndex { (partitionIndex, instances) =>
// TODO: Support different sampling rates, and sampling without replacement.
// Use random seed = seed + partitionIndex + 1 to make generation reproducible.
val poisson = new Poisson(1.0, new DRand(seed + partitionIndex + 1))
val poisson = new Poisson(subsample, new DRand(seed + partitionIndex + 1))
instances.map { instance =>
val subsampleWeights = new Array[Double](numSubsamples)
var subsampleIndex = 0
Expand All @@ -73,7 +115,8 @@ private[tree] object BaggedPoint {
}
}

def convertToBaggedRDDWithoutSampling[Datum](input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
private def convertToBaggedRDDWithoutSampling[Datum]
(input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
input.map(datum => new BaggedPoint(datum, Array(1.0)))
}

Expand Down

0 comments on commit 1a8031c

Please sign in to comment.