-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-37178][ML] Add Target Encoding to ml.feature #48347
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good start, just need to clarify the implementation and consider supporting a few more cases
docs/ml-features.md
Outdated
@@ -855,6 +855,46 @@ for more details on the API. | |||
|
|||
</div> | |||
|
|||
## TargetEncoder | |||
|
|||
Target Encoding maps a column of categorical indices into a numerical feature derived from the target. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's drop at least a link to information on what target encoding is here.
Also, the explanation you give in the PR about what this actually does to which types of input is valuable and should probably be here too, either here or below in discussion of what the parameters do in some detail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think it's ok now. what do you think?
feature => { | ||
try { | ||
val field = schema(feature) | ||
if (field.dataType != DoubleType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do the features have to be floats? I'd imagine they aren't if they're categorical representations you're encoding. I think it's OK to demand they're not strings and are already passed through StringIndexer in that case, but it feels like any numeric type works here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i mimic this behavior from other encoders (i.e. OneHotEncoder)
what would be your approach? accepting Integers? checking for nominal attribute in metadata?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's OK if it's what other encoders do. But I see checks like https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala#L93 - maybe follow that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right
now accepting any subclass of NumericType for features & label
(maybe it doesn't make much sense the continuous case, it could be done anyway)
validateSchema(dataset.schema, fitting = true) | ||
|
||
val stats = dataset | ||
.select(ArraySeq.unsafeWrapArray( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the ArraySeq business necessary? you're just selecting columns with : _*
syntax so any seq would do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it doesn't work in Scala 2.13
Passing an explicit array value to a Scala varargs method is deprecated (since 2.13.0) and will result in a defensive copy; Use the more efficient non-copying ArraySeq.unsafeWrapArray or an explicit toIndexedSeq call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I'd say .toIndexedSeq is simpler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right. done
globalCounter._2 + ((label - globalCounter._2) / (1 + globalCounter._1)))) | ||
} | ||
} catch { | ||
case e: SparkRuntimeException => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got resolved in the overall refactor
case e: SparkRuntimeException => | ||
if (e.getErrorClass == "ROW_VALUE_IS_NULL") { | ||
throw new SparkException(s"Null value found in feature ${inputFeatures(feature)}." + | ||
s" See Imputer estimator for completing missing values.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like you can still target-encode null; it's just another possible value, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
it will be encoded as an unseen category (global statistics)
we could raise an error (as we do while fitting)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this throws an exception? how is it handled as unseen but also raises an exception? it shouldn't, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it raises an exception while fitting and encodes as unseen category while transforming.
I´ll check scikit-learn behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I mean while fitting. I don't feel like this is necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
following scikit approach, now treating null as another category
(category becomes Option[Double])
encodings: Map[String, Map[Option[Double], Double]]
val value = row.getDouble(feature) | ||
if (value < 0.0 || value != value.toInt) throw new SparkException( | ||
s"Values from column ${inputFeatures(feature)} must be indices, but got $value.") | ||
val counter = agg(feature).getOrElse(value, (0.0, 0.0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use val (foo, bar) =
syntax so you don't have to use more cryptic ._1
references later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
val globalCounter = agg(feature).getOrElse(TargetEncoder.UNSEEN_CATEGORY, (0.0, 0.0)) | ||
$(targetType) match { | ||
case TargetEncoder.TARGET_BINARY => | ||
if (label == 1.0) agg(feature) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These if-else clauses need to be indented and with braces around them for clarity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
})( | ||
(agg, row: Row) => { | ||
val label = row.getDouble(inputFeatures.length) | ||
Range(0, inputFeatures.length).map { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(1 until inputFeatures.length)
feels a little more idiomatic, or even for ... yield
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
finally changed for inputFeatures.indices
val values = agg1(feature).keySet ++ agg2(feature).keySet | ||
values.map(value => | ||
value -> { | ||
val stat1 = agg1(feature).getOrElse(value, (0.0, 0.0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, let's give names to the elements of this tuple
A comment or two in these blocks about what this sum is doing would help too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
rest | ||
.foldLeft(when(col === first._1, first._2))( | ||
(new_col: Column, encoding) => | ||
if (encoding._1 != TargetEncoder.UNSEEN_CATEGORY) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And same again around here - some comments and more descriptive var names are important, as I have trouble evaluating the logic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
…ding into sparkml-target-encoding
Map.empty[Option[Double], (Double, Double)] | ||
})( | ||
(agg, row: Row) => { | ||
val label = label_type match { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't work yet on handling null labels
i checked scikit and it fails at this (encoding all to NaN)
we could
- raise an exception
- do not consider the observation and keep going
what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I guess I think it's most sensible to ignore nulls then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
docs/ml-features.md
Outdated
`TargetEncoder` supports the `targetType` parameter to choose the label type when fitting data, affecting how statistics are calculated. | ||
Available options include 'binary' and 'continuous' (mean-encoding). | ||
When set to 'binary', encodings will be fitted from target conditional probabilities (a.k.a bin-counting). | ||
When set to 'continuous', encodings will be fitted from according to target mean (a.k.a. mean-encoding). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can describe this a little bit more somewhere, could be here or at the top - what does target encoding actually do? with a simplistic example of a few rows?
Just want to make it immediate clearly in 1 paragraph what this is doing for binary vs continuous targets
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
case (cat, (class_count, class_stat)) => cat -> { | ||
val weight = class_count / (class_count + $(smoothing)) | ||
$(targetType) match { | ||
case TargetEncoder.TARGET_BINARY => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This all might be worth a few lines of comments explaining the math here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Array.fill(inputFeatures.length) { | ||
Map.empty[Option[Double], (Double, Double)] | ||
})( | ||
(agg, row: Row) => if (!row.isNullAt(inputFeatures.length)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, if the label is null, hm, that could be an error. I don't think it makes sense to have input to any supervised problem where the label is missing. Here you'd be welcome to copy whatever other similar classes do in this situation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we didn't understand each other
so, you want to raise an exception in case a label is null?
inputFeatures.indices.map { | ||
feature => { | ||
val category: Option[Double] = { | ||
if (row.isNullAt(feature)) None // null category |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is doesn't seem like it's ignoring inputs where feature i is null, but reading it as a category None. Maybe I misread the rest? but it seems like it proceeds anyway.
I don't think that's crazy but per your earlier comments, thought the intent might be to ignore these
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nulls in features are treated as a category (with key None in the map)
encodings: Map[feature, Map[Some(category), encoding]]
it's null labels that i´m dropping
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with that too.
…ding into sparkml-target-encoding
What changes were proposed in this pull request?
Adds support for target encoding of ml features.
Target Encoding maps a column of categorical indices into a numerical feature derived from the target.
Leveraging the relationship between categorical variables and the target variable, target encoding usually performs better than one-hot encoding (while avoiding the need to add extra columns)
Why are the changes needed?
Target Encoding is a well-known encoding technique for categorical features.
It's supported on most ml frameworks
https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.TargetEncoder.html
https://search.r-project.org/CRAN/refmans/dataPreparation/html/target_encode.html
Does this PR introduce any user-facing change?
Spark API now includes 2 new classes in package org.apache.spark.ml
How was this patch tested?
Scala => org.apache.spark.ml.feature.TargetEncoderSuite
Java => org.apache.spark.ml.feature.JavaTargetEncoderSuite
Python => python.pyspark.ml.tests.test_feature.FeatureTests (added 2 tests)
Was this patch authored or co-authored using generative AI tooling?
No
Some design notes ... |-
binary and continuous target types (no multi-label yet)
available in Scala, Java and Python APIs
fitting implemented on RDD API (treeAggregate)
transformation implemented on Dataframe API (no UDFs)
categorical features must be indices (integers) in Double-typed columns (as if StringIndexer were used before)
unseen categories in training are represented as class -1.0
Encodings structure
Parameters