Skip to content

Forecast decomposition for CatBoost models #1148

Merged
merged 9 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Method `set_params` to change parameters of ETNA objects ([#1102](https://github.com/tinkoff-ai/etna/pull/1102))
- Function `plot_forecast_decomposition` ([#1129](https://github.com/tinkoff-ai/etna/pull/1129))
- Method `forecast_components` for forecast decomposition in `_TBATSAdapter` [#1125](https://github.com/tinkoff-ai/etna/issues/1125)
-
- Methods `forecast_components` and `predict_components` for forecast decomposition in `_CatBoostAdapter` [#1135](https://github.com/tinkoff-ai/etna/issues/1135)
-
### Changed
- Add optional `features` parameter in the signature of `TSDataset.to_pandas`, `TSDataset.to_flatten` ([#809](https://github.com/tinkoff-ai/etna/pull/809))
- Signature of the constructor of `TFTModel`, `DeepARModel` ([#1110](https://github.com/tinkoff-ai/etna/pull/1110))
Expand All @@ -39,6 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add saving/loading for transforms, models, pipelines, ensembles; tutorial for saving/loading ([#1068](https://github.com/tinkoff-ai/etna/pull/1068))
- Add hierarchical time series support([#1083](https://github.com/tinkoff-ai/etna/pull/1083))
- Add `WAPE` metric & `wape` functional metric ([#1085](https://github.com/tinkoff-ai/etna/pull/1085))
-
### Fixed
- Missed kwargs in TFT init([#1078](https://github.com/tinkoff-ai/etna/pull/1078))

Expand Down
78 changes: 62 additions & 16 deletions etna/models/catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,26 @@ def __init__(
def _prepare_float_category_columns(self, df: pd.DataFrame):
df[self._float_category_columns] = df[self._float_category_columns].astype(str).astype("category")

def _prepare_pool(self, features: pd.DataFrame, target: np.ndarray) -> Pool:
"""Prepare pool for CatBoost model."""
columns_dtypes = features.dtypes
category_columns_dtypes = columns_dtypes[columns_dtypes == "category"]
self._categorical = category_columns_dtypes.index.tolist()

# select only columns with float categories
float_category_columns_dtypes_indices = [
idx
for idx, x in enumerate(category_columns_dtypes)
if issubclass(x.categories.dtype.type, (float, np.floating))
]
float_category_columns_dtypes = category_columns_dtypes.iloc[float_category_columns_dtypes_indices]
float_category_columns = float_category_columns_dtypes.index
self._float_category_columns = float_category_columns
self._prepare_float_category_columns(features)

train_pool = Pool(features, target, cat_features=self._categorical)
return train_pool

def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_CatBoostAdapter":
"""
Fit Catboost model.
Expand All @@ -57,22 +77,7 @@ def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_CatBoostAdapter":
"""
features = df.drop(columns=["timestamp", "target"])
target = df["target"]
columns_dtypes = features.dtypes
category_columns_dtypes = columns_dtypes[columns_dtypes == "category"]
self._categorical = category_columns_dtypes.index.tolist()

# select only columns with float categories
float_category_columns_dtypes_indices = [
idx
for idx, x in enumerate(category_columns_dtypes)
if issubclass(x.categories.dtype.type, (float, np.floating))
]
float_category_columns_dtypes = category_columns_dtypes.iloc[float_category_columns_dtypes_indices]
float_category_columns = float_category_columns_dtypes.index
self._float_category_columns = float_category_columns
self._prepare_float_category_columns(features)

train_pool = Pool(features, target.values, cat_features=self._categorical)
train_pool = self._prepare_pool(features, target.values)
self.model.fit(train_pool)
return self

Expand Down Expand Up @@ -106,6 +111,47 @@ def get_model(self) -> CatBoostRegressor:
"""
return self.model

def forecast_components(self, df: pd.DataFrame) -> pd.DataFrame:
"""Estimate forecast components.

Parameters
----------
df:
features dataframe

Returns
-------
:
dataframe with forecast components
"""
return self.predict_components(df=df)

def predict_components(self, df: pd.DataFrame) -> pd.DataFrame:
"""Estimate prediction components.

Parameters
----------
df:
features dataframe

Returns
-------
:
dataframe with prediction components
"""
features = df.drop(columns=["timestamp", "target"])

prediction = self.model.predict(features)
pool = self._prepare_pool(features, prediction)
shap_values = self.model.get_feature_importance(pool, type="ShapValues")

# encapsulate expected contribution into components
components = shap_values[:, :-1] + shap_values[:, -1, np.newaxis] / (shap_values.shape[1] - 1)

component_names = [f"target_component_{name}" for name in features.columns]

return pd.DataFrame(data=components, columns=component_names)


class CatBoostPerSegmentModel(
PerSegmentModelMixin,
Expand Down
47 changes: 47 additions & 0 deletions tests/test_models/test_catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from etna.metrics import MAE
from etna.models import CatBoostMultiSegmentModel
from etna.models import CatBoostPerSegmentModel
from etna.models.catboost import _CatBoostAdapter
from etna.pipeline import Pipeline
from etna.transforms import DateFlagsTransform
from etna.transforms import LabelEncoderTransform
Expand Down Expand Up @@ -144,3 +145,49 @@ def test_save_load(model, example_tsds):
horizon = 3
transforms = [LagTransform(in_column="target", lags=list(range(horizon, horizon + 3)))]
assert_model_equals_loaded_original(model=model, ts=example_tsds, transforms=transforms, horizon=horizon)


@pytest.fixture()
def dfs_w_exog():
df = generate_ar_df(start_time="2021-01-01", periods=105, n_segments=1)
df["f1"] = np.sin(df["target"])
df["f2"] = np.cos(df["target"])

df.drop(columns=["segment"], inplace=True)
train = df.iloc[:-5]
test = df.iloc[-5:]
return train, test


def test_forecast_components_equal_predict_components(dfs_w_exog):
train, test = dfs_w_exog

model = _CatBoostAdapter(iterations=10)
model.fit(train, [])

prediction_components = model.predict_components(df=test)
forecast_components = model.forecast_components(df=test)
pd.testing.assert_frame_equal(prediction_components, forecast_components)


def test_forecast_components_names(dfs_w_exog, answer=("target_component_f1", "target_component_f2")):
train, test = dfs_w_exog

model = _CatBoostAdapter(iterations=10)
model.fit(train, [])

components = model.forecast_components(df=test)
assert set(components.columns) == set(answer)


def test_decomposition_sums_to_target(dfs_w_exog):
train, test = dfs_w_exog

model = _CatBoostAdapter(iterations=10)
model.fit(train, [])

y_pred = model.predict(test)
components = model.forecast_components(df=test)

y_hat_pred = np.sum(components.values, axis=1)
np.testing.assert_allclose(y_hat_pred, y_pred)