From 0ec8b42f6fe7810c982a495073241d66f04b7915 Mon Sep 17 00:00:00 2001 From: Chris Lemke <1@lemke.ai> Date: Fri, 20 Jan 2023 14:09:45 +0100 Subject: [PATCH] refactor: add possibility to name target column --- .pre-commit-config.yaml | 2 -- src/sk_transformers/generic_transformer.py | 17 ++++++++++++----- .../test_generic_transformer.py | 10 ++++++++++ 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 548587a..c727e87 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -86,9 +86,7 @@ repos: rev: 22.12.0 hooks: - id: black - args: ["--config=pyproject.toml"] - id: black-jupyter - args: ["--config=pyproject.toml"] files: \.ipynb$ - repo: https://github.com/PyCQA/isort diff --git a/src/sk_transformers/generic_transformer.py b/src/sk_transformers/generic_transformer.py index fb7397c..4fdcc6e 100644 --- a/src/sk_transformers/generic_transformer.py +++ b/src/sk_transformers/generic_transformer.py @@ -31,8 +31,9 @@ class ColumnEvalTransformer(BaseTransformer): ``` Args: - features (List[Tuple[str, str]]): List of tuples containing the column name and the method (`eval_func`) to apply. - E.g. `("foo", "str.upper()")` will apply the `str.upper()` on the column `foo`. + features (List[Union[Tuple[str, str], Tuple[str, str, str]]]): List of tuples containing the column name and the method (`eval_func`) to apply. + As a third entry in the tuple an alternative name for the column can be provided. If not provided the column name will be used. + This can be useful if the the original column should not be replaced but another column should be created. Raises: ValueError: If the `eval_func` starts with a dot (`.`). @@ -40,7 +41,9 @@ class ColumnEvalTransformer(BaseTransformer): ValueError: If the `eval_func` tries to assign multiple columns to one target column. """ - def __init__(self, features: List[Tuple[str, str]]) -> None: + def __init__( + self, features: List[Union[Tuple[str, str], Tuple[str, str, str]]] + ) -> None: super().__init__() self.features = features @@ -60,7 +63,11 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame: force_all_finite="allow-nan", ) - for (column, eval_func) in self.features: + for eval_tuple in self.features: + + column = eval_tuple[0] + eval_func = eval_tuple[1] + new_column = eval_tuple[2] if len(eval_tuple) == 3 else column # type: ignore if eval_func[0] == ".": raise ValueError( @@ -77,7 +84,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame: ) try: - X[column] = eval( # pylint: disable=eval-used # nosec + X[new_column] = eval( # pylint: disable=eval-used # nosec f"X[{'column'}].{eval_func}" ) except ValueError as e: diff --git a/tests/test_transformer/test_generic_transformer.py b/tests/test_transformer/test_generic_transformer.py index ee23774..1007cd8 100644 --- a/tests/test_transformer/test_generic_transformer.py +++ b/tests/test_transformer/test_generic_transformer.py @@ -29,6 +29,16 @@ def test_column_eval_transformer_in_pipeline(X_strings) -> None: assert pipeline.steps[0][0] == "columnevaltransformer" +def test_column_eval_transformer_in_pipeline_new_column(X_strings) -> None: + pipeline = make_pipeline( + ColumnEvalTransformer([("email", "str.contains('@')", "new_column")]) + ) + X = pipeline.fit_transform(X_strings) + + assert X["new_column"].to_list() == [True, True, True, True, True, False] + assert pipeline.steps[0][0] == "columnevaltransformer" + + def test_column_eval_transformer_with_invalid_start_of_eval(X_strings) -> None: with pytest.raises(ValueError) as error: transformer = ColumnEvalTransformer([("email", ".str.contains('@')")])