diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d81adb23cf86..0551a281b431c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed +- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654)) + + - Prune deprecated classif. metrics from `pytorch_lightning.metrics.functional.classification` ([#7499](https://github.com/PyTorchLightning/pytorch-lightning/pull/7499)) diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py deleted file mode 100644 index 5fad9bca8ecf1..0000000000000 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -from typing import Union -from weakref import proxy - -from pytorch_lightning.profiler import ( - AdvancedProfiler, - BaseProfiler, - PassThroughProfiler, - PyTorchProfiler, - SimpleProfiler, -) -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -PROFILERS = { - "simple": SimpleProfiler, - "advanced": AdvancedProfiler, - "pytorch": PyTorchProfiler, -} - - -class ProfilerConnector: - - def __init__(self, trainer): - self.trainer = trainer - - def on_trainer_init(self, profiler: Union[BaseProfiler, str]): - - if profiler and not isinstance(profiler, (str, BaseProfiler)): - raise MisconfigurationException( - "Only None, str and subclasses of `BaseProfiler`" - " are valid values for `Trainer`'s `profiler` parameter." - f" Received {profiler} which is of type {type(profiler)}." - ) - if isinstance(profiler, str): - if profiler.lower() in PROFILERS: - profiler_class = PROFILERS[profiler.lower()] - profiler = profiler_class() - else: - raise ValueError( - "When passing string value for the `profiler` parameter of" - " `Trainer`, it can only be 'simple', 'advanced' or 'pytorch'" - ) - self.trainer.profiler = profiler or PassThroughProfiler() - - def setup(self) -> None: - trainer = self.trainer - local_rank = trainer.local_rank if trainer.world_size > 1 else None - trainer.profiler._lightning_module = proxy(trainer.lightning_module) - trainer.profiler.setup(stage=trainer.state.fn._setup_fn, local_rank=local_rank, log_dir=trainer.log_dir) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6a20625978e39..862abdbea46ca 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -18,6 +18,7 @@ from itertools import count from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Union +from weakref import proxy import torch from torch.utils.data import DataLoader @@ -31,7 +32,13 @@ from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment -from pytorch_lightning.profiler import BaseProfiler +from pytorch_lightning.profiler import ( + AdvancedProfiler, + BaseProfiler, + PassThroughProfiler, + PyTorchProfiler, + SimpleProfiler, +) from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import ConfigValidator from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector @@ -43,7 +50,6 @@ from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector -from pytorch_lightning.trainer.connectors.profiler_connector import ProfilerConnector from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin @@ -325,7 +331,6 @@ def __init__( self.callback_connector = CallbackConnector(self) self.debugging_connector = DebuggingConnector(self) self.training_tricks_connector = TrainingTricksConnector(self) - self.profile_connector = ProfilerConnector(self) self.checkpoint_connector = CheckpointConnector(self) self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) @@ -382,7 +387,7 @@ def __init__( self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size) # configure profiler - self.profile_connector.on_trainer_init(profiler) + self.__init_profiler(profiler) # init logger flags self.logger_connector.on_trainer_init( @@ -813,7 +818,7 @@ def _dispatch(self): def run_stage(self): self.accelerator.dispatch(self) - self.profile_connector.setup() + self.__setup_profiler() if self.evaluating: return self._run_evaluate() @@ -1256,3 +1261,25 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: @staticmethod def _log_api_event(event: str) -> None: torch._C._log_api_usage_once("lightning.trainer." + event) + + def __init_profiler(self, profiler: Optional[Union[BaseProfiler, str]]) -> None: + if isinstance(profiler, str): + PROFILERS = { + "simple": SimpleProfiler, + "advanced": AdvancedProfiler, + "pytorch": PyTorchProfiler, + } + profiler = profiler.lower() + if profiler not in PROFILERS: + raise MisconfigurationException( + "When passing string value for the `profiler` parameter of `Trainer`," + f" it can only be one of {list(PROFILERS.keys())}" + ) + profiler_class = PROFILERS[profiler] + profiler = profiler_class() + self.profiler: BaseProfiler = profiler or PassThroughProfiler() + + def __setup_profiler(self) -> None: + local_rank = self.local_rank if self.world_size > 1 else None + self.profiler._lightning_module = proxy(self.lightning_module) + self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 202a525a15b74..acc2bac1c466f 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -23,7 +23,7 @@ from packaging.version import Version from pytorch_lightning import Callback, Trainer -from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler +from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.profiler.pytorch import RegisterRecordFunction from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE @@ -512,3 +512,26 @@ def test_pytorch_profiler_deepcopy(tmpdir): torch.tensor(1) pytorch_profiler.describe() assert deepcopy(pytorch_profiler) + + +@pytest.mark.parametrize(['profiler', 'expected'], [ + (None, PassThroughProfiler), + (SimpleProfiler(), SimpleProfiler), + (AdvancedProfiler(), AdvancedProfiler), + ('simple', SimpleProfiler), + ('Simple', SimpleProfiler), + ('advanced', AdvancedProfiler), + ('pytorch', PyTorchProfiler), +]) +def test_trainer_profiler_correct_args(profiler, expected): + kwargs = {'profiler': profiler} if profiler is not None else {} + trainer = Trainer(**kwargs) + assert isinstance(trainer.profiler, expected) + + +def test_trainer_profiler_incorrect_str_arg(): + with pytest.raises( + MisconfigurationException, + match=r"When passing string value for the `profiler` parameter of `Trainer`, it can only be one of.*" + ): + Trainer(profiler="unknown_profiler") diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2d0b68b1a6cc7..a8567db70d0a6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -36,7 +36,6 @@ from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler from pytorch_lightning.plugins import DDPSpawnPlugin -from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import DeviceType, DistributedType from pytorch_lightning.utilities.cloud_io import load as pl_load @@ -1305,42 +1304,6 @@ def training_step(self, *args, **kwargs): log_metrics_mock.assert_has_calls(expected_calls) -@pytest.mark.parametrize(['profiler', 'expected'], [ - (None, PassThroughProfiler), - (SimpleProfiler(), SimpleProfiler), - (AdvancedProfiler(), AdvancedProfiler), - ('simple', SimpleProfiler), - ('Simple', SimpleProfiler), - ('advanced', AdvancedProfiler), - ('pytorch', PyTorchProfiler), -]) -def test_trainer_profiler_correct_args(profiler, expected): - kwargs = {'profiler': profiler} if profiler is not None else {} - trainer = Trainer(**kwargs) - assert isinstance(trainer.profiler, expected) - - -def test_trainer_profiler_incorrect_str_arg(): - with pytest.raises(ValueError, match=r".*can only be 'simple', 'advanced' or 'pytorch'"): - Trainer(profiler="unknown_profiler") - - -@pytest.mark.parametrize('profiler', ( - 42, - [42], - dict(a=42), - torch.tensor(42), - Trainer(), -)) -def test_trainer_profiler_incorrect_arg_type(profiler): - with pytest.raises( - MisconfigurationException, - match="Only None, str and subclasses of `BaseProfiler`" - r" are valid values for `Trainer`'s `profiler` parameter. *" - ): - Trainer(profiler=profiler) - - class TestLightningDataModule(LightningDataModule): def __init__(self, dataloaders):