Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement asgi lifespan state #107

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/hypercorn/asyncio/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
from functools import partial
from typing import Any, Callable
from typing import Any, Callable, Dict

from ..config import Config
from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope
Expand All @@ -27,12 +27,14 @@ def __init__(self, app: AppWrapper, config: Config, loop: asyncio.AbstractEventL
# required to ensure the support has been checked before
# waiting on timeouts.
self._started = asyncio.Event()
self.state: Dict[str, Any] = {}

async def handle_lifespan(self) -> None:
self._started.set()
scope: LifespanScope = {
"type": "lifespan",
"asgi": {"spec_version": "2.0", "version": "3.0"},
"state": self.state,
}

def _call_soon(func: Callable, *args: Any) -> Any:
Expand Down
4 changes: 2 additions & 2 deletions src/hypercorn/asyncio/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _signal_handler(*_: Any) -> None: # noqa: N803

async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
server_tasks.add(asyncio.current_task(loop))
await TCPServer(app, loop, config, context, reader, writer)
await TCPServer(app, loop, config, context, reader, writer, lifespan.state.copy())

servers = []
for sock in sockets.secure_sockets:
Expand Down Expand Up @@ -127,7 +127,7 @@ async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamW
sock = _share_socket(sock)

_, protocol = await loop.create_datagram_endpoint(
lambda: UDPServer(app, loop, config, context), sock=sock
lambda: UDPServer(app, loop, config, context, lifespan.state.copy()), sock=sock
)
server_tasks.add(loop.create_task(protocol.run()))
bind = repr_socket_addr(sock.family, sock.getsockname())
Expand Down
5 changes: 4 additions & 1 deletion src/hypercorn/asyncio/tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
from ssl import SSLError
from typing import Any, Generator, Optional
from typing import Any, Generator, Optional, Dict

from .task_group import TaskGroup
from .worker_context import WorkerContext
Expand All @@ -24,6 +24,7 @@ def __init__(
context: WorkerContext,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
app_state: Dict[str, Any]
) -> None:
self.app = app
self.config = config
Expand All @@ -36,6 +37,7 @@ def __init__(
self.idle_lock = asyncio.Lock()

self._idle_handle: Optional[asyncio.Task] = None
self.app_state = app_state

def __await__(self) -> Generator[Any, None, None]:
return self.run().__await__()
Expand Down Expand Up @@ -64,6 +66,7 @@ async def run(self) -> None:
server,
self.protocol_send,
alpn_protocol,
self.app_state
)
await self.protocol.initiate()
await self._start_idle()
Expand Down
6 changes: 4 additions & 2 deletions src/hypercorn/asyncio/udp_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import asyncio
from typing import Optional, Tuple, TYPE_CHECKING
from typing import Optional, Tuple, Dict, Any, TYPE_CHECKING

from .task_group import TaskGroup
from .worker_context import WorkerContext
Expand All @@ -22,6 +22,7 @@ def __init__(
loop: asyncio.AbstractEventLoop,
config: Config,
context: WorkerContext,
app_state: Dict[str, Any],
) -> None:
self.app = app
self.config = config
Expand All @@ -30,6 +31,7 @@ def __init__(
self.protocol: "QuicProtocol"
self.protocol_queue: asyncio.Queue = asyncio.Queue(10)
self.transport: Optional[asyncio.DatagramTransport] = None
self.app_state = app_state

def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore
self.transport = transport
Expand All @@ -48,7 +50,7 @@ async def run(self) -> None:
server = parse_socket_addr(socket.family, socket.getsockname())
async with TaskGroup(self.loop) as task_group:
self.protocol = QuicProtocol(
self.app, self.config, self.context, task_group, server, self.protocol_send
self.app, self.config, self.context, task_group, server, self.protocol_send, self.app_state
)

while not self.context.terminated.is_set() or not self.protocol.idle:
Expand Down
8 changes: 7 additions & 1 deletion src/hypercorn/protocol/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Awaitable, Callable, Optional, Tuple, Union
from typing import Awaitable, Callable, Optional, Tuple, Union, Dict, Any

from .h2 import H2Protocol
from .h11 import H2CProtocolRequiredError, H2ProtocolAssumedError, H11Protocol
Expand All @@ -21,6 +21,7 @@ def __init__(
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
alpn_protocol: Optional[str] = None,
app_state: Dict[str, Any] = None,
) -> None:
self.app = app
self.config = config
Expand All @@ -31,6 +32,7 @@ def __init__(
self.server = server
self.send = send
self.protocol: Union[H11Protocol, H2Protocol]
self.app_state = app_state
if alpn_protocol == "h2":
self.protocol = H2Protocol(
self.app,
Expand All @@ -41,6 +43,7 @@ def __init__(
self.client,
self.server,
self.send,
self.app_state,
)
else:
self.protocol = H11Protocol(
Expand All @@ -52,6 +55,7 @@ def __init__(
self.client,
self.server,
self.send,
self.app_state,
)

async def initiate(self) -> None:
Expand All @@ -70,6 +74,7 @@ async def handle(self, event: Event) -> None:
self.client,
self.server,
self.send,
self.app_state,
)
await self.protocol.initiate()
if error.data != b"":
Expand All @@ -84,6 +89,7 @@ async def handle(self, event: Event) -> None:
self.client,
self.server,
self.send,
self.app_state,
)
await self.protocol.initiate(error.headers, error.settings)
if error.data != b"":
Expand Down
6 changes: 5 additions & 1 deletion src/hypercorn/protocol/h11.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from itertools import chain
from typing import Awaitable, Callable, cast, Optional, Tuple, Type, Union
from typing import Awaitable, Callable, cast, Optional, Tuple, Type, Union, Dict, Any

import h11

Expand Down Expand Up @@ -88,6 +88,7 @@ def __init__(
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
app_state: Dict[str, Any],
) -> None:
self.app = app
self.can_read = context.event_class()
Expand All @@ -102,6 +103,7 @@ def __init__(
self.ssl = ssl
self.stream: Optional[Union[HTTPStream, WSStream]] = None
self.task_group = task_group
self.app_state = app_state

async def initiate(self) -> None:
pass
Expand Down Expand Up @@ -205,6 +207,7 @@ async def _create_stream(self, request: h11.Request) -> None:
self.server,
self.stream_send,
STREAM_ID,
self.app_state
)
self.connection = H11WSConnection(cast(h11.Connection, self.connection))
else:
Expand All @@ -218,6 +221,7 @@ async def _create_stream(self, request: h11.Request) -> None:
self.server,
self.stream_send,
STREAM_ID,
self.app_state
)
await self.stream.handle(
Request(
Expand Down
6 changes: 5 additions & 1 deletion src/hypercorn/protocol/h2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union, Any

import h2
import h2.connection
Expand Down Expand Up @@ -88,6 +88,7 @@ def __init__(
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
app_state: Dict[str, Any],
) -> None:
self.app = app
self.client = client
Expand Down Expand Up @@ -117,6 +118,7 @@ def __init__(
self.has_data = self.context.event_class()
self.priority = priority.PriorityTree()
self.stream_buffers: Dict[int, StreamBuffer] = {}
self.app_state = app_state

@property
def idle(self) -> bool:
Expand Down Expand Up @@ -318,6 +320,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None:
self.server,
self.stream_send,
request.stream_id,
self.app_state
)
else:
self.streams[request.stream_id] = HTTPStream(
Expand All @@ -330,6 +333,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None:
self.server,
self.stream_send,
request.stream_id,
self.app_state
)
self.stream_buffers[request.stream_id] = StreamBuffer(self.context.event_class)
try:
Expand Down
6 changes: 5 additions & 1 deletion src/hypercorn/protocol/h3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Union
from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Union, Any

from aioquic.h3.connection import H3Connection
from aioquic.h3.events import DataReceived, HeadersReceived
Expand Down Expand Up @@ -37,6 +37,7 @@ def __init__(
server: Optional[Tuple[str, int]],
quic: QuicConnection,
send: Callable[[], Awaitable[None]],
app_state: Dict[str, Any],
) -> None:
self.app = app
self.client = client
Expand All @@ -47,6 +48,7 @@ def __init__(
self.server = server
self.streams: Dict[int, Union[HTTPStream, WSStream]] = {}
self.task_group = task_group
self.app_state = app_state

async def handle(self, quic_event: QuicEvent) -> None:
for event in self.connection.handle_event(quic_event):
Expand Down Expand Up @@ -102,6 +104,7 @@ async def _create_stream(self, request: HeadersReceived) -> None:
self.server,
self.stream_send,
request.stream_id,
self.app_state,
)
else:
self.streams[request.stream_id] = HTTPStream(
Expand All @@ -114,6 +117,7 @@ async def _create_stream(self, request: HeadersReceived) -> None:
self.server,
self.stream_send,
request.stream_id,
self.app_state,
)

await self.streams[request.stream_id].handle(
Expand Down
5 changes: 4 additions & 1 deletion src/hypercorn/protocol/http_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from enum import auto, Enum
from time import time
from typing import Awaitable, Callable, Optional, Tuple
from typing import Awaitable, Callable, Optional, Tuple, Dict, Any
from urllib.parse import unquote

from .events import Body, EndBody, Event, InformationalResponse, Request, Response, StreamClosed
Expand Down Expand Up @@ -47,6 +47,7 @@ def __init__(
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
stream_id: int,
app_state: Dict[str, Any],
) -> None:
self.app = app
self.client = client
Expand All @@ -62,6 +63,7 @@ def __init__(
self.state = ASGIHTTPState.REQUEST
self.stream_id = stream_id
self.task_group = task_group
self.app_state = app_state

@property
def idle(self) -> bool:
Expand All @@ -87,6 +89,7 @@ async def handle(self, event: Event) -> None:
"client": self.client,
"server": self.server,
"extensions": {},
"state": self.app_state,
}
if event.http_version in PUSH_VERSIONS:
self.scope["extensions"]["http.response.push"] = {}
Expand Down
5 changes: 4 additions & 1 deletion src/hypercorn/protocol/quic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from functools import partial
from typing import Awaitable, Callable, Dict, Optional, Tuple
from typing import Awaitable, Callable, Dict, Optional, Tuple, Any

from aioquic.buffer import Buffer
from aioquic.h3.connection import H3_ALPN
Expand Down Expand Up @@ -34,6 +34,7 @@ def __init__(
task_group: TaskGroup,
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
app_state: Dict[str, Any],
) -> None:
self.app = app
self.config = config
Expand All @@ -46,6 +47,7 @@ def __init__(

self.quic_config = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=False)
self.quic_config.load_cert_chain(certfile=config.certfile, keyfile=config.keyfile)
self.app_state = app_state

@property
def idle(self) -> bool:
Expand Down Expand Up @@ -110,6 +112,7 @@ async def _handle_events(
self.server,
connection,
partial(self.send_all, connection),
self.app_state,
)
elif isinstance(event, ConnectionIdIssued):
self.connections[event.connection_id] = connection
Expand Down
5 changes: 4 additions & 1 deletion src/hypercorn/protocol/ws_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import auto, Enum
from io import BytesIO, StringIO
from time import time
from typing import Awaitable, Callable, Iterable, List, Optional, Tuple, Union
from typing import Awaitable, Callable, Iterable, List, Optional, Tuple, Union, Dict, Any
from urllib.parse import unquote

from wsproto.connection import Connection, ConnectionState, ConnectionType
Expand Down Expand Up @@ -172,6 +172,7 @@ def __init__(
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
stream_id: int,
app_state: Dict[str, Any],
) -> None:
self.app = app
self.app_put: Optional[Callable] = None
Expand All @@ -193,6 +194,7 @@ def __init__(

self.connection: Connection
self.handshake: Handshake
self.app_state = app_state

@property
def idle(self) -> bool:
Expand All @@ -219,6 +221,7 @@ async def handle(self, event: Event) -> None:
"server": self.server,
"subprotocols": self.handshake.subprotocols or [],
"extensions": {"websocket.http.response": {}},
"state": self.app_state,
}

if not valid_server_name(self.config, event):
Expand Down
Loading