Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix deprecation wrapper & tests #6553

Merged
merged 6 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions pytorch_lightning/utilities/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import inspect
from functools import wraps
from typing import Any, Callable, List, Tuple
from typing import Any, Callable, List, Tuple, Optional

from pytorch_lightning.utilities import rank_zero_warn

Expand All @@ -34,37 +34,41 @@ def get_func_arguments_and_types(func: Callable) -> List[Tuple[str, Tuple, Any]]
return name_type_default


def deprecated(target: Callable, ver_deprecate: str = "", ver_remove: str = "") -> Callable:
def deprecated(target: Callable, ver_deprecate: Optional[str] = "", ver_remove: Optional[str] = "") -> Callable:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def deprecated(target: Callable, ver_deprecate: Optional[str] = "", ver_remove: Optional[str] = "") -> Callable:
def deprecated(target: Callable, ver_deprecate: str = "", ver_remove: str = "") -> Callable:

Optional[...] references to Union[..., None]. Here, the default is "".

"""
Decorate a function or class ``__init__`` with warning message
and pass all arguments directly to the target class/method.
"""
"""

def inner_function(func):
def inner_function(base):

@wraps(func)
@wraps(base)
def wrapped_fn(*args, **kwargs):
is_class = inspect.isclass(target)
target_func = target.__init__ if is_class else target
# warn user only once in lifetime
if not getattr(inner_function, 'warned', False):
if not getattr(wrapped_fn, 'warned', False):
target_str = f'{target.__module__}.{target.__name__}'
func_name = func.__qualname__.split('.')[-2] if is_class else func.__name__
base_name = base.__qualname__.split('.')[-2] if is_class else base.__name__
base_str = f'{base.__module__}.{base_name}'
rank_zero_warn(
f"The `{func_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`."
f"`{base_str}` was deprecated since v{ver_deprecate} in favor of `{target_str}`."
f" It will be removed in v{ver_remove}.", DeprecationWarning
)
inner_function.warned = True
wrapped_fn.warned = True

if args: # in case any args passed move them to kwargs
# parse only the argument names
cls_arg_names = [arg[0] for arg in get_func_arguments_and_types(func)]
arg_names = [arg[0] for arg in get_func_arguments_and_types(base)]
# convert args to kwargs
kwargs.update({k: v for k, v in zip(cls_arg_names, args)})
kwargs.update({k: v for k, v in zip(arg_names, args)})
# fill by base defaults
base_defaults = {arg[0]: arg[2] for arg in get_func_arguments_and_types(base) if arg[2] != inspect._empty}
kwargs = dict(list(base_defaults.items()) + list(kwargs.items()))

target_args = [arg[0] for arg in get_func_arguments_and_types(target_func)]
assert all(arg in target_args for arg in kwargs), \
"Failed mapping, arguments missing in target func: %s" % [arg not in target_args for arg in kwargs]
"Failed mapping, arguments missing in target base: %s" % [arg not in target_args for arg in kwargs]
# all args were already moved to kwargs
return target_func(**kwargs)

Expand Down
4 changes: 2 additions & 2 deletions tests/deprecated_api/test_remove_1-5_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def test_v1_5_0_metrics_collection():
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
with pytest.deprecated_call(
match="The `MetricCollection` was deprecated since v1.3.0 in favor"
" of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0"
match="`pytorch_lightning.metrics.metric.MetricCollection` was deprecated since v1.3.0 in favor"
" of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0."
):
metrics = MetricCollection([Accuracy()])
assert metrics(preds, target) == {'Accuracy': torch.Tensor([0.1250])[0]}
77 changes: 67 additions & 10 deletions tests/utilities/test_deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,91 @@
from tests.helpers.utils import no_warning_call


def my_sum(a, b=3):
def my_sum(a=0, b=3):
return a + b


def my2_sum(a, b):
return a + b


@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5")
def dep_sum(a, b):
def dep_sum(a, b=5):
pass


@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5")
@deprecated(target=my2_sum, ver_deprecate="0.1", ver_remove="0.5")
def dep2_sum(a, b):
pass


@deprecated(target=my2_sum, ver_deprecate="0.1", ver_remove="0.5")
def dep3_sum(a, b=4):
pass


def test_deprecated_func():
with pytest.deprecated_call(
match='The `dep_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.'
' It will be removed in v0.5.'
match='`tests.utilities.test_deprecation.dep_sum` was deprecated since v0.1 in favor'
' of `tests.utilities.test_deprecation.my_sum`. It will be removed in v0.5.'
):
assert dep_sum(2, b=5) == 7
assert dep_sum(2) == 7

# check that the warning is raised only once per function
with no_warning_call(DeprecationWarning):
assert dep_sum(2, b=5) == 7
assert dep_sum(3) == 8

# and does not affect other functions
with pytest.deprecated_call(
match='The `dep2_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.'
' It will be removed in v0.5.'
match='`tests.utilities.test_deprecation.dep3_sum` was deprecated since v0.1 in favor'
' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.'
):
assert dep2_sum(2) == 5
assert dep3_sum(2, 1) == 3


def test_deprecated_func_incomplete():

# missing required argument
with pytest.raises(TypeError, match="missing 1 required positional argument: 'b'"):
dep2_sum(2)

# check that the warning is raised only once per function
with no_warning_call(DeprecationWarning):
assert dep2_sum(2, 1) == 3

# reset the warning
dep2_sum.warned = False
# does not affect other functions
with pytest.deprecated_call(
match='`tests.utilities.test_deprecation.dep2_sum` was deprecated since v0.1 in favor'
' of `tests.utilities.test_deprecation.my2_sum`. It will be removed in v0.5.'
):
assert dep2_sum(b=2, a=1) == 3


class NewCls:

def __init__(self, c, d="abc"):
self.my_c = c
self.my_d = d


class PastCls:

@deprecated(target=NewCls, ver_deprecate="0.2", ver_remove="0.4")
def __init__(self, c, d="efg"):
pass


def test_deprecated_class():
with pytest.deprecated_call(
match='`tests.utilities.test_deprecation.PastCls` was deprecated since v0.2 in favor'
' of `tests.utilities.test_deprecation.NewCls`. It will be removed in v0.4.'
):
past = PastCls(2)
assert past.my_c == 2
assert past.my_d == "efg"

# check that the warning is raised only once per function
with no_warning_call(DeprecationWarning):
assert PastCls(c=2, d="")