From 41e3be197f5a2fd0f65b37b743ebfd157a55595d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 23 Sep 2021 18:57:53 -0700 Subject: [PATCH] Remove `call_configure_sharded_model` lifecycle property (#9612) --- CHANGELOG.md | 3 + pytorch_lightning/accelerators/accelerator.py | 14 ----- pytorch_lightning/core/hooks.py | 8 +-- .../plugins/training_type/fully_sharded.py | 11 ---- .../training_type/training_type_plugin.py | 14 ----- pytorch_lightning/trainer/trainer.py | 15 +---- tests/accelerators/test_common.py | 55 ------------------- ..._ddp_fully_sharded_with_full_state_dict.py | 29 +++++----- tests/plugins/test_deepspeed_plugin.py | 2 +- 9 files changed, 22 insertions(+), 129 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d85d89169a928..a8c8284fb1062 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -351,6 +351,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated properties `DeepSpeedPlugin.cpu_offload*` in favor of `offload_optimizer`, `offload_parameters` and `pin_memory` ([#9244](https://github.com/PyTorchLightning/pytorch-lightning/pull/9244)) +- Removed `call_configure_sharded_model_hook` property from `Accelerator` and `TrainingTypePlugin` ([#9612](https://github.com/PyTorchLightning/pytorch-lightning/pull/9612)) + + ### Fixed diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 3036fd83ebf22..c6aa2f75f7b81 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -401,20 +401,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None: """ self.training_type_plugin.save_checkpoint(checkpoint, filepath) - @property - def call_configure_sharded_model_hook(self) -> bool: - """Allow model parallel hook to be called in suitable environments determined by the training type plugin. - This is useful for when we want to shard the model once within fit. - - Returns: - True if we want to call the model parallel setup hook. - """ - return self.training_type_plugin.call_configure_sharded_model_hook - - @call_configure_sharded_model_hook.setter - def call_configure_sharded_model_hook(self, mode: bool) -> None: - self.training_type_plugin.call_configure_sharded_model_hook = mode - @property def setup_optimizers_in_pre_dispatch(self) -> bool: """Override to delay setting optimizers and schedulers till after dispatch. This is useful when the diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index f49b2f0fc0396..4f2161fd03afd 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -297,12 +297,8 @@ def configure_sharded_model(self) -> None: where we'd like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time. - The accelerator manages whether to call this hook at every given stage. - For sharded plugins where model parallelism is required, the hook is usually on called once - to initialize the sharded parameters, and not called again in the same process. - - By default for accelerators/plugins that do not use model sharding techniques, - this hook is called during each fit/val/test/predict stages. + This hook is called during each of fit/val/test/predict stages in the same process, so ensure that + implementation of this hook is idempotent. """ diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 72338e2923c07..74f30b76e383f 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -141,17 +141,6 @@ def wrap_policy(*args, **kwargs): ): yield - def setup_environment(self) -> None: - super().setup_environment() - model_call_configure_sharded_model_hook = getattr( - self.lightning_module, "call_configure_sharded_model_hook", False - ) - if not model_call_configure_sharded_model_hook: - # if model has not called configure sharded model, we reset - # the training type plugin's call_configure_sharded_model_hook - # to give trainer a chance to configure. - self.call_configure_sharded_model_hook = True - def configure_ddp(self) -> None: if not self.cpu_offload: # When using CPU Offload, FSDP will manage the CUDA movement for us. diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 13d6f93f5fb97..675b5bc953503 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -39,7 +39,6 @@ def __init__(self, checkpoint_io: Optional[CheckpointIO] = None) -> None: self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO() self._checkpoint_io = checkpoint_io - self._call_configure_sharded_model_hook = True @property def checkpoint_io(self) -> CheckpointIO: @@ -281,19 +280,6 @@ def model_sharded_context(self) -> Generator: """ yield - @property - def call_configure_sharded_model_hook(self) -> bool: - """Allow model parallel hook to be called in suitable environments determined by the training type plugin. - - This is useful for when we want to shard the model once within fit. - Returns: True if we want to call the model parallel setup hook. - """ - return self._call_configure_sharded_model_hook - - @call_configure_sharded_model_hook.setter - def call_configure_sharded_model_hook(self, mode: bool) -> None: - self._call_configure_sharded_model_hook = mode - @abstractmethod def teardown(self) -> None: """This method is called to teardown the training process. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 581ff11554cb3..7f8ca97b28647 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1291,18 +1291,9 @@ def _call_setup_hook(self) -> None: self.accelerator.barrier("post_setup") def _call_configure_sharded_model(self) -> None: - # Call configure sharded model hook if accelerator requests. In some cases - # we will not call the hook; the hook has initialized the sharded model for example. - - # used on the model if the user re-create a trainer with resume_from_checkpoint - model = self.lightning_module - model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) - if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook: - with self.accelerator.model_sharded_context(): - self.call_hook("configure_sharded_model") - self.call_hook("on_configure_sharded_model") - model.call_configure_sharded_model_hook = True - self.accelerator.call_configure_sharded_model_hook = False + with self.accelerator.model_sharded_context(): + self.call_hook("configure_sharded_model") + self.call_hook("on_configure_sharded_model") def _call_teardown_hook(self) -> None: fn = self.state.fn._setup_fn diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 61f0a1e247215..93564e27defa9 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -16,7 +16,6 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.plugins import SingleDevicePlugin from tests.accelerators.test_dp import CustomClassificationModelDP from tests.helpers.boring_model import BoringModel from tests.helpers.datamodules import ClassifDataModule @@ -77,57 +76,3 @@ def configure_sharded_model(self): trainer.fit(model) assert model.configure_sharded_model_called - - -class DummyModel(BoringModel): - def __init__(self): - super().__init__() - self.configure_sharded_model_called = False - - def configure_sharded_model(self): - self.configure_sharded_model_called = True - - -def test_configure_sharded_model_false(tmpdir): - """Ensure ``configure_sharded_model`` is not called, when turned off.""" - - class CustomPlugin(SingleDevicePlugin): - @property - def call_configure_sharded_model_hook(self) -> bool: - return False - - model = DummyModel() - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - plugins=CustomPlugin(device=torch.device("cpu")), - ) - trainer.fit(model) - - assert not model.configure_sharded_model_called - - -def test_accelerator_configure_sharded_model_called_once(tmpdir): - """Ensure that the configure sharded model hook is called, and set to False after to ensure not called - again.""" - - model = DummyModel() - trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1) - assert trainer.accelerator.call_configure_sharded_model_hook is True - trainer.fit(model) - assert trainer.accelerator.call_configure_sharded_model_hook is False - - -def test_configure_sharded_model_called_once(tmpdir): - """Ensure ``configure_sharded_model`` is only called once.""" - - model = DummyModel() - trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1) - trainer.fit(model) - - assert model.configure_sharded_model_called - model.configure_sharded_model_called = False - - assert not model.configure_sharded_model_called diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index c9c29d31c42ae..473c2dfb185a0 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -49,19 +49,21 @@ def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): class TestFSDPModel(BoringModel): - def setup(self, stage: str) -> None: - if stage != "fit": - # when running stages like test, validate, and predict, we will skip setting up, - # will directly use the module itself unless we load from checkpoint - return - # resetting call_configure_sharded_model_hook attribute so that we could call - # configure sharded model - self.call_configure_sharded_model_hook = False - # for loading full state dict, we first need to create a new unwrapped model - # to load state dict and then wrapping + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.layer: Optional[torch.nn.Module] = None + + def _init_model(self) -> None: self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + def setup(self, stage: str) -> None: + if self.layer is None: + self._init_model() + def configure_sharded_model(self) -> None: + # the model is already wrapped with FSDP: no need to wrap again! + if isinstance(self.layer, FullyShardedDataParallel): + return for i, layer in enumerate(self.layer): if i % 2 == 0: self.layer[i] = wrap(layer) @@ -69,7 +71,7 @@ def configure_sharded_model(self) -> None: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # when loading full state dict, we first need to create a new unwrapped model - self.setup("fit") + self._init_model() def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -131,13 +133,8 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel): def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): trainer.fit(model) - model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) - trainer_accelerator_call_configure_sharded_model_hook = trainer.accelerator.call_configure_sharded_model_hook - model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path - assert model_call_configure_sharded_model_hook - assert not trainer_accelerator_call_configure_sharded_model_hook trainer.save_checkpoint(model_path, weights_only=True) _assert_save_equality(trainer, model_path, cls=TestFSDPModel) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index c7ccaab3e72f4..a351237ec5b3a 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -603,7 +603,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir): run_checkpoint_test(tmpdir) -@RunIf(min_gpus=1, deepspeed=True, special=False) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): """Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning that the optimizer state and scheduler states cannot be restored."""