From 9c51e18cd460085c975484cf2ee6589c7c5e41b3 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Wed, 19 Apr 2023 10:46:31 +0300 Subject: [PATCH 1/2] feature: add default params_to_tune for TimeSeriesImputerTransform --- etna/transforms/missing_values/imputation.py | 29 +++++++++++++++++++ .../test_impute_transform.py | 13 +++++++++ .../test_resample_transform.py | 5 ++++ 3 files changed, 47 insertions(+) diff --git a/etna/transforms/missing_values/imputation.py b/etna/transforms/missing_values/imputation.py index b0c59497b..6ee79723b 100644 --- a/etna/transforms/missing_values/imputation.py +++ b/etna/transforms/missing_values/imputation.py @@ -1,13 +1,20 @@ from enum import Enum +from typing import Dict from typing import List from typing import Optional import numpy as np import pandas as pd +from etna import SETTINGS from etna.transforms.base import OneSegmentTransform from etna.transforms.base import ReversiblePerSegmentWrapper +if SETTINGS.auto_required: + from optuna.distributions import BaseDistribution + from optuna.distributions import CategoricalDistribution + from optuna.distributions import IntUniformDistribution + class ImputerMode(str, Enum): """Enum for different imputation strategy.""" @@ -277,5 +284,27 @@ def get_regressors_info(self) -> List[str]: """Return the list with regressors created by the transform.""" return [] + def params_to_tune(self) -> Dict[str, "BaseDistribution"]: + """Get default grid for tuning hyperparameters. + + This grid doesn't tune ``seasonality`` parameter. It expected to be set by the user. + Strategy "seasonal" is suggested only if ``seasonality`` is set higher than 1. + + Returns + ------- + : + Grid to tune. + """ + if self.seasonality > 1: + return { + "strategy": CategoricalDistribution(["constant", "mean", "running_mean", "forward_fill", "seasonal"]), + "window": IntUniformDistribution(low=1, high=20), + } + else: + return { + "strategy": CategoricalDistribution(["constant", "mean", "running_mean", "forward_fill"]), + "window": IntUniformDistribution(low=1, high=20), + } + __all__ = ["TimeSeriesImputerTransform"] diff --git a/tests/test_transforms/test_missing_values/test_impute_transform.py b/tests/test_transforms/test_missing_values/test_impute_transform.py index 7bbd3d304..c0b9def17 100644 --- a/tests/test_transforms/test_missing_values/test_impute_transform.py +++ b/tests/test_transforms/test_missing_values/test_impute_transform.py @@ -8,6 +8,7 @@ from etna.models import NaiveModel from etna.transforms.missing_values import TimeSeriesImputerTransform from etna.transforms.missing_values.imputation import _OneSegmentTimeSeriesImputerTransform +from tests.test_transforms.utils import assert_sampling_is_valid from tests.test_transforms.utils import assert_transformation_equals_loaded_original @@ -393,3 +394,15 @@ def test_constant_fill_strategy(ts_with_missing_range_x_index_two_segments: TSDa def test_save_load(ts_to_fill): transform = TimeSeriesImputerTransform() assert_transformation_equals_loaded_original(transform=transform, ts=ts_to_fill) + + +@pytest.mark.parametrize( + "transform, expected_strategy_length", + [(TimeSeriesImputerTransform(), 4), (TimeSeriesImputerTransform(seasonality=7), 5)], +) +def test_params_to_tune(transform, expected_strategy_length, ts_to_fill): + ts = ts_to_fill + grid = transform.params_to_tune() + assert len(grid) > 0 + assert len(grid["strategy"].choices) == expected_strategy_length + assert_sampling_is_valid(transform=transform, ts=ts) diff --git a/tests/test_transforms/test_missing_values/test_resample_transform.py b/tests/test_transforms/test_missing_values/test_resample_transform.py index 10b070f9d..6d294a479 100644 --- a/tests/test_transforms/test_missing_values/test_resample_transform.py +++ b/tests/test_transforms/test_missing_values/test_resample_transform.py @@ -127,3 +127,8 @@ def test_get_regressors_info_not_fitted(): transform = ResampleWithDistributionTransform(in_column="regressor_exog", distribution_column="target") with pytest.raises(ValueError, match="Fit the transform to get the correct regressors info!"): _ = transform.get_regressors_info() + + +def test_params_to_tune(): + transform = ResampleWithDistributionTransform(in_column="regressor_exog", distribution_column="target") + assert len(transform.params_to_tune()) == 0 From 366de576e72675c5fa8b5781ce4a1adbed9ec169 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Wed, 19 Apr 2023 10:48:45 +0300 Subject: [PATCH 2/2] chore: update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 386ef9abe..f2f1394fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add default `params_to_tune` for `HoltWintersModel`, `HoltModel` and `SimpleExpSmoothingModel` ([#1209](https://github.com/tinkoff-ai/etna/pull/1209)) - Add default `params_to_tune` for `RNNModel` and `MLPModel` ([#1218](https://github.com/tinkoff-ai/etna/pull/1218)) - Add default `params_to_tune` for `DateFlagsTransform`, `TimeFlagsTransform`, `SpecialDaysTransform` and `FourierTransform` ([#1228](https://github.com/tinkoff-ai/etna/pull/1228)) +- Add default `params_to_tune` for `TimeSeriesImputerTransform` ([#1232](https://github.com/tinkoff-ai/etna/pull/1232)) ### Fixed - Fix bug in `GaleShapleyFeatureSelectionTransform` with wrong number of remaining features ([#1110](https://github.com/tinkoff-ai/etna/pull/1110)) - `ProphetModel` fails with additional seasonality set ([#1157](https://github.com/tinkoff-ai/etna/pull/1157))