Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add allowed_values_transformer #46

Merged
merged 1 commit into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ poetry install
|[`Deep transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/deep_transformer/)|[`ToVecTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/deep_transformer/#sk_transformers.deep_transformer.ToVecTransformer)|This transformer trains an [FT-Transformer](https://paperswithcode.com/method/ft-transformer) using the [pytorch-widedeep package](https://github.com/jrzaurin/pytorch-widedeep) and extracts the embeddings from its embedding layer.|
|[`Encoder transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/encoder_transformer/)|[`MeanEncoderTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/encoder_transformer/#sk_transformers.encoder_transformer.MeanEncoderTransformer)|Scikit-learn API for the [feature-engine MeanEncoder](https://feature-engine.readthedocs.io/en/latest/api_doc/encoding/MeanEncoder.html).|
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`AggregateTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.AggregateTransformer)|This transformer uses Pandas groupby method and aggregate to apply function on a column grouped by another column.|
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`AllowedValuesTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.AllowedValuesTransformer)|This transformer replaces values that are *not* in a list with another value.|
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`ColumnDropperTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.ColumnDropperTransformer)|Drops columns from a dataframe using Pandas drop method.|
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`DtypeTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.DtypeTransformer)|Transformer that converts a column to a different dtype.|
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`FunctionsTransformer`]( https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.FunctionsTransformer)|This transformer is a plain wrapper around the [sklearn.preprocessing.FunctionTransformer](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.FunctionTransformer.html).|
Expand Down
31 changes: 29 additions & 2 deletions examples/playground.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,33 @@
"transformer.fit_transform(X).to_numpy()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### [`AllowedValuesTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.AllowedValuesTransformer)\n",
"\n",
"Replaces all values that are not in a list of allowed values with a replacement value.\n",
"This performs an complementary transformation to that of the ValueReplacerTransformer.\n",
"This is useful while lumping several minor categories together by selecting them\n",
"using a list of major categories."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from sk_transformers import AllowedValuesTransformer\n",
"\n",
"X = pd.DataFrame({\"foo\": [\"a\", \"b\", \"c\", \"d\", \"e\"]})\n",
"transformer = AllowedValuesTransformer([(\"foo\", [\"a\", \"b\"], \"other\")])\n",
"transformer.fit_transform(X)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -648,7 +675,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "trans",
"display_name": "skt10",
"language": "python",
"name": "python3"
},
Expand All @@ -666,7 +693,7 @@
},
"vscode": {
"interpreter": {
"hash": "e0ab0d7b7c2358a4e8dc9a679aa1e03c864d2b2d0f3bb28338b17fac2dad41ae"
"hash": "27d70add5eb29fae1e3167508c5721c993885797430f60563aa3128fba219313"
}
}
},
Expand Down
1 change: 1 addition & 0 deletions src/sk_transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sk_transformers.encoder_transformer import MeanEncoderTransformer
from sk_transformers.generic_transformer import (
AggregateTransformer,
AllowedValuesTransformer,
ColumnDropperTransformer,
DtypeTransformer,
FunctionsTransformer,
Expand Down
57 changes: 57 additions & 0 deletions src/sk_transformers/generic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,3 +613,60 @@ def __prefix_df_column_names(
elif isinstance(df, pd.DataFrame):
df.columns = [prefix + "_" + column for column in df.columns]
return df


class AllowedValuesTransformer(BaseTransformer):
"""Replaces all values that are not in a list of allowed values with a
replacement value. This performs an complementary transformation to that of
the ValueReplacerTransformer. This is useful while lumping several minor
categories together by selecting them using a list of major categories.

Example:
```python
import pandas as pd
from sk_transformers.generic_transformer import AllowedValuesTransformer

X = pd.DataFrame({"foo": ["a", "b", "c", "d", "e"]})
transformer = AllowedValuesTransformer([("foo", ["a", "b"], "other")])
transformer.fit_transform(X)
```
```
foo
0 a
1 b
2 other
3 other
4 other
```

Args:
features (List[Tuple[str, List[Any], Any]]): List of tuples where
the first element is the column name,
the second element is the list of allowed values in the column, and
the third element is the value to replace disallowed values in the column.
"""

def __init__(self, features: List[Tuple[str, List[Any], Any]]) -> None:
super().__init__()
self.features = features

def transform(self, X: pd.DataFrame) -> pd.DataFrame:
"""Replaces values not in a list with another value.

Args:
X (pd.DataFrame): Dataframe containing the columns with values to be replaced.

Returns:
pd.DataFrame: Dataframe with replaced values.
"""

X = check_ready_to_transform(
self,
X,
[feature[0] for feature in self.features],
)

for (column, allowed_values, replacement) in self.features:
X.loc[~X[column].isin(allowed_values), column] = replacement

return X
18 changes: 18 additions & 0 deletions tests/test_transformer/test_generic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from sk_transformers import (
AggregateTransformer,
AllowedValuesTransformer,
ColumnDropperTransformer,
DtypeTransformer,
FunctionsTransformer,
Expand Down Expand Up @@ -294,3 +295,20 @@ def test_left_join_transformer_in_pipeline_with_nan(X_categorical) -> None:
assert "a_values" in result.columns
assert np.isclose(result["a_values"].iloc[5], np.nan, equal_nan=True)
assert pipeline.steps[0][0] == "leftjointransformer"


def test_allowed_values_transformer_in_pipeline(X) -> None:
values = [
("a", [1, 2], -999),
("c", ["1", "2"], "other"),
]
pipeline = make_pipeline(AllowedValuesTransformer(values))
result = pipeline.fit_transform(X)
expected_a = np.array([1, 2, -999, -999, -999, -999, -999, -999, -999, -999])
expected_c = np.array(
["1", "1", "1", "1", "2", "2", "2", "other", "other", "other"]
)

assert np.array_equal(result["a"].to_numpy(), expected_a)
assert np.array_equal(result["c"].to_numpy(), expected_c)
assert pipeline.steps[0][0] == "allowedvaluestransformer"