Skip to content

Commit

Permalink
support multiple seasonalities in mstl_decomposition (#861)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Jun 28, 2024
1 parent d73e43b commit 3e104b5
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 19 deletions.
4 changes: 2 additions & 2 deletions dev/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies:
- matplotlib
- numba>=0.55.0
- numpy>=1.21.6
- pandas>=1.3.5,<2.2
- pandas>=1.3.5
- pyspark>=3.3
- pip
- prophet
Expand All @@ -23,7 +23,7 @@ dependencies:
- fugue[dask,ray]
- nbdev
- plotly-resampler
- polars
- polars[numpy]>=0.0.0rc0
- supersmoother
- tqdm
- utilsforecast>=0.1.4
2 changes: 1 addition & 1 deletion dev/local_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies:
- datasetsforecast
- nbdev
- plotly-resampler
- polars
- polars[numpy]>=0.0.0rc0
- supersmoother
- tqdm
- utilsforecast>=0.1.4
9 changes: 6 additions & 3 deletions nbs/src/core/models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9119,7 +9119,7 @@
"outputs": [],
"source": [
"#| exporti\n",
"def _predict_mstl_seas(mstl_ob, h, season_length):\n",
"def _predict_mstl_components(mstl_ob, h, season_length):\n",
" seasoncolumns = mstl_ob.filter(regex='seasonal*').columns\n",
" nseasons = len(seasoncolumns)\n",
" seascomp = np.full((h, nseasons), np.nan)\n",
Expand All @@ -9128,8 +9128,11 @@
" mp = seasonal_periods[i]\n",
" colname = seasoncolumns[i]\n",
" seascomp[:, i] = np.tile(mstl_ob[colname].values[-mp:], trunc(1 + (h-1)/mp))[:h]\n",
" lastseas = seascomp.sum(axis=1)\n",
" return lastseas"
" return seascomp\n",
"\n",
"def _predict_mstl_seas(mstl_ob, h, season_length):\n",
" seascomp = _predict_mstl_components(mstl_ob, h, season_length)\n",
" return seascomp.sum(axis=1)"
]
},
{
Expand Down
32 changes: 29 additions & 3 deletions nbs/src/feature_engineering.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "eb55820a-551f-45e6-b6a8-c8f2d868aa32",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -41,7 +53,7 @@
"\n",
"from statsforecast import StatsForecast\n",
"from statsforecast.core import _id_as_idx\n",
"from statsforecast.models import MSTL, _predict_mstl_seas"
"from statsforecast.models import MSTL, _predict_mstl_components"
]
},
{
Expand Down Expand Up @@ -92,11 +104,13 @@
" train_features = []\n",
" future_features = []\n",
" df_constructor = type(df)\n",
" seas_cols = [c for c in sf.fitted_[0, 0].model_.columns if c.startswith('seasonal')]\n",
" for fitted_model in sf.fitted_[:, 0]:\n",
" train_features.append(fitted_model.model_[['trend', 'seasonal']])\n",
" train_features.append(fitted_model.model_[['trend'] + seas_cols])\n",
" seas_comp = _predict_mstl_components(fitted_model.model_, h, model.season_length)\n",
" future_df = df_constructor({\n",
" 'trend': fitted_model.trend_forecaster.predict(h)['mean'],\n",
" 'seasonal': _predict_mstl_seas(fitted_model.model_, h, model.season_length)\n",
" **dict(zip(seas_cols, seas_comp.T)),\n",
" })\n",
" future_features.append(future_df)\n",
" train_features = vertical_concat(train_features, match_categories=False)\n",
Expand Down Expand Up @@ -185,6 +199,18 @@
"with_estimate = train_df_pl.with_columns(estimate=pl.col('trend') + pl.col('seasonal'))\n",
"assert smape(with_estimate, models=['estimate'])['estimate'].mean() < 0.1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fb35b3f1-9c67-4422-a4db-f5b51a246dac",
"metadata": {},
"outputs": [],
"source": [
"model = MSTL(season_length=[7, 28])\n",
"train_df, X_df = mstl_decomposition(series, model, 'D', horizon)\n",
"assert train_df.columns.intersection(X_df.columns).tolist() == ['unique_id', 'ds', 'trend', 'seasonal7', 'seasonal28']"
]
}
],
"metadata": {
Expand Down
2 changes: 2 additions & 0 deletions statsforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,8 @@
'statsforecast.models._intervals': ('src/core/models.html#_intervals', 'statsforecast/models.py'),
'statsforecast.models._optimized_ses_forecast': ( 'src/core/models.html#_optimized_ses_forecast',
'statsforecast/models.py'),
'statsforecast.models._predict_mstl_components': ( 'src/core/models.html#_predict_mstl_components',
'statsforecast/models.py'),
'statsforecast.models._predict_mstl_seas': ( 'src/core/models.html#_predict_mstl_seas',
'statsforecast/models.py'),
'statsforecast.models._probability': ('src/core/models.html#_probability', 'statsforecast/models.py'),
Expand Down
16 changes: 9 additions & 7 deletions statsforecast/feature_engineering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# %% auto 0
__all__ = ['mstl_decomposition']

# %% ../nbs/src/feature_engineering.ipynb 2
# %% ../nbs/src/feature_engineering.ipynb 3
from typing import Tuple

import pandas as pd
Expand All @@ -18,9 +18,9 @@

from . import StatsForecast
from .core import _id_as_idx
from .models import MSTL, _predict_mstl_seas
from .models import MSTL, _predict_mstl_components

# %% ../nbs/src/feature_engineering.ipynb 3
# %% ../nbs/src/feature_engineering.ipynb 4
def mstl_decomposition(
df: DataFrame,
model: MSTL,
Expand Down Expand Up @@ -61,14 +61,16 @@ def mstl_decomposition(
train_features = []
future_features = []
df_constructor = type(df)
seas_cols = [c for c in sf.fitted_[0, 0].model_.columns if c.startswith("seasonal")]
for fitted_model in sf.fitted_[:, 0]:
train_features.append(fitted_model.model_[["trend", "seasonal"]])
train_features.append(fitted_model.model_[["trend"] + seas_cols])
seas_comp = _predict_mstl_components(
fitted_model.model_, h, model.season_length
)
future_df = df_constructor(
{
"trend": fitted_model.trend_forecaster.predict(h)["mean"],
"seasonal": _predict_mstl_seas(
fitted_model.model_, h, model.season_length
),
**dict(zip(seas_cols, seas_comp.T)),
}
)
future_features.append(future_df)
Expand Down
10 changes: 7 additions & 3 deletions statsforecast/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4962,7 +4962,7 @@ def forecast(
return res

# %% ../nbs/src/core/models.ipynb 363
def _predict_mstl_seas(mstl_ob, h, season_length):
def _predict_mstl_components(mstl_ob, h, season_length):
seasoncolumns = mstl_ob.filter(regex="seasonal*").columns
nseasons = len(seasoncolumns)
seascomp = np.full((h, nseasons), np.nan)
Expand All @@ -4975,8 +4975,12 @@ def _predict_mstl_seas(mstl_ob, h, season_length):
seascomp[:, i] = np.tile(
mstl_ob[colname].values[-mp:], trunc(1 + (h - 1) / mp)
)[:h]
lastseas = seascomp.sum(axis=1)
return lastseas
return seascomp


def _predict_mstl_seas(mstl_ob, h, season_length):
seascomp = _predict_mstl_components(mstl_ob, h, season_length)
return seascomp.sum(axis=1)

# %% ../nbs/src/core/models.ipynb 364
class MSTL(_TS):
Expand Down

0 comments on commit 3e104b5

Please sign in to comment.