diff --git a/docs/ml-features.md b/docs/ml-features.md index 3dbb960dea03e..418e94ad1ea19 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -855,6 +855,116 @@ for more details on the API. +## TargetEncoder + +[Target Encoding](https://www.researchgate.net/publication/220520258_A_Preprocessing_Scheme_for_High-Cardinality_Categorical_Attributes_in_Classification_and_Prediction_Problems) is a data-preprocessing technique that transforms high-cardinality categorical features into quasi-continuous scalar attributes suited for use in regression-type models. This paradigm maps individual values of an independent feature to a scalar, representing some estimate of the dependent attribute (meaning categorical values that exhibit similar statistics with respect to the target will have a similar representation). + +By leveraging the relationship between categorical features and the target variable, Target Encoding usually performs better than One-Hot and does not require a final binary vector encoding, decreasing the overall dimensionality of the dataset. + +User can specify input and output column names by setting `inputCol` and `outputCol` for single-column use cases, or `inputCols` and `outputCols` for multi-column use cases (both arrays required to have the same size). These columns are expected to contain categorical indices (positive integers), being missing values (null) treated as a separate category. Data type must be any subclass of 'NumericType'. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first. + +User can specify the target column name by setting `label`. This column is expected to contain the ground-truth labels from which encodings will be derived. Observations with missing label (null) are not considered when calculating estimates. Data type must be any subclass of 'NumericType'. + +`TargetEncoder` supports the `handleInvalid` parameter to choose how to handle invalid input, meaning categories not seen at training, when encoding new data. Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an exception). + +`TargetEncoder` supports the `targetType` parameter to choose the label type when fitting data, affecting how estimates are calculated. Available options include 'binary' and 'continuous'. + +When set to 'binary', the target attribute $Y$ is expected to be binary, $Y\in\{ 0,1 \}$. The transformation maps individual values $X_{i}$ to the conditional probability of $Y$ given that $X=X_{i}\;$: $\;\; S_{i}=P(Y\mid X=X_{i})$. This approach is also known as bin-counting. + +When set to 'continuous', the target attribute $Y$ is expected to be continuous, $Y\in\mathbb{Q}$. The transformation maps individual values $X_{i}$ to the average of $Y$ given that $X=X_{i}\;$: $\;\; S_{i}=E[Y\mid X=X_{i}]$. This approach is also known as mean-encoding. + +`TargetEncoder` supports the `smoothing` parameter to tune how in-category stats and overall stats are blended. High-cardinality categorical features are usually unevenly distributed across all possible values of $X$. +Therefore, calculating encodings $S_{i}$ according only to in-class statistics makes this estimates very unreliable, and rarely seen categories will very likely cause overfitting in learning. + +Smoothing prevents this behaviour by weighting in-class estimates with overall estimates according to the relative size of the particular class on the whole dataset. + +$\;\;\; S_{i}=\lambda(n_{i})\, P(Y\mid X=X_{i})+(1-\lambda(n_{i}))\, P(Y)$ for the binary case + +$\;\;\; S_{i}=\lambda(n_{i})\, E[Y\mid X=X_{i}]+(1-\lambda(n_{i}))\, E[Y]$ for the continuous case + +being $\lambda(n_{i})$ a monotonically increasing function on $n_{i}$, bounded between 0 and 1. + +Usually $\lambda(n_{i})$ is implemented as the parametric function $\lambda(n_{i})=\frac{n_{i}}{n_{i}+m}$, where $m$ is the smoothing factor, represented by `smoothing` parameter in `TargetEncoder`. + +**Examples** + +Building on the `TargetEncoder` example, let's assume we have the following +DataFrame with columns `feature` and `target` (binary & continuous): + +~~~~ + feature | target | target + | (bin) | (cont) + --------|--------|-------- + 1 | 0 | 1.3 + 1 | 1 | 2.5 + 1 | 0 | 1.6 + 2 | 1 | 1.8 + 2 | 0 | 2.4 + 3 | 1 | 3.2 +~~~~ + +Applying `TargetEncoder` with 'binary' target type, +`feature` as the input column,`target (bin)` as the label column +and `encoded` as the output column, we are able to fit a model +on the data to learn encodings and transform the data according +to these mappings: + +~~~~ + feature | target | encoded + | (bin) | + --------|--------|-------- + 1 | 0 | 0.333 + 1 | 1 | 0.333 + 1 | 0 | 0.333 + 2 | 1 | 0.5 + 2 | 0 | 0.5 + 3 | 1 | 1.0 +~~~~ + +Applying `TargetEncoder` with 'continuous' target type, +`feature` as the input column,`target (cont)` as the label column +and `encoded` as the output column, we are able to fit a model +on the data to learn encodings and transform the data according +to these mappings: + +~~~~ + feature | target | encoded + | (cont) | + --------|--------|-------- + 1 | 1.3 | 1.8 + 1 | 2.5 | 1.8 + 1 | 1.6 | 1.8 + 2 | 1.8 | 2.1 + 2 | 2.4 | 2.1 + 3 | 3.2 | 3.2 +~~~~ + +
+ +
+ +Refer to the [TargetEncoder Python docs](api/python/reference/api/pyspark.ml.feature.TargetEncoder.html) for more details on the API. + +{% include_example python/ml/target_encoder_example.py %} +
+ +
+ +Refer to the [TargetEncoder Scala docs](api/scala/org/apache/spark/ml/feature/TargetEncoder.html) for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/TargetEncoderExample.scala %} +
+ +
+ +Refer to the [TargetEncoder Java docs](api/java/org/apache/spark/ml/feature/TargetEncoder.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java %} +
+ +
+ ## VectorIndexer `VectorIndexer` helps index categorical features in datasets of `Vector`s. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java new file mode 100644 index 0000000000000..da391bd469192 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.SparkSession; + +// $example on$ +import org.apache.spark.ml.feature.TargetEncoder; +import org.apache.spark.ml.feature.TargetEncoderModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import java.util.Arrays; +import java.util.List; +// $example off$ + +public class JavaTargetEncoderExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaTargetEncoderExample") + .getOrCreate(); + + // Note: categorical features are usually first encoded with StringIndexer + // $example on$ + List data = Arrays.asList( + RowFactory.create(0.0, 1.0, 0, 10.0), + RowFactory.create(1.0, 0.0, 1, 20.0), + RowFactory.create(2.0, 1.0, 0, 30.0), + RowFactory.create(0.0, 2.0, 1, 40.0), + RowFactory.create(0.0, 1.0, 0, 50.0), + RowFactory.create(2.0, 0.0, 1, 60.0) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("categoryIndex1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("categoryIndex2", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("binaryLabel", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("continuousLabel", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + // binary target + TargetEncoder bin_encoder = new TargetEncoder() + .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"}) + .setOutputCols(new String[] {"categoryIndex1Target", "categoryIndex2Target"}) + .setLabelCol("binaryLabel") + .setTargetType("binary"); + + TargetEncoderModel bin_model = bin_encoder.fit(df); + Dataset bin_encoded = bin_model.transform(df); + bin_encoded.show(); + + // continuous target + TargetEncoder cont_encoder = new TargetEncoder() + .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"}) + .setOutputCols(new String[] {"categoryIndex1Target", "categoryIndex2Target"}) + .setLabelCol("continuousLabel") + .setTargetType("continuous"); + + TargetEncoderModel cont_model = cont_encoder.fit(df); + Dataset cont_encoded = cont_model.transform(df); + cont_encoded.show(); + // $example off$ + + spark.stop(); + } +} + diff --git a/examples/src/main/python/ml/target_encoder_example.py b/examples/src/main/python/ml/target_encoder_example.py new file mode 100644 index 0000000000000..f6c1010de71f3 --- /dev/null +++ b/examples/src/main/python/ml/target_encoder_example.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.ml.feature import TargetEncoder + +# $example off$ +from pyspark.sql import SparkSession + +if __name__ == "__main__": + spark = SparkSession.builder.appName("TargetEncoderExample").getOrCreate() + + # Note: categorical features are usually first encoded with StringIndexer + # $example on$ + df = spark.createDataFrame( + [ + (0.0, 1.0, 0, 10.0), + (1.0, 0.0, 1, 20.0), + (2.0, 1.0, 0, 30.0), + (0.0, 2.0, 1, 40.0), + (0.0, 1.0, 0, 50.0), + (2.0, 0.0, 1, 60.0), + ], + ["categoryIndex1", "categoryIndex2", "binaryLabel", "continuousLabel"], + ) + + # binary target + encoder = TargetEncoder( + inputCols=["categoryIndex1", "categoryIndex2"], + outputCols=["categoryIndex1Target", "categoryIndex2Target"], + labelCol="binaryLabel", + targetType="binary" + ) + model = encoder.fit(df) + encoded = model.transform(df) + encoded.show() + + # continuous target + encoder = TargetEncoder( + inputCols=["categoryIndex1", "categoryIndex2"], + outputCols=["categoryIndex1Target", "categoryIndex2Target"], + labelCol="continuousLabel", + targetType="continuous" + ) + + model = encoder.fit(df) + encoded = model.transform(df) + encoded.show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TargetEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TargetEncoderExample.scala new file mode 100644 index 0000000000000..a03f903c86d06 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TargetEncoderExample.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.TargetEncoder +// $example off$ +import org.apache.spark.sql.SparkSession + +object TargetEncoderExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder() + .appName("TargetEncoderExample") + .getOrCreate() + + // Note: categorical features are usually first encoded with StringIndexer + // $example on$ + val df = spark.createDataFrame(Seq( + (0.0, 1.0, 0, 10.0), + (1.0, 0.0, 1, 20.0), + (2.0, 1.0, 0, 30.0), + (0.0, 2.0, 1, 40.0), + (0.0, 1.0, 0, 50.0), + (2.0, 0.0, 1, 60.0) + )).toDF("categoryIndex1", "categoryIndex2", + "binaryLabel", "continuousLabel") + + // binary target + val bin_encoder = new TargetEncoder() + .setInputCols(Array("categoryIndex1", "categoryIndex2")) + .setOutputCols(Array("categoryIndex1Target", "categoryIndex2Target")) + .setLabelCol("binaryLabel") + .setTargetType("binary"); + + val bin_model = bin_encoder.fit(df) + val bin_encoded = bin_model.transform(df) + bin_encoded.show() + + // continuous target + val cont_encoder = new TargetEncoder() + .setInputCols(Array("categoryIndex1", "categoryIndex2")) + .setOutputCols(Array("categoryIndex1Target", "categoryIndex2Target")) + .setLabelCol("continuousLabel") + .setTargetType("continuous"); + + val cont_model = cont_encoder.fit(df) + val cont_encoded = cont_model.transform(df) + cont_encoded.show() + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala new file mode 100644 index 0000000000000..2be3529d00a35 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Since +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** Private trait for params and common methods for TargetEncoder and TargetEncoderModel */ +private[ml] trait TargetEncoderBase extends Params with HasLabelCol + with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols with HasHandleInvalid { + + /** + * Param for how to handle invalid data during transform(). + * Options are 'keep' (invalid data presented as an extra categorical feature) or + * 'error' (throw an error). + * Note that this Param is only used during transform; during fitting, invalid data + * will result in an error. + * Default: "error" + * @group param + */ + @Since("4.0.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + "How to handle invalid data during transform(). " + + "Options are 'keep' (invalid data presented as an extra categorical feature) " + + "or error (throw an error). Note that this Param is only used during transform; " + + "during fitting, invalid data will result in an error.", + ParamValidators.inArray(TargetEncoder.supportedHandleInvalids)) + + setDefault(handleInvalid -> TargetEncoder.ERROR_INVALID) + + @Since("4.0.0") + val targetType: Param[String] = new Param[String](this, "targetType", + "How to handle invalid data during transform(). " + + "Options are 'keep' (invalid data presented as an extra categorical feature) " + + "or error (throw an error). Note that this Param is only used during transform; " + + "during fitting, invalid data will result in an error.", + ParamValidators.inArray(TargetEncoder.supportedTargetTypes)) + + setDefault(targetType -> TargetEncoder.TARGET_BINARY) + + final def getTargetType: String = $(targetType) + + @Since("4.0.0") + val smoothing: DoubleParam = new DoubleParam(this, "smoothing", + "lower bound of the output feature range", + ParamValidators.gtEq(0.0)) + + setDefault(smoothing -> 0.0) + + final def getSmoothing: Double = $(smoothing) + + private[feature] lazy val inputFeatures = if (isSet(inputCol)) Array($(inputCol)) + else if (isSet(inputCols)) $(inputCols) + else Array.empty[String] + + private[feature] lazy val outputFeatures = if (isSet(outputCol)) Array($(outputCol)) + else if (isSet(outputCols)) $(outputCols) + else inputFeatures.map{field: String => s"${field}_indexed"} + + private[feature] def validateSchema(schema: StructType, + fitting: Boolean): StructType = { + + require(inputFeatures.length > 0, + s"At least one input column must be specified.") + + require(inputFeatures.length == outputFeatures.length, + s"The number of input columns ${inputFeatures.length} must be the same as the number of " + + s"output columns ${outputFeatures.length}.") + + val features = if (fitting) inputFeatures :+ $(labelCol) + else inputFeatures + + features.foreach { + feature => { + try { + val field = schema(feature) + if (!field.dataType.isInstanceOf[NumericType]) { + throw new SparkException(s"Data type for column ${feature} is ${field.dataType}" + + s", but a subclass of ${NumericType} is required.") + } + } catch { + case e: IllegalArgumentException => + throw new SparkException(s"No column named ${feature} found on dataset.") + } + } + } + schema + } + +} + +/** + * Target Encoding maps a column of categorical indices into a numerical feature derived + * from the target. + * + * When `handleInvalid` is configured to 'keep', previously unseen values of a feature + * are mapped to the dataset overall statistics. + * + * When 'targetType' is configured to 'binary', categories are encoded as the conditional + * probability of the target given that category (bin counting). + * When 'targetType' is configured to 'continuous', categories are encoded as the average + * of the target given that category (mean encoding) + * + * Parameter 'smoothing' controls how in-category stats and overall stats are weighted. + * + * @note When encoding multi-column by using `inputCols` and `outputCols` params, input/output cols + * come in pairs, specified by the order in the arrays, and each pair is treated independently. + * + * @see `StringIndexer` for converting categorical values into category indices + */ +@Since("4.0.0") +class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) + extends Estimator[TargetEncoderModel] with TargetEncoderBase with DefaultParamsWritable { + + @Since("4.0.0") + def this() = this(Identifiable.randomUID("TargetEncoder")) + + /** @group setParam */ + @Since("4.0.0") + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ + @Since("4.0.0") + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + @Since("4.0.0") + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + @Since("4.0.0") + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + @Since("4.0.0") + def setOutputCols(values: Array[String]): this.type = set(outputCols, values) + + /** @group setParam */ + @Since("4.0.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + /** @group setParam */ + @Since("4.0.0") + def setTargetType(value: String): this.type = set(targetType, value) + + /** @group setParam */ + @Since("4.0.0") + def setSmoothing(value: Double): this.type = set(smoothing, value) + + @Since("4.0.0") + override def transformSchema(schema: StructType): StructType = { + validateSchema(schema, fitting = true) + } + + @Since("4.0.0") + override def fit(dataset: Dataset[_]): TargetEncoderModel = { + validateSchema(dataset.schema, fitting = true) + + val feature_types = inputFeatures.map{ + feature => dataset.schema(feature).dataType + } + val label_type = dataset.schema($(labelCol)).dataType + + val stats = dataset + .select((inputFeatures :+ $(labelCol)).map(col).toIndexedSeq: _*) + .rdd.treeAggregate( + Array.fill(inputFeatures.length) { + Map.empty[Option[Double], (Double, Double)] + })( + (agg, row: Row) => if (!row.isNullAt(inputFeatures.length)) { + val label = label_type match { + case ByteType => row.getByte(inputFeatures.length).toDouble + case ShortType => row.getShort(inputFeatures.length).toDouble + case IntegerType => row.getInt(inputFeatures.length).toDouble + case LongType => row.getLong(inputFeatures.length).toDouble + case DoubleType => row.getDouble(inputFeatures.length) + } + inputFeatures.indices.map { + feature => { + val category: Option[Double] = { + if (row.isNullAt(feature)) None // null category + else { + val value: Double = feature_types(feature) match { + case ByteType => row.getByte(feature).toDouble + case ShortType => row.getShort(feature).toDouble + case IntegerType => row.getInt(feature).toDouble + case LongType => row.getLong(feature).toDouble + case DoubleType => row.getDouble(feature) + } + if (value < 0.0 || value != value.toInt) throw new SparkException( + s"Values from column ${inputFeatures(feature)} must be indices, " + + s"but got $value.") + else Some(value) // non-null category + } + } + val (class_count, class_stat) = agg(feature).getOrElse(category, (0.0, 0.0)) + val (global_count, global_stat) = + agg(feature).getOrElse(TargetEncoder.UNSEEN_CATEGORY, (0.0, 0.0)) + $(targetType) match { + case TargetEncoder.TARGET_BINARY => // counting + if (label == 1.0) { + // positive => increment both counters for current & unseen categories + agg(feature) + + (category -> (1 + class_count, 1 + class_stat)) + + (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, 1 + global_stat)) + } else if (label == 0.0) { + // negative => increment only global counter for current & unseen categories + agg(feature) + + (category -> (1 + class_count, class_stat)) + + (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, global_stat)) + } else throw new SparkException( + s"Values from column ${getLabelCol} must be binary (0,1) but got $label.") + case TargetEncoder.TARGET_CONTINUOUS => // incremental mean + // increment counter and iterate on mean for current & unseen categories + agg(feature) + + (category -> (1 + class_count, + class_stat + ((label - class_stat) / (1 + class_count)))) + + (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, + global_stat + ((label - global_stat) / (1 + global_count)))) + } + } + }.toArray + } else agg, // ignore null-labeled observations + (agg1, agg2) => inputFeatures.indices.map { + feature => { + val categories = agg1(feature).keySet ++ agg2(feature).keySet + categories.map(category => + category -> { + val (counter1, stat1) = agg1(feature).getOrElse(category, (0.0, 0.0)) + val (counter2, stat2) = agg2(feature).getOrElse(category, (0.0, 0.0)) + $(targetType) match { + case TargetEncoder.TARGET_BINARY => (counter1 + counter2, stat1 + stat2) + case TargetEncoder.TARGET_CONTINUOUS => (counter1 + counter2, + ((counter1 * stat1) + (counter2 * stat2)) / (counter1 + counter2)) + } + }).toMap + } + }.toArray) + + // encodings: Map[feature, Map[Some(category), encoding]] + val encodings: Map[String, Map[Option[Double], Double]] = + inputFeatures.zip(stats).map { + case (feature, stat) => + val (global_count, global_stat) = stat.get(TargetEncoder.UNSEEN_CATEGORY).get + feature -> stat.map { + case (cat, (class_count, class_stat)) => cat -> { + val weight = class_count / (class_count + $(smoothing)) // smoothing weight + $(targetType) match { + case TargetEncoder.TARGET_BINARY => + // calculate conditional probabilities and blend + weight * (class_stat/ class_count) + (1 - weight) * (global_stat / global_count) + case TargetEncoder.TARGET_CONTINUOUS => + // blend means + weight * class_stat + (1 - weight) * global_stat + } + } + } + }.toMap + + val model = new TargetEncoderModel(uid, encodings).setParent(this) + copyValues(model) + } + + @Since("4.0.0") + override def copy(extra: ParamMap): TargetEncoder = defaultCopy(extra) +} + +@Since("4.0.0") +object TargetEncoder extends DefaultParamsReadable[TargetEncoder] { + + // handleInvalid parameter values + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val supportedHandleInvalids: Array[String] = Array(KEEP_INVALID, ERROR_INVALID) + + // targetType parameter values + private[feature] val TARGET_BINARY: String = "binary" + private[feature] val TARGET_CONTINUOUS: String = "continuous" + private[feature] val supportedTargetTypes: Array[String] = Array(TARGET_BINARY, TARGET_CONTINUOUS) + + private[feature] val UNSEEN_CATEGORY: Option[Double] = Some(-1) + + @Since("4.0.0") + override def load(path: String): TargetEncoder = super.load(path) +} + +/** + * @param encodings Original number of categories for each feature being encoded. + * The array contains one value for each input column, in order. + */ +@Since("4.0.0") +class TargetEncoderModel private[ml] ( + @Since("4.0.0") override val uid: String, + @Since("4.0.0") val encodings: Map[String, Map[Option[Double], Double]]) + extends Model[TargetEncoderModel] with TargetEncoderBase with MLWritable { + + @Since("4.0.0") + override def transformSchema(schema: StructType): StructType = { + inputFeatures.zip(outputFeatures) + .foldLeft(validateSchema(schema, fitting = false)) { + case (newSchema, fieldName) => + val field = schema(fieldName._1) + newSchema.add(StructField(fieldName._2, field.dataType, field.nullable)) + } + } + + @Since("4.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { + validateSchema(dataset.schema, fitting = false) + + // builds a column-to-column function from a map of encodings + val apply_encodings: Map[Option[Double], Double] => (Column => Column) = + (mappings: Map[Option[Double], Double]) => { + (col: Column) => { + val nullWhen = when(col.isNull, + mappings.get(None) match { + case Some(code) => lit(code) + case None => if ($(handleInvalid) == TargetEncoder.KEEP_INVALID) { + lit(mappings.get(TargetEncoder.UNSEEN_CATEGORY).get) + } else raise_error(lit( + s"Unseen null value in feature ${col.toString}. To handle unseen values, " + + s"set Param handleInvalid to ${TargetEncoder.KEEP_INVALID}.")) + }) + val ordered_mappings = (mappings - None).toList.sortWith { + (a, b) => (b._1 == TargetEncoder.UNSEEN_CATEGORY) || + ((a._1 != TargetEncoder.UNSEEN_CATEGORY) && (a._1.get < b._1.get)) + } + ordered_mappings + .foldLeft(nullWhen)( + (new_col: Column, mapping) => { + val (Some(original), encoded) = mapping + if (original != TargetEncoder.UNSEEN_CATEGORY.get) { + new_col.when(col === original, lit(encoded)) + } else { // unseen category + new_col.otherwise( + if ($(handleInvalid) == TargetEncoder.KEEP_INVALID) lit(encoded) + else raise_error(concat( + lit("Unseen value "), col, + lit(s" in feature ${col.toString}. To handle unseen values, " + + s"set Param handleInvalid to ${TargetEncoder.KEEP_INVALID}.")))) + } + }) + } + } + + dataset.withColumns( + inputFeatures.zip(outputFeatures).map { + feature => + feature._2 -> (encodings.get(feature._1) match { + case Some(dict) => + apply_encodings(dict)(col(feature._1)) + .as(feature._2, NominalAttribute.defaultAttr + .withName(feature._2) + .withNumValues(dict.size) + .withValues(dict.values.toSet.toArray.map(_.toString)).toMetadata()) + case None => + throw new SparkException(s"No encodings found for ${feature._1}.") + col(feature._1) + }) + }.toMap) + } + + + @Since("4.0.0") + override def copy(extra: ParamMap): TargetEncoderModel = { + val copied = new TargetEncoderModel(uid, encodings) + copyValues(copied, extra).setParent(parent) + } + + @Since("4.0.0") + override def write: MLWriter = new TargetEncoderModel.TargetEncoderModelWriter(this) + + @Since("4.0.0") + override def toString: String = { + s"TargetEncoderModel: uid=$uid, " + + s" handleInvalid=${$(handleInvalid)}, targetType=${$(targetType)}, " + + s"numInputCols=${inputFeatures.length}, numOutputCols=${outputFeatures.length}, " + + s"smoothing=${$(smoothing)}" + } + +} + +@Since("4.0.0") +object TargetEncoderModel extends MLReadable[TargetEncoderModel] { + + private[TargetEncoderModel] + class TargetEncoderModelWriter(instance: TargetEncoderModel) extends MLWriter { + + private case class Data(encodings: Map[String, Map[Option[Double], Double]]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) + val data = Data(instance.encodings) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + } + } + + private class TargetEncoderModelReader extends MLReader[TargetEncoderModel] { + + private val className = classOf[TargetEncoderModel].getName + + override def load(path: String): TargetEncoderModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) + val dataPath = new Path(path, "data").toString + val data = sparkSession.read.parquet(dataPath) + .select("encodings") + .head() + val encodings = data.getAs[Map[String, Map[Option[Double], Double]]](0) + val model = new TargetEncoderModel(metadata.uid, encodings) + metadata.getAndSetParams(model) + model + } + } + + @Since("4.0.0") + override def read: MLReader[TargetEncoderModel] = new TargetEncoderModelReader + + @Since("4.0.0") + override def load(path: String): TargetEncoderModel = super.load(path) +} + diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java new file mode 100644 index 0000000000000..44e38543c515e --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import org.apache.spark.SharedSparkSession; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaTargetEncoderSuite extends SharedSparkSession { + + @Test + public void testTargetEncoderBinary() { + + List data = Arrays.asList( + RowFactory.create((short)0, 3, 5.0, 0.0, 1.0/3, 0.0, 1.0/3), + RowFactory.create((short)1, 4, 5.0, 1.0, 2.0/3, 1.0, 1.0/3), + RowFactory.create((short)2, 3, 5.0, 0.0, 1.0/3, 0.0, 1.0/3), + RowFactory.create((short)0, 4, 6.0, 1.0, 1.0/3, 1.0, 2.0/3), + RowFactory.create((short)1, 3, 6.0, 0.0, 2.0/3, 0.0, 2.0/3), + RowFactory.create((short)2, 4, 6.0, 1.0, 1.0/3, 1.0, 2.0/3), + RowFactory.create((short)0, 3, 7.0, 0.0, 1.0/3, 0.0, 0.0), + RowFactory.create((short)1, 4, 8.0, 1.0, 2.0/3, 1.0, 1.0), + RowFactory.create((short)2, 3, null, 0.0, 1.0/3, 0.0, 0.0)); + StructType schema = createStructType(new StructField[]{ + createStructField("input1", ShortType, true), + createStructField("input2", IntegerType, true), + createStructField("input3", DoubleType, true), + createStructField("label", DoubleType, false), + createStructField("expected1", DoubleType, false), + createStructField("expected2", DoubleType, false), + createStructField("expected3", DoubleType, false) + }); + Dataset dataset = spark.createDataFrame(data, schema); + + TargetEncoder encoder = new TargetEncoder() + .setInputCols(new String[]{"input1", "input2", "input3"}) + .setOutputCols(new String[]{"output1", "output2", "output3"}) + .setTargetType("binary"); + TargetEncoderModel model = encoder.fit(dataset); + Dataset output = model.transform(dataset); + + Assertions.assertEquals( + output.select("output1", "output2", "output3").collectAsList(), + output.select("expected1", "expected2", "expected3").collectAsList()); + + } + + @Test + public void testTargetEncoderContinuous() { + + List data = Arrays.asList( + RowFactory.create((short)0, 3, 5.0, 10.0, 40.0, 50.0, 20.0), + RowFactory.create((short)1, 4, 5.0, 20.0, 50.0, 50.0, 20.0), + RowFactory.create((short)2, 3, 5.0, 30.0, 60.0, 50.0, 20.0), + RowFactory.create((short)0, 4, 6.0, 40.0, 40.0, 50.0, 50.0), + RowFactory.create((short)1, 3, 6.0, 50.0, 50.0, 50.0, 50.0), + RowFactory.create((short)2, 4, 6.0, 60.0, 60.0, 50.0, 50.0), + RowFactory.create((short)0, 3, 7.0, 70.0, 40.0, 50.0, 70.0), + RowFactory.create((short)1, 4, 8.0, 80.0, 50.0, 50.0, 80.0), + RowFactory.create((short)2, 3, null, 90.0, 60.0, 50.0, 90.0)); + StructType schema = createStructType(new StructField[]{ + createStructField("input1", ShortType, true), + createStructField("input2", IntegerType, true), + createStructField("input3", DoubleType, true), + createStructField("label", DoubleType, false), + createStructField("expected1", DoubleType, false), + createStructField("expected2", DoubleType, false), + createStructField("expected3", DoubleType, false) + }); + Dataset dataset = spark.createDataFrame(data, schema); + + TargetEncoder encoder = new TargetEncoder() + .setInputCols(new String[]{"input1", "input2", "input3"}) + .setOutputCols(new String[]{"output1", "output2", "output3"}) + .setTargetType("continuous"); + TargetEncoderModel model = encoder.fit(dataset); + Dataset output = model.transform(dataset); + + Assertions.assertEquals( + output.select("output1", "output2", "output3").collectAsList(), + output.select("expected1", "expected2", "expected3").collectAsList()); + + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala new file mode 100644 index 0000000000000..4d3f4f3f7213b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala @@ -0,0 +1,469 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.immutable.HashMap + +import org.apache.spark.{SparkException, SparkRuntimeException} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ + + @transient var data: Seq[Row] = _ + @transient var schema: StructType = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + // scalastyle:off + data = Seq( + Row(0.toShort, 3, 5.0, 0.toByte, 1.0/3, 0.0, 1.0/3, 10.0, 40.0, 50.0, 20.0, 42.5, 50.0, 27.5), + Row(1.toShort, 4, 5.0, 1.toByte, 2.0/3, 1.0, 1.0/3, 20.0, 50.0, 50.0, 20.0, 50.0, 50.0, 27.5), + Row(2.toShort, 3, 5.0, 0.toByte, 1.0/3, 0.0, 1.0/3, 30.0, 60.0, 50.0, 20.0, 57.5, 50.0, 27.5), + Row(0.toShort, 4, 6.0, 1.toByte, 1.0/3, 1.0, 2.0/3, 40.0, 40.0, 50.0, 50.0, 42.5, 50.0, 50.0), + Row(1.toShort, 3, 6.0, 0.toByte, 2.0/3, 0.0, 2.0/3, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0), + Row(2.toShort, 4, 6.0, 1.toByte, 1.0/3, 1.0, 2.0/3, 60.0, 60.0, 50.0, 50.0, 57.5, 50.0, 50.0), + Row(0.toShort, 3, 7.0, 0.toByte, 1.0/3, 0.0, 0.0, 70.0, 40.0, 50.0, 70.0, 42.5, 50.0, 60.0), + Row(1.toShort, 4, 8.0, 1.toByte, 2.0/3, 1.0, 1.0, 80.0, 50.0, 50.0, 80.0, 50.0, 50.0, 65.0), + Row(2.toShort, 3, 9.0, 0.toByte, 1.0/3, 0.0, 0.0, 90.0, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0)) + // scalastyle:on + + schema = StructType(Array( + StructField("input1", ShortType, nullable = true), + StructField("input2", IntegerType, nullable = true), + StructField("input3", DoubleType, nullable = true), + StructField("binaryLabel", ByteType), + StructField("binaryExpected1", DoubleType), + StructField("binaryExpected2", DoubleType), + StructField("binaryExpected3", DoubleType), + StructField("continuousLabel", DoubleType), + StructField("continuousExpected1", DoubleType), + StructField("continuousExpected2", DoubleType), + StructField("continuousExpected3", DoubleType), + StructField("smoothingExpected1", DoubleType), + StructField("smoothingExpected2", DoubleType), + StructField("smoothingExpected3", DoubleType))) + } + + test("params") { + ParamsSuite.checkParams(new TargetEncoder) + } + + test("TargetEncoder - binary target") { + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("binaryLabel") + .setTargetType(TargetEncoder.TARGET_BINARY) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val model = encoder.fit(df) + + val expected_encodings = Map( + "input1" -> + Map(Some(0.0) -> 1.0/3, Some(1.0) -> 2.0/3, Some(2.0) -> 1.0/3, Some(-1.0) -> 4.0/9), + "input2" -> Map(Some(3.0) -> 0.0, Some(4.0) -> 1.0, Some(-1.0) -> 4.0/9), + "input3" -> HashMap(Some(5.0) -> 1.0/3, Some(6.0) -> 2.0/3, Some(7.0) -> 0.0, + Some(8.0) -> 1.0, Some(9.0) -> 0.0, Some(-1.0) -> 4.0/9)) + + assert(model.encodings.equals(expected_encodings)) + + testTransformer[(Double, Double, Double, Double, Double, Double)]( + df.select("input1", "input2", "input3", + "binaryExpected1", "binaryExpected2", "binaryExpected3"), + model, + "output1", "binaryExpected1", + "output2", "binaryExpected2", + "output3", "binaryExpected3") { + case Row(output1: Double, expected1: Double, + output2: Double, expected2: Double, + output3: Double, expected3: Double) => + assert(output1 === expected1) + assert(output2 === expected2) + assert(output3 === expected3) + } + + } + + test("TargetEncoder - continuous target") { + + val df = spark + .createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val model = encoder.fit(df) + + val expected_encodings = Map( + "input1" -> Map(Some(0.0) -> 40.0, Some(1.0) -> 50.0, Some(2.0) -> 60.0, Some(-1.0) -> 50.0), + "input2" -> Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), + "input3" -> HashMap(Some(5.0) -> 20.0, Some(6.0) -> 50.0, Some(7.0) -> 70.0, + Some(8.0) -> 80.0, Some(9.0) -> 90.0, Some(-1.0) -> 50.0)) + + assert(model.encodings.equals(expected_encodings)) + + testTransformer[(Double, Double, Double, Double, Double, Double)]( + df.select("input1", "input2", "input3", + "continuousExpected1", "continuousExpected2", "continuousExpected3"), + model, + "output1", "continuousExpected1", + "output2", "continuousExpected2", + "output3", "continuousExpected3") { + case Row(output1: Double, expected1: Double, + output2: Double, expected2: Double, + output3: Double, expected3: Double) => + assert(output1 === expected1) + assert(output2 === expected2) + assert(output3 === expected3) + } + + } + + test("TargetEncoder - smoothing") { + + val df = spark + .createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + .setSmoothing(1) + + val model = encoder.fit(df) + + val expected_encodings = Map( + "input1" -> Map(Some(0.0) -> 42.5, Some(1.0) -> 50.0, Some(2.0) -> 57.5, Some(-1.0) -> 50.0), + "input2" -> Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), + "input3" -> HashMap(Some(5.0) -> 27.5, Some(6.0) -> 50.0, Some(7.0) -> 60.0, + Some(8.0) -> 65.0, Some(9.0) -> 70.0, Some(-1.0) -> 50.0)) + + assert(model.encodings.equals(expected_encodings)) + + testTransformer[(Double, Double, Double, Double, Double, Double)]( + df.select("input1", "input2", "input3", + "smoothingExpected1", "smoothingExpected2", "smoothingExpected3"), + model, + "output1", "smoothingExpected1", + "output2", "smoothingExpected2", + "output3", "smoothingExpected3") { + case Row(output1: Double, expected1: Double, + output2: Double, expected2: Double, + output3: Double, expected3: Double) => + assert(output1 === expected1) + assert(output2 === expected2) + assert(output3 === expected3) + } + + } + + test("TargetEncoder - unseen value - keep") { + + val df = spark + .createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setHandleInvalid(TargetEncoder.KEEP_INVALID) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val model = encoder.fit(df) + + val data_unseen = Row(0.toShort, 3, 10.0, + 0.toByte, 0.0, 0.0, 0.0, 0.0, 40.0, 50.0, 50.0, 0.0, 0.0, 0.0) + + val df_unseen = spark + .createDataFrame(sc.parallelize(data :+ data_unseen), schema) + + testTransformer[(Double, Double, Double, Double, Double, Double)]( + df_unseen.select("input1", "input2", "input3", + "continuousExpected1", "continuousExpected2", "continuousExpected3"), + model, + "output1", "continuousExpected1", + "output2", "continuousExpected2", + "output3", "continuousExpected3") { + case Row(output1: Double, expected1: Double, + output2: Double, expected2: Double, + output3: Double, expected3: Double) => + assert(output1 === expected1) + assert(output2 === expected2) + assert(output3 === expected3) + } + } + + test("TargetEncoder - unseen value - error") { + + val df = spark + .createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setHandleInvalid(TargetEncoder.ERROR_INVALID) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val model = encoder.fit(df) + + val data_unseen = Row(0.toShort, 3, 10.0, + 0.toByte, 0.0, 0.0, 0.0, 0.0, 4.0/9, 4.0/9, 4.0/9, 0.0, 0.0, 0.0) + + val df_unseen = spark + .createDataFrame(sc.parallelize(data :+ data_unseen), schema) + + val ex = intercept[SparkRuntimeException] { + val out = model.transform(df_unseen) + out.show(false) + } + + assert(ex.isInstanceOf[SparkRuntimeException]) + assert(ex.getMessage.contains("Unseen value 10.0 in feature input3")) + + } + + test("TargetEncoder - missing feature") { + + val df = spark + .createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("binaryLabel") + .setInputCols(Array("input1", "input2", "input3")) + .setTargetType(TargetEncoder.TARGET_BINARY) + .setOutputCols(Array("output1", "output2", "output3")) + + val ex = intercept[SparkException] { + val model = encoder.fit(df.drop("input3")) + print(model.encodings) + } + + assert(ex.isInstanceOf[SparkException]) + assert(ex.getMessage.contains("No column named input3 found on dataset")) + } + + test("TargetEncoder - wrong data type") { + + val wrong_schema = new StructType( + schema.map{ + field: StructField => if (field.name != "input3") field + else new StructField(field.name, StringType, field.nullable, field.metadata) + }.toArray) + + val df = spark + .createDataFrame(sc.parallelize(data), wrong_schema) + .drop("continuousLabel") + + val encoder = new TargetEncoder() + .setLabelCol("binaryLabel") + .setInputCols(Array("input1", "input2", "input3")) + .setTargetType(TargetEncoder.TARGET_BINARY) + .setOutputCols(Array("output1", "output2", "output3")) + + val ex = intercept[SparkException] { + val model = encoder.fit(df) + print(model.encodings) + } + + assert(ex.isInstanceOf[SparkException]) + assert(ex.getMessage.contains("Data type for column input3 is StringType")) + } + + test("TargetEncoder - seen null category") { + + val data_null = Row(2.toShort, 3, null, + 0.toByte, 1.0/3, 0.0, 0.0, 90.0, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0) + + val df_null = spark + .createDataFrame(sc.parallelize(data.dropRight(1) :+ data_null), schema) + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val model = encoder.fit(df_null) + + val expected_encodings = Map( + "input1" -> Map(Some(0.0) -> 40.0, Some(1.0) -> 50.0, Some(2.0) -> 60.0, Some(-1.0) -> 50.0), + "input2" -> Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), + "input3" -> HashMap(Some(5.0) -> 20.0, Some(6.0) -> 50.0, Some(7.0) -> 70.0, + Some(8.0) -> 80.0, None -> 90.0, Some(-1.0) -> 50.0)) + + assert(model.encodings.equals(expected_encodings)) + + val output = model.transform(df_null) + + assert_true( + output("output1") === output("continuousExpected1") && + output("output1") === output("continuousExpected1") && + output("output1") === output("continuousExpected1")) + + } + + test("TargetEncoder - unseen null category") { + + val df = spark + .createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setHandleInvalid(TargetEncoder.KEEP_INVALID) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val data_null = Row(null, null, null, + 0.toByte, 1.0/3, 0.0, 0.0, 90.0, 50.0, 50.0, 50.0, 57.5, 50.0, 70.0) + + val df_null = spark + .createDataFrame(sc.parallelize(data :+ data_null), schema) + + val model = encoder.fit(df) + + val output = model.transform(df_null) + + assert_true( + output("output1") === output("continuousExpected1") && + output("output1") === output("continuousExpected1") && + output("output1") === output("continuousExpected1")) + + } + + test("TargetEncoder - non-indexed categories") { + + val encoder = new TargetEncoder() + .setLabelCol("binaryLabel") + .setTargetType(TargetEncoder.TARGET_BINARY) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val data_noindex = Row( + 0.toShort, 3, 5.1, 0.toByte, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + + val df_noindex = spark + .createDataFrame(sc.parallelize(data :+ data_noindex), schema) + + val ex = intercept[SparkException] { + val model = encoder.fit(df_noindex) + print(model.encodings) + } + + assert(ex.isInstanceOf[SparkException]) + assert(ex.getMessage.contains( + "Values from column input3 must be indices, but got 5.1")) + + } + + test("TargetEncoder - null label") { + + val data_nolabel = Row(2.toShort, 3, 5.0, + null, 1.0/3, 0.0, 0.0, null, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0) + + val df_nolabel = spark + .createDataFrame(sc.parallelize(data :+ data_nolabel), schema) + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val model = encoder.fit(df_nolabel) + + val expected_encodings = Map( + "input1" -> Map(Some(0.0) -> 40.0, Some(1.0) -> 50.0, Some(2.0) -> 60.0, Some(-1.0) -> 50.0), + "input2" -> Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), + "input3" -> HashMap(Some(5.0) -> 20.0, Some(6.0) -> 50.0, Some(7.0) -> 70.0, + Some(8.0) -> 80.0, Some(9.0) -> 90.0, Some(-1.0) -> 50.0)) + + print(model.encodings) + + assert(model.encodings.equals(expected_encodings)) + + } + + test("TargetEncoder - non-binary labels") { + + val encoder = new TargetEncoder() + .setLabelCol("binaryLabel") + .setTargetType(TargetEncoder.TARGET_BINARY) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val data_non_binary = Row( + 0.toShort, 3, 5.0, 2.toByte, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + + val df_non_binary = spark + .createDataFrame(sc.parallelize(data :+ data_non_binary), schema) + + val ex = intercept[SparkException] { + val model = encoder.fit(df_non_binary) + print(model.encodings) + } + + assert(ex.isInstanceOf[SparkException]) + assert(ex.getMessage.contains( + "Values from column binaryLabel must be binary (0,1) but got 2.0")) + + } + + test("TargetEncoder - R/W single-column") { + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setInputCol("input1") + .setOutputCol("output1") + .setHandleInvalid(TargetEncoder.ERROR_INVALID) + .setSmoothing(2) + + testDefaultReadWrite(encoder) + + } + + test("TargetEncoder - R/W multi-column") { + + val encoder = new TargetEncoder() + .setLabelCol("binaryLabel") + .setTargetType(TargetEncoder.TARGET_BINARY) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + .setHandleInvalid(TargetEncoder.KEEP_INVALID) + .setSmoothing(1) + + testDefaultReadWrite(encoder) + + } + +} \ No newline at end of file diff --git a/python/docs/source/reference/pyspark.ml.rst b/python/docs/source/reference/pyspark.ml.rst index 965cbe7eb5a57..f81498d3b5eae 100644 --- a/python/docs/source/reference/pyspark.ml.rst +++ b/python/docs/source/reference/pyspark.ml.rst @@ -104,6 +104,8 @@ Feature StopWordsRemover StringIndexer StringIndexerModel + TargetEncoder + TargetEncoderModel Tokenizer UnivariateFeatureSelector UnivariateFeatureSelectorModel diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 9a392c9dd420f..98fc6dc690880 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -104,6 +104,8 @@ "StopWordsRemover", "StringIndexer", "StringIndexerModel", + "TargetEncoder", + "TargetEncoderModel", "Tokenizer", "UnivariateFeatureSelector", "UnivariateFeatureSelectorModel", @@ -5200,6 +5202,305 @@ def loadDefaultStopWords(language: str) -> List[str]: return list(stopWordsObj.loadDefaultStopWords(language)) +class _TargetEncoderParams( + HasLabelCol, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, HasHandleInvalid +): + """ + Params for :py:class:`TargetEncoder` and :py:class:`TargetEncoderModel`. + + .. versionadded:: 4.0.0 + """ + + handleInvalid: Param[str] = Param( + Params._dummy(), + "handleInvalid", + "How to handle invalid data during transform(). " + + "Options are 'keep' (invalid data presented as an extra " + + "categorical feature) or error (throw an error).", + typeConverter=TypeConverters.toString, + ) + + targetType: Param[str] = Param( + Params._dummy(), + "targetType", + "whether the label is 'binary' or 'continuous'", + typeConverter=TypeConverters.toString, + ) + + smoothing: Param[float] = Param( + Params._dummy(), + "smoothing", + "value to smooth in-category averages with overall averages.", + typeConverter=TypeConverters.toFloat, + ) + + def __init__(self, *args: Any): + super(_TargetEncoderParams, self).__init__(*args) + self._setDefault(handleInvalid="error", targetType="binary", smoothing=0.0) + + @since("4.0.0") + def getTargetType(self) -> str: + """ + Gets the value of targetType or its default value. + """ + return self.getOrDefault(self.targetType) + + @since("4.0.0") + def getSmoothing(self) -> float: + """ + Gets the value of smoothing or its default value. + """ + return self.getOrDefault(self.smoothing) + + +@inherit_doc +class TargetEncoder( + JavaEstimator["TargetEncoderModel"], + _TargetEncoderParams, + JavaMLReadable["TargetEncoder"], + JavaMLWritable, +): + """ + Target Encoding maps a column of categorical indices into a numerical feature derived + from the target. + + When :py:attr:`handleInvalid` is configured to 'keep', previously unseen values of + a feature are mapped to the dataset overall statistics. + + When :py:attr:'targetType' is configured to 'binary', categories are encoded as the + conditional probability of the target given that category (bin counting). + When :py:attr:'targetType' is configured to 'continuous', categories are encoded as + the average of the target given that category (mean encoding) + + Parameter :py:attr:'smoothing' controls how in-category stats and overall stats are + weighted to build the encodings + + @note When encoding multi-column by using `inputCols` and `outputCols` params, + input/output cols come in pairs, specified by the order in the arrays, and each pair + is treated independently. + + .. versionadded:: 4.0.0 + """ + + _input_kwargs: Dict[str, Any] + + @overload + def __init__( + self, + *, + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ..., + labelCol: str = ..., + handleInvalid: str = ..., + targetType: str = ..., + smoothing: float = ..., + ): + ... + + @overload + def __init__( + self, + *, + labelCol: str = ..., + handleInvalid: str = ..., + targetType: str = ..., + smoothing: float = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + ): + ... + + @keyword_only + def __init__( + self, + *, + inputCols: Optional[List[str]] = None, + outputCols: Optional[List[str]] = None, + labelCol: str = "label", + handleInvalid: str = "error", + targetType: str = "binary", + smoothing: float = 0.0, + inputCol: Optional[str] = None, + outputCol: Optional[str] = None, + ): + """ + __init__(self, \\*, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, \ + targetType="binary", smoothing=0.0, inputCol=None, outputCol=None) + """ + super(TargetEncoder, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.TargetEncoder", self.uid) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @overload + def setParams( + self, + *, + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ..., + labelCol: str = ..., + handleInvalid: str = ..., + targetType: str = ..., + smoothing: float = ..., + ) -> "TargetEncoder": + ... + + @overload + def setParams( + self, + *, + labelCol: str = ..., + handleInvalid: str = ..., + targetType: str = ..., + smoothing: float = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + ) -> "TargetEncoder": + ... + + @keyword_only + @since("4.0.0") + def setParams( + self, + *, + inputCols: Optional[List[str]] = None, + outputCols: Optional[List[str]] = None, + labelCol: str = "label", + handleInvalid: str = "error", + targetType: str = "binary", + smoothing: float = 0.0, + inputCol: Optional[str] = None, + outputCol: Optional[str] = None, + ) -> "TargetEncoder": + """ + setParams(self, \\*, inputCols=None, outputCols=None, handleInvalid="error", \ + dropLast=True, inputCol=None, outputCol=None) + Sets params for this TargetEncoder. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("4.0.0") + def setInputCols(self, value: List[str]) -> "TargetEncoder": + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + @since("4.0.0") + def setOutputCols(self, value: List[str]) -> "TargetEncoder": + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + @since("4.0.0") + def setInputCol(self, value: str) -> "TargetEncoder": + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("4.0.0") + def setOutputCol(self, value: str) -> "TargetEncoder": + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + @since("4.0.0") + def setHandleInvalid(self, value: str) -> "TargetEncoder": + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + + @since("4.0.0") + def setTargetType(self, value: str) -> "TargetEncoder": + """ + Sets the value of :py:attr:`targetType`. + """ + return self._set(targetType=value) + + @since("4.0.0") + def setSmoothing(self, value: float) -> "TargetEncoder": + """ + Sets the value of :py:attr:`smoothing`. + """ + return self._set(smoothing=value) + + def _create_model(self, java_model: "JavaObject") -> "TargetEncoderModel": + return TargetEncoderModel(java_model) + + +class TargetEncoderModel( + JavaModel, _TargetEncoderParams, JavaMLReadable["TargetEncoderModel"], JavaMLWritable +): + """ + Model fitted by :py:class:`TargetEncoder`. + + .. versionadded:: 4.0.0 + """ + + @since("4.0.0") + def setInputCols(self, value: List[str]) -> "TargetEncoderModel": + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + @since("4.0.0") + def setOutputCols(self, value: List[str]) -> "TargetEncoderModel": + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + @since("4.0.0") + def setInputCol(self, value: str) -> "TargetEncoderModel": + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("4.0.0") + def setOutputCol(self, value: str) -> "TargetEncoderModel": + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + @since("4.0.0") + def setHandleInvalid(self, value: str) -> "TargetEncoderModel": + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + + @since("4.0.0") + def setTargetType(self, value: str) -> "TargetEncoderModel": + """ + Sets the value of :py:attr:`targetType`. + """ + return self._set(targetType=value) + + @since("4.0.0") + def setSmoothing(self, value: float) -> "TargetEncoderModel": + """ + Sets the value of :py:attr:`smoothing`. + """ + return self._set(smoothing=value) + + @property + @since("4.0.0") + def encodings(self) -> dict[str, dict[float, float]]: + """ + Fitted mappings for each feature to being encoded. + The dictionary contains a dictionary for each input column. + """ + return self._call_java("encodings") + + @inherit_doc class Tokenizer( JavaTransformer, diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index 4bf6641723da6..666ed1c4269e1 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -29,6 +29,7 @@ StopWordsRemover, StringIndexer, StringIndexerModel, + TargetEncoder, VectorSizeHint, ) from pyspark.ml.linalg import DenseVector, SparseVector, Vectors @@ -346,6 +347,80 @@ def test_string_indexer_from_labels(self): ) self.assertEqual(len(transformed_list), 5) + def test_target_encoder_binary(self): + df = self.spark.createDataFrame( + [ + (0, 3, 5.0, 0.0), + (1, 4, 5.0, 1.0), + (2, 3, 5.0, 0.0), + (0, 4, 6.0, 1.0), + (1, 3, 6.0, 0.0), + (2, 4, 6.0, 1.0), + (0, 3, 7.0, 0.0), + (1, 4, 8.0, 1.0), + (2, 3, 9.0, 0.0), + ], + schema="input1 short, input2 int, input3 double, label double", + ) + encoder = TargetEncoder( + inputCols=["input1", "input2", "input3"], + outputCols=["output", "output2", "output3"], + labelCol="label", + targetType="binary", + ) + model = encoder.fit(df) + te = model.transform(df) + actual = te.drop("label").collect() + expected = [ + Row(input1=0, input2=3, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3), + Row(input1=1, input2=4, input3=5.0, output1=2.0 / 3, output2=1.0, output3=1.0 / 3), + Row(input1=2, input2=3, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3), + Row(input1=0, input2=4, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3), + Row(input1=1, input2=3, input3=6.0, output1=2.0 / 3, output2=0.0, output3=2.0 / 3), + Row(input1=2, input2=4, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3), + Row(input1=0, input2=3, input3=7.0, output1=1.0 / 3, output2=0.0, output3=0.0), + Row(input1=1, input2=4, input3=8.0, output1=2.0 / 3, output2=1.0, output3=1.0), + Row(input1=2, input2=3, input3=9.0, output1=1.0 / 3, output2=0.0, output3=0.0), + ] + self.assertEqual(actual, expected) + + def test_target_encoder_continuous(self): + df = self.spark.createDataFrame( + [ + (0, 3, 5.0, 10.0), + (1, 4, 5.0, 20.0), + (2, 3, 5.0, 30.0), + (0, 4, 6.0, 40.0), + (1, 3, 6.0, 50.0), + (2, 4, 6.0, 60.0), + (0, 3, 7.0, 70.0), + (1, 4, 8.0, 80.0), + (2, 3, 9.0, 90.0), + ], + schema="input1 short, input2 int, input3 double, label double", + ) + encoder = TargetEncoder( + inputCols=["input1", "input2", "input3"], + outputCols=["output", "output2", "output3"], + labelCol="label", + targetType="continuous", + ) + model = encoder.fit(df) + te = model.transform(df) + actual = te.drop("label").collect() + expected = [ + Row(input1=0, input2=3, input3=5.0, output1=40.0, output2=50.0, output3=20.0), + Row(input1=1, input2=4, input3=5.0, output1=50.0, output2=50.0, output3=20.0), + Row(input1=2, input2=3, input3=5.0, output1=60.0, output2=50.0, output3=20.0), + Row(input1=0, input2=4, input3=6.0, output1=40.0, output2=50.0, output3=50.0), + Row(input1=1, input2=3, input3=6.0, output1=50.0, output2=50.0, output3=50.0), + Row(input1=2, input2=4, input3=6.0, output1=60.0, output2=50.0, output3=50.0), + Row(input1=0, input2=3, input3=7.0, output1=40.0, output2=50.0, output3=70.0), + Row(input1=1, input2=4, input3=8.0, output1=50.0, output2=50.0, output3=80.0), + Row(input1=2, input2=3, input3=9.0, output1=60.0, output2=50.0, output3=90.0), + ] + self.assertEqual(actual, expected) + def test_vector_size_hint(self): df = self.spark.createDataFrame( [