Skip to content

Forecast decomposition for TBATS #1133

Merged
merged 21 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `ChangePointsLevelTransform` and base classes `PerIntervalModel`, `BaseChangePointsModelAdapter` for per-interval transforms ([#998](https://github.com/tinkoff-ai/etna/pull/998))
- Method `set_params` to change parameters of ETNA objects ([#1102](https://github.com/tinkoff-ai/etna/pull/1102))
- Function `plot_forecast_decomposition` ([#1129](https://github.com/tinkoff-ai/etna/pull/1129))
-
- Method `forecast_components` for forecast decomposition in `_TBATSAdapter` [#1125](https://github.com/tinkoff-ai/etna/issues/1125)
-
### Changed
- Add optional `features` parameter in the signature of `TSDataset.to_pandas`, `TSDataset.to_flatten` ([#809](https://github.com/tinkoff-ai/etna/pull/809))
- Signature of the constructor of `TFTModel`, `DeepARModel` ([#1110](https://github.com/tinkoff-ai/etna/pull/1110))
Expand Down
160 changes: 151 additions & 9 deletions etna/models/tbats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Iterable
from typing import Optional
from typing import Tuple
from warnings import warn

import numpy as np
import pandas as pd
from tbats.abstract import ContextInterface
from tbats.abstract import Estimator
Expand Down Expand Up @@ -39,15 +41,7 @@ def forecast(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Itera
if self._fitted_model is None or self._freq is None:
raise ValueError("Model is not fitted! Fit the model before calling predict method!")

if df["timestamp"].min() <= self._last_train_timestamp:
raise NotImplementedError(
"It is not possible to make in-sample predictions with BATS/TBATS model! "
"In-sample predictions aren't supported by current implementation."
)

steps_to_forecast = determine_num_steps(
start_timestamp=self._last_train_timestamp, end_timestamp=df["timestamp"].max(), freq=self._freq
)
steps_to_forecast = self._get_steps_to_forecast(df=df)
steps_to_skip = steps_to_forecast - df.shape[0]

y_pred = pd.DataFrame()
Expand Down Expand Up @@ -83,6 +77,154 @@ def get_model(self) -> Model:
"""
return self._fitted_model

def forecast_components(self, df: pd.DataFrame) -> pd.DataFrame:
"""Estimate forecast components.

Parameters
----------
df:
features dataframe

Returns
-------
:
dataframe with forecast components
"""
if self._fitted_model is None or self._freq is None:
raise ValueError("Model is not fitted! Fit the model before estimating forecast components!")

self._check_components()

horizon = self._get_steps_to_forecast(df=df)
raw_components = self._decompose_forecast(horizon=horizon)
components = self._process_components(raw_components=raw_components)

return components

def predict_components(self, df: pd.DataFrame) -> pd.DataFrame:
"""Estimate prediction components.

Parameters
----------
df:
features dataframe

Returns
-------
:
dataframe with prediction components
"""
raise NotImplementedError("Prediction decomposition isn't currently implemented!")

def _get_steps_to_forecast(self, df: pd.DataFrame) -> int:
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
if self._freq is None:
raise ValueError("Data frequency is not set!")

if df["timestamp"].min() <= self._last_train_timestamp:
raise NotImplementedError(
"It is not possible to make in-sample predictions with BATS/TBATS model! "
"In-sample predictions aren't supported by current implementation."
)

steps_to_forecast = determine_num_steps(
start_timestamp=self._last_train_timestamp, end_timestamp=df["timestamp"].max(), freq=self._freq
)
return steps_to_forecast

def _check_components(self):
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
"""Compare fitted model params with the initial params.

TBATS tries different models and selects best based on AIC.
That's why some components may not be present in fitted model.
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
"""
if self._fitted_model is None:
raise ValueError("Fitted model is not set!")

fitted_model_params = self._fitted_model.params.components
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved

not_fitted_components = []
seasonal_periods = self._model.seasonal_periods
if (
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
seasonal_periods is not None
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
and len(seasonal_periods) > 0
and len(fitted_model_params.seasonal_periods) == 0
):
not_fitted_components.append("Seasonal")

if self._model.use_arma_errors and not fitted_model_params.use_arma_errors:
not_fitted_components.append("ARMA")

if len(not_fitted_components) > 0:
warn(f"Following components are not fitted: {', '.join(not_fitted_components)}!")

def _decompose_forecast(self, horizon: int) -> np.ndarray:
"""Estimate raw forecast components."""
if self._fitted_model is None:
raise ValueError("Fitted model is not set!")

model = self._fitted_model
state_matrix = model.matrix.make_F_matrix()
component_weights = model.matrix.make_w_vector()

state = model.x_last
components = []
for _ in range(horizon):
components.append(component_weights * state)
state = state_matrix @ state

raw_components = np.stack(components, axis=0)

if model.params.components.use_box_cox:
transformed_pred = np.sum(raw_components, axis=1)
pred = model._inv_boxcox(transformed_pred)
raw_components = raw_components * pred[..., np.newaxis] / transformed_pred[..., np.newaxis]

return raw_components

def _process_components(self, raw_components: np.ndarray) -> pd.DataFrame:
"""Select meaningful components and assign names to them."""
if self._fitted_model is None:
raise ValueError("Fitted model is not set!")

params_components = self._fitted_model.params.components
named_components = dict()

named_components["local_level"] = raw_components[:, 0]

component_idx = 1
if params_components.use_trend:
named_components["trend"] = raw_components[:, component_idx]
component_idx += 1

if len(params_components.seasonal_periods) != 0:
seasonal_periods = params_components.seasonal_periods

if hasattr(params_components, "seasonal_harmonics"):
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
# TBATS
seasonal_harmonics = params_components.seasonal_harmonics
for seasonal_period, seasonal_harmonic in zip(seasonal_periods, seasonal_harmonics):
named_components[f"seasonal(s={seasonal_period})"] = np.sum(
raw_components[:, component_idx : component_idx + 2 * seasonal_harmonic], axis=1
)
component_idx += 2 * seasonal_harmonic

else:
# BATS
component_idx -= 1
for seasonal_period in seasonal_periods:
component_idx += seasonal_period
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
named_components[f"seasonal(s={seasonal_period})"] = raw_components[:, component_idx]

component_idx += 1

if params_components.p > 0 or params_components.q > 0:
p, q = params_components.p, params_components.q
named_components[f"arma(p={p},q={q})"] = np.sum(
raw_components[:, component_idx : component_idx + p + q], axis=1
)

return pd.DataFrame(data=named_components)


class BATSModel(
PerSegmentModelMixin, PredictionIntervalContextIgnorantModelMixin, PredictionIntervalContextIgnorantAbstractModel
Expand Down
Loading