From 42fb68bcb001b2cfdee01bebf3a9bdde87bcc038 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 22:50:18 +0100 Subject: [PATCH 1/6] fix deprecation wrapper & tests --- pytorch_lightning/utilities/deprecation.py | 23 ++++--- tests/utilities/test_deprecation.py | 72 +++++++++++++++++++--- 2 files changed, 78 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/utilities/deprecation.py b/pytorch_lightning/utilities/deprecation.py index 3e2034c6a0453..67f2f6bc248e4 100644 --- a/pytorch_lightning/utilities/deprecation.py +++ b/pytorch_lightning/utilities/deprecation.py @@ -38,33 +38,36 @@ def deprecated(target: Callable, ver_deprecate: str = "", ver_remove: str = "") """ Decorate a function or class ``__init__`` with warning message and pass all arguments directly to the target class/method. - """ + """ - def inner_function(func): + def inner_function(base): - @wraps(func) + @wraps(base) def wrapped_fn(*args, **kwargs): is_class = inspect.isclass(target) target_func = target.__init__ if is_class else target # warn user only once in lifetime - if not getattr(inner_function, 'warned', False): + if not getattr(wrapped_fn, 'warned', False): target_str = f'{target.__module__}.{target.__name__}' - func_name = func.__qualname__.split('.')[-2] if is_class else func.__name__ + base_name = base.__qualname__.split('.')[-2] if is_class else base.__name__ rank_zero_warn( - f"The `{func_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`." + f"The `{base_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`." f" It will be removed in v{ver_remove}.", DeprecationWarning ) - inner_function.warned = True + wrapped_fn.warned = True if args: # in case any args passed move them to kwargs # parse only the argument names - cls_arg_names = [arg[0] for arg in get_func_arguments_and_types(func)] + arg_names = [arg[0] for arg in get_func_arguments_and_types(base)] # convert args to kwargs - kwargs.update({k: v for k, v in zip(cls_arg_names, args)}) + kwargs.update({k: v for k, v in zip(arg_names, args)}) + # fill by base defaults + base_defaults = {arg[0]: arg[2] for arg in get_func_arguments_and_types(base) if arg[2] != inspect._empty} + kwargs = dict(list(base_defaults.items()) + list(kwargs.items())) target_args = [arg[0] for arg in get_func_arguments_and_types(target_func)] assert all(arg in target_args for arg in kwargs), \ - "Failed mapping, arguments missing in target func: %s" % [arg not in target_args for arg in kwargs] + "Failed mapping, arguments missing in target base: %s" % [arg not in target_args for arg in kwargs] # all args were already moved to kwargs return target_func(**kwargs) diff --git a/tests/utilities/test_deprecation.py b/tests/utilities/test_deprecation.py index 7c653c07ad168..58da7cd2f9bcc 100644 --- a/tests/utilities/test_deprecation.py +++ b/tests/utilities/test_deprecation.py @@ -4,34 +4,92 @@ from tests.helpers.utils import no_warning_call -def my_sum(a, b=3): +def my_sum(a=0, b=3): + return a + b + + +def my2_sum(a, b): return a + b @deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5") -def dep_sum(a, b): +def dep_sum(a, b=5): pass -@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5") +@deprecated(target=my2_sum, ver_deprecate="0.1", ver_remove="0.5") def dep2_sum(a, b): pass +@deprecated(target=my2_sum, ver_deprecate="0.1", ver_remove="0.5") +def dep3_sum(a, b=4): + pass + + def test_deprecated_func(): with pytest.deprecated_call( match='The `dep_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.' ' It will be removed in v0.5.' ): - assert dep_sum(2, b=5) == 7 + assert dep_sum(2) == 7 # check that the warning is raised only once per function with no_warning_call(DeprecationWarning): - assert dep_sum(2, b=5) == 7 + assert dep_sum(3) == 8 # and does not affect other functions with pytest.deprecated_call( - match='The `dep2_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.' + match='The `dep3_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my2_sum`.' ' It will be removed in v0.5.' ): - assert dep2_sum(2) == 5 + assert dep3_sum(2, 1) == 3 + + +def test_deprecated_func_incomplete(): + + # missing required argument + with pytest.raises(TypeError, match="missing 1 required positional argument: 'b'"): + dep2_sum(2) + + + # check that the warning is raised only once per function + with no_warning_call(DeprecationWarning): + assert dep2_sum(2, 1) == 3 + + # reset the warning + dep2_sum.warned = False + # does not affect other functions + with pytest.deprecated_call( + match='The `dep2_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my2_sum`.' + ' It will be removed in v0.5.' + ): + assert dep2_sum(b=2, a=1) == 3 + + +class NewCls: + + def __init__(self, c, d="abc"): + self.my_c = c + self.my_d = d + + +class PastCls: + + @deprecated(target=NewCls, ver_deprecate="0.2", ver_remove="0.4") + def __init__(self, c, d="efg"): + pass + + +def test_deprecated_class(): + with pytest.deprecated_call( + match='The `PastCls` was deprecated since v0.2 in favor of `tests.utilities.test_deprecation.NewCls`.' + ' It will be removed in v0.4.' + ): + past = PastCls(2) + assert past.my_c == 2 + assert past.my_d == "efg" + + # check that the warning is raised only once per function + with no_warning_call(DeprecationWarning): + assert PastCls(c=2, d="") From db285a05f04450267ed910a194272f4ffb65f595 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 22:55:22 +0100 Subject: [PATCH 2/6] flake8 --- tests/utilities/test_deprecation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/utilities/test_deprecation.py b/tests/utilities/test_deprecation.py index 58da7cd2f9bcc..d3055926f0f71 100644 --- a/tests/utilities/test_deprecation.py +++ b/tests/utilities/test_deprecation.py @@ -52,7 +52,6 @@ def test_deprecated_func_incomplete(): with pytest.raises(TypeError, match="missing 1 required positional argument: 'b'"): dep2_sum(2) - # check that the warning is raised only once per function with no_warning_call(DeprecationWarning): assert dep2_sum(2, 1) == 3 @@ -83,8 +82,8 @@ def __init__(self, c, d="efg"): def test_deprecated_class(): with pytest.deprecated_call( - match='The `PastCls` was deprecated since v0.2 in favor of `tests.utilities.test_deprecation.NewCls`.' - ' It will be removed in v0.4.' + match='The `PastCls` was deprecated since v0.2 in favor of `tests.utilities.test_deprecation.NewCls`.' + ' It will be removed in v0.4.' ): past = PastCls(2) assert past.my_c == 2 From cc98c0205f4a7a2a779b822ea5ca1ce11a51934c Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 23:21:05 +0100 Subject: [PATCH 3/6] tuping --- pytorch_lightning/utilities/deprecation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/deprecation.py b/pytorch_lightning/utilities/deprecation.py index 67f2f6bc248e4..b7395af643665 100644 --- a/pytorch_lightning/utilities/deprecation.py +++ b/pytorch_lightning/utilities/deprecation.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect from functools import wraps -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, List, Tuple, Optional from pytorch_lightning.utilities import rank_zero_warn @@ -34,7 +34,7 @@ def get_func_arguments_and_types(func: Callable) -> List[Tuple[str, Tuple, Any]] return name_type_default -def deprecated(target: Callable, ver_deprecate: str = "", ver_remove: str = "") -> Callable: +def deprecated(target: Callable, ver_deprecate: Optional[str] = "", ver_remove: Optional[str] = "") -> Callable: """ Decorate a function or class ``__init__`` with warning message and pass all arguments directly to the target class/method. From 167141f3278458967fb0bcf4b9b330c30bd3250e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 23:56:14 +0100 Subject: [PATCH 4/6] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/utilities/deprecation.py | 2 +- tests/utilities/test_deprecation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/deprecation.py b/pytorch_lightning/utilities/deprecation.py index b7395af643665..6bcac52b9d686 100644 --- a/pytorch_lightning/utilities/deprecation.py +++ b/pytorch_lightning/utilities/deprecation.py @@ -51,7 +51,7 @@ def wrapped_fn(*args, **kwargs): target_str = f'{target.__module__}.{target.__name__}' base_name = base.__qualname__.split('.')[-2] if is_class else base.__name__ rank_zero_warn( - f"The `{base_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`." + f"`{base_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`." f" It will be removed in v{ver_remove}.", DeprecationWarning ) wrapped_fn.warned = True diff --git a/tests/utilities/test_deprecation.py b/tests/utilities/test_deprecation.py index d3055926f0f71..2b05468a1b071 100644 --- a/tests/utilities/test_deprecation.py +++ b/tests/utilities/test_deprecation.py @@ -82,7 +82,7 @@ def __init__(self, c, d="efg"): def test_deprecated_class(): with pytest.deprecated_call( - match='The `PastCls` was deprecated since v0.2 in favor of `tests.utilities.test_deprecation.NewCls`.' + match='`test.utilites.test_deprecation.PastCls` was deprecated since v0.2 in favor of `tests.utilities.test_deprecation.NewCls`.' ' It will be removed in v0.4.' ): past = PastCls(2) From aa78d207dc7aadc9d782f60f98b62e9046384888 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 17 Mar 2021 00:00:19 +0100 Subject: [PATCH 5/6] fix --- pytorch_lightning/utilities/deprecation.py | 3 ++- tests/utilities/test_deprecation.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/utilities/deprecation.py b/pytorch_lightning/utilities/deprecation.py index 6bcac52b9d686..4460e5d070b10 100644 --- a/pytorch_lightning/utilities/deprecation.py +++ b/pytorch_lightning/utilities/deprecation.py @@ -50,8 +50,9 @@ def wrapped_fn(*args, **kwargs): if not getattr(wrapped_fn, 'warned', False): target_str = f'{target.__module__}.{target.__name__}' base_name = base.__qualname__.split('.')[-2] if is_class else base.__name__ + base_str = f'{base.__module__}.{base_name}' rank_zero_warn( - f"`{base_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`." + f"`{base_str}` was deprecated since v{ver_deprecate} in favor of `{target_str}`." f" It will be removed in v{ver_remove}.", DeprecationWarning ) wrapped_fn.warned = True diff --git a/tests/utilities/test_deprecation.py b/tests/utilities/test_deprecation.py index 2b05468a1b071..2f54a77099701 100644 --- a/tests/utilities/test_deprecation.py +++ b/tests/utilities/test_deprecation.py @@ -29,8 +29,8 @@ def dep3_sum(a, b=4): def test_deprecated_func(): with pytest.deprecated_call( - match='The `dep_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.' - ' It will be removed in v0.5.' + match='`tests.utilities.test_deprecation.dep_sum` was deprecated since v0.1 in favor' + ' of `tests.utilities.test_deprecation.my_sum`. It will be removed in v0.5.' ): assert dep_sum(2) == 7 @@ -40,8 +40,8 @@ def test_deprecated_func(): # and does not affect other functions with pytest.deprecated_call( - match='The `dep3_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my2_sum`.' - ' It will be removed in v0.5.' + match='`tests.utilities.test_deprecation.dep3_sum` was deprecated since v0.1 in favor' + ' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.' ): assert dep3_sum(2, 1) == 3 @@ -60,8 +60,8 @@ def test_deprecated_func_incomplete(): dep2_sum.warned = False # does not affect other functions with pytest.deprecated_call( - match='The `dep2_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my2_sum`.' - ' It will be removed in v0.5.' + match='`tests.utilities.test_deprecation.dep2_sum` was deprecated since v0.1 in favor' + ' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.' ): assert dep2_sum(b=2, a=1) == 3 @@ -82,8 +82,8 @@ def __init__(self, c, d="efg"): def test_deprecated_class(): with pytest.deprecated_call( - match='`test.utilites.test_deprecation.PastCls` was deprecated since v0.2 in favor of `tests.utilities.test_deprecation.NewCls`.' - ' It will be removed in v0.4.' + match='`tests.utilities.test_deprecation.PastCls` was deprecated since v0.2 in favor' + ' of `tests.utilities.test_deprecation.NewCls`. It will be removed in v0.4.' ): past = PastCls(2) assert past.my_c == 2 From 5afee0ed4809d48facc553c8deddd6ae90726f34 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 17 Mar 2021 00:58:59 +0100 Subject: [PATCH 6/6] fix --- tests/deprecated_api/test_remove_1-5_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-5_metrics.py b/tests/deprecated_api/test_remove_1-5_metrics.py index 7c8c9ad296416..3428c0b761e93 100644 --- a/tests/deprecated_api/test_remove_1-5_metrics.py +++ b/tests/deprecated_api/test_remove_1-5_metrics.py @@ -41,8 +41,8 @@ def test_v1_5_0_metrics_collection(): target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) with pytest.deprecated_call( - match="The `MetricCollection` was deprecated since v1.3.0 in favor" - " of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0" + match="`pytorch_lightning.metrics.metric.MetricCollection` was deprecated since v1.3.0 in favor" + " of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0." ): metrics = MetricCollection([Accuracy()]) assert metrics(preds, target) == {'Accuracy': torch.Tensor([0.1250])[0]}