Skip to content

Commit

Permalink
refactor: add possibility to name target column
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Lemke committed Jan 20, 2023
1 parent e414da6 commit 0ec8b42
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ repos:
rev: 22.12.0
hooks:
- id: black
args: ["--config=pyproject.toml"]
- id: black-jupyter
args: ["--config=pyproject.toml"]
files: \.ipynb$

- repo: https://github.com/PyCQA/isort
Expand Down
17 changes: 12 additions & 5 deletions src/sk_transformers/generic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,19 @@ class ColumnEvalTransformer(BaseTransformer):
```
Args:
features (List[Tuple[str, str]]): List of tuples containing the column name and the method (`eval_func`) to apply.
E.g. `("foo", "str.upper()")` will apply the `str.upper()` on the column `foo`.
features (List[Union[Tuple[str, str], Tuple[str, str, str]]]): List of tuples containing the column name and the method (`eval_func`) to apply.
As a third entry in the tuple an alternative name for the column can be provided. If not provided the column name will be used.
This can be useful if the the original column should not be replaced but another column should be created.
Raises:
ValueError: If the `eval_func` starts with a dot (`.`).
Warning: If the `eval_func` contains `apply` but not `swifter`.
ValueError: If the `eval_func` tries to assign multiple columns to one target column.
"""

def __init__(self, features: List[Tuple[str, str]]) -> None:
def __init__(
self, features: List[Union[Tuple[str, str], Tuple[str, str, str]]]
) -> None:
super().__init__()
self.features = features

Expand All @@ -60,7 +63,11 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
force_all_finite="allow-nan",
)

for (column, eval_func) in self.features:
for eval_tuple in self.features:

column = eval_tuple[0]
eval_func = eval_tuple[1]
new_column = eval_tuple[2] if len(eval_tuple) == 3 else column # type: ignore

if eval_func[0] == ".":
raise ValueError(
Expand All @@ -77,7 +84,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
)

try:
X[column] = eval( # pylint: disable=eval-used # nosec
X[new_column] = eval( # pylint: disable=eval-used # nosec
f"X[{'column'}].{eval_func}"
)
except ValueError as e:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_transformer/test_generic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ def test_column_eval_transformer_in_pipeline(X_strings) -> None:
assert pipeline.steps[0][0] == "columnevaltransformer"


def test_column_eval_transformer_in_pipeline_new_column(X_strings) -> None:
pipeline = make_pipeline(
ColumnEvalTransformer([("email", "str.contains('@')", "new_column")])
)
X = pipeline.fit_transform(X_strings)

assert X["new_column"].to_list() == [True, True, True, True, True, False]
assert pipeline.steps[0][0] == "columnevaltransformer"


def test_column_eval_transformer_with_invalid_start_of_eval(X_strings) -> None:
with pytest.raises(ValueError) as error:
transformer = ColumnEvalTransformer([("email", ".str.contains('@')")])
Expand Down

0 comments on commit 0ec8b42

Please sign in to comment.