Skip to content

Commit

Permalink
prune warning & deprecation wrapper (#6540)
Browse files Browse the repository at this point in the history
* docs

* wrapper

* test

* count

* flake8

(cherry picked from commit 555a6fe)
  • Loading branch information
Borda committed Mar 23, 2021
1 parent 3566171 commit b762fb4
Show file tree
Hide file tree
Showing 26 changed files with 160 additions and 45 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

import torch
from torch import Tensor
from torchmetrics import Metric

from pytorch_lightning.metrics import Metric
from pytorch_lightning.utilities.distributed import sync_ddp_if_available


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from typing import Any, Callable, Optional

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.accuracy import _accuracy_compute, _accuracy_update
from pytorch_lightning.metrics.metric import Metric


class Accuracy(Metric):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from typing import Any, Callable, Optional

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.auc import _auc_compute, _auc_update
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import rank_zero_warn


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from typing import Any, Callable, Optional

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.auroc import _auroc_compute, _auroc_update
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import rank_zero_warn


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from typing import Any, List, Optional, Union

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.average_precision import _average_precision_compute, _average_precision_update
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import rank_zero_warn


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from typing import Any, Optional

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update
from pytorch_lightning.metrics.metric import Metric


class ConfusionMatrix(Metric):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from typing import Any, Optional

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.f_beta import _fbeta_compute, _fbeta_update
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import rank_zero_warn


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from typing import Any, Callable, Optional

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.hamming_distance import _hamming_distance_compute, _hamming_distance_update
from pytorch_lightning.metrics.metric import Metric


class HammingDistance(Metric):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from typing import Any, List, Optional, Tuple, Union

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.precision_recall_curve import (
_precision_recall_curve_compute,
_precision_recall_curve_update,
)
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import rank_zero_warn


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from typing import Any, List, Optional, Tuple, Union

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.roc import _roc_compute, _roc_update
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import rank_zero_warn


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from typing import Any, Callable, Optional, Tuple

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_compute, _stat_scores_update
from pytorch_lightning.metrics.metric import Metric


class StatScores(Metric):
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/metrics/compositional.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@


class CompositionalMetric(__CompositionalMetric):
r"""
This implementation refers to :class:`~torchmetrics.metric.CompositionalMetric`.
.. warning:: This metric is deprecated, use ``torchmetrics.metric.CompositionalMetric``. Will be removed in v1.5.0.
"""
.. deprecated::
Use :class:`torchmetrics.metric.CompositionalMetric`. Will be removed in v1.5.0.
"""

def __init__(
Expand All @@ -34,7 +33,7 @@ def __init__(
metric_b: Union[Metric, int, float, torch.Tensor, None],
):
rank_zero_warn(
"This `Metric` was deprecated since v1.3.0 in favor of `torchmetrics.Metric`."
" It will be removed in v1.5.0", DeprecationWarning
"This `CompositionalMetric` was deprecated since v1.3.0 in favor of"
" `torchmetrics.metric.CompositionalMetric`. It will be removed in v1.5.0", DeprecationWarning
)
super().__init__(operator=operator, metric_a=metric_a, metric_b=metric_b)
28 changes: 12 additions & 16 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from torchmetrics import Metric as __Metric
from torchmetrics import MetricCollection as __MetricCollection
from torchmetrics import Metric as _Metric
from torchmetrics.collections import MetricCollection as _MetricCollection

from pytorch_lightning.utilities.deprecation import deprecated
from pytorch_lightning.utilities.distributed import rank_zero_warn


class Metric(__Metric):
class Metric(_Metric):
r"""
This implementation refers to :class:`~torchmetrics.Metric`.
.. warning:: This metric is deprecated, use ``torchmetrics.Metric``. Will be removed in v1.5.0.
.. deprecated::
Use :class:`torchmetrics.Metric`. Will be removed in v1.5.0.
"""

def __init__(
Expand All @@ -45,16 +45,12 @@ def __init__(
)


class MetricCollection(__MetricCollection):
r"""
This implementation refers to :class:`~torchmetrics.MetricCollection`.
.. warning:: This metric is deprecated, use ``torchmetrics.MetricCollection``. Will be removed in v1.5.0.
class MetricCollection(_MetricCollection):
"""
.. deprecated::
Use :class:`torchmetrics.MetricCollection`. Will be removed in v1.5.0.
"""

@deprecated(target=_MetricCollection, ver_deprecate="1.3.0", ver_remove="1.5.0")
def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]):
rank_zero_warn(
"This `MetricCollection` was deprecated since v1.3.0 in favor of `torchmetrics.MetricCollection`."
" It will be removed in v1.5.0", DeprecationWarning
)
super().__init__(metrics=metrics)
pass
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/regression/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from typing import Any, Callable, Optional

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.explained_variance import (
_explained_variance_compute,
_explained_variance_update,
)
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import rank_zero_warn


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from typing import Any, Callable, Optional

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.mean_absolute_error import (
_mean_absolute_error_compute,
_mean_absolute_error_update,
)
from pytorch_lightning.metrics.metric import Metric


