-
I've written a decorator which accepts a coroutine function that returns a class which supports being an async context manager and lets me directly use "async with" on it, without having to manually await it first. So, instead of writing: conn = await connect(host)
async with conn:
# do something here with conn. I can directly write: async with connect(host) as conn:
# do something here with conn. The type returns by awaiting connect() is an async context manager (with In fact, the decorator allows either syntax to be used on the wrapped coroutine. This is working great, but I noticed that when I feed this through mypy, it complains that it can't find
Here's the code I've been using with same very rudimentary type hints which at leasts lets it type check the wrapped functions as if they weren't decorated. Any thoughts on getting rid of the errors above, though? Is there some way I can easily have the decorator add the necessary annotations for the new methods? DecoratedFunc = TypeVar('DecoratedFunc', bound=Callable)
def async_context_manager(coro: DecoratedFunc) -> DecoratedFunc:
"""Decorator for functions returning asynchronous context managers
This decorator can be used on functions which return objects
intended to be async context managers. The object returned by
the function should implement __aenter__ and __aexit__ methods
to run when the async context is entered and exited.
This wrapper also allows the use of "await" on the function being
decorated, to return the context manager without entering it.
"""
class AsyncContextManager:
"""Async context manager wrapper"""
def __init__(self, coro):
self._coro = coro
self._result = None
def __await__(self):
return self._coro.__await__()
async def __aenter__(self):
self._result = await self._coro
return await self._result.__aenter__()
async def __aexit__(self, *exc_info):
await self._result.__aexit__(*exc_info)
self._result = None
return False
@functools.wraps(coro)
def context_wrapper(*args, **kwargs):
"""Return an async context manager wrapper for this coroutine"""
return AsyncContextManager(coro(*args, **kwargs))
return cast(DecoratedFunc, context_wrapper) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 8 replies
-
I've made some progress on this. The latest version I have is: import asyncio
import functools
from types import TracebackType
from typing import Any, AsyncContextManager, Awaitable, Callable, Generator
from typing import Generic, Optional, Type, TypeVar, cast
_ACM = TypeVar('_ACM', bound=AsyncContextManager)
class _ACMWrapper(Generic[_ACM]):
"""Async context manager wrapper"""
def __init__(self, coro: Awaitable[_ACM]):
self._coro = coro
self._coro_result: Optional[_ACM] = None
def __await__(self) -> Generator[Any, None, _ACM]:
return self._coro.__await__()
async def __aenter__(self) -> _ACM:
self._coro_result = await self._coro
return await self._coro_result.__aenter__()
async def __aexit__(self, exc_type: Type[BaseException],
exc_value: BaseException,
traceback: TracebackType) -> bool:
exit_result = await self._coro_result.__aexit__(
exc_type, exc_value, traceback)
self._coro_result = None
return exit_result
_ACMCoroFunc = Callable[..., Awaitable[_ACM]]
_ACMWrapperFunc = Callable[..., _ACMWrapper[_ACM]]
def async_context_manager(coro: _ACMCoroFunc[_ACM]) -> _ACMWrapperFunc[_ACM]:
"""Decorator for functions returning asynchronous context managers
This decorator can be used on functions which return objects
intended to be async context managers. The object returned by
the function should implement __aenter__ and __aexit__ methods
to run when the async context is entered and exited.
This wrapper also allows the use of "await" on the function being
decorated, to return the context manager without entering it.
"""
@functools.wraps(coro)
def context_wrapper(*args, **kwargs) -> _ACMWrapper[_ACM]:
"""Return an async context manager wrapper for this coroutine"""
return _ACMWrapper(coro(*args, **kwargs))
return cast(_ACMWrapperFunc[_ACM], context_wrapper) This avoids the errors above related to |
Beta Was this translation helpful? Give feedback.
-
I got a chance to look at this again today, and found the issue! It turns out that the definition of import asyncio
import functools
from types import TracebackType
from typing import Any, AsyncContextManager, Awaitable, Callable, Generator
from typing import Generic, Optional, ParamSpec, Type, TypeVar
R = TypeVar('R', bound=AsyncContextManager)
class _Wrapper(Generic[R]):
def __init__(self, coro: Awaitable[R]):
self._coro = coro
self._coro_result: Optional[R] = None
def __await__(self) -> Generator[Any, None, R]:
return self._coro.__await__()
async def __aenter__(self) -> R:
self._coro_result = await self._coro
return await self._coro_result.__aenter__()
async def __aexit__(self, exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType]) -> Optional[bool]:
assert self._coro_result is not None
exit_result = await self._coro_result.__aexit__(
exc_type, exc_value, traceback)
self._coro_result = None
return exit_result
P = ParamSpec('P')
AF = Callable[P, Awaitable[R]]
WF = Callable[P, _Wrapper[R]]
def async_context_manager(coro: AF[P, R]) -> WF[P, R]:
@functools.wraps(coro)
def wrapfunc(*args: P.args, **kwargs: P.kwargs) -> _Wrapper[R]:
return _Wrapper(coro(*args, **kwargs))
return wrapfunc
class MyCtxMgr:
def __init__(self, x: int):
self.x = x
print('Created: x =', x)
async def __aenter__(self) -> 'MyCtxMgr':
print('Entered context, x =', self.x)
return self
async def __aexit__(self, exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType]) -> Optional[bool]:
print('Exited context, x =', self.x)
return False
@async_context_manager
async def get_context(x: int) -> MyCtxMgr:
return MyCtxMgr(x)
async def run():
ctx = await get_context(1)
async with ctx:
pass
async with get_context(2):
pass
asyncio.run(run()) With this version, if I change the arguments to something incorrect, pyright is able to catch it. For instance: ctx = await get_context(1, 'xxx') results in:
Similarly: ctx = await get_context(1, key='xxx') results in:
Unfortunately, mypy doesn't like the type aliases of AF and WF using P as an argument. To make mypy happy, these definitions need to be removed and expanded to something like: def async_context_manager(coro: Callable[P, Awaitable[R]]) -> \
Callable[P, _Wrapper[R]]: This use of def wrapfunc(*args: P.args, # type: ignore
**kwargs: P.kwargs) -> _Wrapper[R]: # type: ignore |
Beta Was this translation helpful? Give feedback.
I got a chance to look at this again today, and found the issue! It turns out that the definition of
__aexit__
in both_Wrapper
and_MyCtxMgr
was slightly off, missing "Optional" on some of the arguments, and thus making those classes incompatible with the binding onR
to match the protocolAsyncContextManager
. Here's a corrected version: