Skip to content

Commit

Permalink
fix: chunk series in parallel forecast
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Sep 19, 2024
1 parent bb02eeb commit ede9e7a
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 223 deletions.
182 changes: 70 additions & 112 deletions nbs/src/core/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@
"import reprlib\n",
"import time\n",
"import warnings\n",
"from collections import defaultdict\n",
"from concurrent.futures import ProcessPoolExecutor, as_completed\n",
"from pathlib import Path\n",
"from typing import Any, Dict, List, Optional, Union\n",
Expand All @@ -100,7 +99,7 @@
"import pandas as pd\n",
"import utilsforecast.processing as ufp\n",
"from fugue.execution.factory import make_execution_engine, try_get_context_execution_engine\n",
"from threadpoolctl import ThreadpoolController, threadpool_limits\n",
"from threadpoolctl import ThreadpoolController\n",
"from tqdm.auto import tqdm\n",
"from triad import conditional_dispatcher\n",
"from utilsforecast.compat import DataFrame, pl_DataFrame, pl_Series\n",
Expand All @@ -124,39 +123,7 @@
" datefmt='%Y-%m-%d %H:%M:%S',\n",
" )\n",
"logger = logging.getLogger(__name__)\n",
"\n",
"_controller = ThreadpoolController()\n",
"\n",
"@_controller.wrap(limits=1)\n",
"def _forecast_serie(h, y, X, X_future, models, fallback_model, level, fitted):\n",
" forecast_res = {}\n",
" fitted_res = {}\n",
" times = {}\n",
" for model in models:\n",
" start = time.perf_counter()\n",
" model_kwargs = dict(h=h, y=y, X=X, X_future=X_future, fitted=fitted)\n",
" if \"level\" in inspect.signature(model.forecast).parameters and level:\n",
" model_kwargs[\"level\"] = level\n",
" try:\n",
" model_res = model.forecast(**model_kwargs)\n",
" except Exception as e:\n",
" if fallback_model is None:\n",
" raise e\n",
" model_res = fallback_model.forecast(**model_kwargs)\n",
" model_name = repr(model)\n",
" times[model_name] = time.perf_counter() - start\n",
" for k, v in model_res.items():\n",
" if k == \"mean\":\n",
" forecast_res[model_name] = v\n",
" elif k.startswith((\"lo\", \"hi\")):\n",
" col_name = f\"{model_name}-{k}\"\n",
" forecast_res[col_name] = v\n",
" elif k == \"fitted\":\n",
" fitted_res[model_name] = v\n",
" elif k.startswith((\"fitted-lo\", \"fitted-hi\")):\n",
" col_name = f'{model_name}-{k.replace(\"fitted-\", \"\")}'\n",
" fitted_res[col_name] = v\n",
" return forecast_res, fitted_res, times"
"_controller = ThreadpoolController()"
]
},
{
Expand Down Expand Up @@ -471,19 +438,19 @@
" def split_fm(self, fm, n_chunks):\n",
" return [fm[idxs] for idxs in np.array_split(range(self.n_groups), n_chunks) if idxs.size]\n",
"\n",
"\n",
" @_controller.wrap(limits=1)\n",
" def _single_threaded_fit(self, models, fallback_model=None):\n",
" with threadpool_limits(limits=1):\n",
" return self.fit(models=models, fallback_model=fallback_model)\n",
" return self.fit(models=models, fallback_model=fallback_model)\n",
"\n",
" @_controller.wrap(limits=1)\n",
" def _single_threaded_predict(self, fm, h, X=None, level=tuple()):\n",
" with threadpool_limits(limits=1):\n",
" return self.predict(fm=fm, h=h, X=X, level=level)\n",
" return self.predict(fm=fm, h=h, X=X, level=level)\n",
"\n",
" @_controller.wrap(limits=1)\n",
" def _single_threaded_fit_predict(self, models, h, X=None, level=tuple()):\n",
" with threadpool_limits(limits=1):\n",
" return self.fit_predict(models=models, h=h, X=X, level=level)\n",
" return self.fit_predict(models=models, h=h, X=X, level=level)\n",
"\n",
" @_controller.wrap(limits=1)\n",
" def _single_threaded_forecast(\n",
" self,\n",
" models,\n",
Expand All @@ -495,18 +462,18 @@
" verbose=False,\n",
" target_col='y',\n",
" ):\n",
" with threadpool_limits(limits=1):\n",
" return self.forecast(\n",
" models=models,\n",
" h=h,\n",
" fallback_model=fallback_model,\n",
" fitted=fitted,\n",
" X=X,\n",
" level=level,\n",
" verbose=verbose,\n",
" target_col=target_col,\n",
" )\n",
" \n",
" return self.forecast(\n",
" models=models,\n",
" h=h,\n",
" fallback_model=fallback_model,\n",
" fitted=fitted,\n",
" X=X,\n",
" level=level,\n",
" verbose=verbose,\n",
" target_col=target_col,\n",
" )\n",
"\n",
" @_controller.wrap(limits=1)\n",
" def _single_threaded_cross_validation(\n",
" self,\n",
" models,\n",
Expand All @@ -521,20 +488,19 @@
" verbose=False,\n",
" target_col='y',\n",
" ):\n",
" with threadpool_limits(limits=1):\n",
" return self.cross_validation(\n",
" models=models,\n",
" h=h,\n",
" test_size=test_size,\n",
" fallback_model=fallback_model,\n",
" step_size=step_size,\n",
" input_size=input_size,\n",
" fitted=fitted,\n",
" level=level,\n",
" refit=refit,\n",
" verbose=verbose,\n",
" target_col=target_col,\n",
" )"
" return self.cross_validation(\n",
" models=models,\n",
" h=h,\n",
" test_size=test_size,\n",
" fallback_model=fallback_model,\n",
" step_size=step_size,\n",
" input_size=input_size,\n",
" fitted=fitted,\n",
" level=level,\n",
" refit=refit,\n",
" verbose=verbose,\n",
" target_col=target_col,\n",
" )"
]
},
{
Expand Down Expand Up @@ -1685,12 +1651,14 @@
" fm = np.vstack([f.get() for f in futures])\n",
" return fm \n",
" \n",
" def _get_gas_Xs(self, X):\n",
" gas = self.ga.split(self.n_jobs)\n",
" def _get_gas_Xs(self, X, tasks_per_job=1):\n",
" n_chunks = min(tasks_per_job * self.n_jobs, self.ga.n_groups)\n",
" gas = self.ga.split(n_chunks)\n",
" if X is not None:\n",
" Xs = X.split(self.n_jobs)\n",
" Xs = X.split(n_chunks)\n",
" else:\n",
" from itertools import repeat\n",
"\n",
" Xs = repeat(None)\n",
" return gas, Xs\n",
" \n",
Expand Down Expand Up @@ -1735,57 +1703,47 @@
" return fm, fcsts, cols\n",
"\n",
" def _forecast_parallel(self, h, fitted, X, level, target_col):\n",
" n_series = self.ga.n_groups\n",
" forecast_res = defaultdict(lambda: np.empty(n_series * h, dtype=self.ga.data.dtype))\n",
" fitted_res = defaultdict(\n",
" lambda: np.empty(self.ga.data.shape[0], dtype=self.ga.data.dtype)\n",
" )\n",
" fitted_res[target_col] = self.ga.data[:, 0]\n",
" future2pos = {}\n",
" times = {repr(m): 0.0 for m in self.models}\n",
" gas, Xs = self._get_gas_Xs(X=X, tasks_per_job=100)\n",
" results = [None] * len(gas)\n",
" with ProcessPoolExecutor(self.n_jobs) as executor:\n",
" for i, serie in enumerate(self.ga):\n",
" y_train = serie[:, 0]\n",
" X_train = serie[:, 1:] if serie.shape[1] > 1 else None\n",
" if X is None:\n",
" X_future = None\n",
" else:\n",
" X_future = X[i]\n",
" future = executor.submit(\n",
" _forecast_serie,\n",
" future2pos = {\n",
" executor.submit(\n",
" ga.forecast,\n",
" h=h,\n",
" y=y_train,\n",
" X=X_train,\n",
" X_future=X_future,\n",
" models=self.models,\n",
" fallback_model=self.fallback_model,\n",
" fitted=fitted,\n",
" X=X,\n",
" level=level,\n",
" )\n",
" future2pos[future] = i\n",
" verbose=False,\n",
" target_col=target_col,\n",
" ): i\n",
" for i, (ga, X) in enumerate(zip(gas, Xs))\n",
" }\n",
" iterable = tqdm(\n",
" as_completed(future2pos), disable=not self.verbose, total=n_series, desc=\"Forecast\"\n",
" ) \n",
" as_completed(future2pos),\n",
" disable=not self.verbose,\n",
" total=len(future2pos),\n",
" desc=\"Forecast\",\n",
" bar_format=\"{l_bar}{bar}| {n_fmt}/{total_fmt} [Elapsed: {elapsed}{postfix}]\",\n",
" )\n",
" for future in iterable:\n",
" i = future2pos[future]\n",
" fcst_idxs = slice(i * h, (i + 1) * h)\n",
" serie_idxs = slice(self.ga.indptr[i], self.ga.indptr[i + 1])\n",
" serie_fcst, serie_fitted, serie_times = future.result()\n",
" for k, v in serie_fcst.items():\n",
" forecast_res[k][fcst_idxs] = v\n",
" for k, v in serie_fitted.items():\n",
" fitted_res[k][serie_idxs] = v\n",
" for model_name, model_time in serie_times.items():\n",
" times[model_name] += model_time\n",
" return {\n",
" 'cols': list(forecast_res.keys()),\n",
" 'forecasts': np.hstack([v[:, None] for v in forecast_res.values()]),\n",
" 'fitted': {\n",
" 'cols': list(fitted_res.keys()),\n",
" 'values': np.hstack([v[:, None] for v in fitted_res.values()]),\n",
" results[i] = future.result()\n",
" result = {\n",
" 'cols': results[0]['cols'],\n",
" 'forecasts': np.vstack([r['forecasts'] for r in results]),\n",
" 'times': {\n",
" m: sum(r['times'][m] for r in results)\n",
" for m in [repr(m) for m in self.models]\n",
" },\n",
" 'times': times,\n",
" } \n",
" }\n",
" if fitted:\n",
" result['fitted'] = {\n",
" 'cols': results[0]['fitted']['cols'],\n",
" 'values': np.hstack([r['fitted']['values'] for r in results]),\n",
" }\n",
" return result\n",
"\n",
" def _cross_validation_parallel(self, h, test_size, step_size, input_size, fitted, level, refit, target_col):\n",
" #create elements for each core\n",
Expand Down
1 change: 0 additions & 1 deletion python/statsforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@
'statsforecast/core.py'),
'statsforecast.core._StatsForecast.save': ( 'src/core/core.html#_statsforecast.save',
'statsforecast/core.py'),
'statsforecast.core._forecast_serie': ('src/core/core.html#_forecast_serie', 'statsforecast/core.py'),
'statsforecast.core._get_n_jobs': ('src/core/core.html#_get_n_jobs', 'statsforecast/core.py'),
'statsforecast.core._id_as_idx': ('src/core/core.html#_id_as_idx', 'statsforecast/core.py'),
'statsforecast.core._maybe_warn_sort_df': ( 'src/core/core.html#_maybe_warn_sort_df',
Expand Down
Loading

0 comments on commit ede9e7a

Please sign in to comment.