Skip to content

Commit

Permalink
add missing models to filename dict (#856)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Jan 12, 2024
1 parent 68affcd commit ddb03b0
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 38 deletions.
35 changes: 20 additions & 15 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -198,21 +198,26 @@
"outputs": [],
"source": [
"#| exporti\n",
"MODEL_FILENAME_DICT = {'gru': GRU, 'lstm': LSTM, 'rnn': RNN, \n",
" 'tcn': TCN, 'deepar': DeepAR, 'dilatedrnn': DilatedRNN,\n",
" 'mlp': MLP, 'nbeats': NBEATS, 'nbeatsx': NBEATSx, 'nhits': NHITS,\n",
" 'tft': TFT,\n",
" 'vanillatransformer': VanillaTransformer, 'informer': Informer, 'autoformer': Autoformer, 'patchtst': PatchTST,\n",
" 'stemgnn': StemGNN,\n",
" 'autogru': GRU, 'autolstm': LSTM, 'autornn': RNN,\n",
" 'autotcn': TCN, 'autodeepar': DeepAR, 'autodilatedrnn': DilatedRNN,\n",
" 'automlp': MLP, 'autonbeats': NBEATS, 'autonbeatsx': NBEATSx, 'autonhits': NHITS,\n",
" 'autotft': TFT,\n",
" 'autovanillatransformer': VanillaTransformer,'autoinformer': Informer, 'autoautoformer': Autoformer, 'autopatchtst': PatchTST,\n",
" 'autofedformer': FEDformer,\n",
" 'autostemgnn': StemGNN,\n",
" 'autotimesnet': TimesNet,\n",
" }"
"MODEL_FILENAME_DICT = {\n",
" 'autoformer': Autoformer, 'autoautoformer': Autoformer,\n",
" 'deepar': DeepAR, 'autodeepar': DeepAR,\n",
" 'dilatedrnn': DilatedRNN , 'autodilatedrnn': DilatedRNN,\n",
" 'fedformer': FEDformer, 'autofedformer': FEDformer,\n",
" 'gru': GRU, 'autogru': GRU,\n",
" 'informer': Informer, 'autoinformer': Informer,\n",
" 'lstm': LSTM, 'autolstm': LSTM,\n",
" 'mlp': MLP, 'automlp': MLP,\n",
" 'nbeats': NBEATS, 'autonbeats': NBEATS,\n",
" 'nbeatsx': NBEATSx, 'autonbeatsx': NBEATSx,\n",
" 'nhits': NHITS, 'autonhits': NHITS,\n",
" 'patchtst': PatchTST, 'autopatchtst': PatchTST,\n",
" 'rnn': RNN, 'autornn': RNN,\n",
" 'stemgnn': StemGNN, 'autostemgnn': StemGNN,\n",
" 'tcn': TCN, 'autotcn': TCN, \n",
" 'tft': TFT, 'autotft': TFT,\n",
" 'timesnet': TimesNet, 'autotimesnet': TimesNet,\n",
" 'vanillatransformer': VanillaTransformer, 'autovanillatransformer': VanillaTransformer,\n",
"}"
]
},
{
Expand Down
48 changes: 25 additions & 23 deletions neuralforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,40 +103,42 @@ def _insample_times(

# %% ../nbs/core.ipynb 7
MODEL_FILENAME_DICT = {
"gru": GRU,
"lstm": LSTM,
"rnn": RNN,
"tcn": TCN,
"autoformer": Autoformer,
"autoautoformer": Autoformer,
"deepar": DeepAR,
"autodeepar": DeepAR,
"dilatedrnn": DilatedRNN,
"mlp": MLP,
"nbeats": NBEATS,
"nbeatsx": NBEATSx,
"nhits": NHITS,
"tft": TFT,
"vanillatransformer": VanillaTransformer,
"informer": Informer,
"autoformer": Autoformer,
"patchtst": PatchTST,
"stemgnn": StemGNN,
"autodilatedrnn": DilatedRNN,
"fedformer": FEDformer,
"autofedformer": FEDformer,
"gru": GRU,
"autogru": GRU,
"informer": Informer,
"autoinformer": Informer,
"lstm": LSTM,
"autolstm": LSTM,
"autornn": RNN,
"autotcn": TCN,
"autodeepar": DeepAR,
"autodilatedrnn": DilatedRNN,
"mlp": MLP,
"automlp": MLP,
"nbeats": NBEATS,
"autonbeats": NBEATS,
"nbeatsx": NBEATSx,
"autonbeatsx": NBEATSx,
"nhits": NHITS,
"autonhits": NHITS,
"autotft": TFT,
"autovanillatransformer": VanillaTransformer,
"autoinformer": Informer,
"autoautoformer": Autoformer,
"patchtst": PatchTST,
"autopatchtst": PatchTST,
"autofedformer": FEDformer,
"rnn": RNN,
"autornn": RNN,
"stemgnn": StemGNN,
"autostemgnn": StemGNN,
"tcn": TCN,
"autotcn": TCN,
"tft": TFT,
"autotft": TFT,
"timesnet": TimesNet,
"autotimesnet": TimesNet,
"vanillatransformer": VanillaTransformer,
"autovanillatransformer": VanillaTransformer,
}

# %% ../nbs/core.ipynb 8
Expand Down

0 comments on commit ddb03b0

Please sign in to comment.