From 9a213c1321593fcf9344161b043bdd7d1e600e24 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 16 Dec 2023 13:51:24 +0100 Subject: [PATCH] Use `ParamSpec` for `run_in_threadpool` (#2375) --- starlette/_exception_handler.py | 16 +++++++++++++++- starlette/concurrency.py | 11 ++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py index ea9ffbe9d..5b4b68e88 100644 --- a/starlette/_exception_handler.py +++ b/starlette/_exception_handler.py @@ -4,7 +4,16 @@ from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send +from starlette.types import ( + ASGIApp, + ExceptionHandler, + HTTPExceptionHandler, + Message, + Receive, + Scope, + Send, + WebSocketExceptionHandler, +) from starlette.websockets import WebSocket ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler] @@ -59,12 +68,17 @@ async def sender(message: Message) -> None: raise RuntimeError(msg) from exc if scope["type"] == "http": + nonlocal conn + handler = typing.cast(HTTPExceptionHandler, handler) + conn = typing.cast(Request, conn) if is_async_callable(handler): response = await handler(conn, exc) else: response = await run_in_threadpool(handler, conn, exc) await response(scope, receive, sender) elif scope["type"] == "websocket": + handler = typing.cast(WebSocketExceptionHandler, handler) + conn = typing.cast(WebSocket, conn) if is_async_callable(handler): await handler(conn, exc) else: diff --git a/starlette/concurrency.py b/starlette/concurrency.py index ca6033c0f..c44ee840f 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -1,9 +1,16 @@ import functools +import sys import typing import warnings import anyio.to_thread +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + +P = ParamSpec("P") T = typing.TypeVar("T") @@ -24,10 +31,8 @@ async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ign task_group.start_soon(run, functools.partial(func, **kwargs)) -# TODO: We should use `ParamSpec` here, but mypy doesn't support it yet. -# Check https://github.com/python/mypy/issues/12278 for more details. async def run_in_threadpool( - func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any + func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs ) -> T: if kwargs: # pragma: no cover # run_sync doesn't accept 'kwargs', so bind them in here