From bad9863a7767158dec31bc5f9f6ab6b415bfa777 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 23 Nov 2023 16:28:21 +0100 Subject: [PATCH 1/4] Reorder `configure_model` --- src/lightning/pytorch/trainer/trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 92b21436cfa1a..ae0fc7756fbf2 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -947,15 +947,14 @@ def _run( self.__setup_profiler() call._call_setup_hook(self) # allow user to setup lightning_module in accelerator environment + log.debug(f"{self.__class__.__name__}: configuring model") + call._call_configure_model(self) # check if we should delay restoring checkpoint till later if not self.strategy.restore_checkpoint_after_setup: log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}") self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path) - log.debug(f"{self.__class__.__name__}: configuring model") - call._call_configure_model(self) - # reset logger connector self._logger_connector.reset_results() self._logger_connector.reset_metrics() From 34366f3eefb0f0e660ef36aefd3d41c7f2604cdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 23 Nov 2023 17:08:33 +0100 Subject: [PATCH 2/4] Hook order test --- tests/tests_pytorch/models/test_hooks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 5794f790f700c..978fe0ab6b740 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -576,10 +576,10 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir): {"name": "prepare_data"}, {"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}}, {"name": "setup", "kwargs": {"stage": "fit"}}, + {"name": "configure_model"}, {"name": "on_load_checkpoint", "args": (loaded_ckpt,)}, {"name": "Callback.on_load_checkpoint", "args": (trainer, model, loaded_ckpt)}, {"name": "Callback.load_state_dict", "args": ({"foo": True},)}, - {"name": "configure_model"}, {"name": "configure_optimizers"}, {"name": "Callback.on_fit_start", "args": (trainer, model)}, {"name": "on_fit_start"}, @@ -654,10 +654,10 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir): {"name": "prepare_data"}, {"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}}, {"name": "setup", "kwargs": {"stage": "fit"}}, + {"name": "configure_model"}, {"name": "on_load_checkpoint", "args": (loaded_ckpt,)}, {"name": "Callback.on_load_checkpoint", "args": (trainer, model, loaded_ckpt)}, {"name": "Callback.load_state_dict", "args": ({"foo": True},)}, - {"name": "configure_model"}, {"name": "configure_optimizers"}, {"name": "Callback.on_fit_start", "args": (trainer, model)}, {"name": "on_fit_start"}, From 1610be9b8011ce82619f9e319f3683771ad6246c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 23 Nov 2023 20:34:56 +0100 Subject: [PATCH 3/4] Update docs --- docs/source-pytorch/common/lightning_module.rst | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index 3f3fd0758bc59..f358abb3b6a98 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -1065,22 +1065,19 @@ for more information. .. code-block:: python + # devices can be GPUs, TPUs, ... + @on_every_device def fit(self): + configure_callbacks() + if global_rank == 0: # prepare data is called on GLOBAL_ZERO only prepare_data() - configure_callbacks() - - with parallel(devices): - # devices can be GPUs, TPUs, ... - train_on_device(model) - - - def train_on_device(model): - # called PER DEVICE setup("fit") + configure_model() configure_optimizers() + on_fit_start() # the sanity check runs here From 8a2f2fa9f1a892338d62d3c882ec17ebd872a8c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 24 Nov 2023 03:50:27 +0100 Subject: [PATCH 4/4] Update lightning_module.rst --- docs/source-pytorch/common/lightning_module.rst | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index f358abb3b6a98..15e3af75d7aec 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -1065,13 +1065,11 @@ for more information. .. code-block:: python - # devices can be GPUs, TPUs, ... - @on_every_device + # runs on every device: devices can be GPUs, TPUs, ... def fit(self): configure_callbacks() - if global_rank == 0: - # prepare data is called on GLOBAL_ZERO only + if local_rank == 0: prepare_data() setup("fit")