Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-37178][ML] Add Target Encoding to ml.feature #48347

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

rebo16v
Copy link

@rebo16v rebo16v commented Oct 4, 2024

What changes were proposed in this pull request?

Adds support for target encoding of ml features.
Target Encoding maps a column of categorical indices into a numerical feature derived from the target.
Leveraging the relationship between categorical variables and the target variable, target encoding usually performs better than one-hot encoding (while avoiding the need to add extra columns)

Why are the changes needed?

Target Encoding is a well-known encoding technique for categorical features.
It's supported on most ml frameworks
https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.TargetEncoder.html
https://search.r-project.org/CRAN/refmans/dataPreparation/html/target_encode.html

Does this PR introduce any user-facing change?

Spark API now includes 2 new classes in package org.apache.spark.ml

  • TargetEncoder (estimator)
  • TargetEncoderModel (transformer)

How was this patch tested?

Scala => org.apache.spark.ml.feature.TargetEncoderSuite
Java => org.apache.spark.ml.feature.JavaTargetEncoderSuite
Python => python.pyspark.ml.tests.test_feature.FeatureTests (added 2 tests)

Was this patch authored or co-authored using generative AI tooling?

No

Some design notes ... |-

  • binary and continuous target types (no multi-label yet)

  • available in Scala, Java and Python APIs

  • fitting implemented on RDD API (treeAggregate)

  • transformation implemented on Dataframe API (no UDFs)

  • categorical features must be indices (integers) in Double-typed columns (as if StringIndexer were used before)

  • unseen categories in training are represented as class -1.0

  • Encodings structure

    • Map[String, Map[Double, Double]]) => Map[ feature_name, Map[ original_category, encoded category ] ]
  • Parameters

    • inputCol(s) / outputCol(s) / labelCol => as usual
    • targetType
      • binary => encodings calculated as in-category conditional probability (counting)
      • continuous => encodings calculated as in-category target mean (incrementally)
    • handleInvalid
      • error => raises an error if trying to encode an unseen category
      • keep => encodes an unseen category with the overall statistics
    • smoothing => controls how in-category stats and overall stats are weighted to calculate final encodings (to avoid overfitting)

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good start, just need to clarify the implementation and consider supporting a few more cases

@@ -855,6 +855,46 @@ for more details on the API.

</div>

## TargetEncoder

Target Encoding maps a column of categorical indices into a numerical feature derived from the target. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's drop at least a link to information on what target encoding is here.
Also, the explanation you give in the PR about what this actually does to which types of input is valuable and should probably be here too, either here or below in discussion of what the parameters do in some detail.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it's ok now. what do you think?

