diff --git a/nbs/common.base_multivariate.ipynb b/nbs/common.base_multivariate.ipynb
index 7e0eac5b6..1196a8694 100644
--- a/nbs/common.base_multivariate.ipynb
+++ b/nbs/common.base_multivariate.ipynb
@@ -54,6 +54,7 @@
"outputs": [],
"source": [
"#| export\n",
+ "import inspect\n",
"import random\n",
"import warnings\n",
"\n",
@@ -61,6 +62,7 @@
"import torch\n",
"import torch.nn as nn\n",
"import pytorch_lightning as pl\n",
+ "from copy import deepcopy\n",
"from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
"\n",
"from neuralforecast.common._scalers import TemporalNorm\n",
@@ -107,6 +109,8 @@
" drop_last_loader=False,\n",
" random_seed=1, \n",
" alias=None,\n",
+ " optimizer=None,\n",
+ " optimizer_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(BaseMultivariate, self).__init__()\n",
"\n",
@@ -140,6 +144,11 @@
" self.early_stop_patience_steps = early_stop_patience_steps\n",
" self.val_check_steps = val_check_steps\n",
" self.step_size = step_size\n",
+ " # custom optimizer defined by user\n",
+ " if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n",
+ " raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
+ " self.optimizer = optimizer\n",
+ " self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {} \n",
"\n",
" # Scaler\n",
" self.scaler = TemporalNorm(scaler_type=scaler_type, dim=2) # Time dimension is in the second axis\n",
@@ -202,7 +211,16 @@
" random.seed(self.random_seed)\n",
" \n",
" def configure_optimizers(self):\n",
- " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
+ " if self.optimizer:\n",
+ " optimizer_signature = inspect.signature(self.optimizer)\n",
+ " optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n",
+ " if 'lr' in optimizer_signature.parameters:\n",
+ " if 'lr' in optimizer_kwargs:\n",
+ " warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n",
+ " optimizer_kwargs['lr'] = self.learning_rate\n",
+ " optimizer = self.optimizer(params=self.parameters(), **self.optimizer_kwargs)\n",
+ " else:\n",
+ " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) \n",
" scheduler = {'scheduler': torch.optim.lr_scheduler.StepLR(optimizer=optimizer,\n",
" step_size=self.lr_decay_steps,\n",
" gamma=0.5),\n",
diff --git a/nbs/common.base_recurrent.ipynb b/nbs/common.base_recurrent.ipynb
index 23ce2add0..600b6c5bd 100644
--- a/nbs/common.base_recurrent.ipynb
+++ b/nbs/common.base_recurrent.ipynb
@@ -60,6 +60,7 @@
"outputs": [],
"source": [
"#| export\n",
+ "import inspect\n",
"import random\n",
"import warnings\n",
"\n",
@@ -67,6 +68,7 @@
"import torch\n",
"import torch.nn as nn\n",
"import pytorch_lightning as pl\n",
+ "from copy import deepcopy\n",
"from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
"\n",
"from neuralforecast.common._scalers import TemporalNorm\n",
@@ -113,6 +115,8 @@
" drop_last_loader=False,\n",
" random_seed=1, \n",
" alias=None,\n",
+ " optimizer=None,\n",
+ " optimizer_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(BaseRecurrent, self).__init__()\n",
"\n",
@@ -154,6 +158,11 @@
" self.lr_decay_steps = max(max_steps // self.num_lr_decays, 1) if self.num_lr_decays > 0 else 10e7\n",
" self.early_stop_patience_steps = early_stop_patience_steps\n",
" self.val_check_steps = val_check_steps\n",
+ " # custom optimizer defined by user\n",
+ " if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n",
+ " raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
+ " self.optimizer = optimizer\n",
+ " self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {} \n",
"\n",
" # Variables\n",
" self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []\n",
@@ -214,7 +223,16 @@
" random.seed(self.random_seed)\n",
" \n",
" def configure_optimizers(self):\n",
- " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
+ " if self.optimizer:\n",
+ " optimizer_signature = inspect.signature(self.optimizer)\n",
+ " optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n",
+ " if 'lr' in optimizer_signature.parameters:\n",
+ " if 'lr' in optimizer_kwargs:\n",
+ " warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n",
+ " optimizer_kwargs['lr'] = self.learning_rate\n",
+ " optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n",
+ " else:\n",
+ " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" scheduler = {'scheduler': torch.optim.lr_scheduler.StepLR(optimizer=optimizer,\n",
" step_size=self.lr_decay_steps,\n",
" gamma=0.5),\n",
diff --git a/nbs/common.base_windows.ipynb b/nbs/common.base_windows.ipynb
index cc728b408..d80e87114 100644
--- a/nbs/common.base_windows.ipynb
+++ b/nbs/common.base_windows.ipynb
@@ -60,6 +60,7 @@
"outputs": [],
"source": [
"#| export\n",
+ "import inspect\n",
"import random\n",
"import warnings\n",
"\n",
@@ -67,6 +68,7 @@
"import torch\n",
"import torch.nn as nn\n",
"import pytorch_lightning as pl\n",
+ "from copy import deepcopy\n",
"from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
"\n",
"from neuralforecast.common._scalers import TemporalNorm\n",
@@ -118,6 +120,8 @@
" drop_last_loader=False,\n",
" random_seed=1,\n",
" alias=None,\n",
+ " optimizer=None,\n",
+ " optimizer_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(BaseWindows, self).__init__()\n",
"\n",
@@ -164,6 +168,11 @@
" self.val_check_steps = val_check_steps\n",
" self.windows_batch_size = windows_batch_size\n",
" self.step_size = step_size\n",
+ " # custom optimizer defined by user\n",
+ " if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n",
+ " raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
+ " self.optimizer = optimizer\n",
+ " self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {}\n",
"\n",
" # Variables\n",
" self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []\n",
@@ -228,7 +237,16 @@
" random.seed(self.random_seed)\n",
" \n",
" def configure_optimizers(self):\n",
- " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
+ " if self.optimizer:\n",
+ " optimizer_signature = inspect.signature(self.optimizer)\n",
+ " optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n",
+ " if 'lr' in optimizer_signature.parameters:\n",
+ " if 'lr' in optimizer_kwargs:\n",
+ " warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n",
+ " optimizer_kwargs['lr'] = self.learning_rate\n",
+ " optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n",
+ " else:\n",
+ " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" scheduler = {'scheduler': torch.optim.lr_scheduler.StepLR(optimizer=optimizer,\n",
" step_size=self.lr_decay_steps,\n",
" gamma=0.5),\n",
diff --git a/nbs/core.ipynb b/nbs/core.ipynb
index f03ca60b2..ce6d5f183 100644
--- a/nbs/core.ipynb
+++ b/nbs/core.ipynb
@@ -2433,6 +2433,85 @@
"test_pl_df3[0, \"static_1\"] = np.nan\n",
"test_fail(lambda: nf.fit(pl_df, static_df=test_pl_df3), contains=\"Found missing values in ['static_1']\")"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "859a474c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "# test customized optimizer behavior such that the user defiend optimizer result should differ from default\n",
+ "# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
+ "\n",
+ "for nf_model in [NHITS, RNN, StemGNN]:\n",
+ " # default optimizer is based on Adam\n",
+ " params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 1}\n",
+ " if nf_model.__name__ == \"StemGNN\":\n",
+ " params.update({\"n_series\": 2})\n",
+ " models = [nf_model(**params)]\n",
+ " nf = NeuralForecast(models=models, freq='M')\n",
+ " nf.fit(AirPassengersPanel_train)\n",
+ " default_optimizer_predict = nf.predict()\n",
+ " mean = default_optimizer_predict.loc[:, nf_model.__name__].mean()\n",
+ "\n",
+ " # using a customized optimizer\n",
+ " params.update({\n",
+ " \"optimizer\": torch.optim.Adadelta,\n",
+ " \"optimizer_kwargs\": {\"rho\": 0.45}, \n",
+ " })\n",
+ " models2 = [nf_model(**params)]\n",
+ " nf2 = NeuralForecast(models=models2, freq='M')\n",
+ " nf2.fit(AirPassengersPanel_train)\n",
+ " customized_optimizer_predict = nf2.predict()\n",
+ " mean2 = customized_optimizer_predict.loc[:, nf_model.__name__].mean()\n",
+ " assert mean2 != mean"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3db3fe1e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "# test that if the user-defined optimizer is not a subclass of torch.optim.optimizer, failed with exception\n",
+ "# tests cover different types of base classes such as basewindows, baserecurrent, basemultivariate\n",
+ "test_fail(lambda: NHITS(h=12, input_size=24, max_steps=10, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
+ "test_fail(lambda: RNN(h=12, input_size=24, max_steps=10, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
+ "test_fail(lambda: StemGNN(h=12, input_size=24, max_steps=10, n_series=2, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d908240f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "# test that if we pass in \"lr\" parameter, we expect warning and it ignores the passed in 'lr' parameter\n",
+ "# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
+ "\n",
+ "for nf_model in [NHITS, RNN, StemGNN]:\n",
+ " params = {\n",
+ " \"h\": 12, \n",
+ " \"input_size\": 24, \n",
+ " \"max_steps\": 1, \n",
+ " \"optimizer\": torch.optim.Adadelta, \n",
+ " \"optimizer_kwargs\": {\"lr\": 0.8, \"rho\": 0.45}\n",
+ " }\n",
+ " if nf_model.__name__ == \"StemGNN\":\n",
+ " params.update({\"n_series\": 2})\n",
+ " models = [nf_model(**params)]\n",
+ " nf = NeuralForecast(models=models, freq='M')\n",
+ " with warnings.catch_warnings(record=True) as issued_warnings:\n",
+ " warnings.simplefilter('always', UserWarning)\n",
+ " nf.fit(AirPassengersPanel_train)\n",
+ " assert any(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\" in str(w.message) for w in issued_warnings)"
+ ]
}
],
"metadata": {
diff --git a/nbs/models.autoformer.ipynb b/nbs/models.autoformer.ipynb
index 478903365..27bd6a1aa 100644
--- a/nbs/models.autoformer.ipynb
+++ b/nbs/models.autoformer.ipynb
@@ -488,6 +488,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
"\n",
"\t*References*
\n",
@@ -530,6 +532,8 @@
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
+ " optimizer = None,\n",
+ " optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(Autoformer, self).__init__(h=h,\n",
" input_size=input_size,\n",
@@ -554,6 +558,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
diff --git a/nbs/models.deepar.ipynb b/nbs/models.deepar.ipynb
index 8117416e9..adeee0bdd 100644
--- a/nbs/models.deepar.ipynb
+++ b/nbs/models.deepar.ipynb
@@ -197,6 +197,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
"\n",
" **References**
\n",
@@ -237,6 +239,8 @@
" random_seed: int = 1,\n",
" num_workers_loader = 0,\n",
" drop_last_loader = False,\n",
+ " optimizer = None,\n",
+ " optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
"\n",
" # DeepAR does not support historic exogenous variables\n",
@@ -279,6 +283,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" self.horizon_backup = self.h # Used because h=0 during training\n",
diff --git a/nbs/models.dilated_rnn.ipynb b/nbs/models.dilated_rnn.ipynb
index 12c418103..4970ee48e 100644
--- a/nbs/models.dilated_rnn.ipynb
+++ b/nbs/models.dilated_rnn.ipynb
@@ -390,6 +390,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
" \"\"\"\n",
" # Class attributes\n",
@@ -422,6 +424,8 @@
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
+ " optimizer = None,\n",
+ " optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(DilatedRNN, self).__init__(\n",
" h=h,\n",
@@ -443,6 +447,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs\n",
" )\n",
"\n",
diff --git a/nbs/models.dlinear.ipynb b/nbs/models.dlinear.ipynb
index eb899300e..f95aec9cf 100644
--- a/nbs/models.dlinear.ipynb
+++ b/nbs/models.dlinear.ipynb
@@ -162,6 +162,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
"\n",
"\t*References*
\n",
@@ -195,6 +197,8 @@
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
+ " optimizer = None,\n",
+ " optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(DLinear, self).__init__(h=h,\n",
" input_size=input_size,\n",
@@ -219,6 +223,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
diff --git a/nbs/models.fedformer.ipynb b/nbs/models.fedformer.ipynb
index 9fcce237c..c18503ac9 100644
--- a/nbs/models.fedformer.ipynb
+++ b/nbs/models.fedformer.ipynb
@@ -477,6 +477,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
"\n",
" \"\"\"\n",
@@ -518,6 +520,8 @@
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
+ " optimizer=None,\n",
+ " optimizer_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(FEDformer, self).__init__(h=h,\n",
" input_size=input_size,\n",
@@ -541,6 +545,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
" # Architecture\n",
" self.futr_input_size = len(self.futr_exog_list)\n",
diff --git a/nbs/models.gru.ipynb b/nbs/models.gru.ipynb
index dab5fa84d..e52f2f29d 100644
--- a/nbs/models.gru.ipynb
+++ b/nbs/models.gru.ipynb
@@ -124,6 +124,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
" \"\"\"\n",
" # Class attributes\n",
@@ -157,6 +159,8 @@
" random_seed=1,\n",
" num_workers_loader=0,\n",
" drop_last_loader=False,\n",
+ " optimizer=None,\n",
+ " optimizer_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(GRU, self).__init__(\n",
" h=h,\n",
@@ -178,6 +182,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs\n",
" )\n",
"\n",
diff --git a/nbs/models.informer.ipynb b/nbs/models.informer.ipynb
index e05756e50..2fc393642 100644
--- a/nbs/models.informer.ipynb
+++ b/nbs/models.informer.ipynb
@@ -297,6 +297,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
"\n",
"\t*References*
\n",
@@ -339,6 +341,8 @@
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
+ " optimizer=None,\n",
+ " optimizer_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(Informer, self).__init__(h=h,\n",
" input_size=input_size,\n",
@@ -363,6 +367,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
diff --git a/nbs/models.lstm.ipynb b/nbs/models.lstm.ipynb
index 5d70fca27..7c3c6ea6d 100644
--- a/nbs/models.lstm.ipynb
+++ b/nbs/models.lstm.ipynb
@@ -122,6 +122,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
" \"\"\"\n",
" # Class attributes\n",
@@ -154,6 +156,8 @@
" random_seed = 1,\n",
" num_workers_loader = 0,\n",
" drop_last_loader = False,\n",
+ " optimizer=None,\n",
+ " optimizer_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(LSTM, self).__init__(\n",
" h=h,\n",
@@ -175,6 +179,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs\n",
" )\n",
"\n",
diff --git a/nbs/models.mlp.ipynb b/nbs/models.mlp.ipynb
index 8a4f22bcb..76db09f1a 100644
--- a/nbs/models.mlp.ipynb
+++ b/nbs/models.mlp.ipynb
@@ -114,6 +114,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
" \"\"\"\n",
" # Class attributes\n",
@@ -145,6 +147,8 @@
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
+ " optimizer = None,\n",
+ " optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
"\n",
" # Inherit BaseWindows class\n",
@@ -171,6 +175,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
diff --git a/nbs/models.nbeats.ipynb b/nbs/models.nbeats.ipynb
index f154acfbe..2e9975b0e 100644
--- a/nbs/models.nbeats.ipynb
+++ b/nbs/models.nbeats.ipynb
@@ -270,6 +270,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
"\n",
" **References:**
\n",
@@ -307,6 +309,8 @@
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
+ " optimizer = None,\n",
+ " optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
"\n",
" # Inherit BaseWindows class\n",
@@ -329,6 +333,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
diff --git a/nbs/models.nbeatsx.ipynb b/nbs/models.nbeatsx.ipynb
index 10d59d948..42042954d 100644
--- a/nbs/models.nbeatsx.ipynb
+++ b/nbs/models.nbeatsx.ipynb
@@ -413,6 +413,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
"\n",
" **References:**
\n",
@@ -456,6 +458,8 @@
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
+ " optimizer=None,\n",
+ " optimizer_kwargs=None,\n",
" **trainer_kwargs,\n",
" ):\n",
" # Protect horizon collapsed seasonality and trend NBEATSx-i basis\n",
@@ -488,6 +492,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
diff --git a/nbs/models.nhits.ipynb b/nbs/models.nhits.ipynb
index 3d9dbba74..f00cd6c08 100644
--- a/nbs/models.nhits.ipynb
+++ b/nbs/models.nhits.ipynb
@@ -304,6 +304,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
"\n",
" **References:**
\n",
@@ -347,6 +349,8 @@
" random_seed: int = 1,\n",
" num_workers_loader = 0,\n",
" drop_last_loader = False,\n",
+ " optimizer = None,\n",
+ " optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
"\n",
" # Inherit BaseWindows class\n",
@@ -373,6 +377,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
diff --git a/nbs/models.patchtst.ipynb b/nbs/models.patchtst.ipynb
index 40d3ced3e..09d52ec9c 100644
--- a/nbs/models.patchtst.ipynb
+++ b/nbs/models.patchtst.ipynb
@@ -715,6 +715,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
"\n",
" **References:**
\n",
@@ -764,6 +766,8 @@
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
+ " optimizer = None,\n",
+ " optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(PatchTST, self).__init__(h=h,\n",
" input_size=input_size,\n",
@@ -788,6 +792,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs) \n",
" # Asserts\n",
" if stat_exog_list is not None:\n",
diff --git a/nbs/models.rnn.ipynb b/nbs/models.rnn.ipynb
index 6185bc90c..4ea6d92f1 100644
--- a/nbs/models.rnn.ipynb
+++ b/nbs/models.rnn.ipynb
@@ -125,7 +125,10 @@
" `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
\n",
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ "\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
" \"\"\"\n",
" # Class attributes\n",
@@ -159,6 +162,8 @@
" random_seed=1,\n",
" num_workers_loader=0,\n",
" drop_last_loader=False,\n",
+ " optimizer=None,\n",
+ " optimizer_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(RNN, self).__init__(\n",
" h=h,\n",
@@ -180,6 +185,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs\n",
" )\n",
"\n",
diff --git a/nbs/models.stemgnn.ipynb b/nbs/models.stemgnn.ipynb
index d707063ca..9df803157 100644
--- a/nbs/models.stemgnn.ipynb
+++ b/nbs/models.stemgnn.ipynb
@@ -188,6 +188,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
" \"\"\"\n",
" # Class attributes\n",
@@ -217,6 +219,8 @@
" random_seed: int = 1,\n",
" num_workers_loader = 0,\n",
" drop_last_loader = False,\n",
+ " optimizer = None,\n",
+ " optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
"\n",
" # Inherit BaseMultivariate class\n",
@@ -239,6 +243,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Exogenous variables\n",
diff --git a/nbs/models.tcn.ipynb b/nbs/models.tcn.ipynb
index 3dcf43003..f44072e1e 100644
--- a/nbs/models.tcn.ipynb
+++ b/nbs/models.tcn.ipynb
@@ -128,6 +128,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
" \"\"\"\n",
" # Class attributes\n",
@@ -160,6 +162,8 @@
" random_seed: int = 1,\n",
" num_workers_loader = 0,\n",
" drop_last_loader = False,\n",
+ " optimizer = None,\n",
+ " optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(TCN, self).__init__(\n",
" h=h,\n",
@@ -181,6 +185,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs\n",
" )\n",
"\n",
diff --git a/nbs/models.tft.ipynb b/nbs/models.tft.ipynb
index 8839ac243..56e60d99b 100644
--- a/nbs/models.tft.ipynb
+++ b/nbs/models.tft.ipynb
@@ -678,6 +678,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
"\n",
" **References:**
\n",
@@ -715,6 +717,8 @@
" num_workers_loader = 0,\n",
" drop_last_loader = False,\n",
" random_seed: int = 1,\n",
+ " optimizer = None,\n",
+ " optimizer_kwargs = None,\n",
" **trainer_kwargs\n",
" ):\n",
"\n",
@@ -738,6 +742,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
" self.example_length = input_size + h\n",
"\n",
diff --git a/nbs/models.timesnet.ipynb b/nbs/models.timesnet.ipynb
index ddc21bce3..09c4de132 100644
--- a/nbs/models.timesnet.ipynb
+++ b/nbs/models.timesnet.ipynb
@@ -257,6 +257,10 @@
" Workers to be used by `TimeSeriesDataLoader`.\n",
" drop_last_loader : bool (default=False)\n",
" If True `TimeSeriesDataLoader` drops last non-full batch.\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional (default=None)\n",
+ " User specified optimizer instead of the default choice (Adam).\n",
+ " `optimizer_kwargs`: dict, optional (defualt=None)\n",
+ " List of parameters used by the user specified `optimizer`.\n",
" **trainer_kwargs\n",
" Keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer)\n",
"\n",
@@ -297,6 +301,8 @@
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
+ " optimizer = None,\n",
+ " optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(TimesNet, self).__init__(h=h,\n",
" input_size=input_size,\n",
@@ -321,6 +327,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
diff --git a/nbs/models.vanillatransformer.ipynb b/nbs/models.vanillatransformer.ipynb
index 9689500e1..242994f81 100644
--- a/nbs/models.vanillatransformer.ipynb
+++ b/nbs/models.vanillatransformer.ipynb
@@ -195,6 +195,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
"\n",
"\t*References*
\n",
@@ -234,6 +236,8 @@
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
+ " optimizer = None,\n",
+ " optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(VanillaTransformer, self).__init__(h=h,\n",
" input_size=input_size,\n",
@@ -257,6 +261,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
diff --git a/neuralforecast/common/_base_multivariate.py b/neuralforecast/common/_base_multivariate.py
index be891fcfa..490f458bd 100644
--- a/neuralforecast/common/_base_multivariate.py
+++ b/neuralforecast/common/_base_multivariate.py
@@ -4,6 +4,7 @@
__all__ = ['BaseMultivariate']
# %% ../../nbs/common.base_multivariate.ipynb 5
+import inspect
import random
import warnings
@@ -11,6 +12,7 @@
import torch
import torch.nn as nn
import pytorch_lightning as pl
+from copy import deepcopy
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from ._scalers import TemporalNorm
@@ -52,6 +54,8 @@ def __init__(
drop_last_loader=False,
random_seed=1,
alias=None,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs,
):
super(BaseMultivariate, self).__init__()
@@ -88,6 +92,13 @@ def __init__(
self.early_stop_patience_steps = early_stop_patience_steps
self.val_check_steps = val_check_steps
self.step_size = step_size
+ # custom optimizer defined by user
+ if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):
+ raise TypeError(
+ "optimizer is not a valid subclass of torch.optim.Optimizer"
+ )
+ self.optimizer = optimizer
+ self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {}
# Scaler
self.scaler = TemporalNorm(
@@ -151,7 +162,20 @@ def on_fit_start(self):
random.seed(self.random_seed)
def configure_optimizers(self):
- optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
+ if self.optimizer:
+ optimizer_signature = inspect.signature(self.optimizer)
+ optimizer_kwargs = deepcopy(self.optimizer_kwargs)
+ if "lr" in optimizer_signature.parameters:
+ if "lr" in optimizer_kwargs:
+ warnings.warn(
+ "ignoring learning rate passed in optimizer_kwargs, using the model's learning rate"
+ )
+ optimizer_kwargs["lr"] = self.learning_rate
+ optimizer = self.optimizer(
+ params=self.parameters(), **self.optimizer_kwargs
+ )
+ else:
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = {
"scheduler": torch.optim.lr_scheduler.StepLR(
optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5
diff --git a/neuralforecast/common/_base_recurrent.py b/neuralforecast/common/_base_recurrent.py
index fec19ed70..a8086c297 100644
--- a/neuralforecast/common/_base_recurrent.py
+++ b/neuralforecast/common/_base_recurrent.py
@@ -4,6 +4,7 @@
__all__ = ['BaseRecurrent']
# %% ../../nbs/common.base_recurrent.ipynb 6
+import inspect
import random
import warnings
@@ -11,6 +12,7 @@
import torch
import torch.nn as nn
import pytorch_lightning as pl
+from copy import deepcopy
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from ._scalers import TemporalNorm
@@ -52,6 +54,8 @@ def __init__(
drop_last_loader=False,
random_seed=1,
alias=None,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs,
):
super(BaseRecurrent, self).__init__()
@@ -101,6 +105,13 @@ def __init__(
)
self.early_stop_patience_steps = early_stop_patience_steps
self.val_check_steps = val_check_steps
+ # custom optimizer defined by user
+ if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):
+ raise TypeError(
+ "optimizer is not a valid subclass of torch.optim.Optimizer"
+ )
+ self.optimizer = optimizer
+ self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {}
# Variables
self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []
@@ -163,7 +174,18 @@ def on_fit_start(self):
random.seed(self.random_seed)
def configure_optimizers(self):
- optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
+ if self.optimizer:
+ optimizer_signature = inspect.signature(self.optimizer)
+ optimizer_kwargs = deepcopy(self.optimizer_kwargs)
+ if "lr" in optimizer_signature.parameters:
+ if "lr" in optimizer_kwargs:
+ warnings.warn(
+ "ignoring learning rate passed in optimizer_kwargs, using the model's learning rate"
+ )
+ optimizer_kwargs["lr"] = self.learning_rate
+ optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)
+ else:
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = {
"scheduler": torch.optim.lr_scheduler.StepLR(
optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5
diff --git a/neuralforecast/common/_base_windows.py b/neuralforecast/common/_base_windows.py
index e4c64fe0a..c44273246 100644
--- a/neuralforecast/common/_base_windows.py
+++ b/neuralforecast/common/_base_windows.py
@@ -4,6 +4,7 @@
__all__ = ['BaseWindows']
# %% ../../nbs/common.base_windows.ipynb 5
+import inspect
import random
import warnings
@@ -11,6 +12,7 @@
import torch
import torch.nn as nn
import pytorch_lightning as pl
+from copy import deepcopy
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from ._scalers import TemporalNorm
@@ -56,6 +58,8 @@ def __init__(
drop_last_loader=False,
random_seed=1,
alias=None,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs,
):
super(BaseWindows, self).__init__()
@@ -107,6 +111,13 @@ def __init__(
self.val_check_steps = val_check_steps
self.windows_batch_size = windows_batch_size
self.step_size = step_size
+ # custom optimizer defined by user
+ if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):
+ raise TypeError(
+ "optimizer is not a valid subclass of torch.optim.Optimizer"
+ )
+ self.optimizer = optimizer
+ self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {}
# Variables
self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []
@@ -173,7 +184,18 @@ def on_fit_start(self):
random.seed(self.random_seed)
def configure_optimizers(self):
- optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
+ if self.optimizer:
+ optimizer_signature = inspect.signature(self.optimizer)
+ optimizer_kwargs = deepcopy(self.optimizer_kwargs)
+ if "lr" in optimizer_signature.parameters:
+ if "lr" in optimizer_kwargs:
+ warnings.warn(
+ "ignoring learning rate passed in optimizer_kwargs, using the model's learning rate"
+ )
+ optimizer_kwargs["lr"] = self.learning_rate
+ optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)
+ else:
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = {
"scheduler": torch.optim.lr_scheduler.StepLR(
optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5
diff --git a/neuralforecast/models/autoformer.py b/neuralforecast/models/autoformer.py
index b11bfb8e3..058ada27e 100644
--- a/neuralforecast/models/autoformer.py
+++ b/neuralforecast/models/autoformer.py
@@ -473,6 +473,8 @@ class Autoformer(BaseWindows):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
*References*
@@ -517,6 +519,8 @@ def __init__(
random_seed: int = 1,
num_workers_loader: int = 0,
drop_last_loader: bool = False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs,
):
super(Autoformer, self).__init__(
@@ -543,6 +547,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs,
)
diff --git a/neuralforecast/models/deepar.py b/neuralforecast/models/deepar.py
index 3779c6a4c..8b4e520b6 100644
--- a/neuralforecast/models/deepar.py
+++ b/neuralforecast/models/deepar.py
@@ -105,6 +105,8 @@ class DeepAR(BaseWindows):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
**References**
@@ -149,8 +151,11 @@ def __init__(
random_seed: int = 1,
num_workers_loader=0,
drop_last_loader=False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs
):
+
# DeepAR does not support historic exogenous variables
if hist_exog_list is not None:
raise Exception("DeepAR does not support historic exogenous variables.")
@@ -196,6 +201,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs
)
diff --git a/neuralforecast/models/dilated_rnn.py b/neuralforecast/models/dilated_rnn.py
index d2c572ed3..7b3b69caf 100644
--- a/neuralforecast/models/dilated_rnn.py
+++ b/neuralforecast/models/dilated_rnn.py
@@ -316,6 +316,8 @@ class DilatedRNN(BaseRecurrent):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
"""
@@ -350,6 +352,8 @@ def __init__(
random_seed: int = 1,
num_workers_loader: int = 0,
drop_last_loader: bool = False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs
):
super(DilatedRNN, self).__init__(
@@ -372,6 +376,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs
)
diff --git a/neuralforecast/models/dlinear.py b/neuralforecast/models/dlinear.py
index c0f336c2f..08fbd1464 100644
--- a/neuralforecast/models/dlinear.py
+++ b/neuralforecast/models/dlinear.py
@@ -75,6 +75,8 @@ class DLinear(BaseWindows):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
*References*
@@ -110,6 +112,8 @@ def __init__(
random_seed: int = 1,
num_workers_loader: int = 0,
drop_last_loader: bool = False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs
):
super(DLinear, self).__init__(
@@ -136,6 +140,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs
)
diff --git a/neuralforecast/models/fedformer.py b/neuralforecast/models/fedformer.py
index ad83203f9..df40bbd31 100644
--- a/neuralforecast/models/fedformer.py
+++ b/neuralforecast/models/fedformer.py
@@ -468,6 +468,8 @@ class FEDformer(BaseWindows):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
"""
@@ -511,6 +513,8 @@ def __init__(
random_seed: int = 1,
num_workers_loader: int = 0,
drop_last_loader: bool = False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs,
):
super(FEDformer, self).__init__(
@@ -536,6 +540,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs,
)
# Architecture
diff --git a/neuralforecast/models/gru.py b/neuralforecast/models/gru.py
index 3b4d7f9ac..c8290700a 100644
--- a/neuralforecast/models/gru.py
+++ b/neuralforecast/models/gru.py
@@ -51,6 +51,8 @@ class GRU(BaseRecurrent):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
"""
@@ -86,6 +88,8 @@ def __init__(
random_seed=1,
num_workers_loader=0,
drop_last_loader=False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs
):
super(GRU, self).__init__(
@@ -108,6 +112,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs
)
diff --git a/neuralforecast/models/informer.py b/neuralforecast/models/informer.py
index 964c50922..c81b3b36a 100644
--- a/neuralforecast/models/informer.py
+++ b/neuralforecast/models/informer.py
@@ -212,6 +212,8 @@ class Informer(BaseWindows):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
*References*
@@ -256,6 +258,8 @@ def __init__(
random_seed: int = 1,
num_workers_loader: int = 0,
drop_last_loader: bool = False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs,
):
super(Informer, self).__init__(
@@ -282,6 +286,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs,
)
diff --git a/neuralforecast/models/lstm.py b/neuralforecast/models/lstm.py
index db14a9e22..9db945c8b 100644
--- a/neuralforecast/models/lstm.py
+++ b/neuralforecast/models/lstm.py
@@ -51,6 +51,8 @@ class LSTM(BaseRecurrent):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
"""
@@ -85,6 +87,8 @@ def __init__(
random_seed=1,
num_workers_loader=0,
drop_last_loader=False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs
):
super(LSTM, self).__init__(
@@ -107,6 +111,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs
)
diff --git a/neuralforecast/models/mlp.py b/neuralforecast/models/mlp.py
index f781973e4..816b40005 100644
--- a/neuralforecast/models/mlp.py
+++ b/neuralforecast/models/mlp.py
@@ -49,6 +49,8 @@ class MLP(BaseWindows):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
"""
@@ -82,8 +84,11 @@ def __init__(
random_seed: int = 1,
num_workers_loader: int = 0,
drop_last_loader: bool = False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs
):
+
# Inherit BaseWindows class
super(MLP, self).__init__(
h=h,
@@ -109,6 +114,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs
)
diff --git a/neuralforecast/models/nbeats.py b/neuralforecast/models/nbeats.py
index 878756c24..dcb8fe8df 100644
--- a/neuralforecast/models/nbeats.py
+++ b/neuralforecast/models/nbeats.py
@@ -228,6 +228,8 @@ class NBEATS(BaseWindows):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
**References:**
@@ -267,6 +269,8 @@ def __init__(
random_seed: int = 1,
num_workers_loader: int = 0,
drop_last_loader: bool = False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs,
):
# Inherit BaseWindows class
@@ -290,6 +294,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs,
)
diff --git a/neuralforecast/models/nbeatsx.py b/neuralforecast/models/nbeatsx.py
index f3ce97150..3a1f9152c 100644
--- a/neuralforecast/models/nbeatsx.py
+++ b/neuralforecast/models/nbeatsx.py
@@ -309,6 +309,8 @@ class NBEATSx(BaseWindows):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
**References:**
@@ -352,6 +354,8 @@ def __init__(
random_seed: int = 1,
num_workers_loader: int = 0,
drop_last_loader: bool = False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs,
):
# Protect horizon collapsed seasonality and trend NBEATSx-i basis
@@ -385,6 +389,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs,
)
diff --git a/neuralforecast/models/nhits.py b/neuralforecast/models/nhits.py
index 603d9241e..eb128b383 100644
--- a/neuralforecast/models/nhits.py
+++ b/neuralforecast/models/nhits.py
@@ -222,6 +222,8 @@ class NHITS(BaseWindows):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
**References:**
@@ -267,6 +269,8 @@ def __init__(
random_seed: int = 1,
num_workers_loader=0,
drop_last_loader=False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs,
):
# Inherit BaseWindows class
@@ -294,6 +298,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs,
)
diff --git a/neuralforecast/models/patchtst.py b/neuralforecast/models/patchtst.py
index 67d58ac56..9355d69a2 100644
--- a/neuralforecast/models/patchtst.py
+++ b/neuralforecast/models/patchtst.py
@@ -865,6 +865,8 @@ class PatchTST(BaseWindows):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
**References:**
@@ -916,6 +918,8 @@ def __init__(
random_seed: int = 1,
num_workers_loader: int = 0,
drop_last_loader: bool = False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs
):
super(PatchTST, self).__init__(
@@ -942,6 +946,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs
)
# Asserts
diff --git a/neuralforecast/models/rnn.py b/neuralforecast/models/rnn.py
index c3d9eb003..6aadf0e5f 100644
--- a/neuralforecast/models/rnn.py
+++ b/neuralforecast/models/rnn.py
@@ -50,6 +50,8 @@ class RNN(BaseRecurrent):
`random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`alias`: str, optional, Custom name of the model.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
"""
@@ -86,6 +88,8 @@ def __init__(
random_seed=1,
num_workers_loader=0,
drop_last_loader=False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs
):
super(RNN, self).__init__(
@@ -108,6 +112,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs
)
diff --git a/neuralforecast/models/stemgnn.py b/neuralforecast/models/stemgnn.py
index 6df17e90b..6c427a4d9 100644
--- a/neuralforecast/models/stemgnn.py
+++ b/neuralforecast/models/stemgnn.py
@@ -161,6 +161,8 @@ class StemGNN(BaseMultivariate):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
"""
@@ -192,6 +194,8 @@ def __init__(
random_seed: int = 1,
num_workers_loader=0,
drop_last_loader=False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs
):
# Inherit BaseMultivariate class
@@ -215,6 +219,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs
)
diff --git a/neuralforecast/models/tcn.py b/neuralforecast/models/tcn.py
index 2cf270b04..7a31d1555 100644
--- a/neuralforecast/models/tcn.py
+++ b/neuralforecast/models/tcn.py
@@ -47,6 +47,8 @@ class TCN(BaseRecurrent):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
"""
@@ -81,6 +83,8 @@ def __init__(
random_seed: int = 1,
num_workers_loader=0,
drop_last_loader=False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs
):
super(TCN, self).__init__(
@@ -103,6 +107,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs
)
diff --git a/neuralforecast/models/tft.py b/neuralforecast/models/tft.py
index 9e96a14ce..914552b63 100644
--- a/neuralforecast/models/tft.py
+++ b/neuralforecast/models/tft.py
@@ -418,6 +418,8 @@ class TFT(BaseWindows):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
**References:**
@@ -457,6 +459,8 @@ def __init__(
num_workers_loader=0,
drop_last_loader=False,
random_seed: int = 1,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs
):
# Inherit BaseWindows class
@@ -480,6 +484,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs
)
self.example_length = input_size + h
diff --git a/neuralforecast/models/timesnet.py b/neuralforecast/models/timesnet.py
index 01f39e8c9..c1233b203 100644
--- a/neuralforecast/models/timesnet.py
+++ b/neuralforecast/models/timesnet.py
@@ -173,6 +173,10 @@ class TimesNet(BaseWindows):
Workers to be used by `TimeSeriesDataLoader`.
drop_last_loader : bool (default=False)
If True `TimeSeriesDataLoader` drops last non-full batch.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional (default=None)
+ User specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional (defualt=None)
+ List of parameters used by the user specified `optimizer`.
**trainer_kwargs
Keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer)
@@ -215,6 +219,8 @@ def __init__(
random_seed: int = 1,
num_workers_loader: int = 0,
drop_last_loader: bool = False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs
):
super(TimesNet, self).__init__(
@@ -241,6 +247,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs
)
diff --git a/neuralforecast/models/vanillatransformer.py b/neuralforecast/models/vanillatransformer.py
index 6f6dc201d..c754ab66b 100644
--- a/neuralforecast/models/vanillatransformer.py
+++ b/neuralforecast/models/vanillatransformer.py
@@ -113,6 +113,8 @@ class VanillaTransformer(BaseWindows):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
*References*
@@ -154,6 +156,8 @@ def __init__(
random_seed: int = 1,
num_workers_loader: int = 0,
drop_last_loader: bool = False,
+ optimizer=None,
+ optimizer_kwargs=None,
**trainer_kwargs,
):
super(VanillaTransformer, self).__init__(
@@ -179,6 +183,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs,
)