diff --git a/.github/workflows/ci_test-base.yml b/.github/workflows/ci_test-base.yml index 0e84642e2f8109..77363992718af4 100644 --- a/.github/workflows/ci_test-base.yml +++ b/.github/workflows/ci_test-base.yml @@ -68,7 +68,7 @@ jobs: - name: Test Package [only] run: | # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 - python -m pytest pytorch_lightning -v --cov=pytorch_lightning --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml + coverage run --source pytorch_lightning -m pytest pytorch_lightning -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - name: Upload pytest test results uses: actions/upload-artifact@v2 diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 812d06f3108127..da853bf623d1bf 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -44,7 +44,7 @@ jobs: - name: Tests run: | # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 - python -m pytest pytorch_lightning tests --cov=pytorch_lightning -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml + coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml shell: bash -l {0} - name: Upload pytest results diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index ba8d8044149939..5a3e23a37fd0b9 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -134,7 +134,7 @@ jobs: - name: Tests run: | # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 - python -m pytest pytorch_lightning tests --cov=pytorch_lightning -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}.xml + coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}.xml - name: Examples run: | diff --git a/.github/workflows/docs-checks.yml b/.github/workflows/docs-checks.yml index 5ee4f23b4b3ccd..4488c598c8ac75 100644 --- a/.github/workflows/docs-checks.yml +++ b/.github/workflows/docs-checks.yml @@ -98,7 +98,7 @@ jobs: # First run the same pipeline as Read-The-Docs cd docs make clean - make html --debug --jobs $(nproc) SPHINXOPTS="-W" + make html --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going" - name: Upload built docs uses: actions/upload-artifact@v2 diff --git a/.gitignore b/.gitignore index cd0ba22453512d..99939ff7fce0cd 100644 --- a/.gitignore +++ b/.gitignore @@ -157,3 +157,4 @@ tags data MNIST runs +*trace* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 21c52539a890d3..45eca43de93ac8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,8 +33,3 @@ repos: hooks: - id: yapf args: [--parallel, --in-place] - - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.790 - hooks: - - id: mypy diff --git a/CHANGELOG.md b/CHANGELOG.md index d696535311c9d3..6a1e85d4add8db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,13 +9,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added `RetrievalMAP` metric, the corresponding functional version `retrieval_average_precision` and a generic superclass for retrieval metrics `RetrievalMetric` ([#5032](https://github.com/PyTorchLightning/pytorch-lightning/pull/5032)) - - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) + - Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) + - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) @@ -37,11 +37,28 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) +- Added `teardown` method to `BaseProfiler` to enable subclasses defining post-profiling steps outside of `__del__` ([#6370](https://github.com/PyTorchLightning/pytorch-lightning/pull/6370)) + + +- Added `setup` method to `BaseProfiler` to enable subclasses defining pre-profiling steps for every process ([#6633](https://github.com/PyTorchLightning/pytorch-lightning/pull/6633)) + + - Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) -- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) +- Added `Trainer.predict` config validation ([#6543](https://github.com/PyTorchLightning/pytorch-lightning/pull/6543)) + + +- Added `AbstractProfiler` interface ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) + +- Added support for including module names for forward in the autograd trace of `PyTorchProfiler` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) + + +- Added support for the PyTorch 1.8.1 autograd profiler ([#6618](https://github.com/PyTorchLightning/pytorch-lightning/pull/6618)) + + +- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) ### Changed @@ -58,6 +75,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) +- Changed profilers to save separate report files per state and rank ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) + + +- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) + + ### Deprecated - `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) @@ -66,6 +89,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) +- Deprecated `Profiler(output_filename)` in favor of `dirpath` and `filename` ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621)) + + +- Deprecated `PytorchProfiler(profiled_functions)` in favor of `record_functions` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) + + - Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505), [#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530), @@ -80,6 +109,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#6584](https://github.com/PyTorchLightning/pytorch-lightning/pull/6584), + [#6636](https://github.com/PyTorchLightning/pytorch-lightning/pull/6636), + + [#6637](https://github.com/PyTorchLightning/pytorch-lightning/pull/6637), + ) @@ -118,6 +151,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565)) + - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) @@ -145,9 +179,21 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416)) +- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587)) + + +- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576)) + + - Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506)) +- Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434)) + + +- Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657)) + + ## [1.2.4] - 2021-03-16 ### Changed @@ -168,12 +214,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) -- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587)) - - -- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576)) - - ## [1.2.3] - 2021-03-09 ### Fixed diff --git a/Makefile b/Makefile index d35e0b77f84296..04b08fa2d27d1a 100644 --- a/Makefile +++ b/Makefile @@ -29,4 +29,4 @@ test: clean docs: clean pip install --quiet -r requirements/docs.txt - python -m sphinx -b html -W docs/source docs/build + python -m sphinx -b html -W --keep-going docs/source docs/build diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 9e2ff77563fa01..d88a31ae9775a3 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -78,7 +78,7 @@ jobs: displayName: 'Get legacy checkpoints' - bash: | - python -m pytest pytorch_lightning tests -v --cov=pytorch_lightning --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50 + python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=50 displayName: 'Testing: standard' - bash: | @@ -121,4 +121,6 @@ jobs: # cd pl_examples/basic_examples # bash submit_ddp_job.sh # bash submit_ddp2_job.sh + env: + PL_USE_MOCKED_MNIST: "1" displayName: 'Examples' diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index a2a74c7587ae3b..5cdb0b377f2b75 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -267,7 +267,7 @@ Lightning allows multiple ways of training - TPUs (``tpu_cores=8|x``) (tpu or TPU pod) .. note:: - If you request multiple GPUs or nodes without setting a mode, DDP will be automatically used. + If you request multiple GPUs or nodes without setting a mode, DDP Spawn will be automatically used. For a deeper understanding of what Lightning is doing, feel free to read this `guide `_. diff --git a/docs/source/conf.py b/docs/source/conf.py index ccf824bb37d9b7..6163de976da405 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,7 +13,6 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. # import m2r -import builtins import glob import os import shutil @@ -27,10 +26,13 @@ FOLDER_GENERATED = 'generated' SPHINX_MOCK_REQUIREMENTS = int(os.environ.get('SPHINX_MOCK_REQUIREMENTS', True)) -if SPHINX_MOCK_REQUIREMENTS: - builtins.__LIGHTNING_SETUP__ = True -import pytorch_lightning # noqa: E402 +try: + from pytorch_lightning import info +except ImportError: + # alternative https://stackoverflow.com/a/67692/4521646 + sys.path.append(os.path.join(PATH_ROOT, "pytorch_lightning")) + import info # -- Project documents ------------------------------------------------------- @@ -79,13 +81,13 @@ def _transform_changelog(path_in: str, path_out: str) -> None: # -- Project information ----------------------------------------------------- project = 'PyTorch Lightning' -copyright = pytorch_lightning.__copyright__ -author = pytorch_lightning.__author__ +copyright = info.__copyright__ +author = info.__author__ # The short X.Y version -version = pytorch_lightning.__version__ +version = info.__version__ # The full version, including alpha/beta/rc tags -release = pytorch_lightning.__version__ +release = info.__version__ # -- General configuration --------------------------------------------------- @@ -176,8 +178,8 @@ def _transform_changelog(path_in: str, path_out: str) -> None: # documentation. html_theme_options = { - 'pytorch_project': pytorch_lightning.__homepage__, - 'canonical_url': pytorch_lightning.__homepage__, + 'pytorch_project': info.__homepage__, + 'canonical_url': info.__homepage__, 'collapse_navigation': False, 'display_version': True, 'logo_only': False, @@ -279,6 +281,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None: 'torch': ('https://pytorch.org/docs/stable/', None), 'numpy': ('https://numpy.org/doc/stable/', None), 'PIL': ('https://pillow.readthedocs.io/en/stable/', None), + 'torchmetrics': ('https://torchmetrics.readthedocs.io/en/stable/', None), } # -- Options for todo extension ---------------------------------------------- @@ -331,6 +334,7 @@ def package_list_from_file(file): } MOCK_PACKAGES = [] if SPHINX_MOCK_REQUIREMENTS: + MOCK_PACKAGES += ['fairscale'] # mock also base packages when we are on RTD since we don't install them there MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements.txt')) MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements', 'extra.txt')) diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst index c65894367a39e8..551b8182caa7d1 100644 --- a/docs/source/starter/introduction_guide.rst +++ b/docs/source/starter/introduction_guide.rst @@ -882,8 +882,8 @@ Or maybe we have a model that we use to do generation generated_imgs = model(z) -To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict`` function -By default, LightningModule ``predict`` calls forward, but it can be overriden to add any processing logic. +To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict_step`` function +By default, LightningModule ``predict_step`` calls forward, but it can be overriden to add any processing logic. .. code-block:: python @@ -893,7 +893,7 @@ By default, LightningModule ``predict`` calls forward, but it can be overriden t imgs = self.decoder(z) return imgs - def predict(self, batch, batch_idx: int , dataloader_idx: int = None): + def predict_step(self, batch, batch_idx: int , dataloader_idx: int = None): return self(batch) diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index f68865f3695c34..7a1164b1bdf3a1 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -83,7 +83,7 @@ Step 1: Define LightningModule .. testcode:: - class LitAutoEncoder(LightningModule): + class LitAutoEncoder(pl.LightningModule): def __init__(self): super().__init__() diff --git a/pl_examples/__init__.py b/pl_examples/__init__.py index ffd60f9ed71af4..150ac309ddcebd 100644 --- a/pl_examples/__init__.py +++ b/pl_examples/__init__.py @@ -15,10 +15,10 @@ _DATASETS_PATH = os.path.join(_PACKAGE_ROOT, 'Datasets') _TORCHVISION_AVAILABLE = _module_available("torchvision") -_TORCHVISION_MNIST_AVAILABLE = True +_TORCHVISION_MNIST_AVAILABLE = not bool(os.environ.get("PL_USE_MOCKED_MNIST", False)) _DALI_AVAILABLE = _module_available("nvidia.dali") -if _TORCHVISION_AVAILABLE: +if _TORCHVISION_MNIST_AVAILABLE: try: from torchvision.datasets.mnist import MNIST MNIST(_DATASETS_PATH, download=True) diff --git a/pl_examples/basic_examples/profiler_example.py b/pl_examples/basic_examples/profiler_example.py new file mode 100644 index 00000000000000..ca640a96f9588b --- /dev/null +++ b/pl_examples/basic_examples/profiler_example.py @@ -0,0 +1,102 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script will generate 2 traces: one for `training_step` and one for `validation_step`. +The traces can be visualized in 2 ways: +* With Chrome: + 1. Open Chrome and copy/paste this url: `chrome://tracing/`. + 2. Once tracing opens, click on `Load` at the top-right and load one of the generated traces. +* With PyTorch Tensorboard Profiler (Instructions are here: https://github.com/pytorch/kineto/tree/master/tb_plugin) + 1. pip install tensorboard torch-tb-profiler + 2. tensorboard --logdir={FOLDER} +""" + +import sys +from argparse import ArgumentParser + +import torch +import torchvision +import torchvision.models as models +import torchvision.transforms as T + +from pl_examples import cli_lightning_logo +from pytorch_lightning import LightningDataModule, LightningModule, Trainer + +DEFAULT_CMD_LINE = ( + "--max_epochs", + "1", + "--limit_train_batches", + "15", + "--limit_val_batches", + "15", + "--profiler", + "pytorch", + "--gpus", + f"{int(torch.cuda.is_available())}", +) + + +class ModelToProfile(LightningModule): + + def __init__(self, model): + super().__init__() + self.model = model + self.criterion = torch.nn.CrossEntropyLoss() + + def training_step(self, batch, batch_idx): + inputs, labels = batch + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + self.log("train_loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + inputs, labels = batch + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + self.log("val_loss", loss) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9) + + +class CIFAR10DataModule(LightningDataModule): + + transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()]) + + def train_dataloader(self, *args, **kwargs): + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=self.transform) + return torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=0) + + def val_dataloader(self, *args, **kwargs): + valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=self.transform) + return torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, num_workers=0) + + +def cli_main(): + + parser = ArgumentParser() + parser = Trainer.add_argparse_args(parser) + cmd_line = None if len(sys.argv) != 1 else DEFAULT_CMD_LINE + args = parser.parse_args(args=cmd_line) + + model = ModelToProfile(models.resnet50(pretrained=True)) + datamodule = CIFAR10DataModule() + trainer = Trainer(**vars(args)) + trainer.fit(model, datamodule=datamodule) + + +if __name__ == '__main__': + cli_lightning_logo() + cli_main() diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index 569078c994ba4e..b9660475bf2f7d 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -2,42 +2,17 @@ import logging import os -import sys -import time -_this_year = time.strftime("%Y") -__version__ = '1.3.0dev' -__author__ = 'William Falcon et al.' -__author_email__ = 'waf2107@columbia.edu' -__license__ = 'Apache-2.0' -__copyright__ = f'Copyright (c) 2018-{_this_year}, {__author__}.' -__homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning' -# this has to be simple string, see: https://github.com/pypa/twine/issues/522 -__docs__ = ( - "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." - " Scale your models. Write less boilerplate." +from pytorch_lightning.info import ( # noqa: F401 + __author__, + __author_email__, + __copyright__, + __docs__, + __homepage__, + __license__, + __version__, ) -__long_docs__ = """ -Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. - It's more of a style-guide than a framework. -In Lightning, you organize your code into 3 distinct categories: - -1. Research code (goes in the LightningModule). -2. Engineering code (you delete, and is handled by the Trainer). -3. Non-essential research code (logging, etc. this goes in Callbacks). - -Although your research/production project might start simple, once you add things like GPU AND TPU training, - 16-bit precision, etc, you end up spending more time engineering than researching. - Lightning automates AND rigorously tests those parts for you. - -Overall, Lightning guarantees rigorously tested, correct, modern best practices for the automated parts. - -Documentation -------------- -- https://pytorch-lightning.readthedocs.io/en/latest -- https://pytorch-lightning.readthedocs.io/en/stable -""" _root_logger = logging.getLogger() _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) @@ -50,32 +25,20 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) -try: - # This variable is injected in the __builtins__ by the build - # process. It used to enable importing subpackages of skimage when - # the binaries are not built - _ = None if __LIGHTNING_SETUP__ else None -except NameError: - __LIGHTNING_SETUP__: bool = False - -if __LIGHTNING_SETUP__: # pragma: no-cover - sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover - # We are not importing the rest of the lightning during the build process, as it may not be compiled yet -else: - from pytorch_lightning import metrics - from pytorch_lightning.callbacks import Callback - from pytorch_lightning.core import LightningDataModule, LightningModule - from pytorch_lightning.trainer import Trainer - from pytorch_lightning.utilities.seed import seed_everything - - __all__ = [ - 'Trainer', - 'LightningDataModule', - 'LightningModule', - 'Callback', - 'seed_everything', - 'metrics', - ] +from pytorch_lightning import metrics # noqa: E402 +from pytorch_lightning.callbacks import Callback # noqa: E402 +from pytorch_lightning.core import LightningDataModule, LightningModule # noqa: E402 +from pytorch_lightning.trainer import Trainer # noqa: E402 +from pytorch_lightning.utilities.seed import seed_everything # noqa: E402 + +__all__ = [ + 'Trainer', + 'LightningDataModule', + 'LightningModule', + 'Callback', + 'seed_everything', + 'metrics', +] # for compatibility with namespace packages __import__('pkg_resources').declare_namespace(__name__) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index ceb9d98505acc5..1dcd541ca0610c 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -21,6 +21,7 @@ from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import AMPType, LightningEnum @@ -85,7 +86,8 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None: model: the LightningModule """ self.setup_training_type_plugin(self.training_type_plugin, model) - self.setup_optimizers(trainer) + if not self.training_type_plugin.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) self.setup_precision_plugin(self.precision_plugin) def start_training(self, trainer: 'Trainer') -> None: @@ -97,12 +99,14 @@ def start_evaluating(self, trainer: 'Trainer') -> None: def start_predicting(self, trainer: 'Trainer') -> None: self.training_type_plugin.start_predicting(trainer) - def pre_dispatch(self) -> None: + def pre_dispatch(self, trainer: 'Trainer') -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.pre_dispatch() + if self.training_type_plugin.setup_optimizers_in_pre_dispatch: + self.setup_optimizers(trainer) self.precision_plugin.pre_dispatch() - def post_dispatch(self) -> None: + def post_dispatch(self, trainer: 'Trainer') -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.post_dispatch() self.precision_plugin.post_dispatch() @@ -216,7 +220,7 @@ def test_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context(): return self.training_type_plugin.test_step(*args) - def predict(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: + def predict_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: """The actual predict step. Args: @@ -232,7 +236,7 @@ def predict(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: args[0] = batch with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context(): - return self.training_type_plugin.predict(*args) + return self.training_type_plugin.predict_step(*args) def training_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: """A hook to do something at the end of the training step @@ -356,7 +360,12 @@ def setup_precision_plugin(self, plugin: PrecisionPlugin) -> None: def to_device(self, batch: Any) -> Any: """Pushes the batch to the root device""" - return self.batch_to_device(batch, self.root_device) + # Todo (tchaton) Better fix + is_dict = isinstance(batch, dict) + if is_dict: + batch = [batch] + batch = self.batch_to_device(batch, self.root_device) + return batch[0] if is_dict else batch @property def amp_backend(self) -> Optional[LightningEnum]: @@ -429,3 +438,31 @@ def results(self) -> Any: In distributed training, we make sure to transfer the results to the appropriate master process. """ return self.training_type_plugin.results + + # todo: remove in v1.5 + def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: + """ + Attaches the training type plugin to the accelerator. + Also transfers ownership of the model to this plugin + + .. deprecated::v1.3 + Will be removed in v1.5.0. + """ + rank_zero_warn( + 'Accelerator method `connect_training_type_plugin` was deprecated in v1.3.' + ' It will be removed in v1.5.' + ) + self.setup_training_type_plugin(plugin, model) + + # todo: remove in v1.5 + def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: + """Attaches the precision plugin to the accelerator + + .. deprecated::v1.3 + Will be removed in v1.5.0. + """ + rank_zero_warn( + 'Accelerator method `connect_precision_plugin` was deprecated in v1.3.' + ' It will be removed in v1.5.' + ) + self.setup_precision_plugin(plugin) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 74e57e2b5642e0..78db9a7dba12eb 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -39,8 +39,7 @@ class tqdm(_tqdm): """ - Custom tqdm progressbar where we append 0 to floating points/strings to - prevent the progress bar from flickering + Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering """ @staticmethod diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 9624f94652713b..8c68cc96eabc23 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -282,6 +282,18 @@ def on_test_end(self) -> None: """ # do something at the end of testing + def on_predict_start(self) -> None: + """ + Called at the beginning of predicting. + """ + # do something at the start of predicting + + def on_predict_end(self) -> None: + """ + Called at the end of predicting. + """ + # do something at the end of predicting + def on_before_zero_grad(self, optimizer: Optimizer) -> None: """ Called after optimizer.step() and before optimizer.zero_grad(). @@ -594,6 +606,18 @@ def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: will have an argument ``dataloader_idx`` which matches the order here. """ + def on_train_dataloader(self) -> None: + """Called before requesting the train dataloader.""" + + def on_val_dataloader(self) -> None: + """Called before requesting the val dataloader.""" + + def on_test_dataloader(self) -> None: + """Called before requesting the test dataloader.""" + + def on_predict_dataloader(self) -> None: + """Called before requesting the predict dataloader.""" + def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: """ Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 5f364169a9b52d..8e0718ab891dc9 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1054,7 +1054,7 @@ def test_epoch_end(self, outputs): self.log('final_metric', final_value) """ - def predict(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None): + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None): """ Use this function with trainer.predict(...). Override if you need to add any processing logic. """ diff --git a/pytorch_lightning/info.py b/pytorch_lightning/info.py new file mode 100644 index 00000000000000..0e7a1c25a74f15 --- /dev/null +++ b/pytorch_lightning/info.py @@ -0,0 +1,35 @@ +import time + +_this_year = time.strftime("%Y") +__version__ = '1.3.0dev' +__author__ = 'William Falcon et al.' +__author_email__ = 'waf2107@columbia.edu' +__license__ = 'Apache-2.0' +__copyright__ = f'Copyright (c) 2018-{_this_year}, {__author__}.' +__homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning' +# this has to be simple string, see: https://github.com/pypa/twine/issues/522 +__docs__ = ( + "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." + " Scale your models. Write less boilerplate." +) +__long_docs__ = """ +Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. + It's more of a style-guide than a framework. + +In Lightning, you organize your code into 3 distinct categories: + +1. Research code (goes in the LightningModule). +2. Engineering code (you delete, and is handled by the Trainer). +3. Non-essential research code (logging, etc. this goes in Callbacks). + +Although your research/production project might start simple, once you add things like GPU AND TPU training, + 16-bit precision, etc, you end up spending more time engineering than researching. + Lightning automates AND rigorously tests those parts for you. + +Overall, Lightning guarantees rigorously tested, correct, modern best practices for the automated parts. + +Documentation +------------- +- https://pytorch-lightning.readthedocs.io/en/latest +- https://pytorch-lightning.readthedocs.io/en/stable +""" diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 1701389cd1c64f..3b31dad5d3411d 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -28,7 +28,6 @@ from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401 from pytorch_lightning.metrics.functional.image_gradients import image_gradients # noqa: F401 from pytorch_lightning.metrics.functional.iou import iou # noqa: F401 -from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision # noqa: F401 from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error # noqa: F401 from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401 from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401 diff --git a/pytorch_lightning/metrics/functional/explained_variance.py b/pytorch_lightning/metrics/functional/explained_variance.py index fa8d43c06c7efd..bcfe698bf4c5ea 100644 --- a/pytorch_lightning/metrics/functional/explained_variance.py +++ b/pytorch_lightning/metrics/functional/explained_variance.py @@ -11,77 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple, Union +from typing import Sequence, Union import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import explained_variance as _explained_variance - -def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - _check_same_shape(preds, target) - return preds, target - - -def _explained_variance_compute( - preds: torch.Tensor, - target: torch.Tensor, - multioutput: str = 'uniform_average', -) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - diff_avg = torch.mean(target - preds, dim=0) - numerator = torch.mean((target - preds - diff_avg)**2, dim=0) - - target_avg = torch.mean(target, dim=0) - denominator = torch.mean((target - target_avg)**2, dim=0) - - # Take care of division by zero - nonzero_numerator = numerator != 0 - nonzero_denominator = denominator != 0 - valid_score = nonzero_numerator & nonzero_denominator - output_scores = torch.ones_like(diff_avg) - output_scores[valid_score] = 1.0 - (numerator[valid_score] / denominator[valid_score]) - output_scores[nonzero_numerator & ~nonzero_denominator] = 0. - - # Decide what to do in multioutput case - # Todo: allow user to pass in tensor with weights - if multioutput == 'raw_values': - return output_scores - if multioutput == 'uniform_average': - return torch.mean(output_scores) - if multioutput == 'variance_weighted': - denom_sum = torch.sum(denominator) - return torch.sum(denominator / denom_sum * output_scores) +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_explained_variance, ver_deprecate="1.3.0", ver_remove="1.5.0") def explained_variance( preds: torch.Tensor, target: torch.Tensor, multioutput: str = 'uniform_average', ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: """ - Computes explained variance. - - Args: - preds: estimated labels - target: ground truth labels - multioutput: Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is `'uniform_average'`.): - - * `'raw_values'` returns full set of scores - * `'uniform_average'` scores are uniformly averaged - * `'variance_weighted'` scores are weighted by their individual variances - - Example: - - >>> from pytorch_lightning.metrics.functional import explained_variance - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> explained_variance(preds, target) - tensor(0.9572) - - >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) - >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) - >>> explained_variance(preds, target, multioutput='raw_values') - tensor([0.9677, 1.0000]) + .. deprecated:: + Use :func:`torchmetrics.functional.explained_variance`. Will be removed in v1.5.0. """ - preds, target = _explained_variance_update(preds, target) - return _explained_variance_compute(preds, target, multioutput) diff --git a/pytorch_lightning/metrics/functional/ir_average_precision.py b/pytorch_lightning/metrics/functional/ir_average_precision.py deleted file mode 100644 index 83b14a21c5553d..00000000000000 --- a/pytorch_lightning/metrics/functional/ir_average_precision.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - - -def retrieval_average_precision(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - r""" - Computes average precision (for information retrieval), as explained - `here `_. - - `preds` and `target` should be of the same shape and live on the same device. If no `target` is ``True``, - 0 is returned. Target must be of type `bool` or `int`, otherwise an error is raised. - - Args: - preds: estimated probabilities of each document to be relevant. - target: ground truth about each document being relevant or not. Requires `bool` or `int` tensor. - - Return: - a single-value tensor with the average precision (AP) of the predictions `preds` wrt the labels `target`. - - Example: - >>> preds = torch.tensor([0.2, 0.3, 0.5]) - >>> target = torch.tensor([True, False, True]) - >>> retrieval_average_precision(preds, target) - tensor(0.8333) - """ - - if preds.shape != target.shape or preds.device != target.device: - raise ValueError("`preds` and `target` must have the same shape and live on the same device") - - if target.dtype not in (torch.bool, torch.int16, torch.int32, torch.int64): - raise ValueError("`target` must be a tensor of booleans or integers") - - if target.dtype is not torch.bool: - target = target.bool() - - if target.sum() == 0: - return torch.tensor(0, device=preds.device) - - target = target[torch.argsort(preds, dim=-1, descending=True)] - positions = torch.arange(1, len(target) + 1, device=target.device, dtype=torch.float32)[target > 0] - res = torch.div((torch.arange(len(positions), device=positions.device, dtype=torch.float32) + 1), positions).mean() - return res diff --git a/pytorch_lightning/metrics/functional/mean_absolute_error.py b/pytorch_lightning/metrics/functional/mean_absolute_error.py index 2bd8f125ecb9e4..85aa07c802eca5 100644 --- a/pytorch_lightning/metrics/functional/mean_absolute_error.py +++ b/pytorch_lightning/metrics/functional/mean_absolute_error.py @@ -11,40 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import mean_absolute_error as _mean_absolute_error - -def _mean_absolute_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - sum_abs_error = torch.sum(torch.abs(preds - target)) - n_obs = target.numel() - return sum_abs_error, n_obs - - -def _mean_absolute_error_compute(sum_abs_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_abs_error / n_obs +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_mean_absolute_error, ver_deprecate="1.3.0", ver_remove="1.5.0") def mean_absolute_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean absolute error - - Args: - pred: estimated labels - target: ground truth labels - - Return: - Tensor with MAE - - Example: - >>> from pytorch_lightning.metrics.functional import mean_absolute_error - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_absolute_error(x, y) - tensor(0.2500) + .. deprecated:: + Use :func:`torchmetrics.functional.mean_absolute_error`. Will be removed in v1.5.0. """ - sum_abs_error, n_obs = _mean_absolute_error_update(preds, target) - return _mean_absolute_error_compute(sum_abs_error, n_obs) diff --git a/pytorch_lightning/metrics/functional/mean_relative_error.py b/pytorch_lightning/metrics/functional/mean_relative_error.py index bfe5eb6b847d70..be21371bdc91a3 100644 --- a/pytorch_lightning/metrics/functional/mean_relative_error.py +++ b/pytorch_lightning/metrics/functional/mean_relative_error.py @@ -11,43 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional.regression.mean_relative_error import mean_relative_error as _mean_relative_error - -def _mean_relative_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - target_nz = target.clone() - target_nz[target == 0] = 1 - sum_rltv_error = torch.sum(torch.abs((preds - target) / target_nz)) - n_obs = target.numel() - return sum_rltv_error, n_obs - - -def _mean_relative_error_compute(sum_rltv_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_rltv_error / n_obs +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_mean_relative_error, ver_deprecate="1.3.0", ver_remove="1.5.0") def mean_relative_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean relative error - - Args: - pred: estimated labels - target: ground truth labels - - Return: - Tensor with mean relative error - - Example: - - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_relative_error(x, y) - tensor(0.1250) - + .. deprecated:: + Use :func:`torchmetrics.functional.regression.mean_relative_error`. Will be removed in v1.5.0. """ - sum_rltv_error, n_obs = _mean_relative_error_update(preds, target) - return _mean_relative_error_compute(sum_rltv_error, n_obs) diff --git a/pytorch_lightning/metrics/functional/mean_squared_error.py b/pytorch_lightning/metrics/functional/mean_squared_error.py index 66c0aadef06510..9d1850dcd8689a 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_error.py @@ -11,40 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import mean_squared_error as _mean_squared_error - -def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - sum_squared_error = torch.sum(torch.pow(preds - target, 2)) - n_obs = target.numel() - return sum_squared_error, n_obs - - -def _mean_squared_error_compute(sum_squared_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_squared_error / n_obs +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_mean_squared_error, ver_deprecate="1.3.0", ver_remove="1.5.0") def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean squared error - - Args: - preds: estimated labels - target: ground truth labels - - Return: - Tensor with MSE - - Example: - >>> from pytorch_lightning.metrics.functional import mean_squared_error - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_squared_error(x, y) - tensor(0.2500) + .. deprecated:: + Use :func:`torchmetrics.functional.mean_squared_error`. Will be removed in v1.5.0. """ - sum_squared_error, n_obs = _mean_squared_error_update(preds, target) - return _mean_squared_error_compute(sum_squared_error, n_obs) diff --git a/pytorch_lightning/metrics/functional/mean_squared_log_error.py b/pytorch_lightning/metrics/functional/mean_squared_log_error.py index baec63c7248f27..56654ea47daf28 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_log_error.py @@ -11,40 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import mean_squared_log_error as _mean_squared_log_error - -def _mean_squared_log_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: - _check_same_shape(preds, target) - sum_squared_log_error = torch.sum(torch.pow(torch.log1p(preds) - torch.log1p(target), 2)) - n_obs = target.numel() - return sum_squared_log_error, n_obs - - -def _mean_squared_log_error_compute(sum_squared_log_error: torch.Tensor, n_obs: int) -> torch.Tensor: - return sum_squared_log_error / n_obs +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_mean_squared_log_error, ver_deprecate="1.3.0", ver_remove="1.5.0") def mean_squared_log_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ - Computes mean squared log error - - Args: - preds: estimated labels - target: ground truth labels - - Return: - Tensor with RMSLE - - Example: - >>> from pytorch_lightning.metrics.functional import mean_squared_log_error - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mean_squared_log_error(x, y) - tensor(0.0207) + .. deprecated:: + Use :func:`torchmetrics.functional.mean_squared_log_error`. Will be removed in v1.5.0. """ - sum_squared_log_error, n_obs = _mean_squared_log_error_update(preds, target) - return _mean_squared_log_error_compute(sum_squared_log_error, n_obs) diff --git a/pytorch_lightning/metrics/functional/psnr.py b/pytorch_lightning/metrics/functional/psnr.py index 0b50ea092b7fad..dd7aa44ae628ea 100644 --- a/pytorch_lightning/metrics/functional/psnr.py +++ b/pytorch_lightning/metrics/functional/psnr.py @@ -14,46 +14,12 @@ from typing import Optional, Tuple, Union import torch -from torchmetrics.utilities import reduce +from torchmetrics.functional import psnr as _psnr -from pytorch_lightning.utilities import rank_zero_warn - - -def _psnr_compute( - sum_squared_error: torch.Tensor, - n_obs: torch.Tensor, - data_range: torch.Tensor, - base: float = 10.0, - reduction: str = 'elementwise_mean', -) -> torch.Tensor: - psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs) - psnr = psnr_base_e * (10 / torch.log(torch.tensor(base))) - return reduce(psnr, reduction=reduction) - - -def _psnr_update(preds: torch.Tensor, - target: torch.Tensor, - dim: Optional[Union[int, Tuple[int, ...]]] = None) -> Tuple[torch.Tensor, torch.Tensor]: - if dim is None: - sum_squared_error = torch.sum(torch.pow(preds - target, 2)) - n_obs = torch.tensor(target.numel(), device=target.device) - return sum_squared_error, n_obs - - sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim) - - if isinstance(dim, int): - dim_list = [dim] - else: - dim_list = list(dim) - if not dim_list: - n_obs = torch.tensor(target.numel(), device=target.device) - else: - n_obs = torch.tensor(target.size(), device=target.device)[dim_list].prod() - n_obs = n_obs.expand_as(sum_squared_error) - - return sum_squared_error, n_obs +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_psnr, ver_deprecate="1.3.0", ver_remove="1.5.0") def psnr( preds: torch.Tensor, target: torch.Tensor, @@ -63,50 +29,6 @@ def psnr( dim: Optional[Union[int, Tuple[int, ...]]] = None, ) -> torch.Tensor: """ - Computes the peak signal-to-noise ratio - - Args: - preds: estimated signal - target: groun truth signal - data_range: - the range of the data. If None, it is determined from the data (max - min). ``data_range`` must be given - when ``dim`` is not None. - base: a base of a logarithm to use (default: 10) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - dim: - Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is - None meaning scores will be reduced across all dimensions. - Return: - Tensor with PSNR score - - Raises: - ValueError: - If ``dim`` is not ``None`` and ``data_range`` is not provided. - - Example: - >>> from pytorch_lightning.metrics.functional import psnr - >>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) - >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) - >>> psnr(pred, target) - tensor(2.5527) - + .. deprecated:: + Use :func:`torchmetrics.functional.psnr`. Will be removed in v1.5.0. """ - if dim is None and reduction != 'elementwise_mean': - rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') - - if data_range is None: - if dim is not None: - # Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to calculate - # `data_range` in the future. - raise ValueError("The `data_range` must be given when `dim` is not None.") - - data_range = target.max() - target.min() - else: - data_range = torch.tensor(float(data_range)) - sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim) - return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction) diff --git a/pytorch_lightning/metrics/functional/r2score.py b/pytorch_lightning/metrics/functional/r2score.py index d3f1090564a88b..49273d9cefaed9 100644 --- a/pytorch_lightning/metrics/functional/r2score.py +++ b/pytorch_lightning/metrics/functional/r2score.py @@ -11,133 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple import torch -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import r2score as _r2score -from pytorch_lightning.utilities import rank_zero_warn - - -def _r2score_update( - preds: torch.tensor, - target: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - _check_same_shape(preds, target) - if preds.ndim > 2: - raise ValueError( - 'Expected both prediction and target to be 1D or 2D tensors,' - f' but recevied tensors with dimension {preds.shape}' - ) - if len(preds) < 2: - raise ValueError('Needs atleast two samples to calculate r2 score.') - - sum_error = torch.sum(target, dim=0) - sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0) - residual = torch.sum(torch.pow(target - preds, 2.0), dim=0) - total = target.size(0) - - return sum_squared_error, sum_error, residual, total - - -def _r2score_compute( - sum_squared_error: torch.Tensor, - sum_error: torch.Tensor, - residual: torch.Tensor, - total: torch.Tensor, - adjusted: int = 0, - multioutput: str = "uniform_average" -) -> torch.Tensor: - mean_error = sum_error / total - diff = sum_squared_error - sum_error * mean_error - raw_scores = 1 - (residual / diff) - - if multioutput == "raw_values": - r2score = raw_scores - elif multioutput == "uniform_average": - r2score = torch.mean(raw_scores) - elif multioutput == "variance_weighted": - diff_sum = torch.sum(diff) - r2score = torch.sum(diff / diff_sum * raw_scores) - else: - raise ValueError( - 'Argument `multioutput` must be either `raw_values`,' - f' `uniform_average` or `variance_weighted`. Received {multioutput}.' - ) - - if adjusted < 0 or not isinstance(adjusted, int): - raise ValueError('`adjusted` parameter should be an integer larger or' ' equal to 0.') - - if adjusted != 0: - if adjusted > total - 1: - rank_zero_warn( - "More independent regressions than datapoints in" - " adjusted r2 score. Falls back to standard r2 score.", UserWarning - ) - elif adjusted == total - 1: - rank_zero_warn("Division by zero in adjusted r2 score. Falls back to" " standard r2 score.", UserWarning) - else: - r2score = 1 - (1 - r2score) * (total - 1) / (total - adjusted - 1) - return r2score +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_r2score, ver_deprecate="1.3.0", ver_remove="1.5.0") def r2score( preds: torch.Tensor, target: torch.Tensor, adjusted: int = 0, multioutput: str = "uniform_average", ) -> torch.Tensor: - r""" - Computes r2 score also known as `coefficient of determination - `_: - - .. math:: R^2 = 1 - \frac{SS_res}{SS_tot} - - where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and - :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate - adjusted r2 score given by - - .. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1} - - where the parameter :math:`k` (the number of independent regressors) should - be provided as the ``adjusted`` argument. - - Args: - preds: estimated labels - target: ground truth labels - adjusted: number of independent regressors for calculating adjusted r2 score. - Default 0 (standard r2 score). - multioutput: Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is ``'uniform_average'``.): - - * ``'raw_values'`` returns full set of scores - * ``'uniform_average'`` scores are uniformly averaged - * ``'variance_weighted'`` scores are weighted by their individual variances - - Raises: - ValueError: - If both ``preds`` and ``targets`` are not ``1D`` or ``2D`` tensors. - ValueError: - If ``len(preds)`` is less than ``2`` - since at least ``2`` sampels are needed to calculate r2 score. - ValueError: - If ``multioutput`` is not one of ``raw_values``, - ``uniform_average`` or ``variance_weighted``. - ValueError: - If ``adjusted`` is not an ``integer`` greater than ``0``. - - Example: - - >>> from pytorch_lightning.metrics.functional import r2score - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> r2score(preds, target) - tensor(0.9486) - - >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) - >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) - >>> r2score(preds, target, multioutput='raw_values') - tensor([0.9654, 0.9082]) """ - sum_squared_error, sum_error, residual, total = _r2score_update(preds, target) - return _r2score_compute(sum_squared_error, sum_error, residual, total, adjusted, multioutput) + .. deprecated:: + Use :func:`torchmetrics.functional.r2score`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/ssim.py b/pytorch_lightning/metrics/functional/ssim.py index 4899a3ad3be4dc..8809fec8d8ff1c 100644 --- a/pytorch_lightning/metrics/functional/ssim.py +++ b/pytorch_lightning/metrics/functional/ssim.py @@ -11,107 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence import torch -from torch.nn import functional as F -from torchmetrics.utilities import reduce -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.functional import ssim as _ssim - -def _gaussian(kernel_size: int, sigma: int, dtype: torch.dtype, device: torch.device): - dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device) - gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2) - return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) - - -def _gaussian_kernel( - channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device -): - gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device) - gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device) - kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) - - return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) - - -def _ssim_update( - preds: torch.Tensor, - target: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - if preds.dtype != target.dtype: - raise TypeError( - "Expected `preds` and `target` to have the same data type." - f" Got pred: {preds.dtype} and target: {target.dtype}." - ) - _check_same_shape(preds, target) - if len(preds.shape) != 4: - raise ValueError( - "Expected `preds` and `target` to have BxCxHxW shape." - f" Got pred: {preds.shape} and target: {target.shape}." - ) - return preds, target - - -def _ssim_compute( - preds: torch.Tensor, - target: torch.Tensor, - kernel_size: Sequence[int] = (11, 11), - sigma: Sequence[float] = (1.5, 1.5), - reduction: str = "elementwise_mean", - data_range: Optional[float] = None, - k1: float = 0.01, - k2: float = 0.03, -): - if len(kernel_size) != 2 or len(sigma) != 2: - raise ValueError( - "Expected `kernel_size` and `sigma` to have the length of two." - f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}." - ) - - if any(x % 2 == 0 or x <= 0 for x in kernel_size): - raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.") - - if any(y <= 0 for y in sigma): - raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.") - - if data_range is None: - data_range = max(preds.max() - preds.min(), target.max() - target.min()) - - c1 = pow(k1 * data_range, 2) - c2 = pow(k2 * data_range, 2) - device = preds.device - - channel = preds.size(1) - dtype = preds.dtype - kernel = _gaussian_kernel(channel, kernel_size, sigma, dtype, device) - pad_w = (kernel_size[0] - 1) // 2 - pad_h = (kernel_size[1] - 1) // 2 - - preds = F.pad(preds, (pad_w, pad_w, pad_h, pad_h), mode='reflect') - target = F.pad(target, (pad_w, pad_w, pad_h, pad_h), mode='reflect') - - input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) - outputs = F.conv2d(input_list, kernel, groups=channel) - output_list = [outputs[x * preds.size(0):(x + 1) * preds.size(0)] for x in range(len(outputs))] - - mu_pred_sq = output_list[0].pow(2) - mu_target_sq = output_list[1].pow(2) - mu_pred_target = output_list[0] * output_list[1] - - sigma_pred_sq = output_list[2] - mu_pred_sq - sigma_target_sq = output_list[3] - mu_target_sq - sigma_pred_target = output_list[4] - mu_pred_target - - upper = 2 * sigma_pred_target + c2 - lower = sigma_pred_sq + sigma_target_sq + c2 - - ssim_idx = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower) - ssim_idx = ssim_idx[..., pad_h:-pad_h, pad_w:-pad_w] - - return reduce(ssim_idx, reduction) +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_ssim, ver_deprecate="1.3.0", ver_remove="1.5.0") def ssim( preds: torch.Tensor, target: torch.Tensor, @@ -123,44 +31,6 @@ def ssim( k2: float = 0.03, ) -> torch.Tensor: """ - Computes Structual Similarity Index Measure - - Args: - preds: estimated image - target: ground truth image - kernel_size: size of the gaussian kernel (default: (11, 11)) - sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - data_range: Range of the image. If ``None``, it is determined from the image (max - min) - k1: Parameter of SSIM. Default: 0.01 - k2: Parameter of SSIM. Default: 0.03 - - Return: - Tensor with SSIM score - - Raises: - TypeError: - If ``preds`` and ``target`` don't have the same data type. - ValueError: - If ``preds`` and ``target`` don't have ``BxCxHxW shape``. - ValueError: - If the length of ``kernel_size`` or ``sigma`` is not ``2``. - ValueError: - If one of the elements of ``kernel_size`` is not an ``odd positive number``. - ValueError: - If one of the elements of ``sigma`` is not a ``positive number``. - - Example: - >>> from pytorch_lightning.metrics.functional import ssim - >>> preds = torch.rand([16, 1, 16, 16]) - >>> target = preds * 0.75 - >>> ssim(preds, target) - tensor(0.9219) + .. deprecated:: + Use :func:`torchmetrics.functional.ssim`. Will be removed in v1.5.0. """ - preds, target = _ssim_update(preds, target) - return _ssim_compute(preds, target, kernel_size, sigma, reduction, data_range, k1, k2) diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py index 8b0259694ef4c6..4f820718545cba 100644 --- a/pytorch_lightning/metrics/regression/explained_variance.py +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -13,72 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import ExplainedVariance as _ExplainedVariance -from pytorch_lightning.metrics.functional.explained_variance import ( - _explained_variance_compute, - _explained_variance_update, -) -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.deprecation import deprecated -class ExplainedVariance(Metric): - r""" - Computes `explained variance - `_: - - .. math:: \text{ExplainedVariance} = 1 - \frac{\text{Var}(y - \hat{y})}{\text{Var}(y)} - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a - tensor of predictions. - - Forward accepts - - - ``preds`` (float tensor): ``(N,)`` or ``(N, ...)`` (multioutput) - - ``target`` (long tensor): ``(N,)`` or ``(N, ...)`` (multioutput) - - In the case of multioutput, as default the variances will be uniformly - averaged over the additional dimensions. Please see argument `multioutput` - for changing this behavior. - - Args: - multioutput: - Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is `'uniform_average'`.): - - * `'raw_values'` returns full set of scores - * `'uniform_average'` scores are uniformly averaged - * `'variance_weighted'` scores are weighted by their individual variances - - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Raises: - ValueError: - If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``. - - Example: - - >>> from pytorch_lightning.metrics import ExplainedVariance - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> explained_variance = ExplainedVariance() - >>> explained_variance(preds, target) - tensor(0.9572) - - >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) - >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) - >>> explained_variance = ExplainedVariance(multioutput='raw_values') - >>> explained_variance(preds, target) - tensor([0.9677, 1.0000]) - """ +class ExplainedVariance(_ExplainedVariance): + @deprecated(target=_ExplainedVariance, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, multioutput: str = 'uniform_average', @@ -87,43 +29,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') - if multioutput not in allowed_multioutput: - raise ValueError( - f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}' - ) - self.multioutput = multioutput - self.add_state("y", default=[], dist_reduce_fx=None) - self.add_state("y_pred", default=[], dist_reduce_fx=None) - - rank_zero_warn( - 'Metric `ExplainedVariance` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' - ) - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - preds, target = _explained_variance_update(preds, target) - self.y_pred.append(preds) - self.y.append(target) + This implementation refers to :class:`~torchmetrics.ExplainedVariance`. - def compute(self): - """ - Computes explained variance over state. + .. deprecated:: + Use :class:`~torchmetrics.ExplainedVariance`. Will be removed in v1.5.0. """ - preds = torch.cat(self.y_pred, dim=0) - target = torch.cat(self.y, dim=0) - return _explained_variance_compute(preds, target, self.multioutput) diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py index 484ccbe83284e9..8510275c127d78 100644 --- a/pytorch_lightning/metrics/regression/mean_absolute_error.py +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -13,42 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import MeanAbsoluteError as _MeanAbsoluteError -from pytorch_lightning.metrics.functional.mean_absolute_error import ( - _mean_absolute_error_compute, - _mean_absolute_error_update, -) +from pytorch_lightning.utilities.deprecation import deprecated -class MeanAbsoluteError(Metric): - r""" - Computes `mean absolute error `_ (MAE): - - .. math:: \text{MAE} = \frac{1}{N}\sum_i^N | y_i - \hat{y_i} | - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - - Args: - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example: - - >>> from pytorch_lightning.metrics import MeanAbsoluteError - >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) - >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> mean_absolute_error = MeanAbsoluteError() - >>> mean_absolute_error(preds, target) - tensor(0.5000) - """ +class MeanAbsoluteError(_MeanAbsoluteError): + @deprecated(target=_MeanAbsoluteError, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, compute_on_step: bool = True, @@ -56,31 +28,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - sum_abs_error, n_obs = _mean_absolute_error_update(preds, target) + This implementation refers to :class:`~torchmetrics.MeanAbsoluteError`. - self.sum_abs_error += sum_abs_error - self.total += n_obs - - def compute(self): - """ - Computes mean absolute error over state. + .. deprecated:: + Use :class:`~torchmetrics.MeanAbsoluteError`. Will be removed in v1.5.0. """ - return _mean_absolute_error_compute(self.sum_abs_error, self.total) diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py index c26371514e7cd6..cbe09faf0046ce 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -13,43 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import MeanSquaredError as _MeanSquaredError -from pytorch_lightning.metrics.functional.mean_squared_error import ( - _mean_squared_error_compute, - _mean_squared_error_update, -) +from pytorch_lightning.utilities.deprecation import deprecated -class MeanSquaredError(Metric): - r""" - Computes `mean squared error `_ (MSE): - - .. math:: \text{MSE} = \frac{1}{N}\sum_i^N(y_i - \hat{y_i})^2 - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - - Args: - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example: - - >>> from pytorch_lightning.metrics import MeanSquaredError - >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) - >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) - >>> mean_squared_error = MeanSquaredError() - >>> mean_squared_error(preds, target) - tensor(0.8750) - - """ +class MeanSquaredError(_MeanSquaredError): + @deprecated(target=_MeanSquaredError, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, compute_on_step: bool = True, @@ -57,31 +28,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - sum_squared_error, n_obs = _mean_squared_error_update(preds, target) - - self.sum_squared_error += sum_squared_error - self.total += n_obs + This implementation refers to :class:`~torchmetrics.MeanSquaredError`. - def compute(self): - """ - Computes mean squared error over state. + .. deprecated:: + Use :class:`~torchmetrics.MeanSquaredError`. Will be removed in v1.5.0. """ - return _mean_squared_error_compute(self.sum_squared_error, self.total) diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py index caaf09a3663ffa..795d6f5409abfb 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -13,45 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import MeanSquaredLogError as _MeanSquaredLogError -from pytorch_lightning.metrics.functional.mean_squared_log_error import ( - _mean_squared_log_error_compute, - _mean_squared_log_error_update, -) +from pytorch_lightning.utilities.deprecation import deprecated -class MeanSquaredLogError(Metric): - r""" - Computes `mean squared logarithmic error - `_ - (MSLE): - - .. math:: \text{MSLE} = \frac{1}{N}\sum_i^N (\log_e(1 + y_i) - \log_e(1 + \hat{y_i}))^2 - - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - - Args: - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example: - - >>> from pytorch_lightning.metrics import MeanSquaredLogError - >>> target = torch.tensor([2.5, 5, 4, 8]) - >>> preds = torch.tensor([3, 5, 2.5, 7]) - >>> mean_squared_log_error = MeanSquaredLogError() - >>> mean_squared_log_error(preds, target) - tensor(0.0397) - - """ +class MeanSquaredLogError(_MeanSquaredLogError): + @deprecated(target=_MeanSquaredLogError, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, compute_on_step: bool = True, @@ -59,31 +28,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - sum_squared_log_error, n_obs = _mean_squared_log_error_update(preds, target) - - self.sum_squared_log_error += sum_squared_log_error - self.total += n_obs + This implementation refers to :class:`~torchmetrics.MeanSquaredLogError`. - def compute(self): - """ - Compute mean squared logarithmic error over state. + .. deprecated:: + Use :class:`~torchmetrics.MeanSquaredLogError`. Will be removed in v1.5.0. """ - return _mean_squared_log_error_compute(self.sum_squared_log_error, self.total) diff --git a/pytorch_lightning/metrics/regression/psnr.py b/pytorch_lightning/metrics/regression/psnr.py index 746ff1e52d574f..85b8eceaa24c5b 100644 --- a/pytorch_lightning/metrics/regression/psnr.py +++ b/pytorch_lightning/metrics/regression/psnr.py @@ -11,61 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Tuple, Union -import torch -from torchmetrics import Metric +from torchmetrics import PSNR as _PSNR -from pytorch_lightning import utilities -from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update +from pytorch_lightning.utilities.deprecation import deprecated -class PSNR(Metric): - r""" - Computes `peak signal-to-noise ratio `_ (PSNR): - - .. math:: \text{PSNR}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)}\right) - - Where :math:`\text{MSE}` denotes the `mean-squared-error - `_ function. - - Args: - data_range: - the range of the data. If None, it is determined from the data (max - min). - The ``data_range`` must be given when ``dim`` is not None. - base: a base of a logarithm to use (default: 10) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - dim: - Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is - None meaning scores will be reduced across all dimensions and all batches. - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Raises: - ValueError: - If ``dim`` is not ``None`` and ``data_range`` is not given. - - Example: - - >>> from pytorch_lightning.metrics import PSNR - >>> psnr = PSNR() - >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) - >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) - >>> psnr(preds, target) - tensor(2.5527) - - """ +class PSNR(_PSNR): + @deprecated(target=_PSNR, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, data_range: Optional[float] = None, @@ -76,71 +31,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) - - if dim is None and reduction != 'elementwise_mean': - utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') - - if dim is None: - self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - else: - self.add_state("sum_squared_error", default=[]) - self.add_state("total", default=[]) - - if data_range is None: - if dim is not None: - # Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to - # calculate `data_range` in the future. - raise ValueError("The `data_range` must be given when `dim` is not None.") - - self.data_range = None - self.add_state("min_target", default=torch.tensor(0.0), dist_reduce_fx=torch.min) - self.add_state("max_target", default=torch.tensor(0.0), dist_reduce_fx=torch.max) - else: - self.register_buffer("data_range", torch.tensor(float(data_range))) - self.base = base - self.reduction = reduction - self.dim = tuple(dim) if isinstance(dim, Sequence) else dim - - def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. + This implementation refers to :class:`~torchmetrics.PSNR`. - Args: - preds: Predictions from model - target: Ground truth values + .. deprecated:: + Use :class:`~torchmetrics.PSNR`. Will be removed in v1.5.0. """ - sum_squared_error, n_obs = _psnr_update(preds, target, dim=self.dim) - if self.dim is None: - if self.data_range is None: - # keep track of min and max target values - self.min_target = min(target.min(), self.min_target) - self.max_target = max(target.max(), self.max_target) - - self.sum_squared_error += sum_squared_error - self.total += n_obs - else: - self.sum_squared_error.append(sum_squared_error) - self.total.append(n_obs) - - def compute(self): - """ - Compute peak signal-to-noise ratio over state. - """ - if self.data_range is not None: - data_range = self.data_range - else: - data_range = self.max_target - self.min_target - - if self.dim is None: - sum_squared_error = self.sum_squared_error - total = self.total - else: - sum_squared_error = torch.cat([values.flatten() for values in self.sum_squared_error]) - total = torch.cat([values.flatten() for values in self.total]) - return _psnr_compute(sum_squared_error, total, data_range, base=self.base, reduction=self.reduction) diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py index 8156b8bc72d484..52621d6df7c286 100644 --- a/pytorch_lightning/metrics/regression/r2score.py +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -13,81 +13,14 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch -from torchmetrics import Metric +from torchmetrics import R2Score as _R2Score -from pytorch_lightning.metrics.functional.r2score import _r2score_compute, _r2score_update +from pytorch_lightning.utilities.deprecation import deprecated -class R2Score(Metric): - r""" - Computes r2 score also known as `coefficient of determination - `_: - - .. math:: R^2 = 1 - \frac{SS_res}{SS_tot} - - where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and - :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate - adjusted r2 score given by - - .. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1} - - where the parameter :math:`k` (the number of independent regressors) should - be provided as the `adjusted` argument. - - Forward accepts - - - ``preds`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput) - - ``target`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput) - - In the case of multioutput, as default the variances will be uniformly - averaged over the additional dimensions. Please see argument `multioutput` - for changing this behavior. - - Args: - num_outputs: - Number of outputs in multioutput setting (default is 1) - adjusted: - number of independent regressors for calculating adjusted r2 score. - Default 0 (standard r2 score). - multioutput: - Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is ``'uniform_average'``.): - - * ``'raw_values'`` returns full set of scores - * ``'uniform_average'`` scores are uniformly averaged - * ``'variance_weighted'`` scores are weighted by their individual variances - - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Raises: - ValueError: - If ``adjusted`` parameter is not an integer larger or equal to 0. - ValueError: - If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``. - - Example: - - >>> from pytorch_lightning.metrics import R2Score - >>> target = torch.tensor([3, -0.5, 2, 7]) - >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> r2score = R2Score() - >>> r2score(preds, target) - tensor(0.9486) - - >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) - >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) - >>> r2score = R2Score(num_outputs=2, multioutput='raw_values') - >>> r2score(preds, target) - tensor([0.9654, 0.9082]) - """ +class R2Score(_R2Score): + @deprecated(target=_R2Score, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, num_outputs: int = 1, @@ -98,50 +31,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.num_outputs = num_outputs - - if adjusted < 0 or not isinstance(adjusted, int): - raise ValueError('`adjusted` parameter should be an integer larger or equal to 0.') - self.adjusted = adjusted - - allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') - if multioutput not in allowed_multioutput: - raise ValueError( - f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}' - ) - self.multioutput = multioutput - - self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") - self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") - self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values """ - sum_squared_error, sum_error, residual, total = _r2score_update(preds, target) + This implementation refers to :class:`~torchmetrics.R2Score`. - self.sum_squared_error += sum_squared_error - self.sum_error += sum_error - self.residual += residual - self.total += total - - def compute(self) -> torch.Tensor: - """ - Computes r2 score over the metric states. + .. deprecated:: + Use :class:`~torchmetrics.R2Score`. Will be removed in v1.5.0. """ - return _r2score_compute( - self.sum_squared_error, self.sum_error, self.residual, self.total, self.adjusted, self.multioutput - ) diff --git a/pytorch_lightning/metrics/regression/ssim.py b/pytorch_lightning/metrics/regression/ssim.py index a3bbab938ffad9..b290808c6fa5e4 100644 --- a/pytorch_lightning/metrics/regression/ssim.py +++ b/pytorch_lightning/metrics/regression/ssim.py @@ -13,43 +13,14 @@ # limitations under the License. from typing import Any, Optional, Sequence -import torch -from torchmetrics import Metric +from torchmetrics import SSIM as _SSIM -from pytorch_lightning.metrics.functional.ssim import _ssim_compute, _ssim_update -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.deprecation import deprecated -class SSIM(Metric): - """ - Computes `Structual Similarity Index Measure - `_ (SSIM). - - Args: - kernel_size: size of the gaussian kernel (default: (11, 11)) - sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - data_range: Range of the image. If ``None``, it is determined from the image (max - min) - k1: Parameter of SSIM. Default: 0.01 - k2: Parameter of SSIM. Default: 0.03 - - Return: - Tensor with SSIM score - - Example: - >>> from pytorch_lightning.metrics import SSIM - >>> preds = torch.rand([16, 1, 16, 16]) - >>> target = preds * 0.75 - >>> ssim = SSIM() - >>> ssim(preds, target) - tensor(0.9219) - """ +class SSIM(_SSIM): + @deprecated(target=_SSIM, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, kernel_size: Sequence[int] = (11, 11), @@ -62,44 +33,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) - rank_zero_warn( - 'Metric `SSIM` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' - ) - - self.add_state("y", default=[], dist_reduce_fx=None) - self.add_state("y_pred", default=[], dist_reduce_fx=None) - self.kernel_size = kernel_size - self.sigma = sigma - self.data_range = data_range - self.k1 = k1 - self.k2 = k2 - self.reduction = reduction - - def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. + This implementation refers to :class:`~torchmetrics.SSIM`. - Args: - preds: Predictions from model - target: Ground truth values - """ - preds, target = _ssim_update(preds, target) - self.y_pred.append(preds) - self.y.append(target) - - def compute(self): - """ - Computes explained variance over state. + .. deprecated:: + Use :class:`~torchmetrics.SSIM`. Will be removed in v1.5.0. """ - preds = torch.cat(self.y_pred, dim=0) - target = torch.cat(self.y, dim=0) - return _ssim_compute( - preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2 - ) diff --git a/pytorch_lightning/metrics/retrieval/__init__.py b/pytorch_lightning/metrics/retrieval/__init__.py deleted file mode 100644 index c5c12b3b6643c9..00000000000000 --- a/pytorch_lightning/metrics/retrieval/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pytorch_lightning.metrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401 -from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401 diff --git a/pytorch_lightning/metrics/retrieval/mean_average_precision.py b/pytorch_lightning/metrics/retrieval/mean_average_precision.py deleted file mode 100644 index 956a53cca2e778..00000000000000 --- a/pytorch_lightning/metrics/retrieval/mean_average_precision.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch - -from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision -from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric - - -class RetrievalMAP(RetrievalMetric): - r""" - Computes `Mean Average Precision - `_. - - Works with binary data. Accepts integer or float predictions from a model output. - - Forward accepts - - ``indexes`` (long tensor): ``(N, ...)`` - - ``preds`` (float tensor): ``(N, ...)`` - - ``target`` (long or bool tensor): ``(N, ...)`` - - `indexes`, `preds` and `target` must have the same dimension. - `indexes` indicate to which query a prediction belongs. - Predictions will be first grouped by indexes and then MAP will be computed as the mean - of the Average Precisions over each query. - - Args: - query_without_relevant_docs: - Specify what to do with queries that do not have at least a positive target. Choose from: - - - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned - - ``'error'``: raise a ``ValueError`` - - ``'pos'``: score on those queries is counted as ``1.0`` - - ``'neg'``: score on those queries is counted as ``0.0`` - exclude: - Do not take into account predictions where the target is equal to this value. default `-100` - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects - the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None - - Example: - >>> from pytorch_lightning.metrics import RetrievalMAP - >>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1]) - >>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) - >>> target = torch.tensor([False, False, True, False, True, False, False]) - - >>> map = RetrievalMAP() - >>> map(indexes, preds, target) - tensor(0.7500) - >>> map.compute() - tensor(0.7500) - """ - - def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - valid_indexes = target != self.exclude - return retrieval_average_precision(preds[valid_indexes], target[valid_indexes]) diff --git a/pytorch_lightning/metrics/retrieval/retrieval_metric.py b/pytorch_lightning/metrics/retrieval/retrieval_metric.py deleted file mode 100644 index 6f9088d00083cf..00000000000000 --- a/pytorch_lightning/metrics/retrieval/retrieval_metric.py +++ /dev/null @@ -1,140 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Callable, Optional - -import torch -from torchmetrics import Metric - -from pytorch_lightning.metrics.utils import get_group_indexes - -#: get_group_indexes is used to group predictions belonging to the same query - -IGNORE_IDX = -100 - - -class RetrievalMetric(Metric, ABC): - r""" - Works with binary data. Accepts integer or float predictions from a model output. - - Forward accepts - - ``indexes`` (long tensor): ``(N, ...)`` - - ``preds`` (float or int tensor): ``(N, ...)`` - - ``target`` (long or bool tensor): ``(N, ...)`` - - `indexes`, `preds` and `target` must have the same dimension and will be flatten - to single dimension once provided. - - `indexes` indicate to which query a prediction belongs. - Predictions will be first grouped by indexes. Then the - real metric, defined by overriding the `_metric` method, - will be computed as the mean of the scores over each query. - - Args: - query_without_relevant_docs: - Specify what to do with queries that do not have at least a positive target. Choose from: - - - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned - - ``'error'``: raise a ``ValueError`` - - ``'pos'``: score on those queries is counted as ``1.0`` - - ``'neg'``: score on those queries is counted as ``0.0`` - exclude: - Do not take into account predictions where the target is equal to this value. default `-100` - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects - the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None - - """ - - def __init__( - self, - query_without_relevant_docs: str = 'skip', - exclude: int = IGNORE_IDX, - compute_on_step: bool = True, - dist_sync_on_step: bool = False, - process_group: Optional[Any] = None, - dist_sync_fn: Callable = None - ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn - ) - - query_without_relevant_docs_options = ('error', 'skip', 'pos', 'neg') - if query_without_relevant_docs not in query_without_relevant_docs_options: - raise ValueError( - f"`query_without_relevant_docs` received a wrong value {query_without_relevant_docs}. " - f"Allowed values are {query_without_relevant_docs_options}" - ) - - self.query_without_relevant_docs = query_without_relevant_docs - self.exclude = exclude - - self.add_state("idx", default=[], dist_reduce_fx=None) - self.add_state("preds", default=[], dist_reduce_fx=None) - self.add_state("target", default=[], dist_reduce_fx=None) - - def update(self, idx: torch.Tensor, preds: torch.Tensor, target: torch.Tensor) -> None: - if not (idx.shape == target.shape == preds.shape): - raise ValueError("`idx`, `preds` and `target` must be of the same shape") - - idx = idx.to(dtype=torch.int64).flatten() - preds = preds.to(dtype=torch.float32).flatten() - target = target.to(dtype=torch.int64).flatten() - - self.idx.append(idx) - self.preds.append(preds) - self.target.append(target) - - def compute(self) -> torch.Tensor: - r""" - First concat state `idx`, `preds` and `target` since they were stored as lists. After that, - compute list of groups that will help in keeping together predictions about the same query. - Finally, for each group compute the `_metric` if the number of positive targets is at least - 1, otherwise behave as specified by `self.query_without_relevant_docs`. - """ - - idx = torch.cat(self.idx, dim=0) - preds = torch.cat(self.preds, dim=0) - target = torch.cat(self.target, dim=0) - - res = [] - kwargs = {'device': idx.device, 'dtype': torch.float32} - - groups = get_group_indexes(idx) - for group in groups: - - mini_preds = preds[group] - mini_target = target[group] - - if not mini_target.sum(): - if self.query_without_relevant_docs == 'error': - raise ValueError( - f"`{self.__class__.__name__}.compute()` was provided with " - f"a query without positive targets, indexes: {group}" - ) - if self.query_without_relevant_docs == 'pos': - res.append(torch.tensor(1.0, **kwargs)) - elif self.query_without_relevant_docs == 'neg': - res.append(torch.tensor(0.0, **kwargs)) - else: - res.append(self._metric(mini_preds, mini_target)) - - if len(res) > 0: - return torch.stack(res).mean() - return torch.tensor(0.0, **kwargs) - - @abstractmethod - def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - r""" - Compute a metric over a predictions and target of a single group. - This method should be overridden by subclasses. - """ diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index 1d6f4e93b5779d..0c1ac7b359fd0a 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -53,7 +53,7 @@ def forward(self, *inputs, **kwargs): elif trainer and (trainer.sanity_checking or trainer.validating): output = self.module.validation_step(*inputs, **kwargs) elif trainer and trainer.predicting: - output = self.module.predict(*inputs, **kwargs) + output = self.module.predict_step(*inputs, **kwargs) else: output = self.module(*inputs, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index bcadf16607b4fe..58e26e7db32d85 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -298,7 +298,7 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.model(*args, **kwargs) - def predict(self, *args, **kwargs): + def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) def post_training_step(self): diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index ea1efd6e158734..87d7fa5faecac5 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -282,7 +282,7 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.model(*args, **kwargs) - def predict(self, *args, **kwargs): + def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) def post_training_step(self): diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index b96b7097d07c7e..a8e42e0fa747af 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -83,7 +83,7 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.model(*args, **kwargs) - def predict(self, *args, **kwargs): + def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) def training_step_end(self, output): diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 9f1bafe309f89e..8d0add27cbb29c 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -96,14 +96,14 @@ def start_training(self, trainer): stack.enter_context(optimizer.skip_synchronize()) # set up training routine - self._results = trainer.run_train() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user hvd.join() def start_evaluating(self, trainer): with ExitStack(): - self._results = trainer.run_evaluate() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user hvd.join() @@ -111,7 +111,7 @@ def start_evaluating(self, trainer): def start_predicting(self, trainer): with ExitStack(): # set up training routine - self._results = trainer.run_predict() + self._results = trainer.run_stage() # Make sure all workers have finished training before returning to the user hvd.join() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index c883ff504f24d7..3887e0cd98908f 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -294,8 +294,8 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.lightning_module.test_step(*args, **kwargs) - def predict(self, *args, **kwargs): - return self.lightning_module.predict(*args, **kwargs) + def predict_step(self, *args, **kwargs): + return self.lightning_module.predict_step(*args, **kwargs) def save_checkpoint(self, filepath, weights_only: bool = False): """Save model/training states as a checkpoint file through state-dump and file-write. diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 6a87792c7bd03f..08dca63a7c9250 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -132,15 +132,15 @@ def rpc_enabled(self) -> bool: def start_training(self, trainer: 'Trainer') -> None: # double dispatch to initiate the training loop - self._results = trainer.run_train() + self._results = trainer.run_stage() def start_evaluating(self, trainer: 'Trainer') -> None: # double dispatch to initiate the test loop - self._results = trainer.run_evaluate() + self._results = trainer.run_stage() def start_predicting(self, trainer: 'Trainer') -> None: # double dispatch to initiate the predicting loop - self._results = trainer.run_predict() + self._results = trainer.run_stage() def training_step(self, *args, **kwargs): return self.lightning_module.training_step(*args, **kwargs) @@ -154,8 +154,8 @@ def validation_step(self, *args, **kwargs): def test_step(self, *args, **kwargs): return self.lightning_module.test_step(*args, **kwargs) - def predict(self, *args, **kwargs): - return self.lightning_module.predict(*args, **kwargs) + def predict_step(self, *args, **kwargs): + return self.lightning_module.predict_step(*args, **kwargs) def training_step_end(self, output): return output @@ -182,3 +182,13 @@ def init_optimizers(self, trainer: "Trainer", model: LightningModule): def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): optimizer.step(closure=lambda_closure, **kwargs) + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + """ + Override to delay setting optimizers and schedulers till after dispatch. + This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model. + However this may break certain precision plugins such as APEX which require optimizers to be set. + Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. + """ + return False diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py index e09a5ea11a084d..6ac6e16c185290 100644 --- a/pytorch_lightning/profiler/__init__.py +++ b/pytorch_lightning/profiler/__init__.py @@ -121,7 +121,8 @@ def custom_processing_step(self, data): Autograd includes a profiler that lets you inspect the cost of different operators inside your model - both on the CPU and GPU. -Find the Pytorch Profiler doc at [PyTorch Profiler](https://pytorch-lightning.readthedocs.io/en/stable/profiler.html) +To read more about the PyTorch Profiler and all its options, +have a look at its `docs `__ .. code-block:: python @@ -134,16 +135,16 @@ def custom_processing_step(self, data): This profiler works with PyTorch ``DistributedDataParallel``. -If ``output_filename`` is provided, each rank will save their profiled operation to their own file. +If ``filename`` is provided, each rank will save their profiled operation to their own file. The profiler +report can be quite long, so you setting a ``filename`` will save the report instead of logging it to the +output in your terminal. If no filename is given, it will be logged only on rank 0. +The profiler's results will be printed on the completion of ``{fit,validate,test,predict}``. -The profiler's results will be printed on the completion of a training `fit()`. This profiler -report can be quite long, so you can also specify an `output_filename` to save the report instead -of logging it to the output in your terminal. - -This profiler will record only for `training_step_and_backward`, `evaluation_step` and `test_step` functions by default. -The output below shows the profiling for the action `training_step_and_backward`. -The user can provide ``PyTorchProfiler(profiled_functions=[...])`` to extend the scope of profiled functions. +This profiler will record ``training_step_and_backward``, ``training_step``, ``backward``, +``validation_step``, ``test_step``, and ``predict_step`` by default. +The output below shows the profiling for the action ``training_step_and_backward``. +The user can provide ``PyTorchProfiler(record_functions={...})`` to extend the scope of profiled functions. .. note:: When using the PyTorch Profiler, wall clock time will not not be representative of the true wall clock time. This is due to forcing profiled operations to be measured synchronously, when many CUDA ops happen asynchronously. It is recommended to use this Profiler to find bottlenecks/breakdowns, however for end to end wall clock time use the `SimpleProfiler`. # noqa E501 @@ -184,13 +185,13 @@ def custom_processing_step(self, data): To visualize the profiled operation, you can either: -* Use:: +Use:: nvvp trace_name.prof -* Use:: +Or:: - python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))' + python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))' """ diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index d704ba83236c16..bc9e3541dbaa8a 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -21,31 +21,19 @@ from abc import ABC, abstractmethod from collections import defaultdict from contextlib import contextmanager -from typing import Optional, Union +from pathlib import Path +from typing import Any, Callable, Dict, Optional, TextIO, Tuple, Union import numpy as np +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem log = logging.getLogger(__name__) -class BaseProfiler(ABC): - """ - If you wish to write a custom profiler, you should inhereit from this class. - """ - - def __init__(self, output_streams: Optional[Union[list, tuple]] = None): - """ - Args: - output_streams: callable - """ - if output_streams: - if not isinstance(output_streams, (list, tuple)): - output_streams = [output_streams] - else: - output_streams = [] - self.write_streams = output_streams +class AbstractProfiler(ABC): + """Specification of a profiler.""" @abstractmethod def start(self, action_name: str) -> None: @@ -55,6 +43,48 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: """Defines how to record the duration once an action is complete.""" + @abstractmethod + def summary(self) -> str: + """Create profiler summary in text format.""" + + @abstractmethod + def setup(self, **kwargs: Any) -> None: + """Execute arbitrary pre-profiling set-up steps as defined by subclass.""" + + @abstractmethod + def teardown(self, **kwargs: Any) -> None: + """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" + + +class BaseProfiler(AbstractProfiler): + """ + If you wish to write a custom profiler, you should inherit from this class. + """ + + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + output_filename: Optional[str] = None, + ) -> None: + self.dirpath = dirpath + self.filename = filename + if output_filename is not None: + rank_zero_warn( + "`Profiler` signature has changed in v1.3. The `output_filename` parameter has been removed in" + " favor of `dirpath` and `filename`. Support for the old signature will be removed in v1.5", + DeprecationWarning + ) + filepath = Path(output_filename) + self.dirpath = filepath.parent + self.filename = filepath.stem + + self._output_file: Optional[TextIO] = None + self._write_stream: Optional[Callable] = None + self._local_rank: Optional[int] = None + self._log_dir: Optional[str] = None + self._stage: Optional[str] = None + @contextmanager def profile(self, action_name: str) -> None: """ @@ -86,17 +116,92 @@ def profile_iterable(self, iterable, action_name: str) -> None: self.stop(action_name) break + def _rank_zero_info(self, *args, **kwargs) -> None: + if self._local_rank in (None, 0): + log.info(*args, **kwargs) + + def _prepare_filename(self, extension: str = ".txt") -> str: + filename = "" + if self._stage is not None: + filename += f"{self._stage}-" + filename += str(self.filename) + if self._local_rank is not None: + filename += f"-{self._local_rank}" + filename += extension + return filename + + def _prepare_streams(self) -> None: + if self._write_stream is not None: + return + if self.filename: + filepath = os.path.join(self.dirpath, self._prepare_filename()) + fs = get_filesystem(filepath) + file = fs.open(filepath, "a") + self._output_file = file + self._write_stream = file.write + else: + self._write_stream = self._rank_zero_info + def describe(self) -> None: - """Logs a profile report after the conclusion of the training run.""" - for write in self.write_streams: - write(self.summary()) + """Logs a profile report after the conclusion of run.""" + # there are pickling issues with open file handles in Python 3.6 + # so to avoid them, we open and close the files within this function + # by calling `_prepare_streams` and `teardown` + self._prepare_streams() + self._write_stream(self.summary()) + if self._output_file is not None: + self._output_file.flush() + self.teardown(stage=self._stage) + + def _stats_to_str(self, stats: Dict[str, str]) -> str: + stage = f"{self._stage.upper()} " if self._stage is not None else "" + output = [stage + "Profiler Report"] + for action, value in stats.items(): + header = f"Profile stats for: {action}" + if self._local_rank is not None: + header += f" rank: {self._local_rank}" + output.append(header) + output.append(value) + return os.linesep.join(output) + + def setup( + self, + stage: Optional[str] = None, + local_rank: Optional[int] = None, + log_dir: Optional[str] = None, + ) -> None: + """Execute arbitrary pre-profiling set-up steps.""" + self._stage = stage + self._local_rank = local_rank + self._log_dir = log_dir + self.dirpath = self.dirpath or log_dir + + def teardown(self, stage: Optional[str] = None) -> None: + """ + Execute arbitrary post-profiling tear-down steps. + + Closes the currently open file and stream. + """ + self._write_stream = None + if self._output_file is not None: + self._output_file.close() + self._output_file = None # can't pickle TextIOWrapper + + def __del__(self) -> None: + self.teardown(stage=self._stage) + + def start(self, action_name: str) -> None: + raise NotImplementedError + + def stop(self, action_name: str) -> None: + raise NotImplementedError - @abstractmethod def summary(self) -> str: - """Create profiler summary in text format.""" + raise NotImplementedError - def on_train_start(self, local_rank: Optional[int] = None): - self.local_rank = local_rank + @property + def local_rank(self) -> int: + return 0 if self._local_rank is None else self._local_rank class PassThroughProfiler(BaseProfiler): @@ -105,9 +210,6 @@ class PassThroughProfiler(BaseProfiler): The Trainer uses this class by default. """ - def __init__(self): - super().__init__(output_streams=None) - def start(self, action_name: str) -> None: pass @@ -124,30 +226,32 @@ class SimpleProfiler(BaseProfiler): the mean duration of each action and the total time spent over the entire training run. """ - def __init__(self, output_filename: Optional[str] = None, extended=True): + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + extended: bool = True, + output_filename: Optional[str] = None, + ) -> None: """ Args: - output_filename: optionally save profile results to file instead of printing - to std out when training is finished. + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. + + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. Raises: ValueError: If you attempt to start an action which has already started, or if you attempt to stop recording an action which was never started. """ - self.current_actions = {} + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) + self.current_actions: Dict[str, float] = {} self.recorded_durations = defaultdict(list) self.extended = extended - - self.output_fname = output_filename - self.output_file = None - if self.output_fname: - fs = get_filesystem(self.output_fname) - self.output_file = fs.open(self.output_fname, "w") - - streaming_out = [self.output_file.write] if self.output_file else [log.info] self.start_time = time.monotonic() - super().__init__(output_streams=streaming_out) def start(self, action_name: str) -> None: if action_name in self.current_actions: @@ -162,14 +266,18 @@ def stop(self, action_name: str) -> None: duration = end_time - start_time self.recorded_durations[action_name].append(duration) - def make_report(self): + def _make_report(self) -> Tuple[list, float]: total_duration = time.monotonic() - self.start_time report = [[a, d, 100. * np.sum(d) / total_duration] for a, d in self.recorded_durations.items()] report.sort(key=lambda x: x[2], reverse=True) return report, total_duration def summary(self) -> str: - output_string = "\n\nProfiler Report\n" + sep = os.linesep + output_string = "" + if self._stage is not None: + output_string += f"{self._stage.upper()} " + output_string += f"Profiler Report{sep}" if self.extended: @@ -177,16 +285,16 @@ def summary(self) -> str: max_key = np.max([len(k) for k in self.recorded_durations.keys()]) def log_row(action, mean, num_calls, total, per): - row = f"{os.linesep}{action:<{max_key}s}\t| {mean:<15}\t|" + row = f"{sep}{action:<{max_key}s}\t| {mean:<15}\t|" row += f"{num_calls:<15}\t| {total:<15}\t| {per:<15}\t|" return row output_string += log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %") output_string_len = len(output_string) - output_string += f"{os.linesep}{'-' * output_string_len}" - report, total_duration = self.make_report() + output_string += f"{sep}{'-' * output_string_len}" + report, total_duration = self._make_report() output_string += log_row("Total", "-", "_", f"{total_duration:.5}", "100 %") - output_string += f"{os.linesep}{'-' * output_string_len}" + output_string += f"{sep}{'-' * output_string_len}" for action, durations, duration_per in report: output_string += log_row( action, @@ -198,27 +306,16 @@ def log_row(action, mean, num_calls, total, per): else: def log_row(action, mean, total): - return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}" + return f"{sep}{action:<20s}\t| {mean:<15}\t| {total:<15}" output_string += log_row("Action", "Mean duration (s)", "Total time (s)") - output_string += f"{os.linesep}{'-' * 65}" + output_string += f"{sep}{'-' * 65}" for action, durations in self.recorded_durations.items(): output_string += log_row(action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}") - output_string += os.linesep + output_string += sep return output_string - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - if self.output_file: - self.output_file.flush() - - def __del__(self): - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() - class AdvancedProfiler(BaseProfiler): """ @@ -227,11 +324,22 @@ class AdvancedProfiler(BaseProfiler): verbose and you should only use this if you want very detailed reports. """ - def __init__(self, output_filename: Optional[str] = None, line_count_restriction: float = 1.0): + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + line_count_restriction: float = 1.0, + output_filename: Optional[str] = None, + ) -> None: """ Args: - output_filename: optionally save profile results to file instead of printing - to std out when training is finished. + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. + + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. + line_count_restriction: this can be used to limit the number of functions reported for each action. either an integer (to select a count of lines), or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines) @@ -240,18 +348,10 @@ def __init__(self, output_filename: Optional[str] = None, line_count_restriction ValueError: If you attempt to stop recording an action which was never started. """ - self.profiled_actions = {} + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) + self.profiled_actions: Dict[str, cProfile.Profile] = {} self.line_count_restriction = line_count_restriction - self.output_fname = output_filename - self.output_file = None - if self.output_fname: - fs = get_filesystem(self.output_fname) - self.output_file = fs.open(self.output_fname, "w") - - streaming_out = [self.output_file.write] if self.output_file else [log.info] - super().__init__(output_streams=streaming_out) - def start(self, action_name: str) -> None: if action_name not in self.profiled_actions: self.profiled_actions[action_name] = cProfile.Profile() @@ -260,9 +360,7 @@ def start(self, action_name: str) -> None: def stop(self, action_name: str) -> None: pr = self.profiled_actions.get(action_name) if pr is None: - raise ValueError( # pragma: no-cover - f"Attempting to stop recording an action ({action_name}) which was never started." - ) + raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.") pr.disable() def summary(self) -> str: @@ -272,21 +370,16 @@ def summary(self) -> str: ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative') ps.print_stats(self.line_count_restriction) recorded_stats[action_name] = s.getvalue() + return self._stats_to_str(recorded_stats) - # log to standard out - output_string = f"{os.linesep}Profiler Report{os.linesep}" - for action, stats in recorded_stats.items(): - output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}" - - return output_string - - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - if self.output_file: - self.output_file.flush() + def teardown(self, stage: Optional[str] = None) -> None: + super().teardown(stage=stage) + self.profiled_actions = {} - def __del__(self): - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() + def __reduce__(self): + # avoids `TypeError: cannot pickle 'cProfile.Profile' object` + return ( + self.__class__, + tuple(), + dict(dirpath=self.dirpath, filename=self.filename, line_count_restriction=self.line_count_restriction), + ) diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 88a33a3d367f8b..73abc1baf939d8 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -12,27 +12,193 @@ # See the License for the specific language governing permissions and # limitations under the License. """Profiler to check if there are any bottlenecks in your code.""" - import inspect import logging import os -from typing import List, Optional +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Set, Type, TYPE_CHECKING, Union import torch +from torch import nn, Tensor +from torch.autograd.profiler import record_function from pytorch_lightning.profiler.profilers import BaseProfiler -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE + +if TYPE_CHECKING: + from torch.autograd.profiler import EventList + from torch.utils.hooks import RemovableHandle + + from pytorch_lightning.core.lightning import LightningModule + +if _KINETO_AVAILABLE: + from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler log = logging.getLogger(__name__) +_PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx] + + +class RegisterRecordFunction: + """ + While profiling autograd operations, this class will add labels for module names around the forward function. + + The Lightning PyTorch Profiler will activate this feature automatically. It can be deactivated as follows: + + Example:: + from pytorch_lightning.profilers import PyTorchProfiler + profiler = PyTorchProfiler(record_module_names=False) + Trainer(profiler=profiler) + + It can be used outside of Lightning as follows: + + Example:: + from pytorch_lightning import Trainer, seed_everything + with RegisterRecordFunction(model): + out = model(batch) + """ + + def __init__(self, model: nn.Module) -> None: + self._model = model + self._records: Dict[str, record_function] = {} + self._handles: Dict[str, List['RemovableHandle']] = {} + + def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor: + record = record_function(record_name) + record.__enter__() + self._records[record_name] = record + return input + + def _stop_recording_forward(self, _: nn.Module, __: Tensor, output: Tensor, record_name: str) -> Tensor: + self._records[record_name].__exit__(None, None, None) + return output + + def __enter__(self) -> None: + for module_name, module in self._model.named_modules(): + if module_name: + full_name = f"{type(module).__module__}.{type(module).__name__}" + record_name = f"{full_name}: {module_name}" + pre_forward_handle = module.register_forward_pre_hook( + partial(self._start_recording_forward, record_name=record_name) + ) + post_forward_handle = module.register_forward_hook( + partial(self._stop_recording_forward, record_name=record_name) + ) + + self._handles[module_name] = [pre_forward_handle, post_forward_handle] + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + for handles in self._handles.values(): + for h in handles: + h.remove() + self._handles = {} + + +class ScheduleWrapper: + """ + This class is used to override the schedule logic from the profiler and perform + recording for both `training_step`, `validation_step`. + """ + + def __init__(self, schedule: Callable) -> None: + if not _KINETO_AVAILABLE: + raise ModuleNotFoundError("You are trying to use `ScheduleWrapper` which require kineto install.") + self._schedule = schedule + self._num_training_step_and_backward = 0 + self._num_validation_step = 0 + self._num_test_step = 0 + self._num_predict_step = 0 + self._training_step_and_backward_reached_end = False + self._validation_step_reached_end = False + self._test_step_reached_end = False + self._predict_step_reached_end = False + # used to stop profiler when `ProfilerAction.RECORD_AND_SAVE` is reached. + self._current_action: Optional[str] = None + self._start_action_name: Optional[str] = None + + def setup(self, start_action_name: str) -> None: + self._start_action_name = start_action_name + + def pre_step(self, current_action: str) -> None: + self._current_action = current_action + + @property + def num_step(self) -> int: + if self._current_action == "training_step_and_backward": + return self._num_training_step_and_backward + elif self._current_action == "validation_step": + return self._num_validation_step + elif self._current_action == "test_step": + return self._num_test_step + elif self._current_action == "predict_step": + return self._num_predict_step + else: + return 0 + + def _step(self) -> None: + if self._current_action == "training_step_and_backward": + self._num_training_step_and_backward += 1 + elif self._current_action == "validation_step": + if self._start_action_name == "on_train_start" and self._num_training_step_and_backward > 0: + self._num_validation_step += 1 + else: + self._num_validation_step += 1 + elif self._current_action == "test_step": + self._num_test_step += 1 + elif self._current_action == "predict_step": + self._num_predict_step += 1 + + @property + def has_finished(self) -> bool: + if self._current_action == "training_step_and_backward": + return self._training_step_and_backward_reached_end + elif self._current_action == "validation_step": + return self._validation_step_reached_end + elif self._current_action == "test_step": + return self._test_step_reached_end + elif self._current_action == "predict_step": + return self._predict_step_reached_end + return False + + def __call__(self, num_step: int) -> 'ProfilerAction': + # ignore the provided input. Keep internal state instead. + if self.has_finished: + return ProfilerAction.NONE + + self._step() + action = self._schedule(self.num_step) + if action == ProfilerAction.RECORD_AND_SAVE: + if self._current_action == "training_step_and_backward": + self._training_step_and_backward_reached_end = True + elif self._current_action == "validation_step": + self._validation_step_reached_end = True + elif self._current_action == "test_step": + self._test_step_reached_end = True + elif self._current_action == "predict_step": + self._predict_step_reached_end = True + return action + class PyTorchProfiler(BaseProfiler): - PROFILED_FUNCTIONS = ("training_step_and_backward", "validation_step", "test_step") - AVAILABLE_SORT_KEYS = ( + RECORD_FUNCTIONS = { + "training_step_and_backward", + "training_step", + "backward", + "validation_step", + "test_step", + "predict_step", + } + STEP_FUNCTIONS = { + "training_step_and_backward", + "validation_step", + "test_step", + "predict_step", + } + AVAILABLE_SORT_KEYS = { "cpu_time", "cuda_time", "cpu_time_total", @@ -42,56 +208,43 @@ class PyTorchProfiler(BaseProfiler): "self_cpu_memory_usage", "self_cuda_memory_usage", "count", - ) + } + START_RECORD_FUNCTIONS = { + 'on_train_start', + 'on_validation_start', + 'on_test_start', + 'on_predict_start', + } def __init__( self, - output_filename: Optional[str] = None, - enabled: bool = True, - use_cuda: bool = False, - record_shapes: bool = False, - profile_memory: bool = False, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, group_by_input_shapes: bool = False, - with_stack: bool = False, - use_kineto: bool = False, - use_cpu: bool = True, emit_nvtx: bool = False, - export_to_chrome: bool = False, - path_to_export_trace: str = None, + export_to_chrome: bool = True, row_limit: int = 20, sort_by_key: Optional[str] = None, + record_functions: Set[str] = None, + record_module_names: bool = True, profiled_functions: Optional[List] = None, - local_rank: Optional[int] = None, - ): + output_filename: Optional[str] = None, + **profiler_kwargs: Any, + ) -> None: """ This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of different operators inside your model - both on the CPU and GPU Args: + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. - output_filename: optionally save profile results to file instead of printing - to std out when training is finished. When using ``ddp``, - each rank will stream the profiled operation to their own file - with the extension ``_{rank}.txt`` - - enabled: Setting this to False makes this context manager a no-op. - - use_cuda: Enables timing of CUDA events as well using the cudaEvent API. - Adds approximately 4us of overhead to each tensor operation. - - record_shapes: If shapes recording is set, information about input dimensions will be collected. - - profile_memory: Whether to report memory usage, default: True (Introduced in PyTorch 1.6.0) + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. group_by_input_shapes: Include operator input shapes and group calls by shape. - with_stack: record source information (file and line number) for the ops (Introduced in PyTorch 1.7.0) - - use_kineto: experimental support for Kineto profiler (Introduced in PyTorch 1.8.0) - - use_cpu: use_kineto=True and can be used to lower the overhead - for GPU-only profiling (Introduced in PyTorch 1.8.0) - emit_nvtx: Context manager that makes every autograd operation emit an NVTX range Run:: @@ -102,202 +255,256 @@ def __init__( nvvp trace_name.prof torch.autograd.profiler.load_nvprof(path) - export_to_chrome: Wether to export the sequence of profiled operators for Chrome. + export_to_chrome: Whether to export the sequence of profiled operators for Chrome. It will generate a ``.json`` file which can be read by Chrome. - path_to_export_trace: Directory path to export ``.json`` traces when using ``export_to_chrome=True``. - By default, it will be save where the file being is being run. - - row_limit: Limit the number of rows in a table, `0` is a special value that + row_limit: Limit the number of rows in a table, ``-1`` is a special value that removes the limit completely. - sort_by_key: Keys to sort out profiled table + sort_by_key: Attribute used to sort entries. By default + they are printed in the same order as they were registered. + Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``, + ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``, + ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``. - profiled_functions: list of profiled functions which will create a context manager on. + record_functions: Set of profiled functions which will create a context manager on. Any other will be pass through. - local_rank: When running in distributed setting, local_rank is used for each process - to write to their own file if `output_fname` is provided. + record_module_names: Whether to add module names while recording autograd operation. + + profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version Raises: MisconfigurationException: - If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``, or - if log file is not a ``.txt`` file. - ValueError: - If you attempt to stop recording an action which was never started. + If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``. + If arg ``schedule`` is not a ``Callable``. + If arg ``schedule`` does not return a ``torch.profiler.ProfilerAction``. """ + super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename) + + record_functions = self.__deprecation_check(profiled_functions, record_functions) + + self._group_by_input_shapes = group_by_input_shapes and profiler_kwargs.get("record_shapes", False) + self._emit_nvtx = emit_nvtx + self._export_to_chrome = export_to_chrome + self._row_limit = row_limit + self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total" + self._record_functions_start = record_functions | self.START_RECORD_FUNCTIONS + self._record_functions = record_functions | self.RECORD_FUNCTIONS + self._record_module_names = record_module_names + self._profiler_kwargs = profiler_kwargs + + self.profiler: Optional[_PROFILER] = None + self.function_events: Optional['EventList'] = None + self._lightning_module: Optional['LightningModule'] = None # set by ProfilerConnector + self._register: Optional[RegisterRecordFunction] = None + self._parent_profiler: Optional[_PROFILER] = None + self._recording_map: Dict[str, record_function] = {} + self._start_action_name: Optional[str] = None + self._schedule: Optional[ScheduleWrapper] = None + + if _KINETO_AVAILABLE: + self.__init_kineto__(profiler_kwargs) + + if self._sort_by_key not in self.AVAILABLE_SORT_KEYS: + raise MisconfigurationException( + f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " + ) - self.profiled_actions = {} - self.enabled = enabled - self.profiled_functions = profiled_functions or self.PROFILED_FUNCTIONS - self.use_cuda = use_cuda - self.record_shapes = record_shapes - self.profile_memory = profile_memory - self.sort_by_key = sort_by_key or ("cuda_time_total" if self.use_cuda else "cpu_time_total") - self.with_stack = with_stack - self.group_by_input_shapes = group_by_input_shapes and record_shapes - self.use_kineto = use_kineto - self.use_cpu = use_cpu - self.row_limit = row_limit - self.emit_nvtx = emit_nvtx - self.export_to_chrome = export_to_chrome - self.path_to_export_trace = path_to_export_trace - - if export_to_chrome and path_to_export_trace is None: + def __init_kineto__(self, profiler_kwargs: Any): + has_schedule = "schedule" in profiler_kwargs + self._has_on_trace_ready = "on_trace_ready" in profiler_kwargs + + schedule = profiler_kwargs.get("schedule", None) + if schedule is not None: + if not isinstance(schedule, Callable): + raise MisconfigurationException(f"Schedule should be a callable. Found: {schedule}") + action = schedule(0) + if not isinstance(action, ProfilerAction): + raise MisconfigurationException( + f"Schedule should return a `torch.profiler.ProfilerAction`. Found: {action}" + ) + schedule = schedule if has_schedule else self._default_schedule() + self._schedule = ScheduleWrapper(schedule) if schedule is not None else schedule + self._profiler_kwargs["schedule"] = self._schedule + + activities = profiler_kwargs.get("activities", None) + self._profiler_kwargs["activities"] = activities or self._default_activities() + self._export_to_flame_graph = profiler_kwargs.get("export_to_flame_graph", False) + self._metric = profiler_kwargs.get("metric", "self_cpu_time_total") + with_stack = profiler_kwargs.get("with_stack", False) or self._export_to_flame_graph + self._profiler_kwargs["with_stack"] = with_stack + + def __deprecation_check( + self, + profiled_functions: Optional[List[str]], + record_functions: Optional[Set[str]], + ) -> Set[str]: + if record_functions is None: + record_functions = set() + + if profiled_functions is not None: rank_zero_warn( - "The exported trace would be save locally as `path_to_export_trace` is empty." - " Note: Each functions will generate its own traced file." + "`PyTorchProfiler.profiled_functions` has been renamed to" + " `record_functions` in v1.3 and will be removed in v1.5", DeprecationWarning ) + if not record_functions: + record_functions |= set(profiled_functions) + else: + raise MisconfigurationException( + "You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`." + " Please use only the later." + ) + + return record_functions + + @staticmethod + def _default_schedule() -> Optional[callable]: + if _KINETO_AVAILABLE: + # Those schedule defaults allow the profiling overhead to be negligible over training time. + return torch.profiler.schedule(wait=1, warmup=1, active=2) + + def _default_activities(self) -> List['ProfilerActivity']: + activities = [] + if not _KINETO_AVAILABLE: + return activities + if self._profiler_kwargs.get("use_cpu", True): + activities.append(ProfilerActivity.CPU) + if self._profiler_kwargs.get("use_cuda", torch.cuda.is_available()): + activities.append(ProfilerActivity.CUDA) + return activities + + @property + def step_action_names(self) -> Set[str]: + return self.STEP_FUNCTIONS | self._record_functions - if self.sort_by_key not in self.AVAILABLE_SORT_KEYS: - raise MisconfigurationException( - f"Found sort_by_key: {sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " - ) + def start(self, action_name: str) -> None: + if self.profiler is None and action_name in self._record_functions_start: - self.profiled_actions = {} - self.context_names = {} - self.running_stack = [] - self.profiler = None + # close profiler if it is already opened. might happen if 2 profilers + # are created and the first one did not call `describe` + try: + torch.autograd._disable_profiler() # noqa + except (AttributeError, RuntimeError): + pass - self.output_fname = output_filename - self.output_file = None - if local_rank is not None: - self.on_train_start(local_rank=local_rank) - self.on_train_start = super().on_train_start + if self._schedule is not None: + self._schedule.setup(action_name) - def on_train_start(self, local_rank: Optional[str] = None): - self.local_rank = local_rank + self._create_profilers() - # when logging to `log.info`, only perform profiling on rank 0 - if local_rank != 0 and self.output_fname is None: - self.wrap_functions_into_rank_zero_only() + profiler = self.profiler.__enter__() + if profiler is not None: + self.profiler = profiler - if self.output_fname: - if local_rank is not None: - if '.txt' not in self.output_fname: - raise MisconfigurationException("Log file should be .txt file.") + if self._parent_profiler is not None: + self._parent_profiler.__enter__() - self.output_fname = self.output_fname.replace(".txt", f"_{self.local_rank}.txt") + if self._register is not None: + self._register.__enter__() - fs = get_filesystem(self.output_fname) - self.output_file = fs.open(self.output_fname, "w") + if ( + self.profiler is not None and action_name in self._record_functions + and action_name not in self._recording_map + ): + recording = record_function(action_name) + recording.__enter__() + self._recording_map[action_name] = recording - streaming_out = [self.output_file.write] if self.output_file else [log.info] - super().__init__(output_streams=streaming_out) + if self._schedule is not None: + self._schedule.pre_step(action_name) - def wrap_functions_into_rank_zero_only(self): - self.start = rank_zero_only(self.start) - self.stop = rank_zero_only(self.stop) - self.summary = rank_zero_only(self.summary) - self.describe = rank_zero_only(self.describe) + def stop(self, action_name: str) -> None: + if action_name in self._recording_map: + self._recording_map[action_name].__exit__(None, None, None) + del self._recording_map[action_name] - def start(self, action_name: str) -> None: - if action_name not in self.profiled_functions: + if not _KINETO_AVAILABLE or self._emit_nvtx: return - if len(self.running_stack) > 0: - self._stop(self.running_stack[-1]) - self.running_stack.append(action_name) + if action_name in self.step_action_names: + if self._schedule is not None: + self._schedule._current_action = action_name - self.context_names[action_name] = "/".join(self.running_stack) + def on_trace_ready(profiler): + filename = f"{action_name}_{self.local_rank}" - self._start(action_name) + if self.dirpath is not None: + if self._export_to_chrome: + handler = tensorboard_trace_handler(self.dirpath, filename) + handler(profiler) - def _start(self, action_name: str) -> None: - if self.emit_nvtx: - self._parent_profiler = self._create_profiler(action_name, torch.cuda.profiler.profile, enter=True) - self._create_profiler(action_name, torch.autograd.profiler.emit_nvtx) - else: - self._create_profiler(action_name, torch.autograd.profiler.profile) - - def _create_profiler(self, action_name, profiler, enter=True): - init_args = inspect.signature(profiler.__init__).parameters - profiler_args = {k: v for k, v in vars(self).items() if k in init_args} - pr = profiler(**profiler_args) - if enter: - out_pr = pr.__enter__() - if out_pr is not None: - pr = out_pr - self.profiler = pr - return self.profiler - - def _stop(self, action_name: str) -> None: - if self.profiler is None: - return + if self._export_to_flame_graph: + path = os.path.join(self.dirpath, self._prepare_filename(extension=".stack")) + profiler.export_stacks(path, metric=self._metric) + else: + rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None") - self.profiler.__exit__(exc_type=None, exc_val=None, exc_tb=None) + if not self._has_on_trace_ready: + self.profiler.on_trace_ready = on_trace_ready + self.profiler.step() - if isinstance(self.profiler, torch.autograd.profiler.emit_nvtx): - # when running ``emit_nvtx``, PyTorch requires 2 context manager. - # The parent_profiler is being closed too. - self._parent_profiler.__exit__(None, None, None) - return + def summary(self) -> str: + if not self._profiler_kwargs.get("enabled", True) or self._emit_nvtx: + return "" - function_events = self.profiler.function_events - self.profiler = None - for name in self.running_stack: - if name not in self.profiled_actions: - self.profiled_actions[name] = function_events - else: - self.profiled_actions[name] += function_events + self._delete_profilers() - def stop(self, action_name: str) -> None: - if action_name not in self.profiled_functions: - return + if not self.function_events: + return "" + + if self._export_to_chrome and not _KINETO_AVAILABLE: + filename = f"{self.local_rank}_trace.json" + path_to_trace = (filename if self.dirpath is None else os.path.join(self.dirpath, filename)) + self.function_events.export_chrome_trace(path_to_trace) - if len(self.running_stack) == 0 or self.running_stack[-1] != action_name: - raise ValueError( # pragma: no-cover - f"Attempting to stop recording an action ({action_name}) which was never started." + data = self.function_events.key_averages(group_by_input_shapes=self._group_by_input_shapes) + table = data.table(sort_by=self._sort_by_key, row_limit=self._row_limit) + + recorded_stats = {"records": table} + return self._stats_to_str(recorded_stats) + + def _create_profilers(self) -> None: + if self._emit_nvtx: + self._parent_profiler = self._create_profiler(torch.cuda.profiler.profile) + self.profiler = self._create_profiler(torch.autograd.profiler.emit_nvtx) + else: + self._parent_profiler = None + self.profiler = self._create_profiler( + torch.profiler.profile if _KINETO_AVAILABLE else torch.autograd.profiler.profile ) - self._stop(action_name) - self.running_stack.pop() - # restore running profiler - if len(self.running_stack) > 0: - self._start(self.running_stack[-1]) + if self._record_module_names and self._lightning_module is not None: + self._register = RegisterRecordFunction(self._lightning_module) - def summary(self) -> str: - recorded_stats = {} - output_string = '' - local_rank = '0' if self.local_rank is None else self.local_rank + def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER: + init_parameters = inspect.signature(profiler.__init__).parameters + kwargs = {k: v for k, v in self._profiler_kwargs.items() if k in init_parameters} + return profiler(**kwargs) - if not self.enabled: - return output_string + def _cache_functions_events(self) -> None: + if self._emit_nvtx: + return + self.function_events = self.profiler.events() if _KINETO_AVAILABLE else self.profiler.function_events - for action_name, function_events in self.profiled_actions.items(): + def _delete_profilers(self) -> None: + if self.profiler is not None: + self.profiler.__exit__(None, None, None) + self._cache_functions_events() + self.profiler = None - # next line is a workaround for a pytorch issue (fixed on master, still present - # on 1.7). Without it the code fails with `AssertionError: There is already a CPU - # parent event for detach` - function_events.populate_cpu_children = lambda: None + if self._parent_profiler is not None: + self._parent_profiler.__exit__(None, None, None) + self._parent_profiler = None - if self.export_to_chrome: - filename = f"{action_name}_{local_rank}_trace.json" - path_to_trace = filename if self.path_to_export_trace is None \ - else os.path.join(self.path_to_export_trace, filename) - function_events.export_chrome_trace(path_to_trace) + if self._register is not None: + self._register.__exit__(None, None, None) + self._register = None - if self.emit_nvtx: - return output_string + def teardown(self, stage: Optional[str] = None) -> None: + self._delete_profilers() - else: - data = function_events.key_averages(group_by_input_shapes=self.group_by_input_shapes) - table = data.table(sort_by=self.sort_by_key, row_limit=self.row_limit) - recorded_stats[action_name] = table - - # log to standard out - output_string = f"{os.linesep}Profiler Report{os.linesep}" - for action, stats in recorded_stats.items(): - output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}") - - return output_string - - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - if self.output_file: - self.output_file.flush() - - def __del__(self): - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() + for k in self._recording_map: + self.stop(k) + self._recording_map = {} + + super().teardown(stage=stage) diff --git a/pytorch_lightning/setup_tools.py b/pytorch_lightning/setup_tools.py index f5aed2608635e4..3362ccb479895e 100644 --- a/pytorch_lightning/setup_tools.py +++ b/pytorch_lightning/setup_tools.py @@ -16,7 +16,7 @@ import re from typing import List -from pytorch_lightning import __homepage__, __version__, _PROJECT_ROOT +_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]: @@ -40,10 +40,10 @@ def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comme return reqs -def _load_readme_description(path_dir: str, homepage: str = __homepage__, version: str = __version__) -> str: +def _load_readme_description(path_dir: str, homepage: str, version: str) -> str: """Load readme as decribtion - >>> _load_readme_description(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> _load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE '
...' """ path_readme = os.path.join(path_dir, "README.md") diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 8c539b5ff478de..a7ba2b1c401232 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -40,7 +40,8 @@ def verify_loop_configurations(self, model: LightningModule) -> None: self.__verify_eval_loop_configuration(model, 'val') elif self.trainer.state == TrainerState.TESTING: self.__verify_eval_loop_configuration(model, 'test') - # TODO: add predict + elif self.trainer.state == TrainerState.PREDICTING: + self.__verify_predict_loop_configuration(model) def __verify_train_loop_configuration(self, model): # ----------------------------------- @@ -99,3 +100,9 @@ def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) - rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop') if has_step and not has_loader: rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop') + + def __verify_predict_loop_configuration(self, model: LightningModule) -> None: + + has_predict_dataloader = is_overridden('predict_dataloader', model) + if not has_predict_dataloader: + raise MisconfigurationException('Dataloader not found for `Trainer.predict`') diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index b3fc0b4eb7b297..5d2f141dc64a83 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -150,6 +150,10 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = N self.trainer.datamodule = datamodule datamodule.trainer = self.trainer + # experimental feature for Flash + if hasattr(datamodule, "data_pipeline"): + model.data_pipeline = datamodule.data_pipeline + class _PatchDataLoader(object): r""" diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index 98d65c1285ff79..191e8711463ab0 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License - from typing import Union +from weakref import proxy from pytorch_lightning.profiler import ( AdvancedProfiler, @@ -54,6 +54,8 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, str]): ) self.trainer.profiler = profiler or PassThroughProfiler() - def on_train_start(self, trainer): + def setup(self) -> None: + trainer = self.trainer local_rank = trainer.local_rank if trainer.world_size > 1 else None - self.trainer.profiler.on_train_start(local_rank) + trainer.profiler.lightning_module = proxy(trainer.lightning_module) + trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 56da7039bbca75..1a9c69d107b972 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -16,7 +16,7 @@ import platform from abc import ABC from copy import deepcopy -from typing import Callable, Iterable, List, Tuple, Union +from typing import Iterable, List, Tuple, Union from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -191,7 +191,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: Args: model: The current `LightningModule` """ - self.train_dataloader = self.request_dataloader(model.train_dataloader) + self.train_dataloader = self.request_dataloader(model, "train") if self.overfit_batches > 0: if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler): @@ -271,7 +271,7 @@ def _reset_eval_dataloader( """ # always get the loaders first so we can count how many there are loader_name = f'{mode}_dataloader' - dataloaders = self.request_dataloader(getattr(model, loader_name)) + dataloaders = self.request_dataloader(model, mode) if not isinstance(dataloaders, list): dataloaders = [dataloaders] @@ -280,7 +280,7 @@ def _reset_eval_dataloader( # duplicate it the numb of times needed to match the train loaders if self.overfit_batches > 0: num_loaders = len(dataloaders) - train_dataloader = self.request_dataloader(getattr(model, 'train_dataloader')) + train_dataloader = self.request_dataloader(model, 'train') dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)] self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders) @@ -380,7 +380,7 @@ def reset_predict_dataloader(self, model) -> None: if has_loader: self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(model, 'predict') - def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: + def request_dataloader(self, model: LightningModule, stage: str) -> DataLoader: """Handles downloading data in the GPU or TPU case. Args: @@ -389,9 +389,10 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: Returns: The dataloader """ - dataloader = dataloader_fx() + if model.trainer is not None: + model.trainer.call_hook(f"on_{stage}_dataloader") + dataloader: DataLoader = getattr(model, f'{stage}_dataloader')() dataloader = self._flatten_dl_only(dataloader) - self.accelerator.barrier('get_dataloaders') return dataloader diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index fc2192216a1625..c53681d20ac423 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -15,6 +15,7 @@ import torch from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.model_helpers import is_overridden @@ -99,6 +100,10 @@ def on_evaluation_end(self, *args, **kwargs): else: self.trainer.call_hook('on_validation_end', *args, **kwargs) + if self.trainer.state != TrainerState.FITTING: + # summarize profile results + self.trainer.profiler.describe() + def reload_evaluation_dataloaders(self): model = self.trainer.lightning_module if self.trainer.testing: diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index 70329b4fdf514c..b33f41cb2ea487 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -44,6 +44,8 @@ def on_predict_model_eval(self, *_, **__): model_ref.on_predict_model_eval() def setup(self, model, max_batches, dataloaders): + self.trainer.call_hook("on_predict_start") + # copy properties for forward overrides self.trainer.model_connector.copy_trainer_model_properties(model) @@ -65,7 +67,7 @@ def _get_num_dataloaders(self, dataloaders): length = len(dataloaders[0]) return length - def predict(self, batch, batch_idx, dataloader_idx): + def predict_step(self, batch, batch_idx, dataloader_idx): # configure args args = [batch, batch_idx] if self.num_dataloaders: @@ -74,7 +76,7 @@ def predict(self, batch, batch_idx, dataloader_idx): model_ref = self.trainer.lightning_module model_ref._current_fx_name = "predict" - predictions = self.trainer.accelerator.predict(args) + predictions = self.trainer.accelerator.predict_step(args) if predictions is None: self.warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") @@ -86,6 +88,8 @@ def predict(self, batch, batch_idx, dataloader_idx): return def on_predict_epoch_end(self): + self.trainer.profiler.describe() + self.trainer._progress_bar_callback.on_predict_end(self.trainer, self.trainer.lightning_module) results = self._predictions @@ -99,3 +103,11 @@ def _convert_to_numpy(v): return results[0] return results + + def on_predict_start(self): + # hook + self.trainer.call_hook("on_predict_start") + + def on_predict_end(self): + # hook + self.trainer.call_hook("on_predict_end") diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index b5654b148afc6b..315e3c60c05579 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -491,6 +491,16 @@ def sanity_checking(self, val: bool) -> None: elif self.sanity_checking: self._running_stage = None + @property + def _setup_state(self) -> TrainerState: + # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" + return TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + + @property + def _teardown_state(self) -> Optional[TrainerState]: + if self.state.running: + return self._setup_state + # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 53b4920bd85ef6..dbc493aa76e040 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -445,13 +445,15 @@ def fit( | || {self.dispatch} || | || LIGHTNING - {self.accelerator.start_training} or || - {self.accelerator.start_evaluating} or || FLOW - {self.accelerator.start_predicting} || + {self.accelerator.start_training} || + or {self.accelerator.start_evaluating} || + or {self.accelerator.start_predicting} || FLOW + | || + {self.run_stage} || | || DIRECTION - {self.run_train} or || - {self.run_evaluation} or || - {self.run_predict} || + {self.run_train} || + or {self.run_evaluation} || + or {self.run_predict} || | || results \/ This is used to guide readers to the core loops: train, test, predict. @@ -495,7 +497,7 @@ def fit( return self.accelerator.results or 1 def pre_dispatch(self): - self.accelerator.pre_dispatch() + self.accelerator.pre_dispatch(self) # log hyper-parameters if self.logger is not None: @@ -505,7 +507,7 @@ def pre_dispatch(self): self.logger.save() def post_dispatch(self): - self.accelerator.post_dispatch() + self.accelerator.post_dispatch(self) self.accelerator.teardown() def dispatch(self): @@ -518,6 +520,9 @@ def dispatch(self): def run_stage(self): results = None + + self.profile_connector.setup() + if self.evaluating: results = self.run_evaluate() elif self.predicting: @@ -757,6 +762,8 @@ def run_evaluate(self): return eval_loop_results def run_predict(self): + self.predict_loop.on_predict_start() + # prepare dataloaders dataloaders, max_batches = self.predict_loop.get_predict_dataloaders() @@ -779,7 +786,6 @@ def run_predict(self): for dataloader_idx, dataloader in enumerate(dataloaders): dataloader = self.accelerator.process_dataloader(dataloader) dl_max_batches = self.predict_loop.max_batches[dataloader_idx] - for batch_idx, batch in enumerate(dataloader): if batch is None: continue @@ -789,10 +795,15 @@ def run_predict(self): break # lightning module methods - with self.profiler.profile("predict"): - self.predict_loop.predict(batch, batch_idx, dataloader_idx) + with self.profiler.profile("predict_step"): + self.predict_loop.predict_step(batch, batch_idx, dataloader_idx) results = self.predict_loop.on_predict_epoch_end() + self.predict_loop.on_predict_end() + + # re-enable grads + torch.set_grad_enabled(True) + return results def run_sanity_check(self, ref_model): @@ -1060,8 +1071,7 @@ def tune( def call_setup_hook(self, model: LightningModule) -> None: assert self.state.running, f"TrainerState: {self.state}" - # 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders" - state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state + state = self._setup_state if self.datamodule is not None: called = getattr(self.datamodule, f'has_setup_{state}') @@ -1072,11 +1082,8 @@ def call_setup_hook(self, model: LightningModule) -> None: model.setup(stage=state) def call_teardown_hook(self, model: LightningModule) -> None: - if self.state.running: - state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state - else: - state = None - + state = self._teardown_state + self.profiler.teardown(stage=state) self.teardown(stage=state) model.teardown(stage=state) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7e737c424ff261..c3ba34ca66d2d3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -102,9 +102,6 @@ def on_train_start(self): # hook self.trainer.call_hook("on_train_start") - # provide rank to profiler - self.trainer.profile_connector.on_train_start(self.trainer) - def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): # clean hparams if hasattr(model, "hparams"): @@ -140,8 +137,7 @@ def on_train_end(self): self.trainer.logger.finalize("success") # summarize profile results - if self.trainer.global_rank == 0: - self.trainer.profiler.describe() + self.trainer.profiler.describe() # give accelerators a chance to finish self.trainer.accelerator.on_train_end() @@ -747,7 +743,7 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, # backward pass if result is not None: - with self.trainer.profiler.profile("model_backward"): + with self.trainer.profiler.profile("backward"): self.backward(result, optimizer, opt_idx) # hook - call this hook only diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 49cbaf3c6bdcfe..46d88184ee1904 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -67,7 +67,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp # Value has been passed as a flag => It is currently None, so we need to set it to True # We always set to True, regardless of the default value. # Users must pass False directly, but when passing nothing True is assumed. - # i.e. the only way to disable somthing that defaults to True is to use the long form: + # i.e. the only way to disable something that defaults to True is to use the long form: # "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None, # which then becomes True here. @@ -242,9 +242,6 @@ def add_argparse_args( if arg == 'track_grad_norm': use_type = float - if arg_default is inspect._empty: - arg_default = None - parser.add_argument( f'--{arg}', dest=arg, @@ -291,10 +288,7 @@ def _gpus_allowed_type(x) -> Union[int, str]: def _gpus_arg_default(x) -> Union[int, str]: - if ',' in x: - return str(x) - else: - return int(x) + return _gpus_allowed_type(x) def _int_or_float_type(x) -> Union[int, float]: diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 41a13d6c678a0d..baeac9be572184 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities""" +import importlib import operator import platform import sys @@ -19,7 +20,7 @@ from importlib.util import find_spec import torch -from pkg_resources import DistributionNotFound, get_distribution +from pkg_resources import DistributionNotFound def _module_available(module_path: str) -> bool: @@ -42,11 +43,24 @@ def _module_available(module_path: str) -> bool: def _compare_version(package: str, op, version) -> bool: + """ + Compare package version with some requirements + + >>> _compare_version("torch", operator.ge, "0.1") + True + """ try: - pkg_version = LooseVersion(get_distribution(package).version) - return op(pkg_version, LooseVersion(version)) - except DistributionNotFound: + pkg = importlib.import_module(package) + except (ModuleNotFoundError, DistributionNotFound): + return False + try: + pkg_version = LooseVersion(pkg.__version__) + except AttributeError: return False + if not (hasattr(pkg_version, "vstring") and hasattr(pkg_version, "version")): + # this is mock by sphinx, so it shall return True ro generate all summaries + return True + return op(pkg_version, LooseVersion(version)) _IS_WINDOWS = platform.system() == "Windows" @@ -54,7 +68,9 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0") _TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") _TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") +_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0") +_KINETO_AVAILABLE = torch.profiler.kineto_available() if _TORCH_GREATER_EQUAL_1_8 else False _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available('pl_bolts') _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') diff --git a/requirements/extra.txt b/requirements/extra.txt index a05c4971ac450a..715916c4e36acb 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,4 +7,5 @@ torchtext>=0.5 # onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 +# todo: when switch to standard package stream, drop `fairscale` from hard mocked docs libs https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip diff --git a/requirements/test.txt b/requirements/test.txt index 099a6fe43b6e62..259cc2e2d64424 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,8 +1,8 @@ coverage>5.2.0 codecov>=2.1 pytest>=6.0 -pytest-cov>2.10 -# pytest-xdist +#pytest-cov>2.10 +#pytest-xdist flake8>=3.6 check-manifest twine==3.2 diff --git a/setup.cfg b/setup.cfg index ab1e1e8c1addc7..3775eb0070f7ca 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,11 +47,6 @@ omit = pytorch_lightning/utilities/xla_device_utils.py pytorch_lightning/utilities/distributed.py pytorch_lightning/tuner/auto_gpu_select.py - # TODO: temporary, until accelerator refactor is finished - pytorch_lightning/accelerators/accelerator.py - pytorch_lightning/plugins/training_type/*.py - pytorch_lightning/plugins/precision/*.py - pytorch_lightning/plugins/base_plugin.py [flake8] diff --git a/setup.py b/setup.py index 5d619d51977b29..e53e24ebf07023 100755 --- a/setup.py +++ b/setup.py @@ -16,20 +16,22 @@ import os # Always prefer setuptools over distutils +import sys + from setuptools import find_packages, setup try: - import builtins + from pytorch_lightning import info, setup_tools except ImportError: - import __builtin__ as builtins + # alternative https://stackoverflow.com/a/67692/4521646 + sys.path.append("pytorch_lightning") + import info + import setup_tools # https://packaging.python.org/guides/single-sourcing-package-version/ # http://blog.ionelmc.ro/2014/05/25/python-packaging/ -PATH_ROOT = os.path.dirname(__file__) -builtins.__LIGHTNING_SETUP__ = True - -import pytorch_lightning # noqa: E402 -from pytorch_lightning.setup_tools import _load_readme_description, _load_requirements # noqa: E402 +_PATH_ROOT = os.path.dirname(__file__) +_PATH_REQUIRE = os.path.join(_PATH_ROOT, 'requirements') # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras # Define package extras. These are only installed if you specify them. @@ -37,10 +39,10 @@ # From local copy of repo, use like `pip install ".[dev, docs]"` extras = { # 'docs': load_requirements(file_name='docs.txt'), - 'examples': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='examples.txt'), - 'loggers': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='loggers.txt'), - 'extra': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='extra.txt'), - 'test': _load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='test.txt') + 'examples': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='examples.txt'), + 'loggers': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='loggers.txt'), + 'extra': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='extra.txt'), + 'test': setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name='test.txt') } extras['dev'] = extras['extra'] + extras['loggers'] + extras['test'] extras['all'] = extras['dev'] + extras['examples'] # + extras['docs'] @@ -53,6 +55,12 @@ # filter cpu only packages extras[ex] = [pkg for pkg in extras[kw] if not any(pgpu.lower() in pkg.lower() for pgpu in PACKAGES_GPU_ONLY)] +long_description = setup_tools._load_readme_description( + _PATH_ROOT, + homepage=info.__homepage__, + version=info.__version__, +) + # https://packaging.python.org/discussions/install-requires-vs-requirements / # keep the meta-data here for simplicity in reading this file... it's not obvious # what happens and to non-engineers they won't know to look in init ... @@ -60,22 +68,22 @@ # engineer specific practices setup( name="pytorch-lightning", - version=pytorch_lightning.__version__, - description=pytorch_lightning.__docs__, - author=pytorch_lightning.__author__, - author_email=pytorch_lightning.__author_email__, - url=pytorch_lightning.__homepage__, + version=info.__version__, + description=info.__docs__, + author=info.__author__, + author_email=info.__author_email__, + url=info.__homepage__, download_url='https://github.com/PyTorchLightning/pytorch-lightning', - license=pytorch_lightning.__license__, + license=info.__license__, packages=find_packages(exclude=['tests', 'tests/*', 'benchmarks', 'legacy', 'legacy/*']), - long_description=_load_readme_description(PATH_ROOT), + long_description=long_description, long_description_content_type='text/markdown', include_package_data=True, zip_safe=False, keywords=['deep learning', 'pytorch', 'AI'], python_requires='>=3.6', setup_requires=[], - install_requires=_load_requirements(PATH_ROOT), + install_requires=setup_tools._load_requirements(_PATH_ROOT), extras_require=extras, project_urls={ "Bug Tracker": "https://github.com/PyTorchLightning/pytorch-lightning/issues", diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 81a5132e473569..46379a9d10c14f 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -3,10 +3,12 @@ import pytest import torch +from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel def test_unsupported_precision_plugins(): @@ -18,3 +20,33 @@ def test_unsupported_precision_plugins(): ) with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."): accelerator.setup(trainer=trainer, model=model) + + +@pytest.mark.parametrize("delay_dispatch", [True, False]) +def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch): + """ + Test when using a custom training type plugin that delays setup optimizers, + we do not call setup optimizers till ``pre_dispatch``. + """ + + class TestModel(BoringModel): + + def on_fit_start(self): + if delay_dispatch: + # Ensure we haven't setup optimizers if we've delayed dispatch + assert len(self.trainer.optimizers) == 0 + else: + assert len(self.trainer.optimizers) > 0 + + def on_fit_end(self): + assert len(self.trainer.optimizers) > 0 + + class CustomPlugin(SingleDevicePlugin): + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + return delay_dispatch + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=CustomPlugin(device=torch.device("cpu"))) + trainer.fit(model) diff --git a/tests/checkpointing/test_torch_saving.py b/tests/checkpointing/test_torch_saving.py index c8b1e96aeaf0a0..8eabc4640046f3 100644 --- a/tests/checkpointing/test_torch_saving.py +++ b/tests/checkpointing/test_torch_saving.py @@ -47,6 +47,7 @@ def test_model_torch_save_ddp_cpu(tmpdir): max_epochs=num_epochs, accelerator="ddp_cpu", num_processes=2, + logger=False, ) temp_path = os.path.join(tmpdir, 'temp.pt') trainer.fit(model) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index f449a37e33c25a..725db1180d9e82 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -20,6 +20,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.profiler import BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache from tests.deprecated_api import no_deprecated_call from tests.helpers import BoringModel @@ -80,6 +81,11 @@ def on_save_checkpoint(self, *args): trainer.save_checkpoint(filepath) +def test_v1_5_0_legacy_profiler_argument(): + with pytest.deprecated_call(match="renamed to `record_functions` in v1.3"): + PyTorchProfiler(profiled_functions=[]) + + def test_v1_5_0_running_sanity_check(): trainer = Trainer() with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'): @@ -203,3 +209,12 @@ def on_test_epoch_end(self, outputs): model = NewSignatureModel() with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."): trainer.test(model) + + +@pytest.mark.parametrize("cls", (BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) +def test_v1_5_0_profiler_output_filename(tmpdir, cls): + filepath = str(tmpdir / "test.txt") + with pytest.deprecated_call(match="`output_filename` parameter has been removed"): + profiler = cls(output_filename=filepath) + assert profiler.dirpath == tmpdir + assert profiler.filename == "test" diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index fe85fbaea90255..5483e33d9cddb4 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -56,6 +56,7 @@ def __new__( *args, min_gpus: int = 0, min_torch: Optional[str] = None, + max_torch: Optional[str] = None, min_python: Optional[str] = None, quantization: bool = False, amp_apex: bool = False, @@ -76,6 +77,7 @@ def __new__( args: native pytest.mark.skipif arguments min_gpus: min number of gpus required to run test min_torch: minimum pytorch version to run test + max_torch: maximum pytorch version to run test min_python: minimum python version required to run test quantization: if `torch.quantization` package is required to run test amp_apex: NVIDIA Apex is installed @@ -102,6 +104,11 @@ def __new__( conditions.append(torch_version < LooseVersion(min_torch)) reasons.append(f"torch>={min_torch}") + if max_torch: + torch_version = LooseVersion(get_distribution("torch").version) + conditions.append(torch_version >= LooseVersion(max_torch)) + reasons.append(f"torch<{max_torch}") + if min_python: py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" conditions.append(py_version < LooseVersion(min_python)) diff --git a/tests/metrics/functional/test_retrieval.py b/tests/metrics/functional/test_retrieval.py deleted file mode 100644 index a0573cba1d27ed..00000000000000 --- a/tests/metrics/functional/test_retrieval.py +++ /dev/null @@ -1,36 +0,0 @@ -import math - -import numpy as np -import pytest -import torch -from sklearn.metrics import average_precision_score as sk_average_precision - -from pytorch_lightning import seed_everything -from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision - - -@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ - pytest.param(sk_average_precision, retrieval_average_precision), -]) -def test_against_sklearn(sklearn_metric, torch_metric): - """Compare PL metrics to sklearn version. """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - seed_everything(0) - - rounds = 25 - sizes = [1, 4, 10, 100] - - for size in sizes: - for _ in range(rounds): - a = np.random.randn(size) - b = np.random.randn(size) > 0 - - sk = torch.tensor(sklearn_metric(b, a), device=device) - pl = torch_metric(torch.tensor(a, device=device), torch.tensor(b, device=device)) - - # `torch_metric`s return 0 when no label is True - # while `sklearn.average_precision_score` returns NaN - if math.isnan(sk): - assert pl == 0 - else: - assert torch.allclose(sk.float(), pl.float()) diff --git a/tests/metrics/regression/test_explained_variance.py b/tests/metrics/regression/test_explained_variance.py deleted file mode 100644 index adab562ac60552..00000000000000 --- a/tests/metrics/regression/test_explained_variance.py +++ /dev/null @@ -1,77 +0,0 @@ -from collections import namedtuple -from functools import partial - -import pytest -import torch -from sklearn.metrics import explained_variance_score - -from pytorch_lightning.metrics.functional import explained_variance -from pytorch_lightning.metrics.regression import ExplainedVariance -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -num_targets = 5 - -Input = namedtuple('Input', ["preds", "target"]) - -_single_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.rand(NUM_BATCHES, BATCH_SIZE), -) - -_multi_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), -) - - -def _single_target_sk_metric(preds, target, sk_fn=explained_variance_score): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_fn(sk_target, sk_preds) - - -def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() - return sk_fn(sk_target, sk_preds) - - -@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted']) -@pytest.mark.parametrize( - "preds, target, sk_metric", - [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), - ], -) -class TestExplainedVariance(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_explained_variance(self, multioutput, preds, target, sk_metric, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp, - preds, - target, - ExplainedVariance, - partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)), - dist_sync_on_step, - metric_args=dict(multioutput=multioutput), - ) - - def test_explained_variance_functional(self, multioutput, preds, target, sk_metric): - self.run_functional_metric_test( - preds, - target, - explained_variance, - partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)), - metric_args=dict(multioutput=multioutput), - ) - - -def test_error_on_different_shape(metric_class=ExplainedVariance): - metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/metrics/regression/test_mean_error.py b/tests/metrics/regression/test_mean_error.py deleted file mode 100644 index 041ce12f11164c..00000000000000 --- a/tests/metrics/regression/test_mean_error.py +++ /dev/null @@ -1,87 +0,0 @@ -from collections import namedtuple -from functools import partial - -import pytest -import torch -from sklearn.metrics import mean_absolute_error as sk_mean_absolute_error -from sklearn.metrics import mean_squared_error as sk_mean_squared_error -from sklearn.metrics import mean_squared_log_error as sk_mean_squared_log_error - -from pytorch_lightning.metrics.functional import mean_absolute_error, mean_squared_error, mean_squared_log_error -from pytorch_lightning.metrics.regression import MeanAbsoluteError, MeanSquaredError, MeanSquaredLogError -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -num_targets = 5 - -Input = namedtuple('Input', ["preds", "target"]) - -_single_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.rand(NUM_BATCHES, BATCH_SIZE), -) - -_multi_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), -) - - -def _single_target_sk_metric(preds, target, sk_fn=mean_squared_error): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_fn(sk_preds, sk_target) - - -def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() - return sk_fn(sk_preds, sk_target) - - -@pytest.mark.parametrize( - "preds, target, sk_metric", - [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), - ], -) -@pytest.mark.parametrize( - "metric_class, metric_functional, sk_fn", - [ - (MeanSquaredError, mean_squared_error, sk_mean_squared_error), - (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error), - (MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error), - ], -) -class TestMeanError(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_mean_error_class( - self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, ddp, dist_sync_on_step - ): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=metric_class, - sk_metric=partial(sk_metric, sk_fn=sk_fn), - dist_sync_on_step=dist_sync_on_step, - ) - - def test_mean_error_functional(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): - self.run_functional_metric_test( - preds=preds, - target=target, - metric_functional=metric_functional, - sk_metric=partial(sk_metric, sk_fn=sk_fn), - ) - - -@pytest.mark.parametrize("metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError]) -def test_error_on_different_shape(metric_class): - metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) diff --git a/tests/metrics/regression/test_psnr.py b/tests/metrics/regression/test_psnr.py deleted file mode 100644 index eb07fffb9d55c2..00000000000000 --- a/tests/metrics/regression/test_psnr.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import namedtuple -from functools import partial - -import numpy as np -import pytest -import torch -from skimage.metrics import peak_signal_noise_ratio - -from pytorch_lightning.metrics.functional import psnr -from pytorch_lightning.metrics.regression import PSNR -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -Input = namedtuple('Input', ["preds", "target"]) - -_input_size = (NUM_BATCHES, BATCH_SIZE, 32, 32) -_inputs = [ - Input( - preds=torch.randint(n_cls_pred, _input_size, dtype=torch.float), - target=torch.randint(n_cls_target, _input_size, dtype=torch.float), - ) for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)] -] - - -def _to_sk_peak_signal_noise_ratio_inputs(value, dim): - value = value.numpy() - batches = value[None] if value.ndim == len(_input_size) - 1 else value - - if dim is None: - return [batches] - - num_dims = np.size(dim) - if not num_dims: - return batches - - inputs = [] - for batch in batches: - batch = np.moveaxis(batch, dim, np.arange(-num_dims, 0)) - psnr_input_shape = batch.shape[-num_dims:] - inputs.extend(batch.reshape(-1, *psnr_input_shape)) - return inputs - - -def _sk_psnr(preds, target, data_range, reduction, dim): - sk_preds_lists = _to_sk_peak_signal_noise_ratio_inputs(preds, dim=dim) - sk_target_lists = _to_sk_peak_signal_noise_ratio_inputs(target, dim=dim) - np_reduce_map = {"elementwise_mean": np.mean, "none": np.array, "sum": np.sum} - return np_reduce_map[reduction]([ - peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range) - for sk_target, sk_preds in zip(sk_target_lists, sk_preds_lists) - ]) - - -def _base_e_sk_psnr(preds, target, data_range, reduction, dim): - return _sk_psnr(preds, target, data_range, reduction, dim) * np.log(10) - - -@pytest.mark.parametrize( - "preds, target, data_range, reduction, dim", - [ - (_inputs[0].preds, _inputs[0].target, 10, "elementwise_mean", None), - (_inputs[1].preds, _inputs[1].target, 10, "elementwise_mean", None), - (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", None), - (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", 1), - (_inputs[2].preds, _inputs[2].target, 5, "elementwise_mean", (1, 2)), - (_inputs[2].preds, _inputs[2].target, 5, "sum", (1, 2)), - ], -) -@pytest.mark.parametrize( - "base, sk_metric", - [ - (10.0, _sk_psnr), - (2.718281828459045, _base_e_sk_psnr), - ], -) -class TestPSNR(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_psnr(self, preds, target, data_range, base, reduction, dim, sk_metric, ddp, dist_sync_on_step): - _args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim} - self.run_class_metric_test( - ddp, - preds, - target, - PSNR, - partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim), - metric_args=_args, - dist_sync_on_step=dist_sync_on_step, - ) - - def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduction, dim): - _args = {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim} - self.run_functional_metric_test( - preds, - target, - psnr, - partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim), - metric_args=_args, - ) - - -@pytest.mark.parametrize("reduction", ["none", "sum"]) -def test_reduction_for_dim_none(reduction): - match = f"The `reduction={reduction}` will not have any effect when `dim` is None." - with pytest.warns(UserWarning, match=match): - PSNR(reduction=reduction, dim=None) - - with pytest.warns(UserWarning, match=match): - psnr(_inputs[0].preds, _inputs[0].target, reduction=reduction, dim=None) - - -def test_missing_data_range(): - with pytest.raises(ValueError): - PSNR(data_range=None, dim=0) - - with pytest.raises(ValueError): - psnr(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0) diff --git a/tests/metrics/regression/test_r2score.py b/tests/metrics/regression/test_r2score.py deleted file mode 100644 index 232b003e6116a7..00000000000000 --- a/tests/metrics/regression/test_r2score.py +++ /dev/null @@ -1,114 +0,0 @@ -from collections import namedtuple -from functools import partial - -import pytest -import torch -from sklearn.metrics import r2_score as sk_r2score - -from pytorch_lightning.metrics.functional import r2score -from pytorch_lightning.metrics.regression import R2Score -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -num_targets = 5 - -Input = namedtuple('Input', ["preds", "target"]) - -_single_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.rand(NUM_BATCHES, BATCH_SIZE), -) - -_multi_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), -) - - -def _single_target_sk_metric(preds, target, adjusted, multioutput): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) - if adjusted != 0: - r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) - return r2_score - - -def _multi_target_sk_metric(preds, target, adjusted, multioutput): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() - r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) - if adjusted != 0: - r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) - return r2_score - - -@pytest.mark.parametrize("adjusted", [0, 5, 10]) -@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted']) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_outputs", - [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric, 1), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric, num_targets), - ], -) -class TestR2Score(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp, - preds, - target, - R2Score, - partial(sk_metric, adjusted=adjusted, multioutput=multioutput), - dist_sync_on_step, - metric_args=dict(adjusted=adjusted, multioutput=multioutput, num_outputs=num_outputs), - ) - - def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs): - self.run_functional_metric_test( - preds, - target, - r2score, - partial(sk_metric, adjusted=adjusted, multioutput=multioutput), - metric_args=dict(adjusted=adjusted, multioutput=multioutput), - ) - - -def test_error_on_different_shape(metric_class=R2Score): - metric = metric_class() - with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): - metric(torch.randn(100, ), torch.randn(50, )) - - -def test_error_on_multidim_tensors(metric_class=R2Score): - metric = metric_class() - with pytest.raises( - ValueError, - match=r'Expected both prediction and target to be 1D or 2D tensors,' - r' but recevied tensors with dimension .' - ): - metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5)) - - -def test_error_on_too_few_samples(metric_class=R2Score): - metric = metric_class() - with pytest.raises(ValueError, match='Needs atleast two samples to calculate r2 score.'): - metric(torch.randn(1, ), torch.randn(1, )) - - -def test_warning_on_too_large_adjusted(metric_class=R2Score): - metric = metric_class(adjusted=10) - - with pytest.warns( - UserWarning, - match="More independent regressions than datapoints in" - " adjusted r2 score. Falls back to standard r2 score." - ): - metric(torch.randn(10, ), torch.randn(10, )) - - with pytest.warns(UserWarning, match="Division by zero in adjusted r2 score. Falls back to" " standard r2 score."): - metric(torch.randn(11, ), torch.randn(11, )) diff --git a/tests/metrics/regression/test_ssim.py b/tests/metrics/regression/test_ssim.py deleted file mode 100644 index f7e4b7a58e0011..00000000000000 --- a/tests/metrics/regression/test_ssim.py +++ /dev/null @@ -1,104 +0,0 @@ -from collections import namedtuple -from functools import partial - -import pytest -import torch -from skimage.metrics import structural_similarity - -from pytorch_lightning.metrics.functional import ssim -from pytorch_lightning.metrics.regression import SSIM -from tests.metrics.utils import BATCH_SIZE, MetricTester, NUM_BATCHES - -torch.manual_seed(42) - -Input = namedtuple('Input', ["preds", "target", "multichannel"]) - -_inputs = [] -for size, channel, coef, multichannel, dtype in [ - (12, 3, 0.9, True, torch.float), - (13, 1, 0.8, False, torch.float32), - (14, 1, 0.7, False, torch.double), - (15, 3, 0.6, True, torch.float64), -]: - preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) - _inputs.append(Input( - preds=preds, - target=preds * coef, - multichannel=multichannel, - )) - - -def _sk_metric(preds, target, data_range, multichannel): - c, h, w = preds.shape[-3:] - sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() - sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() - if not multichannel: - sk_preds = sk_preds[:, :, :, 0] - sk_target = sk_target[:, :, :, 0] - - return structural_similarity( - sk_target, - sk_preds, - data_range=data_range, - multichannel=multichannel, - gaussian_weights=True, - win_size=11, - sigma=1.5, - use_sample_covariance=False - ) - - -@pytest.mark.parametrize( - "preds, target, multichannel", - [(i.preds, i.target, i.multichannel) for i in _inputs], -) -class TestSSIM(MetricTester): - atol = 6e-5 - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_ssim(self, preds, target, multichannel, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp, - preds, - target, - SSIM, - partial(_sk_metric, data_range=1.0, multichannel=multichannel), - metric_args={"data_range": 1.0}, - dist_sync_on_step=dist_sync_on_step, - ) - - def test_ssim_functional(self, preds, target, multichannel): - self.run_functional_metric_test( - preds, - target, - ssim, - partial(_sk_metric, data_range=1.0, multichannel=multichannel), - metric_args={"data_range": 1.0}, - ) - - -@pytest.mark.parametrize( - ['pred', 'target', 'kernel', 'sigma'], - [ - pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input - ], -) -def test_ssim_invalid_inputs(pred, target, kernel, sigma): - pred_t = torch.rand(pred) - target_t = torch.rand(target, dtype=torch.float64) - with pytest.raises(TypeError): - ssim(pred_t, target_t) - - pred = torch.rand(pred) - target = torch.rand(target) - with pytest.raises(ValueError): - ssim(pred, target, kernel, sigma) diff --git a/tests/metrics/retrieval/__init__.py b/tests/metrics/retrieval/__init__.py deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/tests/metrics/retrieval/test_map.py b/tests/metrics/retrieval/test_map.py deleted file mode 100644 index fe43f19b20eb67..00000000000000 --- a/tests/metrics/retrieval/test_map.py +++ /dev/null @@ -1,119 +0,0 @@ -import math -import random -from typing import Callable, List - -import numpy as np -import pytest -import torch -from sklearn.metrics import average_precision_score as sk_average_precision -from torchmetrics import Metric - -from pytorch_lightning import seed_everything -from pytorch_lightning.metrics.retrieval.mean_average_precision import RetrievalMAP - - -@pytest.mark.parametrize(['sklearn_metric', 'torch_class_metric'], [ - [sk_average_precision, RetrievalMAP], -]) -def test_against_sklearn(sklearn_metric: Callable, torch_class_metric: Metric) -> None: - """Compare PL metrics to sklearn version. """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - seed_everything(0) - - rounds = 20 - sizes = [1, 4, 10, 100] - batch_sizes = [1, 4, 10] - query_without_relevant_docs_options = ['skip', 'pos', 'neg'] - - def compute_sklearn_metric(target: List[np.ndarray], preds: List[np.ndarray], behaviour: str) -> torch.Tensor: - """ Compute sk metric with multiple iterations using the base `sklearn_metric`. """ - sk_results = [] - kwargs = {'device': device, 'dtype': torch.float32} - - for b, a in zip(target, preds): - res = sklearn_metric(b, a) - - if math.isnan(res): - if behaviour == 'skip': - pass - elif behaviour == 'pos': - sk_results.append(torch.tensor(1.0, **kwargs)) - else: - sk_results.append(torch.tensor(0.0, **kwargs)) - else: - sk_results.append(torch.tensor(res, **kwargs)) - if len(sk_results) > 0: - sk_results = torch.stack(sk_results).mean() - else: - sk_results = torch.tensor(0.0, **kwargs) - - return sk_results - - def do_test(batch_size: int, size: int) -> None: - """ For each possible behaviour of the metric, check results are correct. """ - for behaviour in query_without_relevant_docs_options: - metric = torch_class_metric(query_without_relevant_docs=behaviour) - shape = (size, ) - - indexes = [] - preds = [] - target = [] - - for i in range(batch_size): - indexes.append(np.ones(shape, dtype=int) * i) - preds.append(np.random.randn(*shape)) - target.append(np.random.randn(*shape) > 0) - - sk_results = compute_sklearn_metric(target, preds, behaviour) - - indexes_tensor = torch.cat([torch.tensor(i) for i in indexes]) - preds_tensor = torch.cat([torch.tensor(p) for p in preds]) - target_tensor = torch.cat([torch.tensor(t) for t in target]) - - # lets assume data are not ordered - perm = torch.randperm(indexes_tensor.nelement()) - indexes_tensor = indexes_tensor.view(-1)[perm].view(indexes_tensor.size()) - preds_tensor = preds_tensor.view(-1)[perm].view(preds_tensor.size()) - target_tensor = target_tensor.view(-1)[perm].view(target_tensor.size()) - - # shuffle ids to require also sorting of documents ability from the lightning metric - pl_result = metric(indexes_tensor, preds_tensor, target_tensor) - - assert torch.allclose(sk_results.float(), pl_result.float(), equal_nan=True) - - for batch_size in batch_sizes: - for size in sizes: - for _ in range(rounds): - do_test(batch_size, size) - - -@pytest.mark.parametrize(['torch_class_metric'], [ - [RetrievalMAP], -]) -def test_input_data(torch_class_metric: Metric) -> None: - """Check PL metrics inputs are controlled correctly. """ - - device = 'cuda' if torch.cuda.is_available() else 'cpu' - seed_everything(0) - - for _ in range(10): - - length = random.randint(0, 20) - - # check error when `query_without_relevant_docs='error'` is raised correctly - indexes = torch.tensor([0] * length, device=device, dtype=torch.int64) - preds = torch.rand(size=(length, ), device=device, dtype=torch.float32) - target = torch.tensor([False] * length, device=device, dtype=torch.bool) - - metric = torch_class_metric(query_without_relevant_docs='error') - - try: - metric(indexes, preds, target) - except Exception as e: - assert isinstance(e, ValueError) - - # check ValueError with non-accepted argument - try: - metric = torch_class_metric(query_without_relevant_docs='casual_argument') - except Exception as e: - assert isinstance(e, ValueError) diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py index b9a12eb7b76010..c50bed29cc0adb 100644 --- a/tests/metrics/test_remove_1-5_metrics.py +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -22,15 +22,22 @@ AUROC, AveragePrecision, ConfusionMatrix, + ExplainedVariance, F1, FBeta, HammingDistance, IoU, + MeanAbsoluteError, + MeanSquaredError, + MeanSquaredLogError, MetricCollection, Precision, PrecisionRecallCurve, + PSNR, + R2Score, Recall, ROC, + SSIM, StatScores, ) from pytorch_lightning.metrics.functional import ( @@ -38,18 +45,26 @@ auroc, average_precision, confusion_matrix, + explained_variance, f1, fbeta, hamming_distance, iou, + mean_absolute_error, + mean_squared_error, + mean_squared_log_error, precision, precision_recall, precision_recall_curve, + psnr, + r2score, recall, roc, + ssim, stat_scores, ) from pytorch_lightning.metrics.functional.accuracy import accuracy +from pytorch_lightning.metrics.functional.mean_relative_error import mean_relative_error from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot @@ -229,8 +244,85 @@ def test_v1_5_metric_detect(): IoU(num_classes=1) target = torch.randint(0, 2, (10, 25, 25)) - pred = torch.tensor(target) - pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] + preds = torch.tensor(target) + preds[2:5, 7:13, 9:15] = 1 - preds[2:5, 7:13, 9:15] iou._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): - assert torch.allclose(iou(pred, target), torch.tensor(0.9660), atol=1e-4) + assert torch.allclose(iou(preds, target), torch.tensor(0.9660), atol=1e-4) + + +def test_v1_5_metric_regress(): + ExplainedVariance.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + ExplainedVariance() + + MeanAbsoluteError.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + MeanAbsoluteError() + + MeanSquaredError.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + MeanSquaredError() + + MeanSquaredLogError.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + MeanSquaredLogError() + + target = torch.tensor([3, -0.5, 2, 7]) + preds = torch.tensor([2.5, 0.0, 2, 8]) + explained_variance.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = explained_variance(preds, target) + assert torch.allclose(res, torch.tensor(0.9572), atol=1e-4) + + x = torch.tensor([0., 1, 2, 3]) + y = torch.tensor([0., 1, 2, 2]) + mean_absolute_error.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert mean_absolute_error(x, y) == 0.25 + + mean_relative_error.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert mean_relative_error(x, y) == 0.125 + + mean_squared_error.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert mean_squared_error(x, y) == 0.25 + + mean_squared_log_error.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = mean_squared_log_error(x, y) + assert torch.allclose(res, torch.tensor(0.0207), atol=1e-4) + + PSNR.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + PSNR() + + R2Score.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + R2Score() + + SSIM.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + SSIM() + + preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) + target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) + psnr.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = psnr(preds, target) + assert torch.allclose(res, torch.tensor(2.5527), atol=1e-4) + + target = torch.tensor([3, -0.5, 2, 7]) + preds = torch.tensor([2.5, 0.0, 2, 8]) + r2score.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = r2score(preds, target) + assert torch.allclose(res, torch.tensor(0.9486), atol=1e-4) + + preds = torch.rand([16, 1, 16, 16]) + target = preds * 0.75 + ssim.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + res = ssim(preds, target) + assert torch.allclose(res, torch.tensor(0.9219), atol=1e-4) diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 3921e7ef33b8e1..aaf47c82d5f087 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -24,7 +24,7 @@ ("training", "training_step"), ("testing", "test_step"), ("validating", "validation_step"), - ("predicting", "predict"), + ("predicting", "predict_step"), ] ) def test_lightning_wrapper_module_methods(wrapper_class, stage): diff --git a/tests/special_tests.sh b/tests/special_tests.sh index dd67af470c4ec5..c381b5e9feeb6a 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -14,7 +14,7 @@ # Running special tests set -e export PL_RUNNING_SPECIAL_TESTS=1 -DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" +DEFAULTS="-m coverage run --source pytorch_lightning --append -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_invalid_deepspeed_defaults_no_precision @@ -34,9 +34,9 @@ python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_ddp python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_dp python ${DEFAULTS} tests/trainer/logging_/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp python ${DEFAULTS} tests/callbacks/test_pruning.py::test_pruning_callback_ddp -python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_ddp +python ${DEFAULTS} tests/test_profiler.py::test_pytorch_profiler_trainer_ddp python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp python ${DEFAULTS} tests/trainer/test_data_loading.py::test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model python ${DEFAULTS} tests/checkpointing/test_checkpoint_callback_frequency.py::test_top_k_ddp -nvprof --profile-from-start off -o trace_name.prof -- python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_nested_emit_nvtx +nvprof --profile-from-start off -o trace_name.prof -- python ${DEFAULTS} tests/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 9b51ca7f7c6d2c..a6e33b3366f33e 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -13,13 +13,22 @@ # limitations under the License. import logging import os +import platform import time -from pathlib import Path +from copy import deepcopy +from distutils.version import LooseVersion import numpy as np import pytest +import torch -from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler +from pytorch_lightning.profiler.pytorch import RegisterRecordFunction +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005 @@ -40,14 +49,7 @@ def _sleep_generator(durations): @pytest.fixture def simple_profiler(): - profiler = SimpleProfiler() - return profiler - - -@pytest.fixture -def advanced_profiler(tmpdir): - profiler = AdvancedProfiler(output_filename=os.path.join(tmpdir, "profiler.txt")) - return profiler + return SimpleProfiler() @pytest.mark.parametrize(["action", "expected"], [ @@ -93,14 +95,6 @@ def test_simple_profiler_overhead(simple_profiler, n_iter=5): assert all(durations < PROFILER_OVERHEAD_MAX_TOLERANCE) -def test_simple_profiler_describe(caplog, simple_profiler): - """Ensure the profiler won't fail when reporting the summary.""" - with caplog.at_level(logging.INFO): - simple_profiler.describe() - - assert "Profiler Report" in caplog.text - - def test_simple_profiler_value_errors(simple_profiler): """Ensure errors are raised where expected.""" @@ -116,6 +110,77 @@ def test_simple_profiler_value_errors(simple_profiler): simple_profiler.stop(action) +def test_simple_profiler_deepcopy(tmpdir): + simple_profiler = SimpleProfiler(dirpath=tmpdir, filename="test") + simple_profiler.describe() + assert deepcopy(simple_profiler) + + +def test_simple_profiler_log_dir(tmpdir): + """Ensure the profiler dirpath defaults to `trainer.log_dir` when not present""" + profiler = SimpleProfiler(filename="profiler") + assert profiler._log_dir is None + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + profiler=profiler, + ) + trainer.fit(model) + + expected = tmpdir / "lightning_logs" / "version_0" + assert trainer.log_dir == expected + assert profiler._log_dir == trainer.log_dir + assert expected.join("fit-profiler.txt").exists() + + +@RunIf(skip_windows=True) +def test_simple_profiler_distributed_files(tmpdir): + """Ensure the proper files are saved in distributed""" + profiler = SimpleProfiler(dirpath=tmpdir, filename='profiler') + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=2, + accelerator="ddp_cpu", + num_processes=2, + profiler=profiler, + logger=False, + ) + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + + actual = set(os.listdir(profiler.dirpath)) + expected = {f"{stage}-profiler-{rank}.txt" for stage in ("fit", "validate", "test") for rank in (0, 1)} + assert actual == expected + + for f in profiler.dirpath.listdir(): + assert f.read_text('utf-8') + + +def test_simple_profiler_logs(tmpdir, caplog, simple_profiler): + """Ensure that the number of printed logs is correct""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=2, + profiler=simple_profiler, + logger=False, + ) + with caplog.at_level(logging.INFO, logger="pytorch_lightning.profiler.profilers"): + trainer.fit(model) + trainer.test(model) + + assert caplog.text.count("Profiler Report") == 2 + + +@pytest.fixture +def advanced_profiler(tmpdir): + return AdvancedProfiler(dirpath=tmpdir, filename="profiler") + + @pytest.mark.parametrize(["action", "expected"], [ pytest.param("a", [3, 1]), pytest.param("b", [2]), @@ -174,7 +239,8 @@ def test_advanced_profiler_describe(tmpdir, advanced_profiler): pass # log to stdout and print to file advanced_profiler.describe() - data = Path(advanced_profiler.output_fname).read_text() + path = advanced_profiler.dirpath / f"{advanced_profiler.filename}.txt" + data = path.read_text("utf-8") assert len(data) > 0 @@ -187,3 +253,259 @@ def test_advanced_profiler_value_errors(advanced_profiler): advanced_profiler.start(action) advanced_profiler.stop(action) + + +def test_advanced_profiler_deepcopy(advanced_profiler): + advanced_profiler.describe() + assert deepcopy(advanced_profiler) + + +@pytest.fixture +def pytorch_profiler(tmpdir): + return PyTorchProfiler(dirpath=tmpdir, filename="profiler") + + +@RunIf(max_torch="1.8.1") +def test_pytorch_profiler_describe(pytorch_profiler): + """Ensure the profiler won't fail when reporting the summary.""" + with pytorch_profiler.profile("on_test_start"): + torch.tensor(0) + + # log to stdout and print to file + pytorch_profiler.describe() + path = pytorch_profiler.dirpath / f"{pytorch_profiler.filename}.txt" + data = path.read_text("utf-8") + assert len(data) > 0 + + +def test_pytorch_profiler_raises(pytorch_profiler): + """Ensure errors are raised where expected.""" + with pytest.raises(MisconfigurationException, match="profiled_functions` and `PyTorchProfiler.record"): + PyTorchProfiler(profiled_functions=["a"], record_functions=["b"]) + + +@RunIf(min_torch="1.6.0") +def test_advanced_profiler_cprofile_deepcopy(tmpdir): + """Checks for pickle issue reported in #6522""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + profiler="advanced", + stochastic_weight_avg=True, + ) + trainer.fit(model) + + +@RunIf(min_gpus=2, special=True) +def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler): + """Ensure that the profiler can be given to the training and default step are properly recorded. """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=5, + limit_val_batches=5, + profiler=pytorch_profiler, + accelerator="ddp", + gpus=2, + ) + trainer.fit(model) + + expected = {'validation_step'} + if not _KINETO_AVAILABLE: + expected |= {'training_step_and_backward', 'training_step', 'backward'} + for name in expected: + assert sum(e.name == name for e in pytorch_profiler.function_events), name + + files = set(os.listdir(pytorch_profiler.dirpath)) + expected = f"fit-profiler-{trainer.local_rank}.txt" + assert expected in files + + path = pytorch_profiler.dirpath / expected + assert path.read_text("utf-8") + + if _KINETO_AVAILABLE: + files = os.listdir(pytorch_profiler.dirpath) + files = [file for file in files if file.endswith('.json')] + assert len(files) == 2, files + local_rank = trainer.local_rank + assert any(f'training_step_{local_rank}' in f for f in files) + assert any(f'validation_step_{local_rank}' in f for f in files) + + +def test_pytorch_profiler_trainer_test(tmpdir): + """Ensure that the profiler can be given to the trainer and test step are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_test_batches=2, + profiler=pytorch_profiler, + ) + trainer.test(model) + + assert sum(e.name == 'test_step' for e in pytorch_profiler.function_events) + + path = pytorch_profiler.dirpath / f"test-{pytorch_profiler.filename}.txt" + assert path.read_text("utf-8") + + if _KINETO_AVAILABLE: + files = sorted([file for file in os.listdir(tmpdir) if file.endswith('.json')]) + assert any(f'test_step_{trainer.local_rank}' in f for f in files) + + +def test_pytorch_profiler_trainer_predict(tmpdir): + """Ensure that the profiler can be given to the trainer and predict function are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) + model = BoringModel() + model.predict_dataloader = model.train_dataloader + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_predict_batches=2, + profiler=pytorch_profiler, + ) + trainer.predict(model) + + assert sum(e.name == 'predict_step' for e in pytorch_profiler.function_events) + path = pytorch_profiler.dirpath / f"predict-{pytorch_profiler.filename}.txt" + assert path.read_text("utf-8") + + +def test_pytorch_profiler_trainer_validate(tmpdir): + """Ensure that the profiler can be given to the trainer and validate function are properly recorded. """ + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None) + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=2, + profiler=pytorch_profiler, + ) + trainer.validate(model) + + assert sum(e.name == 'validation_step' for e in pytorch_profiler.function_events) + + path = pytorch_profiler.dirpath / f"validate-{pytorch_profiler.filename}.txt" + assert path.read_text("utf-8") + + +def test_pytorch_profiler_nested(tmpdir): + """Ensure that the profiler handles nested context""" + + pytorch_profiler = PyTorchProfiler( + record_functions={"a", "b", "c"}, use_cuda=False, dirpath=tmpdir, filename="profiler", schedule=None + ) + + with pytorch_profiler.profile("a"): + a = torch.ones(42) + with pytorch_profiler.profile("b"): + b = torch.zeros(42) + with pytorch_profiler.profile("c"): + _ = a + b + + pytorch_profiler.describe() + + events_name = {e.name for e in pytorch_profiler.function_events} + + if platform.system() == "Windows": + expected = {'a', 'add', 'b', 'c', 'profiler::_record_function_enter', 'profiler::_record_function_exit'} + else: + expected = { + 'signed char', 'add', 'profiler::_record_function_exit', 'bool', 'char', 'profiler::_record_function_enter' + } + + if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + expected = {'add', 'zeros', 'ones', 'zero_', 'b', 'fill_', 'c', 'a', 'empty'} + + if LooseVersion(torch.__version__) >= LooseVersion("1.7.0"): + expected = { + 'aten::zeros', 'aten::add', 'aten::zero_', 'c', 'b', 'a', 'aten::fill_', 'aten::empty', 'aten::ones' + } + + assert events_name == expected, (events_name, torch.__version__, platform.system()) + + +@RunIf(min_gpus=1, special=True) +def test_pytorch_profiler_nested_emit_nvtx(tmpdir): + """ + This test check emit_nvtx is correctly supported + """ + profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True) + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + profiler=profiler, + gpus=1, + ) + trainer.fit(model) + + +@RunIf(min_torch="1.5.0") +def test_register_record_function(tmpdir): + + use_cuda = torch.cuda.is_available() + pytorch_profiler = PyTorchProfiler( + export_to_chrome=False, + record_functions={"a"}, + use_cuda=use_cuda, + dirpath=tmpdir, + filename="profiler", + schedule=None, + on_trace_ready=None, + ) + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), torch.nn.Linear(1, 1)) + + model = TestModel() + input = torch.rand((1, 1)) + + if use_cuda: + model = model.cuda() + input = input.cuda() + + with pytorch_profiler.profile("a"): + with RegisterRecordFunction(model): + model(input) + + pytorch_profiler.describe() + event_names = [e.name for e in pytorch_profiler.function_events] + assert 'torch.nn.modules.container.Sequential: layer' in event_names + assert 'torch.nn.modules.linear.Linear: layer.0' in event_names + assert 'torch.nn.modules.activation.ReLU: layer.1' in event_names + assert 'torch.nn.modules.linear.Linear: layer.2' in event_names + + +@pytest.mark.parametrize("cls", (SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) +def test_profiler_teardown(tmpdir, cls): + """ + This test checks if profiler teardown method is called when trainer is exiting. + """ + + class TestCallback(Callback): + + def on_fit_end(self, trainer, *args, **kwargs) -> None: + # describe sets it to None + assert trainer.profiler._output_file is None + + profiler = cls(dirpath=tmpdir, filename="profiler") + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()]) + trainer.fit(model) + + assert profiler._output_file is None + + +def test_pytorch_profiler_deepcopy(tmpdir): + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profiler", schedule=None) + pytorch_profiler.start("on_train_start") + torch.tensor(1) + pytorch_profiler.describe() + assert deepcopy(pytorch_profiler) diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 4dc5b5f34b50ca..5dc1ea5de4e8a0 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -13,7 +13,6 @@ # limitations under the License. from pytorch_lightning import Trainer -from tests.accelerators import DDPLauncher from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -81,25 +80,3 @@ def test_get_model_gpu(tmpdir): gpus=1, ) trainer.fit(model) - - -@RunIf(min_gpus=1, skip_windows=True) -@DDPLauncher.run("--accelerator [accelerator]", max_epochs=["1"], accelerator=["ddp", "ddp_spawn"]) -def test_get_model_ddp_gpu(tmpdir, args=None): - """ - Tests that `trainer.lightning_module` extracts the model correctly when using GPU + ddp accelerators - """ - - model = TrainerGetModel() - - limit_train_batches = 2 - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=limit_train_batches, - limit_val_batches=2, - max_epochs=1, - gpus=1, - accelerator=args.accelerator - ) - trainer.fit(model) - return 1 diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index 59e10480a485e1..9fccd9b36440ae 100644 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest +import torch -from pytorch_lightning import Trainer +from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel +from tests.helpers import BoringModel, RandomDataset def test_wrong_train_setting(tmpdir): @@ -101,3 +102,48 @@ def test_val_loop_config(tmpdir): model = BoringModel() model.validation_step = None trainer.validate(model) + + +@pytest.mark.parametrize("datamodule", [False, True]) +def test_trainer_predict_verify_config(tmpdir, datamodule): + + class TestModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + class TestLightningDataModule(LightningDataModule): + + def __init__(self, dataloaders): + super().__init__() + self._dataloaders = dataloaders + + def test_dataloader(self): + return self._dataloaders + + def predict_dataloader(self): + return self._dataloaders + + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + + model = TestModel() + + trainer = Trainer(default_root_dir=tmpdir) + + if datamodule: + datamodule = TestLightningDataModule(dataloaders) + results = trainer.predict(model, datamodule=datamodule) + else: + results = trainer.predict(model, dataloaders=dataloaders) + + assert len(results) == 2 + assert results[0][0].shape == torch.Size([1, 2]) + + model.predict_dataloader = None + + with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"): + trainer.predict(model) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 52c51777e2a893..505af173b79108 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1159,3 +1159,71 @@ def test_replace_sampler_with_multiprocessing_context(tmpdir): new_data_loader = trainer.replace_sampler(train, SequentialSampler(train.dataset)) assert (new_data_loader.multiprocessing_context == train.multiprocessing_context) + + +def test_request_dataloader(tmpdir): + """ + This test asserts dataloader can be modified and properly set to the trainer. + """ + + class DataLoaderWrapper: + + def __init__(self, loader): + self.loader = loader + self._iter = iter(self.loader) + + def __iter__(self): + self._iter = iter(self.loader) + return self._iter + + def __next__(self): + return next(self._iter) + + class DataLoaderFunc: + + def __init__(self, loader): + self.loader = loader + + def __call__(self): + return self.loader + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.on_train_dataloader_called = False + self.on_train_batch_start_called = False + self.on_val_dataloader_called = False + self.on_val_batch_start_called = False + + def on_train_dataloader(self) -> None: + loader = self.train_dataloader() + self.train_dataloader = DataLoaderFunc(DataLoaderWrapper(loader)) + self.on_train_dataloader_called = True + + def on_train_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None: + assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper) + self.on_train_batch_start_called = True + + def on_val_dataloader(self) -> None: + loader = self.val_dataloader() + self.val_dataloader = DataLoaderFunc(DataLoaderWrapper(loader)) + self.on_val_dataloader_called = True + + def on_validation_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None: + assert isinstance(self.trainer.val_dataloaders[0], DataLoaderWrapper) + self.on_val_batch_start_called = True + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + model = TestModel() + trainer.fit(model) + trainer.test(model) + assert model.on_train_dataloader_called + assert model.on_train_batch_start_called + assert model.on_val_dataloader_called + assert model.on_val_batch_start_called diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3375b02c5496b0..490f205a7bbec2 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -17,7 +17,6 @@ import sys from argparse import Namespace from copy import deepcopy -from distutils.version import LooseVersion from pathlib import Path from unittest.mock import ANY, call, patch @@ -43,12 +42,6 @@ from tests.helpers.runif import RunIf -@pytest.fixture -def pytorch_profiler(tmpdir): - profiler = PyTorchProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"), local_rank=0) - return profiler - - @pytest.mark.parametrize("url_ckpt", [True, False]) def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" @@ -1447,16 +1440,29 @@ def test_trainer_predict_no_return(tmpdir): class CustomBoringModel(BoringModel): - def predict(self, batch, batch_idx, dataloader_idx=None): + def predict_step(self, batch, batch_idx, dataloader_idx=None): if (batch_idx + 1) % 2 == 0: return - return super().predict(batch, batch_idx, dataloader_idx) + return super().predict_step(batch, batch_idx, dataloader_idx) with pytest.warns(UserWarning, match='predict returned None'): predict(tmpdir, None, None, 1, model=CustomBoringModel()) +def test_trainer_predict_grad(tmpdir): + class CustomBoringModel(BoringModel): + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + assert batch.expand_as(batch).grad_fn is None + return super().predict_step(batch, batch_idx, dataloader_idx) + + predict(tmpdir, None, None, 1, model=CustomBoringModel()) + + x = torch.zeros(1, requires_grad=True) + assert x.expand_as(x).grad_fn is not None + + @pytest.mark.parametrize('datamodule', [False, True]) def test_trainer_predict_cpu(tmpdir, datamodule): predict(tmpdir, None, None, 1, datamodule=datamodule) @@ -1488,124 +1494,6 @@ def test_trainer_predict_ddp_cpu(tmpdir): predict(tmpdir, "ddp_cpu", 0, 2) -def test_pytorch_profiler_describe(pytorch_profiler): - """Ensure the profiler won't fail when reporting the summary.""" - with pytorch_profiler.profile("test_step"): - pass - - # log to stdout and print to file - pytorch_profiler.describe() - data = Path(pytorch_profiler.output_fname).read_text() - assert len(data) > 0 - - -def test_pytorch_profiler_value_errors(pytorch_profiler): - """Ensure errors are raised where expected.""" - - action = "test_step" - with pytest.raises(ValueError): - pytorch_profiler.stop(action) - - pytorch_profiler.start(action) - pytorch_profiler.stop(action) - - -@RunIf(min_gpus=2, special=True) -@pytest.mark.parametrize("use_output_filename", [False, True]) -def test_pytorch_profiler_trainer_ddp(tmpdir, use_output_filename): - """Ensure that the profiler can be given to the training and default step are properly recorded. """ - - if use_output_filename: - output_filename = os.path.join(tmpdir, "profiler.txt") - else: - output_filename = None - - profiler = PyTorchProfiler(output_filename=output_filename) - - model = BoringModel() - trainer = Trainer( - fast_dev_run=True, - profiler=profiler, - accelerator="ddp", - gpus=2, - ) - trainer.fit(model) - - enabled = use_output_filename or not use_output_filename and profiler.local_rank == 0 - - if enabled: - assert len(profiler.summary()) > 0 - assert set(profiler.profiled_actions.keys()) == {'training_step_and_backward', 'validation_step'} - else: - assert profiler.summary() is None - assert set(profiler.profiled_actions.keys()) == set() - - if use_output_filename: - profiler.describe() - data = Path(profiler.output_fname).read_text() - assert len(data) > 0 - - -def test_pytorch_profiler_nested(tmpdir): - """Ensure that the profiler handles nested context""" - - pytorch_profiler = PyTorchProfiler( - profiled_functions=["a", "b", "c"], use_cuda=False, output_filename=os.path.join(tmpdir, "profiler.txt") - ) - - with pytorch_profiler.profile("a"): - a = torch.ones(42) - with pytorch_profiler.profile("b"): - b = torch.zeros(42) - with pytorch_profiler.profile("c"): - _ = a + b - - pa = pytorch_profiler.profiled_actions - - # From PyTorch 1.8.0, less operation are being traced. - if LooseVersion(torch.__version__) >= LooseVersion("1.8.0"): - expected_ = { - 'a': ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'add'], - 'b': ['zeros', 'empty', 'zero_'], - 'c': ['add'], - } - # From PyTorch 1.6.0, more operation are being traced. - elif LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - expected_ = { - 'a': ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'fill_', 'add', 'empty'], - 'b': ['zeros', 'empty', 'zero_', 'fill_'], - 'c': ['add', 'empty'], - } - else: - expected_ = { - 'a': ['add'], - 'b': [], - 'c': ['add'], - } - - for n in ('a', 'b', 'c'): - pa[n] = [e.name for e in pa[n]] - if LooseVersion(torch.__version__) >= LooseVersion("1.7.1"): - pa[n] = [e.replace("aten::", "") for e in pa[n]] - assert pa[n] == expected_[n] - - -@RunIf(min_gpus=1, special=True) -def test_pytorch_profiler_nested_emit_nvtx(tmpdir): - """ - This test check emit_nvtx is correctly supported - """ - profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True) - - model = BoringModel() - trainer = Trainer( - fast_dev_run=True, - profiler=profiler, - gpus=1, - ) - trainer.fit(model) - - @pytest.mark.parametrize( ["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"], [(0.2, 0, 0, 0, False), (0.5, 10, 2, 4, True)], @@ -1856,3 +1744,35 @@ def test_check_val_every_n_epoch_exception(tmpdir): max_epochs=1, check_val_every_n_epoch=1.2, ) + + +def test_trainer_attach_data_pipeline_to_model(tmpdir): + + class DataPipeline: + + pass + + class TestDataModule(LightningDataModule): + + data_pipeline = DataPipeline() + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + class TestCallback(Callback): + + def on_fit_start(self, trainer, pl_module: LightningModule) -> None: + """Called when fit begins""" + assert isinstance(pl_module.data_pipeline, DataPipeline) + + model = BoringModel() + dm = TestDataModule() + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=[TestCallback()]) + trainer.fit(model, datamodule=dm) diff --git a/tests/utilities/test_argparse_utils.py b/tests/utilities/test_argparse.py similarity index 80% rename from tests/utilities/test_argparse_utils.py rename to tests/utilities/test_argparse.py index b2eac514941e6a..f13af4362364ca 100644 --- a/tests/utilities/test_argparse_utils.py +++ b/tests/utilities/test_argparse.py @@ -1,17 +1,52 @@ import io -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace from typing import List +from unittest.mock import MagicMock import pytest from pytorch_lightning import Trainer from pytorch_lightning.utilities.argparse import ( + _gpus_arg_default, + _int_or_float_type, add_argparse_args, + from_argparse_args, get_abbrev_qualified_cls_name, + parse_argparser, parse_args_from_docstring, ) +class ArgparseExample: + + def __init__(self, a: int = 0, b: str = '', c: bool = False): + self.a = a + self.b = b + self.c = c + + +def test_from_argparse_args(): + args = Namespace(a=1, b='test', c=True, d='not valid') + my_instance = from_argparse_args(ArgparseExample, args) + assert my_instance.a == 1 + assert my_instance.b == 'test' + assert my_instance.c + + parser = ArgumentParser() + mock_trainer = MagicMock() + _ = from_argparse_args(mock_trainer, parser) + mock_trainer.parse_argparser.assert_called_once_with(parser) + + +def test_parse_argparser(): + args = Namespace(a=1, b='test', c=None, d='not valid') + new_args = parse_argparser(ArgparseExample, args) + assert new_args.a == 1 + assert new_args.b == 'test' + assert new_args.c + assert new_args.d == 'not valid' + + def test_parse_args_from_docstring_normal(): args_help = parse_args_from_docstring( """Constrain image dataset @@ -168,3 +203,13 @@ def test_add_argparse_args_no_argument_group(): args = parser.parse_args(fake_argv) assert args.main_arg == "abc" assert args.my_parameter == 2 + + +def test_gpus_arg_default(): + assert _gpus_arg_default('1,2') == '1,2' + assert _gpus_arg_default('1') == 1 + + +def test_int_or_float_type(): + assert isinstance(_int_or_float_type('0.0'), float) + assert isinstance(_int_or_float_type('0'), int)