diff --git a/etna/transforms/__init__.py b/etna/transforms/__init__.py index 2ba16b1c0..a46f37d90 100644 --- a/etna/transforms/__init__.py +++ b/etna/transforms/__init__.py @@ -2,6 +2,7 @@ from etna.transforms.base import IrreversibleTransform from etna.transforms.base import NewPerSegmentWrapper from etna.transforms.base import NewTransform +from etna.transforms.base import OneSegmentTransform from etna.transforms.base import PerSegmentWrapper from etna.transforms.base import ReversiblePerSegmentWrapper from etna.transforms.base import ReversibleTransform diff --git a/etna/transforms/base.py b/etna/transforms/base.py index d59ce3707..e99484a54 100644 --- a/etna/transforms/base.py +++ b/etna/transforms/base.py @@ -342,13 +342,83 @@ def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: return df +class OneSegmentTransform(ABC, BaseMixin): + """Base class to create one segment transforms to apply to data.""" + + @abstractmethod + def fit(self, df: pd.DataFrame): + """Fit the transform. + + Should be implemented by user. + + Parameters + ---------- + df: + Dataframe in etna long format. + """ + pass + + @abstractmethod + def transform(self, df: pd.DataFrame) -> pd.DataFrame: + """Transform dataframe. + + Should be implemented by user + + Parameters + ---------- + df: + Dataframe in etna long format. + + Returns + ------- + : + Transformed Dataframe in etna long format. + """ + pass + + def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame: + """Fit and transform Dataframe. + + May be reimplemented. But it is not recommended. + + Parameters + ---------- + df: + Dataframe in etna long format to transform. + + Returns + ------- + : + Transformed Dataframe. + """ + return self.fit(df=df).transform(df=df) + + @abstractmethod + def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: + """Inverse transform Dataframe. + + Should be reimplemented in the subclasses where necessary. + + Parameters + ---------- + df: + Dataframe in etna long format to be inverse transformed. + + Returns + ------- + : + Dataframe after applying inverse transformation. + """ + pass + + class NewPerSegmentWrapper(NewTransform): """Class to apply transform in per segment manner.""" - def __init__(self, transform: NewTransform): + def __init__(self, transform: OneSegmentTransform, required_features: Union[Literal["all"], List[str]]): self._base_transform = transform - self.segment_transforms: Optional[Dict[str, NewTransform]] = None - super().__init__(required_features=transform.required_features) + self.segment_transforms: Optional[Dict[str, OneSegmentTransform]] = None + super().__init__(required_features=required_features) def _fit(self, df: pd.DataFrame): """Fit transform on each segment.""" @@ -356,7 +426,7 @@ def _fit(self, df: pd.DataFrame): segments = df.columns.get_level_values("segment").unique() for segment in segments: self.segment_transforms[segment] = deepcopy(self._base_transform) - self.segment_transforms[segment]._fit(df[segment]) + self.segment_transforms[segment].fit(df[segment]) def _transform(self, df: pd.DataFrame) -> pd.DataFrame: """Apply transform to each segment separately.""" @@ -364,11 +434,11 @@ def _transform(self, df: pd.DataFrame) -> pd.DataFrame: raise ValueError("Transform is not fitted!") results = [] - for key, value in self.segment_transforms.items(): - seg_df = value._transform(df[key]) + for segment, transform in self.segment_transforms.items(): + seg_df = transform.transform(df[segment]) _idx = seg_df.columns.to_frame() - _idx.insert(0, "segment", key) + _idx.insert(0, "segment", segment) seg_df.columns = pd.MultiIndex.from_frame(_idx) results.append(seg_df) @@ -382,15 +452,15 @@ def _transform(self, df: pd.DataFrame) -> pd.DataFrame: class IrreversiblePerSegmentWrapper(NewPerSegmentWrapper, IrreversibleTransform): """Class to apply irreversible transform in per segment manner.""" - def __init__(self, transform: IrreversibleTransform): - super().__init__(transform=transform) + def __init__(self, transform: OneSegmentTransform, required_features: Union[Literal["all"], List[str]]): + super().__init__(transform=transform, required_features=required_features) class ReversiblePerSegmentWrapper(NewPerSegmentWrapper, ReversibleTransform): """Class to apply reversible transform in per segment manner.""" - def __init__(self, transform: ReversibleTransform): - super().__init__(transform=transform) + def __init__(self, transform: OneSegmentTransform, required_features: Union[Literal["all"], List[str]]): + super().__init__(transform=transform, required_features=required_features) def _inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: """Apply inverse_transform to each segment.""" @@ -398,11 +468,11 @@ def _inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: raise ValueError("Transform is not fitted!") results = [] - for key, value in self.segment_transforms.items(): - seg_df = value._inverse_transform(df[key]) # type: ignore + for segment, transform in self.segment_transforms.items(): + seg_df = transform.inverse_transform(df[segment]) _idx = seg_df.columns.to_frame() - _idx.insert(0, "segment", key) + _idx.insert(0, "segment", segment) seg_df.columns = pd.MultiIndex.from_frame(_idx) results.append(seg_df) diff --git a/etna/transforms/missing_values/resample.py b/etna/transforms/missing_values/resample.py index 46827a31a..9420e0797 100644 --- a/etna/transforms/missing_values/resample.py +++ b/etna/transforms/missing_values/resample.py @@ -4,14 +4,15 @@ import pandas as pd -from etna.transforms.base import PerSegmentWrapper -from etna.transforms.base import Transform +from etna.datasets import TSDataset +from etna.transforms.base import IrreversiblePerSegmentWrapper +from etna.transforms.base import OneSegmentTransform -class _OneSegmentResampleWithDistributionTransform(Transform): +class _OneSegmentResampleWithDistributionTransform(OneSegmentTransform): """_OneSegmentResampleWithDistributionTransform resamples the given column using the distribution of the other column.""" - def __init__(self, in_column: str, distribution_column: str, inplace: bool, out_column: Optional[str]): + def __init__(self, in_column: str, distribution_column: str, inplace: bool, out_column: str): """ Init _OneSegmentResampleWithDistributionTransform. @@ -34,7 +35,7 @@ def __init__(self, in_column: str, distribution_column: str, inplace: bool, out_ self.distribution_column = distribution_column self.inplace = inplace self.out_column = out_column - self.distribution: pd.DataFrame = None + self.distribution: Optional[pd.DataFrame] = None def _get_folds(self, df: pd.DataFrame) -> List[int]: """ @@ -101,8 +102,12 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: df = df.drop(["fold", "distribution"], axis=1) return df + def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: + """Inverse transform Dataframe.""" + return df + -class ResampleWithDistributionTransform(PerSegmentWrapper): +class ResampleWithDistributionTransform(IrreversiblePerSegmentWrapper): """ResampleWithDistributionTransform resamples the given column using the distribution of the other column. Warning @@ -136,13 +141,15 @@ def __init__( self.distribution_column = distribution_column self.inplace = inplace self.out_column = self._get_out_column(out_column) + self.in_column_regressor: Optional[bool] = None super().__init__( transform=_OneSegmentResampleWithDistributionTransform( in_column=in_column, distribution_column=distribution_column, inplace=inplace, out_column=self.out_column, - ) + ), + required_features=[in_column, distribution_column], ) def _get_out_column(self, out_column: Optional[str]) -> str: @@ -154,3 +161,17 @@ def _get_out_column(self, out_column: Optional[str]) -> str: if out_column: return out_column return self.__repr__() + + def get_regressors_info(self) -> List[str]: + """Return the list with regressors created by the transform.""" + if self.inplace: + return [] + if self.in_column_regressor is None: + warnings.warn("Regressors info might be incorrect. Fit the transform to get the correct regressors info.") + return [self.out_column] if self.in_column_regressor else [] + + def fit(self, ts: TSDataset) -> "ResampleWithDistributionTransform": + """Fit the transform.""" + self.in_column_regressor = self.in_column in ts.regressors + super().fit(ts) + return self diff --git a/tests/test_transforms/test_missing_values/test_resample_transform.py b/tests/test_transforms/test_missing_values/test_resample_transform.py index 05575f26c..fd0028126 100644 --- a/tests/test_transforms/test_missing_values/test_resample_transform.py +++ b/tests/test_transforms/test_missing_values/test_resample_transform.py @@ -8,7 +8,7 @@ def test_fail_on_incompatible_freq(incompatible_freq_ts): in_column="exog", inplace=True, distribution_column="target", out_column=None ) with pytest.raises(ValueError, match="Can not infer in_column frequency!"): - _ = resampler.fit(incompatible_freq_ts.df) + _ = resampler.fit(incompatible_freq_ts) @pytest.mark.parametrize( @@ -27,7 +27,7 @@ def test_fit(ts, request): resampler = ResampleWithDistributionTransform( in_column="regressor_exog", inplace=True, distribution_column="target", out_column=None ) - resampler.fit(ts.df) + resampler.fit(ts) segments = ts.df.columns.get_level_values("segment").unique() for segment in segments: assert (resampler.segment_transforms[segment].distribution == expected_distribution[segment]).all().all() @@ -48,10 +48,11 @@ def test_transform(daily_exog_ts, inplace, out_column, expected_resampled_ts, re resampler = ResampleWithDistributionTransform( in_column="regressor_exog", inplace=inplace, distribution_column="target", out_column=out_column ) - resampled_df = resampler.fit_transform(daily_exog_ts.df) + resampled_df = resampler.fit_transform(daily_exog_ts).to_pandas() assert resampled_df.equals(expected_resampled_df) +@pytest.mark.xfail(reason="TSDataset 2.0") @pytest.mark.parametrize( "inplace,out_column,expected_resampled_ts", ( @@ -77,4 +78,29 @@ def test_fit_transform_with_nans(daily_exog_ts_diff_endings): resampler = ResampleWithDistributionTransform( in_column="regressor_exog", inplace=True, distribution_column="target" ) - daily_exog_ts_diff_endings.fit_transform([resampler]) + _ = resampler.fit_transform(daily_exog_ts_diff_endings) + + +@pytest.mark.filterwarnings("ignore: Regressors info might be incorrect.") +@pytest.mark.parametrize( + "inplace, in_column_regressor, out_column, expected_regressors", + [ + (True, False, None, []), + (False, False, "output_regressor", []), + (False, True, "output_regressor", ["output_regressor"]), + ], +) +def test_get_regressors_info( + daily_exog_ts, inplace, in_column_regressor, out_column, expected_regressors, in_column="regressor_exog" +): + daily_exog_ts = daily_exog_ts["ts"] + if in_column_regressor: + daily_exog_ts._regressors.append(in_column) + else: + daily_exog_ts._regressors.remove(in_column) + resampler = ResampleWithDistributionTransform( + in_column=in_column, inplace=inplace, distribution_column="target", out_column=out_column + ) + resampler.fit(daily_exog_ts) + regressors_info = resampler.get_regressors_info() + assert sorted(regressors_info) == sorted(expected_regressors)