Skip to content

Implement in-sample predictions in BATS/TBATS #1181

Merged
merged 19 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:

- name: PyTest ("not long")
run: |
poetry run pytest tests -v --cov=etna -m "not long_1 and not long_2" --ignore=tests/test_experimental --cov-report=xml --durations=10
poetry run pytest tests -v --cov=etna -m "not long_1 and not long_2 and not long_3" --ignore=tests/test_experimental --cov-report=xml --durations=10
poetry run pytest etna -v --doctest-modules --ignore=etna/libs --durations=10

- name: Upload coverage
Expand Down Expand Up @@ -148,6 +148,44 @@ jobs:
- name: Upload coverage
uses: codecov/codecov-action@v2

long-3-test:
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2

- name: Set up Python
id: setup-python
uses: actions/setup-python@v2
with:
python-version: 3.8

- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: 1.4.0 # TODO: remove after poetry fix
virtualenvs-create: true
virtualenvs-in-project: true

- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v2
with:
path: .venv
key: venv-${{ runner.os }}-3.8-${{ hashFiles('**/poetry.lock') }}

- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: |
poetry install -E "all tests" -vv

- name: PyTest ("long")
run: |
poetry run pytest tests -v --cov=etna -m "long_3" --ignore=tests/test_experimental --cov-report=xml --durations=10

- name: Upload coverage
uses: codecov/codecov-action@v2

experimental-test:
runs-on: ubuntu-latest

