diff --git a/a2wsgi/asgi.py b/a2wsgi/asgi.py index 6fb549a..1bda6a5 100644 --- a/a2wsgi/asgi.py +++ b/a2wsgi/asgi.py @@ -3,11 +3,11 @@ import threading from http import HTTPStatus from io import BytesIO -from itertools import chain from typing import Any, Coroutine, Deque, Iterable, Optional from typing import cast as typing_cast -from .types import ASGIApp, Environ, ExcInfo, Message, Scope, StartResponse +from .asgi_typing import HTTPScope, ASGIApp, ReceiveEvent, SendEvent +from .wsgi_typing import Environ, StartResponse, ExceptionInfo, IterableChunks class defaultdict(dict): @@ -70,46 +70,45 @@ def wait(self) -> Any: return message -def build_scope(environ: Environ) -> Scope: +def build_scope(environ: Environ) -> HTTPScope: headers = [ - (key.lower().replace("_", "-").encode("latin-1"), value.encode("latin-1")) - for key, value in chain( - ( - (key[5:], value) - for key, value in environ.items() - if key.startswith("HTTP_") - and key not in ("HTTP_CONTENT_TYPE", "HTTP_CONTENT_LENGTH") - ), - ( - (key, value) - for key, value in environ.items() - if key in ("CONTENT_TYPE", "CONTENT_LENGTH") - ), + ( + (key[5:] if key.startswith("HTTP_") else key) + .lower() + .replace("_", "-") + .encode("latin-1"), + value.encode("latin-1"), # type: ignore ) + for key, value in environ.items() + if ( + key.startswith("HTTP_") + and key not in ("HTTP_CONTENT_TYPE", "HTTP_CONTENT_LENGTH") + ) + or key in ("CONTENT_TYPE", "CONTENT_LENGTH") ] - if environ.get("REMOTE_ADDR") and environ.get("REMOTE_PORT"): - client = (environ["REMOTE_ADDR"], int(environ["REMOTE_PORT"])) - else: - client = None - root_path = environ.get("SCRIPT_NAME", "").encode("latin1").decode("utf8") - path = root_path + environ["PATH_INFO"].encode("latin1").decode("utf8") + path = root_path + environ.get("PATH_INFO", "").encode("latin1").decode("utf8") - return { - "wsgi_environ": environ, + scope: HTTPScope = { + "wsgi_environ": environ, # type: ignore a2wsgi "type": "http", "asgi": {"version": "3.0", "spec_version": "3.0"}, "http_version": environ.get("SERVER_PROTOCOL", "http/1.0").split("/")[1], "method": environ["REQUEST_METHOD"], "scheme": environ.get("wsgi.url_scheme", "http"), "path": path, - "query_string": environ["QUERY_STRING"].encode("ascii"), + "query_string": environ.get("QUERY_STRING", "").encode("ascii"), "root_path": root_path, - "client": client, "server": (environ["SERVER_NAME"], int(environ["SERVER_PORT"])), "headers": headers, + "extensions": {}, } + if environ.get("REMOTE_ADDR") and environ.get("REMOTE_PORT"): + client = (environ.get("REMOTE_ADDR", ""), int(environ.get("REMOTE_PORT", "0"))) + scope["client"] = client + + return scope class ASGIMiddleware: @@ -164,12 +163,12 @@ def _init_async_lock(): self.asgi_done = threading.Event() self.wsgi_should_stop: bool = False - async def asgi_receive(self) -> Message: + async def asgi_receive(self) -> ReceiveEvent: async with self.async_lock: self.sync_event.set({"type": "receive"}) return await self.async_event.wait() - async def asgi_send(self, message: Message) -> None: + async def asgi_send(self, message: SendEvent) -> None: async with self.async_lock: self.sync_event.set(message) await self.async_event.wait() @@ -201,7 +200,7 @@ def start_asgi_app(self, environ: Environ) -> asyncio.Task: def __call__( self, environ: Environ, start_response: StartResponse - ) -> Iterable[bytes]: + ) -> IterableChunks: read_count: int = 0 body = environ["wsgi.input"] or BytesIO() content_length = int(environ.get("CONTENT_LENGTH", None) or 0) @@ -262,8 +261,8 @@ def __call__( yield b"" def error_response( - self, start_response: StartResponse, exception: ExcInfo - ) -> Iterable[bytes]: + self, start_response: StartResponse, exception: ExceptionInfo + ) -> IterableChunks: start_response( "500 Internal Server Error", [ diff --git a/a2wsgi/asgi_typing.py b/a2wsgi/asgi_typing.py new file mode 100644 index 0000000..b89db48 --- /dev/null +++ b/a2wsgi/asgi_typing.py @@ -0,0 +1,182 @@ +""" +https://asgi.readthedocs.io/en/latest/specs/index.html +""" +import sys +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + Literal, + Optional, + Tuple, + TypedDict, + Union, +) + +if sys.version_info >= (3, 11): + from typing import NotRequired +else: + from typing_extensions import NotRequired + + +class ASGIVersions(TypedDict): + spec_version: str + version: Literal["3.0"] + + +class HTTPScope(TypedDict): + type: Literal["http"] + asgi: ASGIVersions + http_version: str + method: str + scheme: str + path: str + raw_path: NotRequired[bytes] + query_string: bytes + root_path: str + headers: Iterable[Tuple[bytes, bytes]] + client: NotRequired[Tuple[str, int]] + server: NotRequired[Tuple[str, Optional[int]]] + state: NotRequired[Dict[str, Any]] + extensions: NotRequired[Dict[str, Dict[object, object]]] + + +class WebSocketScope(TypedDict): + type: Literal["websocket"] + asgi: ASGIVersions + http_version: str + scheme: str + path: str + raw_path: bytes + query_string: bytes + root_path: str + headers: Iterable[Tuple[bytes, bytes]] + client: NotRequired[Tuple[str, int]] + server: NotRequired[Tuple[str, Optional[int]]] + subprotocols: Iterable[str] + state: NotRequired[Dict[str, Any]] + extensions: NotRequired[Dict[str, Dict[object, object]]] + + +class LifespanScope(TypedDict): + type: Literal["lifespan"] + asgi: ASGIVersions + state: NotRequired[Dict[str, Any]] + + +WWWScope = Union[HTTPScope, WebSocketScope] +Scope = Union[HTTPScope, WebSocketScope, LifespanScope] + + +class HTTPRequestEvent(TypedDict): + type: Literal["http.request"] + body: bytes + more_body: NotRequired[bool] + + +class HTTPResponseStartEvent(TypedDict): + type: Literal["http.response.start"] + status: int + headers: NotRequired[Iterable[Tuple[bytes, bytes]]] + trailers: NotRequired[bool] + + +class HTTPResponseBodyEvent(TypedDict): + type: Literal["http.response.body"] + body: NotRequired[bytes] + more_body: NotRequired[bool] + + +class HTTPDisconnectEvent(TypedDict): + type: Literal["http.disconnect"] + + +class WebSocketConnectEvent(TypedDict): + type: Literal["websocket.connect"] + + +class WebSocketAcceptEvent(TypedDict): + type: Literal["websocket.accept"] + subprotocol: NotRequired[str] + headers: NotRequired[Iterable[Tuple[bytes, bytes]]] + + +class WebSocketReceiveEvent(TypedDict): + type: Literal["websocket.receive"] + bytes: NotRequired[bytes] + text: NotRequired[str] + + +class WebSocketSendEvent(TypedDict): + type: Literal["websocket.send"] + bytes: NotRequired[bytes] + text: NotRequired[str] + + +class WebSocketDisconnectEvent(TypedDict): + type: Literal["websocket.disconnect"] + code: int + + +class WebSocketCloseEvent(TypedDict): + type: Literal["websocket.close"] + code: NotRequired[int] + reason: NotRequired[str] + + +class LifespanStartupEvent(TypedDict): + type: Literal["lifespan.startup"] + + +class LifespanShutdownEvent(TypedDict): + type: Literal["lifespan.shutdown"] + + +class LifespanStartupCompleteEvent(TypedDict): + type: Literal["lifespan.startup.complete"] + + +class LifespanStartupFailedEvent(TypedDict): + type: Literal["lifespan.startup.failed"] + message: str + + +class LifespanShutdownCompleteEvent(TypedDict): + type: Literal["lifespan.shutdown.complete"] + + +class LifespanShutdownFailedEvent(TypedDict): + type: Literal["lifespan.shutdown.failed"] + message: str + + +ReceiveEvent = Union[ + HTTPRequestEvent, + HTTPDisconnectEvent, + WebSocketConnectEvent, + WebSocketReceiveEvent, + WebSocketDisconnectEvent, + LifespanStartupEvent, + LifespanShutdownEvent, +] + +SendEvent = Union[ + HTTPResponseStartEvent, + HTTPResponseBodyEvent, + HTTPDisconnectEvent, + WebSocketAcceptEvent, + WebSocketSendEvent, + WebSocketCloseEvent, + LifespanStartupCompleteEvent, + LifespanStartupFailedEvent, + LifespanShutdownCompleteEvent, + LifespanShutdownFailedEvent, +] + +Receive = Callable[[], Awaitable[ReceiveEvent]] + +Send = Callable[[SendEvent], Awaitable[None]] + +ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]] diff --git a/a2wsgi/types.py b/a2wsgi/types.py deleted file mode 100644 index 79d0f0b..0000000 --- a/a2wsgi/types.py +++ /dev/null @@ -1,29 +0,0 @@ -from types import TracebackType -from typing import ( - Any, - Awaitable, - Callable, - Iterable, - MutableMapping, - Optional, - Tuple, - Type, -) - -ExcInfo = Tuple[Type[BaseException], BaseException, Optional[TracebackType]] - -Message = MutableMapping[str, Any] - -Scope = MutableMapping[str, Any] - -Receive = Callable[[], Awaitable[Message]] - -Send = Callable[[Message], Awaitable[None]] - -ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]] - -Environ = MutableMapping[str, Any] - -StartResponse = Callable[[str, Iterable[Tuple[str, str]], Optional[ExcInfo]], None] - -WSGIApp = Callable[[Environ, StartResponse], Iterable[bytes]] diff --git a/a2wsgi/wsgi.py b/a2wsgi/wsgi.py index 999ddd3..8771fe2 100644 --- a/a2wsgi/wsgi.py +++ b/a2wsgi/wsgi.py @@ -7,7 +7,8 @@ import typing from concurrent.futures import ThreadPoolExecutor -from .types import Environ, Message, Receive, Scope, Send, StartResponse, WSGIApp +from .asgi_typing import HTTPScope, Scope, Receive, Send, SendEvent +from .wsgi_typing import Environ, StartResponse, ExceptionInfo, WSGIApp, WriteCallable class Body: @@ -87,21 +88,21 @@ def unicode_to_wsgi(u): return u.encode(ENC, ESC).decode("iso-8859-1") -def build_environ(scope: Scope, body: Body) -> Environ: +def build_environ(scope: HTTPScope, body: Body) -> Environ: """ Builds a scope and request body into a WSGI environ object. """ script_name = scope.get("root_path", "").encode("utf8").decode("latin1") path_info = scope["path"].encode("utf8").decode("latin1") if path_info.startswith(script_name): - path_info = path_info[len(script_name):] + path_info = path_info[len(script_name) :] script_name_environ_var = os.environ.get("SCRIPT_NAME", "") if script_name_environ_var: script_name = unicode_to_wsgi(script_name_environ_var) - environ = { - "asgi.scope": scope, + environ: Environ = { + "asgi.scope": scope, # type: ignore a2wsgi "REQUEST_METHOD": scope["method"], "SCRIPT_NAME": script_name, "PATH_INFO": path_info, @@ -117,13 +118,16 @@ def build_environ(scope: Scope, body: Body) -> Environ: } # Get server name and port - required in WSGI, not in ASGI - server = scope.get("server") or ("localhost", 80) - environ["SERVER_NAME"] = server[0] - environ["SERVER_PORT"] = server[1] + server_addr, server_port = scope.get("server") or ("localhost", 80) + environ["SERVER_NAME"] = server_addr + environ["SERVER_PORT"] = str(server_port or 0) # Get client IP address - if scope.get("client"): - environ["REMOTE_ADDR"] = scope["client"][0] + client = scope.get("client") + if client is not None: + addr, port = client + environ["REMOTE_ADDR"] = addr + environ["REMOTE_PORT"] = str(port) # Go through headers and make them into environ entries for name, value in scope.get("headers", []): @@ -177,12 +181,14 @@ def __init__(self, app: WSGIApp, executor: ThreadPoolExecutor) -> None: self.app = app self.executor = executor self.send_event = asyncio.Event() - self.send_queue: typing.Deque[typing.Union[Message, None]] = collections.deque() + self.send_queue: typing.Deque[ + typing.Union[SendEvent, None] + ] = collections.deque() self.loop = asyncio.get_event_loop() self.response_started = False self.exc_info: typing.Any = None - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async def __call__(self, scope: HTTPScope, receive: Receive, send: Send) -> None: body = Body(self.loop, receive) environ = build_environ(scope, body) sender = None @@ -204,7 +210,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if sender and not sender.done(): sender.cancel() # pragma: no cover - def send(self, message: typing.Optional[Message]) -> None: + def send(self, message: typing.Optional[SendEvent]) -> None: self.send_queue.append(message) self.loop.call_soon_threadsafe(self.send_event.set) @@ -223,8 +229,8 @@ def start_response( self, status: str, response_headers: typing.List[typing.Tuple[str, str]], - exc_info: typing.Any = None, - ) -> None: + exc_info: typing.Optional[ExceptionInfo] = None, + ) -> WriteCallable: self.exc_info = exc_info if not self.response_started: self.response_started = True @@ -241,9 +247,18 @@ def start_response( "headers": headers, } ) + return lambda chunk: self.send( + {"type": "http.response.body", "body": chunk, "more_body": True} + ) def wsgi(self, environ: Environ, start_response: StartResponse) -> None: - for chunk in self.app(environ, start_response): - self.send({"type": "http.response.body", "body": chunk, "more_body": True}) + iterable = self.app(environ, start_response) + try: + for chunk in iterable: + self.send( + {"type": "http.response.body", "body": chunk, "more_body": True} + ) - self.send({"type": "http.response.body", "body": b""}) + self.send({"type": "http.response.body", "body": b""}) + finally: + getattr(iterable, "close", lambda: None)() diff --git a/a2wsgi/wsgi_typing.py b/a2wsgi/wsgi_typing.py new file mode 100644 index 0000000..50952b6 --- /dev/null +++ b/a2wsgi/wsgi_typing.py @@ -0,0 +1,194 @@ +""" +https://peps.python.org/pep-3333/ +""" +from types import TracebackType +from typing import ( + Any, + Callable, + Iterable, + List, + Optional, + Protocol, + Tuple, + Type, + TypedDict, +) + +CGIRequiredDefined = TypedDict( + "CGIRequiredDefined", + { + # The HTTP request method, such as GET or POST. This cannot ever be an + # empty string, and so is always required. + "REQUEST_METHOD": str, + # When HTTP_HOST is not set, these variables can be combined to determine + # a default. + # SERVER_NAME and SERVER_PORT are required strings and must never be empty. + "SERVER_NAME": str, + "SERVER_PORT": str, + # The version of the protocol the client used to send the request. + # Typically this will be something like "HTTP/1.0" or "HTTP/1.1" and + # may be used by the application to determine how to treat any HTTP + # request headers. (This variable should probably be called REQUEST_PROTOCOL, + # since it denotes the protocol used in the request, and is not necessarily + # the protocol that will be used in the server's response. However, for + # compatibility with CGI we have to keep the existing name.) + "SERVER_PROTOCOL": str, + }, +) + +CGIOptionalDefined = TypedDict( + "CGIOptionalDefined", + { + "REQUEST_URI": str, + "REMOTE_ADDR": str, + "REMOTE_PORT": str, + # The initial portion of the request URL’s “path” that corresponds to the + # application object, so that the application knows its virtual “location”. + # This may be an empty string, if the application corresponds to the “root” + # of the server. + "SCRIPT_NAME": str, + # The remainder of the request URL’s “path”, designating the virtual + # “location” of the request’s target within the application. This may be an + # empty string, if the request URL targets the application root and does + # not have a trailing slash. + "PATH_INFO": str, + # The portion of the request URL that follows the “?”, if any. May be empty + # or absent. + "QUERY_STRING": str, + # The contents of any Content-Type fields in the HTTP request. May be empty + # or absent. + "CONTENT_TYPE": str, + # The contents of any Content-Length fields in the HTTP request. May be empty + # or absent. + "CONTENT_LENGTH": str, + }, + total=False, +) + + +class InputStream(Protocol): + """ + An input stream (file-like object) from which the HTTP request body bytes can be + read. (The server or gateway may perform reads on-demand as requested by the + application, or it may pre- read the client's request body and buffer it in-memory + or on disk, or use any other technique for providing such an input stream, according + to its preference.) + """ + + def read(self, size: int = -1, /) -> bytes: + """ + The server is not required to read past the client's specified Content-Length, + and should simulate an end-of-file condition if the application attempts to read + past that point. The application should not attempt to read more data than is + specified by the CONTENT_LENGTH variable. + A server should allow read() to be called without an argument, and return the + remainder of the client's input stream. + A server should return empty bytestrings from any attempt to read from an empty + or exhausted input stream. + """ + raise NotImplementedError + + def readline(self, limit: int = -1, /) -> bytes: + """ + Servers should support the optional "size" argument to readline(), but as in + WSGI 1.0, they are allowed to omit support for it. + (In WSGI 1.0, the size argument was not supported, on the grounds that it might + have been complex to implement, and was not often used in practice... but then + the cgi module started using it, and so practical servers had to start + supporting it anyway!) + """ + raise NotImplementedError + + def readlines(self, hint: int = -1, /) -> List[bytes]: + """ + Note that the hint argument to readlines() is optional for both caller and + implementer. The application is free not to supply it, and the server or gateway + is free to ignore it. + """ + raise NotImplementedError + + +class ErrorStream(Protocol): + """ + An output stream (file-like object) to which error output can be written, + for the purpose of recording program or other errors in a standardized and + possibly centralized location. This should be a "text mode" stream; + i.e., applications should use "\n" as a line ending, and assume that it will + be converted to the correct line ending by the server/gateway. + (On platforms where the str type is unicode, the error stream should accept + and log arbitrary unicode without raising an error; it is allowed, however, + to substitute characters that cannot be rendered in the stream's encoding.) + For many servers, wsgi.errors will be the server's main error log. Alternatively, + this may be sys.stderr, or a log file of some sort. The server's documentation + should include an explanation of how to configure this or where to find the + recorded output. A server or gateway may supply different error streams to + different applications, if this is desired. + """ + + def flush(self) -> None: + """ + Since the errors stream may not be rewound, servers and gateways are free to + forward write operations immediately, without buffering. In this case, the + flush() method may be a no-op. Portable applications, however, cannot assume + that output is unbuffered or that flush() is a no-op. They must call flush() + if they need to ensure that output has in fact been written. + (For example, to minimize intermingling of data from multiple processes writing + to the same error log.) + """ + raise NotImplementedError + + def write(self, s: str, /) -> Any: + raise NotImplementedError + + def writelines(self, seq: List[str], /) -> Any: + raise NotImplementedError + + +WSGIDefined = TypedDict( + "WSGIDefined", + { + "wsgi.version": Tuple[int, int], # e.g. (1, 0) + "wsgi.url_scheme": str, # e.g. "http" or "https" + "wsgi.input": InputStream, + "wsgi.errors": ErrorStream, + # This value should evaluate true if the application object may be simultaneously + # invoked by another thread in the same process, and should evaluate false otherwise. + "wsgi.multithread": bool, + # This value should evaluate true if an equivalent application object may be + # simultaneously invoked by another process, and should evaluate false otherwise. + "wsgi.multiprocess": bool, + # This value should evaluate true if the server or gateway expects (but does + # not guarantee!) that the application will only be invoked this one time during + # the life of its containing process. Normally, this will only be true for a + # gateway based on CGI (or something similar). + "wsgi.run_once": bool, + }, +) + + +class Environ(CGIRequiredDefined, CGIOptionalDefined, WSGIDefined): + """ + WSGI Environ + """ + + +ExceptionInfo = Tuple[Type[BaseException], BaseException, Optional[TracebackType]] + +# https://peps.python.org/pep-3333/#the-write-callable +WriteCallable = Callable[[bytes], None] + + +class StartResponse(Protocol): + def __call__( + self, + status: str, + response_headers: List[Tuple[str, str]], + exc_info: ExceptionInfo | None = None, + /, + ) -> WriteCallable: + raise NotImplementedError + + +IterableChunks = Iterable[bytes] + +WSGIApp = Callable[[Environ, StartResponse], IterableChunks]