Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[doc] Move each profiler to its own file + Add missing PyTorchProfiler to the doc #7822

Merged
merged 10 commits into from
Jun 4, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `Trainer.fit` now raises an error when using manual optimization with unsupported features such as `gradient_clip_val` or `accumulate_grad_batches` ([#7788](https://github.com/PyTorchLightning/pytorch-lightning/pull/7788))


- Moved profilers to their own file ([#7822](https://github.com/PyTorchLightning/pytorch-lightning/pull/7822))


### Deprecated


Expand Down
9 changes: 8 additions & 1 deletion docs/source/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,15 @@ Profiler API
.. autosummary::
:toctree: api
:nosignatures:
:template: classtemplate.rst

AbstractProfiler
AdvancedProfiler
BaseProfiler
PassThroughProfiler
PyTorchProfiler
SimpleProfiler
tchaton marked this conversation as resolved.
Show resolved Hide resolved

profilers

Trainer API
-----------
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,16 @@ def custom_processing_step(self, data):
python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))'

"""

from pytorch_lightning.profiler.profilers import AdvancedProfiler, BaseProfiler, PassThroughProfiler, SimpleProfiler
from pytorch_lightning.profiler.advanced import AdvancedProfiler
from pytorch_lightning.profiler.base import AbstractProfiler, BaseProfiler, PassThroughProfiler
from pytorch_lightning.profiler.pytorch import PyTorchProfiler
from pytorch_lightning.profiler.simple import SimpleProfiler

__all__ = [
'AbstractProfiler',
'BaseProfiler',
'SimpleProfiler',
'AdvancedProfiler',
'PassThroughProfiler',
"PyTorchProfiler",
'PyTorchProfiler',
'SimpleProfiler',
]
92 changes: 92 additions & 0 deletions pytorch_lightning/profiler/advanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Profiler to check if there are any bottlenecks in your code."""
import cProfile
import io
import logging
import pstats
from pathlib import Path
from typing import Dict, Optional, Union

from pytorch_lightning.profiler.base import BaseProfiler

log = logging.getLogger(__name__)


class AdvancedProfiler(BaseProfiler):
"""
This profiler uses Python's cProfiler to record more detailed information about
time spent in each function call recorded during a given action. The output is quite
verbose and you should only use this if you want very detailed reports.
"""

def __init__(
self,
dirpath: Optional[Union[str, Path]] = None,
filename: Optional[str] = None,
line_count_restriction: float = 1.0,
output_filename: Optional[str] = None,
) -> None:
"""
Args:
dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the
``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`)
will be used.

filename: If present, filename where the profiler results will be saved instead of printing to stdout.
The ``.txt`` extension will be used automatically.

line_count_restriction: this can be used to limit the number of functions
reported for each action. either an integer (to select a count of lines),
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)

Raises:
ValueError:
If you attempt to stop recording an action which was never started.
"""
super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename)
self.profiled_actions: Dict[str, cProfile.Profile] = {}
self.line_count_restriction = line_count_restriction

def start(self, action_name: str) -> None:
if action_name not in self.profiled_actions:
self.profiled_actions[action_name] = cProfile.Profile()
self.profiled_actions[action_name].enable()

def stop(self, action_name: str) -> None:
pr = self.profiled_actions.get(action_name)
if pr is None:
raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.")
pr.disable()

def summary(self) -> str:
recorded_stats = {}
for action_name, pr in self.profiled_actions.items():
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative')
ps.print_stats(self.line_count_restriction)
recorded_stats[action_name] = s.getvalue()
return self._stats_to_str(recorded_stats)

def teardown(self, stage: Optional[str] = None) -> None:
super().teardown(stage=stage)
self.profiled_actions = {}

def __reduce__(self):
# avoids `TypeError: cannot pickle 'cProfile.Profile' object`
return (
self.__class__,
tuple(),
dict(dirpath=self.dirpath, filename=self.filename, line_count_restriction=self.line_count_restriction),
)
215 changes: 215 additions & 0 deletions pytorch_lightning/profiler/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Profiler to check if there are any bottlenecks in your code."""
import logging
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, Optional, TextIO, Union

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem

log = logging.getLogger(__name__)


class AbstractProfiler(ABC):
"""Specification of a profiler."""

@abstractmethod
def start(self, action_name: str) -> None:
"""Defines how to start recording an action."""

@abstractmethod
def stop(self, action_name: str) -> None:
"""Defines how to record the duration once an action is complete."""

@abstractmethod
def summary(self) -> str:
"""Create profiler summary in text format."""

@abstractmethod
def setup(self, **kwargs: Any) -> None:
"""Execute arbitrary pre-profiling set-up steps as defined by subclass."""

@abstractmethod
def teardown(self, **kwargs: Any) -> None:
"""Execute arbitrary post-profiling tear-down steps as defined by subclass."""


class BaseProfiler(AbstractProfiler):
"""
If you wish to write a custom profiler, you should inherit from this class.
"""

def __init__(
self,
dirpath: Optional[Union[str, Path]] = None,
filename: Optional[str] = None,
output_filename: Optional[str] = None,
) -> None:
self.dirpath = dirpath
self.filename = filename
if output_filename is not None:
rank_zero_warn(
"`Profiler` signature has changed in v1.3. The `output_filename` parameter has been removed in"
" favor of `dirpath` and `filename`. Support for the old signature will be removed in v1.5",
DeprecationWarning
)
filepath = Path(output_filename)
self.dirpath = filepath.parent
self.filename = filepath.stem

self._output_file: Optional[TextIO] = None
self._write_stream: Optional[Callable] = None
self._local_rank: Optional[int] = None
self._log_dir: Optional[str] = None
self._stage: Optional[str] = None

@contextmanager
def profile(self, action_name: str) -> None:
"""
Yields a context manager to encapsulate the scope of a profiled action.

Example::

with self.profile('load training data'):
# load training data code

The profiler will start once you've entered the context and will automatically
stop once you exit the code block.
"""
try:
self.start(action_name)
yield action_name
finally:
self.stop(action_name)

def profile_iterable(self, iterable, action_name: str) -> None:
iterator = iter(iterable)
while True:
try:
self.start(action_name)
value = next(iterator)
self.stop(action_name)
yield value
except StopIteration:
self.stop(action_name)
break

def _rank_zero_info(self, *args, **kwargs) -> None:
if self._local_rank in (None, 0):
log.info(*args, **kwargs)

def _prepare_filename(self, extension: str = ".txt") -> str:
filename = ""
if self._stage is not None:
filename += f"{self._stage}-"
filename += str(self.filename)
if self._local_rank is not None:
filename += f"-{self._local_rank}"
filename += extension
return filename

def _prepare_streams(self) -> None:
if self._write_stream is not None:
return
if self.filename:
filepath = os.path.join(self.dirpath, self._prepare_filename())
fs = get_filesystem(filepath)
file = fs.open(filepath, "a")
self._output_file = file
self._write_stream = file.write
else:
self._write_stream = self._rank_zero_info

def describe(self) -> None:
"""Logs a profile report after the conclusion of run."""
# there are pickling issues with open file handles in Python 3.6
# so to avoid them, we open and close the files within this function
# by calling `_prepare_streams` and `teardown`
self._prepare_streams()
summary = self.summary()
if summary:
self._write_stream(summary)
if self._output_file is not None:
self._output_file.flush()
self.teardown(stage=self._stage)

def _stats_to_str(self, stats: Dict[str, str]) -> str:
stage = f"{self._stage.upper()} " if self._stage is not None else ""
output = [stage + "Profiler Report"]
for action, value in stats.items():
header = f"Profile stats for: {action}"
if self._local_rank is not None:
header += f" rank: {self._local_rank}"
output.append(header)
output.append(value)
return os.linesep.join(output)

def setup(
self,
stage: Optional[str] = None,
local_rank: Optional[int] = None,
log_dir: Optional[str] = None,
) -> None:
"""Execute arbitrary pre-profiling set-up steps."""
self._stage = stage
self._local_rank = local_rank
self._log_dir = log_dir
self.dirpath = self.dirpath or log_dir

def teardown(self, stage: Optional[str] = None) -> None:
"""
Execute arbitrary post-profiling tear-down steps.

Closes the currently open file and stream.
"""
self._write_stream = None
if self._output_file is not None:
self._output_file.close()
self._output_file = None # can't pickle TextIOWrapper

def __del__(self) -> None:
self.teardown(stage=self._stage)

def start(self, action_name: str) -> None:
raise NotImplementedError

def stop(self, action_name: str) -> None:
raise NotImplementedError

def summary(self) -> str:
raise NotImplementedError

@property
def local_rank(self) -> int:
return 0 if self._local_rank is None else self._local_rank


class PassThroughProfiler(BaseProfiler):
"""
This class should be used when you don't want the (small) overhead of profiling.
The Trainer uses this class by default.
"""

def start(self, action_name: str) -> None:
pass

def stop(self, action_name: str) -> None:
pass

def summary(self) -> str:
return ""
Loading