diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 74fe883..596e69a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -61,6 +61,7 @@ jobs: if: matrix.anyio-version == 'anyio-v4' run: pip install --upgrade "anyio>=4.0.0,<5.0" - name: Lint + if: matrix.anyio-version == 'anyio-v4' run: bash scripts/lint.sh - run: mkdir coverage - name: Test diff --git a/asyncer/_compat.py b/asyncer/_compat.py new file mode 100644 index 0000000..ade175e --- /dev/null +++ b/asyncer/_compat.py @@ -0,0 +1,37 @@ +# AnyIO 4.1.0 renamed cancellable to abandon_on_cancel +import importlib +import importlib.metadata +from typing import Callable, TypeVar, Union + +import anyio +import anyio.to_thread +from anyio import CapacityLimiter +from typing_extensions import TypeVarTuple, Unpack + +ANYIO_VERSION = importlib.metadata.version("anyio") + +T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") + +if ANYIO_VERSION >= "4.1.0": + + async def run_sync( + func: Callable[[Unpack[PosArgsT]], T_Retval], + *args: Unpack[PosArgsT], + abandon_on_cancel: bool = False, + limiter: Union[CapacityLimiter, None] = None, + ) -> T_Retval: + return await anyio.to_thread.run_sync( + func, *args, abandon_on_cancel=abandon_on_cancel, limiter=limiter + ) +else: + + async def run_sync( + func: Callable[[Unpack[PosArgsT]], T_Retval], + *args: Unpack[PosArgsT], + abandon_on_cancel: bool = False, + limiter: Union[CapacityLimiter, None] = None, + ) -> T_Retval: + return await anyio.to_thread.run_sync( + func, *args, cancellable=abandon_on_cancel, limiter=limiter + ) diff --git a/asyncer/_main.py b/asyncer/_main.py index 4eeb5eb..78117e4 100644 --- a/asyncer/_main.py +++ b/asyncer/_main.py @@ -12,6 +12,9 @@ TypeVar, Union, ) +from warnings import warn + +from asyncer._compat import run_sync if sys.version_info >= (3, 10): from typing import ParamSpec @@ -319,7 +322,8 @@ def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval: def asyncify( function: Callable[T_ParamSpec, T_Retval], *, - cancellable: bool = False, + abandon_on_cancel: bool = False, + cancellable: Union[bool, None] = None, limiter: Optional[anyio.CapacityLimiter] = None, ) -> Callable[T_ParamSpec, Awaitable[T_Retval]]: """ @@ -359,14 +363,24 @@ def do_work(arg1, arg2, kwarg1="", kwarg2="") -> str: original one, that when called runs the same original function in a thread worker and returns the result. """ + if cancellable is not None: + abandon_on_cancel = cancellable + warn( + "The `cancellable=` keyword argument to `asyncer.asyncify()` is " + "deprecated since Asyncer 0.0.8, following AnyIO 4.1.0. " + "Use `abandon_on_cancel=` instead.", + DeprecationWarning, + stacklevel=2, + ) @functools.wraps(function) async def wrapper( *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs ) -> T_Retval: partial_f = functools.partial(function, *args, **kwargs) - return await anyio.to_thread.run_sync( - partial_f, cancellable=cancellable, limiter=limiter + + return await run_sync( + partial_f, abandon_on_cancel=abandon_on_cancel, limiter=limiter ) return wrapper diff --git a/tests/test_param_cancellable.py b/tests/test_param_cancellable.py new file mode 100644 index 0000000..4c2b0d4 --- /dev/null +++ b/tests/test_param_cancellable.py @@ -0,0 +1,37 @@ +import warnings + +import anyio +import pytest +from asyncer import asyncify + + +def test_cancellable_warns(): + def do_async_work(): + return "Hello World!" + + async def main(): + result = await asyncify(do_async_work, cancellable=True)() + return result + + with pytest.warns(DeprecationWarning) as record: + result = anyio.run(main) + assert isinstance(record[0].message, Warning) + assert ( + "The `cancellable=` keyword argument to `asyncer.asyncify()` is " + "deprecated since Asyncer 0.0.8" in record[0].message.args[0] + ) + assert result == "Hello World!" + + +def test_abandon_on_cancel_no(): + def do_async_work(): + return "Hello World!" + + async def main(): + result = await asyncify(do_async_work, abandon_on_cancel=True)() + return result + + with warnings.catch_warnings(): + warnings.simplefilter("error") + result = anyio.run(main) + assert result == "Hello World!"