Skip to content

Commit

Permalink
Add default params_to_tune for TimeSeriesImputerTransform (#1232)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Geekman authored Apr 19, 2023
1 parent 68634b0 commit a1c19c8
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add default `params_to_tune` for `RNNModel` and `MLPModel` ([#1218](https://github.com/tinkoff-ai/etna/pull/1218))
- Add default `params_to_tune` for `DateFlagsTransform`, `TimeFlagsTransform`, `SpecialDaysTransform` and `FourierTransform` ([#1228](https://github.com/tinkoff-ai/etna/pull/1228))
- Add default `params_to_tune` for `MedianOutliersTransform`, `DensityOutliersTransform` and `PredictionIntervalOutliersTransform` ([#1231](https://github.com/tinkoff-ai/etna/pull/1231))
- Add default `params_to_tune` for `TimeSeriesImputerTransform` ([#1232](https://github.com/tinkoff-ai/etna/pull/1232))
### Fixed
- Fix bug in `GaleShapleyFeatureSelectionTransform` with wrong number of remaining features ([#1110](https://github.com/tinkoff-ai/etna/pull/1110))
- `ProphetModel` fails with additional seasonality set ([#1157](https://github.com/tinkoff-ai/etna/pull/1157))
Expand Down
29 changes: 29 additions & 0 deletions etna/transforms/missing_values/imputation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from enum import Enum
from typing import Dict
from typing import List
from typing import Optional

import numpy as np
import pandas as pd

from etna import SETTINGS
from etna.transforms.base import OneSegmentTransform
from etna.transforms.base import ReversiblePerSegmentWrapper

if SETTINGS.auto_required:
from optuna.distributions import BaseDistribution
from optuna.distributions import CategoricalDistribution
from optuna.distributions import IntUniformDistribution


class ImputerMode(str, Enum):
"""Enum for different imputation strategy."""
Expand Down Expand Up @@ -277,5 +284,27 @@ def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
return []

def params_to_tune(self) -> Dict[str, "BaseDistribution"]:
"""Get default grid for tuning hyperparameters.
This grid doesn't tune ``seasonality`` parameter. It expected to be set by the user.
Strategy "seasonal" is suggested only if ``seasonality`` is set higher than 1.
Returns
-------
:
Grid to tune.
"""
if self.seasonality > 1:
return {
"strategy": CategoricalDistribution(["constant", "mean", "running_mean", "forward_fill", "seasonal"]),
"window": IntUniformDistribution(low=1, high=20),
}
else:
return {
"strategy": CategoricalDistribution(["constant", "mean", "running_mean", "forward_fill"]),
"window": IntUniformDistribution(low=1, high=20),
}


__all__ = ["TimeSeriesImputerTransform"]
13 changes: 13 additions & 0 deletions tests/test_transforms/test_missing_values/test_impute_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from etna.models import NaiveModel
from etna.transforms.missing_values import TimeSeriesImputerTransform
from etna.transforms.missing_values.imputation import _OneSegmentTimeSeriesImputerTransform
from tests.test_transforms.utils import assert_sampling_is_valid
from tests.test_transforms.utils import assert_transformation_equals_loaded_original


Expand Down Expand Up @@ -393,3 +394,15 @@ def test_constant_fill_strategy(ts_with_missing_range_x_index_two_segments: TSDa
def test_save_load(ts_to_fill):
transform = TimeSeriesImputerTransform()
assert_transformation_equals_loaded_original(transform=transform, ts=ts_to_fill)


@pytest.mark.parametrize(
"transform, expected_strategy_length",
[(TimeSeriesImputerTransform(), 4), (TimeSeriesImputerTransform(seasonality=7), 5)],
)
def test_params_to_tune(transform, expected_strategy_length, ts_to_fill):
ts = ts_to_fill
grid = transform.params_to_tune()
assert len(grid) > 0
assert len(grid["strategy"].choices) == expected_strategy_length
assert_sampling_is_valid(transform=transform, ts=ts)
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,8 @@ def test_get_regressors_info_not_fitted():
transform = ResampleWithDistributionTransform(in_column="regressor_exog", distribution_column="target")
with pytest.raises(ValueError, match="Fit the transform to get the correct regressors info!"):
_ = transform.get_regressors_info()


def test_params_to_tune():
transform = ResampleWithDistributionTransform(in_column="regressor_exog", distribution_column="target")
assert len(transform.params_to_tune()) == 0

0 comments on commit a1c19c8

Please sign in to comment.