diff --git a/nbs/common.base_multivariate.ipynb b/nbs/common.base_multivariate.ipynb index 7e0eac5b6..1196a8694 100644 --- a/nbs/common.base_multivariate.ipynb +++ b/nbs/common.base_multivariate.ipynb @@ -54,6 +54,7 @@ "outputs": [], "source": [ "#| export\n", + "import inspect\n", "import random\n", "import warnings\n", "\n", @@ -61,6 +62,7 @@ "import torch\n", "import torch.nn as nn\n", "import pytorch_lightning as pl\n", + "from copy import deepcopy\n", "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", "\n", "from neuralforecast.common._scalers import TemporalNorm\n", @@ -107,6 +109,8 @@ " drop_last_loader=False,\n", " random_seed=1, \n", " alias=None,\n", + " optimizer=None,\n", + " optimizer_kwargs=None,\n", " **trainer_kwargs):\n", " super(BaseMultivariate, self).__init__()\n", "\n", @@ -140,6 +144,11 @@ " self.early_stop_patience_steps = early_stop_patience_steps\n", " self.val_check_steps = val_check_steps\n", " self.step_size = step_size\n", + " # custom optimizer defined by user\n", + " if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n", + " raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n", + " self.optimizer = optimizer\n", + " self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {} \n", "\n", " # Scaler\n", " self.scaler = TemporalNorm(scaler_type=scaler_type, dim=2) # Time dimension is in the second axis\n", @@ -202,7 +211,16 @@ " random.seed(self.random_seed)\n", " \n", " def configure_optimizers(self):\n", - " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " if self.optimizer:\n", + " optimizer_signature = inspect.signature(self.optimizer)\n", + " optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n", + " if 'lr' in optimizer_signature.parameters:\n", + " if 'lr' in optimizer_kwargs:\n", + " warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n", + " optimizer_kwargs['lr'] = self.learning_rate\n", + " optimizer = self.optimizer(params=self.parameters(), **self.optimizer_kwargs)\n", + " else:\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) \n", " scheduler = {'scheduler': torch.optim.lr_scheduler.StepLR(optimizer=optimizer,\n", " step_size=self.lr_decay_steps,\n", " gamma=0.5),\n", diff --git a/nbs/common.base_recurrent.ipynb b/nbs/common.base_recurrent.ipynb index 23ce2add0..600b6c5bd 100644 --- a/nbs/common.base_recurrent.ipynb +++ b/nbs/common.base_recurrent.ipynb @@ -60,6 +60,7 @@ "outputs": [], "source": [ "#| export\n", + "import inspect\n", "import random\n", "import warnings\n", "\n", @@ -67,6 +68,7 @@ "import torch\n", "import torch.nn as nn\n", "import pytorch_lightning as pl\n", + "from copy import deepcopy\n", "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", "\n", "from neuralforecast.common._scalers import TemporalNorm\n", @@ -113,6 +115,8 @@ " drop_last_loader=False,\n", " random_seed=1, \n", " alias=None,\n", + " optimizer=None,\n", + " optimizer_kwargs=None,\n", " **trainer_kwargs):\n", " super(BaseRecurrent, self).__init__()\n", "\n", @@ -154,6 +158,11 @@ " self.lr_decay_steps = max(max_steps // self.num_lr_decays, 1) if self.num_lr_decays > 0 else 10e7\n", " self.early_stop_patience_steps = early_stop_patience_steps\n", " self.val_check_steps = val_check_steps\n", + " # custom optimizer defined by user\n", + " if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n", + " raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n", + " self.optimizer = optimizer\n", + " self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {} \n", "\n", " # Variables\n", " self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []\n", @@ -214,7 +223,16 @@ " random.seed(self.random_seed)\n", " \n", " def configure_optimizers(self):\n", - " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " if self.optimizer:\n", + " optimizer_signature = inspect.signature(self.optimizer)\n", + " optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n", + " if 'lr' in optimizer_signature.parameters:\n", + " if 'lr' in optimizer_kwargs:\n", + " warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n", + " optimizer_kwargs['lr'] = self.learning_rate\n", + " optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n", + " else:\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " scheduler = {'scheduler': torch.optim.lr_scheduler.StepLR(optimizer=optimizer,\n", " step_size=self.lr_decay_steps,\n", " gamma=0.5),\n", diff --git a/nbs/common.base_windows.ipynb b/nbs/common.base_windows.ipynb index cc728b408..d80e87114 100644 --- a/nbs/common.base_windows.ipynb +++ b/nbs/common.base_windows.ipynb @@ -60,6 +60,7 @@ "outputs": [], "source": [ "#| export\n", + "import inspect\n", "import random\n", "import warnings\n", "\n", @@ -67,6 +68,7 @@ "import torch\n", "import torch.nn as nn\n", "import pytorch_lightning as pl\n", + "from copy import deepcopy\n", "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", "\n", "from neuralforecast.common._scalers import TemporalNorm\n", @@ -118,6 +120,8 @@ " drop_last_loader=False,\n", " random_seed=1,\n", " alias=None,\n", + " optimizer=None,\n", + " optimizer_kwargs=None,\n", " **trainer_kwargs):\n", " super(BaseWindows, self).__init__()\n", "\n", @@ -164,6 +168,11 @@ " self.val_check_steps = val_check_steps\n", " self.windows_batch_size = windows_batch_size\n", " self.step_size = step_size\n", + " # custom optimizer defined by user\n", + " if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n", + " raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n", + " self.optimizer = optimizer\n", + " self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {}\n", "\n", " # Variables\n", " self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []\n", @@ -228,7 +237,16 @@ " random.seed(self.random_seed)\n", " \n", " def configure_optimizers(self):\n", - " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " if self.optimizer:\n", + " optimizer_signature = inspect.signature(self.optimizer)\n", + " optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n", + " if 'lr' in optimizer_signature.parameters:\n", + " if 'lr' in optimizer_kwargs:\n", + " warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n", + " optimizer_kwargs['lr'] = self.learning_rate\n", + " optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n", + " else:\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " scheduler = {'scheduler': torch.optim.lr_scheduler.StepLR(optimizer=optimizer,\n", " step_size=self.lr_decay_steps,\n", " gamma=0.5),\n", diff --git a/nbs/core.ipynb b/nbs/core.ipynb index f03ca60b2..ce6d5f183 100644 --- a/nbs/core.ipynb +++ b/nbs/core.ipynb @@ -2433,6 +2433,85 @@ "test_pl_df3[0, \"static_1\"] = np.nan\n", "test_fail(lambda: nf.fit(pl_df, static_df=test_pl_df3), contains=\"Found missing values in ['static_1']\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "859a474c", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# test customized optimizer behavior such that the user defiend optimizer result should differ from default\n", + "# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n", + "\n", + "for nf_model in [NHITS, RNN, StemGNN]:\n", + " # default optimizer is based on Adam\n", + " params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 1}\n", + " if nf_model.__name__ == \"StemGNN\":\n", + " params.update({\"n_series\": 2})\n", + " models = [nf_model(**params)]\n", + " nf = NeuralForecast(models=models, freq='M')\n", + " nf.fit(AirPassengersPanel_train)\n", + " default_optimizer_predict = nf.predict()\n", + " mean = default_optimizer_predict.loc[:, nf_model.__name__].mean()\n", + "\n", + " # using a customized optimizer\n", + " params.update({\n", + " \"optimizer\": torch.optim.Adadelta,\n", + " \"optimizer_kwargs\": {\"rho\": 0.45}, \n", + " })\n", + " models2 = [nf_model(**params)]\n", + " nf2 = NeuralForecast(models=models2, freq='M')\n", + " nf2.fit(AirPassengersPanel_train)\n", + " customized_optimizer_predict = nf2.predict()\n", + " mean2 = customized_optimizer_predict.loc[:, nf_model.__name__].mean()\n", + " assert mean2 != mean" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3db3fe1e", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# test that if the user-defined optimizer is not a subclass of torch.optim.optimizer, failed with exception\n", + "# tests cover different types of base classes such as basewindows, baserecurrent, basemultivariate\n", + "test_fail(lambda: NHITS(h=12, input_size=24, max_steps=10, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n", + "test_fail(lambda: RNN(h=12, input_size=24, max_steps=10, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n", + "test_fail(lambda: StemGNN(h=12, input_size=24, max_steps=10, n_series=2, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d908240f", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# test that if we pass in \"lr\" parameter, we expect warning and it ignores the passed in 'lr' parameter\n", + "# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n", + "\n", + "for nf_model in [NHITS, RNN, StemGNN]:\n", + " params = {\n", + " \"h\": 12, \n", + " \"input_size\": 24, \n", + " \"max_steps\": 1, \n", + " \"optimizer\": torch.optim.Adadelta, \n", + " \"optimizer_kwargs\": {\"lr\": 0.8, \"rho\": 0.45}\n", + " }\n", + " if nf_model.__name__ == \"StemGNN\":\n", + " params.update({\"n_series\": 2})\n", + " models = [nf_model(**params)]\n", + " nf = NeuralForecast(models=models, freq='M')\n", + " with warnings.catch_warnings(record=True) as issued_warnings:\n", + " warnings.simplefilter('always', UserWarning)\n", + " nf.fit(AirPassengersPanel_train)\n", + " assert any(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\" in str(w.message) for w in issued_warnings)" + ] } ], "metadata": { diff --git a/nbs/models.autoformer.ipynb b/nbs/models.autoformer.ipynb index 478903365..27bd6a1aa 100644 --- a/nbs/models.autoformer.ipynb +++ b/nbs/models.autoformer.ipynb @@ -488,6 +488,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", "\n", "\t*References*
\n", @@ -530,6 +532,8 @@ " random_seed: int = 1,\n", " num_workers_loader: int = 0,\n", " drop_last_loader: bool = False,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", " **trainer_kwargs):\n", " super(Autoformer, self).__init__(h=h,\n", " input_size=input_size,\n", @@ -554,6 +558,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs)\n", "\n", " # Architecture\n", diff --git a/nbs/models.deepar.ipynb b/nbs/models.deepar.ipynb index 8117416e9..adeee0bdd 100644 --- a/nbs/models.deepar.ipynb +++ b/nbs/models.deepar.ipynb @@ -197,6 +197,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", "\n", " **References**
\n", @@ -237,6 +239,8 @@ " random_seed: int = 1,\n", " num_workers_loader = 0,\n", " drop_last_loader = False,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", " **trainer_kwargs):\n", "\n", " # DeepAR does not support historic exogenous variables\n", @@ -279,6 +283,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs)\n", "\n", " self.horizon_backup = self.h # Used because h=0 during training\n", diff --git a/nbs/models.dilated_rnn.ipynb b/nbs/models.dilated_rnn.ipynb index 12c418103..4970ee48e 100644 --- a/nbs/models.dilated_rnn.ipynb +++ b/nbs/models.dilated_rnn.ipynb @@ -390,6 +390,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", " \"\"\"\n", " # Class attributes\n", @@ -422,6 +424,8 @@ " random_seed: int = 1,\n", " num_workers_loader: int = 0,\n", " drop_last_loader: bool = False,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", " **trainer_kwargs):\n", " super(DilatedRNN, self).__init__(\n", " h=h,\n", @@ -443,6 +447,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs\n", " )\n", "\n", diff --git a/nbs/models.dlinear.ipynb b/nbs/models.dlinear.ipynb index eb899300e..f95aec9cf 100644 --- a/nbs/models.dlinear.ipynb +++ b/nbs/models.dlinear.ipynb @@ -162,6 +162,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", "\n", "\t*References*
\n", @@ -195,6 +197,8 @@ " random_seed: int = 1,\n", " num_workers_loader: int = 0,\n", " drop_last_loader: bool = False,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", " **trainer_kwargs):\n", " super(DLinear, self).__init__(h=h,\n", " input_size=input_size,\n", @@ -219,6 +223,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs)\n", "\n", " # Architecture\n", diff --git a/nbs/models.fedformer.ipynb b/nbs/models.fedformer.ipynb index 9fcce237c..c18503ac9 100644 --- a/nbs/models.fedformer.ipynb +++ b/nbs/models.fedformer.ipynb @@ -477,6 +477,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", "\n", " \"\"\"\n", @@ -518,6 +520,8 @@ " random_seed: int = 1,\n", " num_workers_loader: int = 0,\n", " drop_last_loader: bool = False,\n", + " optimizer=None,\n", + " optimizer_kwargs=None,\n", " **trainer_kwargs):\n", " super(FEDformer, self).__init__(h=h,\n", " input_size=input_size,\n", @@ -541,6 +545,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs)\n", " # Architecture\n", " self.futr_input_size = len(self.futr_exog_list)\n", diff --git a/nbs/models.gru.ipynb b/nbs/models.gru.ipynb index dab5fa84d..e52f2f29d 100644 --- a/nbs/models.gru.ipynb +++ b/nbs/models.gru.ipynb @@ -124,6 +124,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", " \"\"\"\n", " # Class attributes\n", @@ -157,6 +159,8 @@ " random_seed=1,\n", " num_workers_loader=0,\n", " drop_last_loader=False,\n", + " optimizer=None,\n", + " optimizer_kwargs=None,\n", " **trainer_kwargs):\n", " super(GRU, self).__init__(\n", " h=h,\n", @@ -178,6 +182,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs\n", " )\n", "\n", diff --git a/nbs/models.informer.ipynb b/nbs/models.informer.ipynb index e05756e50..2fc393642 100644 --- a/nbs/models.informer.ipynb +++ b/nbs/models.informer.ipynb @@ -297,6 +297,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", "\n", "\t*References*
\n", @@ -339,6 +341,8 @@ " random_seed: int = 1,\n", " num_workers_loader: int = 0,\n", " drop_last_loader: bool = False,\n", + " optimizer=None,\n", + " optimizer_kwargs=None,\n", " **trainer_kwargs):\n", " super(Informer, self).__init__(h=h,\n", " input_size=input_size,\n", @@ -363,6 +367,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs)\n", "\n", " # Architecture\n", diff --git a/nbs/models.lstm.ipynb b/nbs/models.lstm.ipynb index 5d70fca27..7c3c6ea6d 100644 --- a/nbs/models.lstm.ipynb +++ b/nbs/models.lstm.ipynb @@ -122,6 +122,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", " \"\"\"\n", " # Class attributes\n", @@ -154,6 +156,8 @@ " random_seed = 1,\n", " num_workers_loader = 0,\n", " drop_last_loader = False,\n", + " optimizer=None,\n", + " optimizer_kwargs=None,\n", " **trainer_kwargs):\n", " super(LSTM, self).__init__(\n", " h=h,\n", @@ -175,6 +179,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs\n", " )\n", "\n", diff --git a/nbs/models.mlp.ipynb b/nbs/models.mlp.ipynb index 8a4f22bcb..76db09f1a 100644 --- a/nbs/models.mlp.ipynb +++ b/nbs/models.mlp.ipynb @@ -114,6 +114,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", " \"\"\"\n", " # Class attributes\n", @@ -145,6 +147,8 @@ " random_seed: int = 1,\n", " num_workers_loader: int = 0,\n", " drop_last_loader: bool = False,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", " **trainer_kwargs):\n", "\n", " # Inherit BaseWindows class\n", @@ -171,6 +175,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs)\n", "\n", " # Architecture\n", diff --git a/nbs/models.nbeats.ipynb b/nbs/models.nbeats.ipynb index f154acfbe..2e9975b0e 100644 --- a/nbs/models.nbeats.ipynb +++ b/nbs/models.nbeats.ipynb @@ -270,6 +270,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", "\n", " **References:**
\n", @@ -307,6 +309,8 @@ " random_seed: int = 1,\n", " num_workers_loader: int = 0,\n", " drop_last_loader: bool = False,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", " **trainer_kwargs):\n", "\n", " # Inherit BaseWindows class\n", @@ -329,6 +333,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs)\n", "\n", " # Architecture\n", diff --git a/nbs/models.nbeatsx.ipynb b/nbs/models.nbeatsx.ipynb index 10d59d948..42042954d 100644 --- a/nbs/models.nbeatsx.ipynb +++ b/nbs/models.nbeatsx.ipynb @@ -413,6 +413,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", "\n", " **References:**
\n", @@ -456,6 +458,8 @@ " random_seed: int = 1,\n", " num_workers_loader: int = 0,\n", " drop_last_loader: bool = False,\n", + " optimizer=None,\n", + " optimizer_kwargs=None,\n", " **trainer_kwargs,\n", " ):\n", " # Protect horizon collapsed seasonality and trend NBEATSx-i basis\n", @@ -488,6 +492,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs)\n", "\n", " # Architecture\n", diff --git a/nbs/models.nhits.ipynb b/nbs/models.nhits.ipynb index 3d9dbba74..f00cd6c08 100644 --- a/nbs/models.nhits.ipynb +++ b/nbs/models.nhits.ipynb @@ -304,6 +304,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", "\n", " **References:**
\n", @@ -347,6 +349,8 @@ " random_seed: int = 1,\n", " num_workers_loader = 0,\n", " drop_last_loader = False,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", " **trainer_kwargs):\n", "\n", " # Inherit BaseWindows class\n", @@ -373,6 +377,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs)\n", "\n", " # Architecture\n", diff --git a/nbs/models.patchtst.ipynb b/nbs/models.patchtst.ipynb index 40d3ced3e..09d52ec9c 100644 --- a/nbs/models.patchtst.ipynb +++ b/nbs/models.patchtst.ipynb @@ -715,6 +715,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", "\n", " **References:**
\n", @@ -764,6 +766,8 @@ " random_seed: int = 1,\n", " num_workers_loader: int = 0,\n", " drop_last_loader: bool = False,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", " **trainer_kwargs):\n", " super(PatchTST, self).__init__(h=h,\n", " input_size=input_size,\n", @@ -788,6 +792,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs) \n", " # Asserts\n", " if stat_exog_list is not None:\n", diff --git a/nbs/models.rnn.ipynb b/nbs/models.rnn.ipynb index 6185bc90c..4ea6d92f1 100644 --- a/nbs/models.rnn.ipynb +++ b/nbs/models.rnn.ipynb @@ -125,7 +125,10 @@ " `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
\n", " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `alias`: str, optional, Custom name of the model.
\n", + "\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", " \"\"\"\n", " # Class attributes\n", @@ -159,6 +162,8 @@ " random_seed=1,\n", " num_workers_loader=0,\n", " drop_last_loader=False,\n", + " optimizer=None,\n", + " optimizer_kwargs=None,\n", " **trainer_kwargs):\n", " super(RNN, self).__init__(\n", " h=h,\n", @@ -180,6 +185,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs\n", " )\n", "\n", diff --git a/nbs/models.stemgnn.ipynb b/nbs/models.stemgnn.ipynb index d707063ca..9df803157 100644 --- a/nbs/models.stemgnn.ipynb +++ b/nbs/models.stemgnn.ipynb @@ -188,6 +188,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", " \"\"\"\n", " # Class attributes\n", @@ -217,6 +219,8 @@ " random_seed: int = 1,\n", " num_workers_loader = 0,\n", " drop_last_loader = False,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", " **trainer_kwargs):\n", "\n", " # Inherit BaseMultivariate class\n", @@ -239,6 +243,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs)\n", "\n", " # Exogenous variables\n", diff --git a/nbs/models.tcn.ipynb b/nbs/models.tcn.ipynb index 3dcf43003..f44072e1e 100644 --- a/nbs/models.tcn.ipynb +++ b/nbs/models.tcn.ipynb @@ -128,6 +128,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", " \"\"\"\n", " # Class attributes\n", @@ -160,6 +162,8 @@ " random_seed: int = 1,\n", " num_workers_loader = 0,\n", " drop_last_loader = False,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", " **trainer_kwargs):\n", " super(TCN, self).__init__(\n", " h=h,\n", @@ -181,6 +185,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs\n", " )\n", "\n", diff --git a/nbs/models.tft.ipynb b/nbs/models.tft.ipynb index 8839ac243..56e60d99b 100644 --- a/nbs/models.tft.ipynb +++ b/nbs/models.tft.ipynb @@ -678,6 +678,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", "\n", " **References:**
\n", @@ -715,6 +717,8 @@ " num_workers_loader = 0,\n", " drop_last_loader = False,\n", " random_seed: int = 1,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", " **trainer_kwargs\n", " ):\n", "\n", @@ -738,6 +742,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs)\n", " self.example_length = input_size + h\n", "\n", diff --git a/nbs/models.timesnet.ipynb b/nbs/models.timesnet.ipynb index ddc21bce3..09c4de132 100644 --- a/nbs/models.timesnet.ipynb +++ b/nbs/models.timesnet.ipynb @@ -257,6 +257,10 @@ " Workers to be used by `TimeSeriesDataLoader`.\n", " drop_last_loader : bool (default=False)\n", " If True `TimeSeriesDataLoader` drops last non-full batch.\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional (default=None)\n", + " User specified optimizer instead of the default choice (Adam).\n", + " `optimizer_kwargs`: dict, optional (defualt=None)\n", + " List of parameters used by the user specified `optimizer`.\n", " **trainer_kwargs\n", " Keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer)\n", "\n", @@ -297,6 +301,8 @@ " random_seed: int = 1,\n", " num_workers_loader: int = 0,\n", " drop_last_loader: bool = False,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", " **trainer_kwargs):\n", " super(TimesNet, self).__init__(h=h,\n", " input_size=input_size,\n", @@ -321,6 +327,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs)\n", "\n", " # Architecture\n", diff --git a/nbs/models.vanillatransformer.ipynb b/nbs/models.vanillatransformer.ipynb index 9689500e1..242994f81 100644 --- a/nbs/models.vanillatransformer.ipynb +++ b/nbs/models.vanillatransformer.ipynb @@ -195,6 +195,8 @@ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", "\n", "\t*References*
\n", @@ -234,6 +236,8 @@ " random_seed: int = 1,\n", " num_workers_loader: int = 0,\n", " drop_last_loader: bool = False,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", " **trainer_kwargs):\n", " super(VanillaTransformer, self).__init__(h=h,\n", " input_size=input_size,\n", @@ -257,6 +261,8 @@ " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs)\n", "\n", " # Architecture\n", diff --git a/neuralforecast/common/_base_multivariate.py b/neuralforecast/common/_base_multivariate.py index be891fcfa..490f458bd 100644 --- a/neuralforecast/common/_base_multivariate.py +++ b/neuralforecast/common/_base_multivariate.py @@ -4,6 +4,7 @@ __all__ = ['BaseMultivariate'] # %% ../../nbs/common.base_multivariate.ipynb 5 +import inspect import random import warnings @@ -11,6 +12,7 @@ import torch import torch.nn as nn import pytorch_lightning as pl +from copy import deepcopy from pytorch_lightning.callbacks.early_stopping import EarlyStopping from ._scalers import TemporalNorm @@ -52,6 +54,8 @@ def __init__( drop_last_loader=False, random_seed=1, alias=None, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs, ): super(BaseMultivariate, self).__init__() @@ -88,6 +92,13 @@ def __init__( self.early_stop_patience_steps = early_stop_patience_steps self.val_check_steps = val_check_steps self.step_size = step_size + # custom optimizer defined by user + if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer): + raise TypeError( + "optimizer is not a valid subclass of torch.optim.Optimizer" + ) + self.optimizer = optimizer + self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {} # Scaler self.scaler = TemporalNorm( @@ -151,7 +162,20 @@ def on_fit_start(self): random.seed(self.random_seed) def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + if self.optimizer: + optimizer_signature = inspect.signature(self.optimizer) + optimizer_kwargs = deepcopy(self.optimizer_kwargs) + if "lr" in optimizer_signature.parameters: + if "lr" in optimizer_kwargs: + warnings.warn( + "ignoring learning rate passed in optimizer_kwargs, using the model's learning rate" + ) + optimizer_kwargs["lr"] = self.learning_rate + optimizer = self.optimizer( + params=self.parameters(), **self.optimizer_kwargs + ) + else: + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) scheduler = { "scheduler": torch.optim.lr_scheduler.StepLR( optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5 diff --git a/neuralforecast/common/_base_recurrent.py b/neuralforecast/common/_base_recurrent.py index fec19ed70..a8086c297 100644 --- a/neuralforecast/common/_base_recurrent.py +++ b/neuralforecast/common/_base_recurrent.py @@ -4,6 +4,7 @@ __all__ = ['BaseRecurrent'] # %% ../../nbs/common.base_recurrent.ipynb 6 +import inspect import random import warnings @@ -11,6 +12,7 @@ import torch import torch.nn as nn import pytorch_lightning as pl +from copy import deepcopy from pytorch_lightning.callbacks.early_stopping import EarlyStopping from ._scalers import TemporalNorm @@ -52,6 +54,8 @@ def __init__( drop_last_loader=False, random_seed=1, alias=None, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs, ): super(BaseRecurrent, self).__init__() @@ -101,6 +105,13 @@ def __init__( ) self.early_stop_patience_steps = early_stop_patience_steps self.val_check_steps = val_check_steps + # custom optimizer defined by user + if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer): + raise TypeError( + "optimizer is not a valid subclass of torch.optim.Optimizer" + ) + self.optimizer = optimizer + self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {} # Variables self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else [] @@ -163,7 +174,18 @@ def on_fit_start(self): random.seed(self.random_seed) def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + if self.optimizer: + optimizer_signature = inspect.signature(self.optimizer) + optimizer_kwargs = deepcopy(self.optimizer_kwargs) + if "lr" in optimizer_signature.parameters: + if "lr" in optimizer_kwargs: + warnings.warn( + "ignoring learning rate passed in optimizer_kwargs, using the model's learning rate" + ) + optimizer_kwargs["lr"] = self.learning_rate + optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs) + else: + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) scheduler = { "scheduler": torch.optim.lr_scheduler.StepLR( optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5 diff --git a/neuralforecast/common/_base_windows.py b/neuralforecast/common/_base_windows.py index e4c64fe0a..c44273246 100644 --- a/neuralforecast/common/_base_windows.py +++ b/neuralforecast/common/_base_windows.py @@ -4,6 +4,7 @@ __all__ = ['BaseWindows'] # %% ../../nbs/common.base_windows.ipynb 5 +import inspect import random import warnings @@ -11,6 +12,7 @@ import torch import torch.nn as nn import pytorch_lightning as pl +from copy import deepcopy from pytorch_lightning.callbacks.early_stopping import EarlyStopping from ._scalers import TemporalNorm @@ -56,6 +58,8 @@ def __init__( drop_last_loader=False, random_seed=1, alias=None, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs, ): super(BaseWindows, self).__init__() @@ -107,6 +111,13 @@ def __init__( self.val_check_steps = val_check_steps self.windows_batch_size = windows_batch_size self.step_size = step_size + # custom optimizer defined by user + if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer): + raise TypeError( + "optimizer is not a valid subclass of torch.optim.Optimizer" + ) + self.optimizer = optimizer + self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {} # Variables self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else [] @@ -173,7 +184,18 @@ def on_fit_start(self): random.seed(self.random_seed) def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + if self.optimizer: + optimizer_signature = inspect.signature(self.optimizer) + optimizer_kwargs = deepcopy(self.optimizer_kwargs) + if "lr" in optimizer_signature.parameters: + if "lr" in optimizer_kwargs: + warnings.warn( + "ignoring learning rate passed in optimizer_kwargs, using the model's learning rate" + ) + optimizer_kwargs["lr"] = self.learning_rate + optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs) + else: + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) scheduler = { "scheduler": torch.optim.lr_scheduler.StepLR( optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5 diff --git a/neuralforecast/models/autoformer.py b/neuralforecast/models/autoformer.py index b11bfb8e3..058ada27e 100644 --- a/neuralforecast/models/autoformer.py +++ b/neuralforecast/models/autoformer.py @@ -473,6 +473,8 @@ class Autoformer(BaseWindows): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
*References*
@@ -517,6 +519,8 @@ def __init__( random_seed: int = 1, num_workers_loader: int = 0, drop_last_loader: bool = False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs, ): super(Autoformer, self).__init__( @@ -543,6 +547,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs, ) diff --git a/neuralforecast/models/deepar.py b/neuralforecast/models/deepar.py index 3779c6a4c..8b4e520b6 100644 --- a/neuralforecast/models/deepar.py +++ b/neuralforecast/models/deepar.py @@ -105,6 +105,8 @@ class DeepAR(BaseWindows): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
**References**
@@ -149,8 +151,11 @@ def __init__( random_seed: int = 1, num_workers_loader=0, drop_last_loader=False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs ): + # DeepAR does not support historic exogenous variables if hist_exog_list is not None: raise Exception("DeepAR does not support historic exogenous variables.") @@ -196,6 +201,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs ) diff --git a/neuralforecast/models/dilated_rnn.py b/neuralforecast/models/dilated_rnn.py index d2c572ed3..7b3b69caf 100644 --- a/neuralforecast/models/dilated_rnn.py +++ b/neuralforecast/models/dilated_rnn.py @@ -316,6 +316,8 @@ class DilatedRNN(BaseRecurrent): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
""" @@ -350,6 +352,8 @@ def __init__( random_seed: int = 1, num_workers_loader: int = 0, drop_last_loader: bool = False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs ): super(DilatedRNN, self).__init__( @@ -372,6 +376,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs ) diff --git a/neuralforecast/models/dlinear.py b/neuralforecast/models/dlinear.py index c0f336c2f..08fbd1464 100644 --- a/neuralforecast/models/dlinear.py +++ b/neuralforecast/models/dlinear.py @@ -75,6 +75,8 @@ class DLinear(BaseWindows): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
*References*
@@ -110,6 +112,8 @@ def __init__( random_seed: int = 1, num_workers_loader: int = 0, drop_last_loader: bool = False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs ): super(DLinear, self).__init__( @@ -136,6 +140,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs ) diff --git a/neuralforecast/models/fedformer.py b/neuralforecast/models/fedformer.py index ad83203f9..df40bbd31 100644 --- a/neuralforecast/models/fedformer.py +++ b/neuralforecast/models/fedformer.py @@ -468,6 +468,8 @@ class FEDformer(BaseWindows): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
""" @@ -511,6 +513,8 @@ def __init__( random_seed: int = 1, num_workers_loader: int = 0, drop_last_loader: bool = False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs, ): super(FEDformer, self).__init__( @@ -536,6 +540,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs, ) # Architecture diff --git a/neuralforecast/models/gru.py b/neuralforecast/models/gru.py index 3b4d7f9ac..c8290700a 100644 --- a/neuralforecast/models/gru.py +++ b/neuralforecast/models/gru.py @@ -51,6 +51,8 @@ class GRU(BaseRecurrent): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
""" @@ -86,6 +88,8 @@ def __init__( random_seed=1, num_workers_loader=0, drop_last_loader=False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs ): super(GRU, self).__init__( @@ -108,6 +112,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs ) diff --git a/neuralforecast/models/informer.py b/neuralforecast/models/informer.py index 964c50922..c81b3b36a 100644 --- a/neuralforecast/models/informer.py +++ b/neuralforecast/models/informer.py @@ -212,6 +212,8 @@ class Informer(BaseWindows): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
*References*
@@ -256,6 +258,8 @@ def __init__( random_seed: int = 1, num_workers_loader: int = 0, drop_last_loader: bool = False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs, ): super(Informer, self).__init__( @@ -282,6 +286,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs, ) diff --git a/neuralforecast/models/lstm.py b/neuralforecast/models/lstm.py index db14a9e22..9db945c8b 100644 --- a/neuralforecast/models/lstm.py +++ b/neuralforecast/models/lstm.py @@ -51,6 +51,8 @@ class LSTM(BaseRecurrent): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
""" @@ -85,6 +87,8 @@ def __init__( random_seed=1, num_workers_loader=0, drop_last_loader=False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs ): super(LSTM, self).__init__( @@ -107,6 +111,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs ) diff --git a/neuralforecast/models/mlp.py b/neuralforecast/models/mlp.py index f781973e4..816b40005 100644 --- a/neuralforecast/models/mlp.py +++ b/neuralforecast/models/mlp.py @@ -49,6 +49,8 @@ class MLP(BaseWindows): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
""" @@ -82,8 +84,11 @@ def __init__( random_seed: int = 1, num_workers_loader: int = 0, drop_last_loader: bool = False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs ): + # Inherit BaseWindows class super(MLP, self).__init__( h=h, @@ -109,6 +114,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs ) diff --git a/neuralforecast/models/nbeats.py b/neuralforecast/models/nbeats.py index 878756c24..dcb8fe8df 100644 --- a/neuralforecast/models/nbeats.py +++ b/neuralforecast/models/nbeats.py @@ -228,6 +228,8 @@ class NBEATS(BaseWindows): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
**References:**
@@ -267,6 +269,8 @@ def __init__( random_seed: int = 1, num_workers_loader: int = 0, drop_last_loader: bool = False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs, ): # Inherit BaseWindows class @@ -290,6 +294,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs, ) diff --git a/neuralforecast/models/nbeatsx.py b/neuralforecast/models/nbeatsx.py index f3ce97150..3a1f9152c 100644 --- a/neuralforecast/models/nbeatsx.py +++ b/neuralforecast/models/nbeatsx.py @@ -309,6 +309,8 @@ class NBEATSx(BaseWindows): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
**References:**
@@ -352,6 +354,8 @@ def __init__( random_seed: int = 1, num_workers_loader: int = 0, drop_last_loader: bool = False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs, ): # Protect horizon collapsed seasonality and trend NBEATSx-i basis @@ -385,6 +389,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs, ) diff --git a/neuralforecast/models/nhits.py b/neuralforecast/models/nhits.py index 603d9241e..eb128b383 100644 --- a/neuralforecast/models/nhits.py +++ b/neuralforecast/models/nhits.py @@ -222,6 +222,8 @@ class NHITS(BaseWindows): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
**References:**
@@ -267,6 +269,8 @@ def __init__( random_seed: int = 1, num_workers_loader=0, drop_last_loader=False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs, ): # Inherit BaseWindows class @@ -294,6 +298,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs, ) diff --git a/neuralforecast/models/patchtst.py b/neuralforecast/models/patchtst.py index 67d58ac56..9355d69a2 100644 --- a/neuralforecast/models/patchtst.py +++ b/neuralforecast/models/patchtst.py @@ -865,6 +865,8 @@ class PatchTST(BaseWindows): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
**References:**
@@ -916,6 +918,8 @@ def __init__( random_seed: int = 1, num_workers_loader: int = 0, drop_last_loader: bool = False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs ): super(PatchTST, self).__init__( @@ -942,6 +946,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs ) # Asserts diff --git a/neuralforecast/models/rnn.py b/neuralforecast/models/rnn.py index c3d9eb003..6aadf0e5f 100644 --- a/neuralforecast/models/rnn.py +++ b/neuralforecast/models/rnn.py @@ -50,6 +50,8 @@ class RNN(BaseRecurrent): `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`alias`: str, optional, Custom name of the model.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
""" @@ -86,6 +88,8 @@ def __init__( random_seed=1, num_workers_loader=0, drop_last_loader=False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs ): super(RNN, self).__init__( @@ -108,6 +112,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs ) diff --git a/neuralforecast/models/stemgnn.py b/neuralforecast/models/stemgnn.py index 6df17e90b..6c427a4d9 100644 --- a/neuralforecast/models/stemgnn.py +++ b/neuralforecast/models/stemgnn.py @@ -161,6 +161,8 @@ class StemGNN(BaseMultivariate): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
""" @@ -192,6 +194,8 @@ def __init__( random_seed: int = 1, num_workers_loader=0, drop_last_loader=False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs ): # Inherit BaseMultivariate class @@ -215,6 +219,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs ) diff --git a/neuralforecast/models/tcn.py b/neuralforecast/models/tcn.py index 2cf270b04..7a31d1555 100644 --- a/neuralforecast/models/tcn.py +++ b/neuralforecast/models/tcn.py @@ -47,6 +47,8 @@ class TCN(BaseRecurrent): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
""" @@ -81,6 +83,8 @@ def __init__( random_seed: int = 1, num_workers_loader=0, drop_last_loader=False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs ): super(TCN, self).__init__( @@ -103,6 +107,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs ) diff --git a/neuralforecast/models/tft.py b/neuralforecast/models/tft.py index 9e96a14ce..914552b63 100644 --- a/neuralforecast/models/tft.py +++ b/neuralforecast/models/tft.py @@ -418,6 +418,8 @@ class TFT(BaseWindows): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
**References:**
@@ -457,6 +459,8 @@ def __init__( num_workers_loader=0, drop_last_loader=False, random_seed: int = 1, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs ): # Inherit BaseWindows class @@ -480,6 +484,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs ) self.example_length = input_size + h diff --git a/neuralforecast/models/timesnet.py b/neuralforecast/models/timesnet.py index 01f39e8c9..c1233b203 100644 --- a/neuralforecast/models/timesnet.py +++ b/neuralforecast/models/timesnet.py @@ -173,6 +173,10 @@ class TimesNet(BaseWindows): Workers to be used by `TimeSeriesDataLoader`. drop_last_loader : bool (default=False) If True `TimeSeriesDataLoader` drops last non-full batch. + `optimizer`: Subclass of 'torch.optim.Optimizer', optional (default=None) + User specified optimizer instead of the default choice (Adam). + `optimizer_kwargs`: dict, optional (defualt=None) + List of parameters used by the user specified `optimizer`. **trainer_kwargs Keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer) @@ -215,6 +219,8 @@ def __init__( random_seed: int = 1, num_workers_loader: int = 0, drop_last_loader: bool = False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs ): super(TimesNet, self).__init__( @@ -241,6 +247,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs ) diff --git a/neuralforecast/models/vanillatransformer.py b/neuralforecast/models/vanillatransformer.py index 6f6dc201d..c754ab66b 100644 --- a/neuralforecast/models/vanillatransformer.py +++ b/neuralforecast/models/vanillatransformer.py @@ -113,6 +113,8 @@ class VanillaTransformer(BaseWindows): `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
*References*
@@ -154,6 +156,8 @@ def __init__( random_seed: int = 1, num_workers_loader: int = 0, drop_last_loader: bool = False, + optimizer=None, + optimizer_kwargs=None, **trainer_kwargs, ): super(VanillaTransformer, self).__init__( @@ -179,6 +183,8 @@ def __init__( num_workers_loader=num_workers_loader, drop_last_loader=drop_last_loader, random_seed=random_seed, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, **trainer_kwargs, )