diff --git a/docs/source/conf.py b/docs/source/conf.py index 6163de976da40..47dcc13614522 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -331,6 +331,7 @@ def package_list_from_file(file): 'comet-ml': 'comet_ml', 'neptune-client': 'neptune', 'hydra-core': 'hydra', + 'pyDeprecate': 'deprecate', } MOCK_PACKAGES = [] if SPHINX_MOCK_REQUIREMENTS: diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2a0c108ba7603..2781586730151 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -30,7 +30,7 @@ import yaml from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import WarningCache @@ -258,9 +258,9 @@ def save_checkpoint(self, trainer, unused: Optional = None): to handle correct behaviour in distributed training, i.e., saving only on rank 0. """ if unused is not None: - rank_zero_warn( + rank_zero_deprecation( "`ModelCheckpoint.save_checkpoint` signature has changed in v1.3. The `pl_module` parameter" - " has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning + " has been removed. Support for the old signature will be removed in v1.5" ) global_step = trainer.global_step @@ -371,9 +371,9 @@ def __init_triggers( # period takes precedence over every_n_val_epochs for backwards compatibility if period is not None: - rank_zero_warn( + rank_zero_deprecation( 'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ' Please use `every_n_val_epochs` instead.' ) self._every_n_val_epochs = period @@ -381,17 +381,17 @@ def __init_triggers( @property def period(self) -> Optional[int]: - rank_zero_warn( + rank_zero_deprecation( 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ' Please use `every_n_val_epochs` instead.' ) return self._period @period.setter def period(self, value: Optional[int]) -> None: - rank_zero_warn( + rank_zero_deprecation( 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ' Please use `every_n_val_epochs` instead.' ) self._period = value diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 4d36fe48448dc..8e0718ab891dc 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,7 +38,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES from pytorch_lightning.core.step_result import Result -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -1226,9 +1226,8 @@ def training_step(...): opt_a.step() """ if optimizer is not None: - rank_zero_warn( - "`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4", - DeprecationWarning + rank_zero_deprecation( + "`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4" ) # make sure we're using manual opt diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 1da24737a3752..9b27fdf0cb253 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from warnings import warn from pytorch_lightning.metrics.classification import ( # noqa: F401 Accuracy, @@ -39,8 +38,9 @@ R2Score, SSIM, ) +from pytorch_lightning.utilities import rank_zero_deprecation -warn( +rank_zero_deprecation( "`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package" - " (https://github.com/PyTorchLightning/metrics) since v1.3 and will be removed in v1.5", DeprecationWarning + " (https://github.com/PyTorchLightning/metrics) since v1.3 and will be removed in v1.5" ) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index b9d0a45e6fd33..1a9febe0c831c 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -15,12 +15,12 @@ from torchmetrics import Accuracy as _Accuracy -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class Accuracy(_Accuracy): - @deprecated(target=_Accuracy, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_Accuracy) def __init__( self, threshold: float = 0.5, diff --git a/pytorch_lightning/metrics/classification/auc.py b/pytorch_lightning/metrics/classification/auc.py index ce28e1d4e7072..05bc7b27d7e68 100644 --- a/pytorch_lightning/metrics/classification/auc.py +++ b/pytorch_lightning/metrics/classification/auc.py @@ -15,12 +15,12 @@ from torchmetrics import AUC as _AUC -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class AUC(_AUC): - @deprecated(target=_AUC, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_AUC) def __init__( self, reorder: bool = False, diff --git a/pytorch_lightning/metrics/classification/auroc.py b/pytorch_lightning/metrics/classification/auroc.py index 0866406ecea8f..e10b094fd5a2e 100644 --- a/pytorch_lightning/metrics/classification/auroc.py +++ b/pytorch_lightning/metrics/classification/auroc.py @@ -15,12 +15,12 @@ from torchmetrics import AUROC as _AUROC -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class AUROC(_AUROC): - @deprecated(target=_AUROC, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_AUROC) def __init__( self, num_classes: Optional[int] = None, diff --git a/pytorch_lightning/metrics/classification/average_precision.py b/pytorch_lightning/metrics/classification/average_precision.py index 106d6ea6111b2..6c8cdbd52891d 100644 --- a/pytorch_lightning/metrics/classification/average_precision.py +++ b/pytorch_lightning/metrics/classification/average_precision.py @@ -15,12 +15,12 @@ from torchmetrics import AveragePrecision as _AveragePrecision -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class AveragePrecision(_AveragePrecision): - @deprecated(target=_AveragePrecision, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_AveragePrecision) def __init__( self, num_classes: Optional[int] = None, diff --git a/pytorch_lightning/metrics/classification/confusion_matrix.py b/pytorch_lightning/metrics/classification/confusion_matrix.py index aacd8dcf3b498..2995f668380de 100644 --- a/pytorch_lightning/metrics/classification/confusion_matrix.py +++ b/pytorch_lightning/metrics/classification/confusion_matrix.py @@ -15,12 +15,12 @@ from torchmetrics import ConfusionMatrix as _ConfusionMatrix -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class ConfusionMatrix(_ConfusionMatrix): - @deprecated(target=_ConfusionMatrix, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_ConfusionMatrix) def __init__( self, num_classes: int, diff --git a/pytorch_lightning/metrics/classification/f_beta.py b/pytorch_lightning/metrics/classification/f_beta.py index bac3cc3e99c4e..a3f4172f05400 100644 --- a/pytorch_lightning/metrics/classification/f_beta.py +++ b/pytorch_lightning/metrics/classification/f_beta.py @@ -16,12 +16,12 @@ from torchmetrics import F1 as _F1 from torchmetrics import FBeta as _FBeta -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class FBeta(_FBeta): - @deprecated(target=_FBeta, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_FBeta) def __init__( self, num_classes: int, @@ -43,7 +43,7 @@ def __init__( class F1(_F1): - @deprecated(target=_F1, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_F1) def __init__( self, num_classes: int, diff --git a/pytorch_lightning/metrics/classification/hamming_distance.py b/pytorch_lightning/metrics/classification/hamming_distance.py index b59c3e1053ab8..d66b0c2d9cfa8 100644 --- a/pytorch_lightning/metrics/classification/hamming_distance.py +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -15,12 +15,12 @@ from torchmetrics import HammingDistance as _HammingDistance -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class HammingDistance(_HammingDistance): - @deprecated(target=_HammingDistance, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_HammingDistance) def __init__( self, threshold: float = 0.5, diff --git a/pytorch_lightning/metrics/classification/iou.py b/pytorch_lightning/metrics/classification/iou.py index d5b5d8eeb47e2..f1d9d0945511a 100644 --- a/pytorch_lightning/metrics/classification/iou.py +++ b/pytorch_lightning/metrics/classification/iou.py @@ -15,12 +15,12 @@ from torchmetrics import IoU as _IoU -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class IoU(_IoU): - @deprecated(target=_IoU, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_IoU) def __init__( self, num_classes: int, diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index ae3ee40da0ca5..7b95d21dae97c 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -16,12 +16,12 @@ from torchmetrics import Precision as _Precision from torchmetrics import Recall as _Recall -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class Precision(_Precision): - @deprecated(target=_Precision, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_Precision) def __init__( self, num_classes: Optional[int] = None, @@ -47,7 +47,7 @@ def __init__( class Recall(_Recall): - @deprecated(target=_Recall, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_Recall) def __init__( self, num_classes: Optional[int] = None, diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py index fb8f6a812028c..285cb2fb78ccc 100644 --- a/pytorch_lightning/metrics/classification/precision_recall_curve.py +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -15,12 +15,12 @@ from torchmetrics import PrecisionRecallCurve as _PrecisionRecallCurve -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class PrecisionRecallCurve(_PrecisionRecallCurve): - @deprecated(target=_PrecisionRecallCurve, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_PrecisionRecallCurve) def __init__( self, num_classes: Optional[int] = None, diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py index 5850913f61ed9..3f6cf50803c86 100644 --- a/pytorch_lightning/metrics/classification/roc.py +++ b/pytorch_lightning/metrics/classification/roc.py @@ -15,12 +15,12 @@ from torchmetrics import ROC as _ROC -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class ROC(_ROC): - @deprecated(target=_ROC, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_ROC) def __init__( self, num_classes: Optional[int] = None, diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py index 2c4764477b262..1eed815d4b4cd 100644 --- a/pytorch_lightning/metrics/classification/stat_scores.py +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -15,12 +15,12 @@ from torchmetrics import StatScores as _StatScores -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class StatScores(_StatScores): - @deprecated(target=_StatScores, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_StatScores) def __init__( self, threshold: float = 0.5, diff --git a/pytorch_lightning/metrics/compositional.py b/pytorch_lightning/metrics/compositional.py index 975b8280f77d5..56bb1912e48e6 100644 --- a/pytorch_lightning/metrics/compositional.py +++ b/pytorch_lightning/metrics/compositional.py @@ -15,25 +15,21 @@ import torch from torchmetrics import Metric -from torchmetrics.metric import CompositionalMetric as __CompositionalMetric +from torchmetrics.metric import CompositionalMetric as _CompositionalMetric -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.metrics.utils import deprecated_metrics -class CompositionalMetric(__CompositionalMetric): - """ - .. deprecated:: - Use :class:`torchmetrics.metric.CompositionalMetric`. Will be removed in v1.5.0. - """ +class CompositionalMetric(_CompositionalMetric): + @deprecated_metrics(target=_CompositionalMetric) def __init__( self, operator: Callable, metric_a: Union[Metric, int, float, torch.Tensor], metric_b: Union[Metric, int, float, torch.Tensor, None], ): - rank_zero_warn( - "This `CompositionalMetric` was deprecated since v1.3.0 in favor of" - " `torchmetrics.metric.CompositionalMetric`. It will be removed in v1.5.0", DeprecationWarning - ) - super().__init__(operator=operator, metric_a=metric_a, metric_b=metric_b) + """ + .. deprecated:: + Use :class:`torchmetrics.metric.CompositionalMetric`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 601442cd01202..69fa9d75590e0 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -16,10 +16,10 @@ import torch from torchmetrics.functional import accuracy as _accuracy -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_accuracy, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_accuracy) def accuracy( preds: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/auc.py b/pytorch_lightning/metrics/functional/auc.py index 7cd3457789bf7..7cc6aa458d397 100644 --- a/pytorch_lightning/metrics/functional/auc.py +++ b/pytorch_lightning/metrics/functional/auc.py @@ -14,10 +14,10 @@ import torch from torchmetrics.functional import auc as _auc -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_auc, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_auc) def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor: """ .. deprecated:: diff --git a/pytorch_lightning/metrics/functional/auroc.py b/pytorch_lightning/metrics/functional/auroc.py index 16058110175c5..c49aa1a8fdc48 100644 --- a/pytorch_lightning/metrics/functional/auroc.py +++ b/pytorch_lightning/metrics/functional/auroc.py @@ -16,10 +16,10 @@ import torch from torchmetrics.functional import auroc as _auroc -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_auroc, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_auroc) def auroc( preds: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/average_precision.py b/pytorch_lightning/metrics/functional/average_precision.py index e4ce3941fe008..017b34739a0f4 100644 --- a/pytorch_lightning/metrics/functional/average_precision.py +++ b/pytorch_lightning/metrics/functional/average_precision.py @@ -16,10 +16,10 @@ import torch from torchmetrics.functional import average_precision as _average_precision -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_average_precision, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_average_precision) def average_precision( preds: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 90d5c0a66550a..be1fec196a346 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -21,7 +21,7 @@ from pytorch_lightning.metrics.functional.auc import auc as __auc from pytorch_lightning.metrics.functional.auroc import auroc as __auroc from pytorch_lightning.metrics.functional.iou import iou as __iou -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn def stat_scores( @@ -58,10 +58,10 @@ def stat_scores_multiple_classes( .. deprecated:: Use :func:`torchmetrics.functional.stat_scores`. Will be removed in v1.4.0. """ - rank_zero_warn( + rank_zero_deprecation( "This `stat_scores_multiple_classes` was deprecated in v1.2.0 in favor of" " `from pytorch_lightning.metrics.functional import stat_scores`." - " It will be removed in v1.4.0", DeprecationWarning + " It will be removed in v1.4.0" ) if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) @@ -144,10 +144,10 @@ def precision_recall( .. deprecated:: Use :func:`torchmetrics.functional.precision_recall`. Will be removed in v1.4.0. """ - rank_zero_warn( + rank_zero_deprecation( "This `precision_recall` was deprecated in v1.2.0 in favor of" " `from pytorch_lightning.metrcs.functional import precision_recall`." - " It will be removed in v1.4.0", DeprecationWarning + " It will be removed in v1.4.0" ) tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) @@ -172,10 +172,10 @@ def precision( .. deprecated:: Use :func:`torchmetrics.functional.precision`. Will be removed in v1.4.0. """ - rank_zero_warn( + rank_zero_deprecation( "This `precision` was deprecated in v1.2.0 in favor of" " `from pytorch_lightning.metrics.functional import precision`." - " It will be removed in v1.4.0", DeprecationWarning + " It will be removed in v1.4.0" ) return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[0] @@ -192,10 +192,10 @@ def recall( .. deprecated:: Use :func:`torchmetrics.functional.recall`. Will be removed in v1.4.0. """ - rank_zero_warn( + rank_zero_deprecation( "This `recall` was deprecated in v1.2.0 in favor of" " `from pytorch_lightning.metrics.functional import recall`." - " It will be removed in v1.4.0", DeprecationWarning + " It will be removed in v1.4.0" ) return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[1] @@ -210,20 +210,16 @@ def auc( .. deprecated:: Use :func:`torchmetrics.functional.auc`. Will be removed in v1.4.0. """ - rank_zero_warn( + rank_zero_deprecation( "This `auc` was deprecated in v1.2.0 in favor of" " `pytorch_lightning.metrics.functional.auc import auc`." - " It will be removed in v1.4.0", DeprecationWarning + " It will be removed in v1.4.0" ) return __auc(x, y) # todo: remove in 1.4 def _auc_decorator() -> Callable: - rank_zero_warn( - "This `_auc_decorator` was deprecated in v1.2.0." - " It will be removed in v1.4.0", DeprecationWarning - ) def wrapper(func_to_decorate: Callable) -> Callable: @@ -240,10 +236,6 @@ def new_func(*args, **kwargs) -> torch.Tensor: # todo: remove in 1.4 def _multiclass_auc_decorator() -> Callable: - rank_zero_warn( - "This `_multiclass_auc_decorator` was deprecated in v1.2.0." - " It will be removed in v1.4.0", DeprecationWarning - ) def wrapper(func_to_decorate: Callable) -> Callable: @@ -273,10 +265,9 @@ def auroc( .. deprecated:: Use :func:`torchmetrics.functional.auroc`. Will be removed in v1.4.0. """ - rank_zero_warn( - "This `auroc` was deprecated in v1.2.0 in favor of" - " `pytorch_lightning.metrics.functional.auroc import auroc`." - " It will be removed in v1.4.0", DeprecationWarning + rank_zero_deprecation( + "This `auroc` was deprecated in v1.2.0 in favor of `pytorch_lightning.metrics.functional.auroc import auroc`." + " It will be removed in v1.4.0" ) return __auroc( preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label, max_fpr=max_fpr, num_classes=1 @@ -294,10 +285,10 @@ def multiclass_auroc( .. deprecated:: Use :func:`torchmetrics.functional.auroc`. Will be removed in v1.4.0. """ - rank_zero_warn( + rank_zero_deprecation( "This `multiclass_auroc` was deprecated in v1.2.0 in favor of" " `pytorch_lightning.metrics.functional.auroc import auroc`." - " It will be removed in v1.4.0", DeprecationWarning + " It will be removed in v1.4.0" ) return __auroc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes) @@ -346,10 +337,9 @@ def iou( .. deprecated:: Use :func:`torchmetrics.functional.iou`. Will be removed in v1.4.0. """ - rank_zero_warn( - "This `iou` was deprecated in v1.2.0 in favor of" - " `from pytorch_lightning.metrics.functional.iou import iou`." - " It will be removed in v1.4.0", DeprecationWarning + rank_zero_deprecation( + "This `iou` was deprecated in v1.2.0 in favor of `from pytorch_lightning.metrics.functional.iou import iou`." + " It will be removed in v1.4.0" ) return __iou( pred=pred, diff --git a/pytorch_lightning/metrics/functional/confusion_matrix.py b/pytorch_lightning/metrics/functional/confusion_matrix.py index 5cf8818176696..038bd8b49b730 100644 --- a/pytorch_lightning/metrics/functional/confusion_matrix.py +++ b/pytorch_lightning/metrics/functional/confusion_matrix.py @@ -16,10 +16,10 @@ import torch from torchmetrics.functional import confusion_matrix as _confusion_matrix -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_confusion_matrix, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_confusion_matrix) def confusion_matrix( preds: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/explained_variance.py b/pytorch_lightning/metrics/functional/explained_variance.py index bcfe698bf4c5e..233a0851b8d56 100644 --- a/pytorch_lightning/metrics/functional/explained_variance.py +++ b/pytorch_lightning/metrics/functional/explained_variance.py @@ -16,10 +16,10 @@ import torch from torchmetrics.functional import explained_variance as _explained_variance -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_explained_variance, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_explained_variance) def explained_variance( preds: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/f_beta.py b/pytorch_lightning/metrics/functional/f_beta.py index e4d926e0ab8bf..f994c9a8a3271 100644 --- a/pytorch_lightning/metrics/functional/f_beta.py +++ b/pytorch_lightning/metrics/functional/f_beta.py @@ -15,10 +15,10 @@ from torchmetrics.functional import f1 as _f1 from torchmetrics.functional import fbeta as _fbeta -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_fbeta, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_fbeta) def fbeta( preds: torch.Tensor, target: torch.Tensor, @@ -34,7 +34,7 @@ def fbeta( """ -@deprecated(target=_f1, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_f1) def f1( preds: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/hamming_distance.py b/pytorch_lightning/metrics/functional/hamming_distance.py index ef6bb3277fef2..6a390e776f111 100644 --- a/pytorch_lightning/metrics/functional/hamming_distance.py +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -14,10 +14,10 @@ import torch from torchmetrics.functional import hamming_distance as _hamming_distance -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_hamming_distance, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_hamming_distance) def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: """ .. deprecated:: diff --git a/pytorch_lightning/metrics/functional/image_gradients.py b/pytorch_lightning/metrics/functional/image_gradients.py index b65c21613a5a5..e2151c5fc1d93 100644 --- a/pytorch_lightning/metrics/functional/image_gradients.py +++ b/pytorch_lightning/metrics/functional/image_gradients.py @@ -16,10 +16,10 @@ import torch from torchmetrics.functional import image_gradients as _image_gradients -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_image_gradients, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_image_gradients) def image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ .. deprecated:: diff --git a/pytorch_lightning/metrics/functional/iou.py b/pytorch_lightning/metrics/functional/iou.py index 7ae520eb25dee..76f59854ad4bf 100644 --- a/pytorch_lightning/metrics/functional/iou.py +++ b/pytorch_lightning/metrics/functional/iou.py @@ -16,10 +16,10 @@ import torch from torchmetrics.functional import iou as _iou -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_iou, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_iou) def iou( pred: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/mean_absolute_error.py b/pytorch_lightning/metrics/functional/mean_absolute_error.py index 85aa07c802eca..219284d79d623 100644 --- a/pytorch_lightning/metrics/functional/mean_absolute_error.py +++ b/pytorch_lightning/metrics/functional/mean_absolute_error.py @@ -15,10 +15,10 @@ import torch from torchmetrics.functional import mean_absolute_error as _mean_absolute_error -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_mean_absolute_error, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_mean_absolute_error) def mean_absolute_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ .. deprecated:: diff --git a/pytorch_lightning/metrics/functional/mean_relative_error.py b/pytorch_lightning/metrics/functional/mean_relative_error.py index be21371bdc91a..329fe040ebc7d 100644 --- a/pytorch_lightning/metrics/functional/mean_relative_error.py +++ b/pytorch_lightning/metrics/functional/mean_relative_error.py @@ -15,10 +15,10 @@ import torch from torchmetrics.functional.regression.mean_relative_error import mean_relative_error as _mean_relative_error -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_mean_relative_error, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_mean_relative_error) def mean_relative_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ .. deprecated:: diff --git a/pytorch_lightning/metrics/functional/mean_squared_error.py b/pytorch_lightning/metrics/functional/mean_squared_error.py index 9d1850dcd8689..5bbc0bb1c6a83 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_error.py @@ -15,10 +15,10 @@ import torch from torchmetrics.functional import mean_squared_error as _mean_squared_error -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_mean_squared_error, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_mean_squared_error) def mean_squared_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ .. deprecated:: diff --git a/pytorch_lightning/metrics/functional/mean_squared_log_error.py b/pytorch_lightning/metrics/functional/mean_squared_log_error.py index 56654ea47daf2..29786529381d5 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_log_error.py @@ -15,10 +15,10 @@ import torch from torchmetrics.functional import mean_squared_log_error as _mean_squared_log_error -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_mean_squared_log_error, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_mean_squared_log_error) def mean_squared_log_error(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ .. deprecated:: diff --git a/pytorch_lightning/metrics/functional/nlp.py b/pytorch_lightning/metrics/functional/nlp.py index 29feb959f2f18..c59d7cf2b8976 100644 --- a/pytorch_lightning/metrics/functional/nlp.py +++ b/pytorch_lightning/metrics/functional/nlp.py @@ -21,10 +21,10 @@ import torch from torchmetrics.functional import bleu_score as _bleu_score -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_bleu_score, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_bleu_score) def bleu_score( translate_corpus: Sequence[str], reference_corpus: Sequence[str], diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index 1b5be382a13af..7b6c8641b5829 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -18,10 +18,10 @@ from torchmetrics.functional import precision_recall as _precision_recall from torchmetrics.functional import recall as _recall -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_precision, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_precision) def precision( preds: torch.Tensor, target: torch.Tensor, @@ -39,7 +39,7 @@ def precision( """ -@deprecated(target=_recall, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_recall) def recall( preds: torch.Tensor, target: torch.Tensor, @@ -57,7 +57,7 @@ def recall( """ -@deprecated(target=_precision_recall, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_precision_recall) def precision_recall( preds: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/precision_recall_curve.py b/pytorch_lightning/metrics/functional/precision_recall_curve.py index d1d643ba70c22..dc9863cbb47c4 100644 --- a/pytorch_lightning/metrics/functional/precision_recall_curve.py +++ b/pytorch_lightning/metrics/functional/precision_recall_curve.py @@ -16,10 +16,10 @@ import torch from torchmetrics.functional import precision_recall_curve as _precision_recall_curve -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_precision_recall_curve, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_precision_recall_curve) def precision_recall_curve( preds: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/psnr.py b/pytorch_lightning/metrics/functional/psnr.py index dd7aa44ae628e..51be9d47b91f9 100644 --- a/pytorch_lightning/metrics/functional/psnr.py +++ b/pytorch_lightning/metrics/functional/psnr.py @@ -16,10 +16,10 @@ import torch from torchmetrics.functional import psnr as _psnr -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_psnr, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_psnr) def psnr( preds: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/r2score.py b/pytorch_lightning/metrics/functional/r2score.py index 49273d9cefaed..fe4b541989358 100644 --- a/pytorch_lightning/metrics/functional/r2score.py +++ b/pytorch_lightning/metrics/functional/r2score.py @@ -15,10 +15,10 @@ import torch from torchmetrics.functional import r2score as _r2score -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_r2score, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_r2score) def r2score( preds: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/roc.py b/pytorch_lightning/metrics/functional/roc.py index 1ca534eb6a5be..928a0b40fca54 100644 --- a/pytorch_lightning/metrics/functional/roc.py +++ b/pytorch_lightning/metrics/functional/roc.py @@ -16,10 +16,10 @@ from torch import Tensor from torchmetrics.functional import roc as _roc -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_roc, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_roc) def roc( preds: Tensor, target: Tensor, diff --git a/pytorch_lightning/metrics/functional/self_supervised.py b/pytorch_lightning/metrics/functional/self_supervised.py index c3dc1cbfad659..65dec211e938a 100644 --- a/pytorch_lightning/metrics/functional/self_supervised.py +++ b/pytorch_lightning/metrics/functional/self_supervised.py @@ -14,10 +14,10 @@ import torch from torchmetrics.functional import embedding_similarity as _embedding_similarity -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_embedding_similarity, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_embedding_similarity) def embedding_similarity( batch: torch.Tensor, similarity: str = 'cosine', diff --git a/pytorch_lightning/metrics/functional/ssim.py b/pytorch_lightning/metrics/functional/ssim.py index 8809fec8d8ff1..31cff7fcfb9b4 100644 --- a/pytorch_lightning/metrics/functional/ssim.py +++ b/pytorch_lightning/metrics/functional/ssim.py @@ -16,10 +16,10 @@ import torch from torchmetrics.functional import ssim as _ssim -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_ssim, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_ssim) def ssim( preds: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py index 6f234e84d9aab..30c03da237fe6 100644 --- a/pytorch_lightning/metrics/functional/stat_scores.py +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -16,10 +16,10 @@ import torch from torchmetrics.functional import stat_scores as _stat_scores -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics -@deprecated(target=_stat_scores, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_stat_scores) def stat_scores( preds: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index b76c91dcdf2f1..ee0fcdb8a92e1 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -16,16 +16,12 @@ from torchmetrics import Metric as _Metric from torchmetrics.collections import MetricCollection as _MetricCollection -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class Metric(_Metric): - r""" - .. deprecated:: - Use :class:`torchmetrics.Metric`. Will be removed in v1.5.0. - """ - @deprecated(target=_Metric, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_Metric) def __init__( self, compute_on_step: bool = True, @@ -33,7 +29,7 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - """ + r""" .. deprecated:: Use :class:`torchmetrics.Metric`. Will be removed in v1.5.0. """ @@ -41,7 +37,7 @@ def __init__( class MetricCollection(_MetricCollection): - @deprecated(target=_MetricCollection, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_MetricCollection) def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): """ .. deprecated:: diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py index 4f820718545cb..0f94ae2fb3754 100644 --- a/pytorch_lightning/metrics/regression/explained_variance.py +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -15,12 +15,12 @@ from torchmetrics import ExplainedVariance as _ExplainedVariance -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class ExplainedVariance(_ExplainedVariance): - @deprecated(target=_ExplainedVariance, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_ExplainedVariance) def __init__( self, multioutput: str = 'uniform_average', diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py index 8510275c127d7..57c7db420445b 100644 --- a/pytorch_lightning/metrics/regression/mean_absolute_error.py +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -15,12 +15,12 @@ from torchmetrics import MeanAbsoluteError as _MeanAbsoluteError -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class MeanAbsoluteError(_MeanAbsoluteError): - @deprecated(target=_MeanAbsoluteError, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_MeanAbsoluteError) def __init__( self, compute_on_step: bool = True, diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py index cbe09faf0046c..c8e9c151c99d9 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -15,12 +15,12 @@ from torchmetrics import MeanSquaredError as _MeanSquaredError -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class MeanSquaredError(_MeanSquaredError): - @deprecated(target=_MeanSquaredError, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_MeanSquaredError) def __init__( self, compute_on_step: bool = True, diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py index 795d6f5409abf..c8ee8a7069115 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -15,12 +15,12 @@ from torchmetrics import MeanSquaredLogError as _MeanSquaredLogError -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class MeanSquaredLogError(_MeanSquaredLogError): - @deprecated(target=_MeanSquaredLogError, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_MeanSquaredLogError) def __init__( self, compute_on_step: bool = True, diff --git a/pytorch_lightning/metrics/regression/psnr.py b/pytorch_lightning/metrics/regression/psnr.py index 85b8eceaa24c5..f972e9a8e2b5e 100644 --- a/pytorch_lightning/metrics/regression/psnr.py +++ b/pytorch_lightning/metrics/regression/psnr.py @@ -15,12 +15,12 @@ from torchmetrics import PSNR as _PSNR -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class PSNR(_PSNR): - @deprecated(target=_PSNR, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_PSNR) def __init__( self, data_range: Optional[float] = None, diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py index 52621d6df7c28..ad5f7f3bd8d07 100644 --- a/pytorch_lightning/metrics/regression/r2score.py +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -15,12 +15,12 @@ from torchmetrics import R2Score as _R2Score -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class R2Score(_R2Score): - @deprecated(target=_R2Score, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_R2Score) def __init__( self, num_outputs: int = 1, diff --git a/pytorch_lightning/metrics/regression/ssim.py b/pytorch_lightning/metrics/regression/ssim.py index b290808c6fa5e..cf5571f3e68f4 100644 --- a/pytorch_lightning/metrics/regression/ssim.py +++ b/pytorch_lightning/metrics/regression/ssim.py @@ -15,12 +15,12 @@ from torchmetrics import SSIM as _SSIM -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.metrics.utils import deprecated_metrics class SSIM(_SSIM): - @deprecated(target=_SSIM, ver_deprecate="1.3.0", ver_remove="1.5.0") + @deprecated_metrics(target=_SSIM) def __init__( self, kernel_size: Sequence[int] = (11, 11), diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index cf7e82fc36ad8..4adc88a37ba21 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Optional import torch +from deprecate import deprecated from torchmetrics.utilities.data import dim_zero_cat as _dim_zero_cat from torchmetrics.utilities.data import dim_zero_mean as _dim_zero_mean from torchmetrics.utilities.data import dim_zero_sum as _dim_zero_sum @@ -24,25 +26,27 @@ from torchmetrics.utilities.distributed import class_reduce as _class_reduce from torchmetrics.utilities.distributed import reduce as _reduce -from pytorch_lightning.utilities.deprecation import deprecated +from pytorch_lightning.utilities import rank_zero_deprecation +deprecated_metrics = partial(deprecated, deprecated_in="1.3.0", remove_in="1.5.0", stream=rank_zero_deprecation) -@deprecated(target=_dim_zero_cat, ver_deprecate="1.3.0", ver_remove="1.5.0") + +@deprecated_metrics(target=_dim_zero_cat) def dim_zero_cat(x): pass -@deprecated(target=_dim_zero_sum, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_dim_zero_sum) def dim_zero_sum(x): pass -@deprecated(target=_dim_zero_mean, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_dim_zero_mean) def dim_zero_mean(x): pass -@deprecated(target=_to_onehot, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_to_onehot) def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> torch.Tensor: """ .. deprecated:: @@ -50,7 +54,7 @@ def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> """ -@deprecated(target=_select_topk, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_select_topk) def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: """ .. deprecated:: @@ -58,7 +62,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch """ -@deprecated(target=_to_categorical, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_to_categorical) def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: """ .. deprecated:: @@ -66,7 +70,7 @@ def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: """ -@deprecated(target=_get_num_classes, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_get_num_classes) def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None) -> int: """ .. deprecated:: @@ -74,7 +78,7 @@ def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optio """ -@deprecated(target=_reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_reduce) def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: """ .. deprecated:: @@ -82,7 +86,7 @@ def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: """ -@deprecated(target=_class_reduce, ver_deprecate="1.3.0", ver_remove="1.5.0") +@deprecated_metrics(target=_class_reduce) def class_reduce( num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none" ) -> torch.Tensor: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8823d48a7817e..30552572a0c61 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -19,7 +19,7 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.warnings import WarningCache @@ -243,10 +243,10 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: callback_states = {} for callback in self.callbacks: if self.__is_old_signature(callback.on_save_checkpoint): - rank_zero_warn( + rank_zero_deprecation( "`Callback.on_save_checkpoint` signature has changed in v1.3." " A `checkpoint` parameter has been added." - " Support for the old signature will be removed in v1.5", DeprecationWarning + " Support for the old signature will be removed in v1.5" ) state = callback.on_save_checkpoint(self, self.lightning_module) # noqa: parameter-unfilled else: diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 69d3887fc7718..32dbc8c4088a3 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -14,7 +14,7 @@ from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector -from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn +from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_deprecation class DeprecatedDistDeviceAttributes: @@ -24,96 +24,94 @@ class DeprecatedDistDeviceAttributes: @property def on_cpu(self) -> bool: - rank_zero_warn("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.") return self.accelerator_connector._device_type == DeviceType.CPU @on_cpu.setter def on_cpu(self, val: bool) -> None: - rank_zero_warn("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._device_type = DeviceType.CPU @property def on_tpu(self) -> bool: - rank_zero_warn("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.") return self.accelerator_connector._device_type == DeviceType.TPU @on_tpu.setter def on_tpu(self, val: bool) -> None: - rank_zero_warn("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._device_type = DeviceType.TPU @property def use_tpu(self) -> bool: - rank_zero_warn("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.") return self.on_tpu @use_tpu.setter def use_tpu(self, val: bool) -> None: - rank_zero_warn("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.") self.on_tpu = val @property def on_gpu(self) -> bool: - rank_zero_warn("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.") return self.accelerator_connector._device_type == DeviceType.GPU @on_gpu.setter def on_gpu(self, val: bool) -> None: - rank_zero_warn("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._device_type = DeviceType.GPU @property def use_dp(self) -> bool: - rank_zero_warn("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.") return self.accelerator_connector._distrib_type == DistributedType.DP @use_dp.setter def use_dp(self, val: bool) -> None: - rank_zero_warn("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._distrib_type = DistributedType.DP @property def use_ddp(self) -> bool: - rank_zero_warn("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.") return self.accelerator_connector._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) @use_ddp.setter def use_ddp(self, val: bool) -> None: - rank_zero_warn("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._distrib_type = DistributedType.DDP @property def use_ddp2(self) -> bool: - rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.") return self.accelerator_connector._distrib_type == DistributedType.DDP2 @use_ddp2.setter def use_ddp2(self, val: bool) -> None: - rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._distrib_type = DistributedType.DDP2 @property def use_horovod(self) -> bool: - rank_zero_warn("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.") return self.accelerator_connector._distrib_type == DistributedType.HOROVOD @use_horovod.setter def use_horovod(self, val: bool) -> None: - rank_zero_warn("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) + rank_zero_deprecation("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._distrib_type = DistributedType.HOROVOD @property def use_single_gpu(self) -> bool: - rank_zero_warn( - "Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning - ) + rank_zero_deprecation("Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.") # todo, limiting to exclude DDP2 is not clear but it comes from connectors... return ( self.accelerator_connector._device_type and self.accelerator_connector._device_type == DeviceType.GPU @@ -122,10 +120,7 @@ def use_single_gpu(self) -> bool: @use_single_gpu.setter def use_single_gpu(self, val: bool) -> None: - rank_zero_warn( - "Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", - DeprecationWarning, - ) + rank_zero_deprecation("Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.") if val: self.accelerator_connector._device_type = DeviceType.GPU @@ -138,23 +133,22 @@ class DeprecatedTrainerAttributes: @property def accelerator_backend(self) -> Accelerator: - rank_zero_warn( + rank_zero_deprecation( "The `Trainer.accelerator_backend` attribute is deprecated in favor of `Trainer.accelerator`" - " since 1.2 and will be removed in v1.4.", DeprecationWarning + " since 1.2 and will be removed in v1.4." ) return self.accelerator def get_model(self) -> LightningModule: - rank_zero_warn( + rank_zero_deprecation( "The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`" - " and will be removed in v1.4.", DeprecationWarning + " and will be removed in v1.4." ) return self.lightning_module @property def running_sanity_check(self) -> bool: - rank_zero_warn( - "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking`" - " and will be removed in v1.5.", DeprecationWarning + rank_zero_deprecation( + "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5." ) return self.sanity_checking diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index da41b9855b44a..c53681d20ac42 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -344,8 +344,7 @@ def call_on_evaluation_epoch_end_hook(self): model_hook_fx(outputs) else: self.warning_cache.warn( - f"`ModelHooks.{hook_name}` signature has changed in v1.3." - " `outputs` parameter has been added." + f"`ModelHooks.{hook_name}` signature has changed in v1.3. `outputs` parameter has been added." " Support for the old signature will be removed in v1.5", DeprecationWarning ) model_hook_fx() diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 3e2ee3e51efe1..f4617c23da383 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -17,6 +17,7 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 from pytorch_lightning.utilities.distributed import ( # noqa: F401 AllGatherGrad, + rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn, diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py index eb53579f948e8..80db2429f7d2a 100644 --- a/pytorch_lightning/utilities/argparse_utils.py +++ b/pytorch_lightning/utilities/argparse_utils.py @@ -1,7 +1,5 @@ -from warnings import warn +from pytorch_lightning.utilities import rank_zero_deprecation -warn( - "`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4", DeprecationWarning -) +rank_zero_deprecation("`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v1.4") from pytorch_lightning.utilities.argparse import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/deprecation.py b/pytorch_lightning/utilities/deprecation.py deleted file mode 100644 index f6591b4060b03..0000000000000 --- a/pytorch_lightning/utilities/deprecation.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -from functools import wraps -from typing import Any, Callable, List, Optional, Tuple - -from pytorch_lightning.utilities import rank_zero_warn - - -def get_func_arguments_and_types(func: Callable) -> List[Tuple[str, Tuple, Any]]: - """Parse function arguments, types and default values - - Example: - >>> get_func_arguments_and_types(get_func_arguments_and_types) - [('func', typing.Callable, )] - """ - func_default_params = inspect.signature(func).parameters - name_type_default = [] - for arg in func_default_params: - arg_type = func_default_params[arg].annotation - arg_default = func_default_params[arg].default - name_type_default.append((arg, arg_type, arg_default)) - return name_type_default - - -def deprecated(target: Callable, ver_deprecate: Optional[str] = "", ver_remove: Optional[str] = "") -> Callable: - """ - Decorate a function or class ``__init__`` with warning message - and pass all arguments directly to the target class/method. - """ - - def inner_function(base): - - @wraps(base) - def wrapped_fn(*args, **kwargs): - is_class = inspect.isclass(target) - target_func = target.__init__ if is_class else target - # warn user only once in lifetime - if not getattr(wrapped_fn, 'warned', False): - target_str = f'{target.__module__}.{target.__name__}' - base_name = base.__qualname__.split('.')[-2] if is_class else base.__name__ - base_str = f'{base.__module__}.{base_name}' - rank_zero_warn( - f"`{base_str}` was deprecated since v{ver_deprecate} in favor of `{target_str}`." - f" It will be removed in v{ver_remove}.", DeprecationWarning - ) - wrapped_fn.warned = True - - if args: # in case any args passed move them to kwargs - # parse only the argument names - arg_names = [arg[0] for arg in get_func_arguments_and_types(base)] - # convert args to kwargs - kwargs.update({k: v for k, v in zip(arg_names, args)}) - # fill by base defaults - base_defaults = {arg[0]: arg[2] for arg in get_func_arguments_and_types(base) if arg[2] != inspect._empty} - kwargs = dict(list(base_defaults.items()) + list(kwargs.items())) - - target_args = [arg[0] for arg in get_func_arguments_and_types(target_func)] - assert all(arg in target_args for arg in kwargs), \ - "Failed mapping, arguments missing in target base: %s" % [arg not in target_args for arg in kwargs] - # all args were already moved to kwargs - return target_func(**kwargs) - - return wrapped_fn - - return inner_function diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 658f349a22215..d301b9d8a7217 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -15,7 +15,7 @@ import logging import os import warnings -from functools import wraps +from functools import partial, wraps from typing import Any, Optional, Union import torch @@ -63,6 +63,7 @@ def _debug(*args, **kwargs): rank_zero_debug = rank_zero_only(_debug) rank_zero_info = rank_zero_only(_info) rank_zero_warn = rank_zero_only(_warn) +rank_zero_deprecation = partial(rank_zero_warn, category=DeprecationWarning) def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None): diff --git a/pytorch_lightning/utilities/model_utils.py b/pytorch_lightning/utilities/model_utils.py index 7fd5b287f7ba3..728f73f4f0d32 100644 --- a/pytorch_lightning/utilities/model_utils.py +++ b/pytorch_lightning/utilities/model_utils.py @@ -1,8 +1,7 @@ -from warnings import warn +from pytorch_lightning.utilities import rank_zero_deprecation -warn( - "`model_utils` package has been renamed to `model_helpers` since v1.2 and will be removed in v1.4", - DeprecationWarning +rank_zero_deprecation( + "`model_utils` package has been renamed to `model_helpers` since v1.2 and will be removed in v1.4" ) from pytorch_lightning.utilities.model_helpers import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/warning_utils.py b/pytorch_lightning/utilities/warning_utils.py index c520086f62a81..0668bababa609 100644 --- a/pytorch_lightning/utilities/warning_utils.py +++ b/pytorch_lightning/utilities/warning_utils.py @@ -1,7 +1,5 @@ -from warnings import warn +from pytorch_lightning.utilities import rank_zero_deprecation -warn( - "`warning_utils` package has been renamed to `warnings` since v1.2 and will be removed in v1.4", DeprecationWarning -) +rank_zero_deprecation("`warning_utils` package has been renamed to `warnings` since v1.2 and will be removed in v1.4") from pytorch_lightning.utilities.warnings import * # noqa: F403 E402 F401 diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py index aa0af1697ac51..f028222e3930b 100644 --- a/pytorch_lightning/utilities/xla_device_utils.py +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from warnings import warn +from pytorch_lightning.utilities import rank_zero_deprecation -warn( - "`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4", - DeprecationWarning +rank_zero_deprecation( + "`xla_device_utils` package has been renamed to `xla_device` since v1.2 and will be removed in v1.4" ) from pytorch_lightning.utilities.xla_device import * # noqa: F403 E402 F401 diff --git a/requirements.txt b/requirements.txt index f196b5e639bf5..4649983b79d78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ tqdm>=4.41.0 fsspec[http]>=0.8.1 tensorboard>=2.2.0 torchmetrics>=0.2.0 +pyDeprecate==0.1.1 \ No newline at end of file diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index f5ca312888cd4..99e1b31f6edad 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -142,14 +142,6 @@ def test_v1_4_0_deprecated_metrics(): with pytest.deprecated_call(match='will be removed in v1.4'): multiclass_auroc(torch.rand(20, 5).softmax(dim=-1), torch.randint(0, 5, (20, )), num_classes=5) - from pytorch_lightning.metrics.functional.classification import _auc_decorator - with pytest.deprecated_call(match='will be removed in v1.4'): - _auc_decorator() - - from pytorch_lightning.metrics.functional.classification import _multiclass_auc_decorator - with pytest.deprecated_call(match='will be removed in v1.4'): - _multiclass_auc_decorator() - class CustomDDPPlugin(DDPSpawnPlugin): diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 00ff2bd6bc889..e52e39cb16488 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -2,7 +2,8 @@ from torchmetrics import Metric as TMetric from pytorch_lightning import Trainer -from pytorch_lightning.metrics import Metric as PLMetric, MetricCollection +from pytorch_lightning.metrics import Metric as PLMetric +from pytorch_lightning.metrics import MetricCollection from tests.helpers.boring_model import BoringModel diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py index a5b3f10db24d8..d3703bf3691c9 100644 --- a/tests/metrics/test_remove_1-5_metrics.py +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -91,50 +91,47 @@ def test_v1_5_metrics_collection(): target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) - MetricCollection.__init__.warned = False - with pytest.deprecated_call( - match="`pytorch_lightning.metrics.metric.MetricCollection` was deprecated since v1.3.0 in favor" - " of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0." - ): + MetricCollection.__init__._warned = False + with pytest.deprecated_call(match="It will be removed in v1.5.0."): metrics = MetricCollection([Accuracy()]) assert metrics(preds, target) == {'Accuracy': torch.tensor(0.1250)} def test_v1_5_metric_accuracy(): - accuracy.warned = False + accuracy._warned = False preds = torch.tensor([0, 0, 1, 0, 1]) target = torch.tensor([0, 0, 1, 1, 1]) with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert accuracy(preds, target) == torch.tensor(0.8) - Accuracy.__init__.warned = False + Accuracy.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): Accuracy() def test_v1_5_metric_auc_auroc(): - AUC.__init__.warned = False + AUC.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): AUC() - ROC.__init__.warned = False + ROC.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): ROC() - AUROC.__init__.warned = False + AUROC.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): AUROC() x = torch.tensor([0, 1, 2, 3]) y = torch.tensor([0, 1, 2, 2]) - auc.warned = False + auc._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert auc(x, y) == torch.tensor(4.) preds = torch.tensor([0, 1, 2, 3]) target = torch.tensor([0, 1, 1, 1]) - roc.warned = False + roc._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): fpr, tpr, thrs = roc(preds, target, pos_label=1) assert torch.equal(fpr, torch.tensor([0., 0., 0., 0., 1.])) @@ -143,49 +140,49 @@ def test_v1_5_metric_auc_auroc(): preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) target = torch.tensor([0, 0, 1, 1, 1]) - auroc.warned = False + auroc._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert auroc(preds, target) == torch.tensor(0.5) def test_v1_5_metric_precision_recall(): - AveragePrecision.__init__.warned = False + AveragePrecision.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): AveragePrecision() - Precision.__init__.warned = False + Precision.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): Precision() - Recall.__init__.warned = False + Recall.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): Recall() - PrecisionRecallCurve.__init__.warned = False + PrecisionRecallCurve.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): PrecisionRecallCurve() pred = torch.tensor([0, 1, 2, 3]) target = torch.tensor([0, 1, 1, 1]) - average_precision.warned = False + average_precision._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert average_precision(pred, target) == torch.tensor(1.) - precision.warned = False + precision._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert precision(pred, target) == torch.tensor(0.5) - recall.warned = False + recall._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert recall(pred, target) == torch.tensor(0.5) - precision_recall.warned = False + precision_recall._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): prec, rc = precision_recall(pred, target) - assert prec == torch.tensor(0.5) - assert rc == torch.tensor(0.5) + assert prec == torch.tensor(0.5) + assert rc == torch.tensor(0.5) - precision_recall_curve.warned = False + precision_recall_curve._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): prec, rc, thrs = precision_recall_curve(pred, target) assert torch.equal(prec, torch.tensor([1., 1., 1., 1.])) @@ -194,141 +191,141 @@ def test_v1_5_metric_precision_recall(): def test_v1_5_metric_classif_mix(): - ConfusionMatrix.__init__.warned = False + ConfusionMatrix.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): ConfusionMatrix(num_classes=1) - FBeta.__init__.warned = False + FBeta.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): FBeta(num_classes=1) - F1.__init__.warned = False + F1.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): F1(num_classes=1) - HammingDistance.__init__.warned = False + HammingDistance.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): HammingDistance() - StatScores.__init__.warned = False + StatScores.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): StatScores() target = torch.tensor([1, 1, 0, 0]) preds = torch.tensor([0, 1, 0, 0]) - confusion_matrix.warned = False + confusion_matrix._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert torch.equal(confusion_matrix(preds, target, num_classes=2), torch.tensor([[2., 0.], [1., 1.]])) target = torch.tensor([0, 1, 2, 0, 1, 2]) preds = torch.tensor([0, 2, 1, 0, 0, 1]) - fbeta.warned = False + fbeta._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert torch.allclose(fbeta(preds, target, num_classes=3, beta=0.5), torch.tensor(0.3333), atol=1e-4) - f1.warned = False + f1._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert torch.allclose(f1(preds, target, num_classes=3), torch.tensor(0.3333), atol=1e-4) target = torch.tensor([[0, 1], [1, 1]]) preds = torch.tensor([[0, 1], [0, 1]]) - hamming_distance.warned = False + hamming_distance._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert hamming_distance(preds, target) == torch.tensor(0.25) preds = torch.tensor([1, 0, 2, 1]) target = torch.tensor([1, 1, 2, 0]) - stat_scores.warned = False + stat_scores._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert torch.equal(stat_scores(preds, target, reduce='micro'), torch.tensor([2, 2, 6, 2, 4])) def test_v1_5_metric_detect(): - IoU.__init__.warned = False + IoU.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): IoU(num_classes=1) target = torch.randint(0, 2, (10, 25, 25)) preds = torch.tensor(target) preds[2:5, 7:13, 9:15] = 1 - preds[2:5, 7:13, 9:15] - iou.warned = False + iou._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = iou(preds, target) assert torch.allclose(res, torch.tensor(0.9660), atol=1e-4) def test_v1_5_metric_regress(): - ExplainedVariance.__init__.warned = False + ExplainedVariance.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): ExplainedVariance() - MeanAbsoluteError.__init__.warned = False + MeanAbsoluteError.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): MeanAbsoluteError() - MeanSquaredError.__init__.warned = False + MeanSquaredError.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): MeanSquaredError() - MeanSquaredLogError.__init__.warned = False + MeanSquaredLogError.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): MeanSquaredLogError() target = torch.tensor([3, -0.5, 2, 7]) preds = torch.tensor([2.5, 0.0, 2, 8]) - explained_variance.warned = False + explained_variance._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = explained_variance(preds, target) assert torch.allclose(res, torch.tensor(0.9572), atol=1e-4) x = torch.tensor([0., 1, 2, 3]) y = torch.tensor([0., 1, 2, 2]) - mean_absolute_error.warned = False + mean_absolute_error._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert mean_absolute_error(x, y) == 0.25 - mean_relative_error.warned = False + mean_relative_error._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert mean_relative_error(x, y) == 0.125 - mean_squared_error.warned = False + mean_squared_error._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert mean_squared_error(x, y) == 0.25 - mean_squared_log_error.warned = False + mean_squared_log_error._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = mean_squared_log_error(x, y) assert torch.allclose(res, torch.tensor(0.0207), atol=1e-4) - PSNR.__init__.warned = False + PSNR.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): PSNR() - R2Score.__init__.warned = False + R2Score.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): R2Score() - SSIM.__init__.warned = False + SSIM.__init__._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): SSIM() preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) - psnr.warned = False + psnr._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = psnr(preds, target) assert torch.allclose(res, torch.tensor(2.5527), atol=1e-4) target = torch.tensor([3, -0.5, 2, 7]) preds = torch.tensor([2.5, 0.0, 2, 8]) - r2score.warned = False + r2score._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = r2score(preds, target) assert torch.allclose(res, torch.tensor(0.9486), atol=1e-4) preds = torch.rand([16, 1, 16, 16]) target = preds * 0.75 - ssim.warned = False + ssim._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = ssim(preds, target) assert torch.allclose(res, torch.tensor(0.9219), atol=1e-4) @@ -337,13 +334,13 @@ def test_v1_5_metric_regress(): def test_v1_5_metric_others(): translate_corpus = ['the cat is on the mat'.split()] reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] - bleu_score.warned = False + bleu_score._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = bleu_score(translate_corpus, reference_corpus) assert torch.allclose(res, torch.tensor(0.7598), atol=1e-4) embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]]) - embedding_similarity.warned = False + embedding_similarity._warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): res = embedding_similarity(embeddings) assert torch.allclose( diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 490f205a7bbec..4ca2f737f5106 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1451,6 +1451,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): def test_trainer_predict_grad(tmpdir): + class CustomBoringModel(BoringModel): def predict_step(self, batch, batch_idx, dataloader_idx=None): diff --git a/tests/utilities/test_deprecation.py b/tests/utilities/test_deprecation.py deleted file mode 100644 index 42179f86b80ed..0000000000000 --- a/tests/utilities/test_deprecation.py +++ /dev/null @@ -1,94 +0,0 @@ -import pytest - -from pytorch_lightning.utilities.deprecation import deprecated -from tests.helpers.utils import no_warning_call - - -def my_sum(a=0, b=3): - return a + b - - -def my2_sum(a, b): - return a + b - - -@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5") -def dep_sum(a, b=5): - pass - - -@deprecated(target=my2_sum, ver_deprecate="0.1", ver_remove="0.5") -def dep2_sum(a, b): - pass - - -@deprecated(target=my2_sum, ver_deprecate="0.1", ver_remove="0.5") -def dep3_sum(a, b=4): - pass - - -def test_deprecated_func(): - with pytest.deprecated_call( - match='`tests.utilities.test_deprecation.dep_sum` was deprecated since v0.1 in favor' - ' of `tests.utilities.test_deprecation.my_sum`. It will be removed in v0.5.' - ): - assert dep_sum(2) == 7 - - # check that the warning is raised only once per function - with no_warning_call(DeprecationWarning): - assert dep_sum(3) == 8 - - # and does not affect other functions - with pytest.deprecated_call( - match='`tests.utilities.test_deprecation.dep3_sum` was deprecated since v0.1 in favor' - ' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.' - ): - assert dep3_sum(2, 1) == 3 - - -def test_deprecated_func_incomplete(): - - # missing required argument - with pytest.raises(TypeError, match="missing 1 required positional argument: 'b'"): - dep2_sum(2) - - # check that the warning is raised only once per function - with no_warning_call(DeprecationWarning): - assert dep2_sum(2, 1) == 3 - - # reset the warning - dep2_sum.warned = False - # does not affect other functions - with pytest.deprecated_call( - match='`tests.utilities.test_deprecation.dep2_sum` was deprecated since v0.1 in favor' - ' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.' - ): - assert dep2_sum(b=2, a=1) == 3 - - -class NewCls: - - def __init__(self, c, d="abc"): - self.my_c = c - self.my_d = d - - -class PastCls: - - @deprecated(target=NewCls, ver_deprecate="0.2", ver_remove="0.4") - def __init__(self, c, d="efg"): - pass - - -def test_deprecated_class(): - with pytest.deprecated_call( - match='`tests.utilities.test_deprecation.PastCls` was deprecated since v0.2 in favor' - ' of `tests.utilities.test_deprecation.NewCls`. It will be removed in v0.4.' - ): - past = PastCls(2) - assert past.my_c == 2 - assert past.my_d == "efg" - - # check that the warning is raised only once per function - with no_warning_call(DeprecationWarning): - assert PastCls(c=2, d="")