From ddb03b066270d0474db43e298cc36eeee18eee6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Fri, 12 Jan 2024 13:01:33 -0600 Subject: [PATCH] add missing models to filename dict (#856) --- nbs/core.ipynb | 35 +++++++++++++++++------------- neuralforecast/core.py | 48 ++++++++++++++++++++++-------------------- 2 files changed, 45 insertions(+), 38 deletions(-) diff --git a/nbs/core.ipynb b/nbs/core.ipynb index eeb201034..5e6761771 100644 --- a/nbs/core.ipynb +++ b/nbs/core.ipynb @@ -198,21 +198,26 @@ "outputs": [], "source": [ "#| exporti\n", - "MODEL_FILENAME_DICT = {'gru': GRU, 'lstm': LSTM, 'rnn': RNN, \n", - " 'tcn': TCN, 'deepar': DeepAR, 'dilatedrnn': DilatedRNN,\n", - " 'mlp': MLP, 'nbeats': NBEATS, 'nbeatsx': NBEATSx, 'nhits': NHITS,\n", - " 'tft': TFT,\n", - " 'vanillatransformer': VanillaTransformer, 'informer': Informer, 'autoformer': Autoformer, 'patchtst': PatchTST,\n", - " 'stemgnn': StemGNN,\n", - " 'autogru': GRU, 'autolstm': LSTM, 'autornn': RNN,\n", - " 'autotcn': TCN, 'autodeepar': DeepAR, 'autodilatedrnn': DilatedRNN,\n", - " 'automlp': MLP, 'autonbeats': NBEATS, 'autonbeatsx': NBEATSx, 'autonhits': NHITS,\n", - " 'autotft': TFT,\n", - " 'autovanillatransformer': VanillaTransformer,'autoinformer': Informer, 'autoautoformer': Autoformer, 'autopatchtst': PatchTST,\n", - " 'autofedformer': FEDformer,\n", - " 'autostemgnn': StemGNN,\n", - " 'autotimesnet': TimesNet,\n", - " }" + "MODEL_FILENAME_DICT = {\n", + " 'autoformer': Autoformer, 'autoautoformer': Autoformer,\n", + " 'deepar': DeepAR, 'autodeepar': DeepAR,\n", + " 'dilatedrnn': DilatedRNN , 'autodilatedrnn': DilatedRNN,\n", + " 'fedformer': FEDformer, 'autofedformer': FEDformer,\n", + " 'gru': GRU, 'autogru': GRU,\n", + " 'informer': Informer, 'autoinformer': Informer,\n", + " 'lstm': LSTM, 'autolstm': LSTM,\n", + " 'mlp': MLP, 'automlp': MLP,\n", + " 'nbeats': NBEATS, 'autonbeats': NBEATS,\n", + " 'nbeatsx': NBEATSx, 'autonbeatsx': NBEATSx,\n", + " 'nhits': NHITS, 'autonhits': NHITS,\n", + " 'patchtst': PatchTST, 'autopatchtst': PatchTST,\n", + " 'rnn': RNN, 'autornn': RNN,\n", + " 'stemgnn': StemGNN, 'autostemgnn': StemGNN,\n", + " 'tcn': TCN, 'autotcn': TCN, \n", + " 'tft': TFT, 'autotft': TFT,\n", + " 'timesnet': TimesNet, 'autotimesnet': TimesNet,\n", + " 'vanillatransformer': VanillaTransformer, 'autovanillatransformer': VanillaTransformer,\n", + "}" ] }, { diff --git a/neuralforecast/core.py b/neuralforecast/core.py index 285253a08..86673dcb9 100644 --- a/neuralforecast/core.py +++ b/neuralforecast/core.py @@ -103,40 +103,42 @@ def _insample_times( # %% ../nbs/core.ipynb 7 MODEL_FILENAME_DICT = { - "gru": GRU, - "lstm": LSTM, - "rnn": RNN, - "tcn": TCN, + "autoformer": Autoformer, + "autoautoformer": Autoformer, "deepar": DeepAR, + "autodeepar": DeepAR, "dilatedrnn": DilatedRNN, - "mlp": MLP, - "nbeats": NBEATS, - "nbeatsx": NBEATSx, - "nhits": NHITS, - "tft": TFT, - "vanillatransformer": VanillaTransformer, - "informer": Informer, - "autoformer": Autoformer, - "patchtst": PatchTST, - "stemgnn": StemGNN, + "autodilatedrnn": DilatedRNN, + "fedformer": FEDformer, + "autofedformer": FEDformer, + "gru": GRU, "autogru": GRU, + "informer": Informer, + "autoinformer": Informer, + "lstm": LSTM, "autolstm": LSTM, - "autornn": RNN, - "autotcn": TCN, - "autodeepar": DeepAR, - "autodilatedrnn": DilatedRNN, + "mlp": MLP, "automlp": MLP, + "nbeats": NBEATS, "autonbeats": NBEATS, + "nbeatsx": NBEATSx, "autonbeatsx": NBEATSx, + "nhits": NHITS, "autonhits": NHITS, - "autotft": TFT, - "autovanillatransformer": VanillaTransformer, - "autoinformer": Informer, - "autoautoformer": Autoformer, + "patchtst": PatchTST, "autopatchtst": PatchTST, - "autofedformer": FEDformer, + "rnn": RNN, + "autornn": RNN, + "stemgnn": StemGNN, "autostemgnn": StemGNN, + "tcn": TCN, + "autotcn": TCN, + "tft": TFT, + "autotft": TFT, + "timesnet": TimesNet, "autotimesnet": TimesNet, + "vanillatransformer": VanillaTransformer, + "autovanillatransformer": VanillaTransformer, } # %% ../nbs/core.ipynb 8