diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index 0faf0b2662b..566ec1c2a34 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -400,23 +400,36 @@ def deprecated_parameter( _validate_deadline(deadline) def decorator(func: Callable) -> Callable: + def deprecation_warning(): + qualname = func.__qualname__ if func_name is None else func_name + _warn_or_error( + f'The {parameter_desc} parameter of {qualname} was ' + f'used but is deprecated.\n' + f'It will be removed in cirq {deadline}.\n' + f'{fix}\n' + ) + @functools.wraps(func) def decorated_func(*args, **kwargs) -> Any: if match(args, kwargs): if rewrite is not None: args, kwargs = rewrite(args, kwargs) + deprecation_warning() + return func(*args, **kwargs) - qualname = func.__qualname__ if func_name is None else func_name - _warn_or_error( - f'The {parameter_desc} parameter of {qualname} was ' - f'used but is deprecated.\n' - f'It will be removed in cirq {deadline}.\n' - f'{fix}\n' - ) + @functools.wraps(func) + async def async_decorated_func(*args, **kwargs) -> Any: + if match(args, kwargs): + if rewrite is not None: + args, kwargs = rewrite(args, kwargs) + deprecation_warning() - return func(*args, **kwargs) + return await func(*args, **kwargs) - return decorated_func + if inspect.iscoroutinefunction(func): + return async_decorated_func + else: + return decorated_func return decorator @@ -436,13 +449,12 @@ def deprecate_attributes(module_name: str, deprecated_attributes: Dict[str, Tupl will cause a warning for these deprecated attributes. """ - for (deadline, _) in deprecated_attributes.values(): + for deadline, _ in deprecated_attributes.values(): _validate_deadline(deadline) module = sys.modules[module_name] class Wrapped(ModuleType): - __dict__ = module.__dict__ # Workaround for: https://github.com/python/mypy/issues/8083 diff --git a/cirq-core/cirq/_compat_test.py b/cirq-core/cirq/_compat_test.py index 122affd6537..c5dba50c975 100644 --- a/cirq-core/cirq/_compat_test.py +++ b/cirq-core/cirq/_compat_test.py @@ -14,6 +14,7 @@ import collections import dataclasses import importlib.metadata +import inspect import logging import multiprocessing import os @@ -26,7 +27,7 @@ from importlib.machinery import ModuleSpec from unittest import mock - +import duet import numpy as np import pandas as pd import pytest @@ -263,6 +264,40 @@ def f_with_badly_deprecated_param(new_count): # pragma: no cover # pylint: enable=unused-variable +@duet.sync +async def test_deprecated_parameter_async_function(): + @deprecated_parameter( + deadline='v1.2', + fix='Double it yourself.', + func_name='test_func', + parameter_desc='double_count', + match=lambda args, kwargs: 'double_count' in kwargs, + rewrite=lambda args, kwargs: (args, {'new_count': kwargs['double_count'] * 2}), + ) + async def f(new_count): + return new_count + + assert inspect.iscoroutinefunction(f) + + # Does not warn on usual use. + with cirq.testing.assert_logs(count=0): + assert await f(1) == 1 + assert await f(new_count=1) == 1 + + with cirq.testing.assert_deprecated( + '_compat_test.py:', + 'double_count parameter of test_func was used', + 'will be removed in cirq v1.2', + 'Double it yourself.', + deadline='v1.2', + ): + # pylint: disable=unexpected-keyword-arg + # pylint: disable=no-value-for-parameter + assert await f(double_count=1) == 2 + # pylint: enable=no-value-for-parameter + # pylint: enable=unexpected-keyword-arg + + def test_wrap_module(): my_module = types.ModuleType('my_module', 'my doc string') my_module.foo = 'foo'