Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish committed Oct 12, 2021
1 parent b30d9e5 commit f12eb07
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 54 deletions.
86 changes: 44 additions & 42 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(self, precision_plugin: PrecisionPlugin, training_type_plugin: Trai
def connect(self, model: "pl.LightningModule") -> None:
"""Transfers ownership of the model to this plugin.
See deprecation warning below.
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
`training_type_plugin.connect` directly.
"""
Expand All @@ -73,18 +75,10 @@ def connect(self, model: "pl.LightningModule") -> None:
def setup_environment(self) -> None:
"""Setup any processes or distributed connections.
.. deprecated:: v1.5
This method is deprecated in v1.5 and will be removed in v1.6.
Please call `training_type_plugin.setup_environment` directly.
This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator
environment before setup is complete.
"""
rank_zero_deprecation(
"`Accelerator.setup_environment` is deprecated in v1.5 and will be removed in v1.6. "
"`setup_environment` logic is implemented directly in the `TrainingTypePlugin` implementations."
)
self.training_type_plugin.setup_environment()
self.accelerator.setup_environment()

def setup(self, trainer: "pl.Trainer") -> None:
"""Setup plugins for the trainer fit and creates optimizers.
Expand Down Expand Up @@ -189,16 +183,8 @@ def root_device(self) -> torch.device:
def teardown(self) -> None:
"""This method is called to teardown the training process.
.. deprecated:: v1.5
This method is deprecated in v1.5 and will be removed in v1.6.
Please call `training_type_plugin.teardown` directly.
It is the right place to release memory and free other resources.
"""
rank_zero_deprecation(
"`Accelerator.teardown` is deprecated in v1.5 and will be removed in v1.6. "
"`teardown` logic is implemented directly in the `TrainingTypePlugin` implementations."
)
self.training_type_plugin.teardown()

def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
Expand Down Expand Up @@ -588,22 +574,17 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
raise NotImplementedError

def on_train_start(self) -> None:
"""Called when train begins.
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
`training_type_plugin.on_train_start` directly.
"""
rank_zero_deprecation(
"`Accelerator.on_train_start` is deprecated in v1.5 and will be removed in v1.6. "
"`on_train_start` logic is implemented directly in the `TrainingTypePlugin` implementations."
)
"""Called when train begins."""
return self.training_type_plugin.on_train_start()

def on_validation_start(self) -> None:
"""Called when validation begins.
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
`training_type_plugin.on_validation_start` directly.
See deprecation warning below.
.. deprecated:: v1.5
This method is deprecated in v1.5 and will be removed in v1.6.
Please call `training_type_plugin.on_validation_start` directly.
"""
rank_zero_deprecation(
"`Accelerator.on_validation_start` is deprecated in v1.5 and will be removed in v1.6. "
Expand All @@ -614,8 +595,11 @@ def on_validation_start(self) -> None:
def on_test_start(self) -> None:
"""Called when test begins.
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
`training_type_plugin.on_test_start` directly.
See deprecation warning below.
.. deprecated:: v1.5
This method is deprecated in v1.5 and will be removed in v1.6.
Please call `training_type_plugin.on_test_start` directly.
"""
rank_zero_deprecation(
"`Accelerator.on_test_start` is deprecated in v1.5 and will be removed in v1.6. "
Expand All @@ -626,8 +610,11 @@ def on_test_start(self) -> None:
def on_predict_start(self) -> None:
"""Called when predict begins.
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
`training_type_plugin.on_predict_start` directly.
See deprecation warning below.
.. deprecated:: v1.5
This method is deprecated in v1.5 and will be removed in v1.6.
Please call `training_type_plugin.on_predict_start` directly.
"""
rank_zero_deprecation(
"`Accelerator.on_predict_start` is deprecated in v1.5 and will be removed in v1.6. "
Expand All @@ -638,8 +625,11 @@ def on_predict_start(self) -> None:
def on_validation_end(self) -> None:
"""Called when validation ends.
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
`training_type_plugin.on_validation_end` directly.
See deprecation warning below.
.. deprecated:: v1.5
This method is deprecated in v1.5 and will be removed in v1.6.
Please call `training_type_plugin.on_validation_end` directly.
"""
rank_zero_deprecation(
"`Accelerator.on_validation_end` is deprecated in v1.5 and will be removed in v1.6. "
Expand All @@ -650,8 +640,11 @@ def on_validation_end(self) -> None:
def on_test_end(self) -> None:
"""Called when test end.
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
`training_type_plugin.on_test_end` directly.
See deprecation warning below.
.. deprecated:: v1.5
This method is deprecated in v1.5 and will be removed in v1.6.
Please call `training_type_plugin.on_test_end` directly.
"""
rank_zero_deprecation(
"`Accelerator.on_test_end` is deprecated in v1.5 and will be removed in v1.6. "
Expand All @@ -662,8 +655,11 @@ def on_test_end(self) -> None:
def on_predict_end(self) -> None:
"""Called when predict ends.
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
`training_type_plugin.on_predict_end` directly.
See deprecation warning below.
.. deprecated:: v1.5
This method is deprecated in v1.5 and will be removed in v1.6.
Please call `training_type_plugin.on_predict_end` directly.
"""
rank_zero_deprecation(
"`Accelerator.on_predict_end` is deprecated in v1.5 and will be removed in v1.6. "
Expand All @@ -674,8 +670,11 @@ def on_predict_end(self) -> None:
def on_train_end(self) -> None:
"""Called when train ends.
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
`training_type_plugin.on_train_end` directly.
See deprecation warning below.
.. deprecated:: v1.5
This method is deprecated in v1.5 and will be removed in v1.6.
Please call `training_type_plugin.on_train_end` directly.
"""
rank_zero_deprecation(
"`Accelerator.on_train_end` is deprecated in v1.5 and will be removed in v1.6. "
Expand All @@ -687,8 +686,11 @@ def on_train_end(self) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""Called in the training loop before anything happens for that batch.
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
`training_type_plugin.on_train_batch_start` directly.
See deprecation warning below.
.. deprecated:: v1.5
This method is deprecated in v1.5 and will be removed in v1.6.
Please call `training_type_plugin.on_train_batch_start` directly.
"""
rank_zero_deprecation(
"`Accelerator.on_train_batch_start` is deprecated in v1.5 and will be removed in v1.6. "
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
# SET UP TRAINING
# ----------------------------
self.call_hook("on_before_accelerator_backend_setup")
self.training_type_plugin.setup_environment()
self.accelerator.setup_environment()
self._call_setup_hook() # allow user to setup lightning_module in accelerator environment

# check if we should delay restoring checkpoint till later
Expand Down Expand Up @@ -1147,7 +1147,7 @@ def _post_dispatch(self):
self.accelerator.post_dispatch(self)
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
# which need to happen before.
self.training_type_plugin.teardown()
self.accelerator.teardown()
self.data_connector.teardown()
self._active_loop.teardown()
self.logger_connector.teardown()
Expand Down
8 changes: 1 addition & 7 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test_v1_6_0_deprecated_device_dtype_mixin_import():
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin # noqa: F401


def test_v1_6_0_deprecated_accelerator_collective():
def test_v1_6_0_deprecated_accelerator_pass_through_functions():
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type import SingleDevicePlugin

Expand All @@ -352,9 +352,6 @@ def test_v1_6_0_deprecated_accelerator_collective():
model = BoringModel()
accelerator.connect(model)

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.setup_environment()

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.teardown()

Expand Down Expand Up @@ -389,9 +386,6 @@ def test_v1_6_0_deprecated_accelerator_collective():
with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.restore_checkpoint_after_pre_dispatch

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.on_train_start()

with pytest.deprecated_call(match="will be removed in v1.6"):
accelerator.on_validation_start()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_invalid_on_cpu(tmpdir):
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins="fsdp")
assert isinstance(trainer.accelerator.training_type_plugin, DDPFullyShardedPlugin)
trainer.training_type_plugin.setup_environment()
trainer.accelerator.setup_environment()


@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
Expand Down
4 changes: 2 additions & 2 deletions tests/plugins/test_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_ddp_configure_ddp():
# test wrap the model if fitting
trainer.state.fn = TrainerFn.FITTING
trainer.training_type_plugin.connect(model)
trainer.training_type_plugin.setup_environment()
trainer.accelerator.setup_environment()
trainer.accelerator.setup(trainer)
trainer.lightning_module.trainer = trainer
assert isinstance(trainer.model, LightningModule)
Expand All @@ -123,7 +123,7 @@ def test_ddp_configure_ddp():
)
# test do not wrap the model if trainerFN is not fitting
trainer.training_type_plugin.connect(model)
trainer.training_type_plugin.setup_environment()
trainer.accelerator.setup_environment()
trainer.accelerator.setup(trainer)
trainer.lightning_module.trainer = trainer
trainer._pre_dispatch()
Expand Down

0 comments on commit f12eb07

Please sign in to comment.