Expand Down
12 changes: 7 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `ChangePointsLevelTransform` and base classes `PerIntervalModel`, `BaseChangePointsModelAdapter` for per-interval transforms ([#998](https://github.com/tinkoff-ai/etna/pull/998))
- Method `set_params` to change parameters of ETNA objects ([#1102](https://github.com/tinkoff-ai/etna/pull/1102))
- Function `plot_forecast_decomposition` ([#1129](https://github.com/tinkoff-ai/etna/pull/1129))
- Method `forecast_components` for forecast decomposition in `_TBATSAdapter` ([#1125](https://github.com/tinkoff-ai/etna/issues/1125))
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_CatBoostAdapter` ([#1135](https://github.com/tinkoff-ai/etna/issues/1135))
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_HoltWintersAdapter ` ([#1146](https://github.com/tinkoff-ai/etna/issues/1146))
- Methods `predict_components` for forecast decomposition in `_ProphetAdapter` ([#1161](https://github.com/tinkoff-ai/etna/issues/1161))
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_SARIMAXAdapter` and `_AutoARIMAAdapter` ([#1149](https://github.com/tinkoff-ai/etna/issues/1149))
- Method `forecast_components` for forecast decomposition in `_TBATSAdapter` ([#1133](https://github.com/tinkoff-ai/etna/pull/1133))
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_CatBoostAdapter` ([#1148](https://github.com/tinkoff-ai/etna/pull/1148))
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_HoltWintersAdapter ` ([#1162](https://github.com/tinkoff-ai/etna/pull/1162))
- Method `predict_components` for forecast decomposition in `_ProphetAdapter` ([#1172](https://github.com/tinkoff-ai/etna/pull/1172))
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_SARIMAXAdapter` and `_AutoARIMAAdapter` ([#1174](https://github.com/tinkoff-ai/etna/pull/1174))
- Add `refit` parameter into `backtest` ([#1159](https://github.com/tinkoff-ai/etna/pull/1159))
- Add `stride` parameter into `backtest` ([#1165](https://github.com/tinkoff-ai/etna/pull/1165))
- Add optional parameter `ts` into `forecast` method of pipelines ([#1071](https://github.com/tinkoff-ai/etna/pull/1071))
- Add tests on `transform` method of transforms on subset of segments, on new segments, on future with gap ([#1094](https://github.com/tinkoff-ai/etna/pull/1094))
- Add tests on `inverse_transform` method of transforms on subset of segments, on new segments, on future with gap ([#1127](https://github.com/tinkoff-ai/etna/pull/1127))
- In-sample prediction for `BATSModel` and `TBATSModel` ([#1181](https://github.com/tinkoff-ai/etna/pull/1181))
- Method `predict_components` for forecast decomposition in `_TBATSAdapter` ([#1181](https://github.com/tinkoff-ai/etna/pull/1181))
-
### Changed
- Add optional `features` parameter in the signature of `TSDataset.to_pandas`, `TSDataset.to_flatten` ([#809](https://github.com/tinkoff-ai/etna/pull/809))
Expand Down
117 changes: 109 additions & 8 deletions etna/models/tbats.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class _TBATSAdapter(BaseAdapter):
def __init__(self, model: Estimator):
self._model = model
self._fitted_model: Optional[Model] = None
self._first_train_timestamp = None
self._last_train_timestamp = None
self._freq = None

Expand All @@ -32,6 +33,7 @@ def fit(self, df: pd.DataFrame, regressors: Iterable[str]):

target = df["target"]
self._fitted_model = self._model.fit(target)
self._first_train_timestamp = df["timestamp"].min()
self._last_train_timestamp = df["timestamp"].max()
self._freq = freq

Expand Down Expand Up @@ -65,7 +67,35 @@ def forecast(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Itera
return y_pred

def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Iterable[float]) -> pd.DataFrame:
raise NotImplementedError("Method predict isn't currently implemented!")
if self._fitted_model is None or self._freq is None:
raise ValueError("Model is not fitted! Fit the model before calling predict method!")

train_timestamp = pd.date_range(
start=str(self._first_train_timestamp), end=str(self._last_train_timestamp), freq=self._freq
)

if not (set(train_timestamp) >= set(df["timestamp"])):
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError("Method predict isn't currently implemented for out-of-sample prediction!")

y_pred = pd.DataFrame()
y_pred["target"] = self._fitted_model.y_hat
y_pred["timestamp"] = train_timestamp

if prediction_interval:
for quantile in quantiles:
confidence_intervals = self._fitted_model._calculate_confidence_intervals(
y_pred["target"].values, quantile
)

if quantile < 1 / 2:
y_pred[f"target_{quantile:.4g}"] = confidence_intervals["lower_bound"]
else:
y_pred[f"target_{quantile:.4g}"] = confidence_intervals["upper_bound"]

# selecting time points from provided dataframe
y_pred = y_pred.merge(df["timestamp"], on="timestamp").drop(columns=["timestamp"])
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved

return y_pred

def get_model(self) -> Model:
"""Get internal :py:class:`tbats.tbats.Model` model that was fitted inside etna class.
Expand Down Expand Up @@ -114,7 +144,28 @@ def predict_components(self, df: pd.DataFrame) -> pd.DataFrame:
:
dataframe with prediction components
"""
raise NotImplementedError("Prediction decomposition isn't currently implemented!")
if self._fitted_model is None or self._freq is None:
raise ValueError("Model is not fitted! Fit the model before estimating forecast components!")

train_timestamp = pd.date_range(
start=str(self._first_train_timestamp), end=str(self._last_train_timestamp), freq=self._freq
)

if not (set(train_timestamp) >= set(df["timestamp"])):
raise NotImplementedError(
"Method predict_components isn't currently implemented for out-of-sample prediction!"
)

self._check_components()

raw_components = self._decompose_predict()
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
components = self._process_components(raw_components=raw_components)

# selecting time points from provided dataframe
components["timestamp"] = train_timestamp
components = components.merge(df["timestamp"], on="timestamp").drop(columns=["timestamp"])
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved

return components

def _get_steps_to_forecast(self, df: pd.DataFrame) -> int:
if self._freq is None:
Expand Down Expand Up @@ -157,6 +208,16 @@ def _check_components(self):
if len(not_fitted_components) > 0:
warn(f"Following components are not fitted: {', '.join(not_fitted_components)}!")

def _rescale_components(self, raw_components: np.ndarray) -> np.ndarray:
"""Rescale components when Box-Cox transform used."""
if self._fitted_model is None:
raise ValueError("Fitted model is not set!")

transformed_pred = np.sum(raw_components, axis=1)
pred = self._fitted_model._inv_boxcox(transformed_pred)
components = raw_components * pred[..., np.newaxis] / transformed_pred[..., np.newaxis]
return components

def _decompose_forecast(self, horizon: int) -> np.ndarray:
"""Estimate raw forecast components."""
if self._fitted_model is None:
Expand All @@ -175,9 +236,33 @@ def _decompose_forecast(self, horizon: int) -> np.ndarray:
raw_components = np.stack(components, axis=0)

if model.params.components.use_box_cox:
transformed_pred = np.sum(raw_components, axis=1)
pred = model._inv_boxcox(transformed_pred)
raw_components = raw_components * pred[..., np.newaxis] / transformed_pred[..., np.newaxis]
raw_components = self._rescale_components(raw_components)

return raw_components

def _decompose_predict(self) -> np.ndarray:
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
"""Estimate raw prediction components."""
if self._fitted_model is None:
raise ValueError("Fitted model is not set!")

model = self._fitted_model
state_matrix = model.matrix.make_F_matrix()
component_weights = model.matrix.make_w_vector()
error_weights = model.matrix.make_g_vector()

steps = len(model.y)
state = model.params.x0
weighted_error = model.resid_boxcox[..., np.newaxis] * error_weights[np.newaxis]

components = []
for t in range(steps):
components.append(component_weights * state)
state = state_matrix @ state + weighted_error[t]

raw_components = np.stack(components, axis=0)

if model.params.components.use_box_cox:
raw_components = self._rescale_components(raw_components)

return raw_components

Expand Down Expand Up @@ -223,13 +308,21 @@ def _process_components(self, raw_components: np.ndarray) -> pd.DataFrame:
raw_components[:, component_idx : component_idx + p + q], axis=1
)

return pd.DataFrame(data=named_components)
return pd.DataFrame(data=named_components).add_prefix("target_component_")


class BATSModel(
PerSegmentModelMixin, PredictionIntervalContextIgnorantModelMixin, PredictionIntervalContextIgnorantAbstractModel
):
"""Class for holding segment interval BATS model."""
"""Class for holding segment interval BATS model.

Notes
-----
This model supports in-sample and out-of-sample prediction decomposition.
Prediction components for BATS model are: local level, trend, seasonality and ARMA component.
In-sample and out-of-sample decompositions components are estimated directly from the fitted model parameters.
Box-Cox transform supported with components proportional rescaling.
"""

def __init__(
self,
Expand Down Expand Up @@ -298,7 +391,15 @@ def __init__(
class TBATSModel(
PerSegmentModelMixin, PredictionIntervalContextIgnorantModelMixin, PredictionIntervalContextIgnorantAbstractModel
):
"""Class for holding segment interval TBATS model."""
"""Class for holding segment interval TBATS model.

Notes
-----
This model supports in-sample and out-of-sample prediction decomposition.
Prediction components for TBATS model are: local level, trend, seasonality and ARMA component.
In-sample and out-of-sample decompositions components are estimated directly from the fitted model parameters.
Box-Cox transform supported with components proportional rescaling.
"""

def __init__(
self,
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ doctest_optionflags = "NORMALIZE_WHITESPACE IGNORE_EXCEPTION_DETAIL NUMBER"
markers = [
"smoke",
"long_1",
"long_2"
"long_2",
"long_3"
]

[tool.coverage.report]
Expand Down
12 changes: 6 additions & 6 deletions tests/test_models/test_inference/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class TestPredictInSampleFull:
(HoltModel(), []),
(HoltWintersModel(), []),
(SimpleExpSmoothingModel(), []),
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
],
)
def test_predict_in_sample_full(self, model, transforms, example_tsds):
Expand Down Expand Up @@ -95,8 +97,6 @@ def test_predict_in_sample_full_failed_not_enough_context(self, model, transform
@pytest.mark.parametrize(
"model, transforms",
[
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
(
DeepARModel(
dataset_builder=PytorchForecastingDatasetBuilder(
Expand Down Expand Up @@ -171,6 +171,8 @@ class TestPredictInSampleSuffix:
(NaiveModel(lag=3), []),
(SeasonalMovingAverageModel(), []),
(DeadlineMovingAverageModel(window=1), []),
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
],
)
def test_predict_in_sample_suffix(self, model, transforms, example_tsds):
Expand All @@ -180,8 +182,6 @@ def test_predict_in_sample_suffix(self, model, transforms, example_tsds):
@pytest.mark.parametrize(
"model, transforms",
[
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
(
DeepARModel(
dataset_builder=PytorchForecastingDatasetBuilder(
Expand Down Expand Up @@ -714,6 +714,8 @@ def _test_predict_subset_segments(self, ts, model, transforms, segments, num_ski
(SeasonalMovingAverageModel(), []),
(NaiveModel(lag=3), []),
(DeadlineMovingAverageModel(window=1), []),
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
],
)
def test_predict_subset_segments(self, model, transforms, example_tsds):
Expand All @@ -723,8 +725,6 @@ def test_predict_subset_segments(self, model, transforms, example_tsds):
@pytest.mark.parametrize(
"model, transforms",
[
(BATSModel(use_trend=True), []),
(TBATSModel(use_trend=True), []),
(
DeepARModel(
dataset_builder=PytorchForecastingDatasetBuilder(
Expand Down
Loading