diff --git a/.gitignore b/.gitignore index c007140257188..99939ff7fce0c 100644 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,4 @@ tags data MNIST runs -*traces* +*trace* diff --git a/CHANGELOG.md b/CHANGELOG.md index 51ad97decd867..1c0dfd5a482b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `Trainer.predict` config validation ([#6543](https://github.com/PyTorchLightning/pytorch-lightning/pull/6543)) +- Added `AbstractProfiler` interface ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) + + - Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) @@ -68,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) +- Changed profilers to save separate report files per state and rank ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) + + ### Deprecated - `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) @@ -76,6 +82,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) +- Deprecated `Profiler(output_filename)` in favor of `dirpath` and `filename` ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) + + - Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505), [#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530), diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index 5668fd6654b2f..54bc5cdf0122c 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -21,31 +21,19 @@ from abc import ABC, abstractmethod from collections import defaultdict from contextlib import contextmanager -from typing import Optional, Union +from pathlib import Path +from typing import Any, Callable, Dict, Optional, TextIO, Tuple, Union import numpy as np +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem log = logging.getLogger(__name__) -class BaseProfiler(ABC): - """ - If you wish to write a custom profiler, you should inhereit from this class. - """ - - def __init__(self, output_streams: Optional[Union[list, tuple]] = None): - """ - Args: - output_streams: callable - """ - if output_streams: - if not isinstance(output_streams, (list, tuple)): - output_streams = [output_streams] - else: - output_streams = [] - self.write_streams = output_streams +class AbstractProfiler(ABC): + """Specification of a profiler.""" @abstractmethod def start(self, action_name: str) -> None: @@ -55,23 +43,47 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: """Defines how to record the duration once an action is complete.""" - def setup( + @abstractmethod + def summary(self) -> str: + """Create profiler summary in text format.""" + + @abstractmethod + def setup(self, **kwargs: Any) -> None: + """Execute arbitrary pre-profiling set-up steps as defined by subclass.""" + + @abstractmethod + def teardown(self, **kwargs: Any) -> None: + """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" + + +class BaseProfiler(AbstractProfiler): + """ + If you wish to write a custom profiler, you should inherit from this class. + """ + + def __init__( self, - stage: Optional[str] = None, - local_rank: Optional[int] = None, - log_dir: Optional[str] = None + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + output_filename: Optional[str] = None, ) -> None: - """Execute arbitrary pre-profiling set-up steps.""" - self.stage = stage - self.local_rank = local_rank - self.log_dir = log_dir + self.dirpath = dirpath + self.filename = filename + if output_filename is not None: + rank_zero_warn( + "`Profiler` signature has changed in v1.3. The `output_filename` parameter has been removed in" + " favor of `dirpath` and `filename`. Support for the old signature will be removed in v1.5", + DeprecationWarning + ) + filepath = Path(output_filename) + self.dirpath = filepath.parent + self.filename = filepath.stem - def teardown(self, stage: Optional[str] = None) -> None: - """Execute arbitrary post-profiling tear-down steps.""" - self.stage = stage - if self.output_file: - self.output_file.close() - self.output_file = None + self._output_file: Optional[TextIO] = None + self._write_stream: Optional[Callable] = None + self._local_rank: Optional[int] = None + self._log_dir: Optional[str] = None + self._stage: Optional[str] = None @contextmanager def profile(self, action_name: str) -> None: @@ -104,19 +116,94 @@ def profile_iterable(self, iterable, action_name: str) -> None: self.stop(action_name) break + def _rank_zero_info(self, *args, **kwargs) -> None: + if self._local_rank in (None, 0): + log.info(*args, **kwargs) + + def _prepare_filename(self) -> str: + filename = "" + if self._stage is not None: + filename += f"{self._stage}-" + filename += str(self.filename) + if self._local_rank is not None: + filename += f"-{self.local_rank}" + filename += ".txt" + return filename + + def _prepare_streams(self) -> None: + if self._write_stream is not None: + return + if self.filename: + dirpath = self.dirpath or self._log_dir + filepath = os.path.join(dirpath, self._prepare_filename()) + fs = get_filesystem(filepath) + file = fs.open(filepath, "a") + self._output_file = file + self._write_stream = file.write + else: + self._write_stream = self._rank_zero_info + def describe(self) -> None: - """Logs a profile report after the conclusion of the training run.""" - for write in self.write_streams: - write(self.summary()) - if self.output_file is not None: - self.output_file.flush() + """Logs a profile report after the conclusion of run.""" + # there are pickling issues with open file handles in Python 3.6 + # so to avoid them, we open and close the files within this function + # by calling `_prepare_streams` and `teardown` + self._prepare_streams() + self._write_stream(self.summary()) + if self._output_file is not None: + self._output_file.flush() + self.teardown(stage=self._stage) + + def _stats_to_str(self, stats: Dict[str, str]) -> str: + stage = f"{self._stage.upper()} " if self._stage is not None else "" + output = [stage + "Profiler Report"] + for action, value in stats.items(): + header = f"Profile stats for: {action}" + if self._local_rank is not None: + header += f" rank: {self._local_rank}" + output.append(header) + output.append(value) + return os.linesep.join(output) + + def setup( + self, + stage: Optional[str] = None, + local_rank: Optional[int] = None, + log_dir: Optional[str] = None, + ) -> None: + """Execute arbitrary pre-profiling set-up steps.""" + self._stage = stage + self._local_rank = local_rank + self._log_dir = log_dir + if self.dirpath is None: + self.dirpath = self._log_dir + + def teardown(self, stage: Optional[str] = None) -> None: + """ + Execute arbitrary post-profiling tear-down steps. + + Closes the currently open file and stream. + """ + self._write_stream = None + if self._output_file is not None: + self._output_file.close() + self._output_file = None # can't pickle TextIOWrapper + + def __del__(self) -> None: + self.teardown(stage=self._stage) + + def start(self, action_name: str) -> None: + raise NotImplementedError + + def stop(self, action_name: str) -> None: + raise NotImplementedError - @abstractmethod def summary(self) -> str: - """Create profiler summary in text format.""" + raise NotImplementedError - def __del__(self): - self.teardown(None) + @property + def local_rank(self): + return '0' if self._local_rank is None else self._local_rank class PassThroughProfiler(BaseProfiler): @@ -125,10 +212,6 @@ class PassThroughProfiler(BaseProfiler): The Trainer uses this class by default. """ - def __init__(self): - self.output_file = None - super().__init__(output_streams=None) - def start(self, action_name: str) -> None: pass @@ -145,30 +228,32 @@ class SimpleProfiler(BaseProfiler): the mean duration of each action and the total time spent over the entire training run. """ - def __init__(self, output_filename: Optional[str] = None, extended=True): + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + extended: bool = True, + output_filename: Optional[str] = None, + ) -> None: """ Args: - output_filename: optionally save profile results to file instead of printing - to std out when training is finished. + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. + + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. Raises: ValueError: If you attempt to start an action which has already started, or if you attempt to stop recording an action which was never started. """ - self.current_actions = {} + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) + self.current_actions: Dict[str, float] = {} self.recorded_durations = defaultdict(list) self.extended = extended - - self.output_fname = output_filename - self.output_file = None - if self.output_fname: - fs = get_filesystem(self.output_fname) - self.output_file = fs.open(self.output_fname, "w") - - streaming_out = [self.output_file.write] if self.output_file else [log.info] self.start_time = time.monotonic() - super().__init__(output_streams=streaming_out) def start(self, action_name: str) -> None: if action_name in self.current_actions: @@ -183,14 +268,18 @@ def stop(self, action_name: str) -> None: duration = end_time - start_time self.recorded_durations[action_name].append(duration) - def make_report(self): + def _make_report(self) -> Tuple[list, float]: total_duration = time.monotonic() - self.start_time report = [[a, d, 100. * np.sum(d) / total_duration] for a, d in self.recorded_durations.items()] report.sort(key=lambda x: x[2], reverse=True) return report, total_duration def summary(self) -> str: - output_string = "\n\nProfiler Report\n" + sep = os.linesep + output_string = "" + if self._stage is not None: + output_string += f"{self._stage.upper()} " + output_string += f"Profiler Report{sep}" if self.extended: @@ -198,16 +287,16 @@ def summary(self) -> str: max_key = np.max([len(k) for k in self.recorded_durations.keys()]) def log_row(action, mean, num_calls, total, per): - row = f"{os.linesep}{action:<{max_key}s}\t| {mean:<15}\t|" + row = f"{sep}{action:<{max_key}s}\t| {mean:<15}\t|" row += f"{num_calls:<15}\t| {total:<15}\t| {per:<15}\t|" return row output_string += log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %") output_string_len = len(output_string) - output_string += f"{os.linesep}{'-' * output_string_len}" - report, total_duration = self.make_report() + output_string += f"{sep}{'-' * output_string_len}" + report, total_duration = self._make_report() output_string += log_row("Total", "-", "_", f"{total_duration:.5}", "100 %") - output_string += f"{os.linesep}{'-' * output_string_len}" + output_string += f"{sep}{'-' * output_string_len}" for action, durations, duration_per in report: output_string += log_row( action, @@ -219,14 +308,14 @@ def log_row(action, mean, num_calls, total, per): else: def log_row(action, mean, total): - return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}" + return f"{sep}{action:<20s}\t| {mean:<15}\t| {total:<15}" output_string += log_row("Action", "Mean duration (s)", "Total time (s)") - output_string += f"{os.linesep}{'-' * 65}" + output_string += f"{sep}{'-' * 65}" for action, durations in self.recorded_durations.items(): output_string += log_row(action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}") - output_string += os.linesep + output_string += sep return output_string @@ -237,11 +326,22 @@ class AdvancedProfiler(BaseProfiler): verbose and you should only use this if you want very detailed reports. """ - def __init__(self, output_filename: Optional[str] = None, line_count_restriction: float = 1.0): + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + line_count_restriction: float = 1.0, + output_filename: Optional[str] = None, + ) -> None: """ Args: - output_filename: optionally save profile results to file instead of printing - to std out when training is finished. + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. + + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. + line_count_restriction: this can be used to limit the number of functions reported for each action. either an integer (to select a count of lines), or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines) @@ -250,18 +350,10 @@ def __init__(self, output_filename: Optional[str] = None, line_count_restriction ValueError: If you attempt to stop recording an action which was never started. """ - self.profiled_actions = {} + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) + self.profiled_actions: Dict[str, cProfile.Profile] = {} self.line_count_restriction = line_count_restriction - self.output_fname = output_filename - self.output_file = None - if self.output_fname: - fs = get_filesystem(self.output_fname) - self.output_file = fs.open(self.output_fname, "w") - - streaming_out = [self.output_file.write] if self.output_file else [log.info] - super().__init__(output_streams=streaming_out) - def start(self, action_name: str) -> None: if action_name not in self.profiled_actions: self.profiled_actions[action_name] = cProfile.Profile() @@ -270,9 +362,7 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: pr = self.profiled_actions.get(action_name) if pr is None: - raise ValueError( # pragma: no-cover - f"Attempting to stop recording an action ({action_name}) which was never started." - ) + raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.") pr.disable() def summary(self) -> str: @@ -282,10 +372,16 @@ def summary(self) -> str: ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative') ps.print_stats(self.line_count_restriction) recorded_stats[action_name] = s.getvalue() + return self._stats_to_str(recorded_stats) - # log to standard out - output_string = f"{os.linesep}Profiler Report{os.linesep}" - for action, stats in recorded_stats.items(): - output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}" + def teardown(self, stage: Optional[str] = None) -> None: + super().teardown(stage=stage) + self.profiled_actions = {} - return output_string + def __reduce__(self): + # avoids `TypeError: cannot pickle 'cProfile.Profile' object` + return ( + self.__class__, + tuple(), + dict(dirpath=self.dirpath, filename=self.filename, line_count_restriction=self.line_count_restriction), + ) diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index c35979fa918af..55b1c286789f4 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -16,13 +16,12 @@ import inspect import logging import os -from typing import List, Optional +from pathlib import Path +from typing import List, Optional, Union import torch from pytorch_lightning.profiler.profilers import BaseProfiler -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -46,7 +45,8 @@ class PyTorchProfiler(BaseProfiler): def __init__( self, - output_filename: Optional[str] = None, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, enabled: bool = True, use_cuda: bool = False, record_shapes: bool = False, @@ -61,18 +61,19 @@ def __init__( row_limit: int = 20, sort_by_key: Optional[str] = None, profiled_functions: Optional[List] = None, - local_rank: Optional[int] = None, + output_filename: Optional[str] = None, ): """ This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of different operators inside your model - both on the CPU and GPU Args: + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. - output_filename: optionally save profile results to file instead of printing - to std out when training is finished. When using ``ddp``, - each rank will stream the profiled operation to their own file - with the extension ``_{rank}.txt`` + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. enabled: Setting this to False makes this context manager a no-op. @@ -116,13 +117,9 @@ def __init__( profiled_functions: list of profiled functions which will create a context manager on. Any other will be pass through. - local_rank: When running in distributed setting, local_rank is used for each process - to write to their own file if `output_fname` is provided. - Raises: MisconfigurationException: - If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``, or - if log file is not a ``.txt`` file. + If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``. ValueError: If you attempt to stop recording an action which was never started. """ @@ -159,37 +156,20 @@ def __init__( self.running_stack = [] self.profiler = None - self.output_fname = output_filename - self.output_file = None - if local_rank is not None: - self.setup(local_rank=local_rank) - self.setup = super().setup + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) - def setup(self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None): + def setup( + self, + stage: Optional[str] = None, + local_rank: Optional[int] = None, + log_dir: Optional[str] = None + ) -> None: super().setup(stage=stage, local_rank=local_rank, log_dir=log_dir) - # when logging to `log.info`, only perform profiling on rank 0 - if local_rank != 0 and self.output_fname is None: - self.wrap_functions_into_rank_zero_only() - - if self.output_fname: - if local_rank is not None: - if '.txt' not in self.output_fname: - raise MisconfigurationException("Log file should be .txt file.") - - self.output_fname = self.output_fname.replace(".txt", f"_{self.local_rank}.txt") - - fs = get_filesystem(self.output_fname) - self.output_file = fs.open(self.output_fname, "w") - - streaming_out = [self.output_file.write] if self.output_file else [log.info] - super().__init__(output_streams=streaming_out) - - def wrap_functions_into_rank_zero_only(self): - self.start = rank_zero_only(self.start) - self.stop = rank_zero_only(self.stop) - self.summary = rank_zero_only(self.summary) - self.describe = rank_zero_only(self.describe) + # if the user didn't provide `path_to_export_trace`, + # set it as TensorBoardLogger log_dir if exists + if self.path_to_export_trace is None: + self.path_to_export_trace = log_dir def start(self, action_name: str) -> None: if action_name not in self.profiled_functions: @@ -231,6 +211,7 @@ def _stop(self, action_name: str) -> None: # when running ``emit_nvtx``, PyTorch requires 2 context manager. # The parent_profiler is being closed too. self._parent_profiler.__exit__(None, None, None) + self._parent_profiler = None return function_events = self.profiler.function_events @@ -258,7 +239,6 @@ def stop(self, action_name: str) -> None: def summary(self) -> str: recorded_stats = {} output_string = '' - local_rank = '0' if self.local_rank is None else self.local_rank if not self.enabled: return output_string @@ -271,7 +251,7 @@ def summary(self) -> str: function_events.populate_cpu_children = lambda: None if self.export_to_chrome: - filename = f"{action_name}_{local_rank}_trace.json" + filename = f"{action_name}_{self.local_rank}_trace.json" path_to_trace = filename if self.path_to_export_trace is None \ else os.path.join(self.path_to_export_trace, filename) function_events.export_chrome_trace(path_to_trace) @@ -283,10 +263,4 @@ def summary(self) -> str: data = function_events.key_averages(group_by_input_shapes=self.group_by_input_shapes) table = data.table(sort_by=self.sort_by_key, row_limit=self.row_limit) recorded_stats[action_name] = table - - # log to standard out - output_string = f"{os.linesep}Profiler Report{os.linesep}" - for action, stats in recorded_stats.items(): - output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}") - - return output_string + return self._stats_to_str(recorded_stats) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 20c842939fe17..da41b9855b44a 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -15,6 +15,7 @@ import torch from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.model_helpers import is_overridden @@ -99,6 +100,10 @@ def on_evaluation_end(self, *args, **kwargs): else: self.trainer.call_hook('on_validation_end', *args, **kwargs) + if self.trainer.state != TrainerState.FITTING: + # summarize profile results + self.trainer.profiler.describe() + def reload_evaluation_dataloaders(self): model = self.trainer.lightning_module if self.trainer.testing: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 384a1b67a64f8..cc471f76b6033 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -137,9 +137,7 @@ def on_train_end(self): self.trainer.logger.finalize("success") # summarize profile results - # todo (tchaton) All ranks should call describe. - if self.trainer.global_rank == 0: - self.trainer.profiler.describe() + self.trainer.profiler.describe() # give accelerators a chance to finish self.trainer.accelerator.on_train_end() diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index f449a37e33c25..0c5f581d7775c 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -20,6 +20,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.profiler import BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache from tests.deprecated_api import no_deprecated_call from tests.helpers import BoringModel @@ -203,3 +204,12 @@ def on_test_epoch_end(self, outputs): model = NewSignatureModel() with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."): trainer.test(model) + + +@pytest.mark.parametrize("cls", (BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) +def test_v1_5_0_profiler_output_filename(tmpdir, cls): + filepath = str(tmpdir / "test.txt") + with pytest.deprecated_call(match="`output_filename` parameter has been removed"): + profiler = cls(output_filename=filepath) + assert profiler.dirpath == tmpdir + assert profiler.filename == "test" diff --git a/tests/test_profiler.py b/tests/test_profiler.py index cc4fff3b7ede4..cf6afcc9b626c 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -14,6 +14,7 @@ import logging import os import time +from copy import deepcopy from distutils.version import LooseVersion from pathlib import Path @@ -21,8 +22,7 @@ import pytest import torch -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import Callback +from pytorch_lightning import Callback, Trainer from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -46,8 +46,7 @@ def _sleep_generator(durations): @pytest.fixture def simple_profiler(): - profiler = SimpleProfiler() - return profiler + return SimpleProfiler() @pytest.mark.parametrize(["action", "expected"], [ @@ -93,14 +92,6 @@ def test_simple_profiler_overhead(simple_profiler, n_iter=5): assert all(durations < PROFILER_OVERHEAD_MAX_TOLERANCE) -def test_simple_profiler_describe(caplog, simple_profiler): - """Ensure the profiler won't fail when reporting the summary.""" - with caplog.at_level(logging.INFO): - simple_profiler.describe() - - assert "Profiler Report" in caplog.text - - def test_simple_profiler_value_errors(simple_profiler): """Ensure errors are raised where expected.""" @@ -116,10 +107,75 @@ def test_simple_profiler_value_errors(simple_profiler): simple_profiler.stop(action) +def test_simple_profiler_deepcopy(tmpdir): + simple_profiler = SimpleProfiler(dirpath=tmpdir, filename="test") + simple_profiler.describe() + assert deepcopy(simple_profiler) + + +def test_simple_profiler_log_dir(tmpdir): + """Ensure the profiler dirpath defaults to `trainer.log_dir` when not present""" + profiler = SimpleProfiler(filename="profiler") + assert profiler._log_dir is None + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + profiler=profiler, + ) + trainer.fit(model) + + expected = profiler.dirpath + assert trainer.log_dir == expected + assert profiler._log_dir == trainer.log_dir + assert Path(os.path.join(profiler.dirpath, "fit-profiler.txt")).exists() + + +@RunIf(skip_windows=True) +def test_simple_profiler_distributed_files(tmpdir): + """Ensure the proper files are saved in distributed""" + profiler = SimpleProfiler(dirpath=tmpdir, filename='profiler') + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=2, + accelerator="ddp_cpu", + num_processes=2, + profiler=profiler, + logger=False, + ) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + + actual = set(os.listdir(profiler.dirpath)) + expected = {f"{stage}-profiler-{rank}.txt" for stage in ("fit", "validate", "test") for rank in (0, 1)} + assert actual == expected + + for f in profiler.dirpath.listdir(): + assert f.read_text('utf-8') + + +def test_simple_profiler_logs(tmpdir, caplog, simple_profiler): + """Ensure that the number of printed logs is correct""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=2, + profiler=simple_profiler, + logger=False, + ) + with caplog.at_level(logging.INFO, logger="pytorch_lightning.profiler.profilers"): + trainer.fit(model) + trainer.test(model) + + assert caplog.text.count("Profiler Report") == 2 + + @pytest.fixture def advanced_profiler(tmpdir): - profiler = AdvancedProfiler(output_filename=os.path.join(tmpdir, "profiler.txt")) - return profiler + return AdvancedProfiler(dirpath=tmpdir, filename="profiler") @pytest.mark.parametrize(["action", "expected"], [ @@ -180,7 +236,8 @@ def test_advanced_profiler_describe(tmpdir, advanced_profiler): pass # log to stdout and print to file advanced_profiler.describe() - data = Path(advanced_profiler.output_fname).read_text() + path = advanced_profiler.dirpath / f"{advanced_profiler.filename}.txt" + data = path.read_text("utf-8") assert len(data) > 0 @@ -195,10 +252,14 @@ def test_advanced_profiler_value_errors(advanced_profiler): advanced_profiler.stop(action) +def test_advanced_profiler_deepcopy(advanced_profiler): + advanced_profiler.describe() + assert deepcopy(advanced_profiler) + + @pytest.fixture def pytorch_profiler(tmpdir): - profiler = PyTorchProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"), local_rank=0) - return profiler + return PyTorchProfiler(dirpath=tmpdir, filename="profiler") def test_pytorch_profiler_describe(pytorch_profiler): @@ -208,7 +269,8 @@ def test_pytorch_profiler_describe(pytorch_profiler): # log to stdout and print to file pytorch_profiler.describe() - data = Path(pytorch_profiler.output_fname).read_text() + path = pytorch_profiler.dirpath / f"{pytorch_profiler.filename}.txt" + data = path.read_text("utf-8") assert len(data) > 0 @@ -223,47 +285,53 @@ def test_pytorch_profiler_value_errors(pytorch_profiler): pytorch_profiler.stop(action) -@RunIf(min_gpus=2, special=True) -@pytest.mark.parametrize("use_output_filename", [False, True]) -def test_pytorch_profiler_trainer_ddp(tmpdir, use_output_filename): - """Ensure that the profiler can be given to the training and default step are properly recorded. """ - - if use_output_filename: - output_filename = os.path.join(tmpdir, "profiler.txt") - else: - output_filename = None +@RunIf(min_torch="1.6.0") +def test_advanced_profiler_cprofile_deepcopy(tmpdir): + """Checks for pickle issue reported in #6522""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + profiler="advanced", + stochastic_weight_avg=True, + ) + trainer.fit(model) - profiler = PyTorchProfiler(output_filename=output_filename) +@RunIf(min_gpus=2, special=True) +def test_pytorch_profiler_trainer_ddp(tmpdir): + """Ensure that the profiler can be given to the training and default step are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=None, filename="profiler") model = BoringModel() trainer = Trainer( - fast_dev_run=True, - profiler=profiler, + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + profiler=pytorch_profiler, accelerator="ddp", gpus=2, ) trainer.fit(model) - enabled = use_output_filename or not use_output_filename and profiler.local_rank == 0 + assert len(pytorch_profiler.summary()) > 0 + assert set(pytorch_profiler.profiled_actions) == {'training_step_and_backward', 'validation_step'} - if enabled: - assert len(profiler.summary()) > 0 - assert set(profiler.profiled_actions.keys()) == {'training_step_and_backward', 'validation_step'} - else: - assert profiler.summary() is None - assert set(profiler.profiled_actions.keys()) == set() + files = sorted(f for f in os.listdir(pytorch_profiler.dirpath) if "fit" in f) + rank = int(os.getenv("LOCAL_RANK", "0")) + expected = f"fit-profiler-{rank}.txt" + assert files[rank] == expected - # todo (tchaton) add support for all ranks - if use_output_filename and os.getenv("LOCAL_RANK") == "0": - data = Path(profiler.output_fname).read_text() - assert len(data) > 0 + path = os.path.join(pytorch_profiler.dirpath, expected) + data = Path(path).read_text("utf-8") + assert len(data) > 0 def test_pytorch_profiler_nested(tmpdir): """Ensure that the profiler handles nested context""" pytorch_profiler = PyTorchProfiler( - profiled_functions=["a", "b", "c"], use_cuda=False, output_filename=os.path.join(tmpdir, "profiler.txt") + profiled_functions=["a", "b", "c"], use_cuda=False, dirpath=tmpdir, filename="profiler" ) with pytorch_profiler.profile("a"): @@ -327,13 +395,18 @@ def test_profiler_teardown(tmpdir, cls): class TestCallback(Callback): - def on_fit_end(self, trainer, pl_module) -> None: - assert trainer.profiler.output_file is not None - - profiler = cls(output_filename=os.path.join(tmpdir, "profiler.txt")) + def on_fit_end(self, trainer, *args, **kwargs) -> None: + # describe sets it to None + assert trainer.profiler._output_file is None + profiler = cls(dirpath=tmpdir, filename="profiler") model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()]) trainer.fit(model) - assert profiler.output_file is None + assert profiler._output_file is None + + +def test_pytorch_profiler_deepcopy(pytorch_profiler): + pytorch_profiler.describe() + assert deepcopy(pytorch_profiler) diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 4dc5b5f34b50c..3eb0596b55fc4 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -13,7 +13,6 @@ # limitations under the License. from pytorch_lightning import Trainer -from tests.accelerators import DDPLauncher from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -84,8 +83,7 @@ def test_get_model_gpu(tmpdir): @RunIf(min_gpus=1, skip_windows=True) -@DDPLauncher.run("--accelerator [accelerator]", max_epochs=["1"], accelerator=["ddp", "ddp_spawn"]) -def test_get_model_ddp_gpu(tmpdir, args=None): +def test_get_model_ddp_gpu(tmpdir): """ Tests that `trainer.lightning_module` extracts the model correctly when using GPU + ddp accelerators """ @@ -99,7 +97,6 @@ def test_get_model_ddp_gpu(tmpdir, args=None): limit_val_batches=2, max_epochs=1, gpus=1, - accelerator=args.accelerator ) trainer.fit(model) return 1