From 5f1838fa9190e430118407dcd7131846044482c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hu=CC=80ng=20X=2E=20Le=CC=82?= Date: Thu, 18 Apr 2019 12:03:33 +0700 Subject: [PATCH 01/31] Fix watchdog reload worker repeatedly if there are multiple changed files --- sanic/reloader_helpers.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/sanic/reloader_helpers.py b/sanic/reloader_helpers.py index b58391f64b..5e1338a43b 100644 --- a/sanic/reloader_helpers.py +++ b/sanic/reloader_helpers.py @@ -121,7 +121,7 @@ def kill_process_children(pid): pass # should signal error here -def kill_program_completly(proc): +def kill_program_completely(proc): """Kill worker and it's child processes and exit. :param proc: worker process (process ID) @@ -141,12 +141,14 @@ def watchdog(sleep_interval): mtimes = {} worker_process = restart_with_reloader() signal.signal( - signal.SIGTERM, lambda *args: kill_program_completly(worker_process) + signal.SIGTERM, lambda *args: kill_program_completely(worker_process) ) signal.signal( - signal.SIGINT, lambda *args: kill_program_completly(worker_process) + signal.SIGINT, lambda *args: kill_program_completely(worker_process) ) while True: + need_reload = False + for filename in _iter_module_files(): try: mtime = os.stat(filename).st_mtime @@ -156,12 +158,13 @@ def watchdog(sleep_interval): old_time = mtimes.get(filename) if old_time is None: mtimes[filename] = mtime - continue elif mtime > old_time: - kill_process_children(worker_process.pid) - worker_process.terminate() - worker_process = restart_with_reloader() mtimes[filename] = mtime - break + need_reload = True + + if need_reload: + kill_process_children(worker_process.pid) + worker_process.terminate() + worker_process = restart_with_reloader() sleep(sleep_interval) From c928c9fe1a1004ca8d004515187a4e86dd29d83d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Mon, 26 Aug 2019 14:53:16 +0300 Subject: [PATCH 02/31] Partially working Trio server. --- sanic/app.py | 2 +- sanic/{server.py => server_asyncio.py} | 0 sanic/server_trio.py | 709 +++++++++++++++++++++++++ 3 files changed, 710 insertions(+), 1 deletion(-) rename sanic/{server.py => server_asyncio.py} (100%) create mode 100644 sanic/server_trio.py diff --git a/sanic/app.py b/sanic/app.py index 343ef0cf7e..86ed12feb2 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1008,7 +1008,7 @@ async def handle_request(self, request, write_callback, stream_callback): # - Add exception handling pass else: - write_callback(response) + await write_callback(response) # -------------------------------------------------------------------- # # Testing diff --git a/sanic/server.py b/sanic/server_asyncio.py similarity index 100% rename from sanic/server.py rename to sanic/server_asyncio.py diff --git a/sanic/server_trio.py b/sanic/server_trio.py new file mode 100644 index 0000000000..dcf2860e2f --- /dev/null +++ b/sanic/server_trio.py @@ -0,0 +1,709 @@ +import trio +import os +import socket +import stat +import traceback + +from functools import partial +from inspect import isawaitable +from ipaddress import ip_address +from multiprocessing import Process +from signal import SIG_IGN, SIGINT, SIGTERM, Signals +from signal import signal as signal_func +from time import time + +from httptools import HttpRequestParser +from httptools.parser.errors import HttpParserError +from multidict import CIMultiDict + +from sanic.exceptions import ( + HeaderExpectationFailed, + InvalidUsage, + PayloadTooLarge, + RequestTimeout, + ServerError, + ServiceUnavailable, +) +from sanic.log import access_logger, logger +from sanic.request import EXPECT_HEADER, Request, StreamBuffer +from sanic.response import HTTPResponse + +class Signal: + stopped = False + + +class HttpProtocol: + """ + This class provides a basic HTTP implementation of the sanic framework. + """ + + __slots__ = ( + # app + "app", + # event loop, connection + "loop", + "transport", + "connections", + "signal", + # request params + "parser", + "request", + "url", + "headers", + # request config + "request_handler", + "request_timeout", + "response_timeout", + "keep_alive_timeout", + "request_max_size", + "request_buffer_queue_size", + "request_class", + "is_request_stream", + "router", + "error_handler", + # enable or disable access log purpose + "access_log", + # connection management + "_total_request_size", + "_last_request_time", + "_last_response_time", + "_is_stream_handler", + "_keep_alive", + "_header_fragment", + "state", + "_debug", + "nursery", + ) + + def __init__( + self, + *, + loop, + app, + request_handler, + error_handler, + signal=Signal(), + connections=None, + request_timeout=60, + response_timeout=60, + keep_alive_timeout=5, + request_max_size=None, + request_buffer_queue_size=100, + request_class=None, + access_log=True, + keep_alive=True, + is_request_stream=False, + router=None, + state=None, + debug=False, + **kwargs + ): + self.app = app + self.transport = None + self.request = None + self.parser = None + self.url = None + self.headers = None + self.router = router + self.signal = signal + self.access_log = access_log + self.connections = connections or {} + self.request_handler = request_handler + self.error_handler = error_handler + self.request_timeout = request_timeout + self.request_buffer_queue_size = request_buffer_queue_size + self.response_timeout = response_timeout + self.keep_alive_timeout = keep_alive_timeout + self.request_max_size = request_max_size + self.request_class = request_class or Request + self.is_request_stream = is_request_stream + self._is_stream_handler = False + self._total_request_size = 0 + self._last_request_time = None + self._last_response_time = None + self._keep_alive = keep_alive + self._header_fragment = b"" + self.state = state or {} + if "requests_count" not in self.state: + self.state["requests_count"] = 0 + self._debug = debug + + @property + def keep_alive(self): + """ + Check if the connection needs to be kept alive based on the params + attached to the `_keep_alive` attribute, :attr:`Signal.stopped` + and :func:`HttpProtocol.parser.should_keep_alive` + + :return: ``True`` if connection is to be kept alive ``False`` else + """ + return ( + self._keep_alive + and not self.signal.stopped + and self.parser + and self.parser.should_keep_alive() + ) + + # -------------------------------------------- # + # Parsing + # -------------------------------------------- # + + def data_received(self, data): + # Check for the request itself getting too large and exceeding + # memory limits + self._total_request_size += len(data) + if self._total_request_size > self.request_max_size: + self.write_error(PayloadTooLarge("Payload Too Large")) + + # Create parser if this is the first time we're receiving data + if self.parser is None: + assert self.request is None + self.headers = [] + self.parser = HttpRequestParser(self) + + # requests count + self.state["requests_count"] = self.state["requests_count"] + 1 + + # Parse request chunk or close connection + try: + self.parser.feed_data(data) + except HttpParserError: + message = "Bad Request" + if self._debug: + message += "\n" + traceback.format_exc() + self.write_error(InvalidUsage(message)) + + def on_url(self, url): + if not self.url: + self.url = url + else: + self.url += url + + def on_header(self, name, value): + self._header_fragment += name + + if value is not None: + if ( + self._header_fragment == b"Content-Length" + and int(value) > self.request_max_size + ): + self.write_error(PayloadTooLarge("Payload Too Large")) + try: + value = value.decode() + except UnicodeDecodeError: + value = value.decode("latin_1") + self.headers.append( + (self._header_fragment.decode().casefold(), value) + ) + + self._header_fragment = b"" + + def on_headers_complete(self): + self.request = self.request_class( + url_bytes=self.url, + headers=CIMultiDict(self.headers), + version=self.parser.get_http_version(), + method=self.parser.get_method().decode(), + transport=self.transport, + app=self.app, + ) + + if self.request.headers.get(EXPECT_HEADER): + self.expect_handler() + + if self.is_request_stream: + self._is_stream_handler = self.router.is_stream_handler( + self.request + ) + if self._is_stream_handler: + self.request.stream = StreamBuffer( + self.request_buffer_queue_size + ) + self.execute_request_handler() + + def expect_handler(self): + """ + Handler for Expect Header. + """ + expect = self.request.headers.get(EXPECT_HEADER) + if self.request.version == "1.1": + if expect.lower() == "100-continue": + self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") + else: + self.write_error( + HeaderExpectationFailed( + "Unknown Expect: {expect}".format(expect=expect) + ) + ) + + def on_body(self, body): + if self.is_request_stream and self._is_stream_handler: + self.nursery.start_soon(self.body_append, body) + else: + self.request.body_push(body) + + async def body_append(self, body): + if self.request.stream.is_full(): + self.transport.pause_reading() + await self.request.stream.put(body) + self.transport.resume_reading() + else: + await self.request.stream.put(body) + + def on_message_complete(self): + # Entire request (headers and whole body) is received. + if self.is_request_stream and self._is_stream_handler: + self._request_stream_task = self.loop.create_task( + self.request.stream.put(None) + ) + return + self.request.body_finish() + self.execute_request_handler() + + def execute_request_handler(self): + """ + Invoke the request handler defined by the + :func:`sanic.app.Sanic.handle_request` method + + :return: None + """ + self.nursery.cancel_scope.deadline = trio.current_time() + self.request_timeout + self._last_request_time = time() + self.nursery.start_soon( + self.request_handler, self.request, self.write_response, self.stream_response + ) + + # -------------------------------------------- # + # Responding + # -------------------------------------------- # + def log_response(self, response): + """ + Helper method provided to enable the logging of responses in case if + the :attr:`HttpProtocol.access_log` is enabled. + + :param response: Response generated for the current request + + :type response: :class:`sanic.response.HTTPResponse` or + :class:`sanic.response.StreamingHTTPResponse` + + :return: None + """ + if self.access_log: + extra = {"status": getattr(response, "status", 0)} + + if isinstance(response, HTTPResponse): + extra["byte"] = len(response.body) + else: + extra["byte"] = -1 + + extra["host"] = "UNKNOWN" + if self.request is not None: + if self.request.ip: + extra["host"] = "{0}:{1}".format( + self.request.ip, self.request.port + ) + + extra["request"] = "{0} {1}".format( + self.request.method, self.request.url + ) + else: + extra["request"] = "nil" + + access_logger.info("", extra=extra) + + async def write_response(self, response): + """ + Writes response content synchronously to the transport. + """ + keep_alive = self.keep_alive + try: + await self.transport.send_all( + response.output( + self.request.version, keep_alive, self.keep_alive_timeout + ) + ) + self.log_response(response) + except AttributeError: + logger.error( + "Invalid response object for url %s, " + "Expected Type: HTTPResponse, Actual Type: %s", + self.url, + type(response), + ) + self.write_error(ServerError("Invalid response type")) + except RuntimeError: + if self._debug: + logger.error( + "Connection lost before response written @ %s", + self.request.ip, + ) + keep_alive = False + except Exception as e: + self.bail_out( + "Writing response failed, connection closed {}".format(repr(e)) + ) + finally: + if not keep_alive: + self.nursery.cancel_scope.cancel() + else: + self._last_response_time = time() + self.cleanup() + + async def drain(self): + await self._not_paused.wait() + + async def push_data(self, data): + self.transport.write(data) + + async def stream_response(self, response): + """ + Streams a response to the client asynchronously. Attaches + the transport to the response so the response consumer can + write to the response as needed. + """ + if self._response_timeout_handler: + self._response_timeout_handler.cancel() + self._response_timeout_handler = None + + try: + keep_alive = self.keep_alive + response.protocol = self + await response.stream( + self.request.version, keep_alive, self.keep_alive_timeout + ) + self.log_response(response) + except AttributeError: + logger.error( + "Invalid response object for url %s, " + "Expected Type: HTTPResponse, Actual Type: %s", + self.url, + type(response), + ) + self.write_error(ServerError("Invalid response type")) + except RuntimeError: + if self._debug: + logger.error( + "Connection lost before response written @ %s", + self.request.ip, + ) + keep_alive = False + except Exception as e: + self.bail_out( + "Writing response failed, connection closed {}".format(repr(e)) + ) + finally: + if not keep_alive: + self.transport.close() + self.transport = None + else: + self.nursery.cancel_scope.deadline = trio.current_time() + self.keep_alive_timeout + self._last_response_time = time() + self.cleanup() + + def write_error(self, exception): + response = None + try: + response = self.error_handler.response(self.request, exception) + version = self.request.version if self.request else "1.1" + self.transport.send_all(response.output(version)) + except RuntimeError: + if self._debug: + logger.error( + "Connection lost before error written @ %s", + self.request.ip if self.request else "Unknown", + ) + except Exception as e: + self.bail_out( + "Writing error failed, connection closed {}".format(repr(e)), + from_error=True, + ) + finally: + if self.parser and ( + self.keep_alive or getattr(response, "status", 0) == 408 + ): + self.log_response(response) + try: + self.transport.close() + except AttributeError: + logger.debug("Connection lost before server could close it.") + + def bail_out(self, message, from_error=False): + """ + In case if the transport pipes are closed and the sanic app encounters + an error while writing data to the transport pipe, we log the error + with proper details. + + :param message: Error message to display + :param from_error: If the bail out was invoked while handling an + exception scenario. + + :type message: str + :type from_error: bool + + :return: None + """ + if from_error or self.transport is None or self.transport.is_closing(): + logger.error( + "Transport closed @ %s and exception " + "experienced during error handling", + ( + self.transport.get_extra_info("peername") + if self.transport is not None + else "N/A" + ), + ) + logger.debug("Exception:", exc_info=True) + else: + self.write_error(ServerError(message)) + logger.error(message) + + def cleanup(self): + """This is called when KeepAlive feature is used, + it resets the connection in order for it to be able + to handle receiving another request on the same connection.""" + self.parser = None + self.request = None + self.url = None + self.headers = None + self._total_request_size = 0 + self._is_stream_handler = False + + async def run(self, stream): + async with stream, trio.open_nursery() as self.nursery: + stream.get_extra_info = lambda option: "fake address" + self.transport = stream + self.nursery.cancel_scope.deadline = trio.current_time() + self.request_timeout + async for data in stream: + self.data_received(data) + +async def trigger_events(events): + """Trigger event callbacks (functions or async) + + :param events: one or more sync or async functions to execute + :param loop: event loop + """ + for event in events: + result = event() + if isawaitable(result): + await result + +def bind_socket(host: str, port: int) -> socket: + """Create socket and bind to host. + :param host: IPv4, IPv6, hostname or unix:/tmp/socket may be specified + :param port: IP port number, 0 or None for UNIX sockets + :return: socket.socket object + """ + if host.lower().startswith("unix:"): # UNIX socket + name = host[5:] + sock = socket.socket(socket.AF_UNIX) + if os.path.exists(name) and stat.S_ISSOCK(os.stat(name).st_mode): + os.unlink(name) + oldmask = os.umask(0o111) + try: + sock.bind(name) + finally: + os.umask(oldmask) + return sock + try: # IP address: family must be specified for IPv6 at least + ip = ip_address(host) + host = str(ip) + sock = socket.socket( + socket.AF_INET6 if ip.version == 6 else socket.AF_INET + ) + except ValueError: # Hostname, may become AF_INET or AF_INET6 + sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((host, port)) + return sock + + +def serve_multiple(server_settings, workers): + """Start multiple server processes simultaneously. Stop on interrupt + and terminate signals, and drain connections when complete. + + :param server_settings: kw arguments to be passed to the serve function + :param workers: number of workers to launch + :param stop_event: if provided, is used as a stop signal + :return: + """ + server_settings["reuse_port"] = True + server_settings["run_multiple"] = True + + # Handling when custom socket is not provided. + if server_settings.get("sock") is None: + sock = bind_socket(server_settings["host"], server_settings["port"]) + sock.set_inheritable(True) + server_settings["sock"] = sock + server_settings["host"] = None + server_settings["port"] = None + + processes = [] + + def sig_handler(signal, frame): + logger.info("Received signal %s. Shutting down.", Signals(signal).name) + for process in processes: + os.kill(process.pid, SIGTERM) + + signal_func(SIGINT, lambda s, f: sig_handler(s, f)) + signal_func(SIGTERM, lambda s, f: sig_handler(s, f)) + + for _ in range(workers): + process = Process(target=serve, kwargs=server_settings) + process.daemon = True + process.start() + processes.append(process) + + for process in processes: + process.join() + + # the above processes will block this until they're stopped + for process in processes: + process.terminate() + + sock = server_settings.get("sock") + sockname = sock.getsockname() + sock.close() + # Remove UNIX socket + if isinstance(sockname, str): + os.unlink(sockname) + +def serve( + host, + port, + app, + request_handler, + error_handler, + before_start=None, + after_start=None, + before_stop=None, + after_stop=None, + debug=False, + request_timeout=60, + response_timeout=60, + keep_alive_timeout=5, + ssl=None, + sock=None, + request_max_size=None, + request_buffer_queue_size=100, + reuse_port=False, + protocol=HttpProtocol, + backlog=100, + register_sys_signals=True, + run_multiple=False, + run_async=False, + connections=None, + signal=Signal(), + request_class=None, + access_log=True, + keep_alive=True, + is_request_stream=False, + router=None, + websocket_max_size=None, + websocket_max_queue=None, + websocket_read_limit=2 ** 16, + websocket_write_limit=2 ** 16, + state=None, + graceful_shutdown_timeout=15.0, + asyncio_server_kwargs=None, + loop=None, +): + async def handle_connection(stream): + proto = protocol( + connections=connections, + signal=signal, + app=app, + request_handler=request_handler, + error_handler=error_handler, + request_timeout=request_timeout, + response_timeout=response_timeout, + keep_alive_timeout=keep_alive_timeout, + request_max_size=request_max_size, + request_class=request_class, + access_log=access_log, + keep_alive=keep_alive, + is_request_stream=is_request_stream, + router=router, + websocket_max_size=websocket_max_size, + websocket_max_queue=websocket_max_queue, + websocket_read_limit=websocket_read_limit, + websocket_write_limit=websocket_write_limit, + state=state, + debug=debug, + loop=None, + ) + await proto.run(stream) + + app.asgi = False + assert not (run_async or run_multiple or asyncio_server_kwargs or loop), "Not implemented" + + server = partial( + runserver, + host, + port, + before_start, + after_start, + before_stop, + after_stop, + ssl, + sock, + reuse_port, + backlog, + register_sys_signals, + run_multiple, + graceful_shutdown_timeout, + handle_connection + ) + return server() if run_async else trio.run(server) + +async def runserver( + host, + port, + before_start, + after_start, + before_stop, + after_stop, + ssl, + sock, + reuse_port, + backlog, + register_sys_signals, + run_multiple, + graceful_shutdown_timeout, + handle_connection +): + async with trio.open_nursery() as main_nursery: + await trigger_events(before_start) + # open_tcp_listeners cannot bind UNIX sockets, so do it here + if host and host.startswith("unix:"): + unix_socket_name = host[5:] + sock, host, port = bind_socket(host, port), None, None + else: + unix_socket_name = None + try: + listeners = await trio.open_tcp_listeners( + host=host, port=port or 8000, backlog=backlog + ) + except Exception: + logger.exception("Unable to start server") + return + await trigger_events(after_start) + pid = os.getpid() + logger.info("Starting worker [%s]", pid) + # Accept connections until a signal is received, then perform graceful exit + async with trio.open_nursery() as acceptor: + acceptor.start_soon(partial( + trio.serve_listeners, + handle_connection, + listeners, + handler_nursery=main_nursery + )) + with trio.open_signal_receiver(SIGINT, SIGTERM) as sigiter: + async for _ in sigiter: + acceptor.cancel_scope.cancel() + break + logger.info("Stopping worker [%s]", pid) + await trigger_events(before_stop) + if unix_socket_name: + os.unlink(unix_socket_name) + main_nursery.cancel_scope.deadline = trio.current_time() + graceful_shutdown_timeout + await trigger_events(after_stop) From 2176c8178b6ee44bbd26a4b831b0afb46cb239fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Tue, 27 Aug 2019 16:54:43 +0300 Subject: [PATCH 03/31] A very minimal Trio-based HTTP server. --- sanic/server_trio.py | 734 +++++++++---------------------------------- 1 file changed, 155 insertions(+), 579 deletions(-) diff --git a/sanic/server_trio.py b/sanic/server_trio.py index dcf2860e2f..a859c9339b 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -2,20 +2,23 @@ import os import socket import stat +import sys +import time import traceback from functools import partial from inspect import isawaitable from ipaddress import ip_address from multiprocessing import Process -from signal import SIG_IGN, SIGINT, SIGTERM, Signals +from signal import SIG_IGN, SIGINT, SIGTERM, SIGHUP, Signals from signal import signal as signal_func -from time import time +from time import time, sleep as time_sleep from httptools import HttpRequestParser from httptools.parser.errors import HttpParserError from multidict import CIMultiDict +from sanic.compat import Header from sanic.exceptions import ( HeaderExpectationFailed, InvalidUsage, @@ -31,450 +34,61 @@ class Signal: stopped = False - class HttpProtocol: - """ - This class provides a basic HTTP implementation of the sanic framework. - """ - - __slots__ = ( - # app - "app", - # event loop, connection - "loop", - "transport", - "connections", - "signal", - # request params - "parser", - "request", - "url", - "headers", - # request config - "request_handler", - "request_timeout", - "response_timeout", - "keep_alive_timeout", - "request_max_size", - "request_buffer_queue_size", - "request_class", - "is_request_stream", - "router", - "error_handler", - # enable or disable access log purpose - "access_log", - # connection management - "_total_request_size", - "_last_request_time", - "_last_response_time", - "_is_stream_handler", - "_keep_alive", - "_header_fragment", - "state", - "_debug", - "nursery", - ) - - def __init__( - self, - *, - loop, - app, - request_handler, - error_handler, - signal=Signal(), - connections=None, - request_timeout=60, - response_timeout=60, - keep_alive_timeout=5, - request_max_size=None, - request_buffer_queue_size=100, - request_class=None, - access_log=True, - keep_alive=True, - is_request_stream=False, - router=None, - state=None, - debug=False, - **kwargs - ): - self.app = app - self.transport = None - self.request = None - self.parser = None - self.url = None - self.headers = None - self.router = router - self.signal = signal - self.access_log = access_log - self.connections = connections or {} - self.request_handler = request_handler - self.error_handler = error_handler - self.request_timeout = request_timeout - self.request_buffer_queue_size = request_buffer_queue_size - self.response_timeout = response_timeout - self.keep_alive_timeout = keep_alive_timeout - self.request_max_size = request_max_size - self.request_class = request_class or Request - self.is_request_stream = is_request_stream - self._is_stream_handler = False - self._total_request_size = 0 - self._last_request_time = None - self._last_response_time = None - self._keep_alive = keep_alive - self._header_fragment = b"" - self.state = state or {} - if "requests_count" not in self.state: - self.state["requests_count"] = 0 - self._debug = debug - - @property - def keep_alive(self): - """ - Check if the connection needs to be kept alive based on the params - attached to the `_keep_alive` attribute, :attr:`Signal.stopped` - and :func:`HttpProtocol.parser.should_keep_alive` - - :return: ``True`` if connection is to be kept alive ``False`` else - """ - return ( - self._keep_alive - and not self.signal.stopped - and self.parser - and self.parser.should_keep_alive() - ) - - # -------------------------------------------- # - # Parsing - # -------------------------------------------- # - - def data_received(self, data): - # Check for the request itself getting too large and exceeding - # memory limits - self._total_request_size += len(data) - if self._total_request_size > self.request_max_size: - self.write_error(PayloadTooLarge("Payload Too Large")) - - # Create parser if this is the first time we're receiving data - if self.parser is None: - assert self.request is None - self.headers = [] - self.parser = HttpRequestParser(self) - - # requests count - self.state["requests_count"] = self.state["requests_count"] + 1 - - # Parse request chunk or close connection - try: - self.parser.feed_data(data) - except HttpParserError: - message = "Bad Request" - if self._debug: - message += "\n" + traceback.format_exc() - self.write_error(InvalidUsage(message)) - - def on_url(self, url): - if not self.url: - self.url = url - else: - self.url += url - - def on_header(self, name, value): - self._header_fragment += name - - if value is not None: - if ( - self._header_fragment == b"Content-Length" - and int(value) > self.request_max_size - ): - self.write_error(PayloadTooLarge("Payload Too Large")) - try: - value = value.decode() - except UnicodeDecodeError: - value = value.decode("latin_1") - self.headers.append( - (self._header_fragment.decode().casefold(), value) - ) - - self._header_fragment = b"" - - def on_headers_complete(self): - self.request = self.request_class( - url_bytes=self.url, - headers=CIMultiDict(self.headers), - version=self.parser.get_http_version(), - method=self.parser.get_method().decode(), - transport=self.transport, - app=self.app, - ) - - if self.request.headers.get(EXPECT_HEADER): - self.expect_handler() - - if self.is_request_stream: - self._is_stream_handler = self.router.is_stream_handler( - self.request - ) - if self._is_stream_handler: - self.request.stream = StreamBuffer( - self.request_buffer_queue_size - ) - self.execute_request_handler() - - def expect_handler(self): - """ - Handler for Expect Header. - """ - expect = self.request.headers.get(EXPECT_HEADER) - if self.request.version == "1.1": - if expect.lower() == "100-continue": - self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") - else: - self.write_error( - HeaderExpectationFailed( - "Unknown Expect: {expect}".format(expect=expect) - ) - ) - - def on_body(self, body): - if self.is_request_stream and self._is_stream_handler: - self.nursery.start_soon(self.body_append, body) - else: - self.request.body_push(body) - - async def body_append(self, body): - if self.request.stream.is_full(): - self.transport.pause_reading() - await self.request.stream.put(body) - self.transport.resume_reading() - else: - await self.request.stream.put(body) - - def on_message_complete(self): - # Entire request (headers and whole body) is received. - if self.is_request_stream and self._is_stream_handler: - self._request_stream_task = self.loop.create_task( - self.request.stream.put(None) - ) - return - self.request.body_finish() - self.execute_request_handler() - - def execute_request_handler(self): - """ - Invoke the request handler defined by the - :func:`sanic.app.Sanic.handle_request` method - - :return: None - """ - self.nursery.cancel_scope.deadline = trio.current_time() + self.request_timeout - self._last_request_time = time() - self.nursery.start_soon( - self.request_handler, self.request, self.write_response, self.stream_response - ) - - # -------------------------------------------- # - # Responding - # -------------------------------------------- # - def log_response(self, response): - """ - Helper method provided to enable the logging of responses in case if - the :attr:`HttpProtocol.access_log` is enabled. - - :param response: Response generated for the current request - - :type response: :class:`sanic.response.HTTPResponse` or - :class:`sanic.response.StreamingHTTPResponse` - - :return: None - """ - if self.access_log: - extra = {"status": getattr(response, "status", 0)} - - if isinstance(response, HTTPResponse): - extra["byte"] = len(response.body) - else: - extra["byte"] = -1 - - extra["host"] = "UNKNOWN" - if self.request is not None: - if self.request.ip: - extra["host"] = "{0}:{1}".format( - self.request.ip, self.request.port - ) - - extra["request"] = "{0} {1}".format( - self.request.method, self.request.url - ) - else: - extra["request"] = "nil" - - access_logger.info("", extra=extra) - - async def write_response(self, response): - """ - Writes response content synchronously to the transport. - """ - keep_alive = self.keep_alive - try: - await self.transport.send_all( - response.output( - self.request.version, keep_alive, self.keep_alive_timeout - ) - ) - self.log_response(response) - except AttributeError: - logger.error( - "Invalid response object for url %s, " - "Expected Type: HTTPResponse, Actual Type: %s", - self.url, - type(response), - ) - self.write_error(ServerError("Invalid response type")) - except RuntimeError: - if self._debug: - logger.error( - "Connection lost before response written @ %s", - self.request.ip, - ) - keep_alive = False - except Exception as e: - self.bail_out( - "Writing response failed, connection closed {}".format(repr(e)) - ) - finally: - if not keep_alive: - self.nursery.cancel_scope.cancel() - else: - self._last_response_time = time() - self.cleanup() - - async def drain(self): - await self._not_paused.wait() - - async def push_data(self, data): - self.transport.write(data) - - async def stream_response(self, response): - """ - Streams a response to the client asynchronously. Attaches - the transport to the response so the response consumer can - write to the response as needed. - """ - if self._response_timeout_handler: - self._response_timeout_handler.cancel() - self._response_timeout_handler = None - - try: - keep_alive = self.keep_alive - response.protocol = self - await response.stream( - self.request.version, keep_alive, self.keep_alive_timeout - ) - self.log_response(response) - except AttributeError: - logger.error( - "Invalid response object for url %s, " - "Expected Type: HTTPResponse, Actual Type: %s", - self.url, - type(response), - ) - self.write_error(ServerError("Invalid response type")) - except RuntimeError: - if self._debug: - logger.error( - "Connection lost before response written @ %s", - self.request.ip, - ) - keep_alive = False - except Exception as e: - self.bail_out( - "Writing response failed, connection closed {}".format(repr(e)) - ) - finally: - if not keep_alive: - self.transport.close() - self.transport = None - else: - self.nursery.cancel_scope.deadline = trio.current_time() + self.keep_alive_timeout - self._last_response_time = time() - self.cleanup() - - def write_error(self, exception): - response = None - try: - response = self.error_handler.response(self.request, exception) - version = self.request.version if self.request else "1.1" - self.transport.send_all(response.output(version)) - except RuntimeError: - if self._debug: - logger.error( - "Connection lost before error written @ %s", - self.request.ip if self.request else "Unknown", - ) - except Exception as e: - self.bail_out( - "Writing error failed, connection closed {}".format(repr(e)), - from_error=True, - ) - finally: - if self.parser and ( - self.keep_alive or getattr(response, "status", 0) == 408 - ): - self.log_response(response) - try: - self.transport.close() - except AttributeError: - logger.debug("Connection lost before server could close it.") - - def bail_out(self, message, from_error=False): - """ - In case if the transport pipes are closed and the sanic app encounters - an error while writing data to the transport pipe, we log the error - with proper details. - - :param message: Error message to display - :param from_error: If the bail out was invoked while handling an - exception scenario. - - :type message: str - :type from_error: bool - - :return: None - """ - if from_error or self.transport is None or self.transport.is_closing(): - logger.error( - "Transport closed @ %s and exception " - "experienced during error handling", - ( - self.transport.get_extra_info("peername") - if self.transport is not None - else "N/A" - ), - ) - logger.debug("Exception:", exc_info=True) - else: - self.write_error(ServerError(message)) - logger.error(message) - - def cleanup(self): - """This is called when KeepAlive feature is used, - it resets the connection in order for it to be able - to handle receiving another request on the same connection.""" - self.parser = None - self.request = None - self.url = None - self.headers = None - self._total_request_size = 0 - self._is_stream_handler = False + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + self.request_class = self.request_class or Request async def run(self, stream): async with stream, trio.open_nursery() as self.nursery: - stream.get_extra_info = lambda option: "fake address" - self.transport = stream - self.nursery.cancel_scope.deadline = trio.current_time() + self.request_timeout - async for data in stream: - self.data_received(data) + try: + while True: + self.nursery.cancel_scope.deadline = trio.current_time() + self.request_timeout + # Read headers + buffer = bytearray() + async for data in stream: + prevpos = max(0, len(buffer) - 3) + buffer += data + pos = buffer.find(b"\r\n\r\n", prevpos) + if pos > 0: break # End of headers + if buffer > request_max_size: + self.error_response("request too large") + return + else: + return # Peer closed connection + headers = buffer[:pos] + del buffer[:pos + 4] + try: + headers = headers.decode() + except UnicodeDecodeError: + headers = headers.decode("ISO-8859-1") + req, *headers = headers.split("\r\n") + method, path, version = req.split(" ") + version = version[5:] + assert version == "1.1" + headers = dict(h.split(": ", 1) for h in headers) + self.nursery.cancel_scope.deadline = trio.current_time() + self.response_timeout + request = self.request_class( + url_bytes=path.encode(), + headers=Header(headers), + version=version, + method=method, + transport=None, + app=self.app, + ) + keep_alive = True + async def write_response(response): + await stream.send_all(response.output( + request.version, keep_alive, self.keep_alive_timeout + )) + await self.request_handler(request, write_response, None) + except trio.BrokenResourceError: + pass # Connection reset by peer + except Exception: + logger.exception("Error in server") + + async def error_response(self, message): + pass async def trigger_events(events): """Trigger event callbacks (functions or async) @@ -487,85 +101,9 @@ async def trigger_events(events): if isawaitable(result): await result -def bind_socket(host: str, port: int) -> socket: - """Create socket and bind to host. - :param host: IPv4, IPv6, hostname or unix:/tmp/socket may be specified - :param port: IP port number, 0 or None for UNIX sockets - :return: socket.socket object - """ - if host.lower().startswith("unix:"): # UNIX socket - name = host[5:] - sock = socket.socket(socket.AF_UNIX) - if os.path.exists(name) and stat.S_ISSOCK(os.stat(name).st_mode): - os.unlink(name) - oldmask = os.umask(0o111) - try: - sock.bind(name) - finally: - os.umask(oldmask) - return sock - try: # IP address: family must be specified for IPv6 at least - ip = ip_address(host) - host = str(ip) - sock = socket.socket( - socket.AF_INET6 if ip.version == 6 else socket.AF_INET - ) - except ValueError: # Hostname, may become AF_INET or AF_INET6 - sock = socket.socket() - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind((host, port)) - return sock - def serve_multiple(server_settings, workers): - """Start multiple server processes simultaneously. Stop on interrupt - and terminate signals, and drain connections when complete. - - :param server_settings: kw arguments to be passed to the serve function - :param workers: number of workers to launch - :param stop_event: if provided, is used as a stop signal - :return: - """ - server_settings["reuse_port"] = True - server_settings["run_multiple"] = True - - # Handling when custom socket is not provided. - if server_settings.get("sock") is None: - sock = bind_socket(server_settings["host"], server_settings["port"]) - sock.set_inheritable(True) - server_settings["sock"] = sock - server_settings["host"] = None - server_settings["port"] = None - - processes = [] - - def sig_handler(signal, frame): - logger.info("Received signal %s. Shutting down.", Signals(signal).name) - for process in processes: - os.kill(process.pid, SIGTERM) - - signal_func(SIGINT, lambda s, f: sig_handler(s, f)) - signal_func(SIGTERM, lambda s, f: sig_handler(s, f)) - - for _ in range(workers): - process = Process(target=serve, kwargs=server_settings) - process.daemon = True - process.start() - processes.append(process) - - for process in processes: - process.join() - - # the above processes will block this until they're stopped - for process in processes: - process.terminate() - - sock = server_settings.get("sock") - sockname = sock.getsockname() - sock.close() - # Remove UNIX socket - if isinstance(sockname, str): - os.unlink(sockname) + return serve(**server_settings, workers=workers) def serve( host, @@ -585,7 +123,6 @@ def serve( sock=None, request_max_size=None, request_buffer_queue_size=100, - reuse_port=False, protocol=HttpProtocol, backlog=100, register_sys_signals=True, @@ -605,6 +142,7 @@ def serve( state=None, graceful_shutdown_timeout=15.0, asyncio_server_kwargs=None, + workers=1, loop=None, ): async def handle_connection(stream): @@ -629,81 +167,119 @@ async def handle_connection(stream): websocket_write_limit=websocket_write_limit, state=state, debug=debug, - loop=None, ) await proto.run(stream) app.asgi = False assert not (run_async or run_multiple or asyncio_server_kwargs or loop), "Not implemented" + acceptor = partial( + runaccept, + before_start=before_start, + after_start=after_start, + before_stop=before_stop, + after_stop=after_stop, + handle_connection=handle_connection, + graceful_shutdown_timeout=graceful_shutdown_timeout, + ) + server = partial( runserver, - host, - port, - before_start, - after_start, - before_stop, - after_stop, - ssl, - sock, - reuse_port, - backlog, - register_sys_signals, - run_multiple, - graceful_shutdown_timeout, - handle_connection + acceptor=acceptor, + host=host, + port=port, + ssl=ssl, + sock=sock, + backlog=backlog, + workers=workers, ) - return server() if run_async else trio.run(server) + return server() if run_async else server() -async def runserver( +def runserver( + acceptor, host, port, - before_start, - after_start, - before_stop, - after_stop, ssl, sock, - reuse_port, backlog, - register_sys_signals, - run_multiple, - graceful_shutdown_timeout, - handle_connection + workers, ): - async with trio.open_nursery() as main_nursery: - await trigger_events(before_start) - # open_tcp_listeners cannot bind UNIX sockets, so do it here - if host and host.startswith("unix:"): - unix_socket_name = host[5:] - sock, host, port = bind_socket(host, port), None, None - else: - unix_socket_name = None - try: - listeners = await trio.open_tcp_listeners( - host=host, port=port or 8000, backlog=backlog - ) - except Exception: - logger.exception("Unable to start server") - return - await trigger_events(after_start) + if host and host.startswith("unix:"): + open_listeners = partial( + # Not Implemented: open_unix_listeners, path=host[5:], backlog=backlog + ) + else: + open_listeners = partial( + trio.open_tcp_listeners, + host=host, port=port or 8000, backlog=backlog + ) + try: + listeners = trio.run(open_listeners) + except Exception: + logger.exception("Unable to start server") + return + if ssl: + listeners = [ + trio.SSLListener(l, ssl, https_compatible=True) for l in listeners + ] + master_pid = os.getpid() + runworker = lambda: trio.run(acceptor, listeners, master_pid) + processes = [] + # Setup signal handlers to avoid crashing + sig = None + def handler(s, tb): + nonlocal sig + sig = s + for s in (SIGINT, SIGTERM, SIGHUP): + signal_func(s, handler) + + if workers: + for l in listeners: l.socket.set_inheritable(True) + while True: + while len(processes) < workers: + p = Process(target=runworker) + p.daemon = True + p.start() + processes.append(p) + time_sleep(0.1) # Poll for dead processes + processes = [p for p in processes if p.is_alive()] + s, sig = sig, None + if not s: + continue + for p in processes: os.kill(p.pid, SIGHUP) + if s in (SIGINT, SIGTERM): + break + for l in listeners: trio.run(l.aclose) + for p in processes: p.join() + else: + runworker() + + +async def runaccept(listeners, master_pid, before_start, after_start, before_stop, after_stop, handle_connection, graceful_shutdown_timeout): + try: pid = os.getpid() logger.info("Starting worker [%s]", pid) - # Accept connections until a signal is received, then perform graceful exit - async with trio.open_nursery() as acceptor: - acceptor.start_soon(partial( - trio.serve_listeners, - handle_connection, - listeners, - handler_nursery=main_nursery - )) - with trio.open_signal_receiver(SIGINT, SIGTERM) as sigiter: - async for _ in sigiter: + async with trio.open_nursery() as main_nursery: + await trigger_events(before_start) + # Accept connections until a signal is received, then perform graceful exit + async with trio.open_nursery() as acceptor: + acceptor.start_soon(partial( + trio.serve_listeners, + handler=handle_connection, + listeners=listeners, + handler_nursery=main_nursery + )) + await trigger_events(after_start) + # Wait for a signal and then exit gracefully + with trio.open_signal_receiver(SIGINT, SIGTERM, SIGHUP) as sigiter: + s = await sigiter.__anext__() + logger.info(f"Received {Signals(s).name}") + if s != SIGHUP: + os.kill(master_pid, SIGTERM) acceptor.cancel_scope.cancel() - break - logger.info("Stopping worker [%s]", pid) - await trigger_events(before_stop) - if unix_socket_name: - os.unlink(unix_socket_name) - main_nursery.cancel_scope.deadline = trio.current_time() + graceful_shutdown_timeout - await trigger_events(after_stop) + main_nursery.cancel_scope.deadline = trio.current_time() + graceful_shutdown_timeout + await trigger_events(before_stop) + await trigger_events(after_stop) + logger.info(f"Gracefully finished worker [{pid}]") + except BaseException as e: + logger.exception(f"Stopped worker [{pid}]") From 387287d16c5cb730ada947af4bcf269d6ad46f85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Tue, 27 Aug 2019 17:57:01 +0300 Subject: [PATCH 04/31] Add quick termination of idle connections on shutdown. --- sanic/server_trio.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/sanic/server_trio.py b/sanic/server_trio.py index a859c9339b..4a3e832280 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -34,19 +34,23 @@ class Signal: stopped = False +idle_connections = set() + class HttpProtocol: def __init__(self, **kwargs): self.__dict__.update(kwargs) self.request_class = self.request_class or Request async def run(self, stream): - async with stream, trio.open_nursery() as self.nursery: - try: + try: + async with stream, trio.open_nursery() as self.nursery: while True: self.nursery.cancel_scope.deadline = trio.current_time() + self.request_timeout # Read headers buffer = bytearray() + idle_connections.add(self.nursery.cancel_scope) async for data in stream: + idle_connections.remove(self.nursery.cancel_scope) prevpos = max(0, len(buffer) - 3) buffer += data pos = buffer.find(b"\r\n\r\n", prevpos) @@ -82,10 +86,12 @@ async def write_response(response): request.version, keep_alive, self.keep_alive_timeout )) await self.request_handler(request, write_response, None) - except trio.BrokenResourceError: - pass # Connection reset by peer - except Exception: - logger.exception("Error in server") + except trio.BrokenResourceError: + pass # Connection reset by peer + except Exception: + logger.exception("Error in server") + finally: + idle_connections.remove(self.nursery.cancel_scope) async def error_response(self, message): pass @@ -277,7 +283,9 @@ async def runaccept(listeners, master_pid, before_start, after_start, before_sto if s != SIGHUP: os.kill(master_pid, SIGTERM) acceptor.cancel_scope.cancel() - main_nursery.cancel_scope.deadline = trio.current_time() + graceful_shutdown_timeout + now = trio.current_time() + for c in idle_connections: c.deadline = now + 0.1 + main_nursery.cancel_scope.deadline = now + graceful_shutdown_timeout await trigger_events(before_stop) await trigger_events(after_stop) logger.info(f"Gracefully finished worker [{pid}]") From 94cc4c2fda8eed05073feb7d63247f1c5b6bcbb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Tue, 27 Aug 2019 18:11:28 +0300 Subject: [PATCH 05/31] Just cancel 'em straight away. --- sanic/server_trio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 4a3e832280..8c1608cce6 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -284,7 +284,7 @@ async def runaccept(listeners, master_pid, before_start, after_start, before_sto os.kill(master_pid, SIGTERM) acceptor.cancel_scope.cancel() now = trio.current_time() - for c in idle_connections: c.deadline = now + 0.1 + for c in idle_connections: c.cancel() main_nursery.cancel_scope.deadline = now + graceful_shutdown_timeout await trigger_events(before_stop) await trigger_events(after_stop) From 26e43755b73f46050dd7b06e9fe14c8870f9ffcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Tue, 27 Aug 2019 18:14:46 +0300 Subject: [PATCH 06/31] Allow auto_reload with multiple workers and use serve workers= argument instead of serve_multiple --- sanic/app.py | 27 ++++++++++++--------------- sanic/server_trio.py | 3 --- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 86ed12feb2..538789cc93 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -24,7 +24,7 @@ from sanic.log import LOGGING_CONFIG_DEFAULTS, error_logger, logger from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.router import Router -from sanic.server import HttpProtocol, Signal, serve, serve_multiple +from sanic.server import HttpProtocol, Signal, serve from sanic.static import register as static_register from sanic.testing import SanicASGITestClient, SanicTestClient from sanic.views import CompositionView @@ -1120,21 +1120,18 @@ def run( try: self.is_running = True - if workers == 1: - if auto_reload and os.name != "posix": - # This condition must be removed after implementing - # auto reloader for other operating systems. - raise NotImplementedError - - if ( - auto_reload - and os.environ.get("SANIC_SERVER_RUNNING") != "true" - ): - reloader_helpers.watchdog(2) - else: - serve(**server_settings) + if auto_reload and os.name != "posix": + # This condition must be removed after implementing + # auto reloader for other operating systems. + raise NotImplementedError + + if ( + auto_reload + and os.environ.get("SANIC_SERVER_RUNNING") != "true" + ): + reloader_helpers.watchdog(2) else: - serve_multiple(server_settings, workers) + serve(**server_settings, workers=workers) except BaseException: error_logger.exception( "Experienced exception while trying to serve" diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 8c1608cce6..7dabde83ba 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -108,9 +108,6 @@ async def trigger_events(events): await result -def serve_multiple(server_settings, workers): - return serve(**server_settings, workers=workers) - def serve( host, port, From 2760a904bb4e830819d6916752bdd19a764fdb36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Wed, 28 Aug 2019 10:35:34 +0300 Subject: [PATCH 07/31] server.py switching (for testing) --- sanic/server.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 sanic/server.py diff --git a/sanic/server.py b/sanic/server.py new file mode 100644 index 0000000000..f02a72f5ca --- /dev/null +++ b/sanic/server.py @@ -0,0 +1 @@ +from sanic.server_trio import * From 183413a345f5459beacfdb3f879d7ceec9677855 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Thu, 29 Aug 2019 13:11:46 +0300 Subject: [PATCH 08/31] HTTP/2 support etc. --- sanic/app.py | 1 + sanic/server_trio.py | 276 ++++++++++++++++++++++++++++++++----------- 2 files changed, 209 insertions(+), 68 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 538789cc93..ad2fed4bb1 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1298,6 +1298,7 @@ def _helper( raise ValueError("SSLContext or certificate and key required.") context = create_default_context(purpose=Purpose.CLIENT_AUTH) context.load_cert_chain(cert, keyfile=key) + context.set_alpn_protocols(["h2", "http/1.1"]) ssl = context if stop_event is not None: if debug: diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 7dabde83ba..cd4a1153c5 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -31,71 +31,192 @@ from sanic.request import EXPECT_HEADER, Request, StreamBuffer from sanic.response import HTTPResponse + class Signal: stopped = False + idle_connections = set() + class HttpProtocol: def __init__(self, **kwargs): self.__dict__.update(kwargs) self.request_class = self.request_class or Request + self.stream = None async def run(self, stream): + assert not self.stream + self.stream = stream + self.ssl = False try: async with stream, trio.open_nursery() as self.nursery: - while True: - self.nursery.cancel_scope.deadline = trio.current_time() + self.request_timeout - # Read headers - buffer = bytearray() - idle_connections.add(self.nursery.cancel_scope) - async for data in stream: - idle_connections.remove(self.nursery.cancel_scope) - prevpos = max(0, len(buffer) - 3) - buffer += data - pos = buffer.find(b"\r\n\r\n", prevpos) - if pos > 0: break # End of headers - if buffer > request_max_size: - self.error_response("request too large") - return - else: - return # Peer closed connection - headers = buffer[:pos] - del buffer[:pos + 4] - try: - headers = headers.decode() - except UnicodeDecodeError: - headers = headers.decode("ISO-8859-1") - req, *headers = headers.split("\r\n") - method, path, version = req.split(" ") - version = version[5:] - assert version == "1.1" - headers = dict(h.split(": ", 1) for h in headers) - self.nursery.cancel_scope.deadline = trio.current_time() + self.response_timeout - request = self.request_class( - url_bytes=path.encode(), - headers=Header(headers), - version=version, - method=method, - transport=None, - app=self.app, - ) - keep_alive = True - async def write_response(response): - await stream.send_all(response.output( - request.version, keep_alive, self.keep_alive_timeout - )) - await self.request_handler(request, write_response, None) + await self.http1() + self.nursery.cancel_scope.cancel() except trio.BrokenResourceError: pass # Connection reset by peer except Exception: logger.exception("Error in server") finally: - idle_connections.remove(self.nursery.cancel_scope) + idle_connections.discard(self.nursery.cancel_scope) async def error_response(self, message): pass + async def http1(self): + while True: + self.nursery.cancel_scope.deadline = ( + trio.current_time() + self.request_timeout + ) + # Read headers + buffer = bytearray() + idle_connections.add(self.nursery.cancel_scope) + async for data in self.stream: + idle_connections.discard(self.nursery.cancel_scope) + prevpos = max(0, len(buffer) - 3) + buffer += data + pos = buffer.find(b"\r\n\r\n", prevpos) + if pos > 0: + break # End of headers + if buffer > request_max_size: + self.error_response("request too large") + return + else: + return # Peer closed connection + headers = buffer[:pos] + if headers == b"PRI * HTTP/2.0": # HTTP2 without Upgrade (nghttp) + return await self.http2(data_received=buffer) + del buffer[: pos + 4] + try: + headers = headers.decode() + except UnicodeDecodeError: + headers = headers.decode("ISO-8859-1") + req, *headers = headers.split("\r\n") + method, path, version = req.split(" ") + version = version[5:] + assert version == "1.1" + headers = [ + (name.casefold(), value) + for name, value in (h.split(": ", 1) for h in headers) + ] + hdrs = {} + for name, value in headers: + old = hdrs.get(name) + hdrs[name] = value if old is None else f"{old}, {value}" + if hdrs.get("upgrade") == "h2c": # HTTP2 with Upgrade (curl) + assert ( + not buffer + ), "Extra bytes after upgrade request: {buffer!r}" + return await self.http2(settings_header=hdrs["http2-settings"]) + # Process response + self.nursery.cancel_scope.deadline = ( + trio.current_time() + self.response_timeout + ) + request = self.request_class( + url_bytes=path.encode(), + headers=Header(headers), + version=version, + method=method, + transport=None, + app=self.app, + ) + keep_alive = True + + async def write_response(response): + await self.stream.send_all( + response.output( + request.version, keep_alive, self.keep_alive_timeout + ) + ) + + await self.request_handler(request, write_response, None) + + async def h2_sender(self): + async for _ in self.can_send: + await self.stream.send_all(self.conn.data_to_send()) + + async def http2(self, data_received=None, settings_header=None): + import h2 + from h2.events import ( + RequestReceived, + DataReceived, + ConnectionTerminated, + ) + + config = h2.config.H2Configuration( + client_side=False, + header_encoding="utf-8", + validate_outbound_headers=False, + normalize_outbound_headers=False, + validate_inbound_headers=False, + normalize_inbound_headers=False, + logger=None, # logger + ) + self.conn = h2.connection.H2Connection(config=config) + if settings_header: # Upgrade from HTTP 1.1 + self.conn.initiate_upgrade_connection(settings_header) + await self.stream.send_all( + b"HTTP/1.1 101 Switching Protocols\r\n" + b"Connection: Upgrade\r\n" + b"Upgrade: h2c\r\n\r\n" + c.data_to_send() + ) + else: # h2 ALPN negotiated on SSL init + self.conn.initiate_connection() + await self.stream.send_all(self.conn.data_to_send()) + # A trigger mechanism that ensures promptly sending data from self.conn + # to stream; size must be > 0 to avoid data left unsent in buffer + # when a stream is canceled while awaiting on send_some. + self.send_some, self.can_send = trio.open_memory_channel(1) + self.nursery.start_soon(self.h2_sender) + idle_connections.add(self.nursery.cancel_scope) + async for data in self.stream: + if data_received: + data = data_received + data + data_received = None + # print(">>>", data) + for event in self.conn.receive_data(data): + # print("-*-", event) + if isinstance(event, RequestReceived): + self.nursery.start_soon( + self.h2request, event.stream_id, event.headers + ) + idle_connections.discard(self.nursery.cancel_scope) + if isinstance(event, ConnectionTerminated): + return + await self.send_some.send(...) + + async def h2request(self, stream_id, headers): + hdrs = {} + for name, value in headers: + old = hdrs.get(name) + hdrs[name] = value if old is None else f"{old}, {value}" + # Process response + self.nursery.cancel_scope.deadline = ( + trio.current_time() + self.response_timeout + ) + request = self.request_class( + url_bytes=hdrs.get(":path", "").encode(), + headers=Header(headers), + version="h2", + method=hdrs[":method"], + transport=None, + app=self.app, + ) + + async def write_response(response): + headers = ( + (":status", f"{response.status}"), + ("content-length", f"{len(response.body)}"), + ("content-type", response.content_type), + *response.headers, + ) + self.conn.send_headers(stream_id, headers) + self.conn.send_data(stream_id, response.body, end_stream=True) + await self.send_some.send(...) + + await self.request_handler(request, write_response, None) + + async def trigger_events(events): """Trigger event callbacks (functions or async) @@ -174,7 +295,9 @@ async def handle_connection(stream): await proto.run(stream) app.asgi = False - assert not (run_async or run_multiple or asyncio_server_kwargs or loop), "Not implemented" + assert not ( + run_async or run_multiple or asyncio_server_kwargs or loop + ), "Not implemented" acceptor = partial( runaccept, @@ -198,15 +321,8 @@ async def handle_connection(stream): ) return server() if run_async else server() -def runserver( - acceptor, - host, - port, - ssl, - sock, - backlog, - workers, -): + +def runserver(acceptor, host, port, ssl, sock, backlog, workers): if host and host.startswith("unix:"): open_listeners = partial( # Not Implemented: open_unix_listeners, path=host[5:], backlog=backlog @@ -214,13 +330,17 @@ def runserver( else: open_listeners = partial( trio.open_tcp_listeners, - host=host, port=port or 8000, backlog=backlog + host=host, + port=port or 8000, + backlog=backlog, ) try: listeners = trio.run(open_listeners) except Exception: logger.exception("Unable to start server") return + for l in listeners: + l.socket.set_inheritable(True) if ssl: listeners = [ trio.SSLListener(l, ssl, https_compatible=True) for l in listeners @@ -230,14 +350,15 @@ def runserver( processes = [] # Setup signal handlers to avoid crashing sig = None + def handler(s, tb): nonlocal sig sig = s + for s in (SIGINT, SIGTERM, SIGHUP): signal_func(s, handler) if workers: - for l in listeners: l.socket.set_inheritable(True) while True: while len(processes) < workers: p = Process(target=runworker) @@ -249,16 +370,28 @@ def handler(s, tb): s, sig = sig, None if not s: continue - for p in processes: os.kill(p.pid, SIGHUP) + for p in processes: + os.kill(p.pid, SIGHUP) if s in (SIGINT, SIGTERM): break - for l in listeners: trio.run(l.aclose) - for p in processes: p.join() + for l in listeners: + trio.run(l.aclose) + for p in processes: + p.join() else: runworker() -async def runaccept(listeners, master_pid, before_start, after_start, before_stop, after_stop, handle_connection, graceful_shutdown_timeout): +async def runaccept( + listeners, + master_pid, + before_start, + after_start, + before_stop, + after_stop, + handle_connection, + graceful_shutdown_timeout, +): try: pid = os.getpid() logger.info("Starting worker [%s]", pid) @@ -266,23 +399,30 @@ async def runaccept(listeners, master_pid, before_start, after_start, before_sto await trigger_events(before_start) # Accept connections until a signal is received, then perform graceful exit async with trio.open_nursery() as acceptor: - acceptor.start_soon(partial( - trio.serve_listeners, - handler=handle_connection, - listeners=listeners, - handler_nursery=main_nursery - )) + acceptor.start_soon( + partial( + trio.serve_listeners, + handler=handle_connection, + listeners=listeners, + handler_nursery=main_nursery, + ) + ) await trigger_events(after_start) # Wait for a signal and then exit gracefully - with trio.open_signal_receiver(SIGINT, SIGTERM, SIGHUP) as sigiter: + with trio.open_signal_receiver( + SIGINT, SIGTERM, SIGHUP + ) as sigiter: s = await sigiter.__anext__() logger.info(f"Received {Signals(s).name}") if s != SIGHUP: os.kill(master_pid, SIGTERM) acceptor.cancel_scope.cancel() now = trio.current_time() - for c in idle_connections: c.cancel() - main_nursery.cancel_scope.deadline = now + graceful_shutdown_timeout + for c in idle_connections: + c.cancel() + main_nursery.cancel_scope.deadline = ( + now + graceful_shutdown_timeout + ) await trigger_events(before_stop) await trigger_events(after_stop) logger.info(f"Gracefully finished worker [{pid}]") From fac8c36f97353d7ec1971be0f455c08d1ba0856a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Thu, 29 Aug 2019 18:02:25 +0300 Subject: [PATCH 09/31] Protocol and SSL autodetection, cleanup. --- sanic/server_trio.py | 203 ++++++++++++++++++++++++++----------------- 1 file changed, 124 insertions(+), 79 deletions(-) diff --git a/sanic/server_trio.py b/sanic/server_trio.py index cd4a1153c5..ca489526a2 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -7,16 +7,16 @@ import traceback from functools import partial +from h2.config import H2Configuration +from h2.connection import H2Connection +from h2.events import RequestReceived, DataReceived, ConnectionTerminated from inspect import isawaitable from ipaddress import ip_address from multiprocessing import Process from signal import SIG_IGN, SIGINT, SIGTERM, SIGHUP, Signals from signal import signal as signal_func from time import time, sleep as time_sleep - -from httptools import HttpRequestParser from httptools.parser.errors import HttpParserError -from multidict import CIMultiDict from sanic.compat import Header from sanic.exceptions import ( @@ -36,23 +36,103 @@ class Signal: stopped = False +h2config = H2Configuration( + client_side=False, + header_encoding="utf-8", + validate_outbound_headers=False, + normalize_outbound_headers=False, + validate_inbound_headers=False, + normalize_inbound_headers=False, + logger=None, # logger +) + + idle_connections = set() +def parse_h1_request(data: bytes) -> dict: + try: + data = data.decode() + except UnicodeDecodeError: + data = data.decode("ISO-8859-1") + req, *hlines = data.split("\r\n") + method, path, version = req.split(" ") + assert version == "HTTP/1.1" + headers = {":method": method, ":path": path} + for name, value in (h.split(": ", 1) for h in hlines): + name = name.lower() + old = headers.get(name) + headers[name] = value if old is None else f"{old}, {value}" + return headers + + +def push_back(stream, data): + orig_class = stream.__class__ + + class PushbackStream(orig_class): + async def receive_some(self, max_bytes=None): + if max_bytes and max_bytes < len(data): + ret = data[:max_bytes] + del data[:max_bytes] + return ret + self.__class__ = orig_class + return data + + stream.__class__ = PushbackStream + + class HttpProtocol: def __init__(self, **kwargs): self.__dict__.update(kwargs) self.request_class = self.request_class or Request self.stream = None + async def ssl_init(self): + self.stream = trio.SSLStream( + self.stream, self.ssl, server_side=True, https_compatible=True + ) + await self.stream.do_handshake() + self.alpn = self.stream.selected_alpn_protocol() + + async def sniff_protocol(self): + buffer = bytearray() + req = await self._receive_request_using(buffer) + if isinstance(req, bytearray): + # HTTP1 but might be Upgrade to websocket or h2c + headers = parse_h1_request(req) + upgrade = headers.get("upgrade") + if upgrade == "h2c": + return self.http2(settings_header=headers["http2-settings"]) + if upgrade == "websocket": + return self.websocket() + return self.http1(headers=headers) + push_back(self.stream, buffer) + if req == "ssl": + if not self.ssl: + raise RuntimeError("Only plain HTTP supported (not SSL).") + await self.ssl_init() + if not self.alpn or self.alpn == "http/1.1": + return self.http1() + if self.alpn == "h2": + return self.http2() + raise RuntimeError(f"Unknown ALPN {self.alpn}") + # HTTP2 (not Upgrade) + if req == "h2": + return self.http2() + async def run(self, stream): assert not self.stream self.stream = stream - self.ssl = False try: async with stream, trio.open_nursery() as self.nursery: - await self.http1() - self.nursery.cancel_scope.cancel() + self.nursery.cancel_scope.deadline = ( + trio.current_time() + self.request_timeout + ) + protocol_coroutine = await self.sniff_protocol() + if not protocol_coroutine: + return + await protocol_coroutine + self.nursery.cancel_scope.cancel() # Terminate all connections except trio.BrokenResourceError: pass # Connection reset by peer except Exception: @@ -60,66 +140,48 @@ async def run(self, stream): finally: idle_connections.discard(self.nursery.cancel_scope) - async def error_response(self, message): - pass - - async def http1(self): - while True: - self.nursery.cancel_scope.deadline = ( - trio.current_time() + self.request_timeout - ) - # Read headers - buffer = bytearray() - idle_connections.add(self.nursery.cancel_scope) + async def _receive_request_using(self, buffer: bytearray): + idle_connections.add(self.nursery.cancel_scope) + with trio.fail_after(self.request_timeout): async for data in self.stream: idle_connections.discard(self.nursery.cancel_scope) prevpos = max(0, len(buffer) - 3) buffer += data + if buffer[0] < 0x20: + return "ssl" + if len(buffer) > self.request_max_size: + raise RuntimeError("Request larger than request_max_size") pos = buffer.find(b"\r\n\r\n", prevpos) if pos > 0: - break # End of headers - if buffer > request_max_size: - self.error_response("request too large") + req = buffer[:pos] + if req == b"PRI * HTTP/2.0": return "h2" + del buffer[: pos + 4] + return req + if buffer: + raise RuntimeError("Peer disconnected after {buffer!r}") + + async def http1(self, headers=None): + buffer = bytearray() + while True: + # Process request + if headers is None: + req = await self._receive_request_using(buffer) + if not req: return - else: - return # Peer closed connection - headers = buffer[:pos] - if headers == b"PRI * HTTP/2.0": # HTTP2 without Upgrade (nghttp) - return await self.http2(data_received=buffer) - del buffer[: pos + 4] - try: - headers = headers.decode() - except UnicodeDecodeError: - headers = headers.decode("ISO-8859-1") - req, *headers = headers.split("\r\n") - method, path, version = req.split(" ") - version = version[5:] - assert version == "1.1" - headers = [ - (name.casefold(), value) - for name, value in (h.split(": ", 1) for h in headers) - ] - hdrs = {} - for name, value in headers: - old = hdrs.get(name) - hdrs[name] = value if old is None else f"{old}, {value}" - if hdrs.get("upgrade") == "h2c": # HTTP2 with Upgrade (curl) - assert ( - not buffer - ), "Extra bytes after upgrade request: {buffer!r}" - return await self.http2(settings_header=hdrs["http2-settings"]) + headers = parse_h1_request(req) # Process response self.nursery.cancel_scope.deadline = ( trio.current_time() + self.response_timeout ) request = self.request_class( - url_bytes=path.encode(), + url_bytes=headers[":path"].encode(), headers=Header(headers), - version=version, - method=method, + version="1.1", + method=headers[":method"], transport=None, app=self.app, ) + headers = None keep_alive = True async def write_response(response): @@ -130,35 +192,23 @@ async def write_response(response): ) await self.request_handler(request, write_response, None) + self.nursery.cancel_scope.deadline = ( + trio.current_time() + self.request_timeout + ) async def h2_sender(self): async for _ in self.can_send: await self.stream.send_all(self.conn.data_to_send()) - async def http2(self, data_received=None, settings_header=None): - import h2 - from h2.events import ( - RequestReceived, - DataReceived, - ConnectionTerminated, - ) - - config = h2.config.H2Configuration( - client_side=False, - header_encoding="utf-8", - validate_outbound_headers=False, - normalize_outbound_headers=False, - validate_inbound_headers=False, - normalize_inbound_headers=False, - logger=None, # logger - ) - self.conn = h2.connection.H2Connection(config=config) + async def http2(self, settings_header=None): + self.conn = H2Connection(config=h2config) if settings_header: # Upgrade from HTTP 1.1 self.conn.initiate_upgrade_connection(settings_header) await self.stream.send_all( b"HTTP/1.1 101 Switching Protocols\r\n" b"Connection: Upgrade\r\n" - b"Upgrade: h2c\r\n\r\n" + c.data_to_send() + b"Upgrade: h2c\r\n" + b"\r\n" + self.conn.data_to_send() ) else: # h2 ALPN negotiated on SSL init self.conn.initiate_connection() @@ -170,10 +220,6 @@ async def http2(self, data_received=None, settings_header=None): self.nursery.start_soon(self.h2_sender) idle_connections.add(self.nursery.cancel_scope) async for data in self.stream: - if data_received: - data = data_received + data - data_received = None - # print(">>>", data) for event in self.conn.receive_data(data): # print("-*-", event) if isinstance(event, RequestReceived): @@ -216,6 +262,9 @@ async def write_response(response): await self.request_handler(request, write_response, None) + async def websocket(self): + logger.info("Websocket requested, not yet implemented") + async def trigger_events(events): """Trigger event callbacks (functions or async) @@ -274,6 +323,7 @@ async def handle_connection(stream): connections=connections, signal=signal, app=app, + ssl=ssl, request_handler=request_handler, error_handler=error_handler, request_timeout=request_timeout, @@ -314,7 +364,6 @@ async def handle_connection(stream): acceptor=acceptor, host=host, port=port, - ssl=ssl, sock=sock, backlog=backlog, workers=workers, @@ -322,7 +371,7 @@ async def handle_connection(stream): return server() if run_async else server() -def runserver(acceptor, host, port, ssl, sock, backlog, workers): +def runserver(acceptor, host, port, sock, backlog, workers): if host and host.startswith("unix:"): open_listeners = partial( # Not Implemented: open_unix_listeners, path=host[5:], backlog=backlog @@ -341,10 +390,6 @@ def runserver(acceptor, host, port, ssl, sock, backlog, workers): return for l in listeners: l.socket.set_inheritable(True) - if ssl: - listeners = [ - trio.SSLListener(l, ssl, https_compatible=True) for l in listeners - ] master_pid = os.getpid() runworker = lambda: trio.run(acceptor, listeners, master_pid) processes = [] From 9bc620d4a3f272cc94bcadc74c7e80b169a1b527 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Thu, 29 Aug 2019 19:27:38 +0300 Subject: [PATCH 10/31] Get SSL SNI as protocol.servername. --- sanic/server_trio.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/sanic/server_trio.py b/sanic/server_trio.py index ca489526a2..1370e55030 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -86,8 +86,13 @@ def __init__(self, **kwargs): self.__dict__.update(kwargs) self.request_class = self.request_class or Request self.stream = None + self.servername = None # User-visible server hostname, no port! async def ssl_init(self): + def servername_callback(sock, req_hostname, cb_context): + self.servername = req_hostname + + self.ssl.sni_callback = servername_callback self.stream = trio.SSLStream( self.stream, self.ssl, server_side=True, https_compatible=True ) @@ -210,7 +215,7 @@ async def http2(self, settings_header=None): b"Upgrade: h2c\r\n" b"\r\n" + self.conn.data_to_send() ) - else: # h2 ALPN negotiated on SSL init + else: # straight into HTTP/2 mode self.conn.initiate_connection() await self.stream.send_all(self.conn.data_to_send()) # A trigger mechanism that ensures promptly sending data from self.conn @@ -219,6 +224,7 @@ async def http2(self, settings_header=None): self.send_some, self.can_send = trio.open_memory_channel(1) self.nursery.start_soon(self.h2_sender) idle_connections.add(self.nursery.cancel_scope) + self.requests = {} async for data in self.stream: for event in self.conn.receive_data(data): # print("-*-", event) @@ -226,7 +232,7 @@ async def http2(self, settings_header=None): self.nursery.start_soon( self.h2request, event.stream_id, event.headers ) - idle_connections.discard(self.nursery.cancel_scope) + #idle_connections.discard(self.nursery.cancel_scope) if isinstance(event, ConnectionTerminated): return await self.send_some.send(...) @@ -237,9 +243,6 @@ async def h2request(self, stream_id, headers): old = hdrs.get(name) hdrs[name] = value if old is None else f"{old}, {value}" # Process response - self.nursery.cancel_scope.deadline = ( - trio.current_time() + self.response_timeout - ) request = self.request_class( url_bytes=hdrs.get(":path", "").encode(), headers=Header(headers), @@ -260,7 +263,8 @@ async def write_response(response): self.conn.send_data(stream_id, response.body, end_stream=True) await self.send_some.send(...) - await self.request_handler(request, write_response, None) + with trio.fail_after(self.response_timeout): + await self.request_handler(request, write_response, None) async def websocket(self): logger.info("Websocket requested, not yet implemented") From bf32a195152d68b0156c60d059d05bae5c440eb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Mon, 2 Sep 2019 18:09:16 +0300 Subject: [PATCH 11/31] Streamlined request handler logic. Sketching new streaming framework. --- sanic/app.py | 25 ++++++++++++- sanic/exceptions.py | 5 +++ sanic/request.py | 1 + sanic/response.py | 86 ++++++++++++++++++++++++++++++++++++++++++++ sanic/server_trio.py | 82 ++++++++++++++++++++++++++++++------------ 5 files changed, 175 insertions(+), 24 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index ad2fed4bb1..5442106eaa 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1010,6 +1010,29 @@ async def handle_request(self, request, write_callback, stream_callback): else: await write_callback(response) + async def handle_request_trio(self, request): + # Request middleware + response = await self._run_request_middleware(request) + if response: + return await request.respond(response) + # Fetch handler from router + handler, args, kwargs, uri = self.router.get(request) + if handler is None: + raise ServerError( + "None was returned while requesting a handler from the router" + ) + bp = getattr(handler, "__blueprintname__", None) + bp = (bp,) if bp else () + request.endpoint = self._build_endpoint_name(*bp, handler.__name__) + request.uri_template = uri + # Run main handler + response = handler(request, *args, **kwargs) + if isawaitable(response): + response = await response + # Returned (non-streaming) response + if response is not None: + await request.respond(response) + # -------------------------------------------------------------------- # # Testing # -------------------------------------------------------------------- # @@ -1323,7 +1346,7 @@ def _helper( "app": self, "signal": Signal(), "debug": debug, - "request_handler": self.handle_request, + "request_handler": self.handle_request_trio, "error_handler": self.error_handler, "request_timeout": self.config.REQUEST_TIMEOUT, "response_timeout": self.config.RESPONSE_TIMEOUT, diff --git a/sanic/exceptions.py b/sanic/exceptions.py index 2c4ab2c02e..728235133e 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -175,6 +175,11 @@ class ServiceUnavailable(SanicException): pass +@add_status_code(505) +class VersionNotSupported(SanicException): + pass + + class URLBuildError(ServerError): pass diff --git a/sanic/request.py b/sanic/request.py index 8356d579b8..aa0060f679 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -88,6 +88,7 @@ class Request(dict): "parsed_form", "parsed_json", "raw_url", + "respond", "stream", "transport", "uri_template", diff --git a/sanic/response.py b/sanic/response.py index 6f937c9576..c9cb415050 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -144,6 +144,92 @@ def get_headers( headers, ) +class NewStreamingHTTPResponse(BaseHTTPResponse): + __slots__ = ( + "stream", + "keep_alive", + "keep_alive_timeout", + ) + + def __init__(self, stream): + self.stream = stream + + async def write(self, data): + """Writes a chunk of data to the streaming response. + :param data: bytes-ish data to be written. + """ + if self.chunked is None: + raise RuntimeError( + "cannot write data before setting content type, " + "try using response.write_headers() first" + ) + + if type(data) != bytes: + data = self._encode_body(data) + + if self.chunked: + await self.stream.send_all(b"%x\r\n%b\r\n" % (len(data), data)) + else: + await self.stream.send_all(data) + + async def aclose(self): + if self.chunked: + await self.stream.send_all(b"0\r\n\r\n") + + async def write_headers( + self, status=200, headers=None, content_type="text/plain", chunked=True + ): + self.chunked = chunked + headers = self.get_headers( + status, headers, content_type, chunked + ) + await self.stream.send_all(headers) + + def _headers_as_bytes(self, headers: Header) -> bytes: + hbytes = b"" + for name, value in headers.items(): + try: + hbytes += b"%b: %b\r\n" % ( + name.encode(), + value.encode("utf-8"), + ) + except AttributeError: + hbytes += b"%b: %b\r\n" % ( + str(name).encode(), + str(value).encode("utf-8"), + ) + + return hbytes + + def get_headers( + self, status=200, headers=None, content_type="text/plain", chunked=True + ) -> bytes: + headers = Header(headers or {}) + # This is all returned in a kind-of funky way + # We tried to make this as fast as possible in pure python + timeout_header = b"" + #if self.keep_alive and self.keep_alive_timeout is not None: + # timeout_header = b"Keep-Alive: %d\r\n" % self.keep_alive_timeout + + if chunked: + headers["Transfer-Encoding"] = "chunked" + headers.pop("Content-Length", None) + headers["Content-Type"] = headers.get("Content-Type", content_type) + + headers = self._headers_as_bytes(headers) + + if status == 200: + status_code = b"OK" + else: + status_code = STATUS_CODES.get(status) + + return (b"HTTP/1.1 %d %b\r\n" b"%b" b"%b\r\n") % ( + status, + status_code, + timeout_header, + headers, + ) + class HTTPResponse(BaseHTTPResponse): __slots__ = ("body", "status", "content_type", "headers", "_cookies") diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 1370e55030..2dd15e54a8 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -25,17 +25,25 @@ PayloadTooLarge, RequestTimeout, ServerError, + VersionNotSupported, ServiceUnavailable, ) from sanic.log import access_logger, logger from sanic.request import EXPECT_HEADER, Request, StreamBuffer -from sanic.response import HTTPResponse +from sanic.response import HTTPResponse, NewStreamingHTTPResponse class Signal: stopped = False +# Compatibility wrapper for request StreamBuffer (asyncio) +class TrioStreamBuffer: + def __init__(self, buffer_size=100): + self.sender, self.receiver = trio.open_memory_channel(100) + self.read = self.receiver.receive + self.put = self.sender.send + h2config = H2Configuration( client_side=False, header_encoding="utf-8", @@ -49,7 +57,6 @@ class Signal: idle_connections = set() - def parse_h1_request(data: bytes) -> dict: try: data = data.decode() @@ -57,7 +64,8 @@ def parse_h1_request(data: bytes) -> dict: data = data.decode("ISO-8859-1") req, *hlines = data.split("\r\n") method, path, version = req.split(" ") - assert version == "HTTP/1.1" + if version != "HTTP/1.1": + raise VersionNotSupported(f"Expected 'HTTP/1.1', got '{version}'") headers = {":method": method, ":path": path} for name, value in (h.split(": ", 1) for h in hlines): name = name.lower() @@ -80,6 +88,28 @@ async def receive_some(self, max_bytes=None): stream.__class__ = PushbackStream +class H1StreamRequest: + def __init__(self, headers, stream, set_timeout): + self.length = int(headers.get("content-length")) + if self.length <= 0: + raise InvalidUsage("Content-length must be positive") + self.expect_continue = headers.get("expect", "").lower() == "100-continue" + self.stream = stream + self.set_timeout = set_timeout + + async def read(self): + if self.expect_continue: + await self.stream.send_all(b"HTTP/1.1 100 Continue\r\n\r\n") + self.except_continue = False + buf = await self.stream.read_some() + if len(buf) > self.length: + push_back(self.stream, buf[self.length:]) + del buf[self.length:] + self.length -= len(buf) + # Extend or switch deadline + self.set_timeout("request" if self.length else "response") + return buf + class HttpProtocol: def __init__(self, **kwargs): @@ -125,14 +155,17 @@ async def sniff_protocol(self): if req == "h2": return self.http2() + def set_timeout(self, timeout: str): + self.nursery.cancel_scope.deadline = ( + trio.current_time() + getattr(self, f"{timeout}_timeout") + ) + async def run(self, stream): assert not self.stream self.stream = stream try: async with stream, trio.open_nursery() as self.nursery: - self.nursery.cancel_scope.deadline = ( - trio.current_time() + self.request_timeout - ) + self.set_timeout("request") protocol_coroutine = await self.sniff_protocol() if not protocol_coroutine: return @@ -163,7 +196,7 @@ async def _receive_request_using(self, buffer: bytearray): del buffer[: pos + 4] return req if buffer: - raise RuntimeError("Peer disconnected after {buffer!r}") + raise RuntimeError(f"Peer disconnected after {buffer!r}") async def http1(self, headers=None): buffer = bytearray() @@ -174,10 +207,6 @@ async def http1(self, headers=None): if not req: return headers = parse_h1_request(req) - # Process response - self.nursery.cancel_scope.deadline = ( - trio.current_time() + self.response_timeout - ) request = self.request_class( url_bytes=headers[":path"].encode(), headers=Header(headers), @@ -186,20 +215,27 @@ async def http1(self, headers=None): transport=None, app=self.app, ) + if "content-length" in headers: + request.stream = H1StreamRequest(headers, self.streams) + else: + self.set_timeout("response") headers = None - keep_alive = True - - async def write_response(response): - await self.stream.send_all( - response.output( - request.version, keep_alive, self.keep_alive_timeout - ) - ) - await self.request_handler(request, write_response, None) - self.nursery.cancel_scope.deadline = ( - trio.current_time() + self.request_timeout - ) + # Process response + request.respond = self.h1_respond + await self.request_handler(request) + self.set_timeout("request") + + async def h1_respond(self, response): + # TODO: Prevent multiple responses + if isinstance(response, dict): + headers = response + response = NewStreamingHTTPResponse(self.stream) + await response.write_headers(headers[":status"], headers, headers["content-type"]) + return response + await self.stream.send_all( + response.output("1.1", self.keep_alive, self.keep_alive_timeout) + ) async def h2_sender(self): async for _ in self.can_send: From ec6707cd8bd001c8b07cf8d3b804443815d84593 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Mon, 2 Sep 2019 19:29:11 +0300 Subject: [PATCH 12/31] Streaming responses now working with HTTP/1. --- sanic/server_trio.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 2dd15e54a8..751ddb34b4 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -220,23 +220,29 @@ async def http1(self, headers=None): else: self.set_timeout("response") headers = None + _response = None + # request.respond() + async def respond(response=None, *, status=200, headers=None, content_type="text/html"): + nonlocal _response + if _response: + raise ServerError("Duplicate responses for a single request!") + if _response is None: + _response = NewStreamingHTTPResponse(self.stream) + await _response.write_headers(status, headers, content_type) + return _response + _response = response + await self.stream.send_all( + response.output("1.1", self.keep_alive, self.keep_alive_timeout) + ) - # Process response - request.respond = self.h1_respond + request.respond = respond await self.request_handler(request) + if not _response: + raise ServerError("Request handler made no response.") + if hasattr(_response, "aclose"): + await _response.aclose() self.set_timeout("request") - async def h1_respond(self, response): - # TODO: Prevent multiple responses - if isinstance(response, dict): - headers = response - response = NewStreamingHTTPResponse(self.stream) - await response.write_headers(headers[":status"], headers, headers["content-type"]) - return response - await self.stream.send_all( - response.output("1.1", self.keep_alive, self.keep_alive_timeout) - ) - async def h2_sender(self): async for _ in self.can_send: await self.stream.send_all(self.conn.data_to_send()) From 7031547a579b60dd6baf59b43f1dcc9bdd86c608 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Tue, 3 Sep 2019 11:09:34 +0300 Subject: [PATCH 13/31] Streaming fixes, now request and response may be streamed. --- sanic/response.py | 95 ++++++++++++++++++++------------------------ sanic/server_trio.py | 38 +++++++++++++----- 2 files changed, 71 insertions(+), 62 deletions(-) diff --git a/sanic/response.py b/sanic/response.py index c9cb415050..59aa5f5557 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -147,88 +147,77 @@ def get_headers( class NewStreamingHTTPResponse(BaseHTTPResponse): __slots__ = ( "stream", - "keep_alive", - "keep_alive_timeout", + "length", ) def __init__(self, stream): self.stream = stream + self.length = None - async def write(self, data): + async def write(self, data=b"", end_stream=False): """Writes a chunk of data to the streaming response. :param data: bytes-ish data to be written. """ - if self.chunked is None: - raise RuntimeError( - "cannot write data before setting content type, " - "try using response.write_headers() first" - ) + if self.chunked is None and self.length is None: + raise RuntimeError("Response body is closed") - if type(data) != bytes: + if isinstance(data, (bytes, bytearray)): data = self._encode_body(data) + size = len(data) if self.chunked: - await self.stream.send_all(b"%x\r\n%b\r\n" % (len(data), data)) + if size and not end_stream: + data = b"%x\r\n%b\r\n" % (size, data) + elif size: + data = b"%x\r\n%b\r\n0\r\n\r\n" % (size, data) + self.chunked = None + else: + data = b"0\r\n\r\n" + self.chunked = None else: - await self.stream.send_all(data) + self.length -= size + if self.length < 0: + await self.stream.aclose() + raise RuntimeError( + "Response body larger than specified in content-length" + ) + await self.stream.send_all(data) async def aclose(self): if self.chunked: - await self.stream.send_all(b"0\r\n\r\n") + await self.write(end_stream=True) + elif self.length: + l, self.length = self.length, None + raise RuntimeError(f"Response aclosed with {l} bytes missing") async def write_headers( - self, status=200, headers=None, content_type="text/plain", chunked=True + self, status=200, headers=None, content_type="text/plain" ): - self.chunked = chunked - headers = self.get_headers( - status, headers, content_type, chunked - ) + headers = self.get_headers(status, headers, content_type) await self.stream.send_all(headers) def _headers_as_bytes(self, headers: Header) -> bytes: - hbytes = b"" - for name, value in headers.items(): - try: - hbytes += b"%b: %b\r\n" % ( - name.encode(), - value.encode("utf-8"), - ) - except AttributeError: - hbytes += b"%b: %b\r\n" % ( - str(name).encode(), - str(value).encode("utf-8"), - ) - - return hbytes + headers = "".join(f"{n}: {v!s}\r\n" for n, v in headers.items()) + return headers.encode() def get_headers( - self, status=200, headers=None, content_type="text/plain", chunked=True + self, status=200, headers=None, content_type="text/plain" ) -> bytes: headers = Header(headers or {}) - # This is all returned in a kind-of funky way - # We tried to make this as fast as possible in pure python - timeout_header = b"" - #if self.keep_alive and self.keep_alive_timeout is not None: - # timeout_header = b"Keep-Alive: %d\r\n" % self.keep_alive_timeout - - if chunked: - headers["Transfer-Encoding"] = "chunked" - headers.pop("Content-Length", None) - headers["Content-Type"] = headers.get("Content-Type", content_type) - - headers = self._headers_as_bytes(headers) - if status == 200: - status_code = b"OK" + if "content-length" in headers: + self.length = int(headers["content-length"]) + self.chunked = False else: - status_code = STATUS_CODES.get(status) + headers["transfer-encoding"] = "chunked" + self.chunked = True - return (b"HTTP/1.1 %d %b\r\n" b"%b" b"%b\r\n") % ( - status, - status_code, - timeout_header, - headers, - ) + if "content-type" not in headers: + headers["content-type"] = content_type + + headers = self._headers_as_bytes(headers) + status_code = b"OK" if status == 200 else STATUS_CODES.get(status) + return b"HTTP/1.1 %d %b\r\n%b\r\n" % (status, status_code, headers) class HTTPResponse(BaseHTTPResponse): diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 751ddb34b4..7847c8f385 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -89,22 +89,27 @@ async def receive_some(self, max_bytes=None): stream.__class__ = PushbackStream class H1StreamRequest: - def __init__(self, headers, stream, set_timeout): + def __init__(self, headers, stream, set_timeout, trigger_continue): self.length = int(headers.get("content-length")) if self.length <= 0: raise InvalidUsage("Content-length must be positive") - self.expect_continue = headers.get("expect", "").lower() == "100-continue" self.stream = stream self.set_timeout = set_timeout + self.trigger_continue = trigger_continue + + async def __aiter__(self): + while True: + data = await self.read() + if not data: return + yield data async def read(self): - if self.expect_continue: - await self.stream.send_all(b"HTTP/1.1 100 Continue\r\n\r\n") - self.except_continue = False - buf = await self.stream.read_some() + await self.trigger_continue() + if self.length == 0: return None + buf = await self.stream.receive_some() if len(buf) > self.length: push_back(self.stream, buf[self.length:]) - del buf[self.length:] + buf = buf[:self.length] self.length -= len(buf) # Extend or switch deadline self.set_timeout("request" if self.length else "response") @@ -215,17 +220,32 @@ async def http1(self, headers=None): transport=None, app=self.app, ) + need_continue = headers.get("expect", "").lower() == "100-continue" + async def trigger_continue(): + nonlocal need_continue + if need_continue is False: + return + await self.stream.send_all(b"HTTP/1.1 100 Continue\r\n\r\n") + need_continue = False + if "chunked" in headers.get("transfer-encoding", "").lower(): + raise RuntimeError("Chunked requests not supported") # FIXME if "content-length" in headers: - request.stream = H1StreamRequest(headers, self.streams) + request.stream = H1StreamRequest( + headers, + self.stream, + self.set_timeout, + trigger_continue, + ) else: self.set_timeout("response") headers = None _response = None - # request.respond() + # Implement request.respond: async def respond(response=None, *, status=200, headers=None, content_type="text/html"): nonlocal _response if _response: raise ServerError("Duplicate responses for a single request!") + await trigger_continue() if _response is None: _response = NewStreamingHTTPResponse(self.stream) await _response.write_headers(status, headers, content_type) From c9a8232eb9a0f6137e78b72393705e264667d76e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Tue, 3 Sep 2019 15:37:07 +0300 Subject: [PATCH 14/31] Bugfixes --- sanic/response.py | 2 +- sanic/server_trio.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sanic/response.py b/sanic/response.py index 59aa5f5557..893fed642e 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -161,7 +161,7 @@ async def write(self, data=b"", end_stream=False): if self.chunked is None and self.length is None: raise RuntimeError("Response body is closed") - if isinstance(data, (bytes, bytearray)): + if not isinstance(data, (bytes, bytearray)): data = self._encode_body(data) size = len(data) diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 7847c8f385..314bd2d7ab 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -91,7 +91,7 @@ async def receive_some(self, max_bytes=None): class H1StreamRequest: def __init__(self, headers, stream, set_timeout, trigger_continue): self.length = int(headers.get("content-length")) - if self.length <= 0: + if self.length < 0: raise InvalidUsage("Content-length must be positive") self.stream = stream self.set_timeout = set_timeout @@ -205,6 +205,7 @@ async def _receive_request_using(self, buffer: bytearray): async def http1(self, headers=None): buffer = bytearray() + _response = None while True: # Process request if headers is None: @@ -230,6 +231,8 @@ async def trigger_continue(): if "chunked" in headers.get("transfer-encoding", "").lower(): raise RuntimeError("Chunked requests not supported") # FIXME if "content-length" in headers: + push_back(self.stream, buffer) + del buffer[:] request.stream = H1StreamRequest( headers, self.stream, @@ -246,7 +249,7 @@ async def respond(response=None, *, status=200, headers=None, content_type="text if _response: raise ServerError("Duplicate responses for a single request!") await trigger_continue() - if _response is None: + if response is None: _response = NewStreamingHTTPResponse(self.stream) await _response.write_headers(status, headers, content_type) return _response @@ -261,6 +264,7 @@ async def respond(response=None, *, status=200, headers=None, content_type="text raise ServerError("Request handler made no response.") if hasattr(_response, "aclose"): await _response.aclose() + _response = None self.set_timeout("request") async def h2_sender(self): From c1fd59bf1b9a8b81cea29c05b2433a115cd0bfd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Tue, 3 Sep 2019 18:55:51 +0300 Subject: [PATCH 15/31] Non-streaming handlers working on top of streaming requests. --- sanic/app.py | 3 +++ sanic/request.py | 16 +++++++++++++++- sanic/server_trio.py | 17 +++++++---------- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 5442106eaa..875fdc4af3 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1025,6 +1025,9 @@ async def handle_request_trio(self, request): bp = (bp,) if bp else () request.endpoint = self._build_endpoint_name(*bp, handler.__name__) request.uri_template = uri + # Load header body before starting handler? + if not hasattr(handler, "is_stream"): + await request.receive_body() # Run main handler response = handler(request, *args, **kwargs) if isawaitable(response): diff --git a/sanic/request.py b/sanic/request.py index aa0060f679..583bf211e6 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -11,7 +11,7 @@ from httptools import parse_url -from sanic.exceptions import InvalidUsage +from sanic.exceptions import HeaderExpectationFailed, InvalidUsage from sanic.log import error_logger, logger @@ -137,6 +137,20 @@ def body_push(self, data): def body_finish(self): self.body = b"".join(self.body) + async def receive_body(self): + if self.stream: + max_size = self.stream.request_max_size + body = [] + if self.stream.length > max_size: + raise HeaderExpectationFailed("Request body is too large.") + async for data in request.stream: + if self.stream.pos > max_size: + raise HeaderExpectationFailed("Request body is too large.") + body.append(data) + self.body = b"".join(body) + self.stream = None + + @property def json(self): if self.parsed_json is None: diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 314bd2d7ab..1f18261ea8 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -37,13 +37,6 @@ class Signal: stopped = False -# Compatibility wrapper for request StreamBuffer (asyncio) -class TrioStreamBuffer: - def __init__(self, buffer_size=100): - self.sender, self.receiver = trio.open_memory_channel(100) - self.read = self.receiver.receive - self.put = self.sender.send - h2config = H2Configuration( client_side=False, header_encoding="utf-8", @@ -63,7 +56,7 @@ def parse_h1_request(data: bytes) -> dict: except UnicodeDecodeError: data = data.decode("ISO-8859-1") req, *hlines = data.split("\r\n") - method, path, version = req.split(" ") + method, path, version = req.split(" ", 2) if version != "HTTP/1.1": raise VersionNotSupported(f"Expected 'HTTP/1.1', got '{version}'") headers = {":method": method, ":path": path} @@ -75,6 +68,8 @@ def parse_h1_request(data: bytes) -> dict: def push_back(stream, data): + if not data: + return orig_class = stream.__class__ class PushbackStream(orig_class): @@ -89,10 +84,12 @@ async def receive_some(self, max_bytes=None): stream.__class__ = PushbackStream class H1StreamRequest: + __slots__ = "length", "pos", "stream", "set_timeout", "trigger_continue" def __init__(self, headers, stream, set_timeout, trigger_continue): self.length = int(headers.get("content-length")) if self.length < 0: raise InvalidUsage("Content-length must be positive") + self.pos = 0 self.stream = stream self.set_timeout = set_timeout self.trigger_continue = trigger_continue @@ -105,12 +102,12 @@ async def __aiter__(self): async def read(self): await self.trigger_continue() - if self.length == 0: return None + if self.pos == self.length: return None buf = await self.stream.receive_some() if len(buf) > self.length: push_back(self.stream, buf[self.length:]) buf = buf[:self.length] - self.length -= len(buf) + self.pos += len(buf) # Extend or switch deadline self.set_timeout("request" if self.length else "response") return buf From 2c5f0160b87dd88d7f3082ba0fbc9d7766f96d26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Wed, 4 Sep 2019 13:31:39 +0300 Subject: [PATCH 16/31] Response middleware for non-streaming responses. --- sanic/server_trio.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 1f18261ea8..a4a23fc99a 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -250,7 +250,13 @@ async def respond(response=None, *, status=200, headers=None, content_type="text _response = NewStreamingHTTPResponse(self.stream) await _response.write_headers(status, headers, content_type) return _response + # Middleware has a chance to replace the response + response = await self.app._run_response_middleware( + request, response + ) _response = response + if not isinstance(response, HTTPResponse): + raise ServerError(f"Handling {request.path}: HTTPResponse expected but got {type(response).__name__}") await self.stream.send_all( response.output("1.1", self.keep_alive, self.keep_alive_timeout) ) From b2ea924bf03d7c3d575e7a0bbde3e044751cc6e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Wed, 4 Sep 2019 14:25:28 +0300 Subject: [PATCH 17/31] Minor cleanup to push_back() --- sanic/server_trio.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sanic/server_trio.py b/sanic/server_trio.py index a4a23fc99a..1851817514 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -70,15 +70,15 @@ def parse_h1_request(data: bytes) -> dict: def push_back(stream, data): if not data: return - orig_class = stream.__class__ + stream_type = type(stream) - class PushbackStream(orig_class): + class PushbackStream(stream_type): async def receive_some(self, max_bytes=None): if max_bytes and max_bytes < len(data): ret = data[:max_bytes] del data[:max_bytes] return ret - self.__class__ = orig_class + self.__class__ = stream_type return data stream.__class__ = PushbackStream From 4537544fde54bf3b08d56cdde375882b1e072202 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Wed, 4 Sep 2019 17:15:32 +0300 Subject: [PATCH 18/31] HTTP1 header formatting moved to headers.format_headers and rewritten. - New implementation is one line of code and twice faster than the old one. - Whole header block encoded to UTF-8 in one pass. - No longer supports custom encode method on header values. - Cookie objects now have __str__ in addition to encode, to work with this. --- sanic/cookies.py | 6 +++++- sanic/headers.py | 13 +++++++++++-- sanic/response.py | 16 ++-------------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/sanic/cookies.py b/sanic/cookies.py index 19907945e1..ed672fba1a 100644 --- a/sanic/cookies.py +++ b/sanic/cookies.py @@ -130,6 +130,10 @@ def encode(self, encoding): :return: Cookie encoded in a codec of choosing. :except: UnicodeEncodeError """ + return str(self).encode(encoding) + + def __str__(self): + """Format as a Set-Cookie header value.""" output = ["%s=%s" % (self.key, _quote(self.value))] for key, value in self.items(): if key == "max-age": @@ -147,4 +151,4 @@ def encode(self, encoding): else: output.append("%s=%s" % (self._keys[key], value)) - return "; ".join(output).encode(encoding) + return "; ".join(output) diff --git a/sanic/headers.py b/sanic/headers.py index 6c9fa2215f..e1ac48b34d 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -1,9 +1,9 @@ import re -from typing import Dict, Iterable, Optional, Tuple +from typing import Any, Dict, Iterable, Optional, Tuple from urllib.parse import unquote - +HeaderIterable = Iterable[Tuple[str, Any]] # Values convertible to str Options = Dict[str, str] # key=value fields in various headers OptionsIterable = Iterable[Tuple[str, str]] # May contain duplicate keys @@ -165,3 +165,12 @@ def parse_host(host: str) -> Tuple[Optional[str], Optional[int]]: return None, None host, port = m.groups() return host.lower(), port and int(port) + + +def format_http1(headers: HeaderIterable) -> bytes: + """Convert a headers iterable into HTTP/1 header format. + + - Outputs UTF-8 bytes where each header line ends with \\r\\n. + - Values are converted into strings if necessary. + """ + return "".join(f"{name}: {val}\r\n" for name, val in headers).encode() diff --git a/sanic/response.py b/sanic/response.py index 6f937c9576..83eacd53f7 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -7,6 +7,7 @@ from sanic.compat import Header from sanic.cookies import CookieJar +from sanic.headers import format_http1 from sanic.helpers import STATUS_CODES, has_message_body, remove_entity_headers @@ -30,20 +31,7 @@ def _encode_body(self, data): return str(data).encode() def _parse_headers(self): - headers = b"" - for name, value in self.headers.items(): - try: - headers += b"%b: %b\r\n" % ( - name.encode(), - value.encode("utf-8"), - ) - except AttributeError: - headers += b"%b: %b\r\n" % ( - str(name).encode(), - str(value).encode("utf-8"), - ) - - return headers + return format_http1(self.headers.items()) @property def cookies(self): From d248dbb72bb1b9df9ecb956d11e9c91bd015b5a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Wed, 4 Sep 2019 17:51:59 +0300 Subject: [PATCH 19/31] Linter --- sanic/headers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sanic/headers.py b/sanic/headers.py index e1ac48b34d..ec3dc2377d 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple from urllib.parse import unquote + HeaderIterable = Iterable[Tuple[str, Any]] # Values convertible to str Options = Dict[str, str] # key=value fields in various headers OptionsIterable = Iterable[Tuple[str, str]] # May contain duplicate keys From 7dc683913f6da0c252d8a08b717fb23c6f578659 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Thu, 5 Sep 2019 11:43:23 +0300 Subject: [PATCH 20/31] format_http1_response --- sanic/headers.py | 20 ++++++++++++++++ sanic/response.py | 60 ++++++++++++----------------------------------- 2 files changed, 35 insertions(+), 45 deletions(-) diff --git a/sanic/headers.py b/sanic/headers.py index ec3dc2377d..cc9a0dcfbe 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -3,6 +3,8 @@ from typing import Any, Dict, Iterable, Optional, Tuple from urllib.parse import unquote +from sanic.helpers import STATUS_CODES + HeaderIterable = Iterable[Tuple[str, Any]] # Values convertible to str Options = Dict[str, str] # key=value fields in various headers @@ -175,3 +177,21 @@ def format_http1(headers: HeaderIterable) -> bytes: - Values are converted into strings if necessary. """ return "".join(f"{name}: {val}\r\n" for name, val in headers).encode() + + +def format_http1_response( + status: int, headers: HeaderIterable, body=b"" +) -> bytes: + """Format a full HTTP/1.1 response. + + - If `body` is included, content-length must be specified in headers. + """ + headers = format_http1(headers) + if status == 200: + return b"HTTP/1.1 200 OK\r\n%b\r\n%b" % (headers, body) + return b"HTTP/1.1 %d %b\r\n%b\r\n%b" % ( + status, + STATUS_CODES.get(status, b"UNKNOWN"), + headers, + body, + ) diff --git a/sanic/response.py b/sanic/response.py index 83eacd53f7..92362c3d4d 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -7,8 +7,8 @@ from sanic.compat import Header from sanic.cookies import CookieJar -from sanic.headers import format_http1 -from sanic.helpers import STATUS_CODES, has_message_body, remove_entity_headers +from sanic.headers import format_http1, format_http1_response +from sanic.helpers import has_message_body, remove_entity_headers try: @@ -104,33 +104,17 @@ async def stream( def get_headers( self, version="1.1", keep_alive=False, keep_alive_timeout=None ): - # This is all returned in a kind-of funky way - # We tried to make this as fast as possible in pure python - timeout_header = b"" + if "Content-Type" not in self.headers: + self.headers["Content-Type"] = self.content_type + if keep_alive and keep_alive_timeout is not None: - timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout + self.headers["Keep-Alive"] = keep_alive_timeout if self.chunked and version == "1.1": self.headers["Transfer-Encoding"] = "chunked" self.headers.pop("Content-Length", None) - self.headers["Content-Type"] = self.headers.get( - "Content-Type", self.content_type - ) - headers = self._parse_headers() - - if self.status == 200: - status = b"OK" - else: - status = STATUS_CODES.get(self.status) - - return (b"HTTP/%b %d %b\r\n" b"%b" b"%b\r\n") % ( - version.encode(), - self.status, - status, - timeout_header, - headers, - ) + return format_http1_response(self.status, self.headers.items()) class HTTPResponse(BaseHTTPResponse): @@ -156,11 +140,8 @@ def __init__( self._cookies = None def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None): - # This is all returned in a kind-of funky way - # We tried to make this as fast as possible in pure python - timeout_header = b"" - if keep_alive and keep_alive_timeout is not None: - timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout + if "Content-Type" not in self.headers: + self.headers["Content-Type"] = self.content_type body = b"" if has_message_body(self.status): @@ -176,24 +157,13 @@ def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None): if self.status in (304, 412): self.headers = remove_entity_headers(self.headers) - headers = self._parse_headers() + if keep_alive and keep_alive_timeout is not None: + self.headers["Connection"] = "keep-alive" + self.headers["Keep-Alive"] = keep_alive_timeout + elif not keep_alive: + self.headers["Connection"] = "close" - if self.status == 200: - status = b"OK" - else: - status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE") - - return ( - b"HTTP/%b %d %b\r\n" b"Connection: %b\r\n" b"%b" b"%b\r\n" b"%b" - ) % ( - version.encode(), - self.status, - status, - b"keep-alive" if keep_alive else b"close", - timeout_header, - headers, - body, - ) + return format_http1_response(self.status, self.headers.items(), body) @property def cookies(self): From 8ed885b875175588b79d0373f07c7cfc4efdbaf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Fri, 6 Sep 2019 17:53:57 +0300 Subject: [PATCH 21/31] Hacks to make it run on Windows. --- sanic/server_trio.py | 61 ++++++++++++++++++++++---------------------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 1851817514..816da4076b 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -13,11 +13,13 @@ from inspect import isawaitable from ipaddress import ip_address from multiprocessing import Process -from signal import SIG_IGN, SIGINT, SIGTERM, SIGHUP, Signals +from signal import SIG_IGN, SIGINT, SIGTERM, Signals from signal import signal as signal_func from time import time, sleep as time_sleep from httptools.parser.errors import HttpParserError +SIGHUP = SIGTERM + from sanic.compat import Header from sanic.exceptions import ( HeaderExpectationFailed, @@ -391,31 +393,30 @@ def serve( workers=1, loop=None, ): - async def handle_connection(stream): - proto = protocol( - connections=connections, - signal=signal, - app=app, - ssl=ssl, - request_handler=request_handler, - error_handler=error_handler, - request_timeout=request_timeout, - response_timeout=response_timeout, - keep_alive_timeout=keep_alive_timeout, - request_max_size=request_max_size, - request_class=request_class, - access_log=access_log, - keep_alive=keep_alive, - is_request_stream=is_request_stream, - router=router, - websocket_max_size=websocket_max_size, - websocket_max_queue=websocket_max_queue, - websocket_read_limit=websocket_read_limit, - websocket_write_limit=websocket_write_limit, - state=state, - debug=debug, - ) - await proto.run(stream) + proto = partial( + protocol, + connections=connections, + signal=signal, + app=app, + ssl=ssl, + request_handler=request_handler, + error_handler=error_handler, + request_timeout=request_timeout, + response_timeout=response_timeout, + keep_alive_timeout=keep_alive_timeout, + request_max_size=request_max_size, + request_class=request_class, + access_log=access_log, + keep_alive=keep_alive, + is_request_stream=is_request_stream, + router=router, + websocket_max_size=websocket_max_size, + websocket_max_queue=websocket_max_queue, + websocket_read_limit=websocket_read_limit, + websocket_write_limit=websocket_write_limit, + state=state, + debug=debug, + ) app.asgi = False assert not ( @@ -428,7 +429,7 @@ async def handle_connection(stream): after_start=after_start, before_stop=before_stop, after_stop=after_stop, - handle_connection=handle_connection, + proto=proto, graceful_shutdown_timeout=graceful_shutdown_timeout, ) @@ -479,7 +480,7 @@ def handler(s, tb): if workers: while True: while len(processes) < workers: - p = Process(target=runworker) + p = Process(target=trio.run, args=(acceptor, listeners, master_pid)) p.daemon = True p.start() processes.append(p) @@ -507,7 +508,7 @@ async def runaccept( after_start, before_stop, after_stop, - handle_connection, + proto, graceful_shutdown_timeout, ): try: @@ -520,7 +521,7 @@ async def runaccept( acceptor.start_soon( partial( trio.serve_listeners, - handler=handle_connection, + handler=proto().run, listeners=listeners, handler_nursery=main_nursery, ) From 8e97f2828e196e475576e4dd033e5780cac681bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Fri, 6 Sep 2019 18:45:11 +0300 Subject: [PATCH 22/31] Fix that hack. --- sanic/server_trio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 816da4076b..3edc4fbeac 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -521,7 +521,7 @@ async def runaccept( acceptor.start_soon( partial( trio.serve_listeners, - handler=proto().run, + handler=lambda stream: proto().run(stream), listeners=listeners, handler_nursery=main_nursery, ) From 84539fc60519e5934211b3025d2d78b68b084707 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Fri, 6 Sep 2019 19:11:29 +0300 Subject: [PATCH 23/31] NewStreamingResponse only sends headers on write. --- sanic/response.py | 21 ++++++++++----------- sanic/server_trio.py | 6 ++---- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/sanic/response.py b/sanic/response.py index cd6f46ec66..eee066f101 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -120,11 +120,16 @@ class NewStreamingHTTPResponse(BaseHTTPResponse): __slots__ = ( "stream", "length", + "status", + "headers", ) - def __init__(self, stream): + def __init__(self, stream, status, headers, content_type): self.stream = stream self.length = None + self.status = status + headers["Content-Type"] = content_type + self.headers = headers async def write(self, data=b"", end_stream=False): """Writes a chunk of data to the streaming response. @@ -153,6 +158,10 @@ async def write(self, data=b"", end_stream=False): raise RuntimeError( "Response body larger than specified in content-length" ) + # Prepend header block to data, if not yet sent + if self.status is not None: + data = format_http1_response(self.status, self.headers, data) + self.status = self.headers = None await self.stream.send_all(data) async def aclose(self): @@ -162,16 +171,6 @@ async def aclose(self): l, self.length = self.length, None raise RuntimeError(f"Response aclosed with {l} bytes missing") - async def write_headers( - self, status=200, headers=None, content_type="text/plain" - ): - headers = self.get_headers(status, headers, content_type) - await self.stream.send_all(headers) - - def _headers_as_bytes(self, headers: Header) -> bytes: - headers = "".join(f"{n}: {v!s}\r\n" for n, v in headers.items()) - return headers.encode() - def get_headers( self, status=200, headers=None, content_type="text/plain" ) -> bytes: diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 3edc4fbeac..c65f61df61 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -249,10 +249,8 @@ async def respond(response=None, *, status=200, headers=None, content_type="text raise ServerError("Duplicate responses for a single request!") await trigger_continue() if response is None: - _response = NewStreamingHTTPResponse(self.stream) - await _response.write_headers(status, headers, content_type) - return _response - # Middleware has a chance to replace the response + response = NewStreamingHTTPResponse(self.stream, status, headers, content_type) + # Middleware has a chance to replace or modify the response response = await self.app._run_response_middleware( request, response ) From 491ce25ab4616f00bd42654dbb0be40692802471 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Sun, 8 Sep 2019 20:45:02 +0300 Subject: [PATCH 24/31] Major cleanup; rewrote H1StreamRequest and NewStreamingHTTPResponse into a combined class protocol.H1Stream; added error handling. --- sanic/app.py | 20 ++++-- sanic/headers.py | 4 +- sanic/protocol.py | 155 +++++++++++++++++++++++++++++++++++++++++++ sanic/request.py | 14 ++-- sanic/response.py | 74 --------------------- sanic/server_trio.py | 135 ++++++++++++++----------------------- 6 files changed, 232 insertions(+), 170 deletions(-) create mode 100644 sanic/protocol.py diff --git a/sanic/app.py b/sanic/app.py index 56105705be..ca63b8cd20 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -22,7 +22,11 @@ from sanic.exceptions import SanicException, ServerError, URLBuildError from sanic.handlers import ErrorHandler from sanic.log import LOGGING_CONFIG_DEFAULTS, error_logger, logger -from sanic.response import HTTPResponse, StreamingHTTPResponse +from sanic.response import ( + BaseHTTPResponse, + HTTPResponse, + StreamingHTTPResponse, +) from sanic.router import Router from sanic.server import HttpProtocol, Signal, serve from sanic.static import register as static_register @@ -1033,8 +1037,16 @@ async def handle_request_trio(self, request): if isawaitable(response): response = await response # Returned (non-streaming) response - if response is not None: - await request.respond(response) + if isinstance(response, BaseHTTPResponse): + await request.respond( + status=response.status, + headers=response.headers, + content_type=response.content_type, + ).send(data_bytes=response.body, end_stream=True) + elif response is not None: + raise ServerError( + f"Handling {request.path}: HTTPResponse expected but got {type(response).__name__} {response!r:.200}" + ) # -------------------------------------------------------------------- # # Testing @@ -1065,7 +1077,7 @@ def run( stop_event: Any = None, register_sys_signals: bool = True, access_log: Optional[bool] = None, - **kwargs: Any + **kwargs: Any, ) -> None: """Run the HTTP Server and listen until keyboard interrupt or term signal. On termination, drain connections before closing. diff --git a/sanic/headers.py b/sanic/headers.py index cc9a0dcfbe..90ab345c55 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -180,12 +180,14 @@ def format_http1(headers: HeaderIterable) -> bytes: def format_http1_response( - status: int, headers: HeaderIterable, body=b"" + status: int, headers: HeaderIterable, body: Optional[bytes] = None ) -> bytes: """Format a full HTTP/1.1 response. - If `body` is included, content-length must be specified in headers. """ + if body is None: + body = b"" headers = format_http1(headers) if status == 200: return b"HTTP/1.1 200 OK\r\n%b\r\n%b" % (headers, body) diff --git a/sanic/protocol.py b/sanic/protocol.py new file mode 100644 index 0000000000..bb123218a6 --- /dev/null +++ b/sanic/protocol.py @@ -0,0 +1,155 @@ +from sanic.headers import format_http1, format_http1_response +from sanic.helpers import has_message_body, remove_entity_headers + + +# FIXME: Put somewhere before response: +if False: + # Middleware has a chance to replace or modify the response + response = await self.app._run_response_middleware( + request, response + ) + +class H1Stream: + __slots__ = ("stream", "length", "pos", "set_timeout", "response_state", "status", "headers", "bytes_left") + + def __init__(self, headers, stream, set_timeout, need_continue): + self.length = int(headers.get("content-length", "0")) + assert self.length >= 0 + self.pos = None if need_continue else 0 + self.stream = stream + self.status = self.bytes_left = None + self.response_state = 0 + self.set_timeout = set_timeout + self.update_deadline() + + async def aclose(self): + # Finish sending a response (if no error) + if self.response_state < 2: + await self.send(end_stream=True) + # Response fully sent, request fully read? + if self.pos != self.length or self.response_state != 2: + await self.stream.aclose() # If not, must disconnect :( + + def update_deadline(self): + # Extend or switch deadline + self.set_timeout("request" if self.pos is not None and self.pos < self.length else "request") + + # Request methods + + def dont_continue(self): + """Prevent a pending 100 Continue response being sent, and avoid + receiving the request body. Does not by itself send a 417 response.""" + if self.pos is None: + self.pos = self.length = 0 + + async def trigger_continue(self): + if self.pos is None: + self.pos = 0 + await self.stream.send_all(b"HTTP/1.1 100 Continue\r\n\r\n") + + async def __aiter__(self): + while True: + data = await self.read() + if not data: + return + yield data + + async def read(self): + await self.trigger_continue() + if self.pos == self.length: + return None + buf = await self.stream.receive_some() + if len(buf) > self.length: + self.stream.push_back(buf[self.length :]) + buf = buf[: self.length] + self.pos += len(buf) + self.update_deadline() + return buf + + # Response methods + + def respond(self, status, headers): + if self.response_state > 0: + self.response_state = 3 # FAIL mode + raise RuntimeError("Response already started") + self.status = status + self.headers = headers + return self + + async def send(self, data=None, data_bytes=None, end_stream=False): + """Send any pending response headers and the given data as body. + :param data: str-convertible data to be written + :param data_bytes: bytes-ish data to be written (used if data is None) + :end_stream: whether to close the stream after this block + """ + data = self.data_to_send(data, data_bytes, end_stream) + if data is None: + return + # Check if the request expects a 100-continue first + if self.pos is None: + if self.status == 417: + self.dont_continue() + else: + await self.trigger_continue() + # Send response + await self.stream.send_all(data) + + def data_to_send(self, data, data_bytes, end_stream): + """Format output data bytes for given body data. + Headers are prepended to the first output block and then cleared. + :param data: str-convertible data to be written + :param data_bytes: bytes-ish data to be written (used if data is None) + :return: bytes to send, or None if there is nothing to send + """ + data = data_bytes if data is None else f"{data}".encode() + size = len(data) if data is not None else 0 + + # Headers not yet sent? + if self.response_state == 0: + status, headers = self.status, self.headers + if status in (304, 412): + headers = remove_entity_headers(headers) + if not has_message_body(status): + # Header-only response status + assert ( + size == 0 and end_stream + ), f"A {status} response may only have headers, no body." + assert "content-length" not in self.headers + assert "transfer-encoding" not in self.headers + elif end_stream: + # Non-streaming response (all in one block) + headers["content-length"] = size + elif "content-length" in headers: + # Streaming response with size known in advance + self.bytes_left = int(headers["content-length"]) - size + assert self.bytes_left >= 0 + else: + # Length not known, use chunked encoding + headers["transfer-encoding"] = "chunked" + data = b"%x\r\n%b\r\n" % (size, data) if size else None + self.bytes_left = ... + self.status = self.headers = None + self.response_state = 2 if end_stream else 1 + return format_http1_response(status, headers.items(), data) + + if self.response_state == 2: + if size: + raise RuntimeError("Cannot send data to a closed stream") + return + + self.response_state = 2 if end_stream else 1 + + # Chunked encoding + if self.bytes_left is ...: + if end_stream: + self.bytes_left = None + if size: + return b"%x\r\n%b\r\n0\r\n\r\n" % (size, data) + return b"0\r\n\r\n" + return b"%x\r\n%b\r\n" % (size, data) if size else None + + # Normal encoding + if isinstance(self.bytes_left, int): + self.bytes_left -= size + assert self.bytes_left >= 0 + return data if size else None diff --git a/sanic/request.py b/sanic/request.py index 04d3e3c00b..a53976a194 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -10,6 +10,7 @@ from httptools import parse_url +from sanic.compat import Header from sanic.exceptions import HeaderExpectationFailed, InvalidUsage from sanic.headers import ( parse_content_header, @@ -94,7 +95,6 @@ class Request(dict): "parsed_json", "parsed_forwarded", "raw_url", - "respond", "stream", "transport", "uri_template", @@ -145,18 +145,22 @@ def body_finish(self): self.body = b"".join(self.body) async def receive_body(self): - if self.stream: - max_size = self.stream.request_max_size + if not self.stream.pos: + max_size = self.app.config.REQUEST_MAX_SIZE body = [] if self.stream.length > max_size: raise HeaderExpectationFailed("Request body is too large.") - async for data in request.stream: + async for data in self.stream: if self.stream.pos > max_size: raise HeaderExpectationFailed("Request body is too large.") body.append(data) self.body = b"".join(body) - self.stream = None + def respond(self, status=200, headers=None, content_type="text/html"): + headers = Header(headers or {}) + if "content-type" not in headers: + headers["content-type"] = content_type + return self.stream.respond(status, headers) @property def json(self): diff --git a/sanic/response.py b/sanic/response.py index eee066f101..92362c3d4d 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -116,80 +116,6 @@ def get_headers( return format_http1_response(self.status, self.headers.items()) -class NewStreamingHTTPResponse(BaseHTTPResponse): - __slots__ = ( - "stream", - "length", - "status", - "headers", - ) - - def __init__(self, stream, status, headers, content_type): - self.stream = stream - self.length = None - self.status = status - headers["Content-Type"] = content_type - self.headers = headers - - async def write(self, data=b"", end_stream=False): - """Writes a chunk of data to the streaming response. - :param data: bytes-ish data to be written. - """ - if self.chunked is None and self.length is None: - raise RuntimeError("Response body is closed") - - if not isinstance(data, (bytes, bytearray)): - data = self._encode_body(data) - size = len(data) - - if self.chunked: - if size and not end_stream: - data = b"%x\r\n%b\r\n" % (size, data) - elif size: - data = b"%x\r\n%b\r\n0\r\n\r\n" % (size, data) - self.chunked = None - else: - data = b"0\r\n\r\n" - self.chunked = None - else: - self.length -= size - if self.length < 0: - await self.stream.aclose() - raise RuntimeError( - "Response body larger than specified in content-length" - ) - # Prepend header block to data, if not yet sent - if self.status is not None: - data = format_http1_response(self.status, self.headers, data) - self.status = self.headers = None - await self.stream.send_all(data) - - async def aclose(self): - if self.chunked: - await self.write(end_stream=True) - elif self.length: - l, self.length = self.length, None - raise RuntimeError(f"Response aclosed with {l} bytes missing") - - def get_headers( - self, status=200, headers=None, content_type="text/plain" - ) -> bytes: - headers = Header(headers or {}) - - if "content-length" in headers: - self.length = int(headers["content-length"]) - self.chunked = False - else: - headers["transfer-encoding"] = "chunked" - self.chunked = True - - if "content-type" not in headers: - headers["content-type"] = content_type - - headers = self._headers_as_bytes(headers) - status_code = b"OK" if status == 200 else STATUS_CODES.get(status) - return b"HTTP/1.1 %d %b\r\n%b\r\n" % (status, status_code, headers) - class HTTPResponse(BaseHTTPResponse): __slots__ = ("body", "status", "content_type", "headers", "_cookies") diff --git a/sanic/server_trio.py b/sanic/server_trio.py index c65f61df61..6ce0937aa6 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -1,4 +1,3 @@ -import trio import os import socket import stat @@ -7,18 +6,20 @@ import traceback from functools import partial -from h2.config import H2Configuration -from h2.connection import H2Connection -from h2.events import RequestReceived, DataReceived, ConnectionTerminated from inspect import isawaitable from ipaddress import ip_address from multiprocessing import Process from signal import SIG_IGN, SIGINT, SIGTERM, Signals from signal import signal as signal_func -from time import time, sleep as time_sleep -from httptools.parser.errors import HttpParserError +from time import sleep as time_sleep +from time import time -SIGHUP = SIGTERM +import trio + +from h2.config import H2Configuration +from h2.connection import H2Connection +from h2.events import ConnectionTerminated, DataReceived, RequestReceived +from httptools.parser.errors import HttpParserError from sanic.compat import Header from sanic.exceptions import ( @@ -26,13 +27,22 @@ InvalidUsage, PayloadTooLarge, RequestTimeout, + SanicException, ServerError, - VersionNotSupported, ServiceUnavailable, + VersionNotSupported, ) from sanic.log import access_logger, logger +from sanic.protocol import H1Stream from sanic.request import EXPECT_HEADER, Request, StreamBuffer -from sanic.response import HTTPResponse, NewStreamingHTTPResponse +from sanic.response import HTTPResponse + + +try: + from signal import SIGHUP +except: + SIGHUP = SIGTERM + class Signal: @@ -52,6 +62,7 @@ class Signal: idle_connections = set() + def parse_h1_request(data: bytes) -> dict: try: data = data.decode() @@ -85,35 +96,6 @@ async def receive_some(self, max_bytes=None): stream.__class__ = PushbackStream -class H1StreamRequest: - __slots__ = "length", "pos", "stream", "set_timeout", "trigger_continue" - def __init__(self, headers, stream, set_timeout, trigger_continue): - self.length = int(headers.get("content-length")) - if self.length < 0: - raise InvalidUsage("Content-length must be positive") - self.pos = 0 - self.stream = stream - self.set_timeout = set_timeout - self.trigger_continue = trigger_continue - - async def __aiter__(self): - while True: - data = await self.read() - if not data: return - yield data - - async def read(self): - await self.trigger_continue() - if self.pos == self.length: return None - buf = await self.stream.receive_some() - if len(buf) > self.length: - push_back(self.stream, buf[self.length:]) - buf = buf[:self.length] - self.pos += len(buf) - # Extend or switch deadline - self.set_timeout("request" if self.length else "response") - return buf - class HttpProtocol: def __init__(self, **kwargs): @@ -143,9 +125,10 @@ async def sniff_protocol(self): if upgrade == "h2c": return self.http2(settings_header=headers["http2-settings"]) if upgrade == "websocket": - return self.websocket() + self.websocket = True + self.stream.push_back(buffer) return self.http1(headers=headers) - push_back(self.stream, buffer) + self.stream.push_back(buffer) if req == "ssl": if not self.ssl: raise RuntimeError("Only plain HTTP supported (not SSL).") @@ -160,13 +143,14 @@ async def sniff_protocol(self): return self.http2() def set_timeout(self, timeout: str): - self.nursery.cancel_scope.deadline = ( - trio.current_time() + getattr(self, f"{timeout}_timeout") + self.nursery.cancel_scope.deadline = trio.current_time() + getattr( + self, f"{timeout}_timeout" ) async def run(self, stream): assert not self.stream self.stream = stream + self.stream.push_back = partial(push_back, stream) try: async with stream, trio.open_nursery() as self.nursery: self.set_timeout("request") @@ -196,7 +180,8 @@ async def _receive_request_using(self, buffer: bytearray): pos = buffer.find(b"\r\n\r\n", prevpos) if pos > 0: req = buffer[:pos] - if req == b"PRI * HTTP/2.0": return "h2" + if req == b"PRI * HTTP/2.0": + return "h2" del buffer[: pos + 4] return req if buffer: @@ -221,53 +206,29 @@ async def http1(self, headers=None): app=self.app, ) need_continue = headers.get("expect", "").lower() == "100-continue" - async def trigger_continue(): - nonlocal need_continue - if need_continue is False: - return - await self.stream.send_all(b"HTTP/1.1 100 Continue\r\n\r\n") - need_continue = False + if "chunked" in headers.get("transfer-encoding", "").lower(): raise RuntimeError("Chunked requests not supported") # FIXME if "content-length" in headers: push_back(self.stream, buffer) del buffer[:] - request.stream = H1StreamRequest( - headers, - self.stream, - self.set_timeout, - trigger_continue, - ) - else: - self.set_timeout("response") + request.stream = H1Stream( + headers, self.stream, self.set_timeout, need_continue + ) headers = None - _response = None - # Implement request.respond: - async def respond(response=None, *, status=200, headers=None, content_type="text/html"): - nonlocal _response - if _response: - raise ServerError("Duplicate responses for a single request!") - await trigger_continue() - if response is None: - response = NewStreamingHTTPResponse(self.stream, status, headers, content_type) - # Middleware has a chance to replace or modify the response - response = await self.app._run_response_middleware( - request, response - ) - _response = response - if not isinstance(response, HTTPResponse): - raise ServerError(f"Handling {request.path}: HTTPResponse expected but got {type(response).__name__}") - await self.stream.send_all( - response.output("1.1", self.keep_alive, self.keep_alive_timeout) - ) - - request.respond = respond - await self.request_handler(request) - if not _response: - raise ServerError("Request handler made no response.") - if hasattr(_response, "aclose"): - await _response.aclose() - _response = None + try: + await self.request_handler(request) + except Exception as e: + r = self.app.error_handler.default(request, e) + try: + await request.stream.respond(r.status, r.headers).send( + data_bytes=r.body + ) + except RuntimeError: + pass # If we cannot send to client anymore + raise + finally: + await request.stream.aclose() self.set_timeout("request") async def h2_sender(self): @@ -301,7 +262,7 @@ async def http2(self, settings_header=None): self.nursery.start_soon( self.h2request, event.stream_id, event.headers ) - #idle_connections.discard(self.nursery.cancel_scope) + # idle_connections.discard(self.nursery.cancel_scope) if isinstance(event, ConnectionTerminated): return await self.send_some.send(...) @@ -478,7 +439,9 @@ def handler(s, tb): if workers: while True: while len(processes) < workers: - p = Process(target=trio.run, args=(acceptor, listeners, master_pid)) + p = Process( + target=trio.run, args=(acceptor, listeners, master_pid) + ) p.daemon = True p.start() processes.append(p) From af846949deee139c9efee507a30f46a1ab23f79a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Mon, 9 Sep 2019 12:21:57 +0300 Subject: [PATCH 25/31] Multiprocessing/server rewrite fully in async context with better signal handling. --- sanic/server_trio.py | 77 +++++++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 41 deletions(-) diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 6ce0937aa6..10c2afcdd4 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -8,7 +8,7 @@ from functools import partial from inspect import isawaitable from ipaddress import ip_address -from multiprocessing import Process +import multiprocessing as mp from signal import SIG_IGN, SIGINT, SIGTERM, Signals from signal import signal as signal_func from time import sleep as time_sleep @@ -401,10 +401,10 @@ def serve( backlog=backlog, workers=workers, ) - return server() if run_async else server() + return server() if run_async else trio.run(server) -def runserver(acceptor, host, port, sock, backlog, workers): +async def runserver(acceptor, host, port, sock, backlog, workers): if host and host.startswith("unix:"): open_listeners = partial( # Not Implemented: open_unix_listeners, path=host[5:], backlog=backlog @@ -417,54 +417,51 @@ def runserver(acceptor, host, port, sock, backlog, workers): backlog=backlog, ) try: - listeners = trio.run(open_listeners) + listeners = await open_listeners() except Exception: logger.exception("Unable to start server") return for l in listeners: l.socket.set_inheritable(True) - master_pid = os.getpid() - runworker = lambda: trio.run(acceptor, listeners, master_pid) processes = [] - # Setup signal handlers to avoid crashing - sig = None - def handler(s, tb): - nonlocal sig - sig = s - - for s in (SIGINT, SIGTERM, SIGHUP): - signal_func(s, handler) - - if workers: - while True: - while len(processes) < workers: - p = Process( - target=trio.run, args=(acceptor, listeners, master_pid) - ) - p.daemon = True - p.start() - processes.append(p) - time_sleep(0.1) # Poll for dead processes - processes = [p for p in processes if p.is_alive()] - s, sig = sig, None - if not s: - continue + try: + if workers: + # Spawn method is consistent across platforms and, unlike the fork + # method, can be used from within async functions (such as this one). + mp.set_start_method("spawn") + with trio.open_signal_receiver(SIGINT, SIGTERM, SIGHUP) as sigiter: + while True: + while len(processes) < workers: + p = mp.Process( + target=trio.run, args=(acceptor, listeners) + ) + p.daemon = True + p.start() + processes.append(p) + # Wait for signals and periodically check processes + with trio.move_on_after(0.1): + s = await sigiter.__anext__() + if s in (SIGTERM, SIGINT): + break + logger.info("SIGHUP: Restarting all workers!") + for p in processes: + p.terminate() + processes = [p for p in processes if p.is_alive()] + else: # workers=0 single-process mode + await acceptor(listeners) + finally: + with trio.CancelScope() as cs: + cs.shield = True + # Close listeners and wait for workers to terminate + for l in listeners: + await l.aclose() for p in processes: - os.kill(p.pid, SIGHUP) - if s in (SIGINT, SIGTERM): - break - for l in listeners: - trio.run(l.aclose) - for p in processes: - p.join() - else: - runworker() + p.join() async def runaccept( listeners, - master_pid, before_start, after_start, before_stop, @@ -494,8 +491,6 @@ async def runaccept( ) as sigiter: s = await sigiter.__anext__() logger.info(f"Received {Signals(s).name}") - if s != SIGHUP: - os.kill(master_pid, SIGTERM) acceptor.cancel_scope.cancel() now = trio.current_time() for c in idle_connections: From e0ce6d26eaad72430632c1edc3fc391e383a2f8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Mon, 9 Sep 2019 15:09:45 +0300 Subject: [PATCH 26/31] Cleanup and error handling. --- sanic/server_trio.py | 72 ++++++++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 10c2afcdd4..f89782f329 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -5,6 +5,7 @@ import time import traceback +from enum import Enum from functools import partial from inspect import isawaitable from ipaddress import ip_address @@ -48,6 +49,7 @@ class Signal: stopped = False +SSL, H2 = Enum("Protocol", "SSL H2") h2config = H2Configuration( client_side=False, @@ -61,7 +63,15 @@ class Signal: idle_connections = set() +quit = trio.Event() +def trigger_graceful_exit(): + """Signals all running connections to terminate smoothly.""" + # Disallow new requests + quit.set() + # Promptly terminate idle connections + for c in idle_connections: + c.cancel() def parse_h1_request(data: bytes) -> dict: try: @@ -116,8 +126,7 @@ def servername_callback(sock, req_hostname, cb_context): self.alpn = self.stream.selected_alpn_protocol() async def sniff_protocol(self): - buffer = bytearray() - req = await self._receive_request_using(buffer) + req = await self.receive_request() if isinstance(req, bytearray): # HTTP1 but might be Upgrade to websocket or h2c headers = parse_h1_request(req) @@ -126,10 +135,8 @@ async def sniff_protocol(self): return self.http2(settings_header=headers["http2-settings"]) if upgrade == "websocket": self.websocket = True - self.stream.push_back(buffer) return self.http1(headers=headers) - self.stream.push_back(buffer) - if req == "ssl": + if req is SSL: if not self.ssl: raise RuntimeError("Only plain HTTP supported (not SSL).") await self.ssl_init() @@ -139,7 +146,7 @@ async def sniff_protocol(self): return self.http2() raise RuntimeError(f"Unknown ALPN {self.alpn}") # HTTP2 (not Upgrade) - if req == "h2": + if req is H2: return self.http2() def set_timeout(self, timeout: str): @@ -151,49 +158,51 @@ async def run(self, stream): assert not self.stream self.stream = stream self.stream.push_back = partial(push_back, stream) - try: - async with stream, trio.open_nursery() as self.nursery: + async with stream, trio.open_nursery() as self.nursery: + try: self.set_timeout("request") protocol_coroutine = await self.sniff_protocol() if not protocol_coroutine: return await protocol_coroutine self.nursery.cancel_scope.cancel() # Terminate all connections - except trio.BrokenResourceError: - pass # Connection reset by peer - except Exception: - logger.exception("Error in server") - finally: - idle_connections.discard(self.nursery.cancel_scope) - - async def _receive_request_using(self, buffer: bytearray): + except trio.BrokenResourceError: + pass # Connection reset by peer + except Exception: + logger.exception("Error in server") + finally: + idle_connections.discard(self.nursery.cancel_scope) + + async def receive_request(self): idle_connections.add(self.nursery.cancel_scope) with trio.fail_after(self.request_timeout): + buffer = bytearray() async for data in self.stream: idle_connections.discard(self.nursery.cancel_scope) prevpos = max(0, len(buffer) - 3) buffer += data - if buffer[0] < 0x20: - return "ssl" + if buffer[0] == 0x16: + self.stream.push_back(buffer) + return SSL if len(buffer) > self.request_max_size: raise RuntimeError("Request larger than request_max_size") pos = buffer.find(b"\r\n\r\n", prevpos) if pos > 0: req = buffer[:pos] if req == b"PRI * HTTP/2.0": - return "h2" - del buffer[: pos + 4] + self.stream.push_back(buffer) + return H2 + self.stream.push_back(buffer[pos+4:]) return req if buffer: - raise RuntimeError(f"Peer disconnected after {buffer!r}") + raise RuntimeError(f"Peer disconnected after {buffer!r:.200}") async def http1(self, headers=None): - buffer = bytearray() _response = None - while True: + while not quit.is_set(): # Process request if headers is None: - req = await self._receive_request_using(buffer) + req = await self.receive_request() if not req: return headers = parse_h1_request(req) @@ -209,15 +218,15 @@ async def http1(self, headers=None): if "chunked" in headers.get("transfer-encoding", "").lower(): raise RuntimeError("Chunked requests not supported") # FIXME - if "content-length" in headers: - push_back(self.stream, buffer) - del buffer[:] request.stream = H1Stream( headers, self.stream, self.set_timeout, need_continue ) headers = None try: await self.request_handler(request) + except trio.BrokenResourceError: + logger.info(f"Client disconnected during {request.method} {request.path}") + return except Exception as e: r = self.app.error_handler.default(request, e) try: @@ -474,7 +483,7 @@ async def runaccept( logger.info("Starting worker [%s]", pid) async with trio.open_nursery() as main_nursery: await trigger_events(before_start) - # Accept connections until a signal is received, then perform graceful exit + # Accept connections until a signal is received async with trio.open_nursery() as acceptor: acceptor.start_soon( partial( @@ -492,12 +501,11 @@ async def runaccept( s = await sigiter.__anext__() logger.info(f"Received {Signals(s).name}") acceptor.cancel_scope.cancel() - now = trio.current_time() - for c in idle_connections: - c.cancel() + # No longer accepting new connections. Attempt graceful exit. main_nursery.cancel_scope.deadline = ( - now + graceful_shutdown_timeout + trio.current_time() + graceful_shutdown_timeout ) + trigger_graceful_exit() await trigger_events(before_stop) await trigger_events(after_stop) logger.info(f"Gracefully finished worker [{pid}]") From c4a4fc6838aba4a1b2ff85c9c82322da0c7dcec7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Mon, 9 Sep 2019 16:35:12 +0300 Subject: [PATCH 27/31] Autoreloader should now work with any number of workers, on any OS. --- sanic/app.py | 15 +------ sanic/reloader_helpers.py | 95 +++++---------------------------------- sanic/server_trio.py | 62 +++++++++++++------------ 3 files changed, 48 insertions(+), 124 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index ca63b8cd20..ac63d6abc9 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1117,13 +1117,8 @@ def run( "#asynchronous-support" ) - # Default auto_reload to false - auto_reload = False - # If debug is set, default it to true (unless on windows) - if debug and os.name == "posix": - auto_reload = True - # Allow for overriding either of the defaults - auto_reload = kwargs.get("auto_reload", auto_reload) + # Allow for overriding the default of following debug mode setting + auto_reload = kwargs.get("auto_reload", debug) if sock is None: host, port = host or "127.0.0.1", port or 8000 @@ -1158,11 +1153,6 @@ def run( try: self.is_running = True - if auto_reload and os.name != "posix": - # This condition must be removed after implementing - # auto reloader for other operating systems. - raise NotImplementedError - if ( auto_reload and os.environ.get("SANIC_SERVER_RUNNING") != "true" @@ -1177,7 +1167,6 @@ def run( raise finally: self.is_running = False - logger.info("Server Stopped") def stop(self): """This kills the Sanic""" diff --git a/sanic/reloader_helpers.py b/sanic/reloader_helpers.py index 5e1338a43b..c96d6bcf91 100644 --- a/sanic/reloader_helpers.py +++ b/sanic/reloader_helpers.py @@ -3,7 +3,6 @@ import subprocess import sys -from multiprocessing import Process from time import sleep @@ -56,80 +55,7 @@ def restart_with_reloader(): args = _get_args_for_reloading() new_environ = os.environ.copy() new_environ["SANIC_SERVER_RUNNING"] = "true" - cmd = " ".join(args) - worker_process = Process( - target=subprocess.call, - args=(cmd,), - kwargs={"cwd": cwd, "shell": True, "env": new_environ}, - ) - worker_process.start() - return worker_process - - -def kill_process_children_unix(pid): - """Find and kill child processes of a process (maximum two level). - - :param pid: PID of parent process (process ID) - :return: Nothing - """ - root_process_path = "/proc/{pid}/task/{pid}/children".format(pid=pid) - if not os.path.isfile(root_process_path): - return - with open(root_process_path) as children_list_file: - children_list_pid = children_list_file.read().split() - - for child_pid in children_list_pid: - children_proc_path = "/proc/%s/task/%s/children" % ( - child_pid, - child_pid, - ) - if not os.path.isfile(children_proc_path): - continue - with open(children_proc_path) as children_list_file_2: - children_list_pid_2 = children_list_file_2.read().split() - for _pid in children_list_pid_2: - try: - os.kill(int(_pid), signal.SIGTERM) - except ProcessLookupError: - continue - try: - os.kill(int(child_pid), signal.SIGTERM) - except ProcessLookupError: - continue - - -def kill_process_children_osx(pid): - """Find and kill child processes of a process. - - :param pid: PID of parent process (process ID) - :return: Nothing - """ - subprocess.run(["pkill", "-P", str(pid)]) - - -def kill_process_children(pid): - """Find and kill child processes of a process. - - :param pid: PID of parent process (process ID) - :return: Nothing - """ - if sys.platform == "darwin": - kill_process_children_osx(pid) - elif sys.platform == "linux": - kill_process_children_unix(pid) - else: - pass # should signal error here - - -def kill_program_completely(proc): - """Kill worker and it's child processes and exit. - - :param proc: worker process (process ID) - :return: Nothing - """ - kill_process_children(proc.pid) - proc.terminate() - os._exit(0) + return subprocess.Popen(args, cwd=cwd, env=new_environ) def watchdog(sleep_interval): @@ -140,13 +66,14 @@ def watchdog(sleep_interval): """ mtimes = {} worker_process = restart_with_reloader() - signal.signal( - signal.SIGTERM, lambda *args: kill_program_completely(worker_process) - ) - signal.signal( - signal.SIGINT, lambda *args: kill_program_completely(worker_process) - ) - while True: + quit = False + def terminate(sig, frame): + nonlocal quit + quit = True + worker_process.terminate() + signal.signal(signal.SIGTERM, terminate) + signal.signal(signal.SIGINT, terminate) + while not quit: need_reload = False for filename in _iter_module_files(): @@ -162,8 +89,10 @@ def watchdog(sleep_interval): mtimes[filename] = mtime need_reload = True + if quit or worker_process.poll(): + return + if need_reload: - kill_process_children(worker_process.pid) worker_process.terminate() worker_process = restart_with_reloader() diff --git a/sanic/server_trio.py b/sanic/server_trio.py index f89782f329..0876439b95 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -440,6 +440,7 @@ async def runserver(acceptor, host, port, sock, backlog, workers): # method, can be used from within async functions (such as this one). mp.set_start_method("spawn") with trio.open_signal_receiver(SIGINT, SIGTERM, SIGHUP) as sigiter: + logger.info(f"Server starting, {workers} worker processes") while True: while len(processes) < workers: p = mp.Process( @@ -448,16 +449,18 @@ async def runserver(acceptor, host, port, sock, backlog, workers): p.daemon = True p.start() processes.append(p) + logger.info("Worker [%s]", p.pid) # Wait for signals and periodically check processes with trio.move_on_after(0.1): s = await sigiter.__anext__() - if s in (SIGTERM, SIGINT): - break - logger.info("SIGHUP: Restarting all workers!") + logger.info(f"Server received {Signals(s).name}") for p in processes: p.terminate() + if s in (SIGTERM, SIGINT): + break processes = [p for p in processes if p.is_alive()] - else: # workers=0 single-process mode + else: # workers=0 + logger.info("Worker starting") await acceptor(listeners) finally: with trio.CancelScope() as cs: @@ -467,6 +470,8 @@ async def runserver(acceptor, host, port, sock, backlog, workers): await l.aclose() for p in processes: p.join() + logger.info("Server stopped") + async def runaccept( @@ -479,35 +484,36 @@ async def runaccept( graceful_shutdown_timeout, ): try: - pid = os.getpid() - logger.info("Starting worker [%s]", pid) async with trio.open_nursery() as main_nursery: await trigger_events(before_start) # Accept connections until a signal is received - async with trio.open_nursery() as acceptor: - acceptor.start_soon( - partial( - trio.serve_listeners, - handler=lambda stream: proto().run(stream), - listeners=listeners, - handler_nursery=main_nursery, + with trio.open_signal_receiver(SIGINT, SIGTERM, SIGHUP) as sigiter: + async with trio.open_nursery() as acceptor: + acceptor.start_soon( + partial( + trio.serve_listeners, + handler=lambda stream: proto().run(stream), + listeners=listeners, + handler_nursery=main_nursery, + ) ) - ) - await trigger_events(after_start) - # Wait for a signal and then exit gracefully - with trio.open_signal_receiver( - SIGINT, SIGTERM, SIGHUP - ) as sigiter: + await trigger_events(after_start) + # Wait for a signal and then exit gracefully s = await sigiter.__anext__() logger.info(f"Received {Signals(s).name}") acceptor.cancel_scope.cancel() - # No longer accepting new connections. Attempt graceful exit. - main_nursery.cancel_scope.deadline = ( - trio.current_time() + graceful_shutdown_timeout - ) - trigger_graceful_exit() - await trigger_events(before_stop) + # No longer accepting new connections. Attempt graceful exit. + main_nursery.cancel_scope.deadline = ( + trio.current_time() + graceful_shutdown_timeout + ) + trigger_graceful_exit() + main_nursery.start_soon(trigger_events, before_stop) + # Eat any extra signals (if server and workers were signaled) + with trio.move_on_after(0.01): + async for s in sigiter: + pass + # Now any further signals will cause stacktraces await trigger_events(after_stop) - logger.info(f"Gracefully finished worker [{pid}]") - except BaseException as e: - logger.exception(f"Stopped worker [{pid}]") + logger.info(f"Worker finished gracefully") + except BaseException: + logger.exception(f"Worker terminating") From 87bdb90da2b3f100425b9d8a4b436437929433fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Mon, 9 Sep 2019 18:21:33 +0300 Subject: [PATCH 28/31] Rework reloader. --- sanic/app.py | 15 ++++++-------- sanic/reloader_helpers.py | 19 ++++++++++++++++-- sanic/server_trio.py | 41 ++++++++++++++++++++++++--------------- 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index ac63d6abc9..4b96e643e7 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1116,6 +1116,9 @@ def run( "https://sanic.readthedocs.io/en/latest/sanic/deploying.html" "#asynchronous-support" ) + self.is_first_process = ( + os.environ.get("SANIC_SERVER_RUNNING") != "true" + ) # Allow for overriding the default of following debug mode setting auto_reload = kwargs.get("auto_reload", debug) @@ -1153,10 +1156,7 @@ def run( try: self.is_running = True - if ( - auto_reload - and os.environ.get("SANIC_SERVER_RUNNING") != "true" - ): + if auto_reload and self.is_first_process: reloader_helpers.watchdog(2) else: serve(**server_settings, workers=workers) @@ -1395,10 +1395,7 @@ def _helper( if self.configure_logging and debug: logger.setLevel(logging.DEBUG) - if ( - self.config.LOGO - and os.environ.get("SANIC_SERVER_RUNNING") != "true" - ): + if self.config.LOGO and self.is_first_process: logger.debug( self.config.LOGO if isinstance(self.config.LOGO, str) @@ -1409,7 +1406,7 @@ def _helper( server_settings["run_async"] = True # Serve - if host and port and os.environ.get("SANIC_SERVER_RUNNING") != "true": + if host and port and self.is_first_process: proto = "http" if ssl is not None: proto = "https" diff --git a/sanic/reloader_helpers.py b/sanic/reloader_helpers.py index c96d6bcf91..bb5de04f97 100644 --- a/sanic/reloader_helpers.py +++ b/sanic/reloader_helpers.py @@ -57,6 +57,19 @@ def restart_with_reloader(): new_environ["SANIC_SERVER_RUNNING"] = "true" return subprocess.Popen(args, cwd=cwd, env=new_environ) +def join(worker_process): + try: + # Graceful + worker_process.terminate() + worker_process.wait(2) + except subprocess.TimeoutExpired: + # Not so graceful + try: + worker_process.terminate() + worker_process.wait(1) + except subprocess.TimeoutExpired: + worker_process.kill() + worker_process.wait() def watchdog(sleep_interval): """Watch project files, restart worker process if a change happened. @@ -89,11 +102,13 @@ def terminate(sig, frame): mtimes[filename] = mtime need_reload = True - if quit or worker_process.poll(): + if worker_process.poll(): return if need_reload: - worker_process.terminate() + join(worker_process) worker_process = restart_with_reloader() sleep(sleep_interval) + + join(worker_process) \ No newline at end of file diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 0876439b95..5ebc9d47f6 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -472,7 +472,20 @@ async def runserver(acceptor, host, port, sock, backlog, workers): p.join() logger.info("Server stopped") - +async def sighandler(scopes, task_status=trio.TASK_STATUS_IGNORED): + with trio.open_signal_receiver(SIGINT, SIGTERM, SIGHUP) as sigiter: + t = None + task_status.started() + async for s in sigiter: + # Ignore spuriously repeated signals + if t is not None and trio.current_time() - t < 0.5: + logger.debug(f"Ignored {Signals(s).name}") + continue + logger.info(f"Received {Signals(s).name}") + if not scopes: + raise trio.Cancelled("Signaled too many times") + scopes.pop().cancel() + t = trio.current_time() async def runaccept( listeners, @@ -484,11 +497,15 @@ async def runaccept( graceful_shutdown_timeout, ): try: - async with trio.open_nursery() as main_nursery: - await trigger_events(before_start) - # Accept connections until a signal is received - with trio.open_signal_receiver(SIGINT, SIGTERM, SIGHUP) as sigiter: + async with trio.open_nursery() as signal_nursery: + cancel_scopes = [signal_nursery.cancel_scope] + sigscope = await signal_nursery.start(sighandler, cancel_scopes) + async with trio.open_nursery() as main_nursery: + cancel_scopes.append(main_nursery.cancel_scope) + await trigger_events(before_start) + # Accept connections until a signal is received async with trio.open_nursery() as acceptor: + cancel_scopes.append(acceptor.cancel_scope) acceptor.start_soon( partial( trio.serve_listeners, @@ -498,22 +515,14 @@ async def runaccept( ) ) await trigger_events(after_start) - # Wait for a signal and then exit gracefully - s = await sigiter.__anext__() - logger.info(f"Received {Signals(s).name}") - acceptor.cancel_scope.cancel() # No longer accepting new connections. Attempt graceful exit. main_nursery.cancel_scope.deadline = ( trio.current_time() + graceful_shutdown_timeout ) trigger_graceful_exit() - main_nursery.start_soon(trigger_events, before_stop) - # Eat any extra signals (if server and workers were signaled) - with trio.move_on_after(0.01): - async for s in sigiter: - pass - # Now any further signals will cause stacktraces - await trigger_events(after_stop) + await trigger_events(before_stop) + await trigger_events(after_stop) + signal_nursery.cancel_scope.cancel() # Exit signal handler logger.info(f"Worker finished gracefully") except BaseException: logger.exception(f"Worker terminating") From 53f088e0cf3c5584ece8f61de5082970a9514b78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Mon, 9 Sep 2019 18:33:15 +0300 Subject: [PATCH 29/31] Faster and more correct reload and exit. --- sanic/reloader_helpers.py | 73 ++++++++++++++++++++------------------- sanic/server_trio.py | 2 +- 2 files changed, 39 insertions(+), 36 deletions(-) diff --git a/sanic/reloader_helpers.py b/sanic/reloader_helpers.py index bb5de04f97..eac192a31f 100644 --- a/sanic/reloader_helpers.py +++ b/sanic/reloader_helpers.py @@ -77,38 +77,41 @@ def watchdog(sleep_interval): :param sleep_interval: interval in second. :return: Nothing """ - mtimes = {} - worker_process = restart_with_reloader() - quit = False - def terminate(sig, frame): - nonlocal quit - quit = True - worker_process.terminate() - signal.signal(signal.SIGTERM, terminate) - signal.signal(signal.SIGINT, terminate) - while not quit: - need_reload = False - - for filename in _iter_module_files(): - try: - mtime = os.stat(filename).st_mtime - except OSError: - continue - - old_time = mtimes.get(filename) - if old_time is None: - mtimes[filename] = mtime - elif mtime > old_time: - mtimes[filename] = mtime - need_reload = True - - if worker_process.poll(): - return - - if need_reload: - join(worker_process) - worker_process = restart_with_reloader() - - sleep(sleep_interval) - - join(worker_process) \ No newline at end of file + try: + mtimes = {} + worker_process = restart_with_reloader() + quit = False + def terminate(sig, frame): + nonlocal quit + quit = True + worker_process.terminate() + signal.signal(signal.SIGTERM, terminate) + signal.signal(signal.SIGINT, terminate) + while not quit: + for i in range(10): + sleep(sleep_interval / 10) + if quit or worker_process.poll(): + return + + need_reload = False + + for filename in _iter_module_files(): + try: + mtime = os.stat(filename).st_mtime + except OSError: + continue + + old_time = mtimes.get(filename) + if old_time is None: + mtimes[filename] = mtime + elif mtime > old_time: + mtimes[filename] = mtime + need_reload = True + + if need_reload: + worker_process.terminate() + sleep(0.1) + worker_process = restart_with_reloader() + + finally: + join(worker_process) diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 5ebc9d47f6..f4d8d4bafd 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -460,7 +460,7 @@ async def runserver(acceptor, host, port, sock, backlog, workers): break processes = [p for p in processes if p.is_alive()] else: # workers=0 - logger.info("Worker starting") + logger.info("Server and worker started") await acceptor(listeners) finally: with trio.CancelScope() as cs: From 870b6012a1b4f8ad4e1b4569a0a22a8854bcbb49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Tue, 10 Sep 2019 17:31:09 +0300 Subject: [PATCH 30/31] Optimized to run faster. --- sanic/app.py | 2 +- sanic/protocol.py | 9 +---- sanic/request.py | 2 +- sanic/server_trio.py | 81 +++++++++++++++++++++++++------------------- 4 files changed, 49 insertions(+), 45 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 4b96e643e7..2df3a09df5 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1030,7 +1030,7 @@ async def handle_request_trio(self, request): request.endpoint = self._build_endpoint_name(*bp, handler.__name__) request.uri_template = uri # Load header body before starting handler? - if not hasattr(handler, "is_stream"): + if request.stream.length and not hasattr(handler, "is_stream"): await request.receive_body() # Run main handler response = handler(request, *args, **kwargs) diff --git a/sanic/protocol.py b/sanic/protocol.py index bb123218a6..66d92ba202 100644 --- a/sanic/protocol.py +++ b/sanic/protocol.py @@ -12,15 +12,13 @@ class H1Stream: __slots__ = ("stream", "length", "pos", "set_timeout", "response_state", "status", "headers", "bytes_left") - def __init__(self, headers, stream, set_timeout, need_continue): + def __init__(self, headers, stream, need_continue): self.length = int(headers.get("content-length", "0")) assert self.length >= 0 self.pos = None if need_continue else 0 self.stream = stream self.status = self.bytes_left = None self.response_state = 0 - self.set_timeout = set_timeout - self.update_deadline() async def aclose(self): # Finish sending a response (if no error) @@ -30,10 +28,6 @@ async def aclose(self): if self.pos != self.length or self.response_state != 2: await self.stream.aclose() # If not, must disconnect :( - def update_deadline(self): - # Extend or switch deadline - self.set_timeout("request" if self.pos is not None and self.pos < self.length else "request") - # Request methods def dont_continue(self): @@ -63,7 +57,6 @@ async def read(self): self.stream.push_back(buf[self.length :]) buf = buf[: self.length] self.pos += len(buf) - self.update_deadline() return buf # Response methods diff --git a/sanic/request.py b/sanic/request.py index 594126e6a8..fa68c0b18f 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -136,7 +136,7 @@ def body_finish(self): self.body = b"".join(self.body) async def receive_body(self): - if not self.stream.pos: + if self.stream.length and not self.stream.pos: max_size = self.app.config.REQUEST_MAX_SIZE body = [] if self.stream.length > max_size: diff --git a/sanic/server_trio.py b/sanic/server_trio.py index f4d8d4bafd..0bf9c36c91 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -83,16 +83,14 @@ def parse_h1_request(data: bytes) -> dict: if version != "HTTP/1.1": raise VersionNotSupported(f"Expected 'HTTP/1.1', got '{version}'") headers = {":method": method, ":path": path} - for name, value in (h.split(": ", 1) for h in hlines): + for name, val in (h.split(": ", 1) for h in hlines): name = name.lower() - old = headers.get(name) - headers[name] = value if old is None else f"{old}, {value}" + headers[name] = f"{headers[name]}, {val}" if name in headers else val return headers def push_back(stream, data): - if not data: - return + assert data, "Ensure that data is not empty before calling this function." stream_type = type(stream) class PushbackStream(stream_type): @@ -127,15 +125,6 @@ def servername_callback(sock, req_hostname, cb_context): async def sniff_protocol(self): req = await self.receive_request() - if isinstance(req, bytearray): - # HTTP1 but might be Upgrade to websocket or h2c - headers = parse_h1_request(req) - upgrade = headers.get("upgrade") - if upgrade == "h2c": - return self.http2(settings_header=headers["http2-settings"]) - if upgrade == "websocket": - self.websocket = True - return self.http1(headers=headers) if req is SSL: if not self.ssl: raise RuntimeError("Only plain HTTP supported (not SSL).") @@ -148,6 +137,14 @@ async def sniff_protocol(self): # HTTP2 (not Upgrade) if req is H2: return self.http2() + # HTTP1 but might be Upgrade to websocket or h2c + headers = parse_h1_request(req) + upgrade = headers.get("upgrade") + if upgrade == "h2c": + return self.http2(settings_header=headers["http2-settings"]) + if upgrade == "websocket": + self.websocket = True + return self.http1(headers=headers) def set_timeout(self, timeout: str): self.nursery.cancel_scope.deadline = trio.current_time() + getattr( @@ -175,25 +172,32 @@ async def run(self, stream): async def receive_request(self): idle_connections.add(self.nursery.cancel_scope) - with trio.fail_after(self.request_timeout): - buffer = bytearray() - async for data in self.stream: - idle_connections.discard(self.nursery.cancel_scope) - prevpos = max(0, len(buffer) - 3) + buffer = None + prevpos = 0 + async for data in self.stream: + idle_connections.discard(self.nursery.cancel_scope) + if buffer is None: + buffer = data + else: + # Headers normally come in one packet, so this may be slower buffer += data - if buffer[0] == 0x16: + buflen = len(buffer) + if buffer[0] == 0x16: + self.stream.push_back(buffer) + return SSL + if buflen > self.request_max_size: + raise RuntimeError("Request larger than request_max_size") + pos = buffer.find(b"\r\n\r\n", prevpos) + if pos > 0: + req = buffer[:pos] + if req == b"PRI * HTTP/2.0": self.stream.push_back(buffer) - return SSL - if len(buffer) > self.request_max_size: - raise RuntimeError("Request larger than request_max_size") - pos = buffer.find(b"\r\n\r\n", prevpos) - if pos > 0: - req = buffer[:pos] - if req == b"PRI * HTTP/2.0": - self.stream.push_back(buffer) - return H2 - self.stream.push_back(buffer[pos+4:]) - return req + return H2 + pos += 4 # Skip header and its trailing \r\n\r\n + if buflen > pos: + self.stream.push_back(buffer[pos:]) + return req + prevpos = buflen - 3 # \r\n\r\n may cross packet boundary if buffer: raise RuntimeError(f"Peer disconnected after {buffer!r:.200}") @@ -206,6 +210,7 @@ async def http1(self, headers=None): if not req: return headers = parse_h1_request(req) + request = self.request_class( url_bytes=headers[":path"].encode(), headers=Header(headers), @@ -218,10 +223,12 @@ async def http1(self, headers=None): if "chunked" in headers.get("transfer-encoding", "").lower(): raise RuntimeError("Chunked requests not supported") # FIXME - request.stream = H1Stream( - headers, self.stream, self.set_timeout, need_continue - ) + + request.stream = H1Stream(headers, self.stream, need_continue) headers = None + # Timeout between consecutive requests, reset *here* to minimize + # chances of timeouting request handlers. + self.set_timeout("request") try: await self.request_handler(request) except trio.BrokenResourceError: @@ -238,7 +245,6 @@ async def http1(self, headers=None): raise finally: await request.stream.aclose() - self.set_timeout("request") async def h2_sender(self): async for _ in self.can_send: @@ -487,6 +493,11 @@ async def sighandler(scopes, task_status=trio.TASK_STATUS_IGNORED): scopes.pop().cancel() t = trio.current_time() +async def runbench(stream): + async for d in stream: + if d[-4:] == b"\r\n\r\n": + await stream.send_all(b"HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\ncontent-length: 13\r\n\r\nHello World!\n") + async def runaccept( listeners, before_start, From dda029657bb13bc53f6c26dab137794ee28ea9f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= Date: Tue, 10 Sep 2019 18:43:58 +0300 Subject: [PATCH 31/31] Fix receive_request return value handling (should probably use exceptions instead). --- sanic/server_trio.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sanic/server_trio.py b/sanic/server_trio.py index 0bf9c36c91..8902dcc45e 100644 --- a/sanic/server_trio.py +++ b/sanic/server_trio.py @@ -137,6 +137,8 @@ async def sniff_protocol(self): # HTTP2 (not Upgrade) if req is H2: return self.http2() + if req is None: + return None # HTTP1 but might be Upgrade to websocket or h2c headers = parse_h1_request(req) upgrade = headers.get("upgrade") @@ -207,7 +209,7 @@ async def http1(self, headers=None): # Process request if headers is None: req = await self.receive_request() - if not req: + if not isinstance(req, (bytes, bytearray)): return headers = parse_h1_request(req)