Skip to content

Commit

Permalink
[SPARK-11515][ML] QuantileDiscretizer should take random seed
Browse files Browse the repository at this point in the history
cc jkbradley

Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>

Closes #9535 from yu-iskw/SPARK-11515.
  • Loading branch information
yu-iskw authored and mengxr committed Feb 11, 2016
1 parent efb65e0 commit 574571c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param.{IntParam, _}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.{DoubleType, StructType}
Expand All @@ -33,7 +33,8 @@ import org.apache.spark.util.random.XORShiftRandom
/**
* Params for [[QuantileDiscretizer]].
*/
private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol with HasOutputCol {
private[feature] trait QuantileDiscretizerBase extends Params
with HasInputCol with HasOutputCol with HasSeed {

/**
* Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must
Expand Down Expand Up @@ -73,6 +74,9 @@ final class QuantileDiscretizer(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
def setSeed(value: Long): this.type = set(seed, value)

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
Expand All @@ -85,7 +89,8 @@ final class QuantileDiscretizer(override val uid: String)
}

override def fit(dataset: DataFrame): Bucketizer = {
val samples = QuantileDiscretizer.getSampledInput(dataset.select($(inputCol)), $(numBuckets))
val samples = QuantileDiscretizer
.getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed))
.map { case Row(feature: Double) => feature }
val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
val splits = QuantileDiscretizer.getSplits(candidates)
Expand All @@ -101,13 +106,13 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi
/**
* Sampling from the given dataset to collect quantile statistics.
*/
private[feature] def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = {
private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = {
val totalSamples = dataset.count()
require(totalSamples > 0,
"QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
val requiredSamples = math.max(numBins * numBins, 10000)
val fraction = math.min(requiredSamples / dataset.count(), 1.0)
dataset.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ private object QuantileDiscretizerSuite extends SparkFunSuite {

val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
.setNumBuckets(numBucket)
.setNumBuckets(numBucket).setSeed(1)
val result = discretizer.fit(df).transform(df)

val transformedFeatures = result.select("result").collect()
Expand Down

0 comments on commit 574571c

Please sign in to comment.