From 37f22c99ffc16ae4010ba7f2ff42f0b86cd1f0ad Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 22 Mar 2021 02:37:54 +0530 Subject: [PATCH 01/22] Add trainer.predict config validation (#6543) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 4 +- .../trainer/configuration_validator.py | 9 +++- tests/trainer/test_config_validator.py | 50 ++++++++++++++++++- 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d696535311c9d3..6004a28dd0829b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,8 +40,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) -- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) +- Added `Trainer.predict` config validation ([#6543](https://github.com/PyTorchLightning/pytorch-lightning/pull/6543)) + +- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) ### Changed diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 8c539b5ff478de..a7ba2b1c401232 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -40,7 +40,8 @@ def verify_loop_configurations(self, model: LightningModule) -> None: self.__verify_eval_loop_configuration(model, 'val') elif self.trainer.state == TrainerState.TESTING: self.__verify_eval_loop_configuration(model, 'test') - # TODO: add predict + elif self.trainer.state == TrainerState.PREDICTING: + self.__verify_predict_loop_configuration(model) def __verify_train_loop_configuration(self, model): # ----------------------------------- @@ -99,3 +100,9 @@ def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) - rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop') if has_step and not has_loader: rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop') + + def __verify_predict_loop_configuration(self, model: LightningModule) -> None: + + has_predict_dataloader = is_overridden('predict_dataloader', model) + if not has_predict_dataloader: + raise MisconfigurationException('Dataloader not found for `Trainer.predict`') diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index 59e10480a485e1..9fccd9b36440ae 100644 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest +import torch -from pytorch_lightning import Trainer +from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel +from tests.helpers import BoringModel, RandomDataset def test_wrong_train_setting(tmpdir): @@ -101,3 +102,48 @@ def test_val_loop_config(tmpdir): model = BoringModel() model.validation_step = None trainer.validate(model) + + +@pytest.mark.parametrize("datamodule", [False, True]) +def test_trainer_predict_verify_config(tmpdir, datamodule): + + class TestModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + class TestLightningDataModule(LightningDataModule): + + def __init__(self, dataloaders): + super().__init__() + self._dataloaders = dataloaders + + def test_dataloader(self): + return self._dataloaders + + def predict_dataloader(self): + return self._dataloaders + + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + + model = TestModel() + + trainer = Trainer(default_root_dir=tmpdir) + + if datamodule: + datamodule = TestLightningDataModule(dataloaders) + results = trainer.predict(model, datamodule=datamodule) + else: + results = trainer.predict(model, dataloaders=dataloaders) + + assert len(results) == 2 + assert results[0][0].shape == torch.Size([1, 2]) + + model.predict_dataloader = None + + with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"): + trainer.predict(model) From 42a7b7058573bc659eb1fd6a64035ae80e270211 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 22 Mar 2021 02:40:54 +0530 Subject: [PATCH 02/22] Add DDP Spawn being default for Multi GPUs (#6292) --- docs/source/advanced/multi_gpu.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index a2a74c7587ae3b..5cdb0b377f2b75 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -267,7 +267,7 @@ Lightning allows multiple ways of training - TPUs (``tpu_cores=8|x``) (tpu or TPU pod) .. note:: - If you request multiple GPUs or nodes without setting a mode, DDP will be automatically used. + If you request multiple GPUs or nodes without setting a mode, DDP Spawn will be automatically used. For a deeper understanding of what Lightning is doing, feel free to read this `guide `_. From 51c9260fad5b1ed3b4e41a9ebf460bf4c609fe2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 22 Mar 2021 00:39:55 +0100 Subject: [PATCH 03/22] Move profiler tests (#6619) --- tests/special_tests.sh | 4 +- tests/test_profiler.py | 143 ++++++++++++++++++++++++++++++++-- tests/trainer/test_trainer.py | 125 ----------------------------- 3 files changed, 138 insertions(+), 134 deletions(-) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index dd67af470c4ec5..3fe9d6c0e277c6 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -34,9 +34,9 @@ python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_ddp python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_dp python ${DEFAULTS} tests/trainer/logging_/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp python ${DEFAULTS} tests/callbacks/test_pruning.py::test_pruning_callback_ddp -python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_ddp +python ${DEFAULTS} tests/test_profiler.py::test_pytorch_profiler_trainer_ddp python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp python ${DEFAULTS} tests/trainer/test_data_loading.py::test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model python ${DEFAULTS} tests/checkpointing/test_checkpoint_callback_frequency.py::test_top_k_ddp -nvprof --profile-from-start off -o trace_name.prof -- python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_nested_emit_nvtx +nvprof --profile-from-start off -o trace_name.prof -- python ${DEFAULTS} tests/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 9b51ca7f7c6d2c..5221c0cbf7bf68 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -14,12 +14,17 @@ import logging import os import time +from distutils.version import LooseVersion from pathlib import Path import numpy as np import pytest +import torch -from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler +from pytorch_lightning import Trainer +from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler, PyTorchProfiler +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005 @@ -44,12 +49,6 @@ def simple_profiler(): return profiler -@pytest.fixture -def advanced_profiler(tmpdir): - profiler = AdvancedProfiler(output_filename=os.path.join(tmpdir, "profiler.txt")) - return profiler - - @pytest.mark.parametrize(["action", "expected"], [ pytest.param("a", [3, 1]), pytest.param("b", [2]), @@ -116,6 +115,12 @@ def test_simple_profiler_value_errors(simple_profiler): simple_profiler.stop(action) +@pytest.fixture +def advanced_profiler(tmpdir): + profiler = AdvancedProfiler(output_filename=os.path.join(tmpdir, "profiler.txt")) + return profiler + + @pytest.mark.parametrize(["action", "expected"], [ pytest.param("a", [3, 1]), pytest.param("b", [2]), @@ -187,3 +192,127 @@ def test_advanced_profiler_value_errors(advanced_profiler): advanced_profiler.start(action) advanced_profiler.stop(action) + + +@pytest.fixture +def pytorch_profiler(tmpdir): + profiler = PyTorchProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"), local_rank=0) + return profiler + + +def test_pytorch_profiler_describe(pytorch_profiler): + """Ensure the profiler won't fail when reporting the summary.""" + with pytorch_profiler.profile("test_step"): + pass + + # log to stdout and print to file + pytorch_profiler.describe() + data = Path(pytorch_profiler.output_fname).read_text() + assert len(data) > 0 + + +def test_pytorch_profiler_value_errors(pytorch_profiler): + """Ensure errors are raised where expected.""" + + action = "test_step" + with pytest.raises(ValueError): + pytorch_profiler.stop(action) + + pytorch_profiler.start(action) + pytorch_profiler.stop(action) + + +@RunIf(min_gpus=2, special=True) +@pytest.mark.parametrize("use_output_filename", [False, True]) +def test_pytorch_profiler_trainer_ddp(tmpdir, use_output_filename): + """Ensure that the profiler can be given to the training and default step are properly recorded. """ + + if use_output_filename: + output_filename = os.path.join(tmpdir, "profiler.txt") + else: + output_filename = None + + profiler = PyTorchProfiler(output_filename=output_filename) + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + profiler=profiler, + accelerator="ddp", + gpus=2, + ) + trainer.fit(model) + + enabled = use_output_filename or not use_output_filename and profiler.local_rank == 0 + + if enabled: + assert len(profiler.summary()) > 0 + assert set(profiler.profiled_actions.keys()) == {'training_step_and_backward', 'validation_step'} + else: + assert profiler.summary() is None + assert set(profiler.profiled_actions.keys()) == set() + + if use_output_filename: + profiler.describe() + data = Path(profiler.output_fname).read_text() + assert len(data) > 0 + + +def test_pytorch_profiler_nested(tmpdir): + """Ensure that the profiler handles nested context""" + + pytorch_profiler = PyTorchProfiler( + profiled_functions=["a", "b", "c"], use_cuda=False, output_filename=os.path.join(tmpdir, "profiler.txt") + ) + + with pytorch_profiler.profile("a"): + a = torch.ones(42) + with pytorch_profiler.profile("b"): + b = torch.zeros(42) + with pytorch_profiler.profile("c"): + _ = a + b + + pa = pytorch_profiler.profiled_actions + + # From PyTorch 1.8.0, less operation are being traced. + if LooseVersion(torch.__version__) >= LooseVersion("1.8.0"): + expected_ = { + 'a': ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'add'], + 'b': ['zeros', 'empty', 'zero_'], + 'c': ['add'], + } + # From PyTorch 1.6.0, more operation are being traced. + elif LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + expected_ = { + 'a': ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'fill_', 'add', 'empty'], + 'b': ['zeros', 'empty', 'zero_', 'fill_'], + 'c': ['add', 'empty'], + } + else: + expected_ = { + 'a': ['add'], + 'b': [], + 'c': ['add'], + } + + for n in ('a', 'b', 'c'): + pa[n] = [e.name for e in pa[n]] + if LooseVersion(torch.__version__) >= LooseVersion("1.7.1"): + pa[n] = [e.replace("aten::", "") for e in pa[n]] + assert pa[n] == expected_[n] + + +@RunIf(min_gpus=1, special=True) +def test_pytorch_profiler_nested_emit_nvtx(tmpdir): + """ + This test check emit_nvtx is correctly supported + """ + profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True) + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + profiler=profiler, + gpus=1, + ) + trainer.fit(model) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3375b02c5496b0..66889bb7e11390 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -17,7 +17,6 @@ import sys from argparse import Namespace from copy import deepcopy -from distutils.version import LooseVersion from pathlib import Path from unittest.mock import ANY, call, patch @@ -43,12 +42,6 @@ from tests.helpers.runif import RunIf -@pytest.fixture -def pytorch_profiler(tmpdir): - profiler = PyTorchProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"), local_rank=0) - return profiler - - @pytest.mark.parametrize("url_ckpt", [True, False]) def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" @@ -1488,124 +1481,6 @@ def test_trainer_predict_ddp_cpu(tmpdir): predict(tmpdir, "ddp_cpu", 0, 2) -def test_pytorch_profiler_describe(pytorch_profiler): - """Ensure the profiler won't fail when reporting the summary.""" - with pytorch_profiler.profile("test_step"): - pass - - # log to stdout and print to file - pytorch_profiler.describe() - data = Path(pytorch_profiler.output_fname).read_text() - assert len(data) > 0 - - -def test_pytorch_profiler_value_errors(pytorch_profiler): - """Ensure errors are raised where expected.""" - - action = "test_step" - with pytest.raises(ValueError): - pytorch_profiler.stop(action) - - pytorch_profiler.start(action) - pytorch_profiler.stop(action) - - -@RunIf(min_gpus=2, special=True) -@pytest.mark.parametrize("use_output_filename", [False, True]) -def test_pytorch_profiler_trainer_ddp(tmpdir, use_output_filename): - """Ensure that the profiler can be given to the training and default step are properly recorded. """ - - if use_output_filename: - output_filename = os.path.join(tmpdir, "profiler.txt") - else: - output_filename = None - - profiler = PyTorchProfiler(output_filename=output_filename) - - model = BoringModel() - trainer = Trainer( - fast_dev_run=True, - profiler=profiler, - accelerator="ddp", - gpus=2, - ) - trainer.fit(model) - - enabled = use_output_filename or not use_output_filename and profiler.local_rank == 0 - - if enabled: - assert len(profiler.summary()) > 0 - assert set(profiler.profiled_actions.keys()) == {'training_step_and_backward', 'validation_step'} - else: - assert profiler.summary() is None - assert set(profiler.profiled_actions.keys()) == set() - - if use_output_filename: - profiler.describe() - data = Path(profiler.output_fname).read_text() - assert len(data) > 0 - - -def test_pytorch_profiler_nested(tmpdir): - """Ensure that the profiler handles nested context""" - - pytorch_profiler = PyTorchProfiler( - profiled_functions=["a", "b", "c"], use_cuda=False, output_filename=os.path.join(tmpdir, "profiler.txt") - ) - - with pytorch_profiler.profile("a"): - a = torch.ones(42) - with pytorch_profiler.profile("b"): - b = torch.zeros(42) - with pytorch_profiler.profile("c"): - _ = a + b - - pa = pytorch_profiler.profiled_actions - - # From PyTorch 1.8.0, less operation are being traced. - if LooseVersion(torch.__version__) >= LooseVersion("1.8.0"): - expected_ = { - 'a': ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'add'], - 'b': ['zeros', 'empty', 'zero_'], - 'c': ['add'], - } - # From PyTorch 1.6.0, more operation are being traced. - elif LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - expected_ = { - 'a': ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'fill_', 'add', 'empty'], - 'b': ['zeros', 'empty', 'zero_', 'fill_'], - 'c': ['add', 'empty'], - } - else: - expected_ = { - 'a': ['add'], - 'b': [], - 'c': ['add'], - } - - for n in ('a', 'b', 'c'): - pa[n] = [e.name for e in pa[n]] - if LooseVersion(torch.__version__) >= LooseVersion("1.7.1"): - pa[n] = [e.replace("aten::", "") for e in pa[n]] - assert pa[n] == expected_[n] - - -@RunIf(min_gpus=1, special=True) -def test_pytorch_profiler_nested_emit_nvtx(tmpdir): - """ - This test check emit_nvtx is correctly supported - """ - profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True) - - model = BoringModel() - trainer = Trainer( - fast_dev_run=True, - profiler=profiler, - gpus=1, - ) - trainer.fit(model) - - @pytest.mark.parametrize( ["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"], [(0.2, 0, 0, 0, False), (0.5, 10, 2, 4, True)], From 870247ffe6a5ce819cf7e0a22b997a765d2f6675 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 22 Mar 2021 01:38:10 +0100 Subject: [PATCH 04/22] drop mypy from .pre-commit-config.yaml (#6542) --- .pre-commit-config.yaml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 21c52539a890d3..45eca43de93ac8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,8 +33,3 @@ repos: hooks: - id: yapf args: [--parallel, --in-place] - - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.790 - hooks: - - id: mypy From 853523ee643fe0f0cc30d40d9e85a8869e7edfd8 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 22 Mar 2021 08:53:51 +0000 Subject: [PATCH 05/22] Clean utilities/argparse and add missing tests (#6607) --- pytorch_lightning/utilities/argparse.py | 10 +--- ...est_argparse_utils.py => test_argparse.py} | 46 ++++++++++++++++++- 2 files changed, 47 insertions(+), 9 deletions(-) rename tests/utilities/{test_argparse_utils.py => test_argparse.py} (80%) diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 49cbaf3c6bdcfe..46d88184ee1904 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -67,7 +67,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp # Value has been passed as a flag => It is currently None, so we need to set it to True # We always set to True, regardless of the default value. # Users must pass False directly, but when passing nothing True is assumed. - # i.e. the only way to disable somthing that defaults to True is to use the long form: + # i.e. the only way to disable something that defaults to True is to use the long form: # "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None, # which then becomes True here. @@ -242,9 +242,6 @@ def add_argparse_args( if arg == 'track_grad_norm': use_type = float - if arg_default is inspect._empty: - arg_default = None - parser.add_argument( f'--{arg}', dest=arg, @@ -291,10 +288,7 @@ def _gpus_allowed_type(x) -> Union[int, str]: def _gpus_arg_default(x) -> Union[int, str]: - if ',' in x: - return str(x) - else: - return int(x) + return _gpus_allowed_type(x) def _int_or_float_type(x) -> Union[int, float]: diff --git a/tests/utilities/test_argparse_utils.py b/tests/utilities/test_argparse.py similarity index 80% rename from tests/utilities/test_argparse_utils.py rename to tests/utilities/test_argparse.py index b2eac514941e6a..fdf5ae0cafe65a 100644 --- a/tests/utilities/test_argparse_utils.py +++ b/tests/utilities/test_argparse.py @@ -1,17 +1,51 @@ import io -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace from typing import List +from unittest.mock import MagicMock import pytest from pytorch_lightning import Trainer from pytorch_lightning.utilities.argparse import ( add_argparse_args, + from_argparse_args, get_abbrev_qualified_cls_name, + parse_argparser, parse_args_from_docstring, + _gpus_arg_default, + _int_or_float_type ) +class ArgparseExample: + def __init__(self, a: int = 0, b: str = '', c: bool = False): + self.a = a + self.b = b + self.c = c + + +def test_from_argparse_args(): + args = Namespace(a=1, b='test', c=True, d='not valid') + my_instance = from_argparse_args(ArgparseExample, args) + assert my_instance.a == 1 + assert my_instance.b == 'test' + assert my_instance.c + + parser = ArgumentParser() + mock_trainer = MagicMock() + _ = from_argparse_args(mock_trainer, parser) + mock_trainer.parse_argparser.assert_called_once_with(parser) + + +def test_parse_argparser(): + args = Namespace(a=1, b='test', c=None, d='not valid') + new_args = parse_argparser(ArgparseExample, args) + assert new_args.a == 1 + assert new_args.b == 'test' + assert new_args.c + assert new_args.d == 'not valid' + + def test_parse_args_from_docstring_normal(): args_help = parse_args_from_docstring( """Constrain image dataset @@ -168,3 +202,13 @@ def test_add_argparse_args_no_argument_group(): args = parser.parse_args(fake_argv) assert args.main_arg == "abc" assert args.my_parameter == 2 + + +def test_gpus_arg_default(): + assert _gpus_arg_default('1,2') == '1,2' + assert _gpus_arg_default('1') == 1 + + +def test_int_or_float_type(): + assert isinstance(_int_or_float_type('0.0'), float) + assert isinstance(_int_or_float_type('0'), int) From 58c9fa7edbb40dd3fbfc544ee042e8b23693db08 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Mon, 22 Mar 2021 11:43:53 +0000 Subject: [PATCH 06/22] Allow training type plugin to delay optimizer creation (FSDP 2/n) (#6331) * Allow training_type_plugin to delay optimizer configure * Add missing references to trainer, add a CPU accelerator based test --- pytorch_lightning/accelerators/accelerator.py | 9 +++-- .../training_type/training_type_plugin.py | 10 ++++++ pytorch_lightning/trainer/trainer.py | 4 +-- tests/accelerators/test_cpu.py | 35 ++++++++++++++++++- 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index ceb9d98505acc5..60e6ea88b4250d 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -85,7 +85,8 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None: model: the LightningModule """ self.setup_training_type_plugin(self.training_type_plugin, model) - self.setup_optimizers(trainer) + if not self.training_type_plugin.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) self.setup_precision_plugin(self.precision_plugin) def start_training(self, trainer: 'Trainer') -> None: @@ -97,12 +98,14 @@ def start_evaluating(self, trainer: 'Trainer') -> None: def start_predicting(self, trainer: 'Trainer') -> None: self.training_type_plugin.start_predicting(trainer) - def pre_dispatch(self) -> None: + def pre_dispatch(self, trainer: 'Trainer') -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.pre_dispatch() + if self.training_type_plugin.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) self.precision_plugin.pre_dispatch() - def post_dispatch(self) -> None: + def post_dispatch(self, trainer: 'Trainer') -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.post_dispatch() self.precision_plugin.post_dispatch() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 6a87792c7bd03f..b6f1be359bbf2d 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -182,3 +182,13 @@ def init_optimizers(self, trainer: "Trainer", model: LightningModule): def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): optimizer.step(closure=lambda_closure, **kwargs) + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + """ + Override to delay setting optimizers and schedulers till after dispatch. + This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model. + However this may break certain precision plugins such as APEX which require optimizers to be set. + Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. + """ + return False diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 53b4920bd85ef6..0e9e28c9996f23 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -495,7 +495,7 @@ def fit( return self.accelerator.results or 1 def pre_dispatch(self): - self.accelerator.pre_dispatch() + self.accelerator.pre_dispatch(self) # log hyper-parameters if self.logger is not None: @@ -505,7 +505,7 @@ def pre_dispatch(self): self.logger.save() def post_dispatch(self): - self.accelerator.post_dispatch() + self.accelerator.post_dispatch(self) self.accelerator.teardown() def dispatch(self): diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 81a5132e473569..349e4175a74446 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -2,11 +2,12 @@ import pytest import torch - +from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel def test_unsupported_precision_plugins(): @@ -18,3 +19,35 @@ def test_unsupported_precision_plugins(): ) with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."): accelerator.setup(trainer=trainer, model=model) + + +@pytest.mark.parametrize("delay_dispatch", [True, False]) +def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch): + """ + Test when using a custom training type plugin that delays setup optimizers, + we do not call setup optimizers till ``pre_dispatch``. + """ + + class TestModel(BoringModel): + def on_fit_start(self): + if delay_dispatch: + # Ensure we haven't setup optimizers if we've delayed dispatch + assert len(self.trainer.optimizers) == 0 + else: + assert len(self.trainer.optimizers) > 0 + + def on_fit_end(self): + assert len(self.trainer.optimizers) > 0 + + class CustomPlugin(SingleDevicePlugin): + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + return delay_dispatch + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + plugins=CustomPlugin(device=torch.device("cpu")) + ) + trainer.fit(model) From e2e1de0fb73e6ba69fb26b7ade4371c5ee6a1845 Mon Sep 17 00:00:00 2001 From: camruta <79558951+camruta@users.noreply.github.com> Date: Mon, 22 Mar 2021 04:49:06 -0700 Subject: [PATCH 07/22] Add teardown method to BaseProfiler. (#6370) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ Co-authored-by: ananthsub --- .gitignore | 1 + CHANGELOG.md | 18 ++++++++++++------ pytorch_lightning/profiler/profilers.py | 20 ++++++++++++++------ pytorch_lightning/profiler/pytorch.py | 8 +++++--- pytorch_lightning/trainer/trainer.py | 1 + pytorch_lightning/trainer/training_loop.py | 1 + tests/test_profiler.py | 22 ++++++++++++++++++++-- 7 files changed, 54 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index cd0ba22453512d..c0071402571884 100644 --- a/.gitignore +++ b/.gitignore @@ -157,3 +157,4 @@ tags data MNIST runs +*traces* diff --git a/CHANGELOG.md b/CHANGELOG.md index 6004a28dd0829b..5f005f583c5ed3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,8 +14,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) + - Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) + - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) @@ -37,6 +39,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) +- Added `teardown` method to `BaseProfiler` to enable subclasses defining post-profiling steps outside of `__del__` ([#6370](https://github.com/PyTorchLightning/pytorch-lightning/pull/6370)) + + - Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) @@ -120,6 +125,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565)) + - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) @@ -147,6 +153,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416)) +- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587)) + + +- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576)) + + - Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506)) @@ -170,12 +182,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) -- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587)) - - -- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576)) - - ## [1.2.3] - 2021-03-09 ### Fixed diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index d704ba83236c16..55898dc2ee4e16 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -55,6 +55,10 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: """Defines how to record the duration once an action is complete.""" + def teardown(self) -> None: + """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" + pass + @contextmanager def profile(self, action_name: str) -> None: """ @@ -211,14 +215,16 @@ def log_row(action, mean, total): def describe(self): """Logs a profile report after the conclusion of the training run.""" super().describe() - if self.output_file: - self.output_file.flush() + self.teardown() - def __del__(self): + def teardown(self) -> None: """Close profiler's stream.""" if self.output_file: self.output_file.close() + def __del__(self): + self.teardown() + class AdvancedProfiler(BaseProfiler): """ @@ -283,10 +289,12 @@ def summary(self) -> str: def describe(self): """Logs a profile report after the conclusion of the training run.""" super().describe() - if self.output_file: - self.output_file.flush() + self.teardown() - def __del__(self): + def teardown(self) -> None: """Close profiler's stream.""" if self.output_file: self.output_file.close() + + def __del__(self): + self.teardown() diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 88a33a3d367f8b..fdde80589acf3d 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -294,10 +294,12 @@ def summary(self) -> str: def describe(self): """Logs a profile report after the conclusion of the training run.""" super().describe() - if self.output_file: - self.output_file.flush() + self.teardown() - def __del__(self): + def teardown(self) -> None: """Close profiler's stream.""" if self.output_file: self.output_file.close() + + def __del__(self): + self.teardown() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0e9e28c9996f23..a5b99871d55f91 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1077,6 +1077,7 @@ def call_teardown_hook(self, model: LightningModule) -> None: else: state = None + self.profiler.teardown() self.teardown(stage=state) model.teardown(stage=state) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7e737c424ff261..a77d91a7402b4a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -140,6 +140,7 @@ def on_train_end(self): self.trainer.logger.finalize("success") # summarize profile results + # todo (tchaton) All ranks should call describe. if self.trainer.global_rank == 0: self.trainer.profiler.describe() diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 5221c0cbf7bf68..ccdd8a569c9a8b 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -252,8 +252,8 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, use_output_filename): assert profiler.summary() is None assert set(profiler.profiled_actions.keys()) == set() - if use_output_filename: - profiler.describe() + # todo (tchaton) add support for all ranks + if use_output_filename and os.getenv("LOCAL_RANK") == "0": data = Path(profiler.output_fname).read_text() assert len(data) > 0 @@ -316,3 +316,21 @@ def test_pytorch_profiler_nested_emit_nvtx(tmpdir): gpus=1, ) trainer.fit(model) + + +@pytest.mark.parametrize("cls", (SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) +def test_profiler_teardown(tmpdir, cls): + """ + This test checks if profiler teardown method is called when trainer is exiting. + """ + profiler = cls(output_filename=os.path.join(tmpdir, "profiler.txt")) + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + profiler=profiler, + ) + trainer.fit(model) + + assert profiler.output_file.closed From 1fae10a2dc8224379eac84d6242e0847c2685565 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 22 Mar 2021 13:39:19 +0100 Subject: [PATCH 08/22] refactoring setup (#6590) * refactoring setup * . * docs * flake8 --- docs/source/conf.py | 23 ++++--- pytorch_lightning/__init__.py | 81 +++++++------------------ pytorch_lightning/callbacks/progress.py | 3 +- pytorch_lightning/info.py | 35 +++++++++++ pytorch_lightning/setup_tools.py | 6 +- setup.py | 46 ++++++++------ 6 files changed, 101 insertions(+), 93 deletions(-) create mode 100644 pytorch_lightning/info.py diff --git a/docs/source/conf.py b/docs/source/conf.py index ccf824bb37d9b7..11a0d2a0538bb8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,7 +13,6 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. # import m2r -import builtins import glob import os import shutil @@ -27,10 +26,13 @@ FOLDER_GENERATED = 'generated' SPHINX_MOCK_REQUIREMENTS = int(os.environ.get('SPHINX_MOCK_REQUIREMENTS', True)) -if SPHINX_MOCK_REQUIREMENTS: - builtins.__LIGHTNING_SETUP__ = True -import pytorch_lightning # noqa: E402 +try: + from pytorch_lightning import info +except ImportError: + # alternative https://stackoverflow.com/a/67692/4521646 + sys.path.append(os.path.join(PATH_ROOT, "pytorch_lightning")) + import info # -- Project documents ------------------------------------------------------- @@ -79,13 +81,13 @@ def _transform_changelog(path_in: str, path_out: str) -> None: # -- Project information ----------------------------------------------------- project = 'PyTorch Lightning' -copyright = pytorch_lightning.__copyright__ -author = pytorch_lightning.__author__ +copyright = info.__copyright__ +author = info.__author__ # The short X.Y version -version = pytorch_lightning.__version__ +version = info.__version__ # The full version, including alpha/beta/rc tags -release = pytorch_lightning.__version__ +release = info.__version__ # -- General configuration --------------------------------------------------- @@ -176,8 +178,8 @@ def _transform_changelog(path_in: str, path_out: str) -> None: # documentation. html_theme_options = { - 'pytorch_project': pytorch_lightning.__homepage__, - 'canonical_url': pytorch_lightning.__homepage__, + 'pytorch_project': info.__homepage__, + 'canonical_url': info.__homepage__, 'collapse_navigation': False, 'display_version': True, 'logo_only': False, @@ -279,6 +281,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None: 'torch': ('https://pytorch.org/docs/stable/', None), 'numpy': ('https://numpy.org/doc/stable/', None), 'PIL': ('https://pillow.readthedocs.io/en/stable/', None), + 'torchmetrics': ('https://torchmetrics.readthedocs.io/en/stable/', None), } # -- Options for todo extension ---------------------------------------------- diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index 569078c994ba4e..b9660475bf2f7d 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -2,42 +2,17 @@ import logging import os -import sys -import time -_this_year = time.strftime("%Y") -__version__ = '1.3.0dev' -__author__ = 'William Falcon et al.' -__author_email__ = 'waf2107@columbia.edu' -__license__ = 'Apache-2.0' -__copyright__ = f'Copyright (c) 2018-{_this_year}, {__author__}.' -__homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning' -# this has to be simple string, see: https://github.com/pypa/twine/issues/522 -__docs__ = ( - "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." - " Scale your models. Write less boilerplate." +from pytorch_lightning.info import ( # noqa: F401 + __author__, + __author_email__, + __copyright__, + __docs__, + __homepage__, + __license__, + __version__, ) -__long_docs__ = """ -Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. - It's more of a style-guide than a framework. -In Lightning, you organize your code into 3 distinct categories: - -1. Research code (goes in the LightningModule). -2. Engineering code (you delete, and is handled by the Trainer). -3. Non-essential research code (logging, etc. this goes in Callbacks). - -Although your research/production project might start simple, once you add things like GPU AND TPU training, - 16-bit precision, etc, you end up spending more time engineering than researching. - Lightning automates AND rigorously tests those parts for you. - -Overall, Lightning guarantees rigorously tested, correct, modern best practices for the automated parts. - -Documentation -------------- -- https://pytorch-lightning.readthedocs.io/en/latest -- https://pytorch-lightning.readthedocs.io/en/stable -""" _root_logger = logging.getLogger() _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) @@ -50,32 +25,20 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) -try: - # This variable is injected in the __builtins__ by the build - # process. It used to enable importing subpackages of skimage when - # the binaries are not built - _ = None if __LIGHTNING_SETUP__ else None -except NameError: - __LIGHTNING_SETUP__: bool = False - -if __LIGHTNING_SETUP__: # pragma: no-cover - sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover - # We are not importing the rest of the lightning during the build process, as it may not be compiled yet -else: - from pytorch_lightning import metrics - from pytorch_lightning.callbacks import Callback - from pytorch_lightning.core import LightningDataModule, LightningModule - from pytorch_lightning.trainer import Trainer - from pytorch_lightning.utilities.seed import seed_everything - - __all__ = [ - 'Trainer', - 'LightningDataModule', - 'LightningModule', - 'Callback', - 'seed_everything', - 'metrics', - ] +from pytorch_lightning import metrics # noqa: E402 +from pytorch_lightning.callbacks import Callback # noqa: E402 +from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402 +from pytorch_lightning.trainer import Trainer # noqa: E402 +from pytorch_lightning.utilities.seed import seed_everything # noqa: E402 + +__all__ = [ + 'Trainer', + 'LightningDataModule', + 'LightningModule', + 'Callback', + 'seed_everything', + 'metrics', +] # for compatibility with namespace packages __import__('pkg_resources').declare_namespace(__name__) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 74e57e2b5642e0..78db9a7dba12eb 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -39,8 +39,7 @@ class tqdm(_tqdm): """ - Custom tqdm progressbar where we append 0 to floating points/strings to - prevent the progress bar from flickering + Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering """ @staticmethod diff --git a/pytorch_lightning/info.py b/pytorch_lightning/info.py new file mode 100644 index 00000000000000..0e7a1c25a74f15 --- /dev/null +++ b/pytorch_lightning/info.py @@ -0,0 +1,35 @@ +import time + +_this_year = time.strftime("%Y") +__version__ = '1.3.0dev' +__author__ = 'William Falcon et al.' +__author_email__ = 'waf2107@columbia.edu' +__license__ = 'Apache-2.0' +__copyright__ = f'Copyright (c) 2018-{_this_year}, {__author__}.' +__homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning' +# this has to be simple string, see: https://github.com/pypa/twine/issues/522 +__docs__ = ( + "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." + " Scale your models. Write less boilerplate." +) +__long_docs__ = """ +Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. + It's more of a style-guide than a framework. + +In Lightning, you organize your code into 3 distinct categories: + +1. Research code (goes in the LightningModule). +2. Engineering code (you delete, and is handled by the Trainer). +3. Non-essential research code (logging, etc. this goes in Callbacks). + +Although your research/production project might start simple, once you add things like GPU AND TPU training, + 16-bit precision, etc, you end up spending more time engineering than researching. + Lightning automates AND rigorously tests those parts for you. + +Overall, Lightning guarantees rigorously tested, correct, modern best practices for the automated parts. + +Documentation +------------- +- https://pytorch-lightning.readthedocs.io/en/latest +- https://pytorch-lightning.readthedocs.io/en/stable +""" diff --git a/pytorch_lightning/setup_tools.py b/pytorch_lightning/setup_tools.py index f5aed2608635e4..3362ccb479895e 100644 --- a/pytorch_lightning/setup_tools.py +++ b/pytorch_lightning/setup_tools.py @@ -16,7 +16,7 @@ import re from typing import List -from pytorch_lightning import __homepage__, __version__, _PROJECT_ROOT +_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]: @@ -40,10 +40,10 @@ def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comme return reqs -def _load_readme_description(path_dir: str, homepage: str = __homepage__, version: str = __version__) -> str: +def _load_readme_description(path_dir: str, homepage: str, version: str) -> str: """Load readme as decribtion - >>> _load_readme_description(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> _load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE '
...' """ path_readme = os.path.join(path_dir, "README.md") diff --git a/setup.py b/setup.py index 5d619d51977b29..e53e24ebf07023 100755 --- a/setup.py +++ b/setup.py @@ -16,20 +16,22 @@ import os # Always prefer setuptools over distutils +import sys + from setuptools import find_packages, setup try: - import builtins + from pytorch_lightning import info, setup_tools except ImportError: - import __builtin__ as builtins + # alternative https://stackoverflow.com/a/67692/4521646 + sys.path.append("pytorch_lightning") + import info + import setup_tools # https://packaging.python.org/guides/single-sourcing-package-version/ # http://blog.ionelmc.ro/2014/05/25/python-packaging/ -PATH_ROOT = os.path.dirname(__file__) -builtins.__LIGHTNING_SETUP__ = True - -import pytorch_lightning # noqa: E402 -from pytorch_lightning.setup_tools import _load_readme_description, _load_requirements # noqa: E402 +_PATH_ROOT = os.path.dirname(__file__) +_PATH_REQUIRE = os.path.join(_PATH_ROOT, 'requirements') # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras # Define package extras. These are only installed if you specify them. @@ -37,10 +39,10 @@ # From local copy of repo, use like `pip install ".[dev, docs]"` extras = { # 'docs': load_requirements(file_name='docs.txt'), - 'examples': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='examples.txt'), - 'loggers': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='loggers.txt'), - 'extra': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='extra.txt'), - 'test': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='test.txt') + 'examples': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='examples.txt'), + 'loggers': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='loggers.txt'), + 'extra': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='extra.txt'), + 'test': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='test.txt') } extras['dev'] = extras['extra'] + extras['loggers'] + extras['test'] extras['all'] = extras['dev'] + extras['examples'] # + extras['docs'] @@ -53,6 +55,12 @@ # filter cpu only packages extras[ex] = [pkg for pkg in extras[kw] if not any(pgpu.lower() in pkg.lower() for pgpu in PACKAGES_GPU_ONLY)] +long_description = setup_tools._load_readme_description( + _PATH_ROOT, + homepage=info.__homepage__, + version=info.__version__, +) + # https://packaging.python.org/discussions/install-requires-vs-requirements / # keep the meta-data here for simplicity in reading this file... it's not obvious # what happens and to non-engineers they won't know to look in init ... @@ -60,22 +68,22 @@ # engineer specific practices setup( name="pytorch-lightning", - version=pytorch_lightning.__version__, - description=pytorch_lightning.__docs__, - author=pytorch_lightning.__author__, - author_email=pytorch_lightning.__author_email__, - url=pytorch_lightning.__homepage__, + version=info.__version__, + description=info.__docs__, + author=info.__author__, + author_email=info.__author_email__, + url=info.__homepage__, download_url='https://github.com/PyTorchLightning/pytorch-lightning', - license=pytorch_lightning.__license__, + license=info.__license__, packages=find_packages(exclude=['tests', 'tests/*', 'benchmarks', 'legacy', 'legacy/*']), - long_description=_load_readme_description(PATH_ROOT), + long_description=long_description, long_description_content_type='text/markdown', include_package_data=True, zip_safe=False, keywords=['deep learning', 'pytorch', 'AI'], python_requires='>=3.6', setup_requires=[], - install_requires=_load_requirements(PATH_ROOT), + install_requires=setup_tools._load_requirements(_PATH_ROOT), extras_require=extras, project_urls={ "Bug Tracker": "https://github.com/PyTorchLightning/pytorch-lightning/issues", From e62c7c7839beea9be336fe9f30873d005f9cdc5e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 22 Mar 2021 17:49:01 +0100 Subject: [PATCH 09/22] hotfix: mock examples (#6632) * mock examples * drop from GA --- azure-pipelines.yml | 2 ++ pl_examples/__init__.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 9e2ff77563fa01..b7a2d851052edb 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -121,4 +121,6 @@ jobs: # cd pl_examples/basic_examples # bash submit_ddp_job.sh # bash submit_ddp2_job.sh + env: + PL_USE_MOCKED_MNIST: "1" displayName: 'Examples' diff --git a/pl_examples/__init__.py b/pl_examples/__init__.py index ffd60f9ed71af4..150ac309ddcebd 100644 --- a/pl_examples/__init__.py +++ b/pl_examples/__init__.py @@ -15,10 +15,10 @@ _DATASETS_PATH = os.path.join(_PACKAGE_ROOT, 'Datasets') _TORCHVISION_AVAILABLE = _module_available("torchvision") -_TORCHVISION_MNIST_AVAILABLE = True +_TORCHVISION_MNIST_AVAILABLE = not bool(os.environ.get("PL_USE_MOCKED_MNIST", False)) _DALI_AVAILABLE = _module_available("nvidia.dali") -if _TORCHVISION_AVAILABLE: +if _TORCHVISION_MNIST_AVAILABLE: try: from torchvision.datasets.mnist import MNIST MNIST(_DATASETS_PATH, download=True) From 2064ece5825dfa07c339ed8c6e8ea59183e5938e Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 22 Mar 2021 18:32:31 +0000 Subject: [PATCH 10/22] [refactor] Add setup to profilers + _run_stage_setup to trainer 2/5 (#6633) * add setup * update * updates on comment * Minor changes * Extra import * Docs Co-authored-by: Carlos Mocholi --- CHANGELOG.md | 3 ++ .../plugins/training_type/horovod.py | 6 +-- .../training_type/training_type_plugin.py | 6 +-- pytorch_lightning/profiler/profilers.py | 53 ++++++++----------- pytorch_lightning/profiler/pytorch.py | 21 ++------ .../trainer/connectors/profiler_connector.py | 5 +- pytorch_lightning/trainer/properties.py | 10 ++++ pytorch_lightning/trainer/trainer.py | 28 +++++----- pytorch_lightning/trainer/training_loop.py | 3 -- tests/test_profiler.py | 17 +++--- 10 files changed, 72 insertions(+), 80 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f005f583c5ed3..51ad97decd867f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `teardown` method to `BaseProfiler` to enable subclasses defining post-profiling steps outside of `__del__` ([#6370](https://github.com/PyTorchLightning/pytorch-lightning/pull/6370)) +- Added `setup` method to `BaseProfiler` to enable subclasses defining pre-profiling steps for every process ([#6633](https://github.com/PyTorchLightning/pytorch-lightning/pull/6633)) + + - Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 9f1bafe309f89e..8d0add27cbb29c 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -96,14 +96,14 @@ def start_training(self, trainer): stack.enter_context(optimizer.skip_synchronize()) # set up training routine - self._results = trainer.run_train() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user hvd.join() def start_evaluating(self, trainer): with ExitStack(): - self._results = trainer.run_evaluate() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user hvd.join() @@ -111,7 +111,7 @@ def start_evaluating(self, trainer): def start_predicting(self, trainer): with ExitStack(): # set up training routine - self._results = trainer.run_predict() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user hvd.join() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index b6f1be359bbf2d..89f27963caadfd 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -132,15 +132,15 @@ def rpc_enabled(self) -> bool: def start_training(self, trainer: 'Trainer') -> None: # double dispatch to initiate the training loop - self._results = trainer.run_train() + self._results = trainer.run_stage() def start_evaluating(self, trainer: 'Trainer') -> None: # double dispatch to initiate the test loop - self._results = trainer.run_evaluate() + self._results = trainer.run_stage() def start_predicting(self, trainer: 'Trainer') -> None: # double dispatch to initiate the predicting loop - self._results = trainer.run_predict() + self._results = trainer.run_stage() def training_step(self, *args, **kwargs): return self.lightning_module.training_step(*args, **kwargs) diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index 55898dc2ee4e16..5668fd6654b2f2 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -55,9 +55,23 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: """Defines how to record the duration once an action is complete.""" - def teardown(self) -> None: - """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" - pass + def setup( + self, + stage: Optional[str] = None, + local_rank: Optional[int] = None, + log_dir: Optional[str] = None + ) -> None: + """Execute arbitrary pre-profiling set-up steps.""" + self.stage = stage + self.local_rank = local_rank + self.log_dir = log_dir + + def teardown(self, stage: Optional[str] = None) -> None: + """Execute arbitrary post-profiling tear-down steps.""" + self.stage = stage + if self.output_file: + self.output_file.close() + self.output_file = None @contextmanager def profile(self, action_name: str) -> None: @@ -94,13 +108,15 @@ def describe(self) -> None: """Logs a profile report after the conclusion of the training run.""" for write in self.write_streams: write(self.summary()) + if self.output_file is not None: + self.output_file.flush() @abstractmethod def summary(self) -> str: """Create profiler summary in text format.""" - def on_train_start(self, local_rank: Optional[int] = None): - self.local_rank = local_rank + def __del__(self): + self.teardown(None) class PassThroughProfiler(BaseProfiler): @@ -110,6 +126,7 @@ class PassThroughProfiler(BaseProfiler): """ def __init__(self): + self.output_file = None super().__init__(output_streams=None) def start(self, action_name: str) -> None: @@ -212,19 +229,6 @@ def log_row(action, mean, total): output_string += os.linesep return output_string - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - self.teardown() - - def teardown(self) -> None: - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() - - def __del__(self): - self.teardown() - class AdvancedProfiler(BaseProfiler): """ @@ -285,16 +289,3 @@ def summary(self) -> str: output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}" return output_string - - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - self.teardown() - - def teardown(self) -> None: - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() - - def __del__(self): - self.teardown() diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index fdde80589acf3d..c35979fa918af0 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -162,11 +162,11 @@ def __init__( self.output_fname = output_filename self.output_file = None if local_rank is not None: - self.on_train_start(local_rank=local_rank) - self.on_train_start = super().on_train_start + self.setup(local_rank=local_rank) + self.setup = super().setup - def on_train_start(self, local_rank: Optional[str] = None): - self.local_rank = local_rank + def setup(self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None): + super().setup(stage=stage, local_rank=local_rank, log_dir=log_dir) # when logging to `log.info`, only perform profiling on rank 0 if local_rank != 0 and self.output_fname is None: @@ -290,16 +290,3 @@ def summary(self) -> str: output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}") return output_string - - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - self.teardown() - - def teardown(self) -> None: - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() - - def __del__(self): - self.teardown() diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index 98d65c1285ff79..e628d6d96bd199 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -54,6 +54,7 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, str]): ) self.trainer.profiler = profiler or PassThroughProfiler() - def on_train_start(self, trainer): + def setup(self) -> None: + trainer = self.trainer local_rank = trainer.local_rank if trainer.world_size > 1 else None - self.trainer.profiler.on_train_start(local_rank) + trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index b5654b148afc6b..315e3c60c05579 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -491,6 +491,16 @@ def sanity_checking(self, val: bool) -> None: elif self.sanity_checking: self._running_stage = None + @property + def _setup_state(self) -> TrainerState: + # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" + return TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + + @property + def _teardown_state(self) -> Optional[TrainerState]: + if self.state.running: + return self._setup_state + # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a5b99871d55f91..f7bd1757b9bc21 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -445,13 +445,15 @@ def fit( | || {self.dispatch} || | || LIGHTNING - {self.accelerator.start_training} or || - {self.accelerator.start_evaluating} or || FLOW - {self.accelerator.start_predicting} || + {self.accelerator.start_training} || + or {self.accelerator.start_evaluating} || + or {self.accelerator.start_predicting} || FLOW + | || + {self.run_stage} || | || DIRECTION - {self.run_train} or || - {self.run_evaluation} or || - {self.run_predict} || + {self.run_train} || + or {self.run_evaluation} || + or {self.run_predict} || | || results \/ This is used to guide readers to the core loops: train, test, predict. @@ -518,6 +520,9 @@ def dispatch(self): def run_stage(self): results = None + + self.profile_connector.setup() + if self.evaluating: results = self.run_evaluate() elif self.predicting: @@ -1060,8 +1065,7 @@ def tune( def call_setup_hook(self, model: LightningModule) -> None: assert self.state.running, f"TrainerState: {self.state}" - # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" - state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + state = self._setup_state if self.datamodule is not None: called = getattr(self.datamodule, f'has_setup_{state}') @@ -1072,12 +1076,8 @@ def call_setup_hook(self, model: LightningModule) -> None: model.setup(stage=state) def call_teardown_hook(self, model: LightningModule) -> None: - if self.state.running: - state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state - else: - state = None - - self.profiler.teardown() + state = self._teardown_state + self.profiler.teardown(stage=state) self.teardown(stage=state) model.teardown(stage=state) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a77d91a7402b4a..384a1b67a64f84 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -102,9 +102,6 @@ def on_train_start(self): # hook self.trainer.call_hook("on_train_start") - # provide rank to profiler - self.trainer.profile_connector.on_train_start(self.trainer) - def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): # clean hparams if hasattr(model, "hparams"): diff --git a/tests/test_profiler.py b/tests/test_profiler.py index ccdd8a569c9a8b..cc4fff3b7ede49 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -22,7 +22,8 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler, PyTorchProfiler +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -323,14 +324,16 @@ def test_profiler_teardown(tmpdir, cls): """ This test checks if profiler teardown method is called when trainer is exiting. """ + + class TestCallback(Callback): + + def on_fit_end(self, trainer, pl_module) -> None: + assert trainer.profiler.output_file is not None + profiler = cls(output_filename=os.path.join(tmpdir, "profiler.txt")) model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - profiler=profiler, - ) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()]) trainer.fit(model) - assert profiler.output_file.closed + assert profiler.output_file is None From 8cd75a4dd51939881da265752c2d81307cbe4d9e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 23 Mar 2021 08:51:45 +0100 Subject: [PATCH 11/22] fix comparing versions (#6434) * fix comparing versions * chlog * . * ... * datasets --- .github/workflows/docs-checks.yml | 2 +- CHANGELOG.md | 3 +++ Makefile | 2 +- docs/source/conf.py | 1 + pytorch_lightning/utilities/imports.py | 22 ++++++++++++++++++---- requirements/extra.txt | 1 + 6 files changed, 25 insertions(+), 6 deletions(-) diff --git a/.github/workflows/docs-checks.yml b/.github/workflows/docs-checks.yml index 5ee4f23b4b3ccd..4488c598c8ac75 100644 --- a/.github/workflows/docs-checks.yml +++ b/.github/workflows/docs-checks.yml @@ -98,7 +98,7 @@ jobs: # First run the same pipeline as Read-The-Docs cd docs make clean - make html --debug --jobs $(nproc) SPHINXOPTS="-W" + make html --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going" - name: Upload built docs uses: actions/upload-artifact@v2 diff --git a/CHANGELOG.md b/CHANGELOG.md index 51ad97decd867f..c542b854af104a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -165,6 +165,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506)) +- Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434)) + + ## [1.2.4] - 2021-03-16 ### Changed diff --git a/Makefile b/Makefile index d35e0b77f84296..04b08fa2d27d1a 100644 --- a/Makefile +++ b/Makefile @@ -29,4 +29,4 @@ test: clean docs: clean pip install --quiet -r requirements/docs.txt - python -m sphinx -b html -W docs/source docs/build + python -m sphinx -b html -W --keep-going docs/source docs/build diff --git a/docs/source/conf.py b/docs/source/conf.py index 11a0d2a0538bb8..6163de976da405 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -334,6 +334,7 @@ def package_list_from_file(file): } MOCK_PACKAGES = [] if SPHINX_MOCK_REQUIREMENTS: + MOCK_PACKAGES += ['fairscale'] # mock also base packages when we are on RTD since we don't install them there MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements.txt')) MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements', 'extra.txt')) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 41a13d6c678a0d..8090c4ed6590f8 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities""" +import importlib import operator import platform import sys @@ -19,7 +20,7 @@ from importlib.util import find_spec import torch -from pkg_resources import DistributionNotFound, get_distribution +from pkg_resources import DistributionNotFound def _module_available(module_path: str) -> bool: @@ -42,11 +43,24 @@ def _module_available(module_path: str) -> bool: def _compare_version(package: str, op, version) -> bool: + """ + Compare package version with some requirements + + >>> _compare_version("torch", operator.ge, "0.1") + True + """ try: - pkg_version = LooseVersion(get_distribution(package).version) - return op(pkg_version, LooseVersion(version)) - except DistributionNotFound: + pkg = importlib.import_module(package) + except (ModuleNotFoundError, DistributionNotFound): + return False + try: + pkg_version = LooseVersion(pkg.__version__) + except AttributeError: return False + if not (hasattr(pkg_version, "vstring") and hasattr(pkg_version, "version")): + # this is mock by sphinx, so it shall return True ro generate all summaries + return True + return op(pkg_version, LooseVersion(version)) _IS_WINDOWS = platform.system() == "Windows" diff --git a/requirements/extra.txt b/requirements/extra.txt index a05c4971ac450a..715916c4e36acb 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,4 +7,5 @@ torchtext>=0.5 # onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 +# todo: when switch to standard package stream, drop `fairscale` from hard mocked docs libs https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip From efce2b77779467884df9a3d9c16c3176ea81a650 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 23 Mar 2021 09:35:51 +0100 Subject: [PATCH 12/22] Prune metrics: regression 8/n (#6636) * explained_variance * tests * mean_absolute_error * mean_squared_error * mean_relative_error * mean_squared_log_error * chlog --- CHANGELOG.md | 2 + .../metrics/functional/explained_variance.py | 68 +---------- .../metrics/functional/mean_absolute_error.py | 34 +----- .../metrics/functional/mean_relative_error.py | 37 +----- .../metrics/functional/mean_squared_error.py | 34 +----- .../functional/mean_squared_log_error.py | 34 +----- .../metrics/regression/explained_variance.py | 106 ++---------------- .../metrics/regression/mean_absolute_error.py | 64 ++--------- .../metrics/regression/mean_squared_error.py | 65 ++--------- .../regression/mean_squared_log_error.py | 67 ++--------- tests/accelerators/test_cpu.py | 1 + .../regression/test_explained_variance.py | 77 ------------- tests/metrics/regression/test_mean_error.py | 87 -------------- tests/metrics/test_remove_1-5_metrics.py | 59 +++++++++- tests/utilities/test_argparse.py | 4 +- 15 files changed, 115 insertions(+), 624 deletions(-) delete mode 100644 tests/metrics/regression/test_explained_variance.py delete mode 100644 tests/metrics/regression/test_mean_error.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c542b854af104a..57a071bff297af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -90,6 +90,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#6584](https://github.com/PyTorchLightning/pytorch-lightning/pull/6584), + [#6636](https://github.com/PyTorchLightning/pytorch-lightning/pull/6636), + ) diff --git a/pytorch_lightning/metrics/functional/explained_variance.py b/pytorch_lightning/metrics/functional/explained_variance.py index fa8d43c06c7efd..bcfe698bf4c5ea 100644 --- a/pytorch_lightning/metrics/functional/explained_variance.py +++ b/pytorch_lightning/metrics/functional/explained_variance.py @@ -11,77 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple, Union +from typing import Sequence, Union import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import explained_variance as _explained_variance - -def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - _check_same_shape(preds, target) - return preds, target - - -def _explained_variance_compute( - preds: torch.Tensor, - target: torch.Tensor, - multioutput: str = 'uniform_average', -) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - diff_avg = torch.mean(target - preds, dim=0) - numerator = torch.mean((target - preds - diff_avg)**2, dim=0) - - target_avg = torch.mean(target, dim=0) - denominator = torch.mean((target - target_avg)**2, dim=0) - - # Take care of division by zero - nonzero_numerator = numerator != 0 - nonzero_denominator = denominator != 0 - valid_score = nonzero_numerator & nonzero_denominator - output_scores = torch.ones_like(diff_avg) - output_scores[valid_score] = 1.0 - (numerator[valid_score] / denominator[valid_score]) - output_scores[nonzero_numerator & ~nonzero_denominator] = 0. - - # Decide what to do in multioutput case - # Todo: allow user to pass in tensor with weights - if multioutput == 'raw_values': - return output_scores - if multioutput == 'uniform_average': - return torch.mean(output_scores) - if multioutput == 'variance_weighted': - denom_sum = torch.sum(denominator) - return torch.sum(denominator / denom_sum * output_scores) +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_explained_variance, ver_deprecate="1.3.0", ver_remove="1.5.0") def explained_variance( preds: torch.Tensor, target: torch.Tensor, multioutput: str = 'uniform_average', ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: """ - Computes explained variance. - - Args: - preds: estimated labels - target: ground truth labels - multioutput: Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is `'uniform_average'`.): - - * `'raw_values'` returns full set of scores - * `'uniform_average'` scores are uniformly averaged - * `'variance_weighted'` scores are weighted by their individual variances - - Example: - - >>> from pytorch_lightning.metrics.functional import explained_variance - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> explained_variance(preds, target) - tensor(0.9572) - - >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) - >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) - >>> explained_variance(preds, target, multioutput='raw_values') - tensor([0.9677, 1.0000]) + .. deprecated:: + Use :func:`torchmetrics.functional.explained_variance`. Will be removed in v1.5.0. """ - preds, target = _explained_variance_update(preds, target) - return _explained_variance_compute(preds, target, multioutput) diff --git a/pytorch_lightning/metrics/functional/mean_absolute_error.py b/pytorch_lightning/metrics/functional/mean_absolute_error.py index 2bd8f125ecb9e4..85aa07c802eca5 100644 --- a/pytorch_lightning/metrics/functional/mean_absolute_error.py +++ b/pytorch_lightning/metrics/functional/mean_absolute_error.py @@ -11,40 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import mean_absolute_error as _mean_absolute_error - -def _mean_absolute_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - sum_abs_error = torch.sum(torch.abs(preds - target)) - n_obs = target.numel() - return sum_abs_error, n_obs - - -def _mean_absolute_error_compute(sum_abs_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_abs_error / n_obs +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_mean_absolute_error, ver_deprecate="1.3.0", ver_remove="1.5.0") def mean_absolute_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean absolute error - - Args: - pred: estimated labels - target: ground truth labels - - Return: - Tensor with MAE - - Example: - >>> from pytorch_lightning.metrics.functional import mean_absolute_error - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_absolute_error(x, y) - tensor(0.2500) + .. deprecated:: + Use :func:`torchmetrics.functional.mean_absolute_error`. Will be removed in v1.5.0. """ - sum_abs_error, n_obs = _mean_absolute_error_update(preds, target) - return _mean_absolute_error_compute(sum_abs_error, n_obs) diff --git a/pytorch_lightning/metrics/functional/mean_relative_error.py b/pytorch_lightning/metrics/functional/mean_relative_error.py index bfe5eb6b847d70..be21371bdc91a3 100644 --- a/pytorch_lightning/metrics/functional/mean_relative_error.py +++ b/pytorch_lightning/metrics/functional/mean_relative_error.py @@ -11,43 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional.regression.mean_relative_error import mean_relative_error as _mean_relative_error - -def _mean_relative_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - target_nz = target.clone() - target_nz[target == 0] = 1 - sum_rltv_error = torch.sum(torch.abs((preds - target) / target_nz)) - n_obs = target.numel() - return sum_rltv_error, n_obs - - -def _mean_relative_error_compute(sum_rltv_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_rltv_error / n_obs +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_mean_relative_error, ver_deprecate="1.3.0", ver_remove="1.5.0") def mean_relative_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean relative error - - Args: - pred: estimated labels - target: ground truth labels - - Return: - Tensor with mean relative error - - Example: - - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_relative_error(x, y) - tensor(0.1250) - + .. deprecated:: + Use :func:`torchmetrics.functional.regression.mean_relative_error`. Will be removed in v1.5.0. """ - sum_rltv_error, n_obs = _mean_relative_error_update(preds, target) - return _mean_relative_error_compute(sum_rltv_error, n_obs) diff --git a/pytorch_lightning/metrics/functional/mean_squared_error.py b/pytorch_lightning/metrics/functional/mean_squared_error.py index 66c0aadef06510..9d1850dcd8689a 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_error.py @@ -11,40 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import mean_squared_error as _mean_squared_error - -def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - sum_squared_error = torch.sum(torch.pow(preds - target, 2)) - n_obs = target.numel() - return sum_squared_error, n_obs - - -def _mean_squared_error_compute(sum_squared_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_squared_error / n_obs +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_mean_squared_error, ver_deprecate="1.3.0", ver_remove="1.5.0") def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean squared error - - Args: - preds: estimated labels - target: ground truth labels - - Return: - Tensor with MSE - - Example: - >>> from pytorch_lightning.metrics.functional import mean_squared_error - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_squared_error(x, y) - tensor(0.2500) + .. deprecated:: + Use :func:`torchmetrics.functional.mean_squared_error`. Will be removed in v1.5.0. """ - sum_squared_error, n_obs = _mean_squared_error_update(preds, target) - return _mean_squared_error_compute(sum_squared_error, n_obs) diff --git a/pytorch_lightning/metrics/functional/mean_squared_log_error.py b/pytorch_lightning/metrics/functional/mean_squared_log_error.py index baec63c7248f27..56654ea47daf28 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_log_error.py @@ -11,40 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import mean_squared_log_error as _mean_squared_log_error - -def _mean_squared_log_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - sum_squared_log_error = torch.sum(torch.pow(torch.log1p(preds) - torch.log1p(target), 2)) - n_obs = target.numel() - return sum_squared_log_error, n_obs - - -def _mean_squared_log_error_compute(sum_squared_log_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_squared_log_error / n_obs +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_mean_squared_log_error, ver_deprecate="1.3.0", ver_remove="1.5.0") def mean_squared_log_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean squared log error - - Args: - preds: estimated labels - target: ground truth labels - - Return: - Tensor with RMSLE - - Example: - >>> from pytorch_lightning.metrics.functional import mean_squared_log_error - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_squared_log_error(x, y) - tensor(0.0207) + .. deprecated:: + Use :func:`torchmetrics.functional.mean_squared_log_error`. Will be removed in v1.5.0. """ - sum_squared_log_error, n_obs = _mean_squared_log_error_update(preds, target) - return _mean_squared_log_error_compute(sum_squared_log_error, n_obs) diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py index 8b0259694ef4c6..4f820718545cba 100644 --- a/pytorch_lightning/metrics/regression/explained_variance.py +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -13,72 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import ExplainedVariance as _ExplainedVariance -from pytorch_lightning.metrics.functional.explained_variance import ( - _explained_variance_compute, - _explained_variance_update, -) -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.deprecation import deprecated -class ExplainedVariance(Metric): - r""" - Computes `explained variance - `_: - - .. math:: \text{ExplainedVariance} = 1 - \frac{\text{Var}(y - \hat{y})}{\text{Var}(y)} - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a - tensor of predictions. - - Forward accepts - - - ``preds`` (float tensor): ``(N,)`` or ``(N, ...)`` (multioutput) - - ``target`` (long tensor): ``(N,)`` or ``(N, ...)`` (multioutput) - - In the case of multioutput, as default the variances will be uniformly - averaged over the additional dimensions. Please see argument `multioutput` - for changing this behavior. - - Args: - multioutput: - Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is `'uniform_average'`.): - - * `'raw_values'` returns full set of scores - * `'uniform_average'` scores are uniformly averaged - * `'variance_weighted'` scores are weighted by their individual variances - - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Raises: - ValueError: - If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``. - - Example: - - >>> from pytorch_lightning.metrics import ExplainedVariance - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> explained_variance = ExplainedVariance() - >>> explained_variance(preds, target) - tensor(0.9572) - - >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) - >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) - >>> explained_variance = ExplainedVariance(multioutput='raw_values') - >>> explained_variance(preds, target) - tensor([0.9677, 1.0000]) - """ +class ExplainedVariance(_ExplainedVariance): + @deprecated(target=_ExplainedVariance, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, multioutput: str = 'uniform_average', @@ -87,43 +29,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') - if multioutput not in allowed_multioutput: - raise ValueError( - f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}' - ) - self.multioutput = multioutput - self.add_state("y", default=[], dist_reduce_fx=None) - self.add_state("y_pred", default=[], dist_reduce_fx=None) - - rank_zero_warn( - 'Metric `ExplainedVariance` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' - ) - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - preds, target = _explained_variance_update(preds, target) - self.y_pred.append(preds) - self.y.append(target) + This implementation refers to :class:`~torchmetrics.ExplainedVariance`. - def compute(self): - """ - Computes explained variance over state. + .. deprecated:: + Use :class:`~torchmetrics.ExplainedVariance`. Will be removed in v1.5.0. """ - preds = torch.cat(self.y_pred, dim=0) - target = torch.cat(self.y, dim=0) - return _explained_variance_compute(preds, target, self.multioutput) diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py index 484ccbe83284e9..8510275c127d78 100644 --- a/pytorch_lightning/metrics/regression/mean_absolute_error.py +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -13,42 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import MeanAbsoluteError as _MeanAbsoluteError -from pytorch_lightning.metrics.functional.mean_absolute_error import ( - _mean_absolute_error_compute, - _mean_absolute_error_update, -) +from pytorch_lightning.utilities.deprecation import deprecated -class MeanAbsoluteError(Metric): - r""" - Computes `mean absolute error `_ (MAE): - - .. math:: \text{MAE} = \frac{1}{N}\sum_i^N | y_i - \hat{y_i} | - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - - Args: - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example: - - >>> from pytorch_lightning.metrics import MeanAbsoluteError - >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) - >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> mean_absolute_error = MeanAbsoluteError() - >>> mean_absolute_error(preds, target) - tensor(0.5000) - """ +class MeanAbsoluteError(_MeanAbsoluteError): + @deprecated(target=_MeanAbsoluteError, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, compute_on_step: bool = True, @@ -56,31 +28,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - sum_abs_error, n_obs = _mean_absolute_error_update(preds, target) + This implementation refers to :class:`~torchmetrics.MeanAbsoluteError`. - self.sum_abs_error += sum_abs_error - self.total += n_obs - - def compute(self): - """ - Computes mean absolute error over state. + .. deprecated:: + Use :class:`~torchmetrics.MeanAbsoluteError`. Will be removed in v1.5.0. """ - return _mean_absolute_error_compute(self.sum_abs_error, self.total) diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py index c26371514e7cd6..cbe09faf0046ce 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -13,43 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import MeanSquaredError as _MeanSquaredError -from pytorch_lightning.metrics.functional.mean_squared_error import ( - _mean_squared_error_compute, - _mean_squared_error_update, -) +from pytorch_lightning.utilities.deprecation import deprecated -class MeanSquaredError(Metric): - r""" - Computes `mean squared error `_ (MSE): - - .. math:: \text{MSE} = \frac{1}{N}\sum_i^N(y_i - \hat{y_i})^2 - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - - Args: - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example: - - >>> from pytorch_lightning.metrics import MeanSquaredError - >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) - >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) - >>> mean_squared_error = MeanSquaredError() - >>> mean_squared_error(preds, target) - tensor(0.8750) - - """ +class MeanSquaredError(_MeanSquaredError): + @deprecated(target=_MeanSquaredError, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, compute_on_step: bool = True, @@ -57,31 +28,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - sum_squared_error, n_obs = _mean_squared_error_update(preds, target) - - self.sum_squared_error += sum_squared_error - self.total += n_obs + This implementation refers to :class:`~torchmetrics.MeanSquaredError`. - def compute(self): - """ - Computes mean squared error over state. + .. deprecated:: + Use :class:`~torchmetrics.MeanSquaredError`. Will be removed in v1.5.0. """ - return _mean_squared_error_compute(self.sum_squared_error, self.total) diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py index caaf09a3663ffa..795d6f5409abfb 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -13,45 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import MeanSquaredLogError as _MeanSquaredLogError -from pytorch_lightning.metrics.functional.mean_squared_log_error import ( - _mean_squared_log_error_compute, - _mean_squared_log_error_update, -) +from pytorch_lightning.utilities.deprecation import deprecated -class MeanSquaredLogError(Metric): - r""" - Computes `mean squared logarithmic error - `_ - (MSLE): - - .. math:: \text{MSLE} = \frac{1}{N}\sum_i^N (\log_e(1 + y_i) - \log_e(1 + \hat{y_i}))^2 - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - - Args: - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example: - - >>> from pytorch_lightning.metrics import MeanSquaredLogError - >>> target = torch.tensor([2.5, 5, 4, 8]) - >>> preds = torch.tensor([3, 5, 2.5, 7]) - >>> mean_squared_log_error = MeanSquaredLogError() - >>> mean_squared_log_error(preds, target) - tensor(0.0397) - - """ +class MeanSquaredLogError(_MeanSquaredLogError): + @deprecated(target=_MeanSquaredLogError, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, compute_on_step: bool = True, @@ -59,31 +28,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - sum_squared_log_error, n_obs = _mean_squared_log_error_update(preds, target) - - self.sum_squared_log_error += sum_squared_log_error - self.total += n_obs + This implementation refers to :class:`~torchmetrics.MeanSquaredLogError`. - def compute(self): - """ - Compute mean squared logarithmic error over state. + .. deprecated:: + Use :class:`~torchmetrics.MeanSquaredLogError`. Will be removed in v1.5.0. """ - return _mean_squared_log_error_compute(self.sum_squared_log_error, self.total) diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 349e4175a74446..bcb351984a1752 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -2,6 +2,7 @@ import pytest import torch + from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin diff --git a/tests/metrics/regression/test_explained_variance.py b/tests/metrics/regression/test_explained_variance.py deleted file mode 100644 index adab562ac60552..00000000000000 --- a/tests/metrics/regression/test_explained_variance.py +++ /dev/null @@ -1,77 +0,0 @@ -from collections import namedtuple -from functools import partial - -import pytest -import torch -from sklearn.metrics import explained_variance_score - -from pytorch_lightning.metrics.functional import explained_variance -from pytorch_lightning.metrics.regression import ExplainedVariance -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -num_targets = 5 - -Input = namedtuple('Input', ["preds", "target"]) - -_single_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.rand(NUM_BATCHES, BATCH_SIZE), -) - -_multi_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), -) - - -def _single_target_sk_metric(preds, target, sk_fn=explained_variance_score): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_fn(sk_target, sk_preds) - - -def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() - return sk_fn(sk_target, sk_preds) - - -@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted']) -@pytest.mark.parametrize( - "preds, target, sk_metric", - [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), - ], -) -class TestExplainedVariance(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_explained_variance(self, multioutput, preds, target, sk_metric, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp, - preds, - target, - ExplainedVariance, - partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)), - dist_sync_on_step, - metric_args=dict(multioutput=multioutput), - ) - - def test_explained_variance_functional(self, multioutput, preds, target, sk_metric): - self.run_functional_metric_test( - preds, - target, - explained_variance, - partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)), - metric_args=dict(multioutput=multioutput), - ) - - -def test_error_on_different_shape(metric_class=ExplainedVariance): - metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/metrics/regression/test_mean_error.py b/tests/metrics/regression/test_mean_error.py deleted file mode 100644 index 041ce12f11164c..00000000000000 --- a/tests/metrics/regression/test_mean_error.py +++ /dev/null @@ -1,87 +0,0 @@ -from collections import namedtuple -from functools import partial - -import pytest -import torch -from sklearn.metrics import mean_absolute_error as sk_mean_absolute_error -from sklearn.metrics import mean_squared_error as sk_mean_squared_error -from sklearn.metrics import mean_squared_log_error as sk_mean_squared_log_error - -from pytorch_lightning.metrics.functional import mean_absolute_error, mean_squared_error, mean_squared_log_error -from pytorch_lightning.metrics.regression import MeanAbsoluteError, MeanSquaredError, MeanSquaredLogError -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -num_targets = 5 - -Input = namedtuple('Input', ["preds", "target"]) - -_single_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.rand(NUM_BATCHES, BATCH_SIZE), -) - -_multi_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), -) - - -def _single_target_sk_metric(preds, target, sk_fn=mean_squared_error): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_fn(sk_preds, sk_target) - - -def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() - return sk_fn(sk_preds, sk_target) - - -@pytest.mark.parametrize( - "preds, target, sk_metric", - [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), - ], -) -@pytest.mark.parametrize( - "metric_class, metric_functional, sk_fn", - [ - (MeanSquaredError, mean_squared_error, sk_mean_squared_error), - (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error), - (MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error), - ], -) -class TestMeanError(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_mean_error_class( - self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, ddp, dist_sync_on_step - ): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=metric_class, - sk_metric=partial(sk_metric, sk_fn=sk_fn), - dist_sync_on_step=dist_sync_on_step, - ) - - def test_mean_error_functional(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): - self.run_functional_metric_test( - preds=preds, - target=target, - metric_functional=metric_functional, - sk_metric=partial(sk_metric, sk_fn=sk_fn), - ) - - -@pytest.mark.parametrize("metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError]) -def test_error_on_different_shape(metric_class): - metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py index 339d07b1636327..eaf17ec0792daf 100644 --- a/tests/metrics/test_remove_1-5_metrics.py +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -22,10 +22,14 @@ AUROC, AveragePrecision, ConfusionMatrix, + ExplainedVariance, F1, FBeta, HammingDistance, IoU, + MeanAbsoluteError, + MeanSquaredError, + MeanSquaredLogError, MetricCollection, Precision, PrecisionRecallCurve, @@ -38,10 +42,14 @@ auroc, average_precision, confusion_matrix, + explained_variance, f1, fbeta, hamming_distance, iou, + mean_absolute_error, + mean_squared_error, + mean_squared_log_error, precision, precision_recall, precision_recall_curve, @@ -50,6 +58,7 @@ stat_scores, ) from pytorch_lightning.metrics.functional.accuracy import accuracy +from pytorch_lightning.metrics.functional.mean_relative_error import mean_relative_error from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot @@ -232,8 +241,52 @@ def test_v1_5_metric_detect(): IoU(num_classes=1) target = torch.randint(0, 2, (10, 25, 25)) - pred = torch.tensor(target) - pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] + preds = torch.tensor(target) + preds[2:5, 7:13, 9:15] = 1 - preds[2:5, 7:13, 9:15] iou.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert torch.allclose(iou(pred, target), torch.tensor(0.9660), atol=1e-4) + assert torch.allclose(iou(preds, target), torch.tensor(0.9660), atol=1e-4) + + +def test_v1_5_metric_regress(): + ExplainedVariance.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + ExplainedVariance() + + MeanAbsoluteError.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + MeanAbsoluteError() + + MeanSquaredError.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + MeanSquaredError() + + MeanSquaredLogError.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + MeanSquaredLogError() + + target = torch.tensor([3, -0.5, 2, 7]) + preds = torch.tensor([2.5, 0.0, 2, 8]) + explained_variance.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = explained_variance(preds, target) + assert torch.allclose(res, torch.tensor(0.9572), atol=1e-4) + + x = torch.tensor([0., 1, 2, 3]) + y = torch.tensor([0., 1, 2, 2]) + mean_absolute_error.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert mean_absolute_error(x, y) == 0.25 + + mean_relative_error.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert mean_relative_error(x, y) == 0.125 + + mean_squared_error.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert mean_squared_error(x, y) == 0.25 + + mean_squared_log_error.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = mean_squared_log_error(x, y) + assert torch.allclose(res, torch.tensor(0.0207), atol=1e-4) diff --git a/tests/utilities/test_argparse.py b/tests/utilities/test_argparse.py index fdf5ae0cafe65a..aef266d639b4ae 100644 --- a/tests/utilities/test_argparse.py +++ b/tests/utilities/test_argparse.py @@ -7,13 +7,13 @@ from pytorch_lightning import Trainer from pytorch_lightning.utilities.argparse import ( + _gpus_arg_default, + _int_or_float_type, add_argparse_args, from_argparse_args, get_abbrev_qualified_cls_name, parse_argparser, parse_args_from_docstring, - _gpus_arg_default, - _int_or_float_type ) From f93414d085784e177e75a143bafefa5ffdadd0c8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 23 Mar 2021 11:01:25 +0100 Subject: [PATCH 13/22] Prune metyrics: regression 9/n (#6637) * psnr * r2score * ssim * chlog --- CHANGELOG.md | 2 + pytorch_lightning/metrics/functional/psnr.py | 88 +---------- .../metrics/functional/r2score.py | 124 +-------------- pytorch_lightning/metrics/functional/ssim.py | 142 +----------------- pytorch_lightning/metrics/regression/psnr.py | 123 +-------------- .../metrics/regression/r2score.py | 122 +-------------- pytorch_lightning/metrics/regression/ssim.py | 78 +--------- tests/metrics/regression/test_psnr.py | 133 ---------------- tests/metrics/regression/test_r2score.py | 114 -------------- tests/metrics/regression/test_ssim.py | 104 ------------- tests/metrics/test_remove_1-5_metrics.py | 39 +++++ 11 files changed, 80 insertions(+), 989 deletions(-) delete mode 100644 tests/metrics/regression/test_psnr.py delete mode 100644 tests/metrics/regression/test_r2score.py delete mode 100644 tests/metrics/regression/test_ssim.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 57a071bff297af..4cf3e0f1fd326a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -92,6 +92,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#6636](https://github.com/PyTorchLightning/pytorch-lightning/pull/6636), + [#6637](https://github.com/PyTorchLightning/pytorch-lightning/pull/6637), + ) diff --git a/pytorch_lightning/metrics/functional/psnr.py b/pytorch_lightning/metrics/functional/psnr.py index 0b50ea092b7fad..dd7aa44ae628ea 100644 --- a/pytorch_lightning/metrics/functional/psnr.py +++ b/pytorch_lightning/metrics/functional/psnr.py @@ -14,46 +14,12 @@ from typing import Optional, Tuple, Union import torch -from torchmetrics.utilities import reduce +from torchmetrics.functional import psnr as _psnr -from pytorch_lightning.utilities import rank_zero_warn - - -def _psnr_compute( - sum_squared_error: torch.Tensor, - n_obs: torch.Tensor, - data_range: torch.Tensor, - base: float = 10.0, - reduction: str = 'elementwise_mean', -) -> torch.Tensor: - psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs) - psnr = psnr_base_e * (10 / torch.log(torch.tensor(base))) - return reduce(psnr, reduction=reduction) - - -def _psnr_update(preds: torch.Tensor, - target: torch.Tensor, - dim: Optional[Union[int, Tuple[int, ...]]] = None) -> Tuple[torch.Tensor, torch.Tensor]: - if dim is None: - sum_squared_error = torch.sum(torch.pow(preds - target, 2)) - n_obs = torch.tensor(target.numel(), device=target.device) - return sum_squared_error, n_obs - - sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim) - - if isinstance(dim, int): - dim_list = [dim] - else: - dim_list = list(dim) - if not dim_list: - n_obs = torch.tensor(target.numel(), device=target.device) - else: - n_obs = torch.tensor(target.size(), device=target.device)[dim_list].prod() - n_obs = n_obs.expand_as(sum_squared_error) - - return sum_squared_error, n_obs +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_psnr, ver_deprecate="1.3.0", ver_remove="1.5.0") def psnr( preds: torch.Tensor, target: torch.Tensor, @@ -63,50 +29,6 @@ def psnr( dim: Optional[Union[int, Tuple[int, ...]]] = None, ) -> torch.Tensor: """ - Computes the peak signal-to-noise ratio - - Args: - preds: estimated signal - target: groun truth signal - data_range: - the range of the data. If None, it is determined from the data (max - min). ``data_range`` must be given - when ``dim`` is not None. - base: a base of a logarithm to use (default: 10) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - dim: - Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is - None meaning scores will be reduced across all dimensions. - Return: - Tensor with PSNR score - - Raises: - ValueError: - If ``dim`` is not ``None`` and ``data_range`` is not provided. - - Example: - >>> from pytorch_lightning.metrics.functional import psnr - >>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) - >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) - >>> psnr(pred, target) - tensor(2.5527) - + .. deprecated:: + Use :func:`torchmetrics.functional.psnr`. Will be removed in v1.5.0. """ - if dim is None and reduction != 'elementwise_mean': - rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') - - if data_range is None: - if dim is not None: - # Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to calculate - # `data_range` in the future. - raise ValueError("The `data_range` must be given when `dim` is not None.") - - data_range = target.max() - target.min() - else: - data_range = torch.tensor(float(data_range)) - sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim) - return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction) diff --git a/pytorch_lightning/metrics/functional/r2score.py b/pytorch_lightning/metrics/functional/r2score.py index d3f1090564a88b..49273d9cefaed9 100644 --- a/pytorch_lightning/metrics/functional/r2score.py +++ b/pytorch_lightning/metrics/functional/r2score.py @@ -11,133 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import r2score as _r2score -from pytorch_lightning.utilities import rank_zero_warn - - -def _r2score_update( - preds: torch.tensor, - target: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - _check_same_shape(preds, target) - if preds.ndim > 2: - raise ValueError( - 'Expected both prediction and target to be 1D or 2D tensors,' - f' but recevied tensors with dimension {preds.shape}' - ) - if len(preds) < 2: - raise ValueError('Needs atleast two samples to calculate r2 score.') - - sum_error = torch.sum(target, dim=0) - sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0) - residual = torch.sum(torch.pow(target - preds, 2.0), dim=0) - total = target.size(0) - - return sum_squared_error, sum_error, residual, total - - -def _r2score_compute( - sum_squared_error: torch.Tensor, - sum_error: torch.Tensor, - residual: torch.Tensor, - total: torch.Tensor, - adjusted: int = 0, - multioutput: str = "uniform_average" -) -> torch.Tensor: - mean_error = sum_error / total - diff = sum_squared_error - sum_error * mean_error - raw_scores = 1 - (residual / diff) - - if multioutput == "raw_values": - r2score = raw_scores - elif multioutput == "uniform_average": - r2score = torch.mean(raw_scores) - elif multioutput == "variance_weighted": - diff_sum = torch.sum(diff) - r2score = torch.sum(diff / diff_sum * raw_scores) - else: - raise ValueError( - 'Argument `multioutput` must be either `raw_values`,' - f' `uniform_average` or `variance_weighted`. Received {multioutput}.' - ) - - if adjusted < 0 or not isinstance(adjusted, int): - raise ValueError('`adjusted` parameter should be an integer larger or' ' equal to 0.') - - if adjusted != 0: - if adjusted > total - 1: - rank_zero_warn( - "More independent regressions than datapoints in" - " adjusted r2 score. Falls back to standard r2 score.", UserWarning - ) - elif adjusted == total - 1: - rank_zero_warn("Division by zero in adjusted r2 score. Falls back to" " standard r2 score.", UserWarning) - else: - r2score = 1 - (1 - r2score) * (total - 1) / (total - adjusted - 1) - return r2score +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_r2score, ver_deprecate="1.3.0", ver_remove="1.5.0") def r2score( preds: torch.Tensor, target: torch.Tensor, adjusted: int = 0, multioutput: str = "uniform_average", ) -> torch.Tensor: - r""" - Computes r2 score also known as `coefficient of determination - `_: - - .. math:: R^2 = 1 - \frac{SS_res}{SS_tot} - - where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and - :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate - adjusted r2 score given by - - .. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1} - - where the parameter :math:`k` (the number of independent regressors) should - be provided as the ``adjusted`` argument. - - Args: - preds: estimated labels - target: ground truth labels - adjusted: number of independent regressors for calculating adjusted r2 score. - Default 0 (standard r2 score). - multioutput: Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is ``'uniform_average'``.): - - * ``'raw_values'`` returns full set of scores - * ``'uniform_average'`` scores are uniformly averaged - * ``'variance_weighted'`` scores are weighted by their individual variances - - Raises: - ValueError: - If both ``preds`` and ``targets`` are not ``1D`` or ``2D`` tensors. - ValueError: - If ``len(preds)`` is less than ``2`` - since at least ``2`` sampels are needed to calculate r2 score. - ValueError: - If ``multioutput`` is not one of ``raw_values``, - ``uniform_average`` or ``variance_weighted``. - ValueError: - If ``adjusted`` is not an ``integer`` greater than ``0``. - - Example: - - >>> from pytorch_lightning.metrics.functional import r2score - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> r2score(preds, target) - tensor(0.9486) - - >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) - >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) - >>> r2score(preds, target, multioutput='raw_values') - tensor([0.9654, 0.9082]) """ - sum_squared_error, sum_error, residual, total = _r2score_update(preds, target) - return _r2score_compute(sum_squared_error, sum_error, residual, total, adjusted, multioutput) + .. deprecated:: + Use :func:`torchmetrics.functional.r2score`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/ssim.py b/pytorch_lightning/metrics/functional/ssim.py index 4899a3ad3be4dc..8809fec8d8ff1c 100644 --- a/pytorch_lightning/metrics/functional/ssim.py +++ b/pytorch_lightning/metrics/functional/ssim.py @@ -11,107 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence import torch -from torch.nn import functional as F -from torchmetrics.utilities import reduce -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import ssim as _ssim - -def _gaussian(kernel_size: int, sigma: int, dtype: torch.dtype, device: torch.device): - dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device) - gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2) - return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) - - -def _gaussian_kernel( - channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device -): - gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device) - gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device) - kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) - - return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) - - -def _ssim_update( - preds: torch.Tensor, - target: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - if preds.dtype != target.dtype: - raise TypeError( - "Expected `preds` and `target` to have the same data type." - f" Got pred: {preds.dtype} and target: {target.dtype}." - ) - _check_same_shape(preds, target) - if len(preds.shape) != 4: - raise ValueError( - "Expected `preds` and `target` to have BxCxHxW shape." - f" Got pred: {preds.shape} and target: {target.shape}." - ) - return preds, target - - -def _ssim_compute( - preds: torch.Tensor, - target: torch.Tensor, - kernel_size: Sequence[int] = (11, 11), - sigma: Sequence[float] = (1.5, 1.5), - reduction: str = "elementwise_mean", - data_range: Optional[float] = None, - k1: float = 0.01, - k2: float = 0.03, -): - if len(kernel_size) != 2 or len(sigma) != 2: - raise ValueError( - "Expected `kernel_size` and `sigma` to have the length of two." - f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}." - ) - - if any(x % 2 == 0 or x <= 0 for x in kernel_size): - raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.") - - if any(y <= 0 for y in sigma): - raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.") - - if data_range is None: - data_range = max(preds.max() - preds.min(), target.max() - target.min()) - - c1 = pow(k1 * data_range, 2) - c2 = pow(k2 * data_range, 2) - device = preds.device - - channel = preds.size(1) - dtype = preds.dtype - kernel = _gaussian_kernel(channel, kernel_size, sigma, dtype, device) - pad_w = (kernel_size[0] - 1) // 2 - pad_h = (kernel_size[1] - 1) // 2 - - preds = F.pad(preds, (pad_w, pad_w, pad_h, pad_h), mode='reflect') - target = F.pad(target, (pad_w, pad_w, pad_h, pad_h), mode='reflect') - - input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) - outputs = F.conv2d(input_list, kernel, groups=channel) - output_list = [outputs[x * preds.size(0):(x + 1) * preds.size(0)] for x in range(len(outputs))] - - mu_pred_sq = output_list[0].pow(2) - mu_target_sq = output_list[1].pow(2) - mu_pred_target = output_list[0] * output_list[1] - - sigma_pred_sq = output_list[2] - mu_pred_sq - sigma_target_sq = output_list[3] - mu_target_sq - sigma_pred_target = output_list[4] - mu_pred_target - - upper = 2 * sigma_pred_target + c2 - lower = sigma_pred_sq + sigma_target_sq + c2 - - ssim_idx = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower) - ssim_idx = ssim_idx[..., pad_h:-pad_h, pad_w:-pad_w] - - return reduce(ssim_idx, reduction) +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_ssim, ver_deprecate="1.3.0", ver_remove="1.5.0") def ssim( preds: torch.Tensor, target: torch.Tensor, @@ -123,44 +31,6 @@ def ssim( k2: float = 0.03, ) -> torch.Tensor: """ - Computes Structual Similarity Index Measure - - Args: - preds: estimated image - target: ground truth image - kernel_size: size of the gaussian kernel (default: (11, 11)) - sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - data_range: Range of the image. If ``None``, it is determined from the image (max - min) - k1: Parameter of SSIM. Default: 0.01 - k2: Parameter of SSIM. Default: 0.03 - - Return: - Tensor with SSIM score - - Raises: - TypeError: - If ``preds`` and ``target`` don't have the same data type. - ValueError: - If ``preds`` and ``target`` don't have ``BxCxHxW shape``. - ValueError: - If the length of ``kernel_size`` or ``sigma`` is not ``2``. - ValueError: - If one of the elements of ``kernel_size`` is not an ``odd positive number``. - ValueError: - If one of the elements of ``sigma`` is not a ``positive number``. - - Example: - >>> from pytorch_lightning.metrics.functional import ssim - >>> preds = torch.rand([16, 1, 16, 16]) - >>> target = preds * 0.75 - >>> ssim(preds, target) - tensor(0.9219) + .. deprecated:: + Use :func:`torchmetrics.functional.ssim`. Will be removed in v1.5.0. """ - preds, target = _ssim_update(preds, target) - return _ssim_compute(preds, target, kernel_size, sigma, reduction, data_range, k1, k2) diff --git a/pytorch_lightning/metrics/regression/psnr.py b/pytorch_lightning/metrics/regression/psnr.py index 746ff1e52d574f..85b8eceaa24c5b 100644 --- a/pytorch_lightning/metrics/regression/psnr.py +++ b/pytorch_lightning/metrics/regression/psnr.py @@ -11,61 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Tuple, Union -import torch -from torchmetrics import Metric +from torchmetrics import PSNR as _PSNR -from pytorch_lightning import utilities -from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update +from pytorch_lightning.utilities.deprecation import deprecated -class PSNR(Metric): - r""" - Computes `peak signal-to-noise ratio `_ (PSNR): - - .. math:: \text{PSNR}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)}\right) - - Where :math:`\text{MSE}` denotes the `mean-squared-error - `_ function. - - Args: - data_range: - the range of the data. If None, it is determined from the data (max - min). - The ``data_range`` must be given when ``dim`` is not None. - base: a base of a logarithm to use (default: 10) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - dim: - Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is - None meaning scores will be reduced across all dimensions and all batches. - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Raises: - ValueError: - If ``dim`` is not ``None`` and ``data_range`` is not given. - - Example: - - >>> from pytorch_lightning.metrics import PSNR - >>> psnr = PSNR() - >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) - >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) - >>> psnr(preds, target) - tensor(2.5527) - - """ +class PSNR(_PSNR): + @deprecated(target=_PSNR, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, data_range: Optional[float] = None, @@ -76,71 +31,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) - - if dim is None and reduction != 'elementwise_mean': - utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') - - if dim is None: - self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - else: - self.add_state("sum_squared_error", default=[]) - self.add_state("total", default=[]) - - if data_range is None: - if dim is not None: - # Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to - # calculate `data_range` in the future. - raise ValueError("The `data_range` must be given when `dim` is not None.") - - self.data_range = None - self.add_state("min_target", default=torch.tensor(0.0), dist_reduce_fx=torch.min) - self.add_state("max_target", default=torch.tensor(0.0), dist_reduce_fx=torch.max) - else: - self.register_buffer("data_range", torch.tensor(float(data_range))) - self.base = base - self.reduction = reduction - self.dim = tuple(dim) if isinstance(dim, Sequence) else dim - - def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. + This implementation refers to :class:`~torchmetrics.PSNR`. - Args: - preds: Predictions from model - target: Ground truth values + .. deprecated:: + Use :class:`~torchmetrics.PSNR`. Will be removed in v1.5.0. """ - sum_squared_error, n_obs = _psnr_update(preds, target, dim=self.dim) - if self.dim is None: - if self.data_range is None: - # keep track of min and max target values - self.min_target = min(target.min(), self.min_target) - self.max_target = max(target.max(), self.max_target) - - self.sum_squared_error += sum_squared_error - self.total += n_obs - else: - self.sum_squared_error.append(sum_squared_error) - self.total.append(n_obs) - - def compute(self): - """ - Compute peak signal-to-noise ratio over state. - """ - if self.data_range is not None: - data_range = self.data_range - else: - data_range = self.max_target - self.min_target - - if self.dim is None: - sum_squared_error = self.sum_squared_error - total = self.total - else: - sum_squared_error = torch.cat([values.flatten() for values in self.sum_squared_error]) - total = torch.cat([values.flatten() for values in self.total]) - return _psnr_compute(sum_squared_error, total, data_range, base=self.base, reduction=self.reduction) diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py index 8156b8bc72d484..52621d6df7c286 100644 --- a/pytorch_lightning/metrics/regression/r2score.py +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -13,81 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import R2Score as _R2Score -from pytorch_lightning.metrics.functional.r2score import _r2score_compute, _r2score_update +from pytorch_lightning.utilities.deprecation import deprecated -class R2Score(Metric): - r""" - Computes r2 score also known as `coefficient of determination - `_: - - .. math:: R^2 = 1 - \frac{SS_res}{SS_tot} - - where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and - :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate - adjusted r2 score given by - - .. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1} - - where the parameter :math:`k` (the number of independent regressors) should - be provided as the `adjusted` argument. - - Forward accepts - - - ``preds`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput) - - ``target`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput) - - In the case of multioutput, as default the variances will be uniformly - averaged over the additional dimensions. Please see argument `multioutput` - for changing this behavior. - - Args: - num_outputs: - Number of outputs in multioutput setting (default is 1) - adjusted: - number of independent regressors for calculating adjusted r2 score. - Default 0 (standard r2 score). - multioutput: - Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is ``'uniform_average'``.): - - * ``'raw_values'`` returns full set of scores - * ``'uniform_average'`` scores are uniformly averaged - * ``'variance_weighted'`` scores are weighted by their individual variances - - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Raises: - ValueError: - If ``adjusted`` parameter is not an integer larger or equal to 0. - ValueError: - If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``. - - Example: - - >>> from pytorch_lightning.metrics import R2Score - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> r2score = R2Score() - >>> r2score(preds, target) - tensor(0.9486) - - >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) - >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) - >>> r2score = R2Score(num_outputs=2, multioutput='raw_values') - >>> r2score(preds, target) - tensor([0.9654, 0.9082]) - """ +class R2Score(_R2Score): + @deprecated(target=_R2Score, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, num_outputs: int = 1, @@ -98,50 +31,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.num_outputs = num_outputs - - if adjusted < 0 or not isinstance(adjusted, int): - raise ValueError('`adjusted` parameter should be an integer larger or equal to 0.') - self.adjusted = adjusted - - allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') - if multioutput not in allowed_multioutput: - raise ValueError( - f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}' - ) - self.multioutput = multioutput - - self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") - self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") - self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - sum_squared_error, sum_error, residual, total = _r2score_update(preds, target) + This implementation refers to :class:`~torchmetrics.R2Score`. - self.sum_squared_error += sum_squared_error - self.sum_error += sum_error - self.residual += residual - self.total += total - - def compute(self) -> torch.Tensor: - """ - Computes r2 score over the metric states. + .. deprecated:: + Use :class:`~torchmetrics.R2Score`. Will be removed in v1.5.0. """ - return _r2score_compute( - self.sum_squared_error, self.sum_error, self.residual, self.total, self.adjusted, self.multioutput - ) diff --git a/pytorch_lightning/metrics/regression/ssim.py b/pytorch_lightning/metrics/regression/ssim.py index a3bbab938ffad9..b290808c6fa5e4 100644 --- a/pytorch_lightning/metrics/regression/ssim.py +++ b/pytorch_lightning/metrics/regression/ssim.py @@ -13,43 +13,14 @@ # limitations under the License. from typing import Any, Optional, Sequence -import torch -from torchmetrics import Metric +from torchmetrics import SSIM as _SSIM -from pytorch_lightning.metrics.functional.ssim import _ssim_compute, _ssim_update -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.deprecation import deprecated -class SSIM(Metric): - """ - Computes `Structual Similarity Index Measure - `_ (SSIM). - - Args: - kernel_size: size of the gaussian kernel (default: (11, 11)) - sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - data_range: Range of the image. If ``None``, it is determined from the image (max - min) - k1: Parameter of SSIM. Default: 0.01 - k2: Parameter of SSIM. Default: 0.03 - - Return: - Tensor with SSIM score - - Example: - >>> from pytorch_lightning.metrics import SSIM - >>> preds = torch.rand([16, 1, 16, 16]) - >>> target = preds * 0.75 - >>> ssim = SSIM() - >>> ssim(preds, target) - tensor(0.9219) - """ +class SSIM(_SSIM): + @deprecated(target=_SSIM, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, kernel_size: Sequence[int] = (11, 11), @@ -62,44 +33,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) - rank_zero_warn( - 'Metric `SSIM` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' - ) - - self.add_state("y", default=[], dist_reduce_fx=None) - self.add_state("y_pred", default=[], dist_reduce_fx=None) - self.kernel_size = kernel_size - self.sigma = sigma - self.data_range = data_range - self.k1 = k1 - self.k2 = k2 - self.reduction = reduction - - def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. + This implementation refers to :class:`~torchmetrics.SSIM`. - Args: - preds: Predictions from model - target: Ground truth values - """ - preds, target = _ssim_update(preds, target) - self.y_pred.append(preds) - self.y.append(target) - - def compute(self): - """ - Computes explained variance over state. + .. deprecated:: + Use :class:`~torchmetrics.SSIM`. Will be removed in v1.5.0. """ - preds = torch.cat(self.y_pred, dim=0) - target = torch.cat(self.y, dim=0) - return _ssim_compute( - preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2 - ) diff --git a/tests/metrics/regression/test_psnr.py b/tests/metrics/regression/test_psnr.py deleted file mode 100644 index eb07fffb9d55c2..00000000000000 --- a/tests/metrics/regression/test_psnr.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import namedtuple -from functools import partial - -import numpy as np -import pytest -import torch -from skimage.metrics import peak_signal_noise_ratio - -from pytorch_lightning.metrics.functional import psnr -from pytorch_lightning.metrics.regression import PSNR -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -Input = namedtuple('Input', ["preds", "target"]) - -_input_size = (NUM_BATCHES, BATCH_SIZE, 32, 32) -_inputs = [ - Input( - preds=torch.randint(n_cls_pred, _input_size, dtype=torch.float), - target=torch.randint(n_cls_target, _input_size, dtype=torch.float), - ) for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)] -] - - -def _to_sk_peak_signal_noise_ratio_inputs(value, dim): - value = value.numpy() - batches = value[None] if value.ndim == len(_input_size) - 1 else value - - if dim is None: - return [batches] - - num_dims = np.size(dim) - if not num_dims: - return batches - - inputs = [] - for batch in batches: - batch = np.moveaxis(batch, dim, np.arange(-num_dims, 0)) - psnr_input_shape = batch.shape[-num_dims:] - inputs.extend(batch.reshape(-1, *psnr_input_shape)) - return inputs - - -def _sk_psnr(preds, target, data_range, reduction, dim): - sk_preds_lists = _to_sk_peak_signal_noise_ratio_inputs(preds, dim=dim) - sk_target_lists = _to_sk_peak_signal_noise_ratio_inputs(target, dim=dim) - np_reduce_map = {"elementwise_mean": np.mean, "none": np.array, "sum": np.sum} - return np_reduce_map[reduction]([ - peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range) - for sk_target, sk_preds in zip(sk_target_lists, sk_preds_lists) - ]) - - -def _base_e_sk_psnr(preds, target, data_range, reduction, dim): - return _sk_psnr(preds, target, data_range, reduction, dim) * np.log(10) - - -@pytest.mark.parametrize( - "preds, target, data_range, reduction, dim", - [ - (_inputs[0].preds, _inputs[0].target, 10, "elementwise_mean", None), - (_inputs[1].preds, _inputs[1].target, 10, "elementwise_mean", None), - (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", None), - (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", 1), - (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", (1, 2)), - (_inputs[2].preds, _inputs[2].target, 5, "sum", (1, 2)), - ], -) -@pytest.mark.parametrize( - "base, sk_metric", - [ - (10.0, _sk_psnr), - (2.718281828459045, _base_e_sk_psnr), - ], -) -class TestPSNR(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_psnr(self, preds, target, data_range, base, reduction, dim, sk_metric, ddp, dist_sync_on_step): - _args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim} - self.run_class_metric_test( - ddp, - preds, - target, - PSNR, - partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim), - metric_args=_args, - dist_sync_on_step=dist_sync_on_step, - ) - - def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduction, dim): - _args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim} - self.run_functional_metric_test( - preds, - target, - psnr, - partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim), - metric_args=_args, - ) - - -@pytest.mark.parametrize("reduction", ["none", "sum"]) -def test_reduction_for_dim_none(reduction): - match = f"The `reduction={reduction}` will not have any effect when `dim` is None." - with pytest.warns(UserWarning, match=match): - PSNR(reduction=reduction, dim=None) - - with pytest.warns(UserWarning, match=match): - psnr(_inputs[0].preds, _inputs[0].target, reduction=reduction, dim=None) - - -def test_missing_data_range(): - with pytest.raises(ValueError): - PSNR(data_range=None, dim=0) - - with pytest.raises(ValueError): - psnr(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0) diff --git a/tests/metrics/regression/test_r2score.py b/tests/metrics/regression/test_r2score.py deleted file mode 100644 index 232b003e6116a7..00000000000000 --- a/tests/metrics/regression/test_r2score.py +++ /dev/null @@ -1,114 +0,0 @@ -from collections import namedtuple -from functools import partial - -import pytest -import torch -from sklearn.metrics import r2_score as sk_r2score - -from pytorch_lightning.metrics.functional import r2score -from pytorch_lightning.metrics.regression import R2Score -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -num_targets = 5 - -Input = namedtuple('Input', ["preds", "target"]) - -_single_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.rand(NUM_BATCHES, BATCH_SIZE), -) - -_multi_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), -) - - -def _single_target_sk_metric(preds, target, adjusted, multioutput): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) - if adjusted != 0: - r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) - return r2_score - - -def _multi_target_sk_metric(preds, target, adjusted, multioutput): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() - r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) - if adjusted != 0: - r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) - return r2_score - - -@pytest.mark.parametrize("adjusted", [0, 5, 10]) -@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted']) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_outputs", - [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric, 1), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric, num_targets), - ], -) -class TestR2Score(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp, - preds, - target, - R2Score, - partial(sk_metric, adjusted=adjusted, multioutput=multioutput), - dist_sync_on_step, - metric_args=dict(adjusted=adjusted, multioutput=multioutput, num_outputs=num_outputs), - ) - - def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs): - self.run_functional_metric_test( - preds, - target, - r2score, - partial(sk_metric, adjusted=adjusted, multioutput=multioutput), - metric_args=dict(adjusted=adjusted, multioutput=multioutput), - ) - - -def test_error_on_different_shape(metric_class=R2Score): - metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) - - -def test_error_on_multidim_tensors(metric_class=R2Score): - metric = metric_class() - with pytest.raises( - ValueError, - match=r'Expected both prediction and target to be 1D or 2D tensors,' - r' but recevied tensors with dimension .' - ): - metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5)) - - -def test_error_on_too_few_samples(metric_class=R2Score): - metric = metric_class() - with pytest.raises(ValueError, match='Needs atleast two samples to calculate r2 score.'): - metric(torch.randn(1, ), torch.randn(1, )) - - -def test_warning_on_too_large_adjusted(metric_class=R2Score): - metric = metric_class(adjusted=10) - - with pytest.warns( - UserWarning, - match="More independent regressions than datapoints in" - " adjusted r2 score. Falls back to standard r2 score." - ): - metric(torch.randn(10, ), torch.randn(10, )) - - with pytest.warns(UserWarning, match="Division by zero in adjusted r2 score. Falls back to" " standard r2 score."): - metric(torch.randn(11, ), torch.randn(11, )) diff --git a/tests/metrics/regression/test_ssim.py b/tests/metrics/regression/test_ssim.py deleted file mode 100644 index f7e4b7a58e0011..00000000000000 --- a/tests/metrics/regression/test_ssim.py +++ /dev/null @@ -1,104 +0,0 @@ -from collections import namedtuple -from functools import partial - -import pytest -import torch -from skimage.metrics import structural_similarity - -from pytorch_lightning.metrics.functional import ssim -from pytorch_lightning.metrics.regression import SSIM -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -Input = namedtuple('Input', ["preds", "target", "multichannel"]) - -_inputs = [] -for size, channel, coef, multichannel, dtype in [ - (12, 3, 0.9, True, torch.float), - (13, 1, 0.8, False, torch.float32), - (14, 1, 0.7, False, torch.double), - (15, 3, 0.6, True, torch.float64), -]: - preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) - _inputs.append(Input( - preds=preds, - target=preds * coef, - multichannel=multichannel, - )) - - -def _sk_metric(preds, target, data_range, multichannel): - c, h, w = preds.shape[-3:] - sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() - sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() - if not multichannel: - sk_preds = sk_preds[:, :, :, 0] - sk_target = sk_target[:, :, :, 0] - - return structural_similarity( - sk_target, - sk_preds, - data_range=data_range, - multichannel=multichannel, - gaussian_weights=True, - win_size=11, - sigma=1.5, - use_sample_covariance=False - ) - - -@pytest.mark.parametrize( - "preds, target, multichannel", - [(i.preds, i.target, i.multichannel) for i in _inputs], -) -class TestSSIM(MetricTester): - atol = 6e-5 - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_ssim(self, preds, target, multichannel, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp, - preds, - target, - SSIM, - partial(_sk_metric, data_range=1.0, multichannel=multichannel), - metric_args={"data_range": 1.0}, - dist_sync_on_step=dist_sync_on_step, - ) - - def test_ssim_functional(self, preds, target, multichannel): - self.run_functional_metric_test( - preds, - target, - ssim, - partial(_sk_metric, data_range=1.0, multichannel=multichannel), - metric_args={"data_range": 1.0}, - ) - - -@pytest.mark.parametrize( - ['pred', 'target', 'kernel', 'sigma'], - [ - pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input - ], -) -def test_ssim_invalid_inputs(pred, target, kernel, sigma): - pred_t = torch.rand(pred) - target_t = torch.rand(target, dtype=torch.float64) - with pytest.raises(TypeError): - ssim(pred_t, target_t) - - pred = torch.rand(pred) - target = torch.rand(target) - with pytest.raises(ValueError): - ssim(pred, target, kernel, sigma) diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py index eaf17ec0792daf..43dd330bcfcbea 100644 --- a/tests/metrics/test_remove_1-5_metrics.py +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -33,8 +33,11 @@ MetricCollection, Precision, PrecisionRecallCurve, + PSNR, + R2Score, Recall, ROC, + SSIM, StatScores, ) from pytorch_lightning.metrics.functional import ( @@ -53,8 +56,11 @@ precision, precision_recall, precision_recall_curve, + psnr, + r2score, recall, roc, + ssim, stat_scores, ) from pytorch_lightning.metrics.functional.accuracy import accuracy @@ -290,3 +296,36 @@ def test_v1_5_metric_regress(): with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = mean_squared_log_error(x, y) assert torch.allclose(res, torch.tensor(0.0207), atol=1e-4) + + PSNR.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + PSNR() + + R2Score.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + R2Score() + + SSIM.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + SSIM() + + preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) + target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) + psnr.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = psnr(preds, target) + assert torch.allclose(res, torch.tensor(2.5527), atol=1e-4) + + target = torch.tensor([3, -0.5, 2, 7]) + preds = torch.tensor([2.5, 0.0, 2, 8]) + r2score.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = r2score(preds, target) + assert torch.allclose(res, torch.tensor(0.9486), atol=1e-4) + + preds = torch.rand([16, 1, 16, 16]) + target = preds * 0.75 + ssim.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = ssim(preds, target) + assert torch.allclose(res, torch.tensor(0.9219), atol=1e-4) From 36d180e53271295359c1ca7da1d222bbd2ed7940 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 23 Mar 2021 11:07:35 +0100 Subject: [PATCH 14/22] Refactor base profilers 3/5 (#6621) Co-authored-by: tchaton --- .gitignore | 2 +- CHANGELOG.md | 9 + pytorch_lightning/profiler/profilers.py | 270 +++++++++++++------ pytorch_lightning/profiler/pytorch.py | 76 ++---- pytorch_lightning/trainer/evaluation_loop.py | 5 + pytorch_lightning/trainer/training_loop.py | 4 +- tests/deprecated_api/test_remove_1-5.py | 10 + tests/test_profiler.py | 167 ++++++++---- tests/trainer/properties/test_get_model.py | 5 +- 9 files changed, 355 insertions(+), 193 deletions(-) diff --git a/.gitignore b/.gitignore index c0071402571884..99939ff7fce0cd 100644 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,4 @@ tags data MNIST runs -*traces* +*trace* diff --git a/CHANGELOG.md b/CHANGELOG.md index 4cf3e0f1fd326a..32cf9122efe340 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `Trainer.predict` config validation ([#6543](https://github.com/PyTorchLightning/pytorch-lightning/pull/6543)) +- Added `AbstractProfiler` interface ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) + + - Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) @@ -68,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) +- Changed profilers to save separate report files per state and rank ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) + + ### Deprecated - `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) @@ -76,6 +82,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) +- Deprecated `Profiler(output_filename)` in favor of `dirpath` and `filename` ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) + + - Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505), [#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530), diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index 5668fd6654b2f2..54bc5cdf0122c4 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -21,31 +21,19 @@ from abc import ABC, abstractmethod from collections import defaultdict from contextlib import contextmanager -from typing import Optional, Union +from pathlib import Path +from typing import Any, Callable, Dict, Optional, TextIO, Tuple, Union import numpy as np +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem log = logging.getLogger(__name__) -class BaseProfiler(ABC): - """ - If you wish to write a custom profiler, you should inhereit from this class. - """ - - def __init__(self, output_streams: Optional[Union[list, tuple]] = None): - """ - Args: - output_streams: callable - """ - if output_streams: - if not isinstance(output_streams, (list, tuple)): - output_streams = [output_streams] - else: - output_streams = [] - self.write_streams = output_streams +class AbstractProfiler(ABC): + """Specification of a profiler.""" @abstractmethod def start(self, action_name: str) -> None: @@ -55,23 +43,47 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: """Defines how to record the duration once an action is complete.""" - def setup( + @abstractmethod + def summary(self) -> str: + """Create profiler summary in text format.""" + + @abstractmethod + def setup(self, **kwargs: Any) -> None: + """Execute arbitrary pre-profiling set-up steps as defined by subclass.""" + + @abstractmethod + def teardown(self, **kwargs: Any) -> None: + """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" + + +class BaseProfiler(AbstractProfiler): + """ + If you wish to write a custom profiler, you should inherit from this class. + """ + + def __init__( self, - stage: Optional[str] = None, - local_rank: Optional[int] = None, - log_dir: Optional[str] = None + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + output_filename: Optional[str] = None, ) -> None: - """Execute arbitrary pre-profiling set-up steps.""" - self.stage = stage - self.local_rank = local_rank - self.log_dir = log_dir + self.dirpath = dirpath + self.filename = filename + if output_filename is not None: + rank_zero_warn( + "`Profiler` signature has changed in v1.3. The `output_filename` parameter has been removed in" + " favor of `dirpath` and `filename`. Support for the old signature will be removed in v1.5", + DeprecationWarning + ) + filepath = Path(output_filename) + self.dirpath = filepath.parent + self.filename = filepath.stem - def teardown(self, stage: Optional[str] = None) -> None: - """Execute arbitrary post-profiling tear-down steps.""" - self.stage = stage - if self.output_file: - self.output_file.close() - self.output_file = None + self._output_file: Optional[TextIO] = None + self._write_stream: Optional[Callable] = None + self._local_rank: Optional[int] = None + self._log_dir: Optional[str] = None + self._stage: Optional[str] = None @contextmanager def profile(self, action_name: str) -> None: @@ -104,19 +116,94 @@ def profile_iterable(self, iterable, action_name: str) -> None: self.stop(action_name) break + def _rank_zero_info(self, *args, **kwargs) -> None: + if self._local_rank in (None, 0): + log.info(*args, **kwargs) + + def _prepare_filename(self) -> str: + filename = "" + if self._stage is not None: + filename += f"{self._stage}-" + filename += str(self.filename) + if self._local_rank is not None: + filename += f"-{self.local_rank}" + filename += ".txt" + return filename + + def _prepare_streams(self) -> None: + if self._write_stream is not None: + return + if self.filename: + dirpath = self.dirpath or self._log_dir + filepath = os.path.join(dirpath, self._prepare_filename()) + fs = get_filesystem(filepath) + file = fs.open(filepath, "a") + self._output_file = file + self._write_stream = file.write + else: + self._write_stream = self._rank_zero_info + def describe(self) -> None: - """Logs a profile report after the conclusion of the training run.""" - for write in self.write_streams: - write(self.summary()) - if self.output_file is not None: - self.output_file.flush() + """Logs a profile report after the conclusion of run.""" + # there are pickling issues with open file handles in Python 3.6 + # so to avoid them, we open and close the files within this function + # by calling `_prepare_streams` and `teardown` + self._prepare_streams() + self._write_stream(self.summary()) + if self._output_file is not None: + self._output_file.flush() + self.teardown(stage=self._stage) + + def _stats_to_str(self, stats: Dict[str, str]) -> str: + stage = f"{self._stage.upper()} " if self._stage is not None else "" + output = [stage + "Profiler Report"] + for action, value in stats.items(): + header = f"Profile stats for: {action}" + if self._local_rank is not None: + header += f" rank: {self._local_rank}" + output.append(header) + output.append(value) + return os.linesep.join(output) + + def setup( + self, + stage: Optional[str] = None, + local_rank: Optional[int] = None, + log_dir: Optional[str] = None, + ) -> None: + """Execute arbitrary pre-profiling set-up steps.""" + self._stage = stage + self._local_rank = local_rank + self._log_dir = log_dir + if self.dirpath is None: + self.dirpath = self._log_dir + + def teardown(self, stage: Optional[str] = None) -> None: + """ + Execute arbitrary post-profiling tear-down steps. + + Closes the currently open file and stream. + """ + self._write_stream = None + if self._output_file is not None: + self._output_file.close() + self._output_file = None # can't pickle TextIOWrapper + + def __del__(self) -> None: + self.teardown(stage=self._stage) + + def start(self, action_name: str) -> None: + raise NotImplementedError + + def stop(self, action_name: str) -> None: + raise NotImplementedError - @abstractmethod def summary(self) -> str: - """Create profiler summary in text format.""" + raise NotImplementedError - def __del__(self): - self.teardown(None) + @property + def local_rank(self): + return '0' if self._local_rank is None else self._local_rank class PassThroughProfiler(BaseProfiler): @@ -125,10 +212,6 @@ class PassThroughProfiler(BaseProfiler): The Trainer uses this class by default. """ - def __init__(self): - self.output_file = None - super().__init__(output_streams=None) - def start(self, action_name: str) -> None: pass @@ -145,30 +228,32 @@ class SimpleProfiler(BaseProfiler): the mean duration of each action and the total time spent over the entire training run. """ - def __init__(self, output_filename: Optional[str] = None, extended=True): + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + extended: bool = True, + output_filename: Optional[str] = None, + ) -> None: """ Args: - output_filename: optionally save profile results to file instead of printing - to std out when training is finished. + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. + + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. Raises: ValueError: If you attempt to start an action which has already started, or if you attempt to stop recording an action which was never started. """ - self.current_actions = {} + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) + self.current_actions: Dict[str, float] = {} self.recorded_durations = defaultdict(list) self.extended = extended - - self.output_fname = output_filename - self.output_file = None - if self.output_fname: - fs = get_filesystem(self.output_fname) - self.output_file = fs.open(self.output_fname, "w") - - streaming_out = [self.output_file.write] if self.output_file else [log.info] self.start_time = time.monotonic() - super().__init__(output_streams=streaming_out) def start(self, action_name: str) -> None: if action_name in self.current_actions: @@ -183,14 +268,18 @@ def stop(self, action_name: str) -> None: duration = end_time - start_time self.recorded_durations[action_name].append(duration) - def make_report(self): + def _make_report(self) -> Tuple[list, float]: total_duration = time.monotonic() - self.start_time report = [[a, d, 100. * np.sum(d) / total_duration] for a, d in self.recorded_durations.items()] report.sort(key=lambda x: x[2], reverse=True) return report, total_duration def summary(self) -> str: - output_string = "\n\nProfiler Report\n" + sep = os.linesep + output_string = "" + if self._stage is not None: + output_string += f"{self._stage.upper()} " + output_string += f"Profiler Report{sep}" if self.extended: @@ -198,16 +287,16 @@ def summary(self) -> str: max_key = np.max([len(k) for k in self.recorded_durations.keys()]) def log_row(action, mean, num_calls, total, per): - row = f"{os.linesep}{action:<{max_key}s}\t| {mean:<15}\t|" + row = f"{sep}{action:<{max_key}s}\t| {mean:<15}\t|" row += f"{num_calls:<15}\t| {total:<15}\t| {per:<15}\t|" return row output_string += log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %") output_string_len = len(output_string) - output_string += f"{os.linesep}{'-' * output_string_len}" - report, total_duration = self.make_report() + output_string += f"{sep}{'-' * output_string_len}" + report, total_duration = self._make_report() output_string += log_row("Total", "-", "_", f"{total_duration:.5}", "100 %") - output_string += f"{os.linesep}{'-' * output_string_len}" + output_string += f"{sep}{'-' * output_string_len}" for action, durations, duration_per in report: output_string += log_row( action, @@ -219,14 +308,14 @@ def log_row(action, mean, num_calls, total, per): else: def log_row(action, mean, total): - return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}" + return f"{sep}{action:<20s}\t| {mean:<15}\t| {total:<15}" output_string += log_row("Action", "Mean duration (s)", "Total time (s)") - output_string += f"{os.linesep}{'-' * 65}" + output_string += f"{sep}{'-' * 65}" for action, durations in self.recorded_durations.items(): output_string += log_row(action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}") - output_string += os.linesep + output_string += sep return output_string @@ -237,11 +326,22 @@ class AdvancedProfiler(BaseProfiler): verbose and you should only use this if you want very detailed reports. """ - def __init__(self, output_filename: Optional[str] = None, line_count_restriction: float = 1.0): + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + line_count_restriction: float = 1.0, + output_filename: Optional[str] = None, + ) -> None: """ Args: - output_filename: optionally save profile results to file instead of printing - to std out when training is finished. + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. + + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. + line_count_restriction: this can be used to limit the number of functions reported for each action. either an integer (to select a count of lines), or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines) @@ -250,18 +350,10 @@ def __init__(self, output_filename: Optional[str] = None, line_count_restriction ValueError: If you attempt to stop recording an action which was never started. """ - self.profiled_actions = {} + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) + self.profiled_actions: Dict[str, cProfile.Profile] = {} self.line_count_restriction = line_count_restriction - self.output_fname = output_filename - self.output_file = None - if self.output_fname: - fs = get_filesystem(self.output_fname) - self.output_file = fs.open(self.output_fname, "w") - - streaming_out = [self.output_file.write] if self.output_file else [log.info] - super().__init__(output_streams=streaming_out) - def start(self, action_name: str) -> None: if action_name not in self.profiled_actions: self.profiled_actions[action_name] = cProfile.Profile() @@ -270,9 +362,7 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: pr = self.profiled_actions.get(action_name) if pr is None: - raise ValueError( # pragma: no-cover - f"Attempting to stop recording an action ({action_name}) which was never started." - ) + raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.") pr.disable() def summary(self) -> str: @@ -282,10 +372,16 @@ def summary(self) -> str: ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative') ps.print_stats(self.line_count_restriction) recorded_stats[action_name] = s.getvalue() + return self._stats_to_str(recorded_stats) - # log to standard out - output_string = f"{os.linesep}Profiler Report{os.linesep}" - for action, stats in recorded_stats.items(): - output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}" + def teardown(self, stage: Optional[str] = None) -> None: + super().teardown(stage=stage) + self.profiled_actions = {} - return output_string + def __reduce__(self): + # avoids `TypeError: cannot pickle 'cProfile.Profile' object` + return ( + self.__class__, + tuple(), + dict(dirpath=self.dirpath, filename=self.filename, line_count_restriction=self.line_count_restriction), + ) diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index c35979fa918af0..55b1c286789f43 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -16,13 +16,12 @@ import inspect import logging import os -from typing import List, Optional +from pathlib import Path +from typing import List, Optional, Union import torch from pytorch_lightning.profiler.profilers import BaseProfiler -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -46,7 +45,8 @@ class PyTorchProfiler(BaseProfiler): def __init__( self, - output_filename: Optional[str] = None, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, enabled: bool = True, use_cuda: bool = False, record_shapes: bool = False, @@ -61,18 +61,19 @@ def __init__( row_limit: int = 20, sort_by_key: Optional[str] = None, profiled_functions: Optional[List] = None, - local_rank: Optional[int] = None, + output_filename: Optional[str] = None, ): """ This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of different operators inside your model - both on the CPU and GPU Args: + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. - output_filename: optionally save profile results to file instead of printing - to std out when training is finished. When using ``ddp``, - each rank will stream the profiled operation to their own file - with the extension ``_{rank}.txt`` + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. enabled: Setting this to False makes this context manager a no-op. @@ -116,13 +117,9 @@ def __init__( profiled_functions: list of profiled functions which will create a context manager on. Any other will be pass through. - local_rank: When running in distributed setting, local_rank is used for each process - to write to their own file if `output_fname` is provided. - Raises: MisconfigurationException: - If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``, or - if log file is not a ``.txt`` file. + If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``. ValueError: If you attempt to stop recording an action which was never started. """ @@ -159,37 +156,20 @@ def __init__( self.running_stack = [] self.profiler = None - self.output_fname = output_filename - self.output_file = None - if local_rank is not None: - self.setup(local_rank=local_rank) - self.setup = super().setup + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) - def setup(self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None): + def setup( + self, + stage: Optional[str] = None, + local_rank: Optional[int] = None, + log_dir: Optional[str] = None + ) -> None: super().setup(stage=stage, local_rank=local_rank, log_dir=log_dir) - # when logging to `log.info`, only perform profiling on rank 0 - if local_rank != 0 and self.output_fname is None: - self.wrap_functions_into_rank_zero_only() - - if self.output_fname: - if local_rank is not None: - if '.txt' not in self.output_fname: - raise MisconfigurationException("Log file should be .txt file.") - - self.output_fname = self.output_fname.replace(".txt", f"_{self.local_rank}.txt") - - fs = get_filesystem(self.output_fname) - self.output_file = fs.open(self.output_fname, "w") - - streaming_out = [self.output_file.write] if self.output_file else [log.info] - super().__init__(output_streams=streaming_out) - - def wrap_functions_into_rank_zero_only(self): - self.start = rank_zero_only(self.start) - self.stop = rank_zero_only(self.stop) - self.summary = rank_zero_only(self.summary) - self.describe = rank_zero_only(self.describe) + # if the user didn't provide `path_to_export_trace`, + # set it as TensorBoardLogger log_dir if exists + if self.path_to_export_trace is None: + self.path_to_export_trace = log_dir def start(self, action_name: str) -> None: if action_name not in self.profiled_functions: @@ -231,6 +211,7 @@ def _stop(self, action_name: str) -> None: # when running ``emit_nvtx``, PyTorch requires 2 context manager. # The parent_profiler is being closed too. self._parent_profiler.__exit__(None, None, None) + self._parent_profiler = None return function_events = self.profiler.function_events @@ -258,7 +239,6 @@ def stop(self, action_name: str) -> None: def summary(self) -> str: recorded_stats = {} output_string = '' - local_rank = '0' if self.local_rank is None else self.local_rank if not self.enabled: return output_string @@ -271,7 +251,7 @@ def summary(self) -> str: function_events.populate_cpu_children = lambda: None if self.export_to_chrome: - filename = f"{action_name}_{local_rank}_trace.json" + filename = f"{action_name}_{self.local_rank}_trace.json" path_to_trace = filename if self.path_to_export_trace is None \ else os.path.join(self.path_to_export_trace, filename) function_events.export_chrome_trace(path_to_trace) @@ -283,10 +263,4 @@ def summary(self) -> str: data = function_events.key_averages(group_by_input_shapes=self.group_by_input_shapes) table = data.table(sort_by=self.sort_by_key, row_limit=self.row_limit) recorded_stats[action_name] = table - - # log to standard out - output_string = f"{os.linesep}Profiler Report{os.linesep}" - for action, stats in recorded_stats.items(): - output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}") - - return output_string + return self._stats_to_str(recorded_stats) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 20c842939fe171..da41b9855b44ad 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -15,6 +15,7 @@ import torch from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.model_helpers import is_overridden @@ -99,6 +100,10 @@ def on_evaluation_end(self, *args, **kwargs): else: self.trainer.call_hook('on_validation_end', *args, **kwargs) + if self.trainer.state != TrainerState.FITTING: + # summarize profile results + self.trainer.profiler.describe() + def reload_evaluation_dataloaders(self): model = self.trainer.lightning_module if self.trainer.testing: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 384a1b67a64f84..cc471f76b60334 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -137,9 +137,7 @@ def on_train_end(self): self.trainer.logger.finalize("success") # summarize profile results - # todo (tchaton) All ranks should call describe. - if self.trainer.global_rank == 0: - self.trainer.profiler.describe() + self.trainer.profiler.describe() # give accelerators a chance to finish self.trainer.accelerator.on_train_end() diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index f449a37e33c25a..0c5f581d7775c2 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -20,6 +20,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.profiler import BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache from tests.deprecated_api import no_deprecated_call from tests.helpers import BoringModel @@ -203,3 +204,12 @@ def on_test_epoch_end(self, outputs): model = NewSignatureModel() with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."): trainer.test(model) + + +@pytest.mark.parametrize("cls", (BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) +def test_v1_5_0_profiler_output_filename(tmpdir, cls): + filepath = str(tmpdir / "test.txt") + with pytest.deprecated_call(match="`output_filename` parameter has been removed"): + profiler = cls(output_filename=filepath) + assert profiler.dirpath == tmpdir + assert profiler.filename == "test" diff --git a/tests/test_profiler.py b/tests/test_profiler.py index cc4fff3b7ede49..cf6afcc9b626c1 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -14,6 +14,7 @@ import logging import os import time +from copy import deepcopy from distutils.version import LooseVersion from pathlib import Path @@ -21,8 +22,7 @@ import pytest import torch -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import Callback +from pytorch_lightning import Callback, Trainer from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -46,8 +46,7 @@ def _sleep_generator(durations): @pytest.fixture def simple_profiler(): - profiler = SimpleProfiler() - return profiler + return SimpleProfiler() @pytest.mark.parametrize(["action", "expected"], [ @@ -93,14 +92,6 @@ def test_simple_profiler_overhead(simple_profiler, n_iter=5): assert all(durations < PROFILER_OVERHEAD_MAX_TOLERANCE) -def test_simple_profiler_describe(caplog, simple_profiler): - """Ensure the profiler won't fail when reporting the summary.""" - with caplog.at_level(logging.INFO): - simple_profiler.describe() - - assert "Profiler Report" in caplog.text - - def test_simple_profiler_value_errors(simple_profiler): """Ensure errors are raised where expected.""" @@ -116,10 +107,75 @@ def test_simple_profiler_value_errors(simple_profiler): simple_profiler.stop(action) +def test_simple_profiler_deepcopy(tmpdir): + simple_profiler = SimpleProfiler(dirpath=tmpdir, filename="test") + simple_profiler.describe() + assert deepcopy(simple_profiler) + + +def test_simple_profiler_log_dir(tmpdir): + """Ensure the profiler dirpath defaults to `trainer.log_dir` when not present""" + profiler = SimpleProfiler(filename="profiler") + assert profiler._log_dir is None + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + profiler=profiler, + ) + trainer.fit(model) + + expected = profiler.dirpath + assert trainer.log_dir == expected + assert profiler._log_dir == trainer.log_dir + assert Path(os.path.join(profiler.dirpath, "fit-profiler.txt")).exists() + + +@RunIf(skip_windows=True) +def test_simple_profiler_distributed_files(tmpdir): + """Ensure the proper files are saved in distributed""" + profiler = SimpleProfiler(dirpath=tmpdir, filename='profiler') + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=2, + accelerator="ddp_cpu", + num_processes=2, + profiler=profiler, + logger=False, + ) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + + actual = set(os.listdir(profiler.dirpath)) + expected = {f"{stage}-profiler-{rank}.txt" for stage in ("fit", "validate", "test") for rank in (0, 1)} + assert actual == expected + + for f in profiler.dirpath.listdir(): + assert f.read_text('utf-8') + + +def test_simple_profiler_logs(tmpdir, caplog, simple_profiler): + """Ensure that the number of printed logs is correct""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=2, + profiler=simple_profiler, + logger=False, + ) + with caplog.at_level(logging.INFO, logger="pytorch_lightning.profiler.profilers"): + trainer.fit(model) + trainer.test(model) + + assert caplog.text.count("Profiler Report") == 2 + + @pytest.fixture def advanced_profiler(tmpdir): - profiler = AdvancedProfiler(output_filename=os.path.join(tmpdir, "profiler.txt")) - return profiler + return AdvancedProfiler(dirpath=tmpdir, filename="profiler") @pytest.mark.parametrize(["action", "expected"], [ @@ -180,7 +236,8 @@ def test_advanced_profiler_describe(tmpdir, advanced_profiler): pass # log to stdout and print to file advanced_profiler.describe() - data = Path(advanced_profiler.output_fname).read_text() + path = advanced_profiler.dirpath / f"{advanced_profiler.filename}.txt" + data = path.read_text("utf-8") assert len(data) > 0 @@ -195,10 +252,14 @@ def test_advanced_profiler_value_errors(advanced_profiler): advanced_profiler.stop(action) +def test_advanced_profiler_deepcopy(advanced_profiler): + advanced_profiler.describe() + assert deepcopy(advanced_profiler) + + @pytest.fixture def pytorch_profiler(tmpdir): - profiler = PyTorchProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"), local_rank=0) - return profiler + return PyTorchProfiler(dirpath=tmpdir, filename="profiler") def test_pytorch_profiler_describe(pytorch_profiler): @@ -208,7 +269,8 @@ def test_pytorch_profiler_describe(pytorch_profiler): # log to stdout and print to file pytorch_profiler.describe() - data = Path(pytorch_profiler.output_fname).read_text() + path = pytorch_profiler.dirpath / f"{pytorch_profiler.filename}.txt" + data = path.read_text("utf-8") assert len(data) > 0 @@ -223,47 +285,53 @@ def test_pytorch_profiler_value_errors(pytorch_profiler): pytorch_profiler.stop(action) -@RunIf(min_gpus=2, special=True) -@pytest.mark.parametrize("use_output_filename", [False, True]) -def test_pytorch_profiler_trainer_ddp(tmpdir, use_output_filename): - """Ensure that the profiler can be given to the training and default step are properly recorded. """ - - if use_output_filename: - output_filename = os.path.join(tmpdir, "profiler.txt") - else: - output_filename = None +@RunIf(min_torch="1.6.0") +def test_advanced_profiler_cprofile_deepcopy(tmpdir): + """Checks for pickle issue reported in #6522""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + profiler="advanced", + stochastic_weight_avg=True, + ) + trainer.fit(model) - profiler = PyTorchProfiler(output_filename=output_filename) +@RunIf(min_gpus=2, special=True) +def test_pytorch_profiler_trainer_ddp(tmpdir): + """Ensure that the profiler can be given to the training and default step are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=None, filename="profiler") model = BoringModel() trainer = Trainer( - fast_dev_run=True, - profiler=profiler, + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + profiler=pytorch_profiler, accelerator="ddp", gpus=2, ) trainer.fit(model) - enabled = use_output_filename or not use_output_filename and profiler.local_rank == 0 + assert len(pytorch_profiler.summary()) > 0 + assert set(pytorch_profiler.profiled_actions) == {'training_step_and_backward', 'validation_step'} - if enabled: - assert len(profiler.summary()) > 0 - assert set(profiler.profiled_actions.keys()) == {'training_step_and_backward', 'validation_step'} - else: - assert profiler.summary() is None - assert set(profiler.profiled_actions.keys()) == set() + files = sorted(f for f in os.listdir(pytorch_profiler.dirpath) if "fit" in f) + rank = int(os.getenv("LOCAL_RANK", "0")) + expected = f"fit-profiler-{rank}.txt" + assert files[rank] == expected - # todo (tchaton) add support for all ranks - if use_output_filename and os.getenv("LOCAL_RANK") == "0": - data = Path(profiler.output_fname).read_text() - assert len(data) > 0 + path = os.path.join(pytorch_profiler.dirpath, expected) + data = Path(path).read_text("utf-8") + assert len(data) > 0 def test_pytorch_profiler_nested(tmpdir): """Ensure that the profiler handles nested context""" pytorch_profiler = PyTorchProfiler( - profiled_functions=["a", "b", "c"], use_cuda=False, output_filename=os.path.join(tmpdir, "profiler.txt") + profiled_functions=["a", "b", "c"], use_cuda=False, dirpath=tmpdir, filename="profiler" ) with pytorch_profiler.profile("a"): @@ -327,13 +395,18 @@ def test_profiler_teardown(tmpdir, cls): class TestCallback(Callback): - def on_fit_end(self, trainer, pl_module) -> None: - assert trainer.profiler.output_file is not None - - profiler = cls(output_filename=os.path.join(tmpdir, "profiler.txt")) + def on_fit_end(self, trainer, *args, **kwargs) -> None: + # describe sets it to None + assert trainer.profiler._output_file is None + profiler = cls(dirpath=tmpdir, filename="profiler") model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()]) trainer.fit(model) - assert profiler.output_file is None + assert profiler._output_file is None + + +def test_pytorch_profiler_deepcopy(pytorch_profiler): + pytorch_profiler.describe() + assert deepcopy(pytorch_profiler) diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 4dc5b5f34b50ca..3eb0596b55fc40 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -13,7 +13,6 @@ # limitations under the License. from pytorch_lightning import Trainer -from tests.accelerators import DDPLauncher from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -84,8 +83,7 @@ def test_get_model_gpu(tmpdir): @RunIf(min_gpus=1, skip_windows=True) -@DDPLauncher.run("--accelerator [accelerator]", max_epochs=["1"], accelerator=["ddp", "ddp_spawn"]) -def test_get_model_ddp_gpu(tmpdir, args=None): +def test_get_model_ddp_gpu(tmpdir): """ Tests that `trainer.lightning_module` extracts the model correctly when using GPU + ddp accelerators """ @@ -99,7 +97,6 @@ def test_get_model_ddp_gpu(tmpdir, args=None): limit_val_batches=2, max_epochs=1, gpus=1, - accelerator=args.accelerator ) trainer.fit(model) return 1 From a74909affa0535da02e64b94f6d5f9b2da03c08f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 23 Mar 2021 16:05:32 +0100 Subject: [PATCH 15/22] prune metrics: info retrieval (#6649) --- CHANGELOG.md | 2 - pytorch_lightning/metrics/__init__.py | 1 - .../metrics/functional/__init__.py | 1 - .../functional/ir_average_precision.py | 54 ------- .../metrics/retrieval/__init__.py | 15 -- .../retrieval/mean_average_precision.py | 61 -------- .../metrics/retrieval/retrieval_metric.py | 140 ------------------ tests/metrics/functional/test_retrieval.py | 36 ----- tests/metrics/retrieval/__init__.py | 0 tests/metrics/retrieval/test_map.py | 119 --------------- 10 files changed, 429 deletions(-) delete mode 100644 pytorch_lightning/metrics/functional/ir_average_precision.py delete mode 100644 pytorch_lightning/metrics/retrieval/__init__.py delete mode 100644 pytorch_lightning/metrics/retrieval/mean_average_precision.py delete mode 100644 pytorch_lightning/metrics/retrieval/retrieval_metric.py delete mode 100644 tests/metrics/functional/test_retrieval.py delete mode 100644 tests/metrics/retrieval/__init__.py delete mode 100644 tests/metrics/retrieval/test_map.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 32cf9122efe340..81bfa85cc073fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added `RetrievalMAP` metric, the corresponding functional version `retrieval_average_precision` and a generic superclass for retrieval metrics `RetrievalMetric` ([#5032](https://github.com/PyTorchLightning/pytorch-lightning/pull/5032)) - - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 500689f3182fb3..1da24737a3752f 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -39,7 +39,6 @@ R2Score, SSIM, ) -from pytorch_lightning.metrics.retrieval import RetrievalMAP # noqa: F401 warn( "`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package" diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 1701389cd1c64f..3b31dad5d3411d 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -28,7 +28,6 @@ from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401 from pytorch_lightning.metrics.functional.image_gradients import image_gradients # noqa: F401 from pytorch_lightning.metrics.functional.iou import iou # noqa: F401 -from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision # noqa: F401 from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error # noqa: F401 from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401 from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401 diff --git a/pytorch_lightning/metrics/functional/ir_average_precision.py b/pytorch_lightning/metrics/functional/ir_average_precision.py deleted file mode 100644 index 83b14a21c5553d..00000000000000 --- a/pytorch_lightning/metrics/functional/ir_average_precision.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - - -def retrieval_average_precision(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - r""" - Computes average precision (for information retrieval), as explained - `here `_. - - `preds` and `target` should be of the same shape and live on the same device. If no `target` is ``True``, - 0 is returned. Target must be of type `bool` or `int`, otherwise an error is raised. - - Args: - preds: estimated probabilities of each document to be relevant. - target: ground truth about each document being relevant or not. Requires `bool` or `int` tensor. - - Return: - a single-value tensor with the average precision (AP) of the predictions `preds` wrt the labels `target`. - - Example: - >>> preds = torch.tensor([0.2, 0.3, 0.5]) - >>> target = torch.tensor([True, False, True]) - >>> retrieval_average_precision(preds, target) - tensor(0.8333) - """ - - if preds.shape != target.shape or preds.device != target.device: - raise ValueError("`preds` and `target` must have the same shape and live on the same device") - - if target.dtype not in (torch.bool, torch.int16, torch.int32, torch.int64): - raise ValueError("`target` must be a tensor of booleans or integers") - - if target.dtype is not torch.bool: - target = target.bool() - - if target.sum() == 0: - return torch.tensor(0, device=preds.device) - - target = target[torch.argsort(preds, dim=-1, descending=True)] - positions = torch.arange(1, len(target) + 1, device=target.device, dtype=torch.float32)[target > 0] - res = torch.div((torch.arange(len(positions), device=positions.device, dtype=torch.float32) + 1), positions).mean() - return res diff --git a/pytorch_lightning/metrics/retrieval/__init__.py b/pytorch_lightning/metrics/retrieval/__init__.py deleted file mode 100644 index c5c12b3b6643c9..00000000000000 --- a/pytorch_lightning/metrics/retrieval/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pytorch_lightning.metrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401 -from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401 diff --git a/pytorch_lightning/metrics/retrieval/mean_average_precision.py b/pytorch_lightning/metrics/retrieval/mean_average_precision.py deleted file mode 100644 index 956a53cca2e778..00000000000000 --- a/pytorch_lightning/metrics/retrieval/mean_average_precision.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch - -from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision -from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric - - -class RetrievalMAP(RetrievalMetric): - r""" - Computes `Mean Average Precision - `_. - - Works with binary data. Accepts integer or float predictions from a model output. - - Forward accepts - - ``indexes`` (long tensor): ``(N, ...)`` - - ``preds`` (float tensor): ``(N, ...)`` - - ``target`` (long or bool tensor): ``(N, ...)`` - - `indexes`, `preds` and `target` must have the same dimension. - `indexes` indicate to which query a prediction belongs. - Predictions will be first grouped by indexes and then MAP will be computed as the mean - of the Average Precisions over each query. - - Args: - query_without_relevant_docs: - Specify what to do with queries that do not have at least a positive target. Choose from: - - - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned - - ``'error'``: raise a ``ValueError`` - - ``'pos'``: score on those queries is counted as ``1.0`` - - ``'neg'``: score on those queries is counted as ``0.0`` - exclude: - Do not take into account predictions where the target is equal to this value. default `-100` - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects - the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None - - Example: - >>> from pytorch_lightning.metrics import RetrievalMAP - >>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1]) - >>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) - >>> target = torch.tensor([False, False, True, False, True, False, False]) - - >>> map = RetrievalMAP() - >>> map(indexes, preds, target) - tensor(0.7500) - >>> map.compute() - tensor(0.7500) - """ - - def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - valid_indexes = target != self.exclude - return retrieval_average_precision(preds[valid_indexes], target[valid_indexes]) diff --git a/pytorch_lightning/metrics/retrieval/retrieval_metric.py b/pytorch_lightning/metrics/retrieval/retrieval_metric.py deleted file mode 100644 index 6f9088d00083cf..00000000000000 --- a/pytorch_lightning/metrics/retrieval/retrieval_metric.py +++ /dev/null @@ -1,140 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Callable, Optional - -import torch -from torchmetrics import Metric - -from pytorch_lightning.metrics.utils import get_group_indexes - -#: get_group_indexes is used to group predictions belonging to the same query - -IGNORE_IDX = -100 - - -class RetrievalMetric(Metric, ABC): - r""" - Works with binary data. Accepts integer or float predictions from a model output. - - Forward accepts - - ``indexes`` (long tensor): ``(N, ...)`` - - ``preds`` (float or int tensor): ``(N, ...)`` - - ``target`` (long or bool tensor): ``(N, ...)`` - - `indexes`, `preds` and `target` must have the same dimension and will be flatten - to single dimension once provided. - - `indexes` indicate to which query a prediction belongs. - Predictions will be first grouped by indexes. Then the - real metric, defined by overriding the `_metric` method, - will be computed as the mean of the scores over each query. - - Args: - query_without_relevant_docs: - Specify what to do with queries that do not have at least a positive target. Choose from: - - - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned - - ``'error'``: raise a ``ValueError`` - - ``'pos'``: score on those queries is counted as ``1.0`` - - ``'neg'``: score on those queries is counted as ``0.0`` - exclude: - Do not take into account predictions where the target is equal to this value. default `-100` - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects - the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None - - """ - - def __init__( - self, - query_without_relevant_docs: str = 'skip', - exclude: int = IGNORE_IDX, - compute_on_step: bool = True, - dist_sync_on_step: bool = False, - process_group: Optional[Any] = None, - dist_sync_fn: Callable = None - ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn - ) - - query_without_relevant_docs_options = ('error', 'skip', 'pos', 'neg') - if query_without_relevant_docs not in query_without_relevant_docs_options: - raise ValueError( - f"`query_without_relevant_docs` received a wrong value {query_without_relevant_docs}. " - f"Allowed values are {query_without_relevant_docs_options}" - ) - - self.query_without_relevant_docs = query_without_relevant_docs - self.exclude = exclude - - self.add_state("idx", default=[], dist_reduce_fx=None) - self.add_state("preds", default=[], dist_reduce_fx=None) - self.add_state("target", default=[], dist_reduce_fx=None) - - def update(self, idx: torch.Tensor, preds: torch.Tensor, target: torch.Tensor) -> None: - if not (idx.shape == target.shape == preds.shape): - raise ValueError("`idx`, `preds` and `target` must be of the same shape") - - idx = idx.to(dtype=torch.int64).flatten() - preds = preds.to(dtype=torch.float32).flatten() - target = target.to(dtype=torch.int64).flatten() - - self.idx.append(idx) - self.preds.append(preds) - self.target.append(target) - - def compute(self) -> torch.Tensor: - r""" - First concat state `idx`, `preds` and `target` since they were stored as lists. After that, - compute list of groups that will help in keeping together predictions about the same query. - Finally, for each group compute the `_metric` if the number of positive targets is at least - 1, otherwise behave as specified by `self.query_without_relevant_docs`. - """ - - idx = torch.cat(self.idx, dim=0) - preds = torch.cat(self.preds, dim=0) - target = torch.cat(self.target, dim=0) - - res = [] - kwargs = {'device': idx.device, 'dtype': torch.float32} - - groups = get_group_indexes(idx) - for group in groups: - - mini_preds = preds[group] - mini_target = target[group] - - if not mini_target.sum(): - if self.query_without_relevant_docs == 'error': - raise ValueError( - f"`{self.__class__.__name__}.compute()` was provided with " - f"a query without positive targets, indexes: {group}" - ) - if self.query_without_relevant_docs == 'pos': - res.append(torch.tensor(1.0, **kwargs)) - elif self.query_without_relevant_docs == 'neg': - res.append(torch.tensor(0.0, **kwargs)) - else: - res.append(self._metric(mini_preds, mini_target)) - - if len(res) > 0: - return torch.stack(res).mean() - return torch.tensor(0.0, **kwargs) - - @abstractmethod - def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - r""" - Compute a metric over a predictions and target of a single group. - This method should be overridden by subclasses. - """ diff --git a/tests/metrics/functional/test_retrieval.py b/tests/metrics/functional/test_retrieval.py deleted file mode 100644 index a0573cba1d27ed..00000000000000 --- a/tests/metrics/functional/test_retrieval.py +++ /dev/null @@ -1,36 +0,0 @@ -import math - -import numpy as np -import pytest -import torch -from sklearn.metrics import average_precision_score as sk_average_precision - -from pytorch_lightning import seed_everything -from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision - - -@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ - pytest.param(sk_average_precision, retrieval_average_precision), -]) -def test_against_sklearn(sklearn_metric, torch_metric): - """Compare PL metrics to sklearn version. """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - seed_everything(0) - - rounds = 25 - sizes = [1, 4, 10, 100] - - for size in sizes: - for _ in range(rounds): - a = np.random.randn(size) - b = np.random.randn(size) > 0 - - sk = torch.tensor(sklearn_metric(b, a), device=device) - pl = torch_metric(torch.tensor(a, device=device), torch.tensor(b, device=device)) - - # `torch_metric`s return 0 when no label is True - # while `sklearn.average_precision_score` returns NaN - if math.isnan(sk): - assert pl == 0 - else: - assert torch.allclose(sk.float(), pl.float()) diff --git a/tests/metrics/retrieval/__init__.py b/tests/metrics/retrieval/__init__.py deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/tests/metrics/retrieval/test_map.py b/tests/metrics/retrieval/test_map.py deleted file mode 100644 index fe43f19b20eb67..00000000000000 --- a/tests/metrics/retrieval/test_map.py +++ /dev/null @@ -1,119 +0,0 @@ -import math -import random -from typing import Callable, List - -import numpy as np -import pytest -import torch -from sklearn.metrics import average_precision_score as sk_average_precision -from torchmetrics import Metric - -from pytorch_lightning import seed_everything -from pytorch_lightning.metrics.retrieval.mean_average_precision import RetrievalMAP - - -@pytest.mark.parametrize(['sklearn_metric', 'torch_class_metric'], [ - [sk_average_precision, RetrievalMAP], -]) -def test_against_sklearn(sklearn_metric: Callable, torch_class_metric: Metric) -> None: - """Compare PL metrics to sklearn version. """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - seed_everything(0) - - rounds = 20 - sizes = [1, 4, 10, 100] - batch_sizes = [1, 4, 10] - query_without_relevant_docs_options = ['skip', 'pos', 'neg'] - - def compute_sklearn_metric(target: List[np.ndarray], preds: List[np.ndarray], behaviour: str) -> torch.Tensor: - """ Compute sk metric with multiple iterations using the base `sklearn_metric`. """ - sk_results = [] - kwargs = {'device': device, 'dtype': torch.float32} - - for b, a in zip(target, preds): - res = sklearn_metric(b, a) - - if math.isnan(res): - if behaviour == 'skip': - pass - elif behaviour == 'pos': - sk_results.append(torch.tensor(1.0, **kwargs)) - else: - sk_results.append(torch.tensor(0.0, **kwargs)) - else: - sk_results.append(torch.tensor(res, **kwargs)) - if len(sk_results) > 0: - sk_results = torch.stack(sk_results).mean() - else: - sk_results = torch.tensor(0.0, **kwargs) - - return sk_results - - def do_test(batch_size: int, size: int) -> None: - """ For each possible behaviour of the metric, check results are correct. """ - for behaviour in query_without_relevant_docs_options: - metric = torch_class_metric(query_without_relevant_docs=behaviour) - shape = (size, ) - - indexes = [] - preds = [] - target = [] - - for i in range(batch_size): - indexes.append(np.ones(shape, dtype=int) * i) - preds.append(np.random.randn(*shape)) - target.append(np.random.randn(*shape) > 0) - - sk_results = compute_sklearn_metric(target, preds, behaviour) - - indexes_tensor = torch.cat([torch.tensor(i) for i in indexes]) - preds_tensor = torch.cat([torch.tensor(p) for p in preds]) - target_tensor = torch.cat([torch.tensor(t) for t in target]) - - # lets assume data are not ordered - perm = torch.randperm(indexes_tensor.nelement()) - indexes_tensor = indexes_tensor.view(-1)[perm].view(indexes_tensor.size()) - preds_tensor = preds_tensor.view(-1)[perm].view(preds_tensor.size()) - target_tensor = target_tensor.view(-1)[perm].view(target_tensor.size()) - - # shuffle ids to require also sorting of documents ability from the lightning metric - pl_result = metric(indexes_tensor, preds_tensor, target_tensor) - - assert torch.allclose(sk_results.float(), pl_result.float(), equal_nan=True) - - for batch_size in batch_sizes: - for size in sizes: - for _ in range(rounds): - do_test(batch_size, size) - - -@pytest.mark.parametrize(['torch_class_metric'], [ - [RetrievalMAP], -]) -def test_input_data(torch_class_metric: Metric) -> None: - """Check PL metrics inputs are controlled correctly. """ - - device = 'cuda' if torch.cuda.is_available() else 'cpu' - seed_everything(0) - - for _ in range(10): - - length = random.randint(0, 20) - - # check error when `query_without_relevant_docs='error'` is raised correctly - indexes = torch.tensor([0] * length, device=device, dtype=torch.int64) - preds = torch.rand(size=(length, ), device=device, dtype=torch.float32) - target = torch.tensor([False] * length, device=device, dtype=torch.bool) - - metric = torch_class_metric(query_without_relevant_docs='error') - - try: - metric(indexes, preds, target) - except Exception as e: - assert isinstance(e, ValueError) - - # check ValueError with non-accepted argument - try: - metric = torch_class_metric(query_without_relevant_docs='casual_argument') - except Exception as e: - assert isinstance(e, ValueError) From 0995d30fab0590d155895a77535663794118b5f6 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 23 Mar 2021 15:13:13 +0000 Subject: [PATCH 16/22] Flash predict step (#6577) * add predict_step * Update predict_loop.py * Update trainer.py * Update trainer.py * resolve bugs * update * update * update * resolve bug * resolve some failing tests * udpate tests * update * resolve tests * add a test * remove typo * add a test for attachement * update * changed to on_train_dataloader * remove __flash_special_attr__ * resolve tests * update * update * update * update on comments * Update pytorch_lightning/trainer/data_loading.py Co-authored-by: Jirka Borovec Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Jirka Borovec --- docs/source/starter/introduction_guide.rst | 6 +- pytorch_lightning/accelerators/accelerator.py | 11 ++- pytorch_lightning/core/hooks.py | 24 +++++++ pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/overrides/base.py | 2 +- .../plugins/training_type/ddp.py | 2 +- .../plugins/training_type/ddp_spawn.py | 2 +- pytorch_lightning/plugins/training_type/dp.py | 2 +- .../plugins/training_type/tpu_spawn.py | 4 +- .../training_type/training_type_plugin.py | 4 +- .../trainer/connectors/data_connector.py | 4 ++ pytorch_lightning/trainer/data_loading.py | 15 ++-- pytorch_lightning/trainer/predict_loop.py | 12 +++- pytorch_lightning/trainer/trainer.py | 8 ++- tests/overrides/test_data_parallel.py | 2 +- tests/trainer/test_dataloaders.py | 68 +++++++++++++++++++ tests/trainer/test_trainer.py | 36 +++++++++- 17 files changed, 174 insertions(+), 30 deletions(-) diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst index c65894367a39e8..551b8182caa7d1 100644 --- a/docs/source/starter/introduction_guide.rst +++ b/docs/source/starter/introduction_guide.rst @@ -882,8 +882,8 @@ Or maybe we have a model that we use to do generation generated_imgs = model(z) -To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict`` function -By default, LightningModule ``predict`` calls forward, but it can be overriden to add any processing logic. +To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict_step`` function +By default, LightningModule ``predict_step`` calls forward, but it can be overriden to add any processing logic. .. code-block:: python @@ -893,7 +893,7 @@ By default, LightningModule ``predict`` calls forward, but it can be overriden t imgs = self.decoder(z) return imgs - def predict(self, batch, batch_idx: int , dataloader_idx: int = None): + def predict_step(self, batch, batch_idx: int , dataloader_idx: int = None): return self(batch) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 60e6ea88b4250d..9ea2cec491d2c5 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -219,7 +219,7 @@ def test_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context(): return self.training_type_plugin.test_step(*args) - def predict(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: + def predict_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: """The actual predict step. Args: @@ -235,7 +235,7 @@ def predict(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: args[0] = batch with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context(): - return self.training_type_plugin.predict(*args) + return self.training_type_plugin.predict_step(*args) def training_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: """A hook to do something at the end of the training step @@ -359,7 +359,12 @@ def setup_precision_plugin(self, plugin: PrecisionPlugin) -> None: def to_device(self, batch: Any) -> Any: """Pushes the batch to the root device""" - return self.batch_to_device(batch, self.root_device) + # Todo (tchaton) Better fix + is_dict = isinstance(batch, dict) + if is_dict: + batch = [batch] + batch = self.batch_to_device(batch, self.root_device) + return batch[0] if is_dict else batch @property def amp_backend(self) -> Optional[LightningEnum]: diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 9624f94652713b..8c68cc96eabc23 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -282,6 +282,18 @@ def on_test_end(self) -> None: """ # do something at the end of testing + def on_predict_start(self) -> None: + """ + Called at the beginning of predicting. + """ + # do something at the start of predicting + + def on_predict_end(self) -> None: + """ + Called at the end of predicting. + """ + # do something at the end of predicting + def on_before_zero_grad(self, optimizer: Optimizer) -> None: """ Called after optimizer.step() and before optimizer.zero_grad(). @@ -594,6 +606,18 @@ def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: will have an argument ``dataloader_idx`` which matches the order here. """ + def on_train_dataloader(self) -> None: + """Called before requesting the train dataloader.""" + + def on_val_dataloader(self) -> None: + """Called before requesting the val dataloader.""" + + def on_test_dataloader(self) -> None: + """Called before requesting the test dataloader.""" + + def on_predict_dataloader(self) -> None: + """Called before requesting the predict dataloader.""" + def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: """ Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 044dd95f3b8c66..4d36fe48448dca 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1054,7 +1054,7 @@ def test_epoch_end(self, outputs): self.log('final_metric', final_value) """ - def predict(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None): + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None): """ Use this function with trainer.predict(...). Override if you need to add any processing logic. """ diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index 1d6f4e93b5779d..0c1ac7b359fd0a 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -53,7 +53,7 @@ def forward(self, *inputs, **kwargs): elif trainer and (trainer.sanity_checking or trainer.validating): output = self.module.validation_step(*inputs, **kwargs) elif trainer and trainer.predicting: - output = self.module.predict(*inputs, **kwargs) + output = self.module.predict_step(*inputs, **kwargs) else: output = self.module(*inputs, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index bcadf16607b4fe..58e26e7db32d85 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -298,7 +298,7 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.model(*args, **kwargs) - def predict(self, *args, **kwargs): + def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) def post_training_step(self): diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index ea1efd6e158734..87d7fa5faecac5 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -282,7 +282,7 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.model(*args, **kwargs) - def predict(self, *args, **kwargs): + def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) def post_training_step(self): diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index b96b7097d07c7e..a8e42e0fa747af 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -83,7 +83,7 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.model(*args, **kwargs) - def predict(self, *args, **kwargs): + def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) def training_step_end(self, output): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index c883ff504f24d7..3887e0cd98908f 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -294,8 +294,8 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.lightning_module.test_step(*args, **kwargs) - def predict(self, *args, **kwargs): - return self.lightning_module.predict(*args, **kwargs) + def predict_step(self, *args, **kwargs): + return self.lightning_module.predict_step(*args, **kwargs) def save_checkpoint(self, filepath, weights_only: bool = False): """Save model/training states as a checkpoint file through state-dump and file-write. diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 89f27963caadfd..08dca63a7c9250 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -154,8 +154,8 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.lightning_module.test_step(*args, **kwargs) - def predict(self, *args, **kwargs): - return self.lightning_module.predict(*args, **kwargs) + def predict_step(self, *args, **kwargs): + return self.lightning_module.predict_step(*args, **kwargs) def training_step_end(self, output): return output diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index b3fc0b4eb7b297..5d2f141dc64a83 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -150,6 +150,10 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = N self.trainer.datamodule = datamodule datamodule.trainer = self.trainer + # experimental feature for Flash + if hasattr(datamodule, "data_pipeline"): + model.data_pipeline = datamodule.data_pipeline + class _PatchDataLoader(object): r""" diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 56da7039bbca75..1a9c69d107b972 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -16,7 +16,7 @@ import platform from abc import ABC from copy import deepcopy -from typing import Callable, Iterable, List, Tuple, Union +from typing import Iterable, List, Tuple, Union from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -191,7 +191,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: Args: model: The current `LightningModule` """ - self.train_dataloader = self.request_dataloader(model.train_dataloader) + self.train_dataloader = self.request_dataloader(model, "train") if self.overfit_batches > 0: if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler): @@ -271,7 +271,7 @@ def _reset_eval_dataloader( """ # always get the loaders first so we can count how many there are loader_name = f'{mode}_dataloader' - dataloaders = self.request_dataloader(getattr(model, loader_name)) + dataloaders = self.request_dataloader(model, mode) if not isinstance(dataloaders, list): dataloaders = [dataloaders] @@ -280,7 +280,7 @@ def _reset_eval_dataloader( # duplicate it the numb of times needed to match the train loaders if self.overfit_batches > 0: num_loaders = len(dataloaders) - train_dataloader = self.request_dataloader(getattr(model, 'train_dataloader')) + train_dataloader = self.request_dataloader(model, 'train') dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)] self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders) @@ -380,7 +380,7 @@ def reset_predict_dataloader(self, model) -> None: if has_loader: self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(model, 'predict') - def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: + def request_dataloader(self, model: LightningModule, stage: str) -> DataLoader: """Handles downloading data in the GPU or TPU case. Args: @@ -389,9 +389,10 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: Returns: The dataloader """ - dataloader = dataloader_fx() + if model.trainer is not None: + model.trainer.call_hook(f"on_{stage}_dataloader") + dataloader: DataLoader = getattr(model, f'{stage}_dataloader')() dataloader = self._flatten_dl_only(dataloader) - self.accelerator.barrier('get_dataloaders') return dataloader diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index 70329b4fdf514c..53e82fd3f62b3d 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -65,7 +65,7 @@ def _get_num_dataloaders(self, dataloaders): length = len(dataloaders[0]) return length - def predict(self, batch, batch_idx, dataloader_idx): + def predict_step(self, batch, batch_idx, dataloader_idx): # configure args args = [batch, batch_idx] if self.num_dataloaders: @@ -74,7 +74,7 @@ def predict(self, batch, batch_idx, dataloader_idx): model_ref = self.trainer.lightning_module model_ref._current_fx_name = "predict" - predictions = self.trainer.accelerator.predict(args) + predictions = self.trainer.accelerator.predict_step(args) if predictions is None: self.warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") @@ -99,3 +99,11 @@ def _convert_to_numpy(v): return results[0] return results + + def on_predict_start(self): + # hook + self.trainer.call_hook("on_predict_start") + + def on_predict_end(self): + # hook + self.trainer.call_hook("on_predict_end") diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f7bd1757b9bc21..bb5d6919964e53 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -762,6 +762,8 @@ def run_evaluate(self): return eval_loop_results def run_predict(self): + self.predict_loop.on_predict_start() + # prepare dataloaders dataloaders, max_batches = self.predict_loop.get_predict_dataloaders() @@ -784,7 +786,6 @@ def run_predict(self): for dataloader_idx, dataloader in enumerate(dataloaders): dataloader = self.accelerator.process_dataloader(dataloader) dl_max_batches = self.predict_loop.max_batches[dataloader_idx] - for batch_idx, batch in enumerate(dataloader): if batch is None: continue @@ -794,10 +795,11 @@ def run_predict(self): break # lightning module methods - with self.profiler.profile("predict"): - self.predict_loop.predict(batch, batch_idx, dataloader_idx) + with self.profiler.profile("predict_step"): + self.predict_loop.predict_step(batch, batch_idx, dataloader_idx) results = self.predict_loop.on_predict_epoch_end() + self.predict_loop.on_predict_end() return results def run_sanity_check(self, ref_model): diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 3921e7ef33b8e1..aaf47c82d5f087 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -24,7 +24,7 @@ ("training", "training_step"), ("testing", "test_step"), ("validating", "validation_step"), - ("predicting", "predict"), + ("predicting", "predict_step"), ] ) def test_lightning_wrapper_module_methods(wrapper_class, stage): diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 52c51777e2a893..505af173b79108 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1159,3 +1159,71 @@ def test_replace_sampler_with_multiprocessing_context(tmpdir): new_data_loader = trainer.replace_sampler(train, SequentialSampler(train.dataset)) assert (new_data_loader.multiprocessing_context == train.multiprocessing_context) + + +def test_request_dataloader(tmpdir): + """ + This test asserts dataloader can be modified and properly set to the trainer. + """ + + class DataLoaderWrapper: + + def __init__(self, loader): + self.loader = loader + self._iter = iter(self.loader) + + def __iter__(self): + self._iter = iter(self.loader) + return self._iter + + def __next__(self): + return next(self._iter) + + class DataLoaderFunc: + + def __init__(self, loader): + self.loader = loader + + def __call__(self): + return self.loader + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.on_train_dataloader_called = False + self.on_train_batch_start_called = False + self.on_val_dataloader_called = False + self.on_val_batch_start_called = False + + def on_train_dataloader(self) -> None: + loader = self.train_dataloader() + self.train_dataloader = DataLoaderFunc(DataLoaderWrapper(loader)) + self.on_train_dataloader_called = True + + def on_train_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None: + assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper) + self.on_train_batch_start_called = True + + def on_val_dataloader(self) -> None: + loader = self.val_dataloader() + self.val_dataloader = DataLoaderFunc(DataLoaderWrapper(loader)) + self.on_val_dataloader_called = True + + def on_validation_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None: + assert isinstance(self.trainer.val_dataloaders[0], DataLoaderWrapper) + self.on_val_batch_start_called = True + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + model = TestModel() + trainer.fit(model) + trainer.test(model) + assert model.on_train_dataloader_called + assert model.on_train_batch_start_called + assert model.on_val_dataloader_called + assert model.on_val_batch_start_called diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 66889bb7e11390..d461d9d152e743 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1440,11 +1440,11 @@ def test_trainer_predict_no_return(tmpdir): class CustomBoringModel(BoringModel): - def predict(self, batch, batch_idx, dataloader_idx=None): + def predict_step(self, batch, batch_idx, dataloader_idx=None): if (batch_idx + 1) % 2 == 0: return - return super().predict(batch, batch_idx, dataloader_idx) + return super().predict_step(batch, batch_idx, dataloader_idx) with pytest.warns(UserWarning, match='predict returned None'): predict(tmpdir, None, None, 1, model=CustomBoringModel()) @@ -1731,3 +1731,35 @@ def test_check_val_every_n_epoch_exception(tmpdir): max_epochs=1, check_val_every_n_epoch=1.2, ) + + +def test_trainer_attach_data_pipeline_to_model(tmpdir): + + class DataPipeline: + + pass + + class TestDataModule(LightningDataModule): + + data_pipeline = DataPipeline() + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + class TestCallback(Callback): + + def on_fit_start(self, trainer, pl_module: LightningModule) -> None: + """Called when fit begins""" + assert isinstance(pl_module.data_pipeline, DataPipeline) + + model = BoringModel() + dm = TestDataModule() + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=[TestCallback()]) + trainer.fit(model, datamodule=dm) From 3cf0c3117a6c0ddff9bef5a216cad1cb4af5b6e6 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 23 Mar 2021 17:41:36 +0100 Subject: [PATCH 17/22] fix back-compatibility for Accel (#6655) --- pytorch_lightning/accelerators/accelerator.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 9ea2cec491d2c5..4aa5fedf2b210c 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -21,6 +21,7 @@ from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import AMPType, LightningEnum @@ -437,3 +438,27 @@ def results(self) -> Any: In distributed training, we make sure to transfer the results to the appropriate master process. """ return self.training_type_plugin.results + + # todo: remove in v1.5 + def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: + """ + Attaches the training type plugin to the accelerator. + Also transfers ownership of the model to this plugin + + .. deprecated::v1.3 + Will be removed in v1.5.0. + """ + rank_zero_warn('Accelerator method `connect_training_type_plugin` was deprecated in v1.3.' + ' It will be removed in v1.5.') + self.setup_training_type_plugin(plugin, model) + + # todo: remove in v1.5 + def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: + """Attaches the precision plugin to the accelerator + + .. deprecated::v1.3 + Will be removed in v1.5.0. + """ + rank_zero_warn('Accelerator method `connect_precision_plugin` was deprecated in v1.3.' + ' It will be removed in v1.5.') + self.setup_precision_plugin(plugin) From 51b10f78f4b4c4b704219c619dc5e73784aca57b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 23 Mar 2021 18:13:29 +0100 Subject: [PATCH 18/22] Refactor PyTorch profiler 4/5 (#6349) Co-authored-by: thomas chaton --- CHANGELOG.md | 9 + pytorch_lightning/profiler/profilers.py | 12 +- pytorch_lightning/profiler/pytorch.py | 363 +++++++++++------- .../trainer/connectors/profiler_connector.py | 3 +- pytorch_lightning/trainer/predict_loop.py | 4 + pytorch_lightning/trainer/training_loop.py | 2 +- pytorch_lightning/utilities/imports.py | 1 + tests/checkpointing/test_torch_saving.py | 1 + tests/deprecated_api/test_remove_1-5.py | 5 + tests/test_profiler.py | 176 ++++++--- tests/trainer/properties/test_get_model.py | 20 - 11 files changed, 377 insertions(+), 219 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 81bfa85cc073fc..e1106189e0c17a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `AbstractProfiler` interface ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) +- Added support for including module names for forward in the autograd trace of `PyTorchProfiler` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) + + - Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) @@ -72,6 +75,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed profilers to save separate report files per state and rank ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) +- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) + + ### Deprecated - `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) @@ -83,6 +89,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `Profiler(output_filename)` in favor of `dirpath` and `filename` ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) +- Deprecated `PytorchProfiler(profiled_functions)` in favor of `record_functions` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) + + - Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505), [#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530), diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index 54bc5cdf0122c4..46d72583fb466e 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -126,7 +126,7 @@ def _prepare_filename(self) -> str: filename += f"{self._stage}-" filename += str(self.filename) if self._local_rank is not None: - filename += f"-{self.local_rank}" + filename += f"-{self._local_rank}" filename += ".txt" return filename @@ -134,8 +134,7 @@ def _prepare_streams(self) -> None: if self._write_stream is not None: return if self.filename: - dirpath = self.dirpath or self._log_dir - filepath = os.path.join(dirpath, self._prepare_filename()) + filepath = os.path.join(self.dirpath, self._prepare_filename()) fs = get_filesystem(filepath) file = fs.open(filepath, "a") self._output_file = file @@ -175,8 +174,7 @@ def setup( self._stage = stage self._local_rank = local_rank self._log_dir = log_dir - if self.dirpath is None: - self.dirpath = self._log_dir + self.dirpath = self.dirpath or log_dir def teardown(self, stage: Optional[str] = None) -> None: """ @@ -202,8 +200,8 @@ def summary(self) -> str: raise NotImplementedError @property - def local_rank(self): - return '0' if self._local_rank is None else self._local_rank + def local_rank(self) -> int: + return 0 if self._local_rank is None else self._local_rank class PassThroughProfiler(BaseProfiler): diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 55b1c286789f43..974883a4724c6f 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -12,25 +12,92 @@ # See the License for the specific language governing permissions and # limitations under the License. """Profiler to check if there are any bottlenecks in your code.""" - import inspect import logging import os +from functools import partial from pathlib import Path -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING, Union import torch +from torch import nn, Tensor +from torch.autograd.profiler import record_function from pytorch_lightning.profiler.profilers import BaseProfiler from pytorch_lightning.utilities.distributed import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +if TYPE_CHECKING: + from torch.autograd.profiler import EventList + from torch.utils.hooks import RemovableHandle + + from pytorch_lightning.core.lightning import LightningModule + log = logging.getLogger(__name__) +_PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx] + + +class RegisterRecordFunction: + """ + While profiling autograd operations, this class will add labels for module names around the forward function. + + The Lightning PyTorch Profiler will activate this feature automatically. It can be deactivated as follows: + + Example:: + from pytorch_lightning.profilers import PyTorchProfiler + profiler = PyTorchProfiler(record_module_names=False) + Trainer(profiler=profiler) + + It can be used outside of Lightning as follows: + + Example:: + from pytorch_lightning import Trainer, seed_everything + with RegisterRecordFunction(model): + out = model(batch) + """ + + def __init__(self, model: nn.Module) -> None: + self._model = model + self._records: Dict[str, record_function] = {} + self._handles: Dict[str, List['RemovableHandle']] = {} + + def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor: + record = record_function(record_name) + record.__enter__() + self._records[record_name] = record + return input + + def _stop_recording_forward(self, _: nn.Module, __: Tensor, output: Tensor, record_name: str) -> Tensor: + self._records[record_name].__exit__(None, None, None) + return output + + def __enter__(self) -> None: + for module_name, module in self._model.named_modules(): + if module_name: + full_name = f"{type(module).__module__}.{type(module).__name__}" + record_name = f"{full_name}: {module_name}" + pre_forward_handle = module.register_forward_pre_hook( + partial(self._start_recording_forward, record_name=record_name) + ) + post_forward_handle = module.register_forward_hook( + partial(self._stop_recording_forward, record_name=record_name) + ) + + self._handles[module_name] = [pre_forward_handle, post_forward_handle] + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + for handles in self._handles.values(): + for h in handles: + h.remove() + self._handles = {} + class PyTorchProfiler(BaseProfiler): - PROFILED_FUNCTIONS = ("training_step_and_backward", "validation_step", "test_step") + RECORD_FUNCTIONS = ( + "training_step_and_backward", "training_step", "backward", "validation_step", "test_step", "predict_step" + ) AVAILABLE_SORT_KEYS = ( "cpu_time", "cuda_time", @@ -42,27 +109,24 @@ class PyTorchProfiler(BaseProfiler): "self_cuda_memory_usage", "count", ) + START_RECORD_FUNCTIONS = ('on_train_start', 'on_validation_start', 'on_test_start', 'on_predict_start') def __init__( self, dirpath: Optional[Union[str, Path]] = None, filename: Optional[str] = None, - enabled: bool = True, - use_cuda: bool = False, - record_shapes: bool = False, - profile_memory: bool = False, group_by_input_shapes: bool = False, - with_stack: bool = False, - use_kineto: bool = False, - use_cpu: bool = True, emit_nvtx: bool = False, - export_to_chrome: bool = False, - path_to_export_trace: str = None, + export_to_chrome: bool = True, + path_to_export_trace: Optional[str] = None, row_limit: int = 20, sort_by_key: Optional[str] = None, + record_functions: List[str] = None, + record_module_names: bool = True, profiled_functions: Optional[List] = None, output_filename: Optional[str] = None, - ): + **profiler_kwargs: Any, + ) -> None: """ This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of different operators inside your model - both on the CPU and GPU @@ -75,24 +139,8 @@ def __init__( filename: If present, filename where the profiler results will be saved instead of printing to stdout. The ``.txt`` extension will be used automatically. - enabled: Setting this to False makes this context manager a no-op. - - use_cuda: Enables timing of CUDA events as well using the cudaEvent API. - Adds approximately 4us of overhead to each tensor operation. - - record_shapes: If shapes recording is set, information about input dimensions will be collected. - - profile_memory: Whether to report memory usage, default: True (Introduced in PyTorch 1.6.0) - group_by_input_shapes: Include operator input shapes and group calls by shape. - with_stack: record source information (file and line number) for the ops (Introduced in PyTorch 1.7.0) - - use_kineto: experimental support for Kineto profiler (Introduced in PyTorch 1.8.0) - - use_cpu: use_kineto=True and can be used to lower the overhead - for GPU-only profiling (Introduced in PyTorch 1.8.0) - emit_nvtx: Context manager that makes every autograd operation emit an NVTX range Run:: @@ -103,164 +151,189 @@ def __init__( nvvp trace_name.prof torch.autograd.profiler.load_nvprof(path) - export_to_chrome: Wether to export the sequence of profiled operators for Chrome. + export_to_chrome: Whether to export the sequence of profiled operators for Chrome. It will generate a ``.json`` file which can be read by Chrome. path_to_export_trace: Directory path to export ``.json`` traces when using ``export_to_chrome=True``. By default, it will be save where the file being is being run. - row_limit: Limit the number of rows in a table, `0` is a special value that + row_limit: Limit the number of rows in a table, ``-1`` is a special value that removes the limit completely. - sort_by_key: Keys to sort out profiled table + sort_by_key: Attribute used to sort entries. By default + they are printed in the same order as they were registered. + Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``, + ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``, + ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``. - profiled_functions: list of profiled functions which will create a context manager on. + record_functions: list of profiled functions which will create a context manager on. Any other will be pass through. + record_module_names: Whether to add module names while recording autograd operation. + + profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version + Raises: MisconfigurationException: If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``. - ValueError: - If you attempt to stop recording an action which was never started. """ + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) - self.profiled_actions = {} - self.enabled = enabled - self.profiled_functions = profiled_functions or self.PROFILED_FUNCTIONS - self.use_cuda = use_cuda - self.record_shapes = record_shapes - self.profile_memory = profile_memory - self.sort_by_key = sort_by_key or ("cuda_time_total" if self.use_cuda else "cpu_time_total") - self.with_stack = with_stack - self.group_by_input_shapes = group_by_input_shapes and record_shapes - self.use_kineto = use_kineto - self.use_cpu = use_cpu - self.row_limit = row_limit - self.emit_nvtx = emit_nvtx - self.export_to_chrome = export_to_chrome - self.path_to_export_trace = path_to_export_trace - - if export_to_chrome and path_to_export_trace is None: + record_functions = self.__deprecation_check(profiled_functions, record_functions) + + self._group_by_input_shapes = group_by_input_shapes and profiler_kwargs.get("record_shapes", False) + self._emit_nvtx = emit_nvtx + self._export_to_chrome = export_to_chrome + self._path_to_export_trace = path_to_export_trace + self._row_limit = row_limit + self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total" + self._record_functions_start = set(record_functions + list(self.START_RECORD_FUNCTIONS)) + self._record_functions = set(record_functions + list(self.RECORD_FUNCTIONS)) + self._record_module_names = record_module_names + self._profiler_kwargs = profiler_kwargs + + self.profiler: Optional[_PROFILER] = None + self.function_events: Optional['EventList'] = None + self._lightning_module: Optional['LightningModule'] = None # set by ProfilerConnector + self._register: Optional[RegisterRecordFunction] = None + self._parent_profiler: Optional[_PROFILER] = None + self._recording_map: Dict[str, record_function] = {} + + if self._export_to_chrome and self._path_to_export_trace is None: rank_zero_warn( - "The exported trace would be save locally as `path_to_export_trace` is empty." + "The exported trace would be saved locally as `path_to_export_trace` is None." " Note: Each functions will generate its own traced file." ) - if self.sort_by_key not in self.AVAILABLE_SORT_KEYS: + if self._sort_by_key not in self.AVAILABLE_SORT_KEYS: raise MisconfigurationException( - f"Found sort_by_key: {sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " + f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " ) - self.profiled_actions = {} - self.context_names = {} - self.running_stack = [] - self.profiler = None + def __deprecation_check( + self, + profiled_functions: Optional[List[str]], + record_functions: Optional[List[str]], + ) -> List[str]: + if record_functions is None: + record_functions = [] - super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) + if profiled_functions is not None: + rank_zero_warn( + "`PyTorchProfiler.profiled_functions` has been renamed to" + " `record_functions` in v1.3 and will be removed in v1.5", DeprecationWarning + ) + if not record_functions: + record_functions += profiled_functions + else: + raise MisconfigurationException( + "You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`." + " Please use only the later." + ) + + return record_functions def setup( - self, - stage: Optional[str] = None, - local_rank: Optional[int] = None, - log_dir: Optional[str] = None + self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None ) -> None: super().setup(stage=stage, local_rank=local_rank, log_dir=log_dir) # if the user didn't provide `path_to_export_trace`, # set it as TensorBoardLogger log_dir if exists - if self.path_to_export_trace is None: - self.path_to_export_trace = log_dir + if self._path_to_export_trace is None: + self._path_to_export_trace = log_dir def start(self, action_name: str) -> None: - if action_name not in self.profiled_functions: - return - - if len(self.running_stack) > 0: - self._stop(self.running_stack[-1]) - self.running_stack.append(action_name) + if self.profiler is None and action_name in self._record_functions_start: + + # close profiler if it is already opened. might happen if 2 profilers + # are created and the first one did not call `describe` + try: + torch.autograd._disable_profiler() # noqa + except (AttributeError, RuntimeError): + pass + + self._create_profilers() + + self.profiler.__enter__() + if self._parent_profiler is not None: + self._parent_profiler.__enter__() + if self._register is not None: + self._register.__enter__() + + if ( + self.profiler is not None and action_name in self._record_functions + and action_name not in self._recording_map + ): + recording = record_function(action_name) + recording.__enter__() + self._recording_map[action_name] = recording - self.context_names[action_name] = "/".join(self.running_stack) - - self._start(action_name) + def stop(self, action_name: str) -> None: + if action_name in self._recording_map: + self._recording_map[action_name].__exit__(None, None, None) + del self._recording_map[action_name] - def _start(self, action_name: str) -> None: - if self.emit_nvtx: - self._parent_profiler = self._create_profiler(action_name, torch.cuda.profiler.profile, enter=True) - self._create_profiler(action_name, torch.autograd.profiler.emit_nvtx) - else: - self._create_profiler(action_name, torch.autograd.profiler.profile) - - def _create_profiler(self, action_name, profiler, enter=True): - init_args = inspect.signature(profiler.__init__).parameters - profiler_args = {k: v for k, v in vars(self).items() if k in init_args} - pr = profiler(**profiler_args) - if enter: - out_pr = pr.__enter__() - if out_pr is not None: - pr = out_pr - self.profiler = pr - return self.profiler - - def _stop(self, action_name: str) -> None: - if self.profiler is None: - return - - self.profiler.__exit__(exc_type=None, exc_val=None, exc_tb=None) - - if isinstance(self.profiler, torch.autograd.profiler.emit_nvtx): - # when running ``emit_nvtx``, PyTorch requires 2 context manager. - # The parent_profiler is being closed too. - self._parent_profiler.__exit__(None, None, None) - self._parent_profiler = None - return + def summary(self) -> str: + if not self._profiler_kwargs.get("enabled", True) or self._emit_nvtx: + return "" - function_events = self.profiler.function_events - self.profiler = None - for name in self.running_stack: - if name not in self.profiled_actions: - self.profiled_actions[name] = function_events - else: - self.profiled_actions[name] += function_events + self._delete_profilers() - def stop(self, action_name: str) -> None: - if action_name not in self.profiled_functions: - return + if not self.function_events: + return "" - if len(self.running_stack) == 0 or self.running_stack[-1] != action_name: - raise ValueError( # pragma: no-cover - f"Attempting to stop recording an action ({action_name}) which was never started." + if self._export_to_chrome: + filename = f"{self.local_rank}_trace.json" + path_to_trace = ( + filename if self._path_to_export_trace is None else os.path.join(self._path_to_export_trace, filename) ) - self._stop(action_name) - self.running_stack.pop() - # restore running profiler - if len(self.running_stack) > 0: - self._start(self.running_stack[-1]) + self.function_events.export_chrome_trace(path_to_trace) - def summary(self) -> str: - recorded_stats = {} - output_string = '' + data = self.function_events.key_averages(group_by_input_shapes=self._group_by_input_shapes) + table = data.table(sort_by=self._sort_by_key, row_limit=self._row_limit) - if not self.enabled: - return output_string + recorded_stats = {"records": table} + return self._stats_to_str(recorded_stats) - for action_name, function_events in self.profiled_actions.items(): + def _create_profilers(self) -> None: + if self._emit_nvtx: + self._parent_profiler = self._create_profiler(torch.cuda.profiler.profile) + self.profiler = self._create_profiler(torch.autograd.profiler.emit_nvtx) + else: + self._parent_profiler = None + self.profiler = self._create_profiler(torch.autograd.profiler.profile) + if self._record_module_names and self._lightning_module is not None: + self._register = RegisterRecordFunction(self._lightning_module) + + def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER: + init_parameters = inspect.signature(profiler.__init__).parameters + kwargs = {k: v for k, v in self._profiler_kwargs.items() if k in init_parameters} + return profiler(**kwargs) + + def _cache_functions_events(self): + if not self._emit_nvtx: + self.function_events = self.profiler.function_events + + def _delete_profilers(self) -> None: + if self.profiler is not None: + self.profiler.__exit__(None, None, None) + self._cache_functions_events() + self.profiler = None + + if self._parent_profiler is not None: + self._parent_profiler.__exit__(None, None, None) + self._parent_profiler = None - # next line is a workaround for a pytorch issue (fixed on master, still present - # on 1.7). Without it the code fails with `AssertionError: There is already a CPU - # parent event for detach` - function_events.populate_cpu_children = lambda: None + if self._register is not None: + self._register.__exit__(None, None, None) + self._register = None - if self.export_to_chrome: - filename = f"{action_name}_{self.local_rank}_trace.json" - path_to_trace = filename if self.path_to_export_trace is None \ - else os.path.join(self.path_to_export_trace, filename) - function_events.export_chrome_trace(path_to_trace) + def teardown(self, stage: Optional[str] = None) -> None: + self._delete_profilers() - if self.emit_nvtx: - return output_string + for k in self._recording_map: + self.stop(k) + self._recording_map = {} - else: - data = function_events.key_averages(group_by_input_shapes=self.group_by_input_shapes) - table = data.table(sort_by=self.sort_by_key, row_limit=self.row_limit) - recorded_stats[action_name] = table - return self._stats_to_str(recorded_stats) + super().teardown(stage=stage) diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index e628d6d96bd199..191e8711463ab0 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License - from typing import Union +from weakref import proxy from pytorch_lightning.profiler import ( AdvancedProfiler, @@ -57,4 +57,5 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, str]): def setup(self) -> None: trainer = self.trainer local_rank = trainer.local_rank if trainer.world_size > 1 else None + trainer.profiler.lightning_module = proxy(trainer.lightning_module) trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir) diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index 53e82fd3f62b3d..b33f41cb2ea487 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -44,6 +44,8 @@ def on_predict_model_eval(self, *_, **__): model_ref.on_predict_model_eval() def setup(self, model, max_batches, dataloaders): + self.trainer.call_hook("on_predict_start") + # copy properties for forward overrides self.trainer.model_connector.copy_trainer_model_properties(model) @@ -86,6 +88,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx): return def on_predict_epoch_end(self): + self.trainer.profiler.describe() + self.trainer._progress_bar_callback.on_predict_end(self.trainer, self.trainer.lightning_module) results = self._predictions diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index cc471f76b60334..c3ba34ca66d2d3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -743,7 +743,7 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, # backward pass if result is not None: - with self.trainer.profiler.profile("model_backward"): + with self.trainer.profiler.profile("backward"): self.backward(result, optimizer, opt_idx) # hook - call this hook only diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 8090c4ed6590f8..5a780660a0a99a 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -68,6 +68,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0") _TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") _TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") +_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0") _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available('pl_bolts') diff --git a/tests/checkpointing/test_torch_saving.py b/tests/checkpointing/test_torch_saving.py index c8b1e96aeaf0a0..8eabc4640046f3 100644 --- a/tests/checkpointing/test_torch_saving.py +++ b/tests/checkpointing/test_torch_saving.py @@ -47,6 +47,7 @@ def test_model_torch_save_ddp_cpu(tmpdir): max_epochs=num_epochs, accelerator="ddp_cpu", num_processes=2, + logger=False, ) temp_path = os.path.join(tmpdir, 'temp.pt') trainer.fit(model) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 0c5f581d7775c2..725db1180d9e82 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -81,6 +81,11 @@ def on_save_checkpoint(self, *args): trainer.save_checkpoint(filepath) +def test_v1_5_0_legacy_profiler_argument(): + with pytest.deprecated_call(match="renamed to `record_functions` in v1.3"): + PyTorchProfiler(profiled_functions=[]) + + def test_v1_5_0_running_sanity_check(): trainer = Trainer() with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'): diff --git a/tests/test_profiler.py b/tests/test_profiler.py index cf6afcc9b626c1..5d144aef36573a 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import os +import platform import time from copy import deepcopy from distutils.version import LooseVersion @@ -24,6 +25,9 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler +from pytorch_lightning.profiler.pytorch import RegisterRecordFunction +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -126,10 +130,10 @@ def test_simple_profiler_log_dir(tmpdir): ) trainer.fit(model) - expected = profiler.dirpath + expected = tmpdir / "lightning_logs" / "version_0" assert trainer.log_dir == expected assert profiler._log_dir == trainer.log_dir - assert Path(os.path.join(profiler.dirpath, "fit-profiler.txt")).exists() + assert expected.join("fit-profiler.txt").exists() @RunIf(skip_windows=True) @@ -264,8 +268,8 @@ def pytorch_profiler(tmpdir): def test_pytorch_profiler_describe(pytorch_profiler): """Ensure the profiler won't fail when reporting the summary.""" - with pytorch_profiler.profile("test_step"): - pass + with pytorch_profiler.profile("on_test_start"): + torch.tensor(0) # log to stdout and print to file pytorch_profiler.describe() @@ -274,15 +278,10 @@ def test_pytorch_profiler_describe(pytorch_profiler): assert len(data) > 0 -def test_pytorch_profiler_value_errors(pytorch_profiler): +def test_pytorch_profiler_raises(pytorch_profiler): """Ensure errors are raised where expected.""" - - action = "test_step" - with pytest.raises(ValueError): - pytorch_profiler.stop(action) - - pytorch_profiler.start(action) - pytorch_profiler.stop(action) + with pytest.raises(MisconfigurationException, match="profiled_functions` and `PyTorchProfiler.record"): + PyTorchProfiler(profiled_functions=["a"], record_functions=["b"]) @RunIf(min_torch="1.6.0") @@ -299,9 +298,8 @@ def test_advanced_profiler_cprofile_deepcopy(tmpdir): @RunIf(min_gpus=2, special=True) -def test_pytorch_profiler_trainer_ddp(tmpdir): +def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler): """Ensure that the profiler can be given to the training and default step are properly recorded. """ - pytorch_profiler = PyTorchProfiler(dirpath=None, filename="profiler") model = BoringModel() trainer = Trainer( max_epochs=1, @@ -314,17 +312,68 @@ def test_pytorch_profiler_trainer_ddp(tmpdir): ) trainer.fit(model) - assert len(pytorch_profiler.summary()) > 0 - assert set(pytorch_profiler.profiled_actions) == {'training_step_and_backward', 'validation_step'} + expected = ('validation_step', 'training_step_and_backward', 'training_step', 'backward') + for name in expected: + assert sum(e.name == name for e in pytorch_profiler.function_events) - files = sorted(f for f in os.listdir(pytorch_profiler.dirpath) if "fit" in f) - rank = int(os.getenv("LOCAL_RANK", "0")) - expected = f"fit-profiler-{rank}.txt" - assert files[rank] == expected + files = set(os.listdir(pytorch_profiler.dirpath)) + expected = f"fit-profiler-{trainer.local_rank}.txt" + assert expected in files path = os.path.join(pytorch_profiler.dirpath, expected) - data = Path(path).read_text("utf-8") - assert len(data) > 0 + assert Path(path).read_text() + + +def test_pytorch_profiler_trainer_test(tmpdir, pytorch_profiler): + """Ensure that the profiler can be given to the trainer and test step are properly recorded. """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_test_batches=2, + profiler=pytorch_profiler, + ) + trainer.test(model) + + assert sum(e.name == 'test_step' for e in pytorch_profiler.function_events) + + path = pytorch_profiler.dirpath / f"test-{pytorch_profiler.filename}.txt" + assert path.read_text("utf-8") + + +def test_pytorch_profiler_trainer_predict(tmpdir, pytorch_profiler): + """Ensure that the profiler can be given to the trainer and predict function are properly recorded. """ + model = BoringModel() + model.predict_dataloader = model.train_dataloader + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_test_batches=2, + profiler=pytorch_profiler, + ) + trainer.predict(model) + + assert sum(e.name == 'predict_step' for e in pytorch_profiler.function_events) + + path = pytorch_profiler.dirpath / f"predict-{pytorch_profiler.filename}.txt" + assert path.read_text("utf-8") + + +def test_pytorch_profiler_trainer_validate(tmpdir, pytorch_profiler): + """Ensure that the profiler can be given to the trainer and validate function are properly recorded. """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=2, + profiler=pytorch_profiler, + ) + trainer.validate(model) + + assert sum(e.name == 'validation_step' for e in pytorch_profiler.function_events) + + path = pytorch_profiler.dirpath / f"validate-{pytorch_profiler.filename}.txt" + assert path.read_text("utf-8") def test_pytorch_profiler_nested(tmpdir): @@ -341,34 +390,31 @@ def test_pytorch_profiler_nested(tmpdir): with pytorch_profiler.profile("c"): _ = a + b - pa = pytorch_profiler.profiled_actions + pytorch_profiler.describe() - # From PyTorch 1.8.0, less operation are being traced. - if LooseVersion(torch.__version__) >= LooseVersion("1.8.0"): - expected_ = { - 'a': ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'add'], - 'b': ['zeros', 'empty', 'zero_'], - 'c': ['add'], + events_name = {e.name for e in pytorch_profiler.function_events} + + if platform.system() == "Windows": + expected = {'a', 'add', 'b', 'c', 'profiler::_record_function_enter', 'profiler::_record_function_exit'} + else: + expected = { + 'signed char', 'add', 'profiler::_record_function_exit', 'bool', 'char', 'profiler::_record_function_enter' } - # From PyTorch 1.6.0, more operation are being traced. - elif LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - expected_ = { - 'a': ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'fill_', 'add', 'empty'], - 'b': ['zeros', 'empty', 'zero_', 'fill_'], - 'c': ['add', 'empty'], + + if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + expected = {'add', 'zeros', 'ones', 'zero_', 'b', 'fill_', 'c', 'a', 'empty'} + + if LooseVersion(torch.__version__) >= LooseVersion("1.7.0"): + expected = { + 'aten::zeros', 'aten::add', 'aten::zero_', 'c', 'b', 'a', 'aten::fill_', 'aten::empty', 'aten::ones' } - else: - expected_ = { - 'a': ['add'], - 'b': [], - 'c': ['add'], + + if LooseVersion(torch.__version__) >= LooseVersion("1.8.0"): + expected = { + 'aten::ones', 'a', 'aten::add', 'aten::empty', 'aten::zero_', 'b', 'c', 'aten::zeros', 'aten::fill_' } - for n in ('a', 'b', 'c'): - pa[n] = [e.name for e in pa[n]] - if LooseVersion(torch.__version__) >= LooseVersion("1.7.1"): - pa[n] = [e.replace("aten::", "") for e in pa[n]] - assert pa[n] == expected_[n] + assert events_name == expected, (events_name, torch.__version__, platform.system()) @RunIf(min_gpus=1, special=True) @@ -387,6 +433,43 @@ def test_pytorch_profiler_nested_emit_nvtx(tmpdir): trainer.fit(model) +@RunIf(min_torch="1.5.0") +def test_register_record_function(tmpdir): + + use_cuda = torch.cuda.is_available() + pytorch_profiler = PyTorchProfiler( + export_to_chrome=False, + record_functions=["a"], + use_cuda=use_cuda, + dirpath=tmpdir, + filename="profiler", + ) + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.ReLU(), torch.nn.Linear(8, 2)) + + model = TestModel() + input = torch.rand((1, 8)) + + if use_cuda: + model = model.cuda() + input = input.cuda() + + with pytorch_profiler.profile("a"): + with RegisterRecordFunction(model): + model(input) + + pytorch_profiler.describe() + event_names = [e.name for e in pytorch_profiler.function_events] + assert 'torch.nn.modules.container.Sequential: layer' in event_names + assert 'torch.nn.modules.linear.Linear: layer.0' in event_names + assert 'torch.nn.modules.activation.ReLU: layer.1' in event_names + assert 'torch.nn.modules.linear.Linear: layer.2' in event_names + + @pytest.mark.parametrize("cls", (SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) def test_profiler_teardown(tmpdir, cls): """ @@ -407,6 +490,9 @@ def on_fit_end(self, trainer, *args, **kwargs) -> None: assert profiler._output_file is None +@pytest.mark.skipif(_TORCH_GREATER_EQUAL_1_8, reason="currently not supported for PyTorch 1.8") def test_pytorch_profiler_deepcopy(pytorch_profiler): + pytorch_profiler.start("on_train_start") + torch.tensor(1) pytorch_profiler.describe() assert deepcopy(pytorch_profiler) diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 3eb0596b55fc40..5dc1ea5de4e8a0 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -80,23 +80,3 @@ def test_get_model_gpu(tmpdir): gpus=1, ) trainer.fit(model) - - -@RunIf(min_gpus=1, skip_windows=True) -def test_get_model_ddp_gpu(tmpdir): - """ - Tests that `trainer.lightning_module` extracts the model correctly when using GPU + ddp accelerators - """ - - model = TrainerGetModel() - - limit_train_batches = 2 - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=limit_train_batches, - limit_val_batches=2, - max_epochs=1, - gpus=1, - ) - trainer.fit(model) - return 1 From fd5cb7fcc32bd0962ba9b978489cd6014cfa6a6f Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 23 Mar 2021 20:43:21 +0000 Subject: [PATCH 19/22] Add PyTorch 1.8 Profiler 5/5 (#6618) * Refactor profilers * Update PassThrough * WIP - This is broken and will change * Update pytorch_lightning/profiler/pytorch.py Co-authored-by: thomas chaton * resolve tests * resolve tests * find output * try something * update * add support for test and predict * update * update * use getattr * test * test * update * tests * update * update * update * update * update * remove file * update * update * update * update * update * test * update# * update * update tests * update * add suport for 1.8 * rename records * add support for 1.8 * update * resolve flake8 * resolve test * Refactor basic profilers * Fixes * Unused import * Introduce setup * Profile on all ranks. Print to stdout on 0 * Introduce dirpath + filename * CHANGELOG * Add tests. Address comments * add `on_run_stage_setup` * add on_run_stage_setup function * update * add test for RegisterRecordFunction * update lightnng flow direction * move variable to private * remove trace * Undo code that should be in 3/4 * Multi-stage multi-rank * 2/5 changes * Pass stage in __del__ * Remove TODOs * Describe on_evaluation_end. Add tests * Typo * Address comments * deepcopy tests * Advanced teardown * Fix teardown test * Fix tests * Minor change * Update CHANGELOG.md * Fix test * Quick fixes * Fix 6522 * resolve ddp tests * resolve tests * resolve some tests * update tests * resolve tests * update * resolve tests * resolve some tests * Missed fixes from 3/5 * Fixes * resolve some tests * resolve test for 1.7.1 * Broken refactor * Missed stage * Minor changes * resolve tests * Update CHANGELOG * resolve bug * remove print * Typo * Cleanup * resolve ddp test * remove barrier * update profiler * update * Smaller model * update * resolve tests * update * Minor changes. CHANGELOG * Minimize diff * update to 1.8.1 * RunIf. Extra code. Check segfault * resolve tests * Typo. Bad merge * Fixing a bad merge * replace for kineto * Update pytorch_lightning/profiler/pytorch.py Co-authored-by: ananthsub * Update pytorch_lightning/profiler/pytorch.py Co-authored-by: ananthsub * Minor changes * Bad merge * Use lists for flexibility * Use sets * predict_step * Ananth's suggestion * update * Docs * Update pl_examples/basic_examples/profiler_example.py Co-authored-by: Jirka Borovec * update example * update example Co-authored-by: Carlos Mocholi Co-authored-by: ananthsub Co-authored-by: Jirka Borovec --- CHANGELOG.md | 3 + .../basic_examples/profiler_example.py | 102 +++++++ pytorch_lightning/accelerators/accelerator.py | 12 +- pytorch_lightning/profiler/__init__.py | 25 +- pytorch_lightning/profiler/profilers.py | 4 +- pytorch_lightning/profiler/pytorch.py | 257 +++++++++++++++--- pytorch_lightning/utilities/imports.py | 1 + tests/accelerators/test_cpu.py | 8 +- tests/helpers/runif.py | 7 + tests/helpers/test_datasets.py | 12 +- tests/test_profiler.py | 63 +++-- tests/utilities/test_argparse.py | 1 + 12 files changed, 399 insertions(+), 96 deletions(-) create mode 100644 pl_examples/basic_examples/profiler_example.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e1106189e0c17a..1b3359ace54f92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for including module names for forward in the autograd trace of `PyTorchProfiler` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) +- Added support for the PyTorch 1.8.1 autograd profiler ([#6618](https://github.com/PyTorchLightning/pytorch-lightning/pull/6618)) + + - Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) diff --git a/pl_examples/basic_examples/profiler_example.py b/pl_examples/basic_examples/profiler_example.py new file mode 100644 index 00000000000000..ca640a96f9588b --- /dev/null +++ b/pl_examples/basic_examples/profiler_example.py @@ -0,0 +1,102 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script will generate 2 traces: one for `training_step` and one for `validation_step`. +The traces can be visualized in 2 ways: +* With Chrome: + 1. Open Chrome and copy/paste this url: `chrome://tracing/`. + 2. Once tracing opens, click on `Load` at the top-right and load one of the generated traces. +* With PyTorch Tensorboard Profiler (Instructions are here: https://github.com/pytorch/kineto/tree/master/tb_plugin) + 1. pip install tensorboard torch-tb-profiler + 2. tensorboard --logdir={FOLDER} +""" + +import sys +from argparse import ArgumentParser + +import torch +import torchvision +import torchvision.models as models +import torchvision.transforms as T + +from pl_examples import cli_lightning_logo +from pytorch_lightning import LightningDataModule, LightningModule, Trainer + +DEFAULT_CMD_LINE = ( + "--max_epochs", + "1", + "--limit_train_batches", + "15", + "--limit_val_batches", + "15", + "--profiler", + "pytorch", + "--gpus", + f"{int(torch.cuda.is_available())}", +) + + +class ModelToProfile(LightningModule): + + def __init__(self, model): + super().__init__() + self.model = model + self.criterion = torch.nn.CrossEntropyLoss() + + def training_step(self, batch, batch_idx): + inputs, labels = batch + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + self.log("train_loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + inputs, labels = batch + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + self.log("val_loss", loss) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9) + + +class CIFAR10DataModule(LightningDataModule): + + transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()]) + + def train_dataloader(self, *args, **kwargs): + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=self.transform) + return torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=0) + + def val_dataloader(self, *args, **kwargs): + valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=self.transform) + return torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, num_workers=0) + + +def cli_main(): + + parser = ArgumentParser() + parser = Trainer.add_argparse_args(parser) + cmd_line = None if len(sys.argv) != 1 else DEFAULT_CMD_LINE + args = parser.parse_args(args=cmd_line) + + model = ModelToProfile(models.resnet50(pretrained=True)) + datamodule = CIFAR10DataModule() + trainer = Trainer(**vars(args)) + trainer.fit(model, datamodule=datamodule) + + +if __name__ == '__main__': + cli_lightning_logo() + cli_main() diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 4aa5fedf2b210c..1dcd541ca0610c 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -448,8 +448,10 @@ def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: Lightn .. deprecated::v1.3 Will be removed in v1.5.0. """ - rank_zero_warn('Accelerator method `connect_training_type_plugin` was deprecated in v1.3.' - ' It will be removed in v1.5.') + rank_zero_warn( + 'Accelerator method `connect_training_type_plugin` was deprecated in v1.3.' + ' It will be removed in v1.5.' + ) self.setup_training_type_plugin(plugin, model) # todo: remove in v1.5 @@ -459,6 +461,8 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: .. deprecated::v1.3 Will be removed in v1.5.0. """ - rank_zero_warn('Accelerator method `connect_precision_plugin` was deprecated in v1.3.' - ' It will be removed in v1.5.') + rank_zero_warn( + 'Accelerator method `connect_precision_plugin` was deprecated in v1.3.' + ' It will be removed in v1.5.' + ) self.setup_precision_plugin(plugin) diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py index e09a5ea11a084d..6ac6e16c185290 100644 --- a/pytorch_lightning/profiler/__init__.py +++ b/pytorch_lightning/profiler/__init__.py @@ -121,7 +121,8 @@ def custom_processing_step(self, data): Autograd includes a profiler that lets you inspect the cost of different operators inside your model - both on the CPU and GPU. -Find the Pytorch Profiler doc at [PyTorch Profiler](https://pytorch-lightning.readthedocs.io/en/stable/profiler.html) +To read more about the PyTorch Profiler and all its options, +have a look at its `docs `__ .. code-block:: python @@ -134,16 +135,16 @@ def custom_processing_step(self, data): This profiler works with PyTorch ``DistributedDataParallel``. -If ``output_filename`` is provided, each rank will save their profiled operation to their own file. +If ``filename`` is provided, each rank will save their profiled operation to their own file. The profiler +report can be quite long, so you setting a ``filename`` will save the report instead of logging it to the +output in your terminal. If no filename is given, it will be logged only on rank 0. +The profiler's results will be printed on the completion of ``{fit,validate,test,predict}``. -The profiler's results will be printed on the completion of a training `fit()`. This profiler -report can be quite long, so you can also specify an `output_filename` to save the report instead -of logging it to the output in your terminal. - -This profiler will record only for `training_step_and_backward`, `evaluation_step` and `test_step` functions by default. -The output below shows the profiling for the action `training_step_and_backward`. -The user can provide ``PyTorchProfiler(profiled_functions=[...])`` to extend the scope of profiled functions. +This profiler will record ``training_step_and_backward``, ``training_step``, ``backward``, +``validation_step``, ``test_step``, and ``predict_step`` by default. +The output below shows the profiling for the action ``training_step_and_backward``. +The user can provide ``PyTorchProfiler(record_functions={...})`` to extend the scope of profiled functions. .. note:: When using the PyTorch Profiler, wall clock time will not not be representative of the true wall clock time. This is due to forcing profiled operations to be measured synchronously, when many CUDA ops happen asynchronously. It is recommended to use this Profiler to find bottlenecks/breakdowns, however for end to end wall clock time use the `SimpleProfiler`. # noqa E501 @@ -184,13 +185,13 @@ def custom_processing_step(self, data): To visualize the profiled operation, you can either: -* Use:: +Use:: nvvp trace_name.prof -* Use:: +Or:: - python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))' + python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))' """ diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index 46d72583fb466e..bc9e3541dbaa8a 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -120,14 +120,14 @@ def _rank_zero_info(self, *args, **kwargs) -> None: if self._local_rank in (None, 0): log.info(*args, **kwargs) - def _prepare_filename(self) -> str: + def _prepare_filename(self, extension: str = ".txt") -> str: filename = "" if self._stage is not None: filename += f"{self._stage}-" filename += str(self.filename) if self._local_rank is not None: filename += f"-{self._local_rank}" - filename += ".txt" + filename += extension return filename def _prepare_streams(self) -> None: diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 974883a4724c6f..73abc1baf939d8 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -17,7 +17,7 @@ import os from functools import partial from pathlib import Path -from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Set, Type, TYPE_CHECKING, Union import torch from torch import nn, Tensor @@ -26,6 +26,7 @@ from pytorch_lightning.profiler.profilers import BaseProfiler from pytorch_lightning.utilities.distributed import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE if TYPE_CHECKING: from torch.autograd.profiler import EventList @@ -33,6 +34,9 @@ from pytorch_lightning.core.lightning import LightningModule +if _KINETO_AVAILABLE: + from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler + log = logging.getLogger(__name__) _PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx] @@ -93,12 +97,108 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self._handles = {} +class ScheduleWrapper: + """ + This class is used to override the schedule logic from the profiler and perform + recording for both `training_step`, `validation_step`. + """ + + def __init__(self, schedule: Callable) -> None: + if not _KINETO_AVAILABLE: + raise ModuleNotFoundError("You are trying to use `ScheduleWrapper` which require kineto install.") + self._schedule = schedule + self._num_training_step_and_backward = 0 + self._num_validation_step = 0 + self._num_test_step = 0 + self._num_predict_step = 0 + self._training_step_and_backward_reached_end = False + self._validation_step_reached_end = False + self._test_step_reached_end = False + self._predict_step_reached_end = False + # used to stop profiler when `ProfilerAction.RECORD_AND_SAVE` is reached. + self._current_action: Optional[str] = None + self._start_action_name: Optional[str] = None + + def setup(self, start_action_name: str) -> None: + self._start_action_name = start_action_name + + def pre_step(self, current_action: str) -> None: + self._current_action = current_action + + @property + def num_step(self) -> int: + if self._current_action == "training_step_and_backward": + return self._num_training_step_and_backward + elif self._current_action == "validation_step": + return self._num_validation_step + elif self._current_action == "test_step": + return self._num_test_step + elif self._current_action == "predict_step": + return self._num_predict_step + else: + return 0 + + def _step(self) -> None: + if self._current_action == "training_step_and_backward": + self._num_training_step_and_backward += 1 + elif self._current_action == "validation_step": + if self._start_action_name == "on_train_start" and self._num_training_step_and_backward > 0: + self._num_validation_step += 1 + else: + self._num_validation_step += 1 + elif self._current_action == "test_step": + self._num_test_step += 1 + elif self._current_action == "predict_step": + self._num_predict_step += 1 + + @property + def has_finished(self) -> bool: + if self._current_action == "training_step_and_backward": + return self._training_step_and_backward_reached_end + elif self._current_action == "validation_step": + return self._validation_step_reached_end + elif self._current_action == "test_step": + return self._test_step_reached_end + elif self._current_action == "predict_step": + return self._predict_step_reached_end + return False + + def __call__(self, num_step: int) -> 'ProfilerAction': + # ignore the provided input. Keep internal state instead. + if self.has_finished: + return ProfilerAction.NONE + + self._step() + action = self._schedule(self.num_step) + if action == ProfilerAction.RECORD_AND_SAVE: + if self._current_action == "training_step_and_backward": + self._training_step_and_backward_reached_end = True + elif self._current_action == "validation_step": + self._validation_step_reached_end = True + elif self._current_action == "test_step": + self._test_step_reached_end = True + elif self._current_action == "predict_step": + self._predict_step_reached_end = True + return action + + class PyTorchProfiler(BaseProfiler): - RECORD_FUNCTIONS = ( - "training_step_and_backward", "training_step", "backward", "validation_step", "test_step", "predict_step" - ) - AVAILABLE_SORT_KEYS = ( + RECORD_FUNCTIONS = { + "training_step_and_backward", + "training_step", + "backward", + "validation_step", + "test_step", + "predict_step", + } + STEP_FUNCTIONS = { + "training_step_and_backward", + "validation_step", + "test_step", + "predict_step", + } + AVAILABLE_SORT_KEYS = { "cpu_time", "cuda_time", "cpu_time_total", @@ -108,8 +208,13 @@ class PyTorchProfiler(BaseProfiler): "self_cpu_memory_usage", "self_cuda_memory_usage", "count", - ) - START_RECORD_FUNCTIONS = ('on_train_start', 'on_validation_start', 'on_test_start', 'on_predict_start') + } + START_RECORD_FUNCTIONS = { + 'on_train_start', + 'on_validation_start', + 'on_test_start', + 'on_predict_start', + } def __init__( self, @@ -118,10 +223,9 @@ def __init__( group_by_input_shapes: bool = False, emit_nvtx: bool = False, export_to_chrome: bool = True, - path_to_export_trace: Optional[str] = None, row_limit: int = 20, sort_by_key: Optional[str] = None, - record_functions: List[str] = None, + record_functions: Set[str] = None, record_module_names: bool = True, profiled_functions: Optional[List] = None, output_filename: Optional[str] = None, @@ -154,9 +258,6 @@ def __init__( export_to_chrome: Whether to export the sequence of profiled operators for Chrome. It will generate a ``.json`` file which can be read by Chrome. - path_to_export_trace: Directory path to export ``.json`` traces when using ``export_to_chrome=True``. - By default, it will be save where the file being is being run. - row_limit: Limit the number of rows in a table, ``-1`` is a special value that removes the limit completely. @@ -166,7 +267,7 @@ def __init__( ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``, ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``. - record_functions: list of profiled functions which will create a context manager on. + record_functions: Set of profiled functions which will create a context manager on. Any other will be pass through. record_module_names: Whether to add module names while recording autograd operation. @@ -176,6 +277,8 @@ def __init__( Raises: MisconfigurationException: If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``. + If arg ``schedule`` is not a ``Callable``. + If arg ``schedule`` does not return a ``torch.profiler.ProfilerAction``. """ super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) @@ -184,11 +287,10 @@ def __init__( self._group_by_input_shapes = group_by_input_shapes and profiler_kwargs.get("record_shapes", False) self._emit_nvtx = emit_nvtx self._export_to_chrome = export_to_chrome - self._path_to_export_trace = path_to_export_trace self._row_limit = row_limit self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total" - self._record_functions_start = set(record_functions + list(self.START_RECORD_FUNCTIONS)) - self._record_functions = set(record_functions + list(self.RECORD_FUNCTIONS)) + self._record_functions_start = record_functions | self.START_RECORD_FUNCTIONS + self._record_functions = record_functions | self.RECORD_FUNCTIONS self._record_module_names = record_module_names self._profiler_kwargs = profiler_kwargs @@ -198,25 +300,48 @@ def __init__( self._register: Optional[RegisterRecordFunction] = None self._parent_profiler: Optional[_PROFILER] = None self._recording_map: Dict[str, record_function] = {} + self._start_action_name: Optional[str] = None + self._schedule: Optional[ScheduleWrapper] = None - if self._export_to_chrome and self._path_to_export_trace is None: - rank_zero_warn( - "The exported trace would be saved locally as `path_to_export_trace` is None." - " Note: Each functions will generate its own traced file." - ) + if _KINETO_AVAILABLE: + self.__init_kineto__(profiler_kwargs) if self._sort_by_key not in self.AVAILABLE_SORT_KEYS: raise MisconfigurationException( f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " ) + def __init_kineto__(self, profiler_kwargs: Any): + has_schedule = "schedule" in profiler_kwargs + self._has_on_trace_ready = "on_trace_ready" in profiler_kwargs + + schedule = profiler_kwargs.get("schedule", None) + if schedule is not None: + if not isinstance(schedule, Callable): + raise MisconfigurationException(f"Schedule should be a callable. Found: {schedule}") + action = schedule(0) + if not isinstance(action, ProfilerAction): + raise MisconfigurationException( + f"Schedule should return a `torch.profiler.ProfilerAction`. Found: {action}" + ) + schedule = schedule if has_schedule else self._default_schedule() + self._schedule = ScheduleWrapper(schedule) if schedule is not None else schedule + self._profiler_kwargs["schedule"] = self._schedule + + activities = profiler_kwargs.get("activities", None) + self._profiler_kwargs["activities"] = activities or self._default_activities() + self._export_to_flame_graph = profiler_kwargs.get("export_to_flame_graph", False) + self._metric = profiler_kwargs.get("metric", "self_cpu_time_total") + with_stack = profiler_kwargs.get("with_stack", False) or self._export_to_flame_graph + self._profiler_kwargs["with_stack"] = with_stack + def __deprecation_check( self, profiled_functions: Optional[List[str]], - record_functions: Optional[List[str]], - ) -> List[str]: + record_functions: Optional[Set[str]], + ) -> Set[str]: if record_functions is None: - record_functions = [] + record_functions = set() if profiled_functions is not None: rank_zero_warn( @@ -224,7 +349,7 @@ def __deprecation_check( " `record_functions` in v1.3 and will be removed in v1.5", DeprecationWarning ) if not record_functions: - record_functions += profiled_functions + record_functions |= set(profiled_functions) else: raise MisconfigurationException( "You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`." @@ -233,15 +358,25 @@ def __deprecation_check( return record_functions - def setup( - self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None - ) -> None: - super().setup(stage=stage, local_rank=local_rank, log_dir=log_dir) - - # if the user didn't provide `path_to_export_trace`, - # set it as TensorBoardLogger log_dir if exists - if self._path_to_export_trace is None: - self._path_to_export_trace = log_dir + @staticmethod + def _default_schedule() -> Optional[callable]: + if _KINETO_AVAILABLE: + # Those schedule defaults allow the profiling overhead to be negligible over training time. + return torch.profiler.schedule(wait=1, warmup=1, active=2) + + def _default_activities(self) -> List['ProfilerActivity']: + activities = [] + if not _KINETO_AVAILABLE: + return activities + if self._profiler_kwargs.get("use_cpu", True): + activities.append(ProfilerActivity.CPU) + if self._profiler_kwargs.get("use_cuda", torch.cuda.is_available()): + activities.append(ProfilerActivity.CUDA) + return activities + + @property + def step_action_names(self) -> Set[str]: + return self.STEP_FUNCTIONS | self._record_functions def start(self, action_name: str) -> None: if self.profiler is None and action_name in self._record_functions_start: @@ -253,11 +388,18 @@ def start(self, action_name: str) -> None: except (AttributeError, RuntimeError): pass + if self._schedule is not None: + self._schedule.setup(action_name) + self._create_profilers() - self.profiler.__enter__() + profiler = self.profiler.__enter__() + if profiler is not None: + self.profiler = profiler + if self._parent_profiler is not None: self._parent_profiler.__enter__() + if self._register is not None: self._register.__enter__() @@ -269,11 +411,39 @@ def start(self, action_name: str) -> None: recording.__enter__() self._recording_map[action_name] = recording + if self._schedule is not None: + self._schedule.pre_step(action_name) + def stop(self, action_name: str) -> None: if action_name in self._recording_map: self._recording_map[action_name].__exit__(None, None, None) del self._recording_map[action_name] + if not _KINETO_AVAILABLE or self._emit_nvtx: + return + + if action_name in self.step_action_names: + if self._schedule is not None: + self._schedule._current_action = action_name + + def on_trace_ready(profiler): + filename = f"{action_name}_{self.local_rank}" + + if self.dirpath is not None: + if self._export_to_chrome: + handler = tensorboard_trace_handler(self.dirpath, filename) + handler(profiler) + + if self._export_to_flame_graph: + path = os.path.join(self.dirpath, self._prepare_filename(extension=".stack")) + profiler.export_stacks(path, metric=self._metric) + else: + rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None") + + if not self._has_on_trace_ready: + self.profiler.on_trace_ready = on_trace_ready + self.profiler.step() + def summary(self) -> str: if not self._profiler_kwargs.get("enabled", True) or self._emit_nvtx: return "" @@ -283,11 +453,9 @@ def summary(self) -> str: if not self.function_events: return "" - if self._export_to_chrome: + if self._export_to_chrome and not _KINETO_AVAILABLE: filename = f"{self.local_rank}_trace.json" - path_to_trace = ( - filename if self._path_to_export_trace is None else os.path.join(self._path_to_export_trace, filename) - ) + path_to_trace = (filename if self.dirpath is None else os.path.join(self.dirpath, filename)) self.function_events.export_chrome_trace(path_to_trace) data = self.function_events.key_averages(group_by_input_shapes=self._group_by_input_shapes) @@ -302,7 +470,9 @@ def _create_profilers(self) -> None: self.profiler = self._create_profiler(torch.autograd.profiler.emit_nvtx) else: self._parent_profiler = None - self.profiler = self._create_profiler(torch.autograd.profiler.profile) + self.profiler = self._create_profiler( + torch.profiler.profile if _KINETO_AVAILABLE else torch.autograd.profiler.profile + ) if self._record_module_names and self._lightning_module is not None: self._register = RegisterRecordFunction(self._lightning_module) @@ -311,9 +481,10 @@ def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER: kwargs = {k: v for k, v in self._profiler_kwargs.items() if k in init_parameters} return profiler(**kwargs) - def _cache_functions_events(self): - if not self._emit_nvtx: - self.function_events = self.profiler.function_events + def _cache_functions_events(self) -> None: + if self._emit_nvtx: + return + self.function_events = self.profiler.events() if _KINETO_AVAILABLE else self.profiler.function_events def _delete_profilers(self) -> None: if self.profiler is not None: diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 5a780660a0a99a..baeac9be572184 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -70,6 +70,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") _TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0") +_KINETO_AVAILABLE = torch.profiler.kineto_available() if _TORCH_GREATER_EQUAL_1_8 else False _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available('pl_bolts') _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index bcb351984a1752..46379a9d10c14f 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -30,6 +30,7 @@ def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch): """ class TestModel(BoringModel): + def on_fit_start(self): if delay_dispatch: # Ensure we haven't setup optimizers if we've delayed dispatch @@ -41,14 +42,11 @@ def on_fit_end(self): assert len(self.trainer.optimizers) > 0 class CustomPlugin(SingleDevicePlugin): + @property def setup_optimizers_in_pre_dispatch(self) -> bool: return delay_dispatch model = TestModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - plugins=CustomPlugin(device=torch.device("cpu")) - ) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=CustomPlugin(device=torch.device("cpu"))) trainer.fit(model) diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index fe85fbaea90255..5483e33d9cddb4 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -56,6 +56,7 @@ def __new__( *args, min_gpus: int = 0, min_torch: Optional[str] = None, + max_torch: Optional[str] = None, min_python: Optional[str] = None, quantization: bool = False, amp_apex: bool = False, @@ -76,6 +77,7 @@ def __new__( args: native pytest.mark.skipif arguments min_gpus: min number of gpus required to run test min_torch: minimum pytorch version to run test + max_torch: maximum pytorch version to run test min_python: minimum python version required to run test quantization: if `torch.quantization` package is required to run test amp_apex: NVIDIA Apex is installed @@ -102,6 +104,11 @@ def __new__( conditions.append(torch_version < LooseVersion(min_torch)) reasons.append(f"torch>={min_torch}") + if max_torch: + torch_version = LooseVersion(get_distribution("torch").version) + conditions.append(torch_version >= LooseVersion(max_torch)) + reasons.append(f"torch<{max_torch}") + if min_python: py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" conditions.append(py_version < LooseVersion(min_python)) diff --git a/tests/helpers/test_datasets.py b/tests/helpers/test_datasets.py index 42b5df0ff91a4c..8c866bdbab789c 100644 --- a/tests/helpers/test_datasets.py +++ b/tests/helpers/test_datasets.py @@ -20,11 +20,13 @@ from tests.helpers.datasets import AverageDataset, MNIST, TrialMNIST -@pytest.mark.parametrize('dataset_cls,args', [ - (MNIST, dict(root=PATH_DATASETS)), - (TrialMNIST, dict(root=PATH_DATASETS)), - (AverageDataset, dict()), -]) +@pytest.mark.parametrize( + 'dataset_cls,args', [ + (MNIST, dict(root=PATH_DATASETS)), + (TrialMNIST, dict(root=PATH_DATASETS)), + (AverageDataset, dict()), + ] +) def test_pickling_dataset_mnist(tmpdir, dataset_cls, args): mnist = dataset_cls(**args) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 5d144aef36573a..a6e33b3366f33e 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -17,7 +17,6 @@ import time from copy import deepcopy from distutils.version import LooseVersion -from pathlib import Path import numpy as np import pytest @@ -27,7 +26,7 @@ from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.profiler.pytorch import RegisterRecordFunction from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 +from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -266,6 +265,7 @@ def pytorch_profiler(tmpdir): return PyTorchProfiler(dirpath=tmpdir, filename="profiler") +@RunIf(max_torch="1.8.1") def test_pytorch_profiler_describe(pytorch_profiler): """Ensure the profiler won't fail when reporting the summary.""" with pytorch_profiler.profile("on_test_start"): @@ -302,30 +302,41 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler): """Ensure that the profiler can be given to the training and default step are properly recorded. """ model = BoringModel() trainer = Trainer( - max_epochs=1, default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, + max_epochs=1, + limit_train_batches=5, + limit_val_batches=5, profiler=pytorch_profiler, accelerator="ddp", gpus=2, ) trainer.fit(model) - expected = ('validation_step', 'training_step_and_backward', 'training_step', 'backward') + expected = {'validation_step'} + if not _KINETO_AVAILABLE: + expected |= {'training_step_and_backward', 'training_step', 'backward'} for name in expected: - assert sum(e.name == name for e in pytorch_profiler.function_events) + assert sum(e.name == name for e in pytorch_profiler.function_events), name files = set(os.listdir(pytorch_profiler.dirpath)) expected = f"fit-profiler-{trainer.local_rank}.txt" assert expected in files - path = os.path.join(pytorch_profiler.dirpath, expected) - assert Path(path).read_text() + path = pytorch_profiler.dirpath / expected + assert path.read_text("utf-8") + + if _KINETO_AVAILABLE: + files = os.listdir(pytorch_profiler.dirpath) + files = [file for file in files if file.endswith('.json')] + assert len(files) == 2, files + local_rank = trainer.local_rank + assert any(f'training_step_{local_rank}' in f for f in files) + assert any(f'validation_step_{local_rank}' in f for f in files) -def test_pytorch_profiler_trainer_test(tmpdir, pytorch_profiler): +def test_pytorch_profiler_trainer_test(tmpdir): """Ensure that the profiler can be given to the trainer and test step are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, @@ -340,27 +351,32 @@ def test_pytorch_profiler_trainer_test(tmpdir, pytorch_profiler): path = pytorch_profiler.dirpath / f"test-{pytorch_profiler.filename}.txt" assert path.read_text("utf-8") + if _KINETO_AVAILABLE: + files = sorted([file for file in os.listdir(tmpdir) if file.endswith('.json')]) + assert any(f'test_step_{trainer.local_rank}' in f for f in files) + -def test_pytorch_profiler_trainer_predict(tmpdir, pytorch_profiler): +def test_pytorch_profiler_trainer_predict(tmpdir): """Ensure that the profiler can be given to the trainer and predict function are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) model = BoringModel() model.predict_dataloader = model.train_dataloader trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - limit_test_batches=2, + limit_predict_batches=2, profiler=pytorch_profiler, ) trainer.predict(model) assert sum(e.name == 'predict_step' for e in pytorch_profiler.function_events) - path = pytorch_profiler.dirpath / f"predict-{pytorch_profiler.filename}.txt" assert path.read_text("utf-8") -def test_pytorch_profiler_trainer_validate(tmpdir, pytorch_profiler): +def test_pytorch_profiler_trainer_validate(tmpdir): """Ensure that the profiler can be given to the trainer and validate function are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, @@ -380,7 +396,7 @@ def test_pytorch_profiler_nested(tmpdir): """Ensure that the profiler handles nested context""" pytorch_profiler = PyTorchProfiler( - profiled_functions=["a", "b", "c"], use_cuda=False, dirpath=tmpdir, filename="profiler" + record_functions={"a", "b", "c"}, use_cuda=False, dirpath=tmpdir, filename="profiler", schedule=None ) with pytorch_profiler.profile("a"): @@ -409,11 +425,6 @@ def test_pytorch_profiler_nested(tmpdir): 'aten::zeros', 'aten::add', 'aten::zero_', 'c', 'b', 'a', 'aten::fill_', 'aten::empty', 'aten::ones' } - if LooseVersion(torch.__version__) >= LooseVersion("1.8.0"): - expected = { - 'aten::ones', 'a', 'aten::add', 'aten::empty', 'aten::zero_', 'b', 'c', 'aten::zeros', 'aten::fill_' - } - assert events_name == expected, (events_name, torch.__version__, platform.system()) @@ -439,20 +450,22 @@ def test_register_record_function(tmpdir): use_cuda = torch.cuda.is_available() pytorch_profiler = PyTorchProfiler( export_to_chrome=False, - record_functions=["a"], + record_functions={"a"}, use_cuda=use_cuda, dirpath=tmpdir, filename="profiler", + schedule=None, + on_trace_ready=None, ) class TestModel(BoringModel): def __init__(self): super().__init__() - self.layer = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.ReLU(), torch.nn.Linear(8, 2)) + self.layer = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), torch.nn.Linear(1, 1)) model = TestModel() - input = torch.rand((1, 8)) + input = torch.rand((1, 1)) if use_cuda: model = model.cuda() @@ -490,8 +503,8 @@ def on_fit_end(self, trainer, *args, **kwargs) -> None: assert profiler._output_file is None -@pytest.mark.skipif(_TORCH_GREATER_EQUAL_1_8, reason="currently not supported for PyTorch 1.8") -def test_pytorch_profiler_deepcopy(pytorch_profiler): +def test_pytorch_profiler_deepcopy(tmpdir): + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profiler", schedule=None) pytorch_profiler.start("on_train_start") torch.tensor(1) pytorch_profiler.describe() diff --git a/tests/utilities/test_argparse.py b/tests/utilities/test_argparse.py index aef266d639b4ae..f13af4362364ca 100644 --- a/tests/utilities/test_argparse.py +++ b/tests/utilities/test_argparse.py @@ -18,6 +18,7 @@ class ArgparseExample: + def __init__(self, a: int = 0, b: str = '', c: bool = False): self.a = a self.b = b From 64d0fa44720da0d5e77f751faa5569950a25619e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 23 Mar 2021 23:05:04 +0100 Subject: [PATCH 20/22] update coverage config (#6524) * update coverage config * parallel * parallel * Apply suggestions from code review * Apply suggestions from code review * paralel * paralel * paralel * combine * combine * . * .. * .. * .. * rev * cb * cb * drop * drop * . * .. * ... * ... * ... * . --- .github/workflows/ci_test-base.yml | 2 +- .github/workflows/ci_test-conda.yml | 2 +- .github/workflows/ci_test-full.yml | 2 +- azure-pipelines.yml | 2 +- requirements/test.txt | 4 ++-- setup.cfg | 5 ----- tests/special_tests.sh | 2 +- 7 files changed, 7 insertions(+), 12 deletions(-) diff --git a/.github/workflows/ci_test-base.yml b/.github/workflows/ci_test-base.yml index 0e84642e2f8109..77363992718af4 100644 --- a/.github/workflows/ci_test-base.yml +++ b/.github/workflows/ci_test-base.yml @@ -68,7 +68,7 @@ jobs: - name: Test Package [only] run: | # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 - python -m pytest pytorch_lightning -v --cov=pytorch_lightning --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml + coverage run --source pytorch_lightning -m pytest pytorch_lightning -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - name: Upload pytest test results uses: actions/upload-artifact@v2 diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 812d06f3108127..da853bf623d1bf 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -44,7 +44,7 @@ jobs: - name: Tests run: | # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 - python -m pytest pytorch_lightning tests --cov=pytorch_lightning -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml + coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml shell: bash -l {0} - name: Upload pytest results diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index ba8d8044149939..5a3e23a37fd0b9 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -134,7 +134,7 @@ jobs: - name: Tests run: | # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 - python -m pytest pytorch_lightning tests --cov=pytorch_lightning -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}.xml + coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}.xml - name: Examples run: | diff --git a/azure-pipelines.yml b/azure-pipelines.yml index b7a2d851052edb..d88a31ae9775a3 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -78,7 +78,7 @@ jobs: displayName: 'Get legacy checkpoints' - bash: | - python -m pytest pytorch_lightning tests -v --cov=pytorch_lightning --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50 + python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50 displayName: 'Testing: standard' - bash: | diff --git a/requirements/test.txt b/requirements/test.txt index 099a6fe43b6e62..259cc2e2d64424 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,8 +1,8 @@ coverage>5.2.0 codecov>=2.1 pytest>=6.0 -pytest-cov>2.10 -# pytest-xdist +#pytest-cov>2.10 +#pytest-xdist flake8>=3.6 check-manifest twine==3.2 diff --git a/setup.cfg b/setup.cfg index ab1e1e8c1addc7..3775eb0070f7ca 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,11 +47,6 @@ omit = pytorch_lightning/utilities/xla_device_utils.py pytorch_lightning/utilities/distributed.py pytorch_lightning/tuner/auto_gpu_select.py - # TODO: temporary, until accelerator refactor is finished - pytorch_lightning/accelerators/accelerator.py - pytorch_lightning/plugins/training_type/*.py - pytorch_lightning/plugins/precision/*.py - pytorch_lightning/plugins/base_plugin.py [flake8] diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 3fe9d6c0e277c6..c381b5e9feeb6a 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -14,7 +14,7 @@ # Running special tests set -e export PL_RUNNING_SPECIAL_TESTS=1 -DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" +DEFAULTS="-m coverage run --source pytorch_lightning --append -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_invalid_deepspeed_defaults_no_precision From 741c452551780e110938c8635db496682784be07 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 23 Mar 2021 22:07:48 +0000 Subject: [PATCH 21/22] Fix disabled grads after call to predict (#6657) --- CHANGELOG.md | 3 +++ pytorch_lightning/trainer/trainer.py | 4 ++++ tests/trainer/test_trainer.py | 13 +++++++++++++ 3 files changed, 20 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b3359ace54f92..6a1e85d4add8db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -191,6 +191,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434)) +- Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657)) + + ## [1.2.4] - 2021-03-16 ### Changed diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bb5d6919964e53..dbc493aa76e040 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -800,6 +800,10 @@ def run_predict(self): results = self.predict_loop.on_predict_epoch_end() self.predict_loop.on_predict_end() + + # re-enable grads + torch.set_grad_enabled(True) + return results def run_sanity_check(self, ref_model): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d461d9d152e743..490f205a7bbec2 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1450,6 +1450,19 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): predict(tmpdir, None, None, 1, model=CustomBoringModel()) +def test_trainer_predict_grad(tmpdir): + class CustomBoringModel(BoringModel): + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + assert batch.expand_as(batch).grad_fn is None + return super().predict_step(batch, batch_idx, dataloader_idx) + + predict(tmpdir, None, None, 1, model=CustomBoringModel()) + + x = torch.zeros(1, requires_grad=True) + assert x.expand_as(x).grad_fn is not None + + @pytest.mark.parametrize('datamodule', [False, True]) def test_trainer_predict_cpu(tmpdir, datamodule): predict(tmpdir, None, None, 1, datamodule=datamodule) From b1e3dcc607522b06e88d0cb086ab655b49c88b35 Mon Sep 17 00:00:00 2001 From: Eric Rubiel Date: Tue, 23 Mar 2021 19:08:57 -0400 Subject: [PATCH 22/22] Use `pl.LightningModule` in new-project docs (#6656) --- docs/source/starter/new-project.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index f68865f3695c34..7a1164b1bdf3a1 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -83,7 +83,7 @@ Step 1: Define LightningModule .. testcode:: - class LitAutoEncoder(LightningModule): + class LitAutoEncoder(pl.LightningModule): def __init__(self): super().__init__()