Skip to content

Commit

Permalink
feat: add SequentialTableTransformer (#893)
Browse files Browse the repository at this point in the history
Closes #802 

### Summary of Changes
Added ```SequentialTableTransformer``` class in transformers to
transform tables sequentially using multiple transformers.
Added ```TransformerNotInvertibleError```, which gets raised when
attempting to inverse transform a table using a non-invertible
transformer.
Modified test helper ```assert_tables_equal``` to include more optional
parameters.

---------

Co-authored-by: xXstupidnameXx <xxstupidnamexx@gmail.com>
Co-authored-by: xXstupidnameXx <118291096+xXstupidnameXx@users.noreply.github.com>
Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
Co-authored-by: Lars Reimann <mail@larsreimann.com>
  • Loading branch information
5 people authored Jul 12, 2024
1 parent 4e65ba9 commit e93299f
Show file tree
Hide file tree
Showing 6 changed files with 409 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/safeds/data/tabular/transformation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ._one_hot_encoder import OneHotEncoder
from ._range_scaler import RangeScaler
from ._robust_scaler import RobustScaler
from ._sequential_table_transformer import SequentialTableTransformer
from ._simple_imputer import SimpleImputer
from ._standard_scaler import StandardScaler
from ._table_transformer import TableTransformer
Expand All @@ -27,6 +28,7 @@
"LabelEncoder": "._label_encoder:LabelEncoder",
"OneHotEncoder": "._one_hot_encoder:OneHotEncoder",
"RangeScaler": "._range_scaler:RangeScaler",
"SequentialTableTransformer": "._sequential_table_transformer:SequentialTableTransformer",
"RobustScaler": "._robust_scaler:RobustScaler",
"SimpleImputer": "._simple_imputer:SimpleImputer",
"StandardScaler": "._standard_scaler:StandardScaler",
Expand All @@ -42,6 +44,7 @@
"LabelEncoder",
"OneHotEncoder",
"RangeScaler",
"SequentialTableTransformer",
"RobustScaler",
"SimpleImputer",
"StandardScaler",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from warnings import warn

from safeds._utils import _structural_hash
from safeds.exceptions import TransformerNotFittedError, TransformerNotInvertibleError

from ._invertible_table_transformer import InvertibleTableTransformer

if TYPE_CHECKING:
from safeds.data.tabular.containers import Table

from ._table_transformer import TableTransformer


class SequentialTableTransformer(InvertibleTableTransformer):
"""
The SequentialTableTransformer transforms a table using multiple transformers in sequence.
Parameters
----------
transformers:
The list of transformers used to transform the table. Used in the order as they are supplied in the list.
"""

def __init__(
self,
transformers: list[TableTransformer],
) -> None:
super().__init__(None)

# Check if transformers actually contains any transformers.
if transformers is None or len(transformers) == 0:
warn(
"transformers should contain at least 1 transformer",
UserWarning,
stacklevel=2,
)

# Parameters
self._transformers: list[TableTransformer] = transformers

# Internal State
self._is_fitted: bool = False

def __hash__(self) -> int:
return _structural_hash(
super().__hash__(),
self._transformers,
self._is_fitted,
)

@property
def is_fitted(self) -> bool:
"""Whether the transformer is fitted."""
return self._is_fitted

def fit(self, table: Table) -> SequentialTableTransformer:
"""
Fits all the transformers in order.
Parameters
----------
table:
The table used to fit the transformers.
Returns
-------
fitted_transformer:
The fitted transformer.
Raises
------
ValueError:
Raises a ValueError if the table has no rows.
"""
if table.row_count == 0:
raise ValueError("The SequentialTableTransformer cannot be fitted because the table contains 0 rows.")

current_table: Table = table
fitted_transformers: list[TableTransformer] = []

for transformer in self._transformers:
fitted_transformer = transformer.fit(current_table)
fitted_transformers.append(fitted_transformer)
current_table = fitted_transformer.transform(current_table)

result: SequentialTableTransformer = SequentialTableTransformer(
transformers=fitted_transformers,
)

result._is_fitted = True
return result

def transform(self, table: Table) -> Table:
"""
Transform the table using all the transformers sequentially.
Might change the order and type of columns base on the transformers used.
Parameters
----------
table:
The table to be transformed.
Returns
-------
transformed_table:
The transformed table.
Raises
------
TransformerNotFittedError:
Raises a TransformerNotFittedError if the transformer isn't fitted.
"""
if not self._is_fitted:
raise TransformerNotFittedError

current_table: Table = table
for transformer in self._transformers:
current_table = transformer.transform(current_table)

return current_table

def inverse_transform(self, transformed_table: Table) -> Table:
"""
Inversely transforms the table using all the transformers sequentially in inverse order.
Might change the order and type of columns base on the transformers used.
Parameters
----------
transformed_table:
The table to be transformed back.
Returns
-------
original_table:
The original table.
Raises
------
TransformerNotFittedError:
Raises a TransformerNotFittedError if the transformer isn't fitted.
TransformerNotInvertibleError:
Raises a TransformerNotInvertibleError if one of the transformers isn't invertible.
"""
if not self._is_fitted:
raise TransformerNotFittedError

# sequentially inverse transform the table with all transformers, working from the back of the list forwards.
current_table: Table = transformed_table
for transformer in reversed(self._transformers):
# check if transformer is invertible
if not (isinstance(transformer, InvertibleTableTransformer)):
raise TransformerNotInvertibleError(str(type(transformer)))
current_table = transformer.inverse_transform(current_table)

return current_table
2 changes: 2 additions & 0 deletions src/safeds/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NonNumericColumnError,
OutputLengthMismatchError,
TransformerNotFittedError,
TransformerNotInvertibleError,
ValueNotPresentWhenFittedError,
)
from ._ml import (
Expand Down Expand Up @@ -66,6 +67,7 @@ class OutOfBoundsError(SafeDsError):
"NonNumericColumnError",
"OutputLengthMismatchError",
"TransformerNotFittedError",
"TransformerNotInvertibleError",
"ValueNotPresentWhenFittedError",
# ML exceptions
"DatasetMissesDataError",
Expand Down
7 changes: 7 additions & 0 deletions src/safeds/exceptions/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ def __init__(self) -> None:
super().__init__("The transformer has not been fitted yet.")


class TransformerNotInvertibleError(Exception):
"""Raised when a function tries to invert a non-invertible transformer."""

def __init__(self, transformer_type: str) -> None:
super().__init__(f"{transformer_type} is not invertible.")


class ValueNotPresentWhenFittedError(Exception):
"""Exception raised when attempting to one-hot-encode a table containing values not present in the fitting phase."""

Expand Down
27 changes: 25 additions & 2 deletions tests/helpers/_assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
from safeds.data.tabular.containers import Cell, Column, Table


def assert_tables_equal(table1: Table, table2: Table) -> None:
def assert_tables_equal(
table1: Table,
table2: Table,
*,
ignore_column_order: bool = False,
ignore_row_order: bool = False,
ignore_types: bool = False,
ignore_float_imprecision: bool = True,
) -> None:
"""
Assert that two tables are almost equal.
Expand All @@ -16,8 +24,23 @@ def assert_tables_equal(table1: Table, table2: Table) -> None:
The first table.
table2:
The table to compare the first table to.
ignore_column_order:
Ignore the column order when True. Will return true, even when the column order is different.
ignore_row_order:
Ignore the column order when True. Will return true, even when the row order is different.
ignore_types:
Ignore differing data Types. Will return true, even when columns have differing data types.
ignore_float_imprecision:
If False, check if floating point values match EXACTLY.
"""
assert_frame_equal(table1._data_frame, table2._data_frame)
assert_frame_equal(
table1._data_frame,
table2._data_frame,
check_row_order=not ignore_row_order,
check_column_order=not ignore_column_order,
check_dtypes=not ignore_types,
check_exact=not ignore_float_imprecision,
)


def assert_that_tabular_datasets_are_equal(table1: TabularDataset, table2: TabularDataset) -> None:
Expand Down
Loading

0 comments on commit e93299f

Please sign in to comment.