From 71ee427013fd341ef19abcfac65e9f6b39aea803 Mon Sep 17 00:00:00 2001 From: Phil Gaudreau Date: Fri, 25 Aug 2023 16:29:53 -0700 Subject: [PATCH 1/2] Enables extra parameters to be passed to the load_from_checkpoint function # Summary: When someone trains a model using a GPU and then tries to this model on a machine with CPUs only. They are presented with the following error message ``` RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU. ``` If one specifies the following: ``` MODEL_FILENAME_DICT[model_name].load_from_checkpoint(f"{path}/{model}", map_location=torch.device('cpu')) ``` The error is remedied. This PR generalizes the error above, by allowing users to pass additional arguments to the `load_from_checkpoint` to avoid such complications. --- neuralforecast/core.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/neuralforecast/core.py b/neuralforecast/core.py index d195dd545..ed1416e66 100644 --- a/neuralforecast/core.py +++ b/neuralforecast/core.py @@ -662,7 +662,7 @@ def save( pickle.dump(config_dict, f) @staticmethod - def load(path, verbose=False): + def load(path, verbose=False, **kwargs): """Load NeuralForecast `core.NeuralForecast`'s method to load checkpoint from path. @@ -671,7 +671,10 @@ def load(path, verbose=False): ----------- path : str Directory to save current status. - + **kwargs : + Additional keyword arguments to be passed to the function + `load_from_checkpoint`. + Returns ------- result : NeuralForecast @@ -690,7 +693,7 @@ def load(path, verbose=False): for model in models_ckpt: model_name = model.split("_")[0] models.append( - MODEL_FILENAME_DICT[model_name].load_from_checkpoint(f"{path}/{model}") + MODEL_FILENAME_DICT[model_name].load_from_checkpoint(f"{path}/{model}", **kwargs) ) if verbose: print(f"Model {model_name} loaded.") From 77e924d9ec79f3257655521507f186a33dfc50d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Tue, 29 Aug 2023 10:49:53 -0600 Subject: [PATCH 2/2] update notebook --- nbs/core.ipynb | 43 +++++++++++++++++++++++++++--------------- neuralforecast/core.py | 8 +++++--- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/nbs/core.ipynb b/nbs/core.ipynb index 0ae9f8ce2..85ca772f5 100644 --- a/nbs/core.ipynb +++ b/nbs/core.ipynb @@ -714,7 +714,7 @@ " pickle.dump(config_dict, f)\n", "\n", " @staticmethod\n", - " def load(path, verbose=False):\n", + " def load(path, verbose=False, **kwargs):\n", " \"\"\"Load NeuralForecast\n", "\n", " `core.NeuralForecast`'s method to load checkpoint from path.\n", @@ -723,7 +723,10 @@ " -----------\n", " path : str\n", " Directory to save current status.\n", - " \n", + " kwargs\n", + " Additional keyword arguments to be passed to the function\n", + " `load_from_checkpoint`.\n", + "\n", " Returns\n", " -------\n", " result : NeuralForecast\n", @@ -740,7 +743,7 @@ " models = []\n", " for model in models_ckpt:\n", " model_name = model.split('_')[0]\n", - " models.append(MODEL_FILENAME_DICT[model_name].load_from_checkpoint(f\"{path}/{model}\"))\n", + " models.append(MODEL_FILENAME_DICT[model_name].load_from_checkpoint(f\"{path}/{model}\", **kwargs))\n", " if verbose: print(f\"Model {model_name} loaded.\")\n", "\n", " if verbose: print(10*'-' + ' Loading dataset ' + 10*'-')\n", @@ -833,6 +836,26 @@ "show_doc(NeuralForecast.predict_insample, title_level=3)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "93155738-b40f-43d3-ba76-d345bf2583d5", + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(NeuralForecast.save, title_level=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e915796-173c-4400-812f-c6351d5df3be", + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(NeuralForecast.load, title_level=3)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -911,8 +934,8 @@ "metadata": {}, "outputs": [], "source": [ - "#| hide\\n\",\n", - "# test fit+cross_validation behaviour\\n\",\n", + "#| hide\n", + "# test fit+cross_validation behaviour\n", "models = [NHITS(h=12, input_size=24, max_steps=10)]\n", "nf = NeuralForecast(models=models, freq='M')\n", "nf.fit(AirPassengersPanel_train)\n", @@ -1435,16 +1458,6 @@ "assert valid_losses[-1][1] > 30, 'Validation loss is too low'" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "8b83c08b", - "metadata": {}, - "outputs": [], - "source": [ - "AirPassengersPanel_train" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/neuralforecast/core.py b/neuralforecast/core.py index ed1416e66..0216ccb0e 100644 --- a/neuralforecast/core.py +++ b/neuralforecast/core.py @@ -671,10 +671,10 @@ def load(path, verbose=False, **kwargs): ----------- path : str Directory to save current status. - **kwargs : + **kwargs Additional keyword arguments to be passed to the function `load_from_checkpoint`. - + Returns ------- result : NeuralForecast @@ -693,7 +693,9 @@ def load(path, verbose=False, **kwargs): for model in models_ckpt: model_name = model.split("_")[0] models.append( - MODEL_FILENAME_DICT[model_name].load_from_checkpoint(f"{path}/{model}", **kwargs) + MODEL_FILENAME_DICT[model_name].load_from_checkpoint( + f"{path}/{model}", **kwargs + ) ) if verbose: print(f"Model {model_name} loaded.")