feature => {
try {
val field = schema(feature)
if (field.dataType != DoubleType) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the features have to be floats? I'd imagine they aren't if they're categorical representations you're encoding. I think it's OK to demand they're not strings and are already passed through StringIndexer in that case, but it feels like any numeric type works here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i mimic this behavior from other encoders (i.e. OneHotEncoder)
what would be your approach? accepting Integers? checking for nominal attribute in metadata?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's OK if it's what other encoders do. But I see checks like https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala#L93 - maybe follow that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right
now accepting any subclass of NumericType for features & label
(maybe it doesn't make much sense the continuous case, it could be done anyway)

validateSchema(dataset.schema, fitting = true)

val stats = dataset
.select(ArraySeq.unsafeWrapArray(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the ArraySeq business necessary? you're just selecting columns with : _* syntax so any seq would do

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't work in Scala 2.13
Passing an explicit array value to a Scala varargs method is deprecated (since 2.13.0) and will result in a defensive copy; Use the more efficient non-copying ArraySeq.unsafeWrapArray or an explicit toIndexedSeq call

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'd say .toIndexedSeq is simpler

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right. done

globalCounter._2 + ((label - globalCounter._2) / (1 + globalCounter._1))))
}
} catch {
case e: SparkRuntimeException =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indent

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got resolved in the overall refactor

case e: SparkRuntimeException =>
if (e.getErrorClass == "ROW_VALUE_IS_NULL") {
throw new SparkException(s"Null value found in feature ${inputFeatures(feature)}." +
s" See Imputer estimator for completing missing values.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like you can still target-encode null; it's just another possible value, no?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes
it will be encoded as an unseen category (global statistics)
we could raise an error (as we do while fitting)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this throws an exception? how is it handled as unseen but also raises an exception? it shouldn't, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it raises an exception while fitting and encodes as unseen category while transforming.
I´ll check scikit-learn behavior

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I mean while fitting. I don't feel like this is necessary

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following scikit approach, now treating null as another category
(category becomes Option[Double])
encodings: Map[String, Map[Option[Double], Double]]

val value = row.getDouble(feature)
if (value < 0.0 || value != value.toInt) throw new SparkException(
s"Values from column ${inputFeatures(feature)} must be indices, but got $value.")
val counter = agg(feature).getOrElse(value, (0.0, 0.0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use val (foo, bar) = syntax so you don't have to use more cryptic ._1 references later

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

val globalCounter = agg(feature).getOrElse(TargetEncoder.UNSEEN_CATEGORY, (0.0, 0.0))
$(targetType) match {
case TargetEncoder.TARGET_BINARY =>
if (label == 1.0) agg(feature) +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These if-else clauses need to be indented and with braces around them for clarity

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

})(
(agg, row: Row) => {
val label = row.getDouble(inputFeatures.length)
Range(0, inputFeatures.length).map {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1 until inputFeatures.length) feels a little more idiomatic, or even for ... yield

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

finally changed for inputFeatures.indices

val values = agg1(feature).keySet ++ agg2(feature).keySet
values.map(value =>
value -> {
val stat1 = agg1(feature).getOrElse(value, (0.0, 0.0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, let's give names to the elements of this tuple
A comment or two in these blocks about what this sum is doing would help too

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

rest
.foldLeft(when(col === first._1, first._2))(
(new_col: Column, encoding) =>
if (encoding._1 != TargetEncoder.UNSEEN_CATEGORY) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And same again around here - some comments and more descriptive var names are important, as I have trouble evaluating the logic

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Map.empty[Option[Double], (Double, Double)]
})(
(agg, row: Row) => {
val label = label_type match {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't work yet on handling null labels
i checked scikit and it fails at this (encoding all to NaN)
we could

  1. raise an exception
  2. do not consider the observation and keep going
    what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I guess I think it's most sensible to ignore nulls then

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

`TargetEncoder` supports the `targetType` parameter to choose the label type when fitting data, affecting how statistics are calculated.
Available options include 'binary' and 'continuous' (mean-encoding).
When set to 'binary', encodings will be fitted from target conditional probabilities (a.k.a bin-counting).
When set to 'continuous', encodings will be fitted from according to target mean (a.k.a. mean-encoding).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can describe this a little bit more somewhere, could be here or at the top - what does target encoding actually do? with a simplistic example of a few rows?

Just want to make it immediate clearly in 1 paragraph what this is doing for binary vs continuous targets

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

case (cat, (class_count, class_stat)) => cat -> {
val weight = class_count / (class_count + $(smoothing))
$(targetType) match {
case TargetEncoder.TARGET_BINARY =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all might be worth a few lines of comments explaining the math here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Array.fill(inputFeatures.length) {
Map.empty[Option[Double], (Double, Double)]
})(
(agg, row: Row) => if (!row.isNullAt(inputFeatures.length)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, if the label is null, hm, that could be an error. I don't think it makes sense to have input to any supervised problem where the label is missing. Here you'd be welcome to copy whatever other similar classes do in this situation

Copy link
Author

@rebo16v rebo16v Oct 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we didn't understand each other
so, you want to raise an exception in case a label is null?

inputFeatures.indices.map {
feature => {
val category: Option[Double] = {
if (row.isNullAt(feature)) None // null category
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is doesn't seem like it's ignoring inputs where feature i is null, but reading it as a category None. Maybe I misread the rest? but it seems like it proceeds anyway.

I don't think that's crazy but per your earlier comments, thought the intent might be to ignore these

Copy link
Author

@rebo16v rebo16v Oct 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nulls in features are treated as a category (with key None in the map)
encodings: Map[feature, Map[Some(category), encoding]]
it's null labels that i´m dropping

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with that too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants