diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index ef41ac996195f..018a492f75b7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -25,46 +25,55 @@ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.sql._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** * :: AlphaComponent :: * `Bucketizer` maps a column of continuous features to a column of feature buckets. */ @AlphaComponent -final class Bucketizer(override val parent: Estimator[Bucketizer] = null) +private[ml] final class Bucketizer(override val parent: Estimator[Bucketizer]) extends Model[Bucketizer] with HasInputCol with HasOutputCol { - /** - * The given buckets should match 1) its size is larger than zero; 2) it is ordered in a non-DESC - * way. - */ - private def checkBuckets(buckets: Array[Double]): Boolean = { - if (buckets.size == 0) false - else if (buckets.size == 1) true - else { - buckets.foldLeft((true, Double.MinValue)) { case ((validator, prevValue), currValue) => - if (validator & prevValue <= currValue) { - (true, currValue) - } else { - (false, currValue) - } - }._1 - } - } + def this() = this(null) /** - * Parameter for mapping continuous features into buckets. + * Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets. + * A bucket defined by splits x,y holds values in the range (x,y]. * @group param */ - val buckets: Param[Array[Double]] = new Param[Array[Double]](this, "buckets", - "Split points for mapping continuous features into buckets.", checkBuckets) + val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits", + "Split points for mapping continuous features into buckets. With n splits, there are n+1" + + " buckets. A bucket defined by splits x,y holds values in the range (x,y].", + Bucketizer.checkSplits) /** @group getParam */ - def getBuckets: Array[Double] = $(buckets) + def getSplits: Array[Double] = $(splits) /** @group setParam */ - def setBuckets(value: Array[Double]): this.type = set(buckets, value) + def setSplits(value: Array[Double]): this.type = set(splits, value) + + /** @group Param */ + val lowerInclusive: BooleanParam = new BooleanParam(this, "lowerInclusive", + "An indicator of the inclusiveness of negative infinite.") + setDefault(lowerInclusive -> true) + + /** @group getParam */ + def getLowerInclusive: Boolean = $(lowerInclusive) + + /** @group setParam */ + def setLowerInclusive(value: Boolean): this.type = set(lowerInclusive, value) + + /** @group Param */ + val upperInclusive: BooleanParam = new BooleanParam(this, "upperInclusive", + "An indicator of the inclusiveness of positive infinite.") + setDefault(upperInclusive -> true) + + /** @group getParam */ + def getUpperInclusive: Boolean = $(upperInclusive) + + /** @group setParam */ + def setUpperInclusive(value: Boolean): this.type = set(upperInclusive, value) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -74,24 +83,61 @@ final class Bucketizer(override val parent: Estimator[Bucketizer] = null) override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema) - val bucketizer = udf { feature: Double => binarySearchForBuckets($(buckets), feature) } - val outputColName = $(outputCol) - val metadata = NominalAttribute.defaultAttr - .withName(outputColName).withValues($(buckets).map(_.toString)).toMetadata() - dataset.select(col("*"), bucketizer(dataset($(inputCol))).as(outputColName, metadata)) + val wrappedSplits = Array(Double.MinValue) ++ $(splits) ++ Array(Double.MaxValue) + val bucketizer = udf { feature: Double => + Bucketizer.binarySearchForBuckets(wrappedSplits, feature) } + val newCol = bucketizer(dataset($(inputCol))) + val newField = prepOutputField(dataset.schema) + dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) + } + + private def prepOutputField(schema: StructType): StructField = { + val attr = new NominalAttribute( + name = Some($(outputCol)), + isOrdinal = Some(true), + numValues = Some($(splits).size), + values = Some($(splits).map(_.toString))) + + attr.toStructField() + } + + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) + require(schema.fields.forall(_.name != $(outputCol)), + s"Output column ${$(outputCol)} already exists.") + StructType(schema.fields :+ prepOutputField(schema)) + } +} + +object Bucketizer { + /** + * The given splits should match 1) its size is larger than zero; 2) it is ordered in a strictly + * increasing way. + */ + private def checkSplits(splits: Array[Double]): Boolean = { + if (splits.size == 0) false + else if (splits.size == 1) true + else { + splits.foldLeft((true, Double.MinValue)) { case ((validator, prevValue), currValue) => + if (validator && prevValue < currValue) { + (true, currValue) + } else { + (false, currValue) + } + }._1 + } } /** * Binary searching in several buckets to place each data point. */ - private def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { - val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue) + private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { var left = 0 - var right = wrappedSplits.length - 2 + var right = splits.length - 2 while (left <= right) { val mid = left + (right - left) / 2 - val split = wrappedSplits(mid) - if ((feature > split) && (feature <= wrappedSplits(mid + 1))) { + val split = splits(mid) + if ((feature > split) && (feature <= splits(mid + 1))) { return mid } else if (feature <= split) { right = mid - 1 @@ -99,20 +145,6 @@ final class Bucketizer(override val parent: Estimator[Bucketizer] = null) left = mid + 1 } } - -1 - } - - override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) - - val inputFields = schema.fields - val outputColName = $(outputCol) - - require(inputFields.forall(_.name != outputColName), - s"Output column $outputColName already exists.") - - val attr = NominalAttribute.defaultAttr.withName(outputColName) - val outputFields = inputFields :+ attr.toStructField() - StructType(outputFields) + throw new Exception("Failed to find a bucket.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 68c4ffbf5a73b..a89d9bbedb9f3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.ml.feature +import org.scalatest.FunSuite + import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row, SQLContext} -import org.scalatest.FunSuite class BucketizerSuite extends FunSuite with MLlibTestSparkContext { @@ -34,11 +35,15 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext { val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") .setOutputCol("result") - .setBuckets(buckets) + .setSplits(buckets) bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { case Row(x: Double, y: Double) => assert(x === y, "The feature value is not correct after bucketing.") } } + + test("Binary search for finding buckets") { + + } }