Skip to content

Commit

Permalink
Remove call_configure_sharded_model lifecycle property (#9612)
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthsub committed Sep 24, 2021
1 parent 2b2537d commit 41e3be1
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 129 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated properties `DeepSpeedPlugin.cpu_offload*` in favor of `offload_optimizer`, `offload_parameters` and `pin_memory` ([#9244](https://github.com/PyTorchLightning/pytorch-lightning/pull/9244))


- Removed `call_configure_sharded_model_hook` property from `Accelerator` and `TrainingTypePlugin` ([#9612](https://github.com/PyTorchLightning/pytorch-lightning/pull/9612))


### Fixed


Expand Down
14 changes: 0 additions & 14 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,20 +401,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
"""
self.training_type_plugin.save_checkpoint(checkpoint, filepath)

@property
def call_configure_sharded_model_hook(self) -> bool:
"""Allow model parallel hook to be called in suitable environments determined by the training type plugin.
This is useful for when we want to shard the model once within fit.
Returns:
True if we want to call the model parallel setup hook.
"""
return self.training_type_plugin.call_configure_sharded_model_hook

@call_configure_sharded_model_hook.setter
def call_configure_sharded_model_hook(self, mode: bool) -> None:
self.training_type_plugin.call_configure_sharded_model_hook = mode

@property
def setup_optimizers_in_pre_dispatch(self) -> bool:
"""Override to delay setting optimizers and schedulers till after dispatch. This is useful when the
Expand Down
8 changes: 2 additions & 6 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,8 @@ def configure_sharded_model(self) -> None:
where we'd like to shard the model instantly, which is useful for extremely large models which can save
memory and initialization time.
The accelerator manages whether to call this hook at every given stage.
For sharded plugins where model parallelism is required, the hook is usually on called once
to initialize the sharded parameters, and not called again in the same process.
By default for accelerators/plugins that do not use model sharding techniques,
this hook is called during each fit/val/test/predict stages.
This hook is called during each of fit/val/test/predict stages in the same process, so ensure that
implementation of this hook is idempotent.
"""


Expand Down
11 changes: 0 additions & 11 deletions pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,6 @@ def wrap_policy(*args, **kwargs):
):
yield

def setup_environment(self) -> None:
super().setup_environment()
model_call_configure_sharded_model_hook = getattr(
self.lightning_module, "call_configure_sharded_model_hook", False
)
if not model_call_configure_sharded_model_hook:
# if model has not called configure sharded model, we reset
# the training type plugin's call_configure_sharded_model_hook
# to give trainer a chance to configure.
self.call_configure_sharded_model_hook = True

def configure_ddp(self) -> None:
if not self.cpu_offload:
# When using CPU Offload, FSDP will manage the CUDA movement for us.
Expand Down
14 changes: 0 additions & 14 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(self, checkpoint_io: Optional[CheckpointIO] = None) -> None:
self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None
checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO()
self._checkpoint_io = checkpoint_io
self._call_configure_sharded_model_hook = True

@property
def checkpoint_io(self) -> CheckpointIO:
Expand Down Expand Up @@ -281,19 +280,6 @@ def model_sharded_context(self) -> Generator:
"""
yield

@property
def call_configure_sharded_model_hook(self) -> bool:
"""Allow model parallel hook to be called in suitable environments determined by the training type plugin.
This is useful for when we want to shard the model once within fit.
Returns: True if we want to call the model parallel setup hook.
"""
return self._call_configure_sharded_model_hook

@call_configure_sharded_model_hook.setter
def call_configure_sharded_model_hook(self, mode: bool) -> None:
self._call_configure_sharded_model_hook = mode

@abstractmethod
def teardown(self) -> None:
"""This method is called to teardown the training process.
Expand Down
15 changes: 3 additions & 12 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,18 +1291,9 @@ def _call_setup_hook(self) -> None:
self.accelerator.barrier("post_setup")

def _call_configure_sharded_model(self) -> None:
# Call configure sharded model hook if accelerator requests. In some cases
# we will not call the hook; the hook has initialized the sharded model for example.

# used on the model if the user re-create a trainer with resume_from_checkpoint
model = self.lightning_module
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook:
with self.accelerator.model_sharded_context():
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")
model.call_configure_sharded_model_hook = True
self.accelerator.call_configure_sharded_model_hook = False
with self.accelerator.model_sharded_context():
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")

def _call_teardown_hook(self) -> None:
fn = self.state.fn._setup_fn
Expand Down
55 changes: 0 additions & 55 deletions tests/accelerators/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import SingleDevicePlugin
from tests.accelerators.test_dp import CustomClassificationModelDP
from tests.helpers.boring_model import BoringModel
from tests.helpers.datamodules import ClassifDataModule
Expand Down Expand Up @@ -77,57 +76,3 @@ def configure_sharded_model(self):
trainer.fit(model)

assert model.configure_sharded_model_called


class DummyModel(BoringModel):
def __init__(self):
super().__init__()
self.configure_sharded_model_called = False

def configure_sharded_model(self):
self.configure_sharded_model_called = True


def test_configure_sharded_model_false(tmpdir):
"""Ensure ``configure_sharded_model`` is not called, when turned off."""

class CustomPlugin(SingleDevicePlugin):
@property
def call_configure_sharded_model_hook(self) -> bool:
return False

model = DummyModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
plugins=CustomPlugin(device=torch.device("cpu")),
)
trainer.fit(model)

assert not model.configure_sharded_model_called


def test_accelerator_configure_sharded_model_called_once(tmpdir):
"""Ensure that the configure sharded model hook is called, and set to False after to ensure not called
again."""

model = DummyModel()
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1)
assert trainer.accelerator.call_configure_sharded_model_hook is True
trainer.fit(model)
assert trainer.accelerator.call_configure_sharded_model_hook is False


def test_configure_sharded_model_called_once(tmpdir):
"""Ensure ``configure_sharded_model`` is only called once."""

model = DummyModel()
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1)
trainer.fit(model)

assert model.configure_sharded_model_called
model.configure_sharded_model_called = False

assert not model.configure_sharded_model_called
29 changes: 13 additions & 16 deletions tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,29 @@ def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir):


class TestFSDPModel(BoringModel):
def setup(self, stage: str) -> None:
if stage != "fit":
# when running stages like test, validate, and predict, we will skip setting up,
# will directly use the module itself unless we load from checkpoint
return
# resetting call_configure_sharded_model_hook attribute so that we could call
# configure sharded model
self.call_configure_sharded_model_hook = False
# for loading full state dict, we first need to create a new unwrapped model
# to load state dict and then wrapping
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.layer: Optional[torch.nn.Module] = None

def _init_model(self) -> None:
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))

def setup(self, stage: str) -> None:
if self.layer is None:
self._init_model()

def configure_sharded_model(self) -> None:
# the model is already wrapped with FSDP: no need to wrap again!
if isinstance(self.layer, FullyShardedDataParallel):
return
for i, layer in enumerate(self.layer):
if i % 2 == 0:
self.layer[i] = wrap(layer)
self.layer = wrap(self.layer)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# when loading full state dict, we first need to create a new unwrapped model
self.setup("fit")
self._init_model()

def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
Expand Down Expand Up @@ -131,13 +133,8 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
trainer.fit(model)

model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
trainer_accelerator_call_configure_sharded_model_hook = trainer.accelerator.call_configure_sharded_model_hook

model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path

assert model_call_configure_sharded_model_hook
assert not trainer_accelerator_call_configure_sharded_model_hook
trainer.save_checkpoint(model_path, weights_only=True)

_assert_save_equality(trainer, model_path, cls=TestFSDPModel)
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir):
run_checkpoint_test(tmpdir)


@RunIf(min_gpus=1, deepspeed=True, special=False)
@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir):
"""Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning that the
optimizer state and scheduler states cannot be restored."""
Expand Down

0 comments on commit 41e3be1

Please sign in to comment.