class MeanAbsoluteError(Metric):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/regression/mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from typing import Any, Callable, Optional

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.mean_squared_error import (
_mean_squared_error_compute,
_mean_squared_error_update,
)
from pytorch_lightning.metrics.metric import Metric


class MeanSquaredError(Metric):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from typing import Any, Callable, Optional

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.mean_squared_log_error import (
_mean_squared_log_error_compute,
_mean_squared_log_error_update,
)
from pytorch_lightning.metrics.metric import Metric


class MeanSquaredLogError(Metric):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/regression/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from typing import Any, Optional, Sequence, Tuple, Union

import torch
from torchmetrics import Metric

from pytorch_lightning import utilities
from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update
from pytorch_lightning.metrics.metric import Metric


class PSNR(Metric):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/regression/r2score.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from typing import Any, Callable, Optional

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.r2score import _r2score_compute, _r2score_update
from pytorch_lightning.metrics.metric import Metric


class R2Score(Metric):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/regression/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from typing import Any, Optional, Sequence

import torch
from torchmetrics import Metric

from pytorch_lightning.metrics.functional.ssim import _ssim_compute, _ssim_update
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import rank_zero_warn


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from typing import Any

import torch

from pytorch_lightning.metrics.metric import Metric
from torchmetrics import Metric


class MetricsHolder:
Expand Down
73 changes: 73 additions & 0 deletions pytorch_lightning/utilities/deprecation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from functools import wraps
from typing import Any, Callable, List, Tuple

from pytorch_lightning.utilities import rank_zero_warn


def get_func_arguments_and_types(func: Callable) -> List[Tuple[str, Tuple, Any]]:
"""Parse function arguments, types and default values
Example:
>>> get_func_arguments_and_types(get_func_arguments_and_types)
[('func', typing.Callable, <class 'inspect._empty'>)]
"""
func_default_params = inspect.signature(func).parameters
name_type_default = []
for arg in func_default_params:
arg_type = func_default_params[arg].annotation
arg_default = func_default_params[arg].default
name_type_default.append((arg, arg_type, arg_default))
return name_type_default


def deprecated(target: Callable, ver_deprecate: str = "", ver_remove: str = "") -> Callable:
"""
Decorate a function or class ``__init__`` with warning message
and pass all arguments directly to the target class/method.
"""

def inner_function(func):

@wraps(func)
def wrapped_fn(*args, **kwargs):
is_class = inspect.isclass(target)
target_func = target.__init__ if is_class else target
# warn user only once in lifetime
if not getattr(inner_function, 'warned', False):
target_str = f'{target.__module__}.{target.__name__}'
func_name = func.__qualname__.split('.')[-2] if is_class else func.__name__
rank_zero_warn(
f"The `{func_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`."
f" It will be removed in v{ver_remove}.", DeprecationWarning
)
inner_function.warned = True

if args: # in case any args passed move them to kwargs
# parse only the argument names
cls_arg_names = [arg[0] for arg in get_func_arguments_and_types(func)]
# convert args to kwargs
kwargs.update({k: v for k, v in zip(cls_arg_names, args)})

target_args = [arg[0] for arg in get_func_arguments_and_types(target_func)]
assert all(arg in target_args for arg in kwargs), \
"Failed mapping, arguments missing in target func: %s" % [arg not in target_args for arg in kwargs]
# all args were already moved to kwargs
return target_func(**kwargs)

return wrapped_fn

return inner_function
12 changes: 12 additions & 0 deletions tests/deprecated_api/test_remove_1-5_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest
import torch

from pytorch_lightning.metrics import Accuracy, MetricCollection
from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot


Expand All @@ -34,3 +35,14 @@ def test_v1_5_0_metrics_utils():
x = torch.tensor([[0.2, 0.5], [0.9, 0.1]])
with pytest.deprecated_call(match="It will be removed in v1.5.0"):
assert torch.equal(to_categorical(x), torch.Tensor([1, 0]).to(int))


def test_v1_5_0_metrics_collection():
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
with pytest.deprecated_call(
match="The `MetricCollection` was deprecated since v1.3.0 in favor"
" of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0"
):
metrics = MetricCollection([Accuracy()])
assert metrics(preds, target) == {'Accuracy': torch.Tensor([0.1250])[0]}
2 changes: 1 addition & 1 deletion tests/metrics/test_metric_lightning.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torchmetrics import Metric, MetricCollection

from pytorch_lightning import Trainer
from pytorch_lightning.metrics import Metric, MetricCollection
from tests.helpers.boring_model import BoringModel


Expand Down
Loading

0 comments on commit b762fb4

Please sign in to comment.