Skip to content

Commit

Permalink
Simplify deprecations (#6620)
Browse files Browse the repository at this point in the history
* use external deprecate

* simplify

* simplify

* simplify

* flake8

* .

* others

* .
  • Loading branch information
Borda authored Mar 25, 2021
1 parent 9be092d commit 217c12a
Show file tree
Hide file tree
Showing 66 changed files with 246 additions and 450 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def package_list_from_file(file):
'comet-ml': 'comet_ml',
'neptune-client': 'neptune',
'hydra-core': 'hydra',
'pyDeprecate': 'deprecate',
}
MOCK_PACKAGES = []
if SPHINX_MOCK_REQUIREMENTS:
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import yaml

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import WarningCache
Expand Down Expand Up @@ -258,9 +258,9 @@ def save_checkpoint(self, trainer, unused: Optional = None):
to handle correct behaviour in distributed training, i.e., saving only on rank 0.
"""
if unused is not None:
rank_zero_warn(
rank_zero_deprecation(
"`ModelCheckpoint.save_checkpoint` signature has changed in v1.3. The `pl_module` parameter"
" has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning
" has been removed. Support for the old signature will be removed in v1.5"
)

global_step = trainer.global_step
Expand Down Expand Up @@ -371,27 +371,27 @@ def __init_triggers(

# period takes precedence over every_n_val_epochs for backwards compatibility
if period is not None:
rank_zero_warn(
rank_zero_deprecation(
'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.', DeprecationWarning
' Please use `every_n_val_epochs` instead.'
)
self._every_n_val_epochs = period

self._period = self._every_n_val_epochs

@property
def period(self) -> Optional[int]:
rank_zero_warn(
rank_zero_deprecation(
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.', DeprecationWarning
' Please use `every_n_val_epochs` instead.'
)
return self._period

@period.setter
def period(self, value: Optional[int]) -> None:
rank_zero_warn(
rank_zero_deprecation(
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.', DeprecationWarning
' Please use `every_n_val_epochs` instead.'
)
self._period = value

Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -1229,9 +1229,8 @@ def training_step(...):
opt_a.step()
"""
if optimizer is not None:
rank_zero_warn(
"`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4",
DeprecationWarning
rank_zero_deprecation(
"`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4"
)

# make sure we're using manual opt
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from warnings import warn

from pytorch_lightning.metrics.classification import ( # noqa: F401
Accuracy,
Expand Down Expand Up @@ -39,8 +38,9 @@
R2Score,
SSIM,
)
from pytorch_lightning.utilities import rank_zero_deprecation

warn(
rank_zero_deprecation(
"`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package"
" (https://github.com/PyTorchLightning/metrics) since v1.3 and will be removed in v1.5", DeprecationWarning
" (https://github.com/PyTorchLightning/metrics) since v1.3 and will be removed in v1.5"
)
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

from torchmetrics import Accuracy as _Accuracy

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


class Accuracy(_Accuracy):

@deprecated(target=_Accuracy, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_Accuracy)
def __init__(
self,
threshold: float = 0.5,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

from torchmetrics import AUC as _AUC

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


class AUC(_AUC):

@deprecated(target=_AUC, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_AUC)
def __init__(
self,
reorder: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

from torchmetrics import AUROC as _AUROC

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


class AUROC(_AUROC):

@deprecated(target=_AUROC, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_AUROC)
def __init__(
self,
num_classes: Optional[int] = None,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

from torchmetrics import AveragePrecision as _AveragePrecision

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


class AveragePrecision(_AveragePrecision):

@deprecated(target=_AveragePrecision, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_AveragePrecision)
def __init__(
self,
num_classes: Optional[int] = None,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

from torchmetrics import ConfusionMatrix as _ConfusionMatrix

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


class ConfusionMatrix(_ConfusionMatrix):

@deprecated(target=_ConfusionMatrix, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_ConfusionMatrix)
def __init__(
self,
num_classes: int,
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/metrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from torchmetrics import F1 as _F1
from torchmetrics import FBeta as _FBeta

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


class FBeta(_FBeta):

@deprecated(target=_FBeta, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_FBeta)
def __init__(
self,
num_classes: int,
Expand All @@ -43,7 +43,7 @@ def __init__(

class F1(_F1):

@deprecated(target=_F1, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_F1)
def __init__(
self,
num_classes: int,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/classification/hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

from torchmetrics import HammingDistance as _HammingDistance

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


class HammingDistance(_HammingDistance):

@deprecated(target=_HammingDistance, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_HammingDistance)
def __init__(
self,
threshold: float = 0.5,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/classification/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

from torchmetrics import IoU as _IoU

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


class IoU(_IoU):

@deprecated(target=_IoU, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_IoU)
def __init__(
self,
num_classes: int,
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/metrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from torchmetrics import Precision as _Precision
from torchmetrics import Recall as _Recall

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


class Precision(_Precision):

@deprecated(target=_Precision, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_Precision)
def __init__(
self,
num_classes: Optional[int] = None,
Expand All @@ -47,7 +47,7 @@ def __init__(

class Recall(_Recall):

@deprecated(target=_Recall, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_Recall)
def __init__(
self,
num_classes: Optional[int] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

from torchmetrics import PrecisionRecallCurve as _PrecisionRecallCurve

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


class PrecisionRecallCurve(_PrecisionRecallCurve):

@deprecated(target=_PrecisionRecallCurve, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_PrecisionRecallCurve)
def __init__(
self,
num_classes: Optional[int] = None,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

from torchmetrics import ROC as _ROC

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


class ROC(_ROC):

@deprecated(target=_ROC, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_ROC)
def __init__(
self,
num_classes: Optional[int] = None,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

from torchmetrics import StatScores as _StatScores

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


class StatScores(_StatScores):

@deprecated(target=_StatScores, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_StatScores)
def __init__(
self,
threshold: float = 0.5,
Expand Down
20 changes: 8 additions & 12 deletions pytorch_lightning/metrics/compositional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,21 @@

import torch
from torchmetrics import Metric
from torchmetrics.metric import CompositionalMetric as __CompositionalMetric
from torchmetrics.metric import CompositionalMetric as _CompositionalMetric

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.metrics.utils import deprecated_metrics


class CompositionalMetric(__CompositionalMetric):
"""
.. deprecated::
Use :class:`torchmetrics.metric.CompositionalMetric`. Will be removed in v1.5.0.
"""
class CompositionalMetric(_CompositionalMetric):

@deprecated_metrics(target=_CompositionalMetric)
def __init__(
self,
operator: Callable,
metric_a: Union[Metric, int, float, torch.Tensor],
metric_b: Union[Metric, int, float, torch.Tensor, None],
):
rank_zero_warn(
"This `CompositionalMetric` was deprecated since v1.3.0 in favor of"
" `torchmetrics.metric.CompositionalMetric`. It will be removed in v1.5.0", DeprecationWarning
)
super().__init__(operator=operator, metric_a=metric_a, metric_b=metric_b)
"""
.. deprecated::
Use :class:`torchmetrics.metric.CompositionalMetric`. Will be removed in v1.5.0.
"""
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import torch
from torchmetrics.functional import accuracy as _accuracy

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


@deprecated(target=_accuracy, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_accuracy)
def accuracy(
preds: torch.Tensor,
target: torch.Tensor,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
import torch
from torchmetrics.functional import auc as _auc

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


@deprecated(target=_auc, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_auc)
def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor:
"""
.. deprecated::
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import torch
from torchmetrics.functional import auroc as _auroc

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


@deprecated(target=_auroc, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_auroc)
def auroc(
preds: torch.Tensor,
target: torch.Tensor,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import torch
from torchmetrics.functional import average_precision as _average_precision

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.metrics.utils import deprecated_metrics


@deprecated(target=_average_precision, ver_deprecate="1.3.0", ver_remove="1.5.0")
@deprecated_metrics(target=_average_precision)
def average_precision(
preds: torch.Tensor,
target: torch.Tensor,
Expand Down
Loading

0 comments on commit 217c12a

Please sign in to comment.