From f4b94afe0f1c3624a16629c111c35556bf4fcb07 Mon Sep 17 00:00:00 2001 From: abersheeran Date: Mon, 6 Jan 2025 09:39:32 +0800 Subject: [PATCH] Fix `asgi.py`'s `Task` thread safety --- a2wsgi/asgi.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/a2wsgi/asgi.py b/a2wsgi/asgi.py index 2692325..c9b4bb7 100644 --- a/a2wsgi/asgi.py +++ b/a2wsgi/asgi.py @@ -3,12 +3,14 @@ import threading from http import HTTPStatus from io import BytesIO -from typing import Any, Coroutine, Deque, Iterable, Optional +from typing import Any, Coroutine, Deque, Iterable, Optional, TypeVar from typing import cast as typing_cast from .asgi_typing import HTTPScope, ASGIApp, ReceiveEvent, SendEvent from .wsgi_typing import Environ, StartResponse, IterableChunks +T = TypeVar("T") + class defaultdict(dict): def __init__(self, default_factory, *args, **kwargs) -> None: @@ -198,7 +200,7 @@ def asgi_done_callback(self, future: asyncio.Future) -> None: finally: self.asgi_done.set() - def start_asgi_app(self, environ: Environ) -> asyncio.Task: + async def start_asgi_app(self, environ: Environ) -> asyncio.Task: run_asgi: asyncio.Task = self.loop.create_task( typing_cast( Coroutine[None, None, None], @@ -208,6 +210,9 @@ def start_asgi_app(self, environ: Environ) -> asyncio.Task: run_asgi.add_done_callback(self.asgi_done_callback) return run_asgi + def execute_in_loop(self, coro: Coroutine[None, None, T]) -> T: + return asyncio.run_coroutine_threadsafe(coro, self.loop).result() + def __call__( self, environ: Environ, start_response: StartResponse ) -> IterableChunks: @@ -217,7 +222,7 @@ def __call__( receive_eof = False body_sent = False - asgi_task = self.start_asgi_app(environ) + asgi_task = self.execute_in_loop(self.start_asgi_app(environ)) # activate loop self.loop.call_soon_threadsafe(lambda: None) @@ -287,10 +292,10 @@ def __call__( self.receive_event.set({"type": "http.disconnect"}) break - if asgi_task.done(): + if self.asgi_done.is_set(): break # HTTP response ends, wait for run_asgi's background tasks self.asgi_done.wait(self.wait_time) - asgi_task.cancel() + self.loop.call_soon_threadsafe(asgi_task.cancel) yield b""