diff --git a/.github/workflows/ci_test-base.yml b/.github/workflows/ci_test-base.yml index afd86ca98c213..c0b97439737ff 100644 --- a/.github/workflows/ci_test-base.yml +++ b/.github/workflows/ci_test-base.yml @@ -76,8 +76,7 @@ jobs: with: name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }} path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - # Use always() to always run this step to publish test results when there are test failures - if: always() + if: failure() - name: Statistics if: success() diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 7fc4de8ddbfd3..d64fedbfbe590 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -50,5 +50,4 @@ jobs: with: name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }} path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - # Use always() to always run this step to publish test results when there are test failures - if: always() + if: failure() diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 1a2115a40fcfd..b87a1d8557843 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -129,8 +129,7 @@ jobs: with: name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }} path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - # Use always() to always run this step to publish test results when there are test failures - if: always() + if: failure() - name: Statistics if: success() diff --git a/CHANGELOG.md b/CHANGELOG.md index 951aee6049b7a..f078349ef3665 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,47 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [unreleased.Features] - YYYY-MM-DD + +### Added + + +### Changed + + +### Deprecated + + +### Removed + + +### Fixed + + + +## [unreleased.BugFix] - YYYY-MM-DD + +### Added + + +### Changed + + +### Deprecated + + +### Removed + + +### Fixed + +- Fixed trainer by default `None` in `DDPAccelerator` ([#4915](https://github.com/PyTorchLightning/pytorch-lightning/pull/4915)) + + +- Fixed `LightningOptimizer` exposes optimizer attributes ([#5095](https://github.com/PyTorchLightning/pytorch-lightning/pull/5095)) + + + ## [1.1.0] - 2020-12-09 ### Added @@ -44,9 +85,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- Removed `multiclass_roc` and `multiclass_precision_recall_curve`, use `roc` and `precision_recall_curve` instead ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549)) - Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) -- WandbLogger does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648)) +- `WandbLogger` does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648)) - Changed `automatic_optimization` to be a model attribute ([#4602](https://github.com/PyTorchLightning/pytorch-lightning/pull/4602)) - Changed `Simple Profiler` report to order by percentage time spent + num calls ([#4880](https://github.com/PyTorchLightning/pytorch-lightning/pull/4880)) - Simplify optimization Logic ([#4984](https://github.com/PyTorchLightning/pytorch-lightning/pull/4984)) @@ -64,6 +104,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed - Removed `reorder` parameter of the `auc` metric ([#5004](https://github.com/PyTorchLightning/pytorch-lightning/pull/5004)) +- Removed `multiclass_roc` and `multiclass_precision_recall_curve`, use `roc` and `precision_recall_curve` instead ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549)) ### Fixed diff --git a/docs/source/introduction_guide.rst b/docs/source/introduction_guide.rst index d6d082e2ed779..d4cf578e10bda 100644 --- a/docs/source/introduction_guide.rst +++ b/docs/source/introduction_guide.rst @@ -601,8 +601,8 @@ In this method we do all the preparation we need to do once (instead of on every def setup(self, stage): # transform transform=transforms.Compose([transforms.ToTensor()]) - MNIST(os.getcwd(), train=True, download=False, transform=transform) - MNIST(os.getcwd(), train=False, download=False, transform=transform) + mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform) + mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transform) # train/val split mnist_train, mnist_val = random_split(mnist_train, [55000, 5000]) diff --git a/docs/source/multi_gpu.rst b/docs/source/multi_gpu.rst index def47810504d6..b3e0b905f27f4 100644 --- a/docs/source/multi_gpu.rst +++ b/docs/source/multi_gpu.rst @@ -663,7 +663,7 @@ It is highly recommended to use Sharded Training in multi-GPU environments where A technical note: as batch size scales, storing activations for the backwards pass becomes the bottleneck in training. As a result, sharding optimizer state and gradients becomes less impactful. Work within the future will bring optional sharding to activations and model parameters to reduce memory further, but come with a speed cost. -To use Sharded Training, you need to first install FairScale using the command below or install all extras using ``pip install pytorch-lightning["extra"]``. +To use Sharded Training, you need to first install FairScale using the command below. .. code-block:: bash diff --git a/docs/source/optimizers.rst b/docs/source/optimizers.rst index 6ca72b8069d6d..06e6e9679d29f 100644 --- a/docs/source/optimizers.rst +++ b/docs/source/optimizers.rst @@ -191,37 +191,48 @@ override the :meth:`optimizer_step` function. For example, here step optimizer A every 2 batches and optimizer B every 4 batches -.. testcode:: +.. note:: When using Trainer(enable_pl_optimizer=True), there is no need to call `.zero_grad()`. - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False): - optimizer.step() +.. testcode:: def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx): optimizer.zero_grad() # Alternating schedule for optimizer steps (ie: GANs) - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False): + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): # update generator opt every 2 steps if optimizer_i == 0: if batch_nb % 2 == 0 : - optimizer.step() - optimizer.zero_grad() + optimizer.step(closure=closure) # update discriminator opt every 4 steps if optimizer_i == 1: if batch_nb % 4 == 0 : - optimizer.step() - optimizer.zero_grad() + optimizer.step(closure=closure) + +.. note:: When using ``Trainer(enable_pl_optimizer=True)``, ``.step`` accepts a boolean ``make_optimizer_step`` which can be used as follow. + +.. testcode:: + + def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx): + optimizer.zero_grad() + + # Alternating schedule for optimizer steps (ie: GANs) + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): + # update generator opt every 2 steps + if optimizer_i == 0: + optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 2) == 0) - # ... - # add as many optimizers as you want + # update discriminator opt every 4 steps + if optimizer_i == 1: + optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 4) == 0) Here we add a learning-rate warm up .. testcode:: # learning rate warm-up - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False): + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): # warm up lr if self.trainer.global_step < 500: lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) @@ -229,8 +240,20 @@ Here we add a learning-rate warm up pg['lr'] = lr_scale * self.hparams.learning_rate # update params - optimizer.step() - optimizer.zero_grad() + optimizer.step(closure=closure) + +The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step. + +.. testcode:: + + from pytorch_lightning.core.optimizer import LightningOptimizer + + # function hook in LightningModule + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): + if not isinstance(optimizer, LightningOptimizer): + # wraps into LightingOptimizer only for running step + optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer) + optimizer.step(closure=closure) ---------- diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index d6c5139cb3799..408d95a72dc47 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -1,6 +1,6 @@ """Root package info.""" -__version__ = '1.1.0' +__version__ = '1.1.1rc0' __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0' diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 57979b73f2cb6..f24a4ce8beb8a 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -14,7 +14,7 @@ """Various hooks to be used in the Lightning code.""" -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import torch from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn @@ -501,7 +501,7 @@ def val_dataloader(self): will have an argument ``dataloader_idx`` which matches the order here. """ - def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: + def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: """ Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom data structure. @@ -549,6 +549,7 @@ def transfer_batch_to_device(self, batch, device) - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` - :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection` """ + device = device or self.device return move_data_to_device(batch, device) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f29e7f75bfbff..358b24fe1f40c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -22,6 +22,7 @@ import tempfile from abc import ABC from argparse import Namespace +from pathlib import Path from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch @@ -1170,7 +1171,6 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): def optimizer_step( self, - *args, epoch: int = None, batch_idx: int = None, optimizer: Optimizer = None, @@ -1179,7 +1179,6 @@ def optimizer_step( on_tpu: bool = None, using_native_amp: bool = None, using_lbfgs: bool = None, - **kwargs, ) -> None: r""" Override this method to adjust the default way the @@ -1254,7 +1253,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, if not isinstance(optimizer, LightningOptimizer): # wraps into LightingOptimizer only for running step optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer) - optimizer.step(closure=optimizer_closure, *args, **kwargs) + optimizer.step(closure=optimizer_closure) def optimizer_zero_grad( self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int @@ -1532,12 +1531,19 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: else: self._hparams = hp - def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs): - """Saves the model in ONNX format + @torch.no_grad() + def to_onnx( + self, + file_path: Union[str, Path], + input_sample: Optional[Any] = None, + **kwargs, + ): + """ + Saves the model in ONNX format Args: - file_path: The path of the file the model should be saved to. - input_sample: A sample of an input tensor for tracing. + file_path: The path of the file the onnx model should be saved to. + input_sample: An input for tracing. Default: None (Use self.example_input_array) **kwargs: Will be passed to torch.onnx.export function. Example: @@ -1556,31 +1562,32 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg ... os.path.isfile(tmpfile.name) True """ + mode = self.training - if isinstance(input_sample, Tensor): - input_data = input_sample - elif self.example_input_array is not None: - input_data = self.example_input_array - else: - if input_sample is not None: + if input_sample is None: + if self.example_input_array is None: raise ValueError( - f"Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`" + "Could not export to ONNX since neither `input_sample` nor" + " `model.example_input_array` attribute is set." ) - raise ValueError( - "Could not export to ONNX since neither `input_sample` nor" - " `model.example_input_array` attribute is set." - ) - input_data = input_data.to(self.device) + input_sample = self.example_input_array + + input_sample = self.transfer_batch_to_device(input_sample) + if "example_outputs" not in kwargs: self.eval() - with torch.no_grad(): - kwargs["example_outputs"] = self(input_data) + kwargs["example_outputs"] = self(input_sample) - torch.onnx.export(self, input_data, file_path, **kwargs) + torch.onnx.export(self, input_sample, file_path, **kwargs) + self.train(mode) + @torch.no_grad() def to_torchscript( - self, file_path: Optional[str] = None, method: Optional[str] = 'script', - example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None, **kwargs + self, + file_path: Optional[Union[str, Path]] = None, + method: Optional[str] = 'script', + example_inputs: Optional[Any] = None, + **kwargs, ) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. @@ -1592,7 +1599,7 @@ def to_torchscript( Args: file_path: Path where to save the torchscript. Default: None (no file saved). method: Whether to use TorchScript's script or trace method. Default: 'script' - example_inputs: Tensor to be used to do tracing when method is set to 'trace'. + example_inputs: An input to be used to do tracing when method is set to 'trace'. Default: None (Use self.example_input_array) **kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or :func:`torch.jit.trace` function. @@ -1626,21 +1633,27 @@ def to_torchscript( This LightningModule as a torchscript, regardless of whether file_path is defined or not. """ - mode = self.training - with torch.no_grad(): - if method == 'script': - torchscript_module = torch.jit.script(self.eval(), **kwargs) - elif method == 'trace': - # if no example inputs are provided, try to see if model has example_input_array set - if example_inputs is None: - example_inputs = self.example_input_array - # automatically send example inputs to the right device and use trace - example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device) - torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) - else: - raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:" - f"{method}") + + if method == 'script': + torchscript_module = torch.jit.script(self.eval(), **kwargs) + elif method == 'trace': + # if no example inputs are provided, try to see if model has example_input_array set + if example_inputs is None: + if self.example_input_array is None: + raise ValueError( + 'Choosing method=`trace` requires either `example_inputs`' + ' or `model.example_input_array` to be defined' + ) + example_inputs = self.example_input_array + + # automatically send example inputs to the right device and use trace + example_inputs = self.transfer_batch_to_device(example_inputs) + torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) + else: + raise ValueError("The 'method' parameter only supports 'script' or 'trace'," + f" but value given was: {method}") + self.train(mode) if file_path is not None: diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index dc63231ba6ccb..c8e9ff8b80a2f 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -57,12 +57,35 @@ def __init__(self, else: self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) - self._trainer = None self._optimizer = optimizer + self._trainer = None self._accumulate_grad_batches = accumulate_grad_batches - self._automatic_optimization = None self._optimizer_idx = None + @property + def defaults(self): + return self._optimizer.defaults + + @defaults.setter + def defaults(self, defaults): + self._optimizer.defaults = defaults + + @property + def state(self): + return self._optimizer.state + + @state.setter + def state(self, state): + self._optimizer.state = state + + @property + def param_groups(self): + return self._optimizer.param_groups + + @param_groups.setter + def param_groups(self, param_groups): + self._optimizer.param_groups = param_groups + @property def accumulate_grad_batches(self): return self._accumulate_grad_batches @@ -73,7 +96,6 @@ def accumulate_grad_batches(self, accumulate_grad_batches): def _on_trainer_init(self, trainer): self._trainer = proxy(trainer) - self._automatic_optimization = trainer.train_loop.automatic_optimization for opt_idx, opt in enumerate(trainer.optimizers): if opt == self._optimizer: self._optimizer_idx = opt_idx diff --git a/pytorch_lightning/metrics/classification/__init__.py b/pytorch_lightning/metrics/classification/__init__.py index 13cb705f30b17..b4cbb6b073efe 100644 --- a/pytorch_lightning/metrics/classification/__init__.py +++ b/pytorch_lightning/metrics/classification/__init__.py @@ -14,7 +14,7 @@ from pytorch_lightning.metrics.classification.accuracy import Accuracy from pytorch_lightning.metrics.classification.average_precision import AveragePrecision from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix -from pytorch_lightning.metrics.classification.f_beta import FBeta, F1 +from pytorch_lightning.metrics.classification.f_beta import FBeta, Fbeta, F1 from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve from pytorch_lightning.metrics.classification.roc import ROC diff --git a/pytorch_lightning/metrics/classification/average_precision.py b/pytorch_lightning/metrics/classification/average_precision.py index 0a8a952470dbc..33878cb48965d 100644 --- a/pytorch_lightning/metrics/classification/average_precision.py +++ b/pytorch_lightning/metrics/classification/average_precision.py @@ -92,9 +92,8 @@ def __init__( self.add_state("target", default=[], dist_reduce_fx=None) rank_zero_warn( - 'Metric `AveragePrecision` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' + 'Metric `AveragePrecision` will save all targets and predictions in buffer.' + ' For large datasets this may lead to large memory footprint.' ) def update(self, preds: torch.Tensor, target: torch.Tensor): diff --git a/pytorch_lightning/metrics/classification/f_beta.py b/pytorch_lightning/metrics/classification/f_beta.py index 56cc00f9a5dce..fadfd000ebbe1 100755 --- a/pytorch_lightning/metrics/classification/f_beta.py +++ b/pytorch_lightning/metrics/classification/f_beta.py @@ -20,6 +20,7 @@ _fbeta_compute ) from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.utilities import rank_zero_warn class FBeta(Metric): @@ -51,11 +52,11 @@ class FBeta(Metric): Threshold value for binary or multi-label logits. default: 0.5 average: - * `'micro'` computes metric globally - * `'macro'` computes metric for each class and uniformly averages them - * `'weighted'` computes metric for each class and does a weighted-average, - where each class is weighted by their support (accounts for class imbalance) - * `None` computes and returns the metric per class + - ``'micro'`` computes metric globally + - ``'macro'`` computes metric for each class and uniformly averages them + - ``'weighted'`` computes metric for each class and does a weighted-average, + where each class is weighted by their support (accounts for class imbalance) + - ``'none'`` computes and returns the metric per class multilabel: If predictions are from multilabel classification. compute_on_step: @@ -131,6 +132,34 @@ def compute(self) -> torch.Tensor: self.actual_positives, self.beta, self.average) +# todo: remove in v1.2 +class Fbeta(FBeta): + r""" + Computes `F-score `_ + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.classification.f_beta.FBeta` + """ + def __init__( + self, + num_classes: int, + beta: float = 1.0, + threshold: float = 0.5, + average: str = "micro", + multilabel: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + rank_zero_warn( + "This `Fbeta` was deprecated in v1.0.x in favor of" + " `from pytorch_lightning.metrics.classification.f_beta import FBeta`." + " It will be removed in v1.2.0", DeprecationWarning + ) + super().__init__( + num_classes, beta, threshold, average, multilabel, compute_on_step, dist_sync_on_step, process_group + ) + + class F1(FBeta): """ Computes F1 metric. F1 metrics correspond to a harmonic mean of the @@ -156,11 +185,11 @@ class F1(FBeta): Threshold value for binary or multi-label logits. default: 0.5 average: - * `'micro'` computes metric globally - * `'macro'` computes metric for each class and uniformly averages them - * `'weighted'` computes metric for each class and does a weighted-average, - where each class is weighted by their support (accounts for class imbalance) - * `None` computes and returns the metric per class + - ``'micro'`` computes metric globally + - ``'macro'`` computes metric for each class and uniformly averages them + - ``'weighted'`` computes metric for each class and does a weighted-average, + where each class is weighted by their support (accounts for class imbalance) + - ``'none'`` computes and returns the metric per class multilabel: If predictions are from multilabel classification. compute_on_step: @@ -183,7 +212,6 @@ class F1(FBeta): def __init__( self, num_classes: int = 1, - beta: float = 1.0, threshold: float = 0.5, average: str = "micro", multilabel: bool = False, diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py index 052a25a7a977d..620904898535d 100644 --- a/pytorch_lightning/metrics/classification/precision_recall_curve.py +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -102,9 +102,8 @@ def __init__( self.add_state("target", default=[], dist_reduce_fx=None) rank_zero_warn( - 'Metric `PrecisionRecallCurve` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' + 'Metric `PrecisionRecallCurve` will save all targets and predictions in buffer.' + ' For large datasets this may lead to large memory footprint.' ) def update(self, preds: torch.Tensor, target: torch.Tensor): diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py index 89e8265b19fc1..2b7d82488b491 100644 --- a/pytorch_lightning/metrics/classification/roc.py +++ b/pytorch_lightning/metrics/classification/roc.py @@ -105,9 +105,8 @@ def __init__( self.add_state("target", default=[], dist_reduce_fx=None) rank_zero_warn( - 'Metric `ROC` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' + 'Metric `ROC` will save all targets and predictions in buffer.' + ' For large datasets this may lead to large memory footprint.' ) def update(self, preds: torch.Tensor, target: torch.Tensor): diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index e13242e40b0ac..e38ab5f415c32 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -17,13 +17,18 @@ auc, auroc, dice_score, + f1_score, + fbeta_score, + get_num_classes, + iou, multiclass_auroc, precision, precision_recall, recall, stat_scores, stat_scores_multiple_classes, - iou, + to_categorical, + to_onehot, ) from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # TODO: unify metrics between class and functional, add below diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 1c43ec75bb508..e1ba601b51553 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -15,12 +15,75 @@ from typing import Callable, Optional, Sequence, Tuple import torch -from torch.nn import functional as F -from pytorch_lightning.metrics.utils import to_categorical, get_num_classes, reduce, class_reduce +from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap +from pytorch_lightning.metrics.functional.f_beta import fbeta as __fb, f1 as __f1 +from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve, precision_recall_curve as __prc +from pytorch_lightning.metrics.functional.roc import roc as __roc +from pytorch_lightning.metrics.utils import ( + to_categorical as __tc, + to_onehot as __to, + get_num_classes as __gnc, + reduce, + class_reduce, +) from pytorch_lightning.utilities import rank_zero_warn +def to_onehot( + tensor: torch.Tensor, + num_classes: Optional[int] = None, +) -> torch.Tensor: + """ + Converts a dense label tensor to one-hot format + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_onehot` + """ + rank_zero_warn( + "This `to_onehot` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.utils import to_onehot`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __to(tensor, num_classes) + + +def to_categorical( + tensor: torch.Tensor, + argmax_dim: int = 1 +) -> torch.Tensor: + """ + Converts a tensor of probabilities to a dense label tensor + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_categorical` + + """ + rank_zero_warn( + "This `to_categorical` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.utils import to_categorical`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __tc(tensor) + + +def get_num_classes( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, +) -> int: + """ + Calculates the number of classes for a given prediction and target tensor. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.get_num_classes` + + """ + rank_zero_warn( + "This `get_num_classes` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.utils import get_num_classes`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __gnc(pred,target, num_classes) + + def stat_scores( pred: torch.Tensor, target: torch.Tensor, @@ -332,52 +395,28 @@ def recall( num_classes=num_classes, class_reduction=class_reduction)[1] -def _binary_clf_curve( +# todo: remove in 1.3 +def roc( pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None, pos_label: int = 1., ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py - """ - if sample_weight is not None and not isinstance(sample_weight, torch.Tensor): - sample_weight = torch.tensor(sample_weight, device=pred.device, dtype=torch.float) - - # remove class dimension if necessary - if pred.ndim > target.ndim: - pred = pred[:, 0] - desc_score_indices = torch.argsort(pred, descending=True) - - pred = pred[desc_score_indices] - target = target[desc_score_indices] - - if sample_weight is not None: - weight = sample_weight[desc_score_indices] - else: - weight = 1. - - # pred typically has many tied values. Here we extract - # the indices associated with the distinct values. We also - # concatenate a value for the end of the curve. - distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0] - threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1) - - target = (target == pos_label).to(torch.long) - tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] - - if sample_weight is not None: - # express fps as a cumsum to ensure fps is increasing even in - # the presence of floating point errors - fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] - else: - fps = 1 + threshold_idxs - tps + Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. - return fps, tps, pred[threshold_idxs] + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc` + """ + rank_zero_warn( + "This `multiclass_roc` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.roc import roc`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) # TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py -def __roc( +def _roc( pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None, @@ -386,22 +425,13 @@ def __roc( """ Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. - .. warning:: Deprecated - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weights - pos_label: the label for the positive class - - Return: - false-positive rate (fpr), true-positive rate (tpr), thresholds + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc` Example: >>> x = torch.tensor([0, 1, 2, 3]) >>> y = torch.tensor([0, 1, 1, 1]) - >>> fpr, tpr, thresholds = __roc(x, y) + >>> fpr, tpr, thresholds = _roc(x, y) >>> fpr tensor([0., 0., 0., 0., 1.]) >>> tpr @@ -410,9 +440,12 @@ def __roc( tensor([4, 3, 2, 1, 0]) """ - fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=pos_label) + rank_zero_warn( + "This `multiclass_roc` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.roc import roc`." + " It will be removed in v1.3.0", DeprecationWarning + ) + fps, tps, thresholds = _binary_clf_curve(pred, target, sample_weights=sample_weight, pos_label=pos_label) # Add an extra threshold position # to make sure that the curve starts at (0, 0) @@ -434,7 +467,7 @@ def __roc( # TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py -def __multiclass_roc( +def multiclass_roc( pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None, @@ -443,7 +476,7 @@ def __multiclass_roc( """ Computes the Receiver Operating Characteristic (ROC) for multiclass predictors. - .. warning:: Deprecated + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc` Args: pred: estimated probabilities @@ -462,19 +495,24 @@ def __multiclass_roc( ... [0.05, 0.05, 0.85, 0.05], ... [0.05, 0.05, 0.05, 0.85]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> __multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE + >>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE ((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500]))) """ + rank_zero_warn( + "This `multiclass_roc` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.roc import roc`." + " It will be removed in v1.3.0", DeprecationWarning + ) num_classes = get_num_classes(pred, target, num_classes) class_roc_vals = [] for c in range(num_classes): pred_c = pred[:, c] - class_roc_vals.append(__roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c)) + class_roc_vals.append(_roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c)) return tuple(class_roc_vals) @@ -572,7 +610,7 @@ def auroc( @auc_decorator() def _auroc(pred, target, sample_weight, pos_label): - return __roc(pred, target, sample_weight, pos_label) + return _roc(pred, target, sample_weight, pos_label) return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) @@ -625,7 +663,7 @@ def multiclass_auroc( @multiclass_auc_decorator() def _multiclass_auroc(pred, target, sample_weight, num_classes): - return __multiclass_roc(pred, target, sample_weight, num_classes) + return multiclass_roc(pred, target, sample_weight, num_classes) class_aurocs = _multiclass_auroc(pred=pred, target=target, sample_weight=sample_weight, @@ -772,3 +810,110 @@ def iou( ]) return reduce(scores, reduction=reduction) + + +# todo: remove in 1.3 +def precision_recall_curve( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., +): + """ + Computes precision-recall pairs for different thresholds. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve` + """ + rank_zero_warn( + "This `precision_recall_curve` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __prc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) + + +# todo: remove in 1.3 +def multiclass_precision_recall_curve( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, +): + """ + Computes precision-recall pairs for different thresholds given a multiclass scores. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve` + """ + rank_zero_warn( + "This `multiclass_precision_recall_curve` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`." + " It will be removed in v1.3.0", DeprecationWarning + ) + if num_classes is None: + num_classes = get_num_classes(pred, target, num_classes) + return __prc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes) + + +# todo: remove in 1.3 +def average_precision( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., +): + """ + Compute average precision from prediction scores. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.average_precision.average_precision` + """ + rank_zero_warn( + "This `average_precision` was deprecated in v1.1.0 in favor of" + " `pytorch_lightning.metrics.functional.average_precision import average_precision`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __ap(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) + + +# todo: remove in 1.2 +def fbeta_score( + pred: torch.Tensor, + target: torch.Tensor, + beta: float, + num_classes: Optional[int] = None, + class_reduction: str = 'micro', +) -> torch.Tensor: + """ + Computes the F-beta score which is a weighted harmonic mean of precision and recall. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.f_beta.fbeta` + """ + rank_zero_warn( + "This `average_precision` was deprecated in v1.0.x in favor of" + " `from pytorch_lightning.metrics.functional.f_beta import fbeta`." + " It will be removed in v1.2.0", DeprecationWarning + ) + if num_classes is None: + num_classes = get_num_classes(pred, target) + return __fb(preds=pred, target=target, beta=beta, num_classes=num_classes, average=class_reduction) + + +# todo: remove in 1.2 +def f1_score( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + class_reduction: str = 'micro', +) -> torch.Tensor: + """ + Computes the F1-score (a.k.a F-measure), which is the harmonic mean of the precision and recall. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.f_beta.f1` + """ + rank_zero_warn( + "This `average_precision` was deprecated in v1.0.x in favor of" + " `from pytorch_lightning.metrics.functional.f_beta import f1`." + " It will be removed in v1.2.0", DeprecationWarning + ) + if num_classes is None: + num_classes = get_num_classes(pred, target) + return __f1(preds=pred, target=target, num_classes=num_classes, average=class_reduction) diff --git a/pytorch_lightning/metrics/functional/explained_variance.py b/pytorch_lightning/metrics/functional/explained_variance.py index 012e1486ebb1f..20b38c58a2a6b 100644 --- a/pytorch_lightning/metrics/functional/explained_variance.py +++ b/pytorch_lightning/metrics/functional/explained_variance.py @@ -23,10 +23,11 @@ def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tup return preds, target -def _explained_variance_compute(preds: torch.Tensor, - target: torch.Tensor, - multioutput: str = 'uniform_average', - ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: +def _explained_variance_compute( + preds: torch.Tensor, + target: torch.Tensor, + multioutput: str = 'uniform_average', +) -> Union[torch.Tensor, Sequence[torch.Tensor]]: diff_avg = torch.mean(target - preds, dim=0) numerator = torch.mean((target - preds - diff_avg) ** 2, dim=0) @@ -52,10 +53,11 @@ def _explained_variance_compute(preds: torch.Tensor, return torch.sum(denominator / denom_sum * output_scores) -def explained_variance(preds: torch.Tensor, - target: torch.Tensor, - multioutput: str = 'uniform_average', - ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: +def explained_variance( + preds: torch.Tensor, + target: torch.Tensor, + multioutput: str = 'uniform_average', +) -> Union[torch.Tensor, Sequence[torch.Tensor]]: """ Computes explained variance. diff --git a/pytorch_lightning/metrics/functional/f_beta.py b/pytorch_lightning/metrics/functional/f_beta.py index 3f0a7a0449325..2b0ba194d56f0 100755 --- a/pytorch_lightning/metrics/functional/f_beta.py +++ b/pytorch_lightning/metrics/functional/f_beta.py @@ -83,11 +83,11 @@ def fbeta( Threshold value for binary or multi-label logits. default: 0.5 average: - * `'micro'` computes metric globally - * `'macro'` computes metric for each class and uniformly averages them - * `'weighted'` computes metric for each class and does a weighted-average, - where each class is weighted by their support (accounts for class imbalance) - * `None` computes and returns the metric per class + - ``'micro'`` computes metric globally + - ``'macro'`` computes metric for each class and uniformly averages them + - ``'weighted'`` computes metric for each class and does a weighted-average, + where each class is weighted by their support (accounts for class imbalance) + - ``'none'`` computes and returns the metric per class multilabel: If predictions are from multilabel classification. @@ -110,7 +110,6 @@ def f1( preds: torch.Tensor, target: torch.Tensor, num_classes: int, - beta: float = 1.0, threshold: float = 0.5, average: str = "micro", multilabel: bool = False @@ -136,11 +135,11 @@ def f1( Threshold value for binary or multi-label logits. default: 0.5 average: - * `'micro'` computes metric globally - * `'macro'` computes metric for each class and uniformly averages them - * `'weighted'` computes metric for each class and does a weighted-average, - where each class is weighted by their support (accounts for class imbalance) - * `None` computes and returns the metric per class + - ``'micro'`` computes metric globally + - ``'macro'`` computes metric for each class and uniformly averages them + - ``'weighted'`` computes metric for each class and does a weighted-average, + where each class is weighted by their support (accounts for class imbalance) + - ``'none'`` computes and returns the metric per class multilabel: If predictions are from multilabel classification. diff --git a/pytorch_lightning/metrics/functional/reduction.py b/pytorch_lightning/metrics/functional/reduction.py new file mode 100644 index 0000000000000..c116b16d363a9 --- /dev/null +++ b/pytorch_lightning/metrics/functional/reduction.py @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from pytorch_lightning.metrics.utils import reduce as __reduce, class_reduce as __cr +from pytorch_lightning.utilities import rank_zero_warn + + +def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: + rank_zero_warn( + "This `reduce` was deprecated in v1.1.0 in favor of" + " `pytorch_lightning.metrics.utils import reduce`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __reduce(to_reduce=to_reduce, reduction=reduction) + + +def class_reduce(num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = 'none'): + rank_zero_warn( + "This `class_reduce` was deprecated in v1.1.0 in favor of" + " `pytorch_lightning.metrics.utils import class_reduce`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __cr(num=num, denom=denom, weights=weights, class_reduction=class_reduction) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 20dfb0f4b380f..68a0f4781c9a9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -477,7 +477,7 @@ def _process_result(self, training_step_output, split_batch): return training_step_output_for_epoch_end - def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure, *args, **kwargs): + def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): model_ref = self.trainer.get_model() is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) @@ -491,16 +491,14 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_ # model hook model_ref.optimizer_step( - epoch=self.trainer.current_epoch, - batch_idx=batch_idx, - optimizer=optimizer, - optimizer_idx=opt_idx, - optimizer_closure=train_step_and_backward_closure, + self.trainer.current_epoch, + batch_idx, + optimizer, + opt_idx, + train_step_and_backward_closure, on_tpu=self.trainer.use_tpu and TPU_AVAILABLE, using_native_amp=using_native_amp, using_lbfgs=is_lbfgs, - *args, - **kwargs, ) def on_before_zero_grad(self, optimizer): diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index a11862b4003bc..e5641337cc8d2 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -54,6 +54,7 @@ def _module_available(module_path: str) -> bool: OMEGACONF_AVAILABLE = _module_available("omegaconf") HYDRA_AVAILABLE = _module_available("hydra") HOROVOD_AVAILABLE = _module_available("horovod.torch") +BOLTS_AVAILABLE = _module_available("pl_bolts") TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') diff --git a/requirements/extra.txt b/requirements/extra.txt index ad54358269bd1..3f14b1e5910dd 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,4 +7,4 @@ torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip \ No newline at end of file +https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 3e2e6d040f44c..e3a597063d02e 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -55,6 +55,11 @@ def test_automatic_optimization_num_calls(enable_pl_optimizer, tmpdir): class TestModel(BoringModel): + def training_step(self, batch, batch_idx, optimizer_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + def configure_optimizers(self): optimizer = SGD(self.layer.parameters(), lr=0.1) optimizer_2 = Adam(self.layer.parameters(), lr=0.1) @@ -98,3 +103,47 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, assert sgd_zero_grad.call_count == 4 assert adam_step.call_count == 2 assert adam_zero_grad.call_count == 2 + + +@pytest.mark.parametrize("enable_pl_optimizer", [False, True]) +def test_params_groups_and_state_are_accessible(enable_pl_optimizer, tmpdir): + + with patch("torch.optim.SGD.step") as sgd_step, \ + patch("torch.optim.SGD.zero_grad") as sgd_zero_grad, \ + patch("torch.optim.Adam.step") as adam_step, \ + patch("torch.optim.Adam.zero_grad") as adam_zero_grad: + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def configure_optimizers(self): + optimizer = SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = Adam(self.layer.parameters(), lr=0.1) + return [optimizer, optimizer_2] + + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, + on_tpu=False, using_native_amp=False, using_lbfgs=False): + # warm up lr + if self.trainer.global_step < 500: + lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) + for pg in optimizer.param_groups: + pg['lr'] = lr_scale * 0.01 + + optimizer.step(closure=closure) + + model = TestModel() + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=8, + accumulate_grad_batches=1, + enable_pl_optimizer=enable_pl_optimizer + ) + + trainer.fit(model) diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index e6ec59ec4f5aa..a9fcf918cc699 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -19,10 +19,12 @@ import torch.nn as nn from torch.optim import Adam, Optimizer +import pytorch_lightning as pl from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset +from tests.base.boring_model import BoringModel, RandomDataset, RandomDictDataset, RandomDictStringDataset def test_lightning_optimizer(tmpdir): @@ -80,8 +82,8 @@ def configure_optimizers(self): assert trainer.optimizers[0].__repr__() == expected -@patch("torch.optim.Adam.step") -@patch("torch.optim.SGD.step") +@patch("torch.optim.Adam.step", autospec=True) +@patch("torch.optim.SGD.step", autospec=True) def test_lightning_optimizer_manual_optimization(mock_sgd_step, mock_adam_step, tmpdir): """ Test that the user can use our LightningOptimizer. Not recommended for now. @@ -96,13 +98,13 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): output = self.layer(batch) loss_1 = self.loss(batch, output) self.manual_backward(loss_1, opt_1) - opt_1.step(idx="1") + opt_1.step() def closure(): output = self.layer(batch) loss_2 = self.loss(batch, output) self.manual_backward(loss_2, opt_2) - opt_2.step(closure=closure, idx="2") + opt_2.step(closure=closure) def configure_optimizers(self): optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -133,8 +135,8 @@ def automatic_optimization(self) -> bool: assert len(mock_adam_step.mock_calls) == 8 -@patch("torch.optim.Adam.step") -@patch("torch.optim.SGD.step") +@patch("torch.optim.Adam.step", autospec=True) +@patch("torch.optim.SGD.step", autospec=True) def test_lightning_optimizer_manual_optimization_and_accumulated_gradients(mock_sgd_step, mock_adam_step, tmpdir): """ Test that the user can use our LightningOptimizer. Not recommended. @@ -149,13 +151,13 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): output = self.layer(batch) loss_1 = self.loss(batch, output) self.manual_backward(loss_1, opt_1) - opt_1.step(idx="1") + opt_1.step() def closure(): output = self.layer(batch) loss_2 = self.loss(batch, output) self.manual_backward(loss_2, opt_2) - opt_2.step(closure=closure, idx="2") + opt_2.step(closure=closure) def configure_optimizers(self): optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -191,13 +193,29 @@ def test_state(tmpdir): model = torch.nn.Linear(3, 4) optimizer = torch.optim.Adam(model.parameters()) lightning_optimizer = LightningOptimizer(optimizer) + + # test state + assert optimizer.state == lightning_optimizer.state + lightning_optimizer.state = optimizer.state + assert optimizer.state == lightning_optimizer.state + + # test param_groups + assert optimizer.param_groups == lightning_optimizer.param_groups + lightning_optimizer.param_groups = optimizer.param_groups + assert optimizer.param_groups == lightning_optimizer.param_groups + + # test defaults + assert optimizer.defaults == lightning_optimizer.defaults + lightning_optimizer.defaults = optimizer.defaults + assert optimizer.defaults == lightning_optimizer.defaults + assert isinstance(lightning_optimizer, LightningOptimizer) assert isinstance(lightning_optimizer, Adam) assert isinstance(lightning_optimizer, Optimizer) lightning_dict = {} - special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx", - "_trainer", "_use_accumulate_grad_batches_from_trainer", "_automatic_optimization", - "_accumulate_grad_batches"] + special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx", "_support_closure", + "_trainer", "__getstate__", "__setstate__", "state_dict", "load_state_dict", + "zero_grad", "__setstate__", "add_param_group"] for k, v in lightning_optimizer.__dict__.items(): if k not in special_attrs: lightning_dict[k] = v diff --git a/tests/deprecated_api/__init__.py b/tests/deprecated_api/__init__.py new file mode 100644 index 0000000000000..99e21d1ed6b22 --- /dev/null +++ b/tests/deprecated_api/__init__.py @@ -0,0 +1,21 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in vX.Y.Z""" +import sys + + +def _soft_unimport_module(str_module): + # once the module is imported e.g with parsing with pytest it lives in memory + if str_module in sys.modules: + del sys.modules[str_module] diff --git a/tests/deprecated_api/test_remove_1-2.py b/tests/deprecated_api/test_remove_1-2.py new file mode 100644 index 0000000000000..331208d56df10 --- /dev/null +++ b/tests/deprecated_api/test_remove_1-2.py @@ -0,0 +1,45 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in vX.Y.Z""" + +import pytest +import torch + +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +def test_tbd_remove_in_v1_2_0(): + with pytest.deprecated_call(match='will be removed in v1.2'): + ModelCheckpoint(filepath='..') + + with pytest.deprecated_call(match='will be removed in v1.2'): + ModelCheckpoint('..') + + with pytest.raises(MisconfigurationException, match='inputs which are not feasible'): + ModelCheckpoint(filepath='..', dirpath='.') + + +def test_tbd_remove_in_v1_2_0_metrics(): + from pytorch_lightning.metrics.classification import Fbeta + from pytorch_lightning.metrics.functional.classification import f1_score, fbeta_score + + with pytest.deprecated_call(match='will be removed in v1.2'): + Fbeta(2) + + with pytest.deprecated_call(match='will be removed in v1.2'): + fbeta_score(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 1]), 0.2) + + with pytest.deprecated_call(match='will be removed in v1.2'): + f1_score(torch.tensor([0, 1, 0, 1]), torch.tensor([0, 1, 0, 0])) diff --git a/tests/test_deprecated.py b/tests/deprecated_api/test_remove_1-3.py similarity index 52% rename from tests/test_deprecated.py rename to tests/deprecated_api/test_remove_1-3.py index f549de1f4d71e..7ec69796b1e46 100644 --- a/tests/test_deprecated.py +++ b/tests/deprecated_api/test_remove_1-3.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test deprecated functionality which will be removed in vX.Y.Z""" -import sys from argparse import ArgumentParser from unittest import mock @@ -21,10 +20,8 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.metrics.functional.classification import auc from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import EvalModelTemplate def test_tbd_remove_in_v1_3_0(tmpdir): @@ -51,15 +48,61 @@ def __init__(self, hparams): DeprecatedHparamsModel({}) -def test_tbd_remove_in_v1_2_0(): - with pytest.deprecated_call(match='will be removed in v1.2'): - checkpoint_cb = ModelCheckpoint(filepath='.') +def test_tbd_remove_in_v1_3_0_metrics(): + from pytorch_lightning.metrics.functional.classification import to_onehot + with pytest.deprecated_call(match='will be removed in v1.3'): + to_onehot(torch.tensor([1, 2, 3])) + + from pytorch_lightning.metrics.functional.classification import to_categorical + with pytest.deprecated_call(match='will be removed in v1.3'): + to_categorical(torch.tensor([[0.2, 0.5], [0.9, 0.1]])) + + from pytorch_lightning.metrics.functional.classification import get_num_classes + with pytest.deprecated_call(match='will be removed in v1.3'): + get_num_classes(pred=torch.tensor([0, 1]), target=torch.tensor([1, 1])) - with pytest.deprecated_call(match='will be removed in v1.2'): - checkpoint_cb = ModelCheckpoint('.') + x_binary = torch.tensor([0, 1, 2, 3]) + y_binary = torch.tensor([0, 1, 2, 3]) - with pytest.raises(MisconfigurationException, match='inputs which are not feasible'): - checkpoint_cb = ModelCheckpoint(filepath='.', dirpath='.') + from pytorch_lightning.metrics.functional.classification import roc + with pytest.deprecated_call(match='will be removed in v1.3'): + roc(pred=x_binary, target=y_binary) + + from pytorch_lightning.metrics.functional.classification import _roc + with pytest.deprecated_call(match='will be removed in v1.3'): + _roc(pred=x_binary, target=y_binary) + + x_multy = torch.tensor([[0.85, 0.05, 0.05, 0.05], + [0.05, 0.85, 0.05, 0.05], + [0.05, 0.05, 0.85, 0.05], + [0.05, 0.05, 0.05, 0.85]]) + y_multy = torch.tensor([0, 1, 3, 2]) + + from pytorch_lightning.metrics.functional.classification import multiclass_roc + with pytest.deprecated_call(match='will be removed in v1.3'): + multiclass_roc(pred=x_multy, target=y_multy) + + from pytorch_lightning.metrics.functional.classification import average_precision + with pytest.deprecated_call(match='will be removed in v1.3'): + average_precision(pred=x_binary, target=y_binary) + + from pytorch_lightning.metrics.functional.classification import precision_recall_curve + with pytest.deprecated_call(match='will be removed in v1.3'): + precision_recall_curve(pred=x_binary, target=y_binary) + + from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve + with pytest.deprecated_call(match='will be removed in v1.3'): + multiclass_precision_recall_curve(pred=x_multy, target=y_multy) + + from pytorch_lightning.metrics.functional.reduction import reduce + with pytest.deprecated_call(match='will be removed in v1.3'): + reduce(torch.tensor([0, 1, 1, 0]), 'sum') + + from pytorch_lightning.metrics.functional.reduction import class_reduce + with pytest.deprecated_call(match='will be removed in v1.3'): + class_reduce(torch.randint(1, 10, (50,)).float(), + torch.randint(10, 20, (50,)).float(), + torch.randint(1, 100, (50,)).float()) # TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py @@ -68,6 +111,7 @@ def test_tbd_remove_in_v1_2_0(): (False, PassThroughProfiler), ]) def test_trainer_profiler_remove_in_v1_3_0(profiler, expected): + # remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py with pytest.deprecated_call(match='will be removed in v1.3'): trainer = Trainer(profiler=profiler) assert isinstance(trainer.profiler, expected) @@ -91,47 +135,3 @@ def test_trainer_cli_profiler_remove_in_v1_3_0(cli_args, expected_parsed_arg, ex assert getattr(args, "profiler") == expected_parsed_arg trainer = Trainer.from_argparse_args(args) assert isinstance(trainer.profiler, expected_profiler) - - -def _soft_unimport_module(str_module): - # once the module is imported e.g with parsing with pytest it lives in memory - if str_module in sys.modules: - del sys.modules[str_module] - - -class ModelVer0_6(EvalModelTemplate): - - # todo: this shall not be needed while evaluate asks for dataloader explicitly - def val_dataloader(self): - return self.dataloader(train=False) - - def validation_step(self, batch, batch_idx, *args, **kwargs): - return {'val_loss': torch.tensor(0.6)} - - def validation_end(self, outputs): - return {'val_loss': torch.tensor(0.6)} - - def test_dataloader(self): - return self.dataloader(train=False) - - def test_end(self, outputs): - return {'test_loss': torch.tensor(0.6)} - - -class ModelVer0_7(EvalModelTemplate): - - # todo: this shall not be needed while evaluate asks for dataloader explicitly - def val_dataloader(self): - return self.dataloader(train=False) - - def validation_step(self, batch, batch_idx, *args, **kwargs): - return {'val_loss': torch.tensor(0.7)} - - def validation_end(self, outputs): - return {'val_loss': torch.tensor(0.7)} - - def test_dataloader(self): - return self.dataloader(train=False) - - def test_end(self, outputs): - return {'test_loss': torch.tensor(0.7)} diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index f7bd7d558f5b4..a6fbe9e849785 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -17,13 +17,13 @@ accuracy, precision, recall, - _binary_clf_curve, dice_score, auroc, multiclass_auroc, auc, iou, ) +from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve from pytorch_lightning.metrics.utils import to_onehot, get_num_classes, to_categorical @@ -222,7 +222,7 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape): if sample_weight is not None: sample_weight = torch.ones_like(pred) * sample_weight - fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label) + fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) assert isinstance(tps, torch.Tensor) assert isinstance(fps, torch.Tensor) diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index a3919a6a8a7dd..82727d37479b6 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -21,44 +21,44 @@ import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils from pytorch_lightning import Trainer -from tests.base import EvalModelTemplate +from tests.base import BoringModel, EvalModelTemplate def test_model_saves_with_input_sample(tmpdir): """Test that ONNX model saves with input sample and size is greater than 3 MB""" - model = EvalModelTemplate() + model = BoringModel() trainer = Trainer(max_epochs=1) trainer.fit(model) file_path = os.path.join(tmpdir, "model.onnx") - input_sample = torch.randn((1, 28 * 28)) + input_sample = torch.randn((1, 32)) model.to_onnx(file_path, input_sample) assert os.path.isfile(file_path) - assert os.path.getsize(file_path) > 3e+06 + assert os.path.getsize(file_path) > 4e2 @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_model_saves_on_gpu(tmpdir): """Test that model saves on gpu""" - model = EvalModelTemplate() + model = BoringModel() trainer = Trainer(gpus=1, max_epochs=1) trainer.fit(model) file_path = os.path.join(tmpdir, "model.onnx") - input_sample = torch.randn((1, 28 * 28)) + input_sample = torch.randn((1, 32)) model.to_onnx(file_path, input_sample) assert os.path.isfile(file_path) - assert os.path.getsize(file_path) > 3e+06 + assert os.path.getsize(file_path) > 4e2 def test_model_saves_with_example_output(tmpdir): """Test that ONNX model saves when provided with example output""" - model = EvalModelTemplate() + model = BoringModel() trainer = Trainer(max_epochs=1) trainer.fit(model) file_path = os.path.join(tmpdir, "model.onnx") - input_sample = torch.randn((1, 28 * 28)) + input_sample = torch.randn((1, 32)) model.eval() example_outputs = model.forward(input_sample) model.to_onnx(file_path, input_sample, example_outputs=example_outputs) @@ -67,11 +67,13 @@ def test_model_saves_with_example_output(tmpdir): def test_model_saves_with_example_input_array(tmpdir): """Test that ONNX model saves with_example_input_array and size is greater than 3 MB""" - model = EvalModelTemplate() + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + file_path = os.path.join(tmpdir, "model.onnx") model.to_onnx(file_path) assert os.path.exists(file_path) is True - assert os.path.getsize(file_path) > 3e+06 + assert os.path.getsize(file_path) > 4e2 @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @@ -100,7 +102,9 @@ def test_model_saves_on_multi_gpu(tmpdir): def test_verbose_param(tmpdir, capsys): """Test that output is present when verbose parameter is set""" - model = EvalModelTemplate() + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + file_path = os.path.join(tmpdir, "model.onnx") model.to_onnx(file_path, verbose=True) captured = capsys.readouterr() @@ -108,8 +112,8 @@ def test_verbose_param(tmpdir, capsys): def test_error_if_no_input(tmpdir): - """Test that an exception is thrown when there is no input tensor""" - model = EvalModelTemplate() + """Test that an error is thrown when there is no input tensor""" + model = BoringModel() model.example_input_array = None file_path = os.path.join(tmpdir, "model.onnx") with pytest.raises(ValueError, match=r'Could not export to ONNX since neither `input_sample` nor' @@ -117,21 +121,12 @@ def test_error_if_no_input(tmpdir): model.to_onnx(file_path) -def test_error_if_input_sample_is_not_tensor(tmpdir): - """Test that an exception is thrown when there is no input tensor""" - model = EvalModelTemplate() - model.example_input_array = None - file_path = os.path.join(tmpdir, "model.onnx") - input_sample = np.random.randn(1, 28 * 28) - with pytest.raises(ValueError, match=f'Received `input_sample` of type {type(input_sample)}. Expected type is ' - f'`Tensor`'): - model.to_onnx(file_path, input_sample) - - def test_if_inference_output_is_valid(tmpdir): """Test that the output inferred from ONNX model is same as from PyTorch""" - model = EvalModelTemplate() - trainer = Trainer(max_epochs=5) + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + trainer = Trainer(max_epochs=2) trainer.fit(model) model.eval() diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index bf2c34b8bfef5..3c43b201f52e4 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -16,43 +16,72 @@ import pytest import torch -from tests.base import EvalModelTemplate +from tests.base import BoringModel from tests.base.datamodules import TrialMNISTDataModule from tests.base.models import ParityModuleRNN, BasicGAN @pytest.mark.parametrize("modelclass", [ - EvalModelTemplate, + BoringModel, ParityModuleRNN, BasicGAN, ]) def test_torchscript_input_output(modelclass): """ Test that scripted LightningModule forward works. """ model = modelclass() + + if isinstance(model, BoringModel): + model.example_input_array = torch.randn(5, 32) + script = model.to_torchscript() assert isinstance(script, torch.jit.ScriptModule) + model.eval() - model_output = model(model.example_input_array) + with torch.no_grad(): + model_output = model(model.example_input_array) + script_output = script(model.example_input_array) assert torch.allclose(script_output, model_output) @pytest.mark.parametrize("modelclass", [ - EvalModelTemplate, + BoringModel, ParityModuleRNN, BasicGAN, ]) -def test_torchscript_input_output_trace(modelclass): - """ Test that traced LightningModule forward works. """ +def test_torchscript_example_input_output_trace(modelclass): + """ Test that traced LightningModule forward works with example_input_array """ model = modelclass() + + if isinstance(model, BoringModel): + model.example_input_array = torch.randn(5, 32) + script = model.to_torchscript(method='trace') assert isinstance(script, torch.jit.ScriptModule) + model.eval() - model_output = model(model.example_input_array) + with torch.no_grad(): + model_output = model(model.example_input_array) + script_output = script(model.example_input_array) assert torch.allclose(script_output, model_output) +def test_torchscript_input_output_trace(): + """ Test that traced LightningModule forward works with example_inputs """ + model = BoringModel() + example_inputs = torch.randn(1, 32) + script = model.to_torchscript(example_inputs=example_inputs, method='trace') + assert isinstance(script, torch.jit.ScriptModule) + + model.eval() + with torch.no_grad(): + model_output = model(example_inputs) + + script_output = script(example_inputs) + assert torch.allclose(script_output, model_output) + + @pytest.mark.parametrize("device", [ torch.device("cpu"), torch.device("cuda", 0) @@ -60,7 +89,9 @@ def test_torchscript_input_output_trace(modelclass): @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") def test_torchscript_device(device): """ Test that scripted module is on the correct device. """ - model = EvalModelTemplate().to(device) + model = BoringModel().to(device) + model.example_input_array = torch.randn(5, 32) + script = model.to_torchscript() assert next(script.parameters()).device == device script_output = script(model.example_input_array.to(device)) @@ -69,7 +100,7 @@ def test_torchscript_device(device): def test_torchscript_retain_training_state(): """ Test that torchscript export does not alter the training mode of original model. """ - model = EvalModelTemplate() + model = BoringModel() model.train(True) script = model.to_torchscript() assert model.training @@ -81,7 +112,7 @@ def test_torchscript_retain_training_state(): @pytest.mark.parametrize("modelclass", [ - EvalModelTemplate, + BoringModel, ParityModuleRNN, BasicGAN, ]) @@ -100,7 +131,7 @@ def test_torchscript_properties(modelclass): @pytest.mark.parametrize("modelclass", [ - EvalModelTemplate, + BoringModel, ParityModuleRNN, BasicGAN, ]) @@ -109,9 +140,27 @@ def test_torchscript_properties(modelclass): reason="torch.save/load has bug loading script modules on torch <= 1.4", ) def test_torchscript_save_load(tmpdir, modelclass): - """ Test that scripted LightningModules is correctly saved and can be loaded. """ + """ Test that scripted LightningModule is correctly saved and can be loaded. """ model = modelclass() output_file = str(tmpdir / "model.pt") script = model.to_torchscript(file_path=output_file) loaded_script = torch.jit.load(output_file) assert torch.allclose(next(script.parameters()), next(loaded_script.parameters())) + + +def test_torchcript_invalid_method(tmpdir): + """Test that an error is thrown with invalid torchscript method""" + model = BoringModel() + model.train(True) + + with pytest.raises(ValueError, match="only supports 'script' or 'trace'"): + model.to_torchscript(method='temp') + + +def test_torchscript_with_no_input(tmpdir): + """Test that an error is thrown when there is no input tensor""" + model = BoringModel() + model.example_input_array = None + + with pytest.raises(ValueError, match='requires either `example_inputs` or `model.example_input_array`'): + model.to_torchscript(method='trace') diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 5e341e9c66f63..33d14e852b285 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -825,7 +825,7 @@ def optimizer_closure(): retain_graph = num_backward != backward_idx # noqa E225 self.manual_backward(loss_1, opt, retain_graph=retain_graph) - opt.step(1, closure=optimizer_closure, something="new") + opt.step(closure=optimizer_closure) def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer @@ -855,7 +855,7 @@ def automatic_optimization(self) -> bool: ) trainer.fit(model) - expected_calls = [call(1, closure=ANY, something="new") for s in range(2)] + expected_calls = [call(closure=ANY) for s in range(2)] step_mock.assert_has_calls(expected_calls) @@ -902,7 +902,7 @@ def dis_closure(): if batch_idx % 4 == 0 : # Note: Set make_optimizer_step to True or it will use by default # Trainer(accumulate_grad_batches=x) - opt_dis.step(closure=dis_closure, make_optimizer_step=True, optim='adam') + opt_dis.step(closure=dis_closure, make_optimizer_step=True) def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer @@ -935,8 +935,7 @@ def automatic_optimization(self) -> bool: trainer.fit(model) expected_calls = [call(closure=ANY, optim='sgd') for s in range(4)] mock_sgd_step.assert_has_calls(expected_calls) - - expected_calls = [call(closure=ANY, optim='adam') for s in range(2)] + expected_calls = [call(closure=ANY) for s in range(2)] mock_adam_step.assert_has_calls(expected_calls) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 9b29d6ec2b1dd..c24f1f5421e5c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -958,6 +958,7 @@ def test_gradient_clipping(tmpdir): """ Test gradient clipping """ + tutils.reset_seed() model = EvalModelTemplate() @@ -995,6 +996,7 @@ def test_gradient_clipping_fp16(tmpdir): """ Test gradient clipping with fp16 """ + tutils.reset_seed() model = EvalModelTemplate()