From 6b728713bb3b35ad58cd0085acaa443b33ab03ac Mon Sep 17 00:00:00 2001 From: shabie <30535146+shabie@users.noreply.github.com> Date: Thu, 18 Nov 2021 18:29:13 +0100 Subject: [PATCH 01/10] log metrics for correct dataloader only (#10522) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: tchaton Co-authored-by: Carlos Mocholí --- .../logger_connector/logger_connector.py | 17 +++++++++--- .../logging_/test_eval_loop_logging.py | 27 +++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 37fcb06a1dc24..640fc667705a8 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -154,6 +154,19 @@ def update_eval_step_metrics(self) -> None: # increment the step even if nothing was logged self._increment_eval_log_step() + @staticmethod + def _filter_metrics_for_dataloader( + dl_idx: int, metrics: Dict[str, Union[Any, Dict[str, Any]]], metric_prefix: str = "dataloader_idx" + ) -> Dict[str, Union[Any, Dict[str, Any]]]: + result = {} + for k, v in metrics.items(): + if metric_prefix not in k: + result[k] = v + continue + if k.endswith(f"{metric_prefix}_{dl_idx}"): + result[k] = v + return result + def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None: if self.trainer.sanity_checking: return @@ -162,9 +175,7 @@ def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None: has_been_initialized = len(self.eval_loop_results) == num_dataloaders for dl_idx in range(self.trainer._evaluation_loop.num_dataloaders): # remove callback metrics that don't belong to this dataloader - callback_metrics = { - k: v for k, v in metrics.items() if "dataloader_idx" not in k or f"dataloader_idx_{dl_idx}" in k - } + callback_metrics = self._filter_metrics_for_dataloader(dl_idx, metrics) if has_been_initialized: self.eval_loop_results[dl_idx].update(callback_metrics) else: diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 6ed40b5f03082..88229effbc8c9 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -23,6 +23,7 @@ from pytorch_lightning import callbacks, Trainer from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset @@ -672,3 +673,29 @@ def val_dataloader(self): enable_model_summary=False, ) trainer.fit(model) + + +@pytest.mark.parametrize( + ["kwargs", "expected"], + [ + ({"dl_idx": 0, "metrics": {"acc": 123}}, {"acc": 123}), + ( + {"dl_idx": 0, "metrics": {"acc/dataloader_idx_0": 123, "acc/dataloader_idx_1": 321}}, + {"acc/dataloader_idx_0": 123}, + ), + ( + {"dl_idx": 10, "metrics": {"acc/dataloader_idx_1": 123, "acc/dataloader_idx_10": 321}}, + {"acc/dataloader_idx_10": 321}, + ), + ( + {"dl_idx": 3, "metrics": {"top_3_acc/dataloader_idx_0": 123, "top_3_acc/dataloader_idx_3": 321}}, + {"top_3_acc/dataloader_idx_3": 321}, + ), + # theoretical case, as `/dataloader_idx_3` would have been added + ({"dl_idx": 3, "metrics": {"top_3_acc": 123}}, {"top_3_acc": 123}), + ], +) +def test_filter_metrics_for_dataloader(kwargs, expected): + """Logged metrics should only include metrics from the concerned dataloader.""" + actual = LoggerConnector._filter_metrics_for_dataloader(**kwargs) + assert actual == expected From 0f6d89422be4b1ea97a9b286164ec9ccb4e7a068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 18 Nov 2021 18:48:53 +0100 Subject: [PATCH 02/10] Control automatic resubmission on SLURM (#10601) --- CHANGELOG.md | 2 +- docs/source/clouds/cluster.rst | 8 +++++ .../environments/cluster_environment.py | 2 +- .../plugins/environments/slurm_environment.py | 11 +++++- .../trainer/connectors/signal_connector.py | 23 ++++-------- .../environments/test_slurm_environment.py | 1 + .../connectors/test_signal_connector.py | 36 +++++++++++++++++-- 7 files changed, 60 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e915d07b799b8..9ada4815bcf57 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/issues/10601)) - diff --git a/docs/source/clouds/cluster.rst b/docs/source/clouds/cluster.rst index 2d6b4e19d6e98..ce594c0b5ea2a 100644 --- a/docs/source/clouds/cluster.rst +++ b/docs/source/clouds/cluster.rst @@ -210,6 +210,14 @@ To get this behavior make sure to add the correct signal to your SLURM script # 90 seconds before training ends SBATCH --signal=SIGUSR1@90 +If auto-resubmit is not desired, it can be turned off in the :class:`~pytorch_lightning.plugins.environments.slurm_environment.SLURMEnvironment` plugin: + +.. code-block:: python + + from pytorch_lightning.plugins import SLURMEnvironment + + trainer = Trainer(plugins=[SLURMEnvironment(auto_requeue=False)]) + Building SLURM scripts ---------------------- diff --git a/pytorch_lightning/plugins/environments/cluster_environment.py b/pytorch_lightning/plugins/environments/cluster_environment.py index 1cf209c897cf4..af274bd176b14 100644 --- a/pytorch_lightning/plugins/environments/cluster_environment.py +++ b/pytorch_lightning/plugins/environments/cluster_environment.py @@ -23,7 +23,7 @@ class ClusterEnvironment(ABC): def __new__(cls, *args: Any, **kwargs: Any) -> "ClusterEnvironment": # TODO: remove in 1.7 _check_for_deprecated_methods(cls) - return super().__new__(cls, *args, **kwargs) + return super().__new__(cls) @property @abstractmethod diff --git a/pytorch_lightning/plugins/environments/slurm_environment.py b/pytorch_lightning/plugins/environments/slurm_environment.py index d9be5eda54c6b..4e7070be6b2f1 100644 --- a/pytorch_lightning/plugins/environments/slurm_environment.py +++ b/pytorch_lightning/plugins/environments/slurm_environment.py @@ -22,7 +22,16 @@ class SLURMEnvironment(ClusterEnvironment): - """Cluster environment for training on a cluster managed by SLURM.""" + """Cluster environment for training on a cluster managed by SLURM. + + Args: + auto_requeue: Whether automatic job resubmission is enabled or not. How and under which conditions a job gets + rescheduled gets determined by the owner of this plugin. + """ + + def __init__(self, auto_requeue: bool = True) -> None: + super().__init__() + self.auto_requeue = auto_requeue @property def creates_processes_externally(self) -> bool: diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index dc33d1244441f..795145b5be6af 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -8,6 +8,7 @@ from typing import Callable, List, Union import pytorch_lightning as pl +from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.utilities.imports import _fault_tolerant_training log = logging.getLogger(__name__) @@ -36,17 +37,18 @@ def register_signal_handlers(self) -> None: if _fault_tolerant_training(): sigusr1_handlers.append(self.fault_tolerant_sigusr1_handler_fn) - if self._is_on_slurm(): - log.info("Set SLURM handle signals.") + environment = self.trainer._accelerator_connector.cluster_environment + if isinstance(environment, SLURMEnvironment) and environment.auto_requeue: + log.info("SLURM auto-requeueing enabled. Setting signal handlers.") sigusr1_handlers.append(self.slurm_sigusr1_handler_fn) sigterm_handlers.append(self.sigterm_handler_fn) # signal.SIGUSR1 doesn't seem available on windows if not self._is_on_windows(): - if not self._has_already_handler(signal.SIGUSR1): + if sigusr1_handlers and not self._has_already_handler(signal.SIGUSR1): signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers)) - if not self._has_already_handler(signal.SIGTERM): + if sigterm_handlers and not self._has_already_handler(signal.SIGTERM): signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers)) def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: @@ -86,19 +88,6 @@ def fault_tolerant_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) - def sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None: log.info("bypassing sigterm") - def _is_on_slurm(self) -> bool: - # see if we're using slurm (not interactive) - on_slurm = False - try: - job_name = os.environ["SLURM_JOB_NAME"] - if job_name != "bash": - on_slurm = True - # todo: specify the possible exception - except Exception: - pass - - return on_slurm - def _is_on_windows(self) -> bool: return sys.platform == "win32" diff --git a/tests/plugins/environments/test_slurm_environment.py b/tests/plugins/environments/test_slurm_environment.py index 5515c6bfc4986..f2c726548eb24 100644 --- a/tests/plugins/environments/test_slurm_environment.py +++ b/tests/plugins/environments/test_slurm_environment.py @@ -52,6 +52,7 @@ def test_default_attributes(): def test_attributes_from_environment_variables(caplog): """Test that the SLURM cluster environment takes the attributes from the environment variables.""" env = SLURMEnvironment() + assert env.auto_requeue is True assert env.main_address == "1.1.1.1" assert env.main_port == 15000 + 1234 assert env.world_size() == 20 diff --git a/tests/trainer/connectors/test_signal_connector.py b/tests/trainer/connectors/test_signal_connector.py index aa5407e2f1228..76dae5e07db35 100644 --- a/tests/trainer/connectors/test_signal_connector.py +++ b/tests/trainer/connectors/test_signal_connector.py @@ -19,6 +19,8 @@ import pytest from pytorch_lightning import Trainer +from pytorch_lightning.plugins.environments import SLURMEnvironment +from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector from pytorch_lightning.utilities.exceptions import ExitGracefullyException from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -29,9 +31,6 @@ @RunIf(skip_windows=True) def test_fault_tolerant_sig_handler(register_handler, terminate_gracefully, tmpdir): - # hack to reset the signal - signal.signal(signal.SIGUSR1, 0) - if register_handler: def handler(*_): @@ -57,3 +56,34 @@ def training_step(self, batch, batch_idx): else: trainer.fit(model) assert trainer._terminate_gracefully == (False if register_handler else terminate_gracefully) + + # reset the signal to system defaults + signal.signal(signal.SIGUSR1, signal.SIG_DFL) + + +@RunIf(skip_windows=True) +@pytest.mark.parametrize("auto_requeue", (True, False)) +def test_auto_requeue_flag(auto_requeue): + sigterm_handler_default = signal.getsignal(signal.SIGTERM) + sigusr1_handler_default = signal.getsignal(signal.SIGUSR1) + + trainer = Trainer(plugins=[SLURMEnvironment(auto_requeue=auto_requeue)]) + connector = SignalConnector(trainer) + connector.register_signal_handlers() + + if auto_requeue: + sigterm_handlers = signal.getsignal(signal.SIGTERM).signal_handlers + assert len(sigterm_handlers) == 1 + assert sigterm_handlers[0].__qualname__ == "SignalConnector.sigterm_handler_fn" + + sigusr1_handlers = signal.getsignal(signal.SIGUSR1).signal_handlers + assert len(sigusr1_handlers) == 1 + assert sigusr1_handlers[0].__qualname__ == "SignalConnector.slurm_sigusr1_handler_fn" + else: + assert signal.getsignal(signal.SIGTERM) is sigterm_handler_default + assert signal.getsignal(signal.SIGUSR1) is sigusr1_handler_default + + # restore the signal handlers so the next test can run with system defaults + # TODO: should this be done in SignalConnector teardown? + signal.signal(signal.SIGTERM, sigterm_handler_default) + signal.signal(signal.SIGUSR1, sigusr1_handler_default) From 2c7c4aab8087d4c1c99c57c7acc66ef9a8e815d4 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 18 Nov 2021 10:51:54 -0800 Subject: [PATCH 03/10] Refactor progress bar initialization to avoid extra attribute set on Trainer (#10553) --- .../trainer/connectors/callback_connector.py | 47 ++++++++++--------- pytorch_lightning/trainer/trainer.py | 13 +++-- 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 6a54e973ffcf3..03926aeb9bc68 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -94,9 +94,7 @@ def on_trainer_init( " bar pass `enable_progress_bar = False` to the Trainer." ) - self.trainer._progress_bar_callback = self.configure_progress_bar( - progress_bar_refresh_rate, process_position, enable_progress_bar - ) + self.configure_progress_bar(progress_bar_refresh_rate, process_position, enable_progress_bar) # configure the ModelSummary callback self._configure_model_summary_callback(enable_model_summary, weights_summary) @@ -193,9 +191,10 @@ def _configure_model_summary_callback( ) max_depth = ModelSummaryMode.get_max_depth(weights_summary) - is_progress_bar_rich = isinstance(self.trainer._progress_bar_callback, RichProgressBar) + progress_bar_callback = self.trainer.progress_bar_callback + is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar) - if self.trainer._progress_bar_callback is not None and is_progress_bar_rich: + if progress_bar_callback is not None and is_progress_bar_rich: model_summary = RichModelSummary(max_depth=max_depth) else: model_summary = ModelSummary(max_depth=max_depth) @@ -214,12 +213,7 @@ def _configure_swa_callbacks(self): def configure_progress_bar( self, refresh_rate: Optional[int] = None, process_position: int = 0, enable_progress_bar: bool = True - ) -> Optional[ProgressBarBase]: - if os.getenv("COLAB_GPU") and refresh_rate is None: - # smaller refresh rate on colab causes crashes, choose a higher value - refresh_rate = 20 - refresh_rate = 1 if refresh_rate is None else refresh_rate - + ) -> None: progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)] if len(progress_bars) > 1: raise MisconfigurationException( @@ -227,19 +221,28 @@ def configure_progress_bar( " progress bar is supported." ) if len(progress_bars) == 1: + # the user specified the progress bar in the callbacks list + # so the trainer doesn't need to provide a default one + if enable_progress_bar: + return + + # otherwise the user specified a progress bar callback but also + # elected to disable the progress bar with the trainer flag progress_bar_callback = progress_bars[0] - if not enable_progress_bar: - raise MisconfigurationException( - "Trainer was configured with `enable_progress_bar=False`" - f" but found `{progress_bar_callback.__class__.__name__}` in callbacks list." - ) - elif refresh_rate > 0 and enable_progress_bar: - progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position) - self.trainer.callbacks.append(progress_bar_callback) - else: - progress_bar_callback = None + raise MisconfigurationException( + "Trainer was configured with `enable_progress_bar=False`" + f" but found `{progress_bar_callback.__class__.__name__}` in callbacks list." + ) + + # Return early if the user intends to disable the progress bar callback + if refresh_rate == 0 or not enable_progress_bar: + return + if refresh_rate is None: + # smaller refresh rate on colab causes crashes, choose a higher value + refresh_rate = 20 if os.getenv("COLAB_GPU") else 1 - return progress_bar_callback + progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position) + self.trainer.callbacks.append(progress_bar_callback) def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None: if max_time is None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index be9c71e2fe470..26fbcb4362c73 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1763,10 +1763,6 @@ def data_parallel(self) -> bool: _StrategyType.DDP2, ) - @property - def progress_bar_callback(self) -> Optional[ProgressBarBase]: - return self._progress_bar_callback - @property def progress_bar_dict(self) -> dict: """Read-only for progress bar metrics.""" @@ -1845,6 +1841,15 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]: in the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] + @property + def progress_bar_callback(self) -> Optional[ProgressBarBase]: + """An instance of :class:`~pytorch_lightning.callbacks.progress.base.ProgressBarBase` found in the + Trainer.callbacks list, or ``None`` if one doesn't exist.""" + for c in self.callbacks: + if isinstance(c, ProgressBarBase): + return c + return None + @property def resume_from_checkpoint(self) -> Optional[Union[str, Path]]: resume_from_checkpoint = self.checkpoint_connector.resume_from_checkpoint_fit_path From 700521c7d353f79f38d1146cffc63bbe23df8e66 Mon Sep 17 00:00:00 2001 From: four4fish <88516121+four4fish@users.noreply.github.com> Date: Thu, 18 Nov 2021 16:39:01 -0800 Subject: [PATCH 04/10] 1/n Move precision plugin into strategy - update reference (#10570) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 1/n move precision plugin into strategy - update reference * update precision plugin reference in tpu_spawn * add missing reference in error message * add back removed license line * update references in tests * update reference in trainer * update return annotation for precision_plugin property on TTP * simplify access to precision plugin reference in sharded plug * add changelog * remove precision property from ttp and add deprecation message * fix make doc and update precision reference * simplify a reference to precision accidentally overridden Adrian's change, now add it back * Update CHANGELOG.md add Adrian's change back * Update accelerator precision Add Adrian's change back * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add none check for precision plugin just to be safe * Update ipu.py * update precision_plugin param deprecation message * Update accelerator.py * Remove deprecated warning Tests will fail after 9940 Co-authored-by: Adrian Wälchli Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 6 +- pytorch_lightning/accelerators/accelerator.py | 59 +++++++++++++------ pytorch_lightning/accelerators/tpu.py | 5 +- pytorch_lightning/lite/lite.py | 2 +- .../plugins/training_type/ddp.py | 3 + .../plugins/training_type/ddp_spawn.py | 3 + .../plugins/training_type/deepspeed.py | 8 ++- pytorch_lightning/plugins/training_type/dp.py | 9 ++- .../plugins/training_type/fully_sharded.py | 5 +- .../plugins/training_type/horovod.py | 9 ++- .../plugins/training_type/ipu.py | 6 +- .../plugins/training_type/parallel.py | 4 +- .../plugins/training_type/sharded.py | 2 +- .../plugins/training_type/sharded_spawn.py | 5 +- .../plugins/training_type/single_device.py | 4 +- .../plugins/training_type/single_tpu.py | 4 +- .../plugins/training_type/tpu_spawn.py | 8 ++- .../training_type/training_type_plugin.py | 10 +++- .../connectors/accelerator_connector.py | 13 ++-- pytorch_lightning/trainer/trainer.py | 4 +- tests/accelerators/test_ipu.py | 8 +-- tests/accelerators/test_tpu.py | 16 ++++- ..._ddp_fully_sharded_with_full_state_dict.py | 4 +- tests/plugins/test_deepspeed_plugin.py | 4 +- 24 files changed, 142 insertions(+), 59 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ada4815bcf57..ae0515cf22703 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520)) +- Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) + + - @@ -50,7 +53,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `DistributedType` in favor of `_StrategyType` ([#10505](https://github.com/PyTorchLightning/pytorch-lightning/pull/10505)) -- +- Deprecated the `precision_plugin` constructor argument from `Accelerator` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) - @@ -139,6 +142,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `reload_dataloaders_every_epoch` from `Trainer` in favour of `reload_dataloaders_every_n_epochs` ([#10481](https://github.com/PyTorchLightning/pytorch-lightning/pull/10481)) +- Removed the `precision_plugin` attribute from `Accelerator` in favor of its equivalent attribute `precision_plugin` in the `TrainingTypePlugin` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) ### Fixed diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 14b6a47c7243f..eb3886b209503 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -25,6 +25,7 @@ from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.enums import AMPType, LightningEnum from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -44,15 +45,23 @@ class Accelerator: One to handle differences from the training routine and one to handle different precisions. """ - def __init__(self, precision_plugin: PrecisionPlugin, training_type_plugin: TrainingTypePlugin) -> None: + def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_plugin: TrainingTypePlugin) -> None: """ Args: precision_plugin: the plugin to handle precision-specific parts + + .. deprecated:: + The ``precision_plugin`` parameter has been deprecated and will be removed soon. + Pass the precision plugin as a parameter to the ``TrainingTypePlugin`` instead. + training_type_plugin: the plugin to handle different training routines """ - self.precision_plugin = precision_plugin + self.training_type_plugin = training_type_plugin + if precision_plugin is not None: + self.training_type_plugin._precision_plugin = precision_plugin + self.optimizers: List = [] self.lr_schedulers: List = [] self.optimizer_frequencies: List = [] @@ -84,7 +93,7 @@ def pre_dispatch(self, trainer: "pl.Trainer") -> None: if self.training_type_plugin.setup_optimizers_in_pre_dispatch: self.setup_optimizers(trainer) - self.precision_plugin.pre_dispatch() + self.training_type_plugin.precision_plugin.pre_dispatch() def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """Moves the state of the optimizers to the GPU if needed.""" @@ -96,12 +105,12 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: def dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.dispatch(trainer) - self.precision_plugin.dispatch(trainer) + self.training_type_plugin.precision_plugin.dispatch(trainer) def post_dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something after the training/evaluation/prediction starts.""" self.training_type_plugin.post_dispatch(trainer) - self.precision_plugin.post_dispatch() + self.training_type_plugin.precision_plugin.post_dispatch() @property def model(self) -> Module: @@ -159,7 +168,7 @@ def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details """ - with self.precision_plugin.train_step_context(): + with self.training_type_plugin.precision_plugin.train_step_context(): return self.training_type_plugin.training_step(*step_kwargs.values()) def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: @@ -167,7 +176,7 @@ def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[S See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details """ - with self.precision_plugin.val_step_context(): + with self.training_type_plugin.precision_plugin.val_step_context(): return self.training_type_plugin.validation_step(*step_kwargs.values()) def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: @@ -175,7 +184,7 @@ def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OU See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details """ - with self.precision_plugin.test_step_context(): + with self.training_type_plugin.precision_plugin.test_step_context(): return self.training_type_plugin.test_step(*step_kwargs.values()) def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: @@ -183,7 +192,7 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details """ - with self.precision_plugin.predict_step_context(): + with self.training_type_plugin.precision_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: @@ -193,11 +202,11 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: closure_loss: a tensor holding the loss value to backpropagate """ self.training_type_plugin.pre_backward(closure_loss) - closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss) + closure_loss = self.training_type_plugin.precision_plugin.pre_backward(self.lightning_module, closure_loss) - self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs) + self.training_type_plugin.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs) - closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss) + closure_loss = self.training_type_plugin.precision_plugin.post_backward(self.lightning_module, closure_loss) self.training_type_plugin.post_backward(closure_loss) return closure_loss @@ -208,7 +217,7 @@ def optimizer_step( opt_idx: int, closure: Callable[[], Any], model: Optional[Union["pl.LightningModule", Module]] = None, - **kwargs: Any + **kwargs: Any, ) -> None: """performs the actual optimizer step. @@ -220,7 +229,7 @@ def optimizer_step( **kwargs: Any extra arguments to ``optimizer.step`` """ model = model or self.lightning_module - self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) + self.training_type_plugin.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: """Zeros all model parameter's gradients.""" @@ -248,26 +257,38 @@ def setup_training_type_plugin(self) -> None: def setup_precision_plugin(self) -> None: """Attaches the precision plugin to the accelerator.""" - model, optimizers, schedulers = self.precision_plugin.connect(self.model, self.optimizers, self.lr_schedulers) + model, optimizers, schedulers = self.training_type_plugin.precision_plugin.connect( + self.model, self.optimizers, self.lr_schedulers + ) self.model = model self.optimizers = optimizers self.lr_schedulers = schedulers @property def amp_backend(self) -> Optional[LightningEnum]: - if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): + if isinstance(self.training_type_plugin.precision_plugin, ApexMixedPrecisionPlugin): return AMPType.APEX - if isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): + if isinstance(self.training_type_plugin.precision_plugin, NativeMixedPrecisionPlugin): return AMPType.NATIVE return None @property def precision(self) -> Union[str, int]: - return self.precision_plugin.precision + """The type of precision being used with this accelerator. + + .. deprecated:: + This property been deprecated and will be removed soon. + Use ``training_type_plugin.precision_plugin.precision`` instead. + """ + rank_zero_deprecation( + f"`{self.__class__.__name__}.precision` has been deprecated and will be removed soon" + f" Use `training_type_plugin.precision_plugin.precision` instead." + ) + return self.training_type_plugin.precision_plugin.precision @property def scaler(self) -> Optional["GradScaler"]: - return getattr(self.precision_plugin, "scaler", None) + return getattr(self.training_type_plugin.precision_plugin, "scaler", None) def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: """Returns state of an optimizer. diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 6e824a25f6b9d..673e8419ca7fb 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -36,10 +36,11 @@ def setup(self, trainer: "pl.Trainer") -> None: ValueError: If the precision or training type plugin are unsupported. """ - if not isinstance(self.precision_plugin, TPUPrecisionPlugin): + if not isinstance(self.training_type_plugin.precision_plugin, TPUPrecisionPlugin): # this configuration should have been avoided in the accelerator connector raise ValueError( - f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`, found: {self.precision_plugin}." + f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`," + f" found: {self.training_type_plugin.precision_plugin}." ) if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)): raise ValueError( diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 2a2ed9586b420..bb07c763156aa 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -108,7 +108,7 @@ def __init__( ) self._accelerator = self._accelerator_connector.accelerator self._strategy = self._accelerator.training_type_plugin - self._precision_plugin = self._accelerator.precision_plugin + self._precision_plugin = self._strategy.precision_plugin self._models_setup: int = 0 # wrap the run method so we can inject setup logic or spawn processes for the user diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 0285859a6714a..6d1b168d5ac7a 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -36,6 +36,7 @@ from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import ( @@ -86,6 +87,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, @@ -96,6 +98,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, ) self.interactive_ddp_procs = [] self._num_nodes = 1 diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index a77027adb6dcf..da724944ade7e 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -29,6 +29,7 @@ from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn @@ -65,6 +66,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, @@ -74,6 +76,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, ) self._num_nodes = 1 self.sync_batchnorm = False diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index eb087ad199808..01959bdcee212 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -30,6 +30,7 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.trainer.states import TrainerFn @@ -129,6 +130,7 @@ def __init__( synchronize_checkpoint_boundary: bool = False, load_full_weights: bool = False, partition_module: bool = True, + precision_plugin: Optional[PrecisionPlugin] = None, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. `For more information: https://pytorch- @@ -273,6 +275,7 @@ def __init__( super().__init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, + precision_plugin=precision_plugin, ) self.config = self._load_config(config) @@ -331,7 +334,7 @@ def __init__( @property def precision(self) -> Union[str, int]: - return self._precision or self.lightning_module.trainer.precision + return self._precision or self.precision_plugin.precision @property def amp_level(self) -> Optional[str]: @@ -456,8 +459,7 @@ def init_deepspeed(self): "DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs." ) - precision = self.lightning_module.trainer.accelerator.precision - model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) + model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision) if self.zero_stage_3 and self.partition_module: # Ensure the entire model has been moved to the appropriate device diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 83328e8c47271..3f1b9a3acfa50 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -18,6 +18,7 @@ from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.enums import _StrategyType @@ -35,8 +36,14 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ): - super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=None, + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + ) @property def global_rank(self) -> int: diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index c9601a905df1c..73ea87b05835e 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -18,6 +18,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE from pytorch_lightning.utilities.enums import _StrategyType @@ -46,6 +47,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ): """Plugin for Fully Sharded Data Parallel provided by FairScale. @@ -97,6 +99,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, ) self.cpu_offload = cpu_offload self.move_grads_to_cpu = move_grads_to_cpu @@ -124,7 +127,7 @@ def setup_distributed(self) -> None: @contextlib.contextmanager def model_sharded_context(self) -> Generator: - precision = self.lightning_module.trainer.precision + precision = self.precision_plugin.precision def wrap_policy(*args, **kwargs): return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 51558189a3d35..961d2764b8ef3 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -21,6 +21,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import distributed_available @@ -41,8 +42,14 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ): - super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=None, + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + ) rank_zero_only.rank = self.global_rank @property diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 898e62791d6ee..c24008ac3ee4f 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -22,6 +22,7 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE @@ -64,6 +65,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, training_opts: Optional["poptorch.Options"] = None, inference_opts: Optional["poptorch.Options"] = None, ) -> None: @@ -84,6 +86,7 @@ def __init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, ) if not _IPU_AVAILABLE: raise MisconfigurationException( @@ -116,8 +119,7 @@ def setup(self) -> None: self.lightning_module.trainer._update_dataloader = self._convert_to_poptorch_loader def pre_dispatch(self) -> None: - precision = self.lightning_module.trainer.precision - model = LightningIPUModule(self.lightning_module, precision) + model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision) self.model = model # reset the backup diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 4f4b2c5b8e3c3..07ede1ae4f833 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -23,6 +23,7 @@ from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp @@ -36,8 +37,9 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ): - super().__init__(checkpoint_io) + super().__init__(checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.parallel_devices = parallel_devices self.cluster_environment = cluster_environment diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index d7563437bd16b..eb4cb48534708 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -75,7 +75,7 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, Lightnin optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: - precision = self._precision or self.lightning_module.trainer.precision + precision = self._precision or self.precision_plugin.precision is_fp16 = precision in ("mixed", 16) # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 12e627edbe5cb..12c06b9dde541 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -118,9 +118,8 @@ def post_training_step(self): def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process - precision_plugin = trainer.accelerator.precision_plugin - if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin): - precision_plugin.scaler = ShardedGradScaler() + if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): + self.precision_plugin.scaler = ShardedGradScaler() return super().new_process(trainer, mp_queue) @classmethod diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 1737bf3b41ca8..12a0f625b64fc 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -16,6 +16,7 @@ import torch from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE @@ -27,8 +28,9 @@ def __init__( self, device: torch.device, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ): - super().__init__(checkpoint_io) + super().__init__(checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.device: torch.device = device self.global_rank = 0 self.local_rank = 0 diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 9fed2000391dd..e6f6a5f4b26f2 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -16,6 +16,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -33,12 +34,13 @@ def __init__( self, device: int, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, ): device = xm.xla_device(device) checkpoint_io = checkpoint_io or XLACheckpointIO() - super().__init__(device=device, checkpoint_io=checkpoint_io) + super().__init__(device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.debug = debug self.tpu_local_core_rank = 0 diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 7aa4a67721c04..3ab9a8171aac5 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -27,6 +27,7 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn @@ -56,11 +57,14 @@ def __init__( self, parallel_devices: Optional[List[int]] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, **_: Any ) -> None: checkpoint_io = checkpoint_io or XLACheckpointIO() - super().__init__(parallel_devices=parallel_devices, checkpoint_io=checkpoint_io) + super().__init__( + parallel_devices=parallel_devices, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin + ) self.debug = debug self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 @@ -167,7 +171,7 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: set_shared_parameters(self.model.module, shared_params) trainer.accelerator.setup_optimizers(trainer) - trainer.precision_plugin.connect(self._model, None, None) + self.precision_plugin.connect(self._model, None, None) self.barrier("pre-run-stage") diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index c23edf594146f..7010c0e878dc9 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -25,6 +25,7 @@ from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT @@ -33,16 +34,23 @@ class TrainingTypePlugin(ABC): """Base class for all training type plugins that change the behaviour of the training, validation and test- loop.""" - def __init__(self, checkpoint_io: Optional[CheckpointIO] = None) -> None: + def __init__( + self, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None + ) -> None: self._model: Optional[Module] = None self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO() self._checkpoint_io = checkpoint_io + self._precision_plugin = precision_plugin if precision_plugin is not None else PrecisionPlugin() @property def checkpoint_io(self) -> CheckpointIO: return self._checkpoint_io + @property + def precision_plugin(self) -> PrecisionPlugin: + return self._precision_plugin + @checkpoint_io.setter def checkpoint_io(self, plugin: CheckpointIO) -> None: self._checkpoint_io = plugin diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 5532385ca1d98..e5df9c3b84898 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -405,6 +405,9 @@ def training_type_plugin(self) -> TrainingTypePlugin: # attach checkpoint plugin to the training type plugin if self._checkpoint_io is not None: self._training_type_plugin.checkpoint_io = self._checkpoint_io + precision_plugin = self.precision_plugin + if precision_plugin is not None: + self._training_type_plugin._precision_plugin = precision_plugin self._training_type_plugin_resolved = True return self._training_type_plugin @@ -531,11 +534,11 @@ def use_deepspeed(self) -> bool: @property def _is_sharded_training_type(self) -> bool: - return isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)) + return isinstance(self._training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)) @property def _is_fully_sharded_training_type(self) -> bool: - return isinstance(self.training_type_plugin, DDPFullyShardedPlugin) + return isinstance(self._training_type_plugin, DDPFullyShardedPlugin) @property def is_distributed(self) -> bool: @@ -793,12 +796,10 @@ def select_accelerator(self) -> Accelerator: acc_cls = IPUAccelerator else: acc_cls = CPUAccelerator - # as precision_plugin is dependent on training_type_plugin, make sure - # that we first select training_type_plugin, then precision_plugin - accelerator = acc_cls(training_type_plugin=self.training_type_plugin, precision_plugin=self.precision_plugin) + + accelerator = acc_cls(precision_plugin=None, training_type_plugin=self.training_type_plugin) # transfer ownership of the plugins to the accelerator self._training_type_plugin = proxy(self.training_type_plugin) - self._precision_plugin = proxy(self.precision_plugin) return accelerator diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 26fbcb4362c73..2f6e987635d47 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1568,7 +1568,7 @@ def training_type_plugin(self) -> TrainingTypePlugin: @property def precision_plugin(self) -> PrecisionPlugin: - return self.accelerator.precision_plugin + return self.training_type_plugin.precision_plugin @property def global_rank(self) -> int: @@ -1672,7 +1672,7 @@ def amp_backend(self) -> Optional[str]: @property def precision(self) -> Union[str, int]: - return self.accelerator.precision + return self.training_type_plugin.precision_plugin.precision @property def scaler(self): diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index dfaa1c8042355..be2e597c9a2f9 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -193,8 +193,8 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st model = IPUModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=1, precision=16, callbacks=TestCallback()) - assert isinstance(trainer.accelerator.precision_plugin, IPUPrecisionPlugin) - assert trainer.accelerator.precision_plugin.precision == 16 + assert isinstance(trainer.training_type_plugin.precision_plugin, IPUPrecisionPlugin) + assert trainer.training_type_plugin.precision_plugin.precision == 16 with pytest.raises(SystemExit): trainer.fit(model) @@ -213,8 +213,8 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=1, precision=16, callbacks=TestCallback()) assert isinstance(trainer.accelerator.training_type_plugin, IPUPlugin) - assert isinstance(trainer.accelerator.precision_plugin, IPUPrecisionPlugin) - assert trainer.accelerator.precision_plugin.precision == 16 + assert isinstance(trainer.training_type_plugin.precision_plugin, IPUPrecisionPlugin) + assert trainer.training_type_plugin.precision_plugin.precision == 16 with pytest.raises(SystemExit): trainer.fit(model) diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index 78e4c505bb99a..fc1ce413cd494 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -23,7 +23,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.tpu import TPUAccelerator -from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin, XLACheckpointIO +from pytorch_lightning.plugins import DDPPlugin, TPUPrecisionPlugin, TPUSpawnPlugin, XLACheckpointIO from pytorch_lightning.utilities import find_shared_parameters from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset @@ -292,11 +292,23 @@ def test_tpu_invalid_raises(): with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"): accelerator.setup(object()) - accelerator = TPUAccelerator(TPUPrecisionPlugin(), object()) + accelerator = TPUAccelerator(TPUPrecisionPlugin(), DDPPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi"): accelerator.setup(object()) +def test_tpu_invalid_raises_set_precision_with_strategy(): + accelerator = TPUAccelerator(object(), TPUSpawnPlugin(precision_plugin=object())) + with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"): + accelerator.setup(object()) + + accelerator = TPUAccelerator(None, DDPPlugin(precision_plugin=TPUPrecisionPlugin())) + with pytest.raises( + ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin" + ): + accelerator.setup(object()) + + @RunIf(tpu=True) def test_xla_checkpoint_plugin_being_default(): trainer = Trainer(tpu_cores=8) diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index 1468c7f4a4137..c0fab297173e7 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -34,8 +34,8 @@ def test_invalid_on_cpu(tmpdir): def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): """Test to ensure that plugin native amp plugin is correctly chosen when using sharded.""" trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp", gpus=1, precision=16) - assert isinstance(trainer.accelerator.training_type_plugin, DDPFullyShardedPlugin) - assert isinstance(trainer.accelerator.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) + assert isinstance(trainer.training_type_plugin, DDPFullyShardedPlugin) + assert isinstance(trainer.training_type_plugin.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) class TestFSDPModel(BoringModel): diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 2d39a3de6b5c5..480b050c39b36 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -170,8 +170,8 @@ def test_deepspeed_precision_choice(amp_backend, precision, tmpdir): ) assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin) - assert isinstance(trainer.accelerator.precision_plugin, DeepSpeedPrecisionPlugin) - assert trainer.accelerator.precision_plugin.precision == precision + assert isinstance(trainer.training_type_plugin.precision_plugin, DeepSpeedPrecisionPlugin) + assert trainer.training_type_plugin.precision_plugin.precision == precision @RunIf(deepspeed=True) From 35f6cbe09ff1898d70ee24ff7a19782e1912f8dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 19 Nov 2021 01:52:55 +0100 Subject: [PATCH 05/10] Use `update_wrapper` in test_hooks.py (#10578) --- tests/models/test_hooks.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 9e4b545ecc5bc..8328afdac7529 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial +from functools import partial, update_wrapper from inspect import getmembers, isfunction from unittest import mock from unittest.mock import ANY, PropertyMock @@ -223,7 +223,9 @@ def call(hook, fn, *args, **kwargs): for h in get_members(Callback): attr = getattr(self, h) - setattr(self, h, partial(call, h, attr)) + partial_h = partial(call, h, attr) + update_wrapper(partial_h, attr) + setattr(self, h, partial_h) def on_save_checkpoint(*args, **kwargs): return {"foo": True} @@ -256,7 +258,9 @@ def call(hook, fn, *args, **kwargs): for h in pl_module_hooks: attr = getattr(self, h) - setattr(self, h, partial(call, h, attr)) + partial_h = partial(call, h, attr) + update_wrapper(partial_h, attr) + setattr(self, h, partial_h) def validation_epoch_end(self, *args, **kwargs): # `BoringModel` does not have a return for `validation_step_end` so this would fail @@ -852,7 +856,9 @@ def call(hook, fn, *args, **kwargs): for h in get_members(LightningDataModule): attr = getattr(self, h) - setattr(self, h, partial(call, h, attr)) + partial_h = partial(call, h, attr) + update_wrapper(partial_h, attr) + setattr(self, h, partial_h) model = BoringModel() batches = 2 @@ -871,20 +877,12 @@ def call(hook, fn, *args, **kwargs): called = [] dm = HookedDataModule(called) trainer.fit(model, datamodule=dm) - batch_transfer = [ - dict(name="on_before_batch_transfer", args=(ANY, 0)), - dict(name="transfer_batch_to_device", args=(ANY, torch.device("cpu"), 0)), - dict(name="on_after_batch_transfer", args=(ANY, 0)), - ] expected = [ dict(name="prepare_data"), dict(name="setup", kwargs=dict(stage="fit")), dict(name="val_dataloader"), - *batch_transfer * batches, dict(name="train_dataloader"), - *batch_transfer * batches, dict(name="val_dataloader"), - *batch_transfer * batches, dict( name="on_save_checkpoint", args=( @@ -910,7 +908,6 @@ def call(hook, fn, *args, **kwargs): dict(name="prepare_data"), dict(name="setup", kwargs=dict(stage="validate")), dict(name="val_dataloader"), - *batch_transfer * batches, dict(name="teardown", kwargs=dict(stage="validate")), ] assert called == expected @@ -922,7 +919,6 @@ def call(hook, fn, *args, **kwargs): dict(name="prepare_data"), dict(name="setup", kwargs=dict(stage="test")), dict(name="test_dataloader"), - *batch_transfer * batches, dict(name="teardown", kwargs=dict(stage="test")), ] assert called == expected @@ -934,7 +930,6 @@ def call(hook, fn, *args, **kwargs): dict(name="prepare_data"), dict(name="setup", kwargs=dict(stage="predict")), dict(name="predict_dataloader"), - *batch_transfer * batches, dict(name="teardown", kwargs=dict(stage="predict")), ] assert called == expected From 0de8ab4f2ef8cf3544d83abd43c34eaab6e62dc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 19 Nov 2021 03:04:53 +0100 Subject: [PATCH 06/10] Fix failing master due to an interction between PRs (#10627) --- tests/models/test_hooks.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 8328afdac7529..19c4e71d54fc4 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -485,8 +485,7 @@ def training_step(self, batch, batch_idx): dict(name="Callback.on_init_start", args=(trainer,)), dict(name="Callback.on_init_end", args=(trainer,)), ] - with pytest.deprecated_call(match="on_train_dataloader` is deprecated in v1.5"): - trainer.fit(model) + trainer.fit(model) saved_ckpt = { "callbacks": ANY, "epoch": 1, @@ -588,8 +587,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): enable_model_summary=False, callbacks=[HookedCallback([])], ) - with pytest.deprecated_call(match="on_keyboard_interrupt` callback hook was deprecated in v1.5"): - trainer.fit(model) + trainer.fit(model) best_model_path = trainer.checkpoint_callback.best_model_path # resume from checkpoint with HookedModel @@ -611,8 +609,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): dict(name="Callback.on_init_start", args=(trainer,)), dict(name="Callback.on_init_end", args=(trainer,)), ] - with pytest.deprecated_call(match="on_train_dataloader` is deprecated in v1.5"): - trainer.fit(model, ckpt_path=best_model_path) + trainer.fit(model, ckpt_path=best_model_path) saved_ckpt = { "callbacks": ANY, "epoch": 2, # TODO: wrong saved epoch @@ -707,8 +704,7 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader, dict(name="Callback.on_init_end", args=(trainer,)), ] fn = getattr(trainer, verb) - with pytest.deprecated_call(match=f"on_{dataloader}_dataloader` is deprecated in v1.5"): - fn(model, verbose=False) + fn(model, verbose=False) hooks = [ dict(name="train", args=(False,)), dict(name=f"on_{noun}_model_eval"), @@ -752,8 +748,7 @@ def test_trainer_model_hook_system_predict(tmpdir): dict(name="Callback.on_init_start", args=(trainer,)), dict(name="Callback.on_init_end", args=(trainer,)), ] - with pytest.deprecated_call(match="on_predict_dataloader` is deprecated in v1.5"): - trainer.predict(model) + trainer.predict(model) expected = [ dict(name="Callback.on_init_start", args=(trainer,)), dict(name="Callback.on_init_end", args=(trainer,)), From 5788789f0164b984fdcb9d6390b3babbf5997fec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 19 Nov 2021 03:07:33 +0100 Subject: [PATCH 07/10] Move benchmarks into the test directory (#10614) --- .azure-pipelines/gpu-benchmark.yml | 2 +- .azure-pipelines/gpu-tests.yml | 4 ++-- benchmarks/__init__.py | 18 ------------------ pyproject.toml | 6 +----- tests/benchmarks/__init__.py | 0 .../benchmarks}/generate_comparison.py | 2 +- .../benchmarks}/test_basic_parity.py | 0 .../benchmarks}/test_sharded_parity.py | 0 tests/special_tests.sh | 2 +- 9 files changed, 6 insertions(+), 28 deletions(-) delete mode 100644 benchmarks/__init__.py create mode 100644 tests/benchmarks/__init__.py rename {benchmarks => tests/benchmarks}/generate_comparison.py (97%) rename {benchmarks => tests/benchmarks}/test_basic_parity.py (100%) rename {benchmarks => tests/benchmarks}/test_sharded_parity.py (100%) diff --git a/.azure-pipelines/gpu-benchmark.yml b/.azure-pipelines/gpu-benchmark.yml index f8b9593d72798..6d45cc2f4566a 100644 --- a/.azure-pipelines/gpu-benchmark.yml +++ b/.azure-pipelines/gpu-benchmark.yml @@ -36,7 +36,7 @@ jobs: steps: - bash: | - python -m pytest benchmarks -v --durations=0 + python -m pytest tests/benchmarks -v --durations=0 displayName: 'Testing: benchmarks' env: PL_RUNNING_BENCHMARKS: 1 diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index f1af36a6090b9..71332a840fdb0 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -68,7 +68,7 @@ jobs: displayName: 'Get legacy checkpoints' - bash: | - python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50 + python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests --ignore tests/benchmarks -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50 displayName: 'Testing: standard' - bash: | @@ -113,5 +113,5 @@ jobs: displayName: 'Testing: examples' - bash: | - python -m pytest benchmarks -v --maxfail=2 --durations=0 + python -m pytest tests/benchmarks -v --maxfail=2 --durations=0 displayName: 'Testing: benchmarks' diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py deleted file mode 100644 index b4a3da40d40d0..0000000000000 --- a/benchmarks/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -_BENCHMARK_ROOT = os.path.dirname(__file__) -_PROJECT_ROOT = os.path.dirname(_BENCHMARK_ROOT) -_PATH_DATASETS = os.path.join(_PROJECT_ROOT, "Datasets") diff --git a/pyproject.toml b/pyproject.toml index 6546d96e3d5e5..08b7b50eee770 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,6 @@ requires = [ [tool.isort] known_first_party = [ - "benchmarks", "docs", "pl_examples", "pytorch_lightning", @@ -24,7 +23,7 @@ line-length = 120 [tool.mypy] -files = ["pytorch_lightning", "pl_examples", "benchmarks"] +files = ["pytorch_lightning"] disallow_untyped_defs = "True" ignore_missing_imports = "True" show_error_codes = "True" @@ -53,9 +52,6 @@ module = [ "pytorch_lightning.distributed.*", "pytorch_lightning.tuner.*", "pytorch_lightning.utilities.*", - "pl_examples.*", - "benchmarks.*", - "tests.helpers.*" ] ignore_errors = "True" diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/benchmarks/generate_comparison.py b/tests/benchmarks/generate_comparison.py similarity index 97% rename from benchmarks/generate_comparison.py rename to tests/benchmarks/generate_comparison.py index 5a9cde0d80ed3..bc95b5d9cf591 100644 --- a/benchmarks/generate_comparison.py +++ b/tests/benchmarks/generate_comparison.py @@ -16,7 +16,7 @@ import matplotlib.pylab as plt import pandas as pd -from benchmarks.test_basic_parity import measure_loops +from tests.benchmarks.test_basic_parity import measure_loops from tests.helpers.advanced_models import ParityModuleMNIST, ParityModuleRNN NUM_EPOCHS = 20 diff --git a/benchmarks/test_basic_parity.py b/tests/benchmarks/test_basic_parity.py similarity index 100% rename from benchmarks/test_basic_parity.py rename to tests/benchmarks/test_basic_parity.py diff --git a/benchmarks/test_sharded_parity.py b/tests/benchmarks/test_sharded_parity.py similarity index 100% rename from benchmarks/test_sharded_parity.py rename to tests/benchmarks/test_sharded_parity.py diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 4143ec5930ee3..6a3701a7ee9d5 100755 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -20,7 +20,7 @@ export PL_RUNNING_SPECIAL_TESTS=1 defaults='-m coverage run --source pytorch_lightning --append -m pytest --capture=no' # find tests marked as `@RunIf(special=True)`. done manually instead of with pytest because it is faster -grep_output=$(grep --recursive --word-regexp 'tests' 'benchmarks' --regexp 'special=True' --include '*.py' --exclude 'tests/conftest.py') +grep_output=$(grep --recursive --word-regexp 'tests' --regexp 'special=True' --include '*.py' --exclude 'tests/conftest.py') # file paths, remove duplicates files=$(echo "$grep_output" | cut -f1 -d: | sort | uniq) From 7d3ad5b76eebe1333a89a51a72b77ec198d86489 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 19 Nov 2021 03:13:35 +0000 Subject: [PATCH 08/10] Don't register signal in thread (#10610) --- CHANGELOG.md | 2 +- .../trainer/connectors/signal_connector.py | 10 ++++++++-- tests/trainer/connectors/test_signal_connector.py | 14 ++++++++++++++ 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ae0515cf22703..d2380a1dc6c79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -147,7 +147,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed signals being registered within threads ([#10610](https://github.com/PyTorchLightning/pytorch-lightning/pull/10610)) - diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index 795145b5be6af..90d0f6928283f 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -2,6 +2,7 @@ import os import signal import sys +import threading from signal import Signals from subprocess import call from types import FrameType, FunctionType @@ -46,10 +47,10 @@ def register_signal_handlers(self) -> None: # signal.SIGUSR1 doesn't seem available on windows if not self._is_on_windows(): if sigusr1_handlers and not self._has_already_handler(signal.SIGUSR1): - signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers)) + self._register_signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers)) if sigterm_handlers and not self._has_already_handler(signal.SIGTERM): - signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers)) + self._register_signal(signal.SIGTERM, HandlersCompose(sigterm_handlers)) def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: if self.trainer.is_global_zero: @@ -96,3 +97,8 @@ def _has_already_handler(self, signum: Signals) -> bool: return isinstance(signal.getsignal(signum), FunctionType) except AttributeError: return False + + @staticmethod + def _register_signal(signum: Signals, handlers: HandlersCompose) -> None: + if threading.current_thread() is threading.main_thread(): + signal.signal(signum, handlers) diff --git a/tests/trainer/connectors/test_signal_connector.py b/tests/trainer/connectors/test_signal_connector.py index 76dae5e07db35..fbfce158e3675 100644 --- a/tests/trainer/connectors/test_signal_connector.py +++ b/tests/trainer/connectors/test_signal_connector.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import concurrent.futures import os import signal from time import sleep @@ -87,3 +88,16 @@ def test_auto_requeue_flag(auto_requeue): # TODO: should this be done in SignalConnector teardown? signal.signal(signal.SIGTERM, sigterm_handler_default) signal.signal(signal.SIGUSR1, sigusr1_handler_default) + + +def _registering_signals(): + trainer = Trainer() + trainer.signal_connector.register_signal_handlers() + + +@RunIf(skip_windows=True) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_signal_connector_in_thread(): + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + for future in concurrent.futures.as_completed([executor.submit(_registering_signals)]): + assert future.exception() is None From 137b62d80df9896ccce63bf607a29cfdbf1f06f0 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 19 Nov 2021 11:29:57 +0530 Subject: [PATCH 09/10] Add `refresh_rate` to RichProgressBar (#10497) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: ananthsub Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 3 + .../callbacks/progress/rich_progress.py | 59 +++++++++++-------- tests/callbacks/test_rich_progress_bar.py | 27 ++++++++- 3 files changed, 63 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d2380a1dc6c79..438e2c9933310 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520)) +- Renamed `refresh_rate_per_second` parameter to `referesh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497)) + + - Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index c091223fba0bd..e2a269d659127 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -206,7 +206,8 @@ class RichProgressBar(ProgressBarBase): trainer = Trainer(callbacks=RichProgressBar()) Args: - refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled. + refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. + Set it to ``0`` to disable the display. leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False theme: Contains styles used to stylize the progress bar. @@ -222,7 +223,7 @@ class RichProgressBar(ProgressBarBase): def __init__( self, - refresh_rate_per_second: int = 10, + refresh_rate: int = 1, leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(), ) -> None: @@ -231,7 +232,7 @@ def __init__( "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install -U rich`." ) super().__init__() - self._refresh_rate_per_second: int = refresh_rate_per_second + self._refresh_rate: int = refresh_rate self._leave: bool = leave self._enabled: bool = True self.progress: Optional[Progress] = None @@ -242,17 +243,12 @@ def __init__( self.theme = theme @property - def refresh_rate_per_second(self) -> float: - """Refresh rate for Rich Progress. - - Returns: Refresh rate for Progress Bar. - Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress). - """ - return self._refresh_rate_per_second if self._refresh_rate_per_second > 0 else 1 + def refresh_rate(self) -> float: + return self._refresh_rate @property def is_enabled(self) -> bool: - return self._enabled and self._refresh_rate_per_second > 0 + return self._enabled and self.refresh_rate > 0 @property def is_disabled(self) -> bool: @@ -289,7 +285,7 @@ def _init_progress(self, trainer): self.progress = CustomProgress( *self.configure_columns(trainer), self._metric_component, - refresh_per_second=self.refresh_rate_per_second, + auto_refresh=False, disable=self.is_disabled, console=self._console, ) @@ -297,6 +293,10 @@ def _init_progress(self, trainer): # progress has started self._progress_stopped = False + def refresh(self) -> None: + if self.progress: + self.progress.refresh() + def on_train_start(self, trainer, pl_module): super().on_train_start(trainer, pl_module) self._init_progress(trainer) @@ -328,10 +328,12 @@ def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self._init_progress(trainer) self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description) + self.refresh() def on_sanity_check_end(self, trainer, pl_module): super().on_sanity_check_end(trainer, pl_module) self._update(self.val_sanity_progress_bar_id, visible=False) + self.refresh() def on_train_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) @@ -354,6 +356,7 @@ def on_train_epoch_start(self, trainer, pl_module): self.progress.reset( self.main_progress_bar_id, total=total_batches, description=train_description, visible=True ) + self.refresh() def on_validation_epoch_start(self, trainer, pl_module): super().on_validation_epoch_start(trainer, pl_module) @@ -364,6 +367,7 @@ def on_validation_epoch_start(self, trainer, pl_module): val_checks_per_epoch = self.total_train_batches // trainer.val_check_batch total_val_batches = self.total_val_batches * val_checks_per_epoch self.val_progress_bar_id = self._add_task(total_val_batches, self.validation_description, visible=False) + self.refresh() def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: if self.progress is not None: @@ -371,45 +375,54 @@ def _add_task(self, total_batches: int, description: str, visible: bool = True) f"[{self.theme.description}]{description}", total=total_batches, visible=visible ) - def _update(self, progress_bar_id: int, visible: bool = True) -> None: - if self.progress is not None: - self.progress.update(progress_bar_id, advance=1.0, visible=visible) + def _update(self, progress_bar_id: int, current: int, total: int, visible: bool = True) -> None: + if self.progress is not None and self._should_update(current, total): + self.progress.update(progress_bar_id, advance=self.refresh_rate, visible=visible) + self.refresh() + + def _should_update(self, current: int, total: int) -> bool: + return self.is_enabled and (current % self.refresh_rate == 0 or current == total) def on_validation_epoch_end(self, trainer, pl_module): super().on_validation_epoch_end(trainer, pl_module) if self.val_progress_bar_id is not None: - self._update(self.val_progress_bar_id, visible=False) + self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False) def on_test_epoch_start(self, trainer, pl_module): - super().on_train_epoch_start(trainer, pl_module) self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description) + self.refresh() def on_predict_epoch_start(self, trainer, pl_module): super().on_predict_epoch_start(trainer, pl_module) self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description) + self.refresh() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) - self._update(self.main_progress_bar_id) + self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches) self._update_metrics(trainer, pl_module) + self.refresh() def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if trainer.sanity_checking: - self._update(self.val_sanity_progress_bar_id) + self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches) elif self.val_progress_bar_id is not None: # check to see if we should update the main training progress bar if self.main_progress_bar_id is not None: - self._update(self.main_progress_bar_id) - self._update(self.val_progress_bar_id) + self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches) + self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches) + self.refresh() def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - self._update(self.test_progress_bar_id) + self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches) + self.refresh() def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - self._update(self.predict_progress_bar_id) + self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches) + self.refresh() def _get_train_description(self, current_epoch: int) -> str: train_description = f"Epoch {current_epoch}" diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 8f3f20630b5c0..8ca7326fd78f6 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -36,11 +36,11 @@ def test_rich_progress_bar_callback(): @RunIf(rich=True) -def test_rich_progress_bar_refresh_rate(): - progress_bar = RichProgressBar(refresh_rate_per_second=1) +def test_rich_progress_bar_refresh_rate_enabled(): + progress_bar = RichProgressBar(refresh_rate=1) assert progress_bar.is_enabled assert not progress_bar.is_disabled - progress_bar = RichProgressBar(refresh_rate_per_second=0) + progress_bar = RichProgressBar(refresh_rate=0) assert not progress_bar.is_enabled assert progress_bar.is_disabled @@ -180,3 +180,24 @@ def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count): ) trainer.fit(model) assert mock_progress_reset.call_count == reset_call_count + + +@RunIf(rich=True) +@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") +@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(0, 0), (3, 7)])) +def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, expected_call_count): + + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + limit_train_batches=6, + limit_val_batches=6, + max_epochs=1, + callbacks=RichProgressBar(refresh_rate=refresh_rate), + ) + + trainer.fit(model) + + assert progress_update.call_count == expected_call_count From c09c9c760760859a101088ab54b61f933fee53c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Nov 2021 07:49:52 +0100 Subject: [PATCH 10/10] Remove redundant fit call from accelerator connector test (#10626) --- tests/accelerators/test_accelerator_connector.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index c789e86c161a1..a9c9c50d80168 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -350,7 +350,6 @@ def test_accelerator_choice_ddp_cpu_and_strategy_spawn(tmpdir): def _test_accelerator_choice_ddp_cpu_and_strategy(tmpdir, ddp_strategy_class): - model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, strategy=ddp_strategy_class(find_unused_parameters=True), @@ -362,7 +361,6 @@ def _test_accelerator_choice_ddp_cpu_and_strategy(tmpdir, ddp_strategy_class): assert isinstance(trainer.accelerator, CPUAccelerator) assert trainer.training_type_plugin.num_processes == 2 assert trainer.training_type_plugin.parallel_devices == [torch.device("cpu")] * 2 - trainer.fit(model) @mock.patch.dict(