From 7a5ef1b31550b29ecf9209e5ad88dbc7bbbadcea Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 22 May 2021 08:37:18 -0700 Subject: [PATCH 1/6] Remove ProfilerConnector class --- .../trainer/connectors/profiler_connector.py | 61 ------------------- pytorch_lightning/trainer/trainer.py | 35 +++++++++-- 2 files changed, 30 insertions(+), 66 deletions(-) delete mode 100644 pytorch_lightning/trainer/connectors/profiler_connector.py diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py deleted file mode 100644 index 5fad9bca8ecf1..0000000000000 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -from typing import Union -from weakref import proxy - -from pytorch_lightning.profiler import ( - AdvancedProfiler, - BaseProfiler, - PassThroughProfiler, - PyTorchProfiler, - SimpleProfiler, -) -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -PROFILERS = { - "simple": SimpleProfiler, - "advanced": AdvancedProfiler, - "pytorch": PyTorchProfiler, -} - - -class ProfilerConnector: - - def __init__(self, trainer): - self.trainer = trainer - - def on_trainer_init(self, profiler: Union[BaseProfiler, str]): - - if profiler and not isinstance(profiler, (str, BaseProfiler)): - raise MisconfigurationException( - "Only None, str and subclasses of `BaseProfiler`" - " are valid values for `Trainer`'s `profiler` parameter." - f" Received {profiler} which is of type {type(profiler)}." - ) - if isinstance(profiler, str): - if profiler.lower() in PROFILERS: - profiler_class = PROFILERS[profiler.lower()] - profiler = profiler_class() - else: - raise ValueError( - "When passing string value for the `profiler` parameter of" - " `Trainer`, it can only be 'simple', 'advanced' or 'pytorch'" - ) - self.trainer.profiler = profiler or PassThroughProfiler() - - def setup(self) -> None: - trainer = self.trainer - local_rank = trainer.local_rank if trainer.world_size > 1 else None - trainer.profiler._lightning_module = proxy(trainer.lightning_module) - trainer.profiler.setup(stage=trainer.state.fn._setup_fn, local_rank=local_rank, log_dir=trainer.log_dir) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6a20625978e39..9b5385f5b975d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -18,6 +18,7 @@ from itertools import count from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Union +from weakref import proxy import torch from torch.utils.data import DataLoader @@ -31,7 +32,13 @@ from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment -from pytorch_lightning.profiler import BaseProfiler +from pytorch_lightning.profiler import ( + AdvancedProfiler, + BaseProfiler, + PassThroughProfiler, + PyTorchProfiler, + SimpleProfiler, +) from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import ConfigValidator from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector @@ -43,7 +50,6 @@ from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector -from pytorch_lightning.trainer.connectors.profiler_connector import ProfilerConnector from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin @@ -325,7 +331,6 @@ def __init__( self.callback_connector = CallbackConnector(self) self.debugging_connector = DebuggingConnector(self) self.training_tricks_connector = TrainingTricksConnector(self) - self.profile_connector = ProfilerConnector(self) self.checkpoint_connector = CheckpointConnector(self) self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) @@ -382,7 +387,7 @@ def __init__( self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size) # configure profiler - self.profile_connector.on_trainer_init(profiler) + self.__init_profiler(profiler) # init logger flags self.logger_connector.on_trainer_init( @@ -813,7 +818,7 @@ def _dispatch(self): def run_stage(self): self.accelerator.dispatch(self) - self.profile_connector.setup() + self.__setup_profiler() if self.evaluating: return self._run_evaluate() @@ -1256,3 +1261,23 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: @staticmethod def _log_api_event(event: str) -> None: torch._C._log_api_usage_once("lightning.trainer." + event) + + def __init_profiler(self, profiler: Optional[Union[BaseProfiler, str]]) -> None: + if isinstance(profiler, str): + PROFILERS = { + "simple": SimpleProfiler, + "advanced": AdvancedProfiler, + "pytorch": PyTorchProfiler, + } + if profiler.lower() not in PROFILERS: + raise MisconfigurationException( + f"When passing string value for the `profiler` parameter of `Trainer, it can only be one of {list(PROFILERS.keys())}" + ) + profiler_class = PROFILERS[profiler.lower()] + profiler = profiler_class() + self.profiler = profiler or PassThroughProfiler() + + def __setup_profiler(self) -> None: + local_rank = self.local_rank if self.world_size > 1 else None + self.profiler._lightning_module = proxy(self.lightning_module) + self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir) From e5953b4eecaf37279f5125986e865a2d01e2f82f Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 22 May 2021 08:38:36 -0700 Subject: [PATCH 2/6] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9b5385f5b975d..360aa18921f8f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1271,7 +1271,8 @@ def __init_profiler(self, profiler: Optional[Union[BaseProfiler, str]]) -> None: } if profiler.lower() not in PROFILERS: raise MisconfigurationException( - f"When passing string value for the `profiler` parameter of `Trainer, it can only be one of {list(PROFILERS.keys())}" + "When passing string value for the `profiler` parameter of `Trainer, " + f" it can only be one of {list(PROFILERS.keys())}" ) profiler_class = PROFILERS[profiler.lower()] profiler = profiler_class() From c7d36a3e034edd6972aab267032ce2c7dee086b2 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 22 May 2021 08:39:39 -0700 Subject: [PATCH 3/6] Update CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d81adb23cf86..0551a281b431c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed +- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654)) + + - Prune deprecated classif. metrics from `pytorch_lightning.metrics.functional.classification` ([#7499](https://github.com/PyTorchLightning/pytorch-lightning/pull/7499)) From 1f7cdc3b89ff2aa937b647326fa2a8dc7b4ded79 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 22 May 2021 08:40:34 -0700 Subject: [PATCH 4/6] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 360aa18921f8f..eeb641b7c0b23 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1276,7 +1276,7 @@ def __init_profiler(self, profiler: Optional[Union[BaseProfiler, str]]) -> None: ) profiler_class = PROFILERS[profiler.lower()] profiler = profiler_class() - self.profiler = profiler or PassThroughProfiler() + self.profiler: BaseProfiler = profiler or PassThroughProfiler() def __setup_profiler(self) -> None: local_rank = self.local_rank if self.world_size > 1 else None From 5fec1fb62375b85dd66c3efd7792e567d0bf3cd8 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 22 May 2021 08:44:05 -0700 Subject: [PATCH 5/6] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index eeb641b7c0b23..de207be711fc1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1269,12 +1269,13 @@ def __init_profiler(self, profiler: Optional[Union[BaseProfiler, str]]) -> None: "advanced": AdvancedProfiler, "pytorch": PyTorchProfiler, } - if profiler.lower() not in PROFILERS: + profiler = profiler.lower() + if profiler not in PROFILERS: raise MisconfigurationException( "When passing string value for the `profiler` parameter of `Trainer, " f" it can only be one of {list(PROFILERS.keys())}" ) - profiler_class = PROFILERS[profiler.lower()] + profiler_class = PROFILERS[profiler] profiler = profiler_class() self.profiler: BaseProfiler = profiler or PassThroughProfiler() From f755f6a95e115c25e601c5bda6f3608228de7a1f Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 22 May 2021 09:01:04 -0700 Subject: [PATCH 6/6] tests --- pytorch_lightning/trainer/trainer.py | 2 +- tests/test_profiler.py | 25 ++++++++++++++++++- tests/trainer/test_trainer.py | 37 ---------------------------- 3 files changed, 25 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index de207be711fc1..862abdbea46ca 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1272,7 +1272,7 @@ def __init_profiler(self, profiler: Optional[Union[BaseProfiler, str]]) -> None: profiler = profiler.lower() if profiler not in PROFILERS: raise MisconfigurationException( - "When passing string value for the `profiler` parameter of `Trainer, " + "When passing string value for the `profiler` parameter of `Trainer`," f" it can only be one of {list(PROFILERS.keys())}" ) profiler_class = PROFILERS[profiler] diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 202a525a15b74..acc2bac1c466f 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -23,7 +23,7 @@ from packaging.version import Version from pytorch_lightning import Callback, Trainer -from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler +from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.profiler.pytorch import RegisterRecordFunction from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE @@ -512,3 +512,26 @@ def test_pytorch_profiler_deepcopy(tmpdir): torch.tensor(1) pytorch_profiler.describe() assert deepcopy(pytorch_profiler) + + +@pytest.mark.parametrize(['profiler', 'expected'], [ + (None, PassThroughProfiler), + (SimpleProfiler(), SimpleProfiler), + (AdvancedProfiler(), AdvancedProfiler), + ('simple', SimpleProfiler), + ('Simple', SimpleProfiler), + ('advanced', AdvancedProfiler), + ('pytorch', PyTorchProfiler), +]) +def test_trainer_profiler_correct_args(profiler, expected): + kwargs = {'profiler': profiler} if profiler is not None else {} + trainer = Trainer(**kwargs) + assert isinstance(trainer.profiler, expected) + + +def test_trainer_profiler_incorrect_str_arg(): + with pytest.raises( + MisconfigurationException, + match=r"When passing string value for the `profiler` parameter of `Trainer`, it can only be one of.*" + ): + Trainer(profiler="unknown_profiler") diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2d0b68b1a6cc7..a8567db70d0a6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -36,7 +36,6 @@ from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler from pytorch_lightning.plugins import DDPSpawnPlugin -from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import DeviceType, DistributedType from pytorch_lightning.utilities.cloud_io import load as pl_load @@ -1305,42 +1304,6 @@ def training_step(self, *args, **kwargs): log_metrics_mock.assert_has_calls(expected_calls) -@pytest.mark.parametrize(['profiler', 'expected'], [ - (None, PassThroughProfiler), - (SimpleProfiler(), SimpleProfiler), - (AdvancedProfiler(), AdvancedProfiler), - ('simple', SimpleProfiler), - ('Simple', SimpleProfiler), - ('advanced', AdvancedProfiler), - ('pytorch', PyTorchProfiler), -]) -def test_trainer_profiler_correct_args(profiler, expected): - kwargs = {'profiler': profiler} if profiler is not None else {} - trainer = Trainer(**kwargs) - assert isinstance(trainer.profiler, expected) - - -def test_trainer_profiler_incorrect_str_arg(): - with pytest.raises(ValueError, match=r".*can only be 'simple', 'advanced' or 'pytorch'"): - Trainer(profiler="unknown_profiler") - - -@pytest.mark.parametrize('profiler', ( - 42, - [42], - dict(a=42), - torch.tensor(42), - Trainer(), -)) -def test_trainer_profiler_incorrect_arg_type(profiler): - with pytest.raises( - MisconfigurationException, - match="Only None, str and subclasses of `BaseProfiler`" - r" are valid values for `Trainer`'s `profiler` parameter. *" - ): - Trainer(profiler=profiler) - - class TestLightningDataModule(LightningDataModule): def __init__(self, dataloaders):