Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ProfilerConnector class #7654

Merged
merged 6 commits into from
May 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654))


- Prune deprecated classif. metrics from `pytorch_lightning.metrics.functional.classification` ([#7499](https://github.com/PyTorchLightning/pytorch-lightning/pull/7499))


Expand Down
61 changes: 0 additions & 61 deletions pytorch_lightning/trainer/connectors/profiler_connector.py

This file was deleted.

37 changes: 32 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from itertools import count
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Union
from weakref import proxy

import torch
from torch.utils.data import DataLoader
Expand All @@ -31,7 +32,13 @@
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.plugins import Plugin
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.profiler import BaseProfiler
from pytorch_lightning.profiler import (
AdvancedProfiler,
BaseProfiler,
PassThroughProfiler,
PyTorchProfiler,
SimpleProfiler,
)
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
Expand All @@ -43,7 +50,6 @@
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
from pytorch_lightning.trainer.connectors.profiler_connector import ProfilerConnector
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
Expand Down Expand Up @@ -325,7 +331,6 @@ def __init__(
self.callback_connector = CallbackConnector(self)
self.debugging_connector = DebuggingConnector(self)
self.training_tricks_connector = TrainingTricksConnector(self)
self.profile_connector = ProfilerConnector(self)
self.checkpoint_connector = CheckpointConnector(self)
self.slurm_connector = SLURMConnector(self)
self.tuner = Tuner(self)
Expand Down Expand Up @@ -382,7 +387,7 @@ def __init__(
self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)

# configure profiler
self.profile_connector.on_trainer_init(profiler)
self.__init_profiler(profiler)

# init logger flags
self.logger_connector.on_trainer_init(
Expand Down Expand Up @@ -813,7 +818,7 @@ def _dispatch(self):

def run_stage(self):
self.accelerator.dispatch(self)
self.profile_connector.setup()
self.__setup_profiler()

if self.evaluating:
return self._run_evaluate()
Expand Down Expand Up @@ -1256,3 +1261,25 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
@staticmethod
def _log_api_event(event: str) -> None:
torch._C._log_api_usage_once("lightning.trainer." + event)

def __init_profiler(self, profiler: Optional[Union[BaseProfiler, str]]) -> None:
if isinstance(profiler, str):
PROFILERS = {
"simple": SimpleProfiler,
"advanced": AdvancedProfiler,
"pytorch": PyTorchProfiler,
}
profiler = profiler.lower()
if profiler not in PROFILERS:
raise MisconfigurationException(
"When passing string value for the `profiler` parameter of `Trainer`,"
f" it can only be one of {list(PROFILERS.keys())}"
)
profiler_class = PROFILERS[profiler]
profiler = profiler_class()
Comment on lines +1266 to +1279
Copy link
Contributor Author

@ananthsub ananthsub May 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be formalized in a profiler registry later

self.profiler: BaseProfiler = profiler or PassThroughProfiler()

def __setup_profiler(self) -> None:
local_rank = self.local_rank if self.world_size > 1 else None
self.profiler._lightning_module = proxy(self.lightning_module)
self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)
25 changes: 24 additions & 1 deletion tests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from packaging.version import Version

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
Expand Down Expand Up @@ -512,3 +512,26 @@ def test_pytorch_profiler_deepcopy(tmpdir):
torch.tensor(1)
pytorch_profiler.describe()
assert deepcopy(pytorch_profiler)


@pytest.mark.parametrize(['profiler', 'expected'], [
(None, PassThroughProfiler),
(SimpleProfiler(), SimpleProfiler),
(AdvancedProfiler(), AdvancedProfiler),
('simple', SimpleProfiler),
('Simple', SimpleProfiler),
('advanced', AdvancedProfiler),
('pytorch', PyTorchProfiler),
])
def test_trainer_profiler_correct_args(profiler, expected):
kwargs = {'profiler': profiler} if profiler is not None else {}
trainer = Trainer(**kwargs)
assert isinstance(trainer.profiler, expected)


def test_trainer_profiler_incorrect_str_arg():
with pytest.raises(
MisconfigurationException,
match=r"When passing string value for the `profiler` parameter of `Trainer`, it can only be one of.*"
):
Trainer(profiler="unknown_profiler")
37 changes: 0 additions & 37 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import DeviceType, DistributedType
from pytorch_lightning.utilities.cloud_io import load as pl_load
Expand Down Expand Up @@ -1305,42 +1304,6 @@ def training_step(self, *args, **kwargs):
log_metrics_mock.assert_has_calls(expected_calls)


@pytest.mark.parametrize(['profiler', 'expected'], [
(None, PassThroughProfiler),
(SimpleProfiler(), SimpleProfiler),
(AdvancedProfiler(), AdvancedProfiler),
('simple', SimpleProfiler),
('Simple', SimpleProfiler),
('advanced', AdvancedProfiler),
('pytorch', PyTorchProfiler),
])
def test_trainer_profiler_correct_args(profiler, expected):
kwargs = {'profiler': profiler} if profiler is not None else {}
trainer = Trainer(**kwargs)
assert isinstance(trainer.profiler, expected)


def test_trainer_profiler_incorrect_str_arg():
with pytest.raises(ValueError, match=r".*can only be 'simple', 'advanced' or 'pytorch'"):
Trainer(profiler="unknown_profiler")


@pytest.mark.parametrize('profiler', (
42,
[42],
dict(a=42),
torch.tensor(42),
Trainer(),
))
def test_trainer_profiler_incorrect_arg_type(profiler):
with pytest.raises(
MisconfigurationException,
match="Only None, str and subclasses of `BaseProfiler`"
r" are valid values for `Trainer`'s `profiler` parameter. *"
):
Trainer(profiler=profiler)


class TestLightningDataModule(LightningDataModule):

def __init__(self, dataloaders):
Expand Down