diff --git a/CHANGELOG.md b/CHANGELOG.md index e1106189e0c17..1b3359ace54f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for including module names for forward in the autograd trace of `PyTorchProfiler` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) +- Added support for the PyTorch 1.8.1 autograd profiler ([#6618](https://github.com/PyTorchLightning/pytorch-lightning/pull/6618)) + + - Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) diff --git a/pl_examples/basic_examples/profiler_example.py b/pl_examples/basic_examples/profiler_example.py new file mode 100644 index 0000000000000..ca640a96f9588 --- /dev/null +++ b/pl_examples/basic_examples/profiler_example.py @@ -0,0 +1,102 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script will generate 2 traces: one for `training_step` and one for `validation_step`. +The traces can be visualized in 2 ways: +* With Chrome: + 1. Open Chrome and copy/paste this url: `chrome://tracing/`. + 2. Once tracing opens, click on `Load` at the top-right and load one of the generated traces. +* With PyTorch Tensorboard Profiler (Instructions are here: https://github.com/pytorch/kineto/tree/master/tb_plugin) + 1. pip install tensorboard torch-tb-profiler + 2. tensorboard --logdir={FOLDER} +""" + +import sys +from argparse import ArgumentParser + +import torch +import torchvision +import torchvision.models as models +import torchvision.transforms as T + +from pl_examples import cli_lightning_logo +from pytorch_lightning import LightningDataModule, LightningModule, Trainer + +DEFAULT_CMD_LINE = ( + "--max_epochs", + "1", + "--limit_train_batches", + "15", + "--limit_val_batches", + "15", + "--profiler", + "pytorch", + "--gpus", + f"{int(torch.cuda.is_available())}", +) + + +class ModelToProfile(LightningModule): + + def __init__(self, model): + super().__init__() + self.model = model + self.criterion = torch.nn.CrossEntropyLoss() + + def training_step(self, batch, batch_idx): + inputs, labels = batch + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + self.log("train_loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + inputs, labels = batch + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + self.log("val_loss", loss) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9) + + +class CIFAR10DataModule(LightningDataModule): + + transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()]) + + def train_dataloader(self, *args, **kwargs): + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=self.transform) + return torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=0) + + def val_dataloader(self, *args, **kwargs): + valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=self.transform) + return torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, num_workers=0) + + +def cli_main(): + + parser = ArgumentParser() + parser = Trainer.add_argparse_args(parser) + cmd_line = None if len(sys.argv) != 1 else DEFAULT_CMD_LINE + args = parser.parse_args(args=cmd_line) + + model = ModelToProfile(models.resnet50(pretrained=True)) + datamodule = CIFAR10DataModule() + trainer = Trainer(**vars(args)) + trainer.fit(model, datamodule=datamodule) + + +if __name__ == '__main__': + cli_lightning_logo() + cli_main() diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 4aa5fedf2b210..1dcd541ca0610 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -448,8 +448,10 @@ def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: Lightn .. deprecated::v1.3 Will be removed in v1.5.0. """ - rank_zero_warn('Accelerator method `connect_training_type_plugin` was deprecated in v1.3.' - ' It will be removed in v1.5.') + rank_zero_warn( + 'Accelerator method `connect_training_type_plugin` was deprecated in v1.3.' + ' It will be removed in v1.5.' + ) self.setup_training_type_plugin(plugin, model) # todo: remove in v1.5 @@ -459,6 +461,8 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: .. deprecated::v1.3 Will be removed in v1.5.0. """ - rank_zero_warn('Accelerator method `connect_precision_plugin` was deprecated in v1.3.' - ' It will be removed in v1.5.') + rank_zero_warn( + 'Accelerator method `connect_precision_plugin` was deprecated in v1.3.' + ' It will be removed in v1.5.' + ) self.setup_precision_plugin(plugin) diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py index e09a5ea11a084..6ac6e16c18529 100644 --- a/pytorch_lightning/profiler/__init__.py +++ b/pytorch_lightning/profiler/__init__.py @@ -121,7 +121,8 @@ def custom_processing_step(self, data): Autograd includes a profiler that lets you inspect the cost of different operators inside your model - both on the CPU and GPU. -Find the Pytorch Profiler doc at [PyTorch Profiler](https://pytorch-lightning.readthedocs.io/en/stable/profiler.html) +To read more about the PyTorch Profiler and all its options, +have a look at its `docs `__ .. code-block:: python @@ -134,16 +135,16 @@ def custom_processing_step(self, data): This profiler works with PyTorch ``DistributedDataParallel``. -If ``output_filename`` is provided, each rank will save their profiled operation to their own file. +If ``filename`` is provided, each rank will save their profiled operation to their own file. The profiler +report can be quite long, so you setting a ``filename`` will save the report instead of logging it to the +output in your terminal. If no filename is given, it will be logged only on rank 0. +The profiler's results will be printed on the completion of ``{fit,validate,test,predict}``. -The profiler's results will be printed on the completion of a training `fit()`. This profiler -report can be quite long, so you can also specify an `output_filename` to save the report instead -of logging it to the output in your terminal. - -This profiler will record only for `training_step_and_backward`, `evaluation_step` and `test_step` functions by default. -The output below shows the profiling for the action `training_step_and_backward`. -The user can provide ``PyTorchProfiler(profiled_functions=[...])`` to extend the scope of profiled functions. +This profiler will record ``training_step_and_backward``, ``training_step``, ``backward``, +``validation_step``, ``test_step``, and ``predict_step`` by default. +The output below shows the profiling for the action ``training_step_and_backward``. +The user can provide ``PyTorchProfiler(record_functions={...})`` to extend the scope of profiled functions. .. note:: When using the PyTorch Profiler, wall clock time will not not be representative of the true wall clock time. This is due to forcing profiled operations to be measured synchronously, when many CUDA ops happen asynchronously. It is recommended to use this Profiler to find bottlenecks/breakdowns, however for end to end wall clock time use the `SimpleProfiler`. # noqa E501 @@ -184,13 +185,13 @@ def custom_processing_step(self, data): To visualize the profiled operation, you can either: -* Use:: +Use:: nvvp trace_name.prof -* Use:: +Or:: - python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))' + python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))' """ diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index 46d72583fb466..bc9e3541dbaa8 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -120,14 +120,14 @@ def _rank_zero_info(self, *args, **kwargs) -> None: if self._local_rank in (None, 0): log.info(*args, **kwargs) - def _prepare_filename(self) -> str: + def _prepare_filename(self, extension: str = ".txt") -> str: filename = "" if self._stage is not None: filename += f"{self._stage}-" filename += str(self.filename) if self._local_rank is not None: filename += f"-{self._local_rank}" - filename += ".txt" + filename += extension return filename def _prepare_streams(self) -> None: diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 974883a4724c6..73abc1baf939d 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -17,7 +17,7 @@ import os from functools import partial from pathlib import Path -from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Set, Type, TYPE_CHECKING, Union import torch from torch import nn, Tensor @@ -26,6 +26,7 @@ from pytorch_lightning.profiler.profilers import BaseProfiler from pytorch_lightning.utilities.distributed import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE if TYPE_CHECKING: from torch.autograd.profiler import EventList @@ -33,6 +34,9 @@ from pytorch_lightning.core.lightning import LightningModule +if _KINETO_AVAILABLE: + from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler + log = logging.getLogger(__name__) _PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx] @@ -93,12 +97,108 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self._handles = {} +class ScheduleWrapper: + """ + This class is used to override the schedule logic from the profiler and perform + recording for both `training_step`, `validation_step`. + """ + + def __init__(self, schedule: Callable) -> None: + if not _KINETO_AVAILABLE: + raise ModuleNotFoundError("You are trying to use `ScheduleWrapper` which require kineto install.") + self._schedule = schedule + self._num_training_step_and_backward = 0 + self._num_validation_step = 0 + self._num_test_step = 0 + self._num_predict_step = 0 + self._training_step_and_backward_reached_end = False + self._validation_step_reached_end = False + self._test_step_reached_end = False + self._predict_step_reached_end = False + # used to stop profiler when `ProfilerAction.RECORD_AND_SAVE` is reached. + self._current_action: Optional[str] = None + self._start_action_name: Optional[str] = None + + def setup(self, start_action_name: str) -> None: + self._start_action_name = start_action_name + + def pre_step(self, current_action: str) -> None: + self._current_action = current_action + + @property + def num_step(self) -> int: + if self._current_action == "training_step_and_backward": + return self._num_training_step_and_backward + elif self._current_action == "validation_step": + return self._num_validation_step + elif self._current_action == "test_step": + return self._num_test_step + elif self._current_action == "predict_step": + return self._num_predict_step + else: + return 0 + + def _step(self) -> None: + if self._current_action == "training_step_and_backward": + self._num_training_step_and_backward += 1 + elif self._current_action == "validation_step": + if self._start_action_name == "on_train_start" and self._num_training_step_and_backward > 0: + self._num_validation_step += 1 + else: + self._num_validation_step += 1 + elif self._current_action == "test_step": + self._num_test_step += 1 + elif self._current_action == "predict_step": + self._num_predict_step += 1 + + @property + def has_finished(self) -> bool: + if self._current_action == "training_step_and_backward": + return self._training_step_and_backward_reached_end + elif self._current_action == "validation_step": + return self._validation_step_reached_end + elif self._current_action == "test_step": + return self._test_step_reached_end + elif self._current_action == "predict_step": + return self._predict_step_reached_end + return False + + def __call__(self, num_step: int) -> 'ProfilerAction': + # ignore the provided input. Keep internal state instead. + if self.has_finished: + return ProfilerAction.NONE + + self._step() + action = self._schedule(self.num_step) + if action == ProfilerAction.RECORD_AND_SAVE: + if self._current_action == "training_step_and_backward": + self._training_step_and_backward_reached_end = True + elif self._current_action == "validation_step": + self._validation_step_reached_end = True + elif self._current_action == "test_step": + self._test_step_reached_end = True + elif self._current_action == "predict_step": + self._predict_step_reached_end = True + return action + + class PyTorchProfiler(BaseProfiler): - RECORD_FUNCTIONS = ( - "training_step_and_backward", "training_step", "backward", "validation_step", "test_step", "predict_step" - ) - AVAILABLE_SORT_KEYS = ( + RECORD_FUNCTIONS = { + "training_step_and_backward", + "training_step", + "backward", + "validation_step", + "test_step", + "predict_step", + } + STEP_FUNCTIONS = { + "training_step_and_backward", + "validation_step", + "test_step", + "predict_step", + } + AVAILABLE_SORT_KEYS = { "cpu_time", "cuda_time", "cpu_time_total", @@ -108,8 +208,13 @@ class PyTorchProfiler(BaseProfiler): "self_cpu_memory_usage", "self_cuda_memory_usage", "count", - ) - START_RECORD_FUNCTIONS = ('on_train_start', 'on_validation_start', 'on_test_start', 'on_predict_start') + } + START_RECORD_FUNCTIONS = { + 'on_train_start', + 'on_validation_start', + 'on_test_start', + 'on_predict_start', + } def __init__( self, @@ -118,10 +223,9 @@ def __init__( group_by_input_shapes: bool = False, emit_nvtx: bool = False, export_to_chrome: bool = True, - path_to_export_trace: Optional[str] = None, row_limit: int = 20, sort_by_key: Optional[str] = None, - record_functions: List[str] = None, + record_functions: Set[str] = None, record_module_names: bool = True, profiled_functions: Optional[List] = None, output_filename: Optional[str] = None, @@ -154,9 +258,6 @@ def __init__( export_to_chrome: Whether to export the sequence of profiled operators for Chrome. It will generate a ``.json`` file which can be read by Chrome. - path_to_export_trace: Directory path to export ``.json`` traces when using ``export_to_chrome=True``. - By default, it will be save where the file being is being run. - row_limit: Limit the number of rows in a table, ``-1`` is a special value that removes the limit completely. @@ -166,7 +267,7 @@ def __init__( ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``, ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``. - record_functions: list of profiled functions which will create a context manager on. + record_functions: Set of profiled functions which will create a context manager on. Any other will be pass through. record_module_names: Whether to add module names while recording autograd operation. @@ -176,6 +277,8 @@ def __init__( Raises: MisconfigurationException: If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``. + If arg ``schedule`` is not a ``Callable``. + If arg ``schedule`` does not return a ``torch.profiler.ProfilerAction``. """ super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) @@ -184,11 +287,10 @@ def __init__( self._group_by_input_shapes = group_by_input_shapes and profiler_kwargs.get("record_shapes", False) self._emit_nvtx = emit_nvtx self._export_to_chrome = export_to_chrome - self._path_to_export_trace = path_to_export_trace self._row_limit = row_limit self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total" - self._record_functions_start = set(record_functions + list(self.START_RECORD_FUNCTIONS)) - self._record_functions = set(record_functions + list(self.RECORD_FUNCTIONS)) + self._record_functions_start = record_functions | self.START_RECORD_FUNCTIONS + self._record_functions = record_functions | self.RECORD_FUNCTIONS self._record_module_names = record_module_names self._profiler_kwargs = profiler_kwargs @@ -198,25 +300,48 @@ def __init__( self._register: Optional[RegisterRecordFunction] = None self._parent_profiler: Optional[_PROFILER] = None self._recording_map: Dict[str, record_function] = {} + self._start_action_name: Optional[str] = None + self._schedule: Optional[ScheduleWrapper] = None - if self._export_to_chrome and self._path_to_export_trace is None: - rank_zero_warn( - "The exported trace would be saved locally as `path_to_export_trace` is None." - " Note: Each functions will generate its own traced file." - ) + if _KINETO_AVAILABLE: + self.__init_kineto__(profiler_kwargs) if self._sort_by_key not in self.AVAILABLE_SORT_KEYS: raise MisconfigurationException( f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " ) + def __init_kineto__(self, profiler_kwargs: Any): + has_schedule = "schedule" in profiler_kwargs + self._has_on_trace_ready = "on_trace_ready" in profiler_kwargs + + schedule = profiler_kwargs.get("schedule", None) + if schedule is not None: + if not isinstance(schedule, Callable): + raise MisconfigurationException(f"Schedule should be a callable. Found: {schedule}") + action = schedule(0) + if not isinstance(action, ProfilerAction): + raise MisconfigurationException( + f"Schedule should return a `torch.profiler.ProfilerAction`. Found: {action}" + ) + schedule = schedule if has_schedule else self._default_schedule() + self._schedule = ScheduleWrapper(schedule) if schedule is not None else schedule + self._profiler_kwargs["schedule"] = self._schedule + + activities = profiler_kwargs.get("activities", None) + self._profiler_kwargs["activities"] = activities or self._default_activities() + self._export_to_flame_graph = profiler_kwargs.get("export_to_flame_graph", False) + self._metric = profiler_kwargs.get("metric", "self_cpu_time_total") + with_stack = profiler_kwargs.get("with_stack", False) or self._export_to_flame_graph + self._profiler_kwargs["with_stack"] = with_stack + def __deprecation_check( self, profiled_functions: Optional[List[str]], - record_functions: Optional[List[str]], - ) -> List[str]: + record_functions: Optional[Set[str]], + ) -> Set[str]: if record_functions is None: - record_functions = [] + record_functions = set() if profiled_functions is not None: rank_zero_warn( @@ -224,7 +349,7 @@ def __deprecation_check( " `record_functions` in v1.3 and will be removed in v1.5", DeprecationWarning ) if not record_functions: - record_functions += profiled_functions + record_functions |= set(profiled_functions) else: raise MisconfigurationException( "You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`." @@ -233,15 +358,25 @@ def __deprecation_check( return record_functions - def setup( - self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None - ) -> None: - super().setup(stage=stage, local_rank=local_rank, log_dir=log_dir) - - # if the user didn't provide `path_to_export_trace`, - # set it as TensorBoardLogger log_dir if exists - if self._path_to_export_trace is None: - self._path_to_export_trace = log_dir + @staticmethod + def _default_schedule() -> Optional[callable]: + if _KINETO_AVAILABLE: + # Those schedule defaults allow the profiling overhead to be negligible over training time. + return torch.profiler.schedule(wait=1, warmup=1, active=2) + + def _default_activities(self) -> List['ProfilerActivity']: + activities = [] + if not _KINETO_AVAILABLE: + return activities + if self._profiler_kwargs.get("use_cpu", True): + activities.append(ProfilerActivity.CPU) + if self._profiler_kwargs.get("use_cuda", torch.cuda.is_available()): + activities.append(ProfilerActivity.CUDA) + return activities + + @property + def step_action_names(self) -> Set[str]: + return self.STEP_FUNCTIONS | self._record_functions def start(self, action_name: str) -> None: if self.profiler is None and action_name in self._record_functions_start: @@ -253,11 +388,18 @@ def start(self, action_name: str) -> None: except (AttributeError, RuntimeError): pass + if self._schedule is not None: + self._schedule.setup(action_name) + self._create_profilers() - self.profiler.__enter__() + profiler = self.profiler.__enter__() + if profiler is not None: + self.profiler = profiler + if self._parent_profiler is not None: self._parent_profiler.__enter__() + if self._register is not None: self._register.__enter__() @@ -269,11 +411,39 @@ def start(self, action_name: str) -> None: recording.__enter__() self._recording_map[action_name] = recording + if self._schedule is not None: + self._schedule.pre_step(action_name) + def stop(self, action_name: str) -> None: if action_name in self._recording_map: self._recording_map[action_name].__exit__(None, None, None) del self._recording_map[action_name] + if not _KINETO_AVAILABLE or self._emit_nvtx: + return + + if action_name in self.step_action_names: + if self._schedule is not None: + self._schedule._current_action = action_name + + def on_trace_ready(profiler): + filename = f"{action_name}_{self.local_rank}" + + if self.dirpath is not None: + if self._export_to_chrome: + handler = tensorboard_trace_handler(self.dirpath, filename) + handler(profiler) + + if self._export_to_flame_graph: + path = os.path.join(self.dirpath, self._prepare_filename(extension=".stack")) + profiler.export_stacks(path, metric=self._metric) + else: + rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None") + + if not self._has_on_trace_ready: + self.profiler.on_trace_ready = on_trace_ready + self.profiler.step() + def summary(self) -> str: if not self._profiler_kwargs.get("enabled", True) or self._emit_nvtx: return "" @@ -283,11 +453,9 @@ def summary(self) -> str: if not self.function_events: return "" - if self._export_to_chrome: + if self._export_to_chrome and not _KINETO_AVAILABLE: filename = f"{self.local_rank}_trace.json" - path_to_trace = ( - filename if self._path_to_export_trace is None else os.path.join(self._path_to_export_trace, filename) - ) + path_to_trace = (filename if self.dirpath is None else os.path.join(self.dirpath, filename)) self.function_events.export_chrome_trace(path_to_trace) data = self.function_events.key_averages(group_by_input_shapes=self._group_by_input_shapes) @@ -302,7 +470,9 @@ def _create_profilers(self) -> None: self.profiler = self._create_profiler(torch.autograd.profiler.emit_nvtx) else: self._parent_profiler = None - self.profiler = self._create_profiler(torch.autograd.profiler.profile) + self.profiler = self._create_profiler( + torch.profiler.profile if _KINETO_AVAILABLE else torch.autograd.profiler.profile + ) if self._record_module_names and self._lightning_module is not None: self._register = RegisterRecordFunction(self._lightning_module) @@ -311,9 +481,10 @@ def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER: kwargs = {k: v for k, v in self._profiler_kwargs.items() if k in init_parameters} return profiler(**kwargs) - def _cache_functions_events(self): - if not self._emit_nvtx: - self.function_events = self.profiler.function_events + def _cache_functions_events(self) -> None: + if self._emit_nvtx: + return + self.function_events = self.profiler.events() if _KINETO_AVAILABLE else self.profiler.function_events def _delete_profilers(self) -> None: if self.profiler is not None: diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 5a780660a0a99..baeac9be57218 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -70,6 +70,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") _TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0") +_KINETO_AVAILABLE = torch.profiler.kineto_available() if _TORCH_GREATER_EQUAL_1_8 else False _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available('pl_bolts') _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index bcb351984a175..46379a9d10c14 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -30,6 +30,7 @@ def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch): """ class TestModel(BoringModel): + def on_fit_start(self): if delay_dispatch: # Ensure we haven't setup optimizers if we've delayed dispatch @@ -41,14 +42,11 @@ def on_fit_end(self): assert len(self.trainer.optimizers) > 0 class CustomPlugin(SingleDevicePlugin): + @property def setup_optimizers_in_pre_dispatch(self) -> bool: return delay_dispatch model = TestModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - plugins=CustomPlugin(device=torch.device("cpu")) - ) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=CustomPlugin(device=torch.device("cpu"))) trainer.fit(model) diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index fe85fbaea9025..5483e33d9cddb 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -56,6 +56,7 @@ def __new__( *args, min_gpus: int = 0, min_torch: Optional[str] = None, + max_torch: Optional[str] = None, min_python: Optional[str] = None, quantization: bool = False, amp_apex: bool = False, @@ -76,6 +77,7 @@ def __new__( args: native pytest.mark.skipif arguments min_gpus: min number of gpus required to run test min_torch: minimum pytorch version to run test + max_torch: maximum pytorch version to run test min_python: minimum python version required to run test quantization: if `torch.quantization` package is required to run test amp_apex: NVIDIA Apex is installed @@ -102,6 +104,11 @@ def __new__( conditions.append(torch_version < LooseVersion(min_torch)) reasons.append(f"torch>={min_torch}") + if max_torch: + torch_version = LooseVersion(get_distribution("torch").version) + conditions.append(torch_version >= LooseVersion(max_torch)) + reasons.append(f"torch<{max_torch}") + if min_python: py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" conditions.append(py_version < LooseVersion(min_python)) diff --git a/tests/helpers/test_datasets.py b/tests/helpers/test_datasets.py index 42b5df0ff91a4..8c866bdbab789 100644 --- a/tests/helpers/test_datasets.py +++ b/tests/helpers/test_datasets.py @@ -20,11 +20,13 @@ from tests.helpers.datasets import AverageDataset, MNIST, TrialMNIST -@pytest.mark.parametrize('dataset_cls,args', [ - (MNIST, dict(root=PATH_DATASETS)), - (TrialMNIST, dict(root=PATH_DATASETS)), - (AverageDataset, dict()), -]) +@pytest.mark.parametrize( + 'dataset_cls,args', [ + (MNIST, dict(root=PATH_DATASETS)), + (TrialMNIST, dict(root=PATH_DATASETS)), + (AverageDataset, dict()), + ] +) def test_pickling_dataset_mnist(tmpdir, dataset_cls, args): mnist = dataset_cls(**args) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 5d144aef36573..a6e33b3366f33 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -17,7 +17,6 @@ import time from copy import deepcopy from distutils.version import LooseVersion -from pathlib import Path import numpy as np import pytest @@ -27,7 +26,7 @@ from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.profiler.pytorch import RegisterRecordFunction from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 +from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -266,6 +265,7 @@ def pytorch_profiler(tmpdir): return PyTorchProfiler(dirpath=tmpdir, filename="profiler") +@RunIf(max_torch="1.8.1") def test_pytorch_profiler_describe(pytorch_profiler): """Ensure the profiler won't fail when reporting the summary.""" with pytorch_profiler.profile("on_test_start"): @@ -302,30 +302,41 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler): """Ensure that the profiler can be given to the training and default step are properly recorded. """ model = BoringModel() trainer = Trainer( - max_epochs=1, default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, + max_epochs=1, + limit_train_batches=5, + limit_val_batches=5, profiler=pytorch_profiler, accelerator="ddp", gpus=2, ) trainer.fit(model) - expected = ('validation_step', 'training_step_and_backward', 'training_step', 'backward') + expected = {'validation_step'} + if not _KINETO_AVAILABLE: + expected |= {'training_step_and_backward', 'training_step', 'backward'} for name in expected: - assert sum(e.name == name for e in pytorch_profiler.function_events) + assert sum(e.name == name for e in pytorch_profiler.function_events), name files = set(os.listdir(pytorch_profiler.dirpath)) expected = f"fit-profiler-{trainer.local_rank}.txt" assert expected in files - path = os.path.join(pytorch_profiler.dirpath, expected) - assert Path(path).read_text() + path = pytorch_profiler.dirpath / expected + assert path.read_text("utf-8") + + if _KINETO_AVAILABLE: + files = os.listdir(pytorch_profiler.dirpath) + files = [file for file in files if file.endswith('.json')] + assert len(files) == 2, files + local_rank = trainer.local_rank + assert any(f'training_step_{local_rank}' in f for f in files) + assert any(f'validation_step_{local_rank}' in f for f in files) -def test_pytorch_profiler_trainer_test(tmpdir, pytorch_profiler): +def test_pytorch_profiler_trainer_test(tmpdir): """Ensure that the profiler can be given to the trainer and test step are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, @@ -340,27 +351,32 @@ def test_pytorch_profiler_trainer_test(tmpdir, pytorch_profiler): path = pytorch_profiler.dirpath / f"test-{pytorch_profiler.filename}.txt" assert path.read_text("utf-8") + if _KINETO_AVAILABLE: + files = sorted([file for file in os.listdir(tmpdir) if file.endswith('.json')]) + assert any(f'test_step_{trainer.local_rank}' in f for f in files) + -def test_pytorch_profiler_trainer_predict(tmpdir, pytorch_profiler): +def test_pytorch_profiler_trainer_predict(tmpdir): """Ensure that the profiler can be given to the trainer and predict function are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) model = BoringModel() model.predict_dataloader = model.train_dataloader trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - limit_test_batches=2, + limit_predict_batches=2, profiler=pytorch_profiler, ) trainer.predict(model) assert sum(e.name == 'predict_step' for e in pytorch_profiler.function_events) - path = pytorch_profiler.dirpath / f"predict-{pytorch_profiler.filename}.txt" assert path.read_text("utf-8") -def test_pytorch_profiler_trainer_validate(tmpdir, pytorch_profiler): +def test_pytorch_profiler_trainer_validate(tmpdir): """Ensure that the profiler can be given to the trainer and validate function are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, @@ -380,7 +396,7 @@ def test_pytorch_profiler_nested(tmpdir): """Ensure that the profiler handles nested context""" pytorch_profiler = PyTorchProfiler( - profiled_functions=["a", "b", "c"], use_cuda=False, dirpath=tmpdir, filename="profiler" + record_functions={"a", "b", "c"}, use_cuda=False, dirpath=tmpdir, filename="profiler", schedule=None ) with pytorch_profiler.profile("a"): @@ -409,11 +425,6 @@ def test_pytorch_profiler_nested(tmpdir): 'aten::zeros', 'aten::add', 'aten::zero_', 'c', 'b', 'a', 'aten::fill_', 'aten::empty', 'aten::ones' } - if LooseVersion(torch.__version__) >= LooseVersion("1.8.0"): - expected = { - 'aten::ones', 'a', 'aten::add', 'aten::empty', 'aten::zero_', 'b', 'c', 'aten::zeros', 'aten::fill_' - } - assert events_name == expected, (events_name, torch.__version__, platform.system()) @@ -439,20 +450,22 @@ def test_register_record_function(tmpdir): use_cuda = torch.cuda.is_available() pytorch_profiler = PyTorchProfiler( export_to_chrome=False, - record_functions=["a"], + record_functions={"a"}, use_cuda=use_cuda, dirpath=tmpdir, filename="profiler", + schedule=None, + on_trace_ready=None, ) class TestModel(BoringModel): def __init__(self): super().__init__() - self.layer = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.ReLU(), torch.nn.Linear(8, 2)) + self.layer = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), torch.nn.Linear(1, 1)) model = TestModel() - input = torch.rand((1, 8)) + input = torch.rand((1, 1)) if use_cuda: model = model.cuda() @@ -490,8 +503,8 @@ def on_fit_end(self, trainer, *args, **kwargs) -> None: assert profiler._output_file is None -@pytest.mark.skipif(_TORCH_GREATER_EQUAL_1_8, reason="currently not supported for PyTorch 1.8") -def test_pytorch_profiler_deepcopy(pytorch_profiler): +def test_pytorch_profiler_deepcopy(tmpdir): + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profiler", schedule=None) pytorch_profiler.start("on_train_start") torch.tensor(1) pytorch_profiler.describe() diff --git a/tests/utilities/test_argparse.py b/tests/utilities/test_argparse.py index aef266d639b4a..f13af4362364c 100644 --- a/tests/utilities/test_argparse.py +++ b/tests/utilities/test_argparse.py @@ -18,6 +18,7 @@ class ArgparseExample: + def __init__(self, a: int = 0, b: str = '', c: bool = False): self.a = a self.b = b