-
Hey there! For context - I've tried asking the same question on typing Gitter and had no luck there since I believe the visibility is quite low and I got no comments to the message up to now. I am taking my luck here just to understand if what I'm trying to achieve is possible at all or if I shouldn't bother trying. The idea seems simple - we need to preserve all types of a function the decorator with arguments is being applied to. However, the issue comes when the decorator is supposed to work with both sync and async functions. The solution I've come close with is the following:
However, I have also tried |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Here's a potential solution. It type checks without errors in pyright. Mypy produces a few overload-related errors that appear to be false positives, so you could add Code sample in pyright playground from typing import Awaitable, Callable, Protocol, overload
from typing_extensions import TypeIs
from inspect import iscoroutinefunction
from functools import wraps
def is_coroutine[**P, R](
func: Callable[P, R | Awaitable[R]],
) -> TypeIs[Callable[P, Awaitable[R]]]:
return iscoroutinefunction(func)
class SyncOrAsync(Protocol):
@overload
def __call__[**P, R](
self, _func: Callable[P, Awaitable[R]]
) -> Callable[P, Awaitable[R]]:
...
@overload
def __call__[**P, R](self, _func: Callable[P, R]) -> Callable[P, R]:
...
def __call__[**P, R](
self, _func: Callable[P, Awaitable[R]] | Callable[P, R]
) -> Callable[P, Awaitable[R]] | Callable[P, R]:
...
def my_dec(param1: str, param2: int | None = None) -> SyncOrAsync:
@overload
def decorator[**P, R](
_func: Callable[P, Awaitable[R]],
) -> Callable[P, Awaitable[R]]:
...
@overload
def decorator[**P, R](
_func: Callable[P, R],
) -> Callable[P, R]:
...
def decorator[**P, R](
_func: Callable[P, Awaitable[R]] | Callable[P, R],
) -> Callable[P, Awaitable[R]] | Callable[P, R]:
if is_coroutine(_func):
_awaitable_func = _func
@wraps(_awaitable_func)
async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return await _awaitable_func(*args, **kwargs)
return _async_wrapper
else:
@wraps(_func)
def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return _func(*args, **kwargs)
return _sync_wrapper
return decorator
@my_dec(param1="test")
def test_sync() -> str:
return "test return"
v_sync = test_sync()
@my_dec(param1="test")
async def test_async() -> str:
return "test return"
v_async = test_async() |
Beta Was this translation helpful? Give feedback.
Here's a potential solution. It type checks without errors in pyright. Mypy produces a few overload-related errors that appear to be false positives, so you could add
# type: ignore
comments to suppress them.Code sample in pyright playground