Skip to content

Commit

Permalink
Add option to support user defined optimizer for NeuralForecast Models (
Browse files Browse the repository at this point in the history
#901)

* update configure_optimizers

* update model parameter options

* update nbs notebook

* improve test

Fix failure

Fix doc failure

improve test

* Review: Consider 'lr' signature, ignoring the passed in
parameter in optimizer_kwargs

update2

Add test

* Fix failure, simplify test

fix failure

* Fix failed

* Review: Simplify test with reduced models, max_steps

* Change the max_steps from 5 to 1

---------

Co-authored-by: Cristian Challu <cristiani.challu@gmail.com>
  • Loading branch information
JQGoh and cchallu authored Mar 1, 2024
1 parent e1e4474 commit ae91929
Show file tree
Hide file tree
Showing 45 changed files with 442 additions and 6 deletions.
20 changes: 19 additions & 1 deletion nbs/common.base_multivariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@
"outputs": [],
"source": [
"#| export\n",
"import inspect\n",
"import random\n",
"import warnings\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import pytorch_lightning as pl\n",
"from copy import deepcopy\n",
"from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
"\n",
"from neuralforecast.common._scalers import TemporalNorm\n",
Expand Down Expand Up @@ -107,6 +109,8 @@
" drop_last_loader=False,\n",
" random_seed=1, \n",
" alias=None,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(BaseMultivariate, self).__init__()\n",
"\n",
Expand Down Expand Up @@ -140,6 +144,11 @@
" self.early_stop_patience_steps = early_stop_patience_steps\n",
" self.val_check_steps = val_check_steps\n",
" self.step_size = step_size\n",
" # custom optimizer defined by user\n",
" if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n",
" raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
" self.optimizer = optimizer\n",
" self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {} \n",
"\n",
" # Scaler\n",
" self.scaler = TemporalNorm(scaler_type=scaler_type, dim=2) # Time dimension is in the second axis\n",
Expand Down Expand Up @@ -202,7 +211,16 @@
" random.seed(self.random_seed)\n",
" \n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" if self.optimizer:\n",
" optimizer_signature = inspect.signature(self.optimizer)\n",
" optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n",
" if 'lr' in optimizer_signature.parameters:\n",
" if 'lr' in optimizer_kwargs:\n",
" warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n",
" optimizer_kwargs['lr'] = self.learning_rate\n",
" optimizer = self.optimizer(params=self.parameters(), **self.optimizer_kwargs)\n",
" else:\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) \n",
" scheduler = {'scheduler': torch.optim.lr_scheduler.StepLR(optimizer=optimizer,\n",
" step_size=self.lr_decay_steps,\n",
" gamma=0.5),\n",
Expand Down
20 changes: 19 additions & 1 deletion nbs/common.base_recurrent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,15 @@
"outputs": [],
"source": [
"#| export\n",
"import inspect\n",
"import random\n",
"import warnings\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import pytorch_lightning as pl\n",
"from copy import deepcopy\n",
"from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
"\n",
"from neuralforecast.common._scalers import TemporalNorm\n",
Expand Down Expand Up @@ -113,6 +115,8 @@
" drop_last_loader=False,\n",
" random_seed=1, \n",
" alias=None,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(BaseRecurrent, self).__init__()\n",
"\n",
Expand Down Expand Up @@ -154,6 +158,11 @@
" self.lr_decay_steps = max(max_steps // self.num_lr_decays, 1) if self.num_lr_decays > 0 else 10e7\n",
" self.early_stop_patience_steps = early_stop_patience_steps\n",
" self.val_check_steps = val_check_steps\n",
" # custom optimizer defined by user\n",
" if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n",
" raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
" self.optimizer = optimizer\n",
" self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {} \n",
"\n",
" # Variables\n",
" self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []\n",
Expand Down Expand Up @@ -214,7 +223,16 @@
" random.seed(self.random_seed)\n",
" \n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" if self.optimizer:\n",
" optimizer_signature = inspect.signature(self.optimizer)\n",
" optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n",
" if 'lr' in optimizer_signature.parameters:\n",
" if 'lr' in optimizer_kwargs:\n",
" warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n",
" optimizer_kwargs['lr'] = self.learning_rate\n",
" optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n",
" else:\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" scheduler = {'scheduler': torch.optim.lr_scheduler.StepLR(optimizer=optimizer,\n",
" step_size=self.lr_decay_steps,\n",
" gamma=0.5),\n",
Expand Down
20 changes: 19 additions & 1 deletion nbs/common.base_windows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,15 @@
"outputs": [],
"source": [
"#| export\n",
"import inspect\n",
"import random\n",
"import warnings\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import pytorch_lightning as pl\n",
"from copy import deepcopy\n",
"from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
"\n",
"from neuralforecast.common._scalers import TemporalNorm\n",
Expand Down Expand Up @@ -118,6 +120,8 @@
" drop_last_loader=False,\n",
" random_seed=1,\n",
" alias=None,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(BaseWindows, self).__init__()\n",
"\n",
Expand Down Expand Up @@ -164,6 +168,11 @@
" self.val_check_steps = val_check_steps\n",
" self.windows_batch_size = windows_batch_size\n",
" self.step_size = step_size\n",
" # custom optimizer defined by user\n",
" if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n",
" raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
" self.optimizer = optimizer\n",
" self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {}\n",
"\n",
" # Variables\n",
" self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []\n",
Expand Down Expand Up @@ -228,7 +237,16 @@
" random.seed(self.random_seed)\n",
" \n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" if self.optimizer:\n",
" optimizer_signature = inspect.signature(self.optimizer)\n",
" optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n",
" if 'lr' in optimizer_signature.parameters:\n",
" if 'lr' in optimizer_kwargs:\n",
" warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n",
" optimizer_kwargs['lr'] = self.learning_rate\n",
" optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n",
" else:\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" scheduler = {'scheduler': torch.optim.lr_scheduler.StepLR(optimizer=optimizer,\n",
" step_size=self.lr_decay_steps,\n",
" gamma=0.5),\n",
Expand Down
79 changes: 79 additions & 0 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2433,6 +2433,85 @@
"test_pl_df3[0, \"static_1\"] = np.nan\n",
"test_fail(lambda: nf.fit(pl_df, static_df=test_pl_df3), contains=\"Found missing values in ['static_1']\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "859a474c",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test customized optimizer behavior such that the user defiend optimizer result should differ from default\n",
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
"\n",
"for nf_model in [NHITS, RNN, StemGNN]:\n",
" # default optimizer is based on Adam\n",
" params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 1}\n",
" if nf_model.__name__ == \"StemGNN\":\n",
" params.update({\"n_series\": 2})\n",
" models = [nf_model(**params)]\n",
" nf = NeuralForecast(models=models, freq='M')\n",
" nf.fit(AirPassengersPanel_train)\n",
" default_optimizer_predict = nf.predict()\n",
" mean = default_optimizer_predict.loc[:, nf_model.__name__].mean()\n",
"\n",
" # using a customized optimizer\n",
" params.update({\n",
" \"optimizer\": torch.optim.Adadelta,\n",
" \"optimizer_kwargs\": {\"rho\": 0.45}, \n",
" })\n",
" models2 = [nf_model(**params)]\n",
" nf2 = NeuralForecast(models=models2, freq='M')\n",
" nf2.fit(AirPassengersPanel_train)\n",
" customized_optimizer_predict = nf2.predict()\n",
" mean2 = customized_optimizer_predict.loc[:, nf_model.__name__].mean()\n",
" assert mean2 != mean"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3db3fe1e",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test that if the user-defined optimizer is not a subclass of torch.optim.optimizer, failed with exception\n",
"# tests cover different types of base classes such as basewindows, baserecurrent, basemultivariate\n",
"test_fail(lambda: NHITS(h=12, input_size=24, max_steps=10, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
"test_fail(lambda: RNN(h=12, input_size=24, max_steps=10, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
"test_fail(lambda: StemGNN(h=12, input_size=24, max_steps=10, n_series=2, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d908240f",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test that if we pass in \"lr\" parameter, we expect warning and it ignores the passed in 'lr' parameter\n",
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
"\n",
"for nf_model in [NHITS, RNN, StemGNN]:\n",
" params = {\n",
" \"h\": 12, \n",
" \"input_size\": 24, \n",
" \"max_steps\": 1, \n",
" \"optimizer\": torch.optim.Adadelta, \n",
" \"optimizer_kwargs\": {\"lr\": 0.8, \"rho\": 0.45}\n",
" }\n",
" if nf_model.__name__ == \"StemGNN\":\n",
" params.update({\"n_series\": 2})\n",
" models = [nf_model(**params)]\n",
" nf = NeuralForecast(models=models, freq='M')\n",
" with warnings.catch_warnings(record=True) as issued_warnings:\n",
" warnings.simplefilter('always', UserWarning)\n",
" nf.fit(AirPassengersPanel_train)\n",
" assert any(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\" in str(w.message) for w in issued_warnings)"
]
}
],
"metadata": {
Expand Down
6 changes: 6 additions & 0 deletions nbs/models.autoformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
" `alias`: str, optional, Custom name of the model.<br>\n",
" `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>\n",
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
"\n",
"\t*References*<br>\n",
Expand Down Expand Up @@ -530,6 +532,8 @@
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
" optimizer = None,\n",
" optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(Autoformer, self).__init__(h=h,\n",
" input_size=input_size,\n",
Expand All @@ -554,6 +558,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
Expand Down
6 changes: 6 additions & 0 deletions nbs/models.deepar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
" `alias`: str, optional, Custom name of the model.<br>\n",
" `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>\n",
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br> \n",
"\n",
" **References**<br>\n",
Expand Down Expand Up @@ -237,6 +239,8 @@
" random_seed: int = 1,\n",
" num_workers_loader = 0,\n",
" drop_last_loader = False,\n",
" optimizer = None,\n",
" optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
"\n",
" # DeepAR does not support historic exogenous variables\n",
Expand Down Expand Up @@ -279,6 +283,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" self.horizon_backup = self.h # Used because h=0 during training\n",
Expand Down
6 changes: 6 additions & 0 deletions nbs/models.dilated_rnn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
" `alias`: str, optional, Custom name of the model.<br>\n",
" `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>\n",
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br> \n",
" \"\"\"\n",
" # Class attributes\n",
Expand Down Expand Up @@ -422,6 +424,8 @@
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
" optimizer = None,\n",
" optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(DilatedRNN, self).__init__(\n",
" h=h,\n",
Expand All @@ -443,6 +447,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs\n",
" )\n",
"\n",
Expand Down
6 changes: 6 additions & 0 deletions nbs/models.dlinear.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
" `alias`: str, optional, Custom name of the model.<br>\n",
" `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>\n",
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
"\n",
"\t*References*<br>\n",
Expand Down Expand Up @@ -195,6 +197,8 @@
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
" optimizer = None,\n",
" optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(DLinear, self).__init__(h=h,\n",
" input_size=input_size,\n",
Expand All @@ -219,6 +223,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
Expand Down
Loading

0 comments on commit ae91929

Please sign in to comment.