Skip to content

Add default params_to_tune for TimeSeriesImputerTransform #1232

Merged
merged 3 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add default `params_to_tune` for `RNNModel` and `MLPModel` ([#1218](https://github.com/tinkoff-ai/etna/pull/1218))
- Add default `params_to_tune` for `DateFlagsTransform`, `TimeFlagsTransform`, `SpecialDaysTransform` and `FourierTransform` ([#1228](https://github.com/tinkoff-ai/etna/pull/1228))
- Add default `params_to_tune` for `MedianOutliersTransform`, `DensityOutliersTransform` and `PredictionIntervalOutliersTransform` ([#1231](https://github.com/tinkoff-ai/etna/pull/1231))
- Add default `params_to_tune` for `TimeSeriesImputerTransform` ([#1232](https://github.com/tinkoff-ai/etna/pull/1232))
### Fixed
- Fix bug in `GaleShapleyFeatureSelectionTransform` with wrong number of remaining features ([#1110](https://github.com/tinkoff-ai/etna/pull/1110))
- `ProphetModel` fails with additional seasonality set ([#1157](https://github.com/tinkoff-ai/etna/pull/1157))
Expand Down
29 changes: 29 additions & 0 deletions etna/transforms/missing_values/imputation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from enum import Enum
from typing import Dict
from typing import List
from typing import Optional

import numpy as np
import pandas as pd

from etna import SETTINGS
from etna.transforms.base import OneSegmentTransform
from etna.transforms.base import ReversiblePerSegmentWrapper

if SETTINGS.auto_required:
from optuna.distributions import BaseDistribution
from optuna.distributions import CategoricalDistribution
from optuna.distributions import IntUniformDistribution


class ImputerMode(str, Enum):
"""Enum for different imputation strategy."""
Expand Down Expand Up @@ -277,5 +284,27 @@ def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
return []

def params_to_tune(self) -> Dict[str, "BaseDistribution"]:
"""Get default grid for tuning hyperparameters.

This grid doesn't tune ``seasonality`` parameter. It expected to be set by the user.
Strategy "seasonal" is suggested only if ``seasonality`` is set higher than 1.

Returns
-------
:
Grid to tune.
"""
if self.seasonality > 1:
return {
"strategy": CategoricalDistribution(["constant", "mean", "running_mean", "forward_fill", "seasonal"]),
"window": IntUniformDistribution(low=1, high=20),
}
else:
return {
"strategy": CategoricalDistribution(["constant", "mean", "running_mean", "forward_fill"]),
"window": IntUniformDistribution(low=1, high=20),
}


__all__ = ["TimeSeriesImputerTransform"]
13 changes: 13 additions & 0 deletions tests/test_transforms/test_missing_values/test_impute_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from etna.models import NaiveModel
from etna.transforms.missing_values import TimeSeriesImputerTransform
from etna.transforms.missing_values.imputation import _OneSegmentTimeSeriesImputerTransform
from tests.test_transforms.utils import assert_sampling_is_valid
from tests.test_transforms.utils import assert_transformation_equals_loaded_original


Expand Down Expand Up @@ -393,3 +394,15 @@ def test_constant_fill_strategy(ts_with_missing_range_x_index_two_segments: TSDa
def test_save_load(ts_to_fill):
transform = TimeSeriesImputerTransform()
assert_transformation_equals_loaded_original(transform=transform, ts=ts_to_fill)


@pytest.mark.parametrize(
"transform, expected_strategy_length",
[(TimeSeriesImputerTransform(), 4), (TimeSeriesImputerTransform(seasonality=7), 5)],
)
def test_params_to_tune(transform, expected_strategy_length, ts_to_fill):
ts = ts_to_fill
grid = transform.params_to_tune()
assert len(grid) > 0
assert len(grid["strategy"].choices) == expected_strategy_length
assert_sampling_is_valid(transform=transform, ts=ts)
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,8 @@ def test_get_regressors_info_not_fitted():
transform = ResampleWithDistributionTransform(in_column="regressor_exog", distribution_column="target")
with pytest.raises(ValueError, match="Fit the transform to get the correct regressors info!"):
_ = transform.get_regressors_info()


def test_params_to_tune():
transform = ResampleWithDistributionTransform(in_column="regressor_exog", distribution_column="target")
assert len(transform.params_to_tune()) == 0