Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix MSTL where trend forecaster supports level #625

Merged
merged 5 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 84 additions & 45 deletions nbs/src/core/models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
"source": [
"#| export\n",
"import warnings\n",
"from inspect import signature\n",
"from math import trunc\n",
"from typing import Any, Dict, List, Optional, Sequence, Tuple, Union\n",
"\n",
Expand Down Expand Up @@ -819,40 +818,50 @@
" cls_.forward(y=x, h=h, fitted=True)['fitted'],\n",
" cls_.predict_in_sample()['fitted'], \n",
" )\n",
" \n",
" def check_dict_equals(dict_1, dict_2):\n",
" if not dict_1.keys() == dict_2.keys():\n",
" return False\n",
" return all(np.array_equal(dict_1[key], dict_2[key], equal_nan=True) for key in dict_1)\n",
" \n",
"\n",
" if test_forward:\n",
" if not check_dict_equals(cls_.predict(h=h), cls_.forward(y=x, h=h)):\n",
" try:\n",
" pd.testing.assert_frame_equal(\n",
" pd.DataFrame(cls_.predict(h=h)),\n",
" pd.DataFrame(cls_.forward(y=x, h=h)),\n",
" )\n",
" except AssertionError:\n",
" raise Exception('predict and forward methods are not equal')\n",
" \n",
" if level is not None:\n",
" fcst_cls = cls_.predict(h=h, level=level)\n",
" fcst_forecast = cls_.forecast(y=x, h=h, level=level)\n",
" if not check_dict_equals(fcst_cls, fcst_forecast):\n",
" fcst_cls = pd.DataFrame(cls_.predict(h=h, level=level))\n",
" fcst_forecast = pd.DataFrame(cls_.forecast(y=x, h=h, level=level))\n",
" try:\n",
" pd.testing.assert_frame_equal(fcst_cls, fcst_forecast)\n",
" except AssertionError:\n",
" raise Exception('predict and forecast methods are not equal with levels')\n",
" \n",
" if test_forward:\n",
" if not check_dict_equals(cls_.predict(h=h, level=level), \n",
" cls_.forward(y=x, h=h, level=level)):\n",
" try:\n",
" pd.testing.assert_frame_equal(\n",
" pd.DataFrame(cls_.predict(h=h, level=level)),\n",
" pd.DataFrame(cls_.forward(y=x, h=h, level=level))\n",
" )\n",
" except AssertionError:\n",
" raise Exception('predict and forward methods are not equal with levels')\n",
" \n",
" if not skip_insample:\n",
" fcst_cls = cls_.predict_in_sample(level=level)\n",
" fcst_cls = pd.DataFrame(cls_.predict_in_sample(level=level))\n",
" fcst_forecast = cls_.forecast(y=x, h=h, level=level, fitted=True)\n",
" fcst_forecast = {key: val for key, val in fcst_forecast.items() if 'fitted' in key}\n",
" if not check_dict_equals(fcst_cls, fcst_forecast):\n",
" fcst_forecast = pd.DataFrame({key: val for key, val in fcst_forecast.items() if 'fitted' in key})\n",
" try:\n",
" pd.testing.assert_frame_equal(fcst_cls, fcst_forecast)\n",
" except AssertionError:\n",
" raise Exception(\n",
" 'predict and forecast methods are not equal with ' \n",
" 'levels for fitted values '\n",
" )\n",
" if test_forward:\n",
" fcst_forward = cls_.forecast(y=x, h=h, level=level, fitted=True)\n",
" fcst_forward = {key: val for key, val in fcst_forward.items() if 'fitted' in key}\n",
" if not check_dict_equals(fcst_cls, fcst_forward):\n",
" fcst_forward = pd.DataFrame({key: val for key, val in fcst_forward.items() if 'fitted' in key})\n",
" try:\n",
" pd.testing.assert_frame_equal(fcst_cls, fcst_forward)\n",
" except AssertionError:\n",
" raise Exception(\n",
" 'predict and forward methods are not equal with ' \n",
" 'levels for fitted values '\n",
Expand Down Expand Up @@ -6041,6 +6050,7 @@
" \n",
" if level is None:\n",
" return res\n",
" level = sorted(level) \n",
" if self.prediction_intervals is not None:\n",
" res = self._add_predict_conformal_intervals(res, level)\n",
" else:\n",
Expand All @@ -6066,6 +6076,7 @@
" \"\"\" \n",
" res = {'fitted': self.model_['fitted']}\n",
" if level is not None:\n",
" level = sorted(level)\n",
" res = _add_fitted_pi(res=res, se=self.model_['sigma'], level=level)\n",
" return res\n",
" \n",
Expand Down Expand Up @@ -8945,7 +8956,7 @@
" )\n",
" x_sa = self.model_[['trend', 'remainder']].sum(axis=1).values\n",
" self.trend_forecaster = self.trend_forecaster.new().fit(y=x_sa, X=X)\n",
" self._store_cs(y=y, X=X)\n",
" self._store_cs(y=x_sa, X=X)\n",
" return self\n",
" \n",
" def predict(\n",
Expand All @@ -8971,16 +8982,16 @@
" Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions.\n",
" \"\"\"\n",
" kwargs: Dict[str, Any] = {'h': h, 'X': X}\n",
" if 'level' in signature(self.trend_forecaster.predict).parameters:\n",
" if self.trend_forecaster.prediction_intervals is None:\n",
" kwargs['level'] = level\n",
" res = self.trend_forecaster.predict(**kwargs)\n",
" seas = _predict_mstl_seas(self.model_, h=h, season_length=self.season_length)\n",
" res = {key: val + seas for key, val in res.items()}\n",
" if level is None:\n",
" if level is None or self.trend_forecaster.prediction_intervals is None:\n",
" return res\n",
" level = sorted(level)\n",
" if self.trend_forecaster.prediction_intervals is not None:\n",
" res = self._add_predict_conformal_intervals(res, level)\n",
" res = self.trend_forecaster._add_predict_conformal_intervals(res, level)\n",
" else:\n",
" raise Exception(\n",
" \"You have to instantiate either the trend forecaster class or MSTL class with `prediction_intervals` to calculate them\"\n",
Expand All @@ -9000,11 +9011,7 @@
" forecasts : dict \n",
" Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions.\n",
" \"\"\"\n",
" kwargs = {}\n",
" if 'level' in signature(self.trend_forecaster.predict_in_sample).parameters:\n",
" kwargs['level'] = level\n",
" \n",
" res = self.trend_forecaster.predict_in_sample(**kwargs)\n",
" res = self.trend_forecaster.predict_in_sample(level=level)\n",
" seas = self.model_.filter(regex='seasonal*').sum(axis=1).values\n",
" res = {key: val + seas for key, val in res.items()}\n",
" return res\n",
Expand Down Expand Up @@ -9057,27 +9064,26 @@
" 'X_future': X_future,\n",
" 'fitted': fitted\n",
" }\n",
" if 'level' in signature(self.trend_forecaster.forecast).parameters:\n",
" if fitted or self.trend_forecaster.prediction_intervals is None:\n",
" kwargs['level'] = level\n",
" res = self.trend_forecaster.forecast(**kwargs)\n",
" if level is not None:\n",
" level = sorted(level)\n",
" if self.trend_forecaster.prediction_intervals is not None:\n",
" res = self.trend_forecaster._add_conformal_intervals(fcst=res, y=x_sa, X=X, level=level)\n",
" elif f'lo-{level[0]}' not in res:\n",
" raise Exception(\n",
" \"You have to instantiate either the trend forecaster class or MSTL class with `prediction_intervals` to calculate them\"\n",
" ) \n",
" #reseasonalize results\n",
" seas_h = _predict_mstl_seas(model_, h=h, season_length=self.season_length)\n",
" seas_insample = model_.filter(regex='seasonal*').sum(axis=1).values\n",
" res = {\n",
" key: val + (seas_insample if 'fitted' in key else seas_h) \\\n",
" for key, val in res.items()\n",
" }\n",
" if level is None:\n",
" return res\n",
" level = sorted(level)\n",
" if self.trend_forecaster.prediction_intervals is not None:\n",
" res = self._add_conformal_intervals(fcst=res, y=y, X=X, level=level)\n",
" else:\n",
" raise Exception(\n",
" \"You have to instantiate either the trend forecaster class or MSTL class with `prediction_intervals` to calculate them\"\n",
" )\n",
" return res\n",
" \n",
"\n",
" def forward(\n",
" self,\n",
" y: np.ndarray,\n",
Expand Down Expand Up @@ -9124,20 +9130,20 @@
" 'X_future': X_future,\n",
" 'fitted': fitted\n",
" }\n",
" if 'level' in signature(self.trend_forecaster.forward).parameters:\n",
" if fitted or self.trend_forecaster.prediction_intervals is None:\n",
" kwargs['level'] = level\n",
" res = self.trend_forecaster.forward(**kwargs)\n",
" if level is not None:\n",
" level = sorted(level)\n",
" if self.trend_forecaster.prediction_intervals is not None:\n",
" res = self.trend_forecaster._add_conformal_intervals(fcst=res, y=x_sa, X=X, level=level) \n",
" #reseasonalize results\n",
" seas_h = _predict_mstl_seas(model_, h=h, season_length=self.season_length)\n",
" seas_insample = model_.filter(regex='seasonal*').sum(axis=1).values\n",
" res = {\n",
" key: val + (seas_insample if 'fitted' in key else seas_h) \\\n",
" for key, val in res.items()\n",
" }\n",
" if level is not None:\n",
" level = sorted(level)\n",
" if self.trend_forecaster.prediction_intervals is not None:\n",
" res = self._add_conformal_intervals(fcst=res, y=y, X=X, level=level)\n",
" return res"
]
},
Expand Down Expand Up @@ -9170,6 +9176,39 @@
" test_forward=test_forward)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# intervals with & without conformal\n",
"# trend fcst supports level, use native levels\n",
"mstl_native = MSTL(season_length=12, trend_forecaster=ARIMA(order=(0, 1, 0)))\n",
"res_native_fp = pd.DataFrame(mstl_native.fit(y=ap).predict(h=24, level=[80, 95]))\n",
"res_native_fc = pd.DataFrame(mstl_native.forecast(y=ap, h=24, level=[80, 95]))\n",
"pd.testing.assert_frame_equal(res_native_fp, res_native_fc)\n",
"\n",
"# trend fcst supports level, use conformal\n",
"mstl_conformal = MSTL(\n",
" season_length=12,\n",
" trend_forecaster=ARIMA(\n",
" order=(0, 1, 0),\n",
" prediction_intervals=ConformalIntervals(h=24),\n",
" ),\n",
")\n",
"res_conformal_fp = pd.DataFrame(mstl_conformal.fit(y=ap).predict(h=24, level=[80, 95]))\n",
"res_conformal_fc = pd.DataFrame(mstl_conformal.forecast(y=ap, h=24, level=[80, 95]))\n",
"pd.testing.assert_frame_equal(res_conformal_fp, res_conformal_fc)\n",
"test_fail(lambda: pd.testing.assert_frame_equal(test_native_fp, test_conformal_fp))\n",
"\n",
"# trend fcst doesn't support level\n",
"mstl_bad = MSTL(season_length=12, trend_forecaster=CrostonClassic())\n",
"test_fail(lambda: mstl_bad.fit(y=ap).predict(h=24, level=[80, 95]), contains='prediction_intervals')\n",
"test_fail(lambda: mstl_bad.forecast(y=ap, h=24, level=[80, 95]), contains='prediction_intervals')"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -9180,8 +9219,8 @@
"# conformal prediction\n",
"# define the prediction interval in the trend_forecaster\n",
"trend_forecasters = [\n",
" AutoARIMA(prediction_intervals=ConformalIntervals(h=13, n_windows=2)), \n",
" AutoCES(prediction_intervals=ConformalIntervals(h=13, n_windows=2)), \n",
" AutoARIMA(prediction_intervals=ConformalIntervals(h=13, n_windows=2)),\n",
" AutoCES(prediction_intervals=ConformalIntervals(h=13, n_windows=2)),\n",
"]\n",
"skip_insamples = [False, True]\n",
"test_forwards = [False, True]\n",
Expand Down
Loading
Loading