From a5e80dbae9ad48ad3e4d07a22130cd3f46cd15bb Mon Sep 17 00:00:00 2001 From: synodriver Date: Fri, 10 Mar 2023 12:35:49 +0800 Subject: [PATCH 1/2] implement asgi lifespan state --- src/hypercorn/asyncio/lifespan.py | 4 +++- src/hypercorn/asyncio/run.py | 4 ++-- src/hypercorn/asyncio/tcp_server.py | 5 ++++- src/hypercorn/asyncio/udp_server.py | 6 ++++-- src/hypercorn/protocol/__init__.py | 8 +++++++- src/hypercorn/protocol/h11.py | 6 +++++- src/hypercorn/protocol/h2.py | 6 +++++- src/hypercorn/protocol/h3.py | 6 +++++- src/hypercorn/protocol/http_stream.py | 5 ++++- src/hypercorn/protocol/quic.py | 5 ++++- src/hypercorn/protocol/ws_stream.py | 5 ++++- src/hypercorn/trio/lifespan.py | 3 +++ src/hypercorn/trio/run.py | 4 ++-- src/hypercorn/trio/tcp_server.py | 6 ++++-- src/hypercorn/trio/udp_server.py | 5 ++++- src/hypercorn/typing.py | 3 +++ 16 files changed, 63 insertions(+), 18 deletions(-) diff --git a/src/hypercorn/asyncio/lifespan.py b/src/hypercorn/asyncio/lifespan.py index 244950c6..c05caa67 100644 --- a/src/hypercorn/asyncio/lifespan.py +++ b/src/hypercorn/asyncio/lifespan.py @@ -2,7 +2,7 @@ import asyncio from functools import partial -from typing import Any, Callable +from typing import Any, Callable, Dict from ..config import Config from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope @@ -27,12 +27,14 @@ def __init__(self, app: AppWrapper, config: Config, loop: asyncio.AbstractEventL # required to ensure the support has been checked before # waiting on timeouts. self._started = asyncio.Event() + self.state: Dict[str, Any] = {} async def handle_lifespan(self) -> None: self._started.set() scope: LifespanScope = { "type": "lifespan", "asgi": {"spec_version": "2.0", "version": "3.0"}, + "state": self.state, } def _call_soon(func: Callable, *args: Any) -> Any: diff --git a/src/hypercorn/asyncio/run.py b/src/hypercorn/asyncio/run.py index b50b9c65..e786aae5 100644 --- a/src/hypercorn/asyncio/run.py +++ b/src/hypercorn/asyncio/run.py @@ -93,7 +93,7 @@ def _signal_handler(*_: Any) -> None: # noqa: N803 async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: server_tasks.add(asyncio.current_task(loop)) - await TCPServer(app, loop, config, context, reader, writer) + await TCPServer(app, loop, config, context, reader, writer, lifespan.state.copy()) servers = [] for sock in sockets.secure_sockets: @@ -127,7 +127,7 @@ async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamW sock = _share_socket(sock) _, protocol = await loop.create_datagram_endpoint( - lambda: UDPServer(app, loop, config, context), sock=sock + lambda: UDPServer(app, loop, config, context, lifespan.state.copy()), sock=sock ) server_tasks.add(loop.create_task(protocol.run())) bind = repr_socket_addr(sock.family, sock.getsockname()) diff --git a/src/hypercorn/asyncio/tcp_server.py b/src/hypercorn/asyncio/tcp_server.py index 91b3c050..e3c90526 100644 --- a/src/hypercorn/asyncio/tcp_server.py +++ b/src/hypercorn/asyncio/tcp_server.py @@ -2,7 +2,7 @@ import asyncio from ssl import SSLError -from typing import Any, Generator, Optional +from typing import Any, Generator, Optional, Dict from .task_group import TaskGroup from .worker_context import WorkerContext @@ -24,6 +24,7 @@ def __init__( context: WorkerContext, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, + app_state: Dict[str, Any] ) -> None: self.app = app self.config = config @@ -36,6 +37,7 @@ def __init__( self.idle_lock = asyncio.Lock() self._idle_handle: Optional[asyncio.Task] = None + self.app_state = app_state def __await__(self) -> Generator[Any, None, None]: return self.run().__await__() @@ -64,6 +66,7 @@ async def run(self) -> None: server, self.protocol_send, alpn_protocol, + self.app_state ) await self.protocol.initiate() await self._start_idle() diff --git a/src/hypercorn/asyncio/udp_server.py b/src/hypercorn/asyncio/udp_server.py index 629ab9f4..e5df8db6 100644 --- a/src/hypercorn/asyncio/udp_server.py +++ b/src/hypercorn/asyncio/udp_server.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from typing import Optional, Tuple, TYPE_CHECKING +from typing import Optional, Tuple, Dict, Any, TYPE_CHECKING from .task_group import TaskGroup from .worker_context import WorkerContext @@ -22,6 +22,7 @@ def __init__( loop: asyncio.AbstractEventLoop, config: Config, context: WorkerContext, + app_state: Dict[str, Any], ) -> None: self.app = app self.config = config @@ -30,6 +31,7 @@ def __init__( self.protocol: "QuicProtocol" self.protocol_queue: asyncio.Queue = asyncio.Queue(10) self.transport: Optional[asyncio.DatagramTransport] = None + self.app_state = app_state def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore self.transport = transport @@ -48,7 +50,7 @@ async def run(self) -> None: server = parse_socket_addr(socket.family, socket.getsockname()) async with TaskGroup(self.loop) as task_group: self.protocol = QuicProtocol( - self.app, self.config, self.context, task_group, server, self.protocol_send + self.app, self.config, self.context, task_group, server, self.protocol_send, self.app_state ) while not self.context.terminated.is_set() or not self.protocol.idle: diff --git a/src/hypercorn/protocol/__init__.py b/src/hypercorn/protocol/__init__.py index 39385681..c81d473a 100755 --- a/src/hypercorn/protocol/__init__.py +++ b/src/hypercorn/protocol/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Awaitable, Callable, Optional, Tuple, Union +from typing import Awaitable, Callable, Optional, Tuple, Union, Dict, Any from .h2 import H2Protocol from .h11 import H2CProtocolRequiredError, H2ProtocolAssumedError, H11Protocol @@ -21,6 +21,7 @@ def __init__( server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], alpn_protocol: Optional[str] = None, + app_state: Dict[str, Any] = None, ) -> None: self.app = app self.config = config @@ -31,6 +32,7 @@ def __init__( self.server = server self.send = send self.protocol: Union[H11Protocol, H2Protocol] + self.app_state = app_state if alpn_protocol == "h2": self.protocol = H2Protocol( self.app, @@ -41,6 +43,7 @@ def __init__( self.client, self.server, self.send, + self.app_state, ) else: self.protocol = H11Protocol( @@ -52,6 +55,7 @@ def __init__( self.client, self.server, self.send, + self.app_state, ) async def initiate(self) -> None: @@ -70,6 +74,7 @@ async def handle(self, event: Event) -> None: self.client, self.server, self.send, + self.app_state, ) await self.protocol.initiate() if error.data != b"": @@ -84,6 +89,7 @@ async def handle(self, event: Event) -> None: self.client, self.server, self.send, + self.app_state, ) await self.protocol.initiate(error.headers, error.settings) if error.data != b"": diff --git a/src/hypercorn/protocol/h11.py b/src/hypercorn/protocol/h11.py index e18d4884..fe1f3fa1 100755 --- a/src/hypercorn/protocol/h11.py +++ b/src/hypercorn/protocol/h11.py @@ -1,7 +1,7 @@ from __future__ import annotations from itertools import chain -from typing import Awaitable, Callable, cast, Optional, Tuple, Type, Union +from typing import Awaitable, Callable, cast, Optional, Tuple, Type, Union, Dict, Any import h11 @@ -88,6 +88,7 @@ def __init__( client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], + app_state: Dict[str, Any], ) -> None: self.app = app self.can_read = context.event_class() @@ -102,6 +103,7 @@ def __init__( self.ssl = ssl self.stream: Optional[Union[HTTPStream, WSStream]] = None self.task_group = task_group + self.app_state = app_state async def initiate(self) -> None: pass @@ -205,6 +207,7 @@ async def _create_stream(self, request: h11.Request) -> None: self.server, self.stream_send, STREAM_ID, + self.app_state ) self.connection = H11WSConnection(cast(h11.Connection, self.connection)) else: @@ -218,6 +221,7 @@ async def _create_stream(self, request: h11.Request) -> None: self.server, self.stream_send, STREAM_ID, + self.app_state ) await self.stream.handle( Request( diff --git a/src/hypercorn/protocol/h2.py b/src/hypercorn/protocol/h2.py index 6e76d493..7e976ea7 100755 --- a/src/hypercorn/protocol/h2.py +++ b/src/hypercorn/protocol/h2.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union, Any import h2 import h2.connection @@ -88,6 +88,7 @@ def __init__( client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], + app_state: Dict[str, Any], ) -> None: self.app = app self.client = client @@ -117,6 +118,7 @@ def __init__( self.has_data = self.context.event_class() self.priority = priority.PriorityTree() self.stream_buffers: Dict[int, StreamBuffer] = {} + self.app_state = app_state @property def idle(self) -> bool: @@ -318,6 +320,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None: self.server, self.stream_send, request.stream_id, + self.app_state ) else: self.streams[request.stream_id] = HTTPStream( @@ -330,6 +333,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None: self.server, self.stream_send, request.stream_id, + self.app_state ) self.stream_buffers[request.stream_id] = StreamBuffer(self.context.event_class) try: diff --git a/src/hypercorn/protocol/h3.py b/src/hypercorn/protocol/h3.py index 88d9a4d3..a33c9f78 100644 --- a/src/hypercorn/protocol/h3.py +++ b/src/hypercorn/protocol/h3.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Union +from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Union, Any from aioquic.h3.connection import H3Connection from aioquic.h3.events import DataReceived, HeadersReceived @@ -37,6 +37,7 @@ def __init__( server: Optional[Tuple[str, int]], quic: QuicConnection, send: Callable[[], Awaitable[None]], + app_state: Dict[str, Any], ) -> None: self.app = app self.client = client @@ -47,6 +48,7 @@ def __init__( self.server = server self.streams: Dict[int, Union[HTTPStream, WSStream]] = {} self.task_group = task_group + self.app_state = app_state async def handle(self, quic_event: QuicEvent) -> None: for event in self.connection.handle_event(quic_event): @@ -102,6 +104,7 @@ async def _create_stream(self, request: HeadersReceived) -> None: self.server, self.stream_send, request.stream_id, + self.app_state, ) else: self.streams[request.stream_id] = HTTPStream( @@ -114,6 +117,7 @@ async def _create_stream(self, request: HeadersReceived) -> None: self.server, self.stream_send, request.stream_id, + self.app_state, ) await self.streams[request.stream_id].handle( diff --git a/src/hypercorn/protocol/http_stream.py b/src/hypercorn/protocol/http_stream.py index 6cd9beea..b591f1c2 100644 --- a/src/hypercorn/protocol/http_stream.py +++ b/src/hypercorn/protocol/http_stream.py @@ -2,7 +2,7 @@ from enum import auto, Enum from time import time -from typing import Awaitable, Callable, Optional, Tuple +from typing import Awaitable, Callable, Optional, Tuple, Dict, Any from urllib.parse import unquote from .events import Body, EndBody, Event, InformationalResponse, Request, Response, StreamClosed @@ -47,6 +47,7 @@ def __init__( server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], stream_id: int, + app_state: Dict[str, Any], ) -> None: self.app = app self.client = client @@ -62,6 +63,7 @@ def __init__( self.state = ASGIHTTPState.REQUEST self.stream_id = stream_id self.task_group = task_group + self.app_state = app_state @property def idle(self) -> bool: @@ -87,6 +89,7 @@ async def handle(self, event: Event) -> None: "client": self.client, "server": self.server, "extensions": {}, + "state": self.app_state, } if event.http_version in PUSH_VERSIONS: self.scope["extensions"]["http.response.push"] = {} diff --git a/src/hypercorn/protocol/quic.py b/src/hypercorn/protocol/quic.py index 3d16e54d..e3798322 100644 --- a/src/hypercorn/protocol/quic.py +++ b/src/hypercorn/protocol/quic.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import partial -from typing import Awaitable, Callable, Dict, Optional, Tuple +from typing import Awaitable, Callable, Dict, Optional, Tuple, Any from aioquic.buffer import Buffer from aioquic.h3.connection import H3_ALPN @@ -34,6 +34,7 @@ def __init__( task_group: TaskGroup, server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], + app_state: Dict[str, Any], ) -> None: self.app = app self.config = config @@ -46,6 +47,7 @@ def __init__( self.quic_config = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=False) self.quic_config.load_cert_chain(certfile=config.certfile, keyfile=config.keyfile) + self.app_state = app_state @property def idle(self) -> bool: @@ -110,6 +112,7 @@ async def _handle_events( self.server, connection, partial(self.send_all, connection), + self.app_state, ) elif isinstance(event, ConnectionIdIssued): self.connections[event.connection_id] = connection diff --git a/src/hypercorn/protocol/ws_stream.py b/src/hypercorn/protocol/ws_stream.py index cebbf89b..e85cbdcf 100644 --- a/src/hypercorn/protocol/ws_stream.py +++ b/src/hypercorn/protocol/ws_stream.py @@ -3,7 +3,7 @@ from enum import auto, Enum from io import BytesIO, StringIO from time import time -from typing import Awaitable, Callable, Iterable, List, Optional, Tuple, Union +from typing import Awaitable, Callable, Iterable, List, Optional, Tuple, Union, Dict, Any from urllib.parse import unquote from wsproto.connection import Connection, ConnectionState, ConnectionType @@ -172,6 +172,7 @@ def __init__( server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], stream_id: int, + app_state: Dict[str, Any], ) -> None: self.app = app self.app_put: Optional[Callable] = None @@ -193,6 +194,7 @@ def __init__( self.connection: Connection self.handshake: Handshake + self.app_state = app_state @property def idle(self) -> bool: @@ -219,6 +221,7 @@ async def handle(self, event: Event) -> None: "server": self.server, "subprotocols": self.handshake.subprotocols or [], "extensions": {"websocket.http.response": {}}, + "state": self.app_state, } if not valid_server_name(self.config, event): diff --git a/src/hypercorn/trio/lifespan.py b/src/hypercorn/trio/lifespan.py index a45fc528..9eabe644 100644 --- a/src/hypercorn/trio/lifespan.py +++ b/src/hypercorn/trio/lifespan.py @@ -1,5 +1,6 @@ from __future__ import annotations +from typing import Dict, Any import trio from ..config import Config @@ -21,6 +22,7 @@ def __init__(self, app: AppWrapper, config: Config) -> None: config.max_app_queue_size ) self.supported = True + self.state: Dict[str, Any] = {} async def handle_lifespan( self, *, task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED @@ -29,6 +31,7 @@ async def handle_lifespan( scope: LifespanScope = { "type": "lifespan", "asgi": {"spec_version": "2.0", "version": "3.0"}, + "state": self.state, } try: await self.app( diff --git a/src/hypercorn/trio/run.py b/src/hypercorn/trio/run.py index 5dfbf91f..b6118cc6 100644 --- a/src/hypercorn/trio/run.py +++ b/src/hypercorn/trio/run.py @@ -69,7 +69,7 @@ async def worker_serve( await config.log.info(f"Running on http://{bind} (CTRL + C to quit)") for sock in sockets.quic_sockets: - await server_nursery.start(UDPServer(app, config, context, sock).run) + await server_nursery.start(UDPServer(app, config, context, sock, lifespan.state.copy()).run) bind = repr_socket_addr(sock.family, sock.getsockname()) await config.log.info(f"Running on https://{bind} (QUIC) (CTRL + C to quit)") @@ -82,7 +82,7 @@ async def worker_serve( nursery.start_soon( partial( trio.serve_listeners, - partial(TCPServer, app, config, context), + partial(TCPServer, app, config, context, lifespan.state.copy()), listeners, handler_nursery=server_nursery, ), diff --git a/src/hypercorn/trio/tcp_server.py b/src/hypercorn/trio/tcp_server.py index 3419440f..21c078ef 100644 --- a/src/hypercorn/trio/tcp_server.py +++ b/src/hypercorn/trio/tcp_server.py @@ -1,7 +1,7 @@ from __future__ import annotations from math import inf -from typing import Any, Generator, Optional +from typing import Any, Generator, Optional, Dict import trio @@ -18,7 +18,7 @@ class TCPServer: def __init__( - self, app: AppWrapper, config: Config, context: WorkerContext, stream: trio.abc.Stream + self, app: AppWrapper, config: Config, context: WorkerContext, stream: trio.abc.Stream, app_state: Dict[str, Any] ) -> None: self.app = app self.config = config @@ -29,6 +29,7 @@ def __init__( self.stream = stream self._idle_handle: Optional[trio.CancelScope] = None + self.app_state = app_state def __await__(self) -> Generator[Any, None, None]: return self.run().__await__() @@ -64,6 +65,7 @@ async def run(self) -> None: server, self.protocol_send, alpn_protocol, + self.app_state ) await self.protocol.initiate() await self._start_idle() diff --git a/src/hypercorn/trio/udp_server.py b/src/hypercorn/trio/udp_server.py index b8d4530b..2146324f 100644 --- a/src/hypercorn/trio/udp_server.py +++ b/src/hypercorn/trio/udp_server.py @@ -1,5 +1,6 @@ from __future__ import annotations +from typing import Dict, Any import trio from .task_group import TaskGroup @@ -19,11 +20,13 @@ def __init__( config: Config, context: WorkerContext, socket: trio.socket.socket, + app_state: Dict[str, Any], ) -> None: self.app = app self.config = config self.context = context self.socket = trio.socket.from_stdlib_socket(socket) + self.app_state = app_state async def run( self, task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED @@ -34,7 +37,7 @@ async def run( server = parse_socket_addr(self.socket.family, self.socket.getsockname()) async with TaskGroup() as task_group: self.protocol = QuicProtocol( - self.app, self.config, self.context, task_group, server, self.protocol_send + self.app, self.config, self.context, task_group, server, self.protocol_send, self.app_state ) while not self.context.terminated.is_set() or not self.protocol.idle: diff --git a/src/hypercorn/typing.py b/src/hypercorn/typing.py index 206415c0..aeb9ccf5 100644 --- a/src/hypercorn/typing.py +++ b/src/hypercorn/typing.py @@ -39,6 +39,7 @@ class HTTPScope(TypedDict): client: Optional[Tuple[str, int]] server: Optional[Tuple[str, Optional[int]]] extensions: Dict[str, dict] + state: Dict[str, Any] class WebsocketScope(TypedDict): @@ -55,11 +56,13 @@ class WebsocketScope(TypedDict): server: Optional[Tuple[str, Optional[int]]] subprotocols: Iterable[str] extensions: Dict[str, dict] + state: Dict[str, Any] class LifespanScope(TypedDict): type: Literal["lifespan"] asgi: ASGIVersions + state: Dict[str, Any] WWWScope = Union[HTTPScope, WebsocketScope] From a0fd02cf6d6f69348ee5c03af9742f699d84b373 Mon Sep 17 00:00:00 2001 From: synodriver Date: Wed, 24 May 2023 14:41:25 +0800 Subject: [PATCH 2/2] fix trio appstate init --- src/hypercorn/trio/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hypercorn/trio/run.py b/src/hypercorn/trio/run.py index b6118cc6..f4ce767d 100644 --- a/src/hypercorn/trio/run.py +++ b/src/hypercorn/trio/run.py @@ -82,7 +82,7 @@ async def worker_serve( nursery.start_soon( partial( trio.serve_listeners, - partial(TCPServer, app, config, context, lifespan.state.copy()), + partial(TCPServer, app, config, context, app_state=lifespan.state.copy()), listeners, handler_nursery=server_nursery, ),