Skip to content

Commit

Permalink
Use ParamSpec for run_in_threadpool (#2375)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Dec 16, 2023
1 parent 6715eb4 commit 9a213c1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
16 changes: 15 additions & 1 deletion starlette/_exception_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
from starlette.types import (
ASGIApp,
ExceptionHandler,
HTTPExceptionHandler,
Message,
Receive,
Scope,
Send,
WebSocketExceptionHandler,
)
from starlette.websockets import WebSocket

ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
Expand Down Expand Up @@ -59,12 +68,17 @@ async def sender(message: Message) -> None:
raise RuntimeError(msg) from exc

if scope["type"] == "http":
nonlocal conn
handler = typing.cast(HTTPExceptionHandler, handler)
conn = typing.cast(Request, conn)
if is_async_callable(handler):
response = await handler(conn, exc)
else:
response = await run_in_threadpool(handler, conn, exc)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
handler = typing.cast(WebSocketExceptionHandler, handler)
conn = typing.cast(WebSocket, conn)
if is_async_callable(handler):
await handler(conn, exc)
else:
Expand Down
11 changes: 8 additions & 3 deletions starlette/concurrency.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import functools
import sys
import typing
import warnings

import anyio.to_thread

if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec

P = ParamSpec("P")
T = typing.TypeVar("T")


Expand All @@ -24,10 +31,8 @@ async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ign
task_group.start_soon(run, functools.partial(func, **kwargs))


# TODO: We should use `ParamSpec` here, but mypy doesn't support it yet.
# Check https://github.com/python/mypy/issues/12278 for more details.
async def run_in_threadpool(
func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any
func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> T:
if kwargs: # pragma: no cover
# run_sync doesn't accept 'kwargs', so bind them in here
Expand Down

0 comments on commit 9a213c1

Please sign in to comment.