Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix deprecation wrapper & tests #6553

Merged
merged 6 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions pytorch_lightning/utilities/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,33 +38,36 @@ def deprecated(target: Callable, ver_deprecate: str = "", ver_remove: str = "")
"""
Decorate a function or class ``__init__`` with warning message
and pass all arguments directly to the target class/method.
"""
"""

def inner_function(func):
def inner_function(base):

@wraps(func)
@wraps(base)
def wrapped_fn(*args, **kwargs):
is_class = inspect.isclass(target)
target_func = target.__init__ if is_class else target
# warn user only once in lifetime
if not getattr(inner_function, 'warned', False):
if not getattr(wrapped_fn, 'warned', False):
target_str = f'{target.__module__}.{target.__name__}'
func_name = func.__qualname__.split('.')[-2] if is_class else func.__name__
base_name = base.__qualname__.split('.')[-2] if is_class else base.__name__
rank_zero_warn(
f"The `{func_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`."
f"The `{base_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`."
Borda marked this conversation as resolved.
Show resolved Hide resolved
f" It will be removed in v{ver_remove}.", DeprecationWarning
)
inner_function.warned = True
wrapped_fn.warned = True

if args: # in case any args passed move them to kwargs
# parse only the argument names
cls_arg_names = [arg[0] for arg in get_func_arguments_and_types(func)]
arg_names = [arg[0] for arg in get_func_arguments_and_types(base)]
# convert args to kwargs
kwargs.update({k: v for k, v in zip(cls_arg_names, args)})
kwargs.update({k: v for k, v in zip(arg_names, args)})
# fill by base defaults
base_defaults = {arg[0]: arg[2] for arg in get_func_arguments_and_types(base) if arg[2] != inspect._empty}
kwargs = dict(list(base_defaults.items()) + list(kwargs.items()))

target_args = [arg[0] for arg in get_func_arguments_and_types(target_func)]
assert all(arg in target_args for arg in kwargs), \
"Failed mapping, arguments missing in target func: %s" % [arg not in target_args for arg in kwargs]
"Failed mapping, arguments missing in target base: %s" % [arg not in target_args for arg in kwargs]
# all args were already moved to kwargs
return target_func(**kwargs)

Expand Down
71 changes: 64 additions & 7 deletions tests/utilities/test_deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,91 @@
from tests.helpers.utils import no_warning_call


def my_sum(a, b=3):
def my_sum(a=0, b=3):
return a + b


def my2_sum(a, b):
return a + b


@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5")
def dep_sum(a, b):
def dep_sum(a, b=5):
pass


@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5")
@deprecated(target=my2_sum, ver_deprecate="0.1", ver_remove="0.5")
def dep2_sum(a, b):
pass


@deprecated(target=my2_sum, ver_deprecate="0.1", ver_remove="0.5")
def dep3_sum(a, b=4):
pass


def test_deprecated_func():
with pytest.deprecated_call(
match='The `dep_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.'
' It will be removed in v0.5.'
):
assert dep_sum(2, b=5) == 7
assert dep_sum(2) == 7

# check that the warning is raised only once per function
with no_warning_call(DeprecationWarning):
assert dep_sum(2, b=5) == 7
assert dep_sum(3) == 8

# and does not affect other functions
with pytest.deprecated_call(
match='The `dep2_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.'
match='The `dep3_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my2_sum`.'
' It will be removed in v0.5.'
):
assert dep3_sum(2, 1) == 3


def test_deprecated_func_incomplete():

# missing required argument
with pytest.raises(TypeError, match="missing 1 required positional argument: 'b'"):
dep2_sum(2)

# check that the warning is raised only once per function
with no_warning_call(DeprecationWarning):
assert dep2_sum(2, 1) == 3

# reset the warning
dep2_sum.warned = False
# does not affect other functions
with pytest.deprecated_call(
match='The `dep2_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my2_sum`.'
' It will be removed in v0.5.'
):
assert dep2_sum(2) == 5
assert dep2_sum(b=2, a=1) == 3


class NewCls:

def __init__(self, c, d="abc"):
self.my_c = c
self.my_d = d


class PastCls:

@deprecated(target=NewCls, ver_deprecate="0.2", ver_remove="0.4")
def __init__(self, c, d="efg"):
pass


def test_deprecated_class():
with pytest.deprecated_call(
match='The `PastCls` was deprecated since v0.2 in favor of `tests.utilities.test_deprecation.NewCls`.'
' It will be removed in v0.4.'
):
Borda marked this conversation as resolved.
Show resolved Hide resolved
past = PastCls(2)
assert past.my_c == 2
assert past.my_d == "efg"

# check that the warning is raised only once per function
with no_warning_call(DeprecationWarning):
assert PastCls(c=2, d="")