diff --git a/.github/workflows/ci_test-base.yml b/.github/workflows/ci_test-base.yml index afd86ca98c213..c0b97439737ff 100644 --- a/.github/workflows/ci_test-base.yml +++ b/.github/workflows/ci_test-base.yml @@ -76,8 +76,7 @@ jobs: with: name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }} path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - # Use always() to always run this step to publish test results when there are test failures - if: always() + if: failure() - name: Statistics if: success() diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 7fc4de8ddbfd3..d64fedbfbe590 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -50,5 +50,4 @@ jobs: with: name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }} path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - # Use always() to always run this step to publish test results when there are test failures - if: always() + if: failure() diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 1a2115a40fcfd..b87a1d8557843 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -129,8 +129,7 @@ jobs: with: name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }} path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - # Use always() to always run this step to publish test results when there are test failures - if: always() + if: failure() - name: Statistics if: success() diff --git a/CHANGELOG.md b/CHANGELOG.md index 951aee6049b7a..f078349ef3665 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,47 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [unreleased.Features] - YYYY-MM-DD + +### Added + + +### Changed + + +### Deprecated + + +### Removed + + +### Fixed + + + +## [unreleased.BugFix] - YYYY-MM-DD + +### Added + + +### Changed + + +### Deprecated + + +### Removed + + +### Fixed + +- Fixed trainer by default `None` in `DDPAccelerator` ([#4915](https://github.com/PyTorchLightning/pytorch-lightning/pull/4915)) + + +- Fixed `LightningOptimizer` exposes optimizer attributes ([#5095](https://github.com/PyTorchLightning/pytorch-lightning/pull/5095)) + + + ## [1.1.0] - 2020-12-09 ### Added @@ -44,9 +85,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- Removed `multiclass_roc` and `multiclass_precision_recall_curve`, use `roc` and `precision_recall_curve` instead ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549)) - Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) -- WandbLogger does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648)) +- `WandbLogger` does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648)) - Changed `automatic_optimization` to be a model attribute ([#4602](https://github.com/PyTorchLightning/pytorch-lightning/pull/4602)) - Changed `Simple Profiler` report to order by percentage time spent + num calls ([#4880](https://github.com/PyTorchLightning/pytorch-lightning/pull/4880)) - Simplify optimization Logic ([#4984](https://github.com/PyTorchLightning/pytorch-lightning/pull/4984)) @@ -64,6 +104,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed - Removed `reorder` parameter of the `auc` metric ([#5004](https://github.com/PyTorchLightning/pytorch-lightning/pull/5004)) +- Removed `multiclass_roc` and `multiclass_precision_recall_curve`, use `roc` and `precision_recall_curve` instead ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549)) ### Fixed diff --git a/docs/source/optimizers.rst b/docs/source/optimizers.rst index 6ca72b8069d6d..06e6e9679d29f 100644 --- a/docs/source/optimizers.rst +++ b/docs/source/optimizers.rst @@ -191,37 +191,48 @@ override the :meth:`optimizer_step` function. For example, here step optimizer A every 2 batches and optimizer B every 4 batches -.. testcode:: +.. note:: When using Trainer(enable_pl_optimizer=True), there is no need to call `.zero_grad()`. - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False): - optimizer.step() +.. testcode:: def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx): optimizer.zero_grad() # Alternating schedule for optimizer steps (ie: GANs) - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False): + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): # update generator opt every 2 steps if optimizer_i == 0: if batch_nb % 2 == 0 : - optimizer.step() - optimizer.zero_grad() + optimizer.step(closure=closure) # update discriminator opt every 4 steps if optimizer_i == 1: if batch_nb % 4 == 0 : - optimizer.step() - optimizer.zero_grad() + optimizer.step(closure=closure) + +.. note:: When using ``Trainer(enable_pl_optimizer=True)``, ``.step`` accepts a boolean ``make_optimizer_step`` which can be used as follow. + +.. testcode:: + + def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx): + optimizer.zero_grad() + + # Alternating schedule for optimizer steps (ie: GANs) + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): + # update generator opt every 2 steps + if optimizer_i == 0: + optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 2) == 0) - # ... - # add as many optimizers as you want + # update discriminator opt every 4 steps + if optimizer_i == 1: + optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 4) == 0) Here we add a learning-rate warm up .. testcode:: # learning rate warm-up - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False): + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): # warm up lr if self.trainer.global_step < 500: lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) @@ -229,8 +240,20 @@ Here we add a learning-rate warm up pg['lr'] = lr_scale * self.hparams.learning_rate # update params - optimizer.step() - optimizer.zero_grad() + optimizer.step(closure=closure) + +The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step. + +.. testcode:: + + from pytorch_lightning.core.optimizer import LightningOptimizer + + # function hook in LightningModule + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): + if not isinstance(optimizer, LightningOptimizer): + # wraps into LightingOptimizer only for running step + optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer) + optimizer.step(closure=closure) ---------- diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index d6c5139cb3799..408d95a72dc47 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -1,6 +1,6 @@ """Root package info.""" -__version__ = '1.1.0' +__version__ = '1.1.1rc0' __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0' diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f29e7f75bfbff..ef05ce69c1828 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1170,7 +1170,6 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): def optimizer_step( self, - *args, epoch: int = None, batch_idx: int = None, optimizer: Optimizer = None, @@ -1179,7 +1178,6 @@ def optimizer_step( on_tpu: bool = None, using_native_amp: bool = None, using_lbfgs: bool = None, - **kwargs, ) -> None: r""" Override this method to adjust the default way the @@ -1254,7 +1252,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, if not isinstance(optimizer, LightningOptimizer): # wraps into LightingOptimizer only for running step optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer) - optimizer.step(closure=optimizer_closure, *args, **kwargs) + optimizer.step(closure=optimizer_closure) def optimizer_zero_grad( self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index dc63231ba6ccb..c8e9ff8b80a2f 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -57,12 +57,35 @@ def __init__(self, else: self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) - self._trainer = None self._optimizer = optimizer + self._trainer = None self._accumulate_grad_batches = accumulate_grad_batches - self._automatic_optimization = None self._optimizer_idx = None + @property + def defaults(self): + return self._optimizer.defaults + + @defaults.setter + def defaults(self, defaults): + self._optimizer.defaults = defaults + + @property + def state(self): + return self._optimizer.state + + @state.setter + def state(self, state): + self._optimizer.state = state + + @property + def param_groups(self): + return self._optimizer.param_groups + + @param_groups.setter + def param_groups(self, param_groups): + self._optimizer.param_groups = param_groups + @property def accumulate_grad_batches(self): return self._accumulate_grad_batches @@ -73,7 +96,6 @@ def accumulate_grad_batches(self, accumulate_grad_batches): def _on_trainer_init(self, trainer): self._trainer = proxy(trainer) - self._automatic_optimization = trainer.train_loop.automatic_optimization for opt_idx, opt in enumerate(trainer.optimizers): if opt == self._optimizer: self._optimizer_idx = opt_idx diff --git a/pytorch_lightning/metrics/classification/__init__.py b/pytorch_lightning/metrics/classification/__init__.py index 13cb705f30b17..b4cbb6b073efe 100644 --- a/pytorch_lightning/metrics/classification/__init__.py +++ b/pytorch_lightning/metrics/classification/__init__.py @@ -14,7 +14,7 @@ from pytorch_lightning.metrics.classification.accuracy import Accuracy from pytorch_lightning.metrics.classification.average_precision import AveragePrecision from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix -from pytorch_lightning.metrics.classification.f_beta import FBeta, F1 +from pytorch_lightning.metrics.classification.f_beta import FBeta, Fbeta, F1 from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve from pytorch_lightning.metrics.classification.roc import ROC diff --git a/pytorch_lightning/metrics/classification/average_precision.py b/pytorch_lightning/metrics/classification/average_precision.py index 0a8a952470dbc..33878cb48965d 100644 --- a/pytorch_lightning/metrics/classification/average_precision.py +++ b/pytorch_lightning/metrics/classification/average_precision.py @@ -92,9 +92,8 @@ def __init__( self.add_state("target", default=[], dist_reduce_fx=None) rank_zero_warn( - 'Metric `AveragePrecision` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' + 'Metric `AveragePrecision` will save all targets and predictions in buffer.' + ' For large datasets this may lead to large memory footprint.' ) def update(self, preds: torch.Tensor, target: torch.Tensor): diff --git a/pytorch_lightning/metrics/classification/f_beta.py b/pytorch_lightning/metrics/classification/f_beta.py index 56cc00f9a5dce..d6147b00463b3 100755 --- a/pytorch_lightning/metrics/classification/f_beta.py +++ b/pytorch_lightning/metrics/classification/f_beta.py @@ -20,6 +20,7 @@ _fbeta_compute ) from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.utilities import rank_zero_warn class FBeta(Metric): @@ -131,6 +132,34 @@ def compute(self) -> torch.Tensor: self.actual_positives, self.beta, self.average) +# todo: remove in v1.2 +class Fbeta(FBeta): + r""" + Computes `F-score `_ + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.classification.f_beta.FBeta` + """ + def __init__( + self, + num_classes: int, + beta: float = 1.0, + threshold: float = 0.5, + average: str = "micro", + multilabel: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + rank_zero_warn( + "This `Fbeta` was deprecated in v1.0.x in favor of" + " `from pytorch_lightning.metrics.classification.f_beta import FBeta`." + " It will be removed in v1.2.0", DeprecationWarning + ) + super().__init__( + num_classes, beta, threshold, average, multilabel, compute_on_step, dist_sync_on_step, process_group + ) + + class F1(FBeta): """ Computes F1 metric. F1 metrics correspond to a harmonic mean of the diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py index 052a25a7a977d..620904898535d 100644 --- a/pytorch_lightning/metrics/classification/precision_recall_curve.py +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -102,9 +102,8 @@ def __init__( self.add_state("target", default=[], dist_reduce_fx=None) rank_zero_warn( - 'Metric `PrecisionRecallCurve` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' + 'Metric `PrecisionRecallCurve` will save all targets and predictions in buffer.' + ' For large datasets this may lead to large memory footprint.' ) def update(self, preds: torch.Tensor, target: torch.Tensor): diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py index 89e8265b19fc1..2b7d82488b491 100644 --- a/pytorch_lightning/metrics/classification/roc.py +++ b/pytorch_lightning/metrics/classification/roc.py @@ -105,9 +105,8 @@ def __init__( self.add_state("target", default=[], dist_reduce_fx=None) rank_zero_warn( - 'Metric `ROC` will save all targets and' - ' predictions in buffer. For large datasets this may lead' - ' to large memory footprint.' + 'Metric `ROC` will save all targets and predictions in buffer.' + ' For large datasets this may lead to large memory footprint.' ) def update(self, preds: torch.Tensor, target: torch.Tensor): diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index e13242e40b0ac..e38ab5f415c32 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -17,13 +17,18 @@ auc, auroc, dice_score, + f1_score, + fbeta_score, + get_num_classes, + iou, multiclass_auroc, precision, precision_recall, recall, stat_scores, stat_scores_multiple_classes, - iou, + to_categorical, + to_onehot, ) from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # TODO: unify metrics between class and functional, add below diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 1c43ec75bb508..e1ba601b51553 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -15,12 +15,75 @@ from typing import Callable, Optional, Sequence, Tuple import torch -from torch.nn import functional as F -from pytorch_lightning.metrics.utils import to_categorical, get_num_classes, reduce, class_reduce +from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap +from pytorch_lightning.metrics.functional.f_beta import fbeta as __fb, f1 as __f1 +from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve, precision_recall_curve as __prc +from pytorch_lightning.metrics.functional.roc import roc as __roc +from pytorch_lightning.metrics.utils import ( + to_categorical as __tc, + to_onehot as __to, + get_num_classes as __gnc, + reduce, + class_reduce, +) from pytorch_lightning.utilities import rank_zero_warn +def to_onehot( + tensor: torch.Tensor, + num_classes: Optional[int] = None, +) -> torch.Tensor: + """ + Converts a dense label tensor to one-hot format + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_onehot` + """ + rank_zero_warn( + "This `to_onehot` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.utils import to_onehot`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __to(tensor, num_classes) + + +def to_categorical( + tensor: torch.Tensor, + argmax_dim: int = 1 +) -> torch.Tensor: + """ + Converts a tensor of probabilities to a dense label tensor + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_categorical` + + """ + rank_zero_warn( + "This `to_categorical` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.utils import to_categorical`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __tc(tensor) + + +def get_num_classes( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, +) -> int: + """ + Calculates the number of classes for a given prediction and target tensor. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.get_num_classes` + + """ + rank_zero_warn( + "This `get_num_classes` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.utils import get_num_classes`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __gnc(pred,target, num_classes) + + def stat_scores( pred: torch.Tensor, target: torch.Tensor, @@ -332,52 +395,28 @@ def recall( num_classes=num_classes, class_reduction=class_reduction)[1] -def _binary_clf_curve( +# todo: remove in 1.3 +def roc( pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None, pos_label: int = 1., ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py - """ - if sample_weight is not None and not isinstance(sample_weight, torch.Tensor): - sample_weight = torch.tensor(sample_weight, device=pred.device, dtype=torch.float) - - # remove class dimension if necessary - if pred.ndim > target.ndim: - pred = pred[:, 0] - desc_score_indices = torch.argsort(pred, descending=True) - - pred = pred[desc_score_indices] - target = target[desc_score_indices] - - if sample_weight is not None: - weight = sample_weight[desc_score_indices] - else: - weight = 1. - - # pred typically has many tied values. Here we extract - # the indices associated with the distinct values. We also - # concatenate a value for the end of the curve. - distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0] - threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1) - - target = (target == pos_label).to(torch.long) - tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] - - if sample_weight is not None: - # express fps as a cumsum to ensure fps is increasing even in - # the presence of floating point errors - fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] - else: - fps = 1 + threshold_idxs - tps + Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. - return fps, tps, pred[threshold_idxs] + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc` + """ + rank_zero_warn( + "This `multiclass_roc` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.roc import roc`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) # TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py -def __roc( +def _roc( pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None, @@ -386,22 +425,13 @@ def __roc( """ Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. - .. warning:: Deprecated - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weights - pos_label: the label for the positive class - - Return: - false-positive rate (fpr), true-positive rate (tpr), thresholds + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc` Example: >>> x = torch.tensor([0, 1, 2, 3]) >>> y = torch.tensor([0, 1, 1, 1]) - >>> fpr, tpr, thresholds = __roc(x, y) + >>> fpr, tpr, thresholds = _roc(x, y) >>> fpr tensor([0., 0., 0., 0., 1.]) >>> tpr @@ -410,9 +440,12 @@ def __roc( tensor([4, 3, 2, 1, 0]) """ - fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=pos_label) + rank_zero_warn( + "This `multiclass_roc` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.roc import roc`." + " It will be removed in v1.3.0", DeprecationWarning + ) + fps, tps, thresholds = _binary_clf_curve(pred, target, sample_weights=sample_weight, pos_label=pos_label) # Add an extra threshold position # to make sure that the curve starts at (0, 0) @@ -434,7 +467,7 @@ def __roc( # TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py -def __multiclass_roc( +def multiclass_roc( pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None, @@ -443,7 +476,7 @@ def __multiclass_roc( """ Computes the Receiver Operating Characteristic (ROC) for multiclass predictors. - .. warning:: Deprecated + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc` Args: pred: estimated probabilities @@ -462,19 +495,24 @@ def __multiclass_roc( ... [0.05, 0.05, 0.85, 0.05], ... [0.05, 0.05, 0.05, 0.85]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> __multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE + >>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE ((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500]))) """ + rank_zero_warn( + "This `multiclass_roc` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.roc import roc`." + " It will be removed in v1.3.0", DeprecationWarning + ) num_classes = get_num_classes(pred, target, num_classes) class_roc_vals = [] for c in range(num_classes): pred_c = pred[:, c] - class_roc_vals.append(__roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c)) + class_roc_vals.append(_roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c)) return tuple(class_roc_vals) @@ -572,7 +610,7 @@ def auroc( @auc_decorator() def _auroc(pred, target, sample_weight, pos_label): - return __roc(pred, target, sample_weight, pos_label) + return _roc(pred, target, sample_weight, pos_label) return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) @@ -625,7 +663,7 @@ def multiclass_auroc( @multiclass_auc_decorator() def _multiclass_auroc(pred, target, sample_weight, num_classes): - return __multiclass_roc(pred, target, sample_weight, num_classes) + return multiclass_roc(pred, target, sample_weight, num_classes) class_aurocs = _multiclass_auroc(pred=pred, target=target, sample_weight=sample_weight, @@ -772,3 +810,110 @@ def iou( ]) return reduce(scores, reduction=reduction) + + +# todo: remove in 1.3 +def precision_recall_curve( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., +): + """ + Computes precision-recall pairs for different thresholds. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve` + """ + rank_zero_warn( + "This `precision_recall_curve` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __prc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) + + +# todo: remove in 1.3 +def multiclass_precision_recall_curve( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, +): + """ + Computes precision-recall pairs for different thresholds given a multiclass scores. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve` + """ + rank_zero_warn( + "This `multiclass_precision_recall_curve` was deprecated in v1.1.0 in favor of" + " `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`." + " It will be removed in v1.3.0", DeprecationWarning + ) + if num_classes is None: + num_classes = get_num_classes(pred, target, num_classes) + return __prc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes) + + +# todo: remove in 1.3 +def average_precision( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., +): + """ + Compute average precision from prediction scores. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.average_precision.average_precision` + """ + rank_zero_warn( + "This `average_precision` was deprecated in v1.1.0 in favor of" + " `pytorch_lightning.metrics.functional.average_precision import average_precision`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __ap(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) + + +# todo: remove in 1.2 +def fbeta_score( + pred: torch.Tensor, + target: torch.Tensor, + beta: float, + num_classes: Optional[int] = None, + class_reduction: str = 'micro', +) -> torch.Tensor: + """ + Computes the F-beta score which is a weighted harmonic mean of precision and recall. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.f_beta.fbeta` + """ + rank_zero_warn( + "This `average_precision` was deprecated in v1.0.x in favor of" + " `from pytorch_lightning.metrics.functional.f_beta import fbeta`." + " It will be removed in v1.2.0", DeprecationWarning + ) + if num_classes is None: + num_classes = get_num_classes(pred, target) + return __fb(preds=pred, target=target, beta=beta, num_classes=num_classes, average=class_reduction) + + +# todo: remove in 1.2 +def f1_score( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + class_reduction: str = 'micro', +) -> torch.Tensor: + """ + Computes the F1-score (a.k.a F-measure), which is the harmonic mean of the precision and recall. + + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.f_beta.f1` + """ + rank_zero_warn( + "This `average_precision` was deprecated in v1.0.x in favor of" + " `from pytorch_lightning.metrics.functional.f_beta import f1`." + " It will be removed in v1.2.0", DeprecationWarning + ) + if num_classes is None: + num_classes = get_num_classes(pred, target) + return __f1(preds=pred, target=target, num_classes=num_classes, average=class_reduction) diff --git a/pytorch_lightning/metrics/functional/explained_variance.py b/pytorch_lightning/metrics/functional/explained_variance.py index 012e1486ebb1f..20b38c58a2a6b 100644 --- a/pytorch_lightning/metrics/functional/explained_variance.py +++ b/pytorch_lightning/metrics/functional/explained_variance.py @@ -23,10 +23,11 @@ def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tup return preds, target -def _explained_variance_compute(preds: torch.Tensor, - target: torch.Tensor, - multioutput: str = 'uniform_average', - ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: +def _explained_variance_compute( + preds: torch.Tensor, + target: torch.Tensor, + multioutput: str = 'uniform_average', +) -> Union[torch.Tensor, Sequence[torch.Tensor]]: diff_avg = torch.mean(target - preds, dim=0) numerator = torch.mean((target - preds - diff_avg) ** 2, dim=0) @@ -52,10 +53,11 @@ def _explained_variance_compute(preds: torch.Tensor, return torch.sum(denominator / denom_sum * output_scores) -def explained_variance(preds: torch.Tensor, - target: torch.Tensor, - multioutput: str = 'uniform_average', - ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: +def explained_variance( + preds: torch.Tensor, + target: torch.Tensor, + multioutput: str = 'uniform_average', +) -> Union[torch.Tensor, Sequence[torch.Tensor]]: """ Computes explained variance. diff --git a/pytorch_lightning/metrics/functional/reduction.py b/pytorch_lightning/metrics/functional/reduction.py new file mode 100644 index 0000000000000..c116b16d363a9 --- /dev/null +++ b/pytorch_lightning/metrics/functional/reduction.py @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from pytorch_lightning.metrics.utils import reduce as __reduce, class_reduce as __cr +from pytorch_lightning.utilities import rank_zero_warn + + +def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: + rank_zero_warn( + "This `reduce` was deprecated in v1.1.0 in favor of" + " `pytorch_lightning.metrics.utils import reduce`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __reduce(to_reduce=to_reduce, reduction=reduction) + + +def class_reduce(num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = 'none'): + rank_zero_warn( + "This `class_reduce` was deprecated in v1.1.0 in favor of" + " `pytorch_lightning.metrics.utils import class_reduce`." + " It will be removed in v1.3.0", DeprecationWarning + ) + return __cr(num=num, denom=denom, weights=weights, class_reduction=class_reduction) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 20dfb0f4b380f..68a0f4781c9a9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -477,7 +477,7 @@ def _process_result(self, training_step_output, split_batch): return training_step_output_for_epoch_end - def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure, *args, **kwargs): + def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): model_ref = self.trainer.get_model() is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) @@ -491,16 +491,14 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_ # model hook model_ref.optimizer_step( - epoch=self.trainer.current_epoch, - batch_idx=batch_idx, - optimizer=optimizer, - optimizer_idx=opt_idx, - optimizer_closure=train_step_and_backward_closure, + self.trainer.current_epoch, + batch_idx, + optimizer, + opt_idx, + train_step_and_backward_closure, on_tpu=self.trainer.use_tpu and TPU_AVAILABLE, using_native_amp=using_native_amp, using_lbfgs=is_lbfgs, - *args, - **kwargs, ) def on_before_zero_grad(self, optimizer): diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index a11862b4003bc..e5641337cc8d2 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -54,6 +54,7 @@ def _module_available(module_path: str) -> bool: OMEGACONF_AVAILABLE = _module_available("omegaconf") HYDRA_AVAILABLE = _module_available("hydra") HOROVOD_AVAILABLE = _module_available("horovod.torch") +BOLTS_AVAILABLE = _module_available("pl_bolts") TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') diff --git a/requirements/extra.txt b/requirements/extra.txt index ad54358269bd1..3f14b1e5910dd 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,4 +7,4 @@ torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip \ No newline at end of file +https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 3e2e6d040f44c..e3a597063d02e 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -55,6 +55,11 @@ def test_automatic_optimization_num_calls(enable_pl_optimizer, tmpdir): class TestModel(BoringModel): + def training_step(self, batch, batch_idx, optimizer_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + def configure_optimizers(self): optimizer = SGD(self.layer.parameters(), lr=0.1) optimizer_2 = Adam(self.layer.parameters(), lr=0.1) @@ -98,3 +103,47 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, assert sgd_zero_grad.call_count == 4 assert adam_step.call_count == 2 assert adam_zero_grad.call_count == 2 + + +@pytest.mark.parametrize("enable_pl_optimizer", [False, True]) +def test_params_groups_and_state_are_accessible(enable_pl_optimizer, tmpdir): + + with patch("torch.optim.SGD.step") as sgd_step, \ + patch("torch.optim.SGD.zero_grad") as sgd_zero_grad, \ + patch("torch.optim.Adam.step") as adam_step, \ + patch("torch.optim.Adam.zero_grad") as adam_zero_grad: + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def configure_optimizers(self): + optimizer = SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = Adam(self.layer.parameters(), lr=0.1) + return [optimizer, optimizer_2] + + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, + on_tpu=False, using_native_amp=False, using_lbfgs=False): + # warm up lr + if self.trainer.global_step < 500: + lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) + for pg in optimizer.param_groups: + pg['lr'] = lr_scale * 0.01 + + optimizer.step(closure=closure) + + model = TestModel() + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=8, + accumulate_grad_batches=1, + enable_pl_optimizer=enable_pl_optimizer + ) + + trainer.fit(model) diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index e6ec59ec4f5aa..a9fcf918cc699 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -19,10 +19,12 @@ import torch.nn as nn from torch.optim import Adam, Optimizer +import pytorch_lightning as pl from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset +from tests.base.boring_model import BoringModel, RandomDataset, RandomDictDataset, RandomDictStringDataset def test_lightning_optimizer(tmpdir): @@ -80,8 +82,8 @@ def configure_optimizers(self): assert trainer.optimizers[0].__repr__() == expected -@patch("torch.optim.Adam.step") -@patch("torch.optim.SGD.step") +@patch("torch.optim.Adam.step", autospec=True) +@patch("torch.optim.SGD.step", autospec=True) def test_lightning_optimizer_manual_optimization(mock_sgd_step, mock_adam_step, tmpdir): """ Test that the user can use our LightningOptimizer. Not recommended for now. @@ -96,13 +98,13 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): output = self.layer(batch) loss_1 = self.loss(batch, output) self.manual_backward(loss_1, opt_1) - opt_1.step(idx="1") + opt_1.step() def closure(): output = self.layer(batch) loss_2 = self.loss(batch, output) self.manual_backward(loss_2, opt_2) - opt_2.step(closure=closure, idx="2") + opt_2.step(closure=closure) def configure_optimizers(self): optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -133,8 +135,8 @@ def automatic_optimization(self) -> bool: assert len(mock_adam_step.mock_calls) == 8 -@patch("torch.optim.Adam.step") -@patch("torch.optim.SGD.step") +@patch("torch.optim.Adam.step", autospec=True) +@patch("torch.optim.SGD.step", autospec=True) def test_lightning_optimizer_manual_optimization_and_accumulated_gradients(mock_sgd_step, mock_adam_step, tmpdir): """ Test that the user can use our LightningOptimizer. Not recommended. @@ -149,13 +151,13 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): output = self.layer(batch) loss_1 = self.loss(batch, output) self.manual_backward(loss_1, opt_1) - opt_1.step(idx="1") + opt_1.step() def closure(): output = self.layer(batch) loss_2 = self.loss(batch, output) self.manual_backward(loss_2, opt_2) - opt_2.step(closure=closure, idx="2") + opt_2.step(closure=closure) def configure_optimizers(self): optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -191,13 +193,29 @@ def test_state(tmpdir): model = torch.nn.Linear(3, 4) optimizer = torch.optim.Adam(model.parameters()) lightning_optimizer = LightningOptimizer(optimizer) + + # test state + assert optimizer.state == lightning_optimizer.state + lightning_optimizer.state = optimizer.state + assert optimizer.state == lightning_optimizer.state + + # test param_groups + assert optimizer.param_groups == lightning_optimizer.param_groups + lightning_optimizer.param_groups = optimizer.param_groups + assert optimizer.param_groups == lightning_optimizer.param_groups + + # test defaults + assert optimizer.defaults == lightning_optimizer.defaults + lightning_optimizer.defaults = optimizer.defaults + assert optimizer.defaults == lightning_optimizer.defaults + assert isinstance(lightning_optimizer, LightningOptimizer) assert isinstance(lightning_optimizer, Adam) assert isinstance(lightning_optimizer, Optimizer) lightning_dict = {} - special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx", - "_trainer", "_use_accumulate_grad_batches_from_trainer", "_automatic_optimization", - "_accumulate_grad_batches"] + special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx", "_support_closure", + "_trainer", "__getstate__", "__setstate__", "state_dict", "load_state_dict", + "zero_grad", "__setstate__", "add_param_group"] for k, v in lightning_optimizer.__dict__.items(): if k not in special_attrs: lightning_dict[k] = v diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index f7bd7d558f5b4..a6fbe9e849785 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -17,13 +17,13 @@ accuracy, precision, recall, - _binary_clf_curve, dice_score, auroc, multiclass_auroc, auc, iou, ) +from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve from pytorch_lightning.metrics.utils import to_onehot, get_num_classes, to_categorical @@ -222,7 +222,7 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape): if sample_weight is not None: sample_weight = torch.ones_like(pred) * sample_weight - fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label) + fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) assert isinstance(tps, torch.Tensor) assert isinstance(fps, torch.Tensor) diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index f549de1f4d71e..59c6728009b6f 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -51,6 +51,63 @@ def __init__(self, hparams): DeprecatedHparamsModel({}) +def test_tbd_remove_in_v1_3_0_metrics(): + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import to_onehot + to_onehot(torch.tensor([1, 2, 3])) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import to_categorical + to_categorical(torch.tensor([[0.2, 0.5], [0.9, 0.1]])) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import get_num_classes + get_num_classes(pred=torch.tensor([0, 1]), target=torch.tensor([1, 1])) + + x_binary = torch.tensor([0, 1, 2, 3]) + y_binary = torch.tensor([0, 1, 2, 3]) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import roc + roc(pred=x_binary, target=y_binary) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import _roc + _roc(pred=x_binary, target=y_binary) + + x_multy = torch.tensor([[0.85, 0.05, 0.05, 0.05], + [0.05, 0.85, 0.05, 0.05], + [0.05, 0.05, 0.85, 0.05], + [0.05, 0.05, 0.05, 0.85]]) + y_multy = torch.tensor([0, 1, 3, 2]) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import multiclass_roc + multiclass_roc(pred=x_multy, target=y_multy) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import average_precision + average_precision(pred=x_binary, target=y_binary) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import precision_recall_curve + precision_recall_curve(pred=x_binary, target=y_binary) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve + multiclass_precision_recall_curve(pred=x_multy, target=y_multy) + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.reduction import reduce + reduce(torch.tensor([0, 1, 1, 0]), 'sum') + + with pytest.deprecated_call(match='will be removed in v1.3'): + from pytorch_lightning.metrics.functional.reduction import class_reduce + class_reduce(torch.randint(1, 10, (50,)).float(), + torch.randint(10, 20, (50,)).float(), + torch.randint(1, 100, (50,)).float()) + + def test_tbd_remove_in_v1_2_0(): with pytest.deprecated_call(match='will be removed in v1.2'): checkpoint_cb = ModelCheckpoint(filepath='.') @@ -62,6 +119,20 @@ def test_tbd_remove_in_v1_2_0(): checkpoint_cb = ModelCheckpoint(filepath='.', dirpath='.') +def test_tbd_remove_in_v1_2_0_metrics(): + from pytorch_lightning.metrics.classification import Fbeta + from pytorch_lightning.metrics.functional.classification import f1_score, fbeta_score + + with pytest.deprecated_call(match='will be removed in v1.2'): + Fbeta(2) + + with pytest.deprecated_call(match='will be removed in v1.2'): + fbeta_score(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 1]), 0.2) + + with pytest.deprecated_call(match='will be removed in v1.2'): + f1_score(torch.tensor([0, 1, 0, 1]), torch.tensor([0, 1, 0, 0])) + + # TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py @pytest.mark.parametrize(['profiler', 'expected'], [ (True, SimpleProfiler), diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 5e341e9c66f63..33d14e852b285 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -825,7 +825,7 @@ def optimizer_closure(): retain_graph = num_backward != backward_idx # noqa E225 self.manual_backward(loss_1, opt, retain_graph=retain_graph) - opt.step(1, closure=optimizer_closure, something="new") + opt.step(closure=optimizer_closure) def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer @@ -855,7 +855,7 @@ def automatic_optimization(self) -> bool: ) trainer.fit(model) - expected_calls = [call(1, closure=ANY, something="new") for s in range(2)] + expected_calls = [call(closure=ANY) for s in range(2)] step_mock.assert_has_calls(expected_calls) @@ -902,7 +902,7 @@ def dis_closure(): if batch_idx % 4 == 0 : # Note: Set make_optimizer_step to True or it will use by default # Trainer(accumulate_grad_batches=x) - opt_dis.step(closure=dis_closure, make_optimizer_step=True, optim='adam') + opt_dis.step(closure=dis_closure, make_optimizer_step=True) def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer @@ -935,8 +935,7 @@ def automatic_optimization(self) -> bool: trainer.fit(model) expected_calls = [call(closure=ANY, optim='sgd') for s in range(4)] mock_sgd_step.assert_has_calls(expected_calls) - - expected_calls = [call(closure=ANY, optim='adam') for s in range(2)] + expected_calls = [call(closure=ANY) for s in range(2)] mock_adam_step.assert_has_calls(expected_calls)