Skip to content

Commit

Permalink
feat(steps): implement target encoding for columns
Browse files Browse the repository at this point in the history
  • Loading branch information
deepyaman committed Apr 9, 2024
1 parent 3e64906 commit 3bef1aa
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 1 deletion.
8 changes: 7 additions & 1 deletion ibisml/steps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from ibisml.steps.common import Cast, Drop, Mutate, MutateAt
from ibisml.steps.encode import CategoricalEncode, CountEncode, OneHotEncode
from ibisml.steps.encode import (
CategoricalEncode,
CountEncode,
OneHotEncode,
TargetEncode,
)
from ibisml.steps.feature_selection import ZeroVariance
from ibisml.steps.impute import FillNA, ImputeMean, ImputeMedian, ImputeMode
from ibisml.steps.standardize import ScaleMinMax, ScaleStandard
Expand All @@ -22,5 +27,6 @@
"OneHotEncode",
"ScaleMinMax",
"ScaleStandard",
"TargetEncode",
"ZeroVariance",
)
72 changes: 72 additions & 0 deletions ibisml/steps/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,75 @@ def transform_table(self, table: ir.Table) -> ir.Table:
fillna = FillNA(self.value_counts_, 0)
fillna.fit_table(table, Metadata())
return fillna.transform_table(table)


class TargetEncode(Step):
"""A step for target encoding select columns.
Parameters
----------
inputs
A selection of columns to target encode.
smooth
The amount of mixing of the target mean conditioned on the value of the
category with the global target mean. A larger `smooth` value will put
more weight on the global target mean.
Examples
--------
>>> import ibisml as ml
Target encode all string columns.
>>> step = ml.TargetEncode(ml.string())
"""

def __init__(self, inputs: SelectionType, smooth: float = 0.0) -> None:
self.inputs = selector(inputs)
self.smooth = smooth

def _repr(self) -> Iterable[tuple[str, Any]]:
yield ("", self.inputs)
yield ("smooth", self.smooth)

def fit_table(self, table: ir.Table, metadata: Metadata) -> None:
target_means = (
table.aggregate([table[c].mean().name(c) for c in metadata.targets])
.execute()
.to_dict("records")[0]
)

target_aggs = {}
for target in metadata.targets:
target_aggs[f"{target}_mean"] = table[target].mean()
target_aggs[f"{target}_count"] = table[target].count()

columns = self.inputs.select_columns(table, metadata)
self.encodings_ = {}
for column in columns:
agged = table.group_by(column).aggregate(**target_aggs)

target_encodings = {}
for target in metadata.targets:
target_encodings[f"{target}"] = agged[f"{target}_mean"] * agged[
f"{target}_count"
] + target_means[target] * self.smooth / (
agged[f"{target}_count"] + self.smooth
)

self.encodings_[column] = ibis.memtable(
agged.mutate(**target_encodings).drop(target_aggs).to_pyarrow()
)

def transform_table(self, table: ir.Table) -> ir.Table:
for c, encodings in self.encodings_.items():
joined = table.left_join(
encodings, table[c] == encodings[0], lname="left_{name}", rname=""
)
table = joined.drop(encodings.columns[0], f"left_{c}").rename(
{c: encodings.columns[1]}
if len(encodings.columns) < 3
else {f"{c}{k + 1}": t for k, t in enumerate(encodings.columns[1:])}
)

return table

0 comments on commit 3bef1aa

Please sign in to comment.