diff --git a/sanic/app.py b/sanic/app.py index 58c785b095..34016ddc48 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -22,15 +22,13 @@ from sanic.exceptions import SanicException, ServerError, URLBuildError from sanic.handlers import ErrorHandler from sanic.log import LOGGING_CONFIG_DEFAULTS, error_logger, logger -from sanic.response import HTTPResponse, StreamingHTTPResponse -from sanic.router import Router -from sanic.server import ( - AsyncioServer, - HttpProtocol, - Signal, - serve, - serve_multiple, +from sanic.response import ( + BaseHTTPResponse, + HTTPResponse, + StreamingHTTPResponse, ) +from sanic.router import Router +from sanic.server import HttpProtocol, Signal, serve from sanic.static import register as static_register from sanic.testing import SanicASGITestClient, SanicTestClient from sanic.views import CompositionView @@ -1019,7 +1017,41 @@ async def handle_request(self, request, write_callback, stream_callback): # - Add exception handling pass else: - write_callback(response) + await write_callback(response) + + async def handle_request_trio(self, request): + # Request middleware + response = await self._run_request_middleware(request) + if response: + return await request.respond(response) + # Fetch handler from router + handler, args, kwargs, uri = self.router.get(request) + if handler is None: + raise ServerError( + "None was returned while requesting a handler from the router" + ) + bp = getattr(handler, "__blueprintname__", None) + bp = (bp,) if bp else () + request.endpoint = self._build_endpoint_name(*bp, handler.__name__) + request.uri_template = uri + # Load header body before starting handler? + if request.stream.length and not hasattr(handler, "is_stream"): + await request.receive_body() + # Run main handler + response = handler(request, *args, **kwargs) + if isawaitable(response): + response = await response + # Returned (non-streaming) response + if isinstance(response, BaseHTTPResponse): + await request.respond( + status=response.status, + headers=response.headers, + content_type=response.content_type, + ).send(data_bytes=response.body, end_stream=True) + elif response is not None: + raise ServerError( + f"Handling {request.path}: HTTPResponse expected but got {type(response).__name__} {response!r:.200}" + ) # -------------------------------------------------------------------- # # Testing @@ -1050,7 +1082,7 @@ def run( stop_event: Any = None, register_sys_signals: bool = True, access_log: Optional[bool] = None, - **kwargs: Any + **kwargs: Any, ) -> None: """Run the HTTP Server and listen until keyboard interrupt or term signal. On termination, drain connections before closing. @@ -1089,14 +1121,12 @@ def run( "https://sanic.readthedocs.io/en/latest/sanic/deploying.html" "#asynchronous-support" ) + self.is_first_process = ( + os.environ.get("SANIC_SERVER_RUNNING") != "true" + ) - # Default auto_reload to false - auto_reload = False - # If debug is set, default it to true (unless on windows) - if debug and os.name == "posix": - auto_reload = True - # Allow for overriding either of the defaults - auto_reload = kwargs.get("auto_reload", auto_reload) + # Allow for overriding the default of following debug mode setting + auto_reload = kwargs.get("auto_reload", debug) if sock is None: host, port = host or "127.0.0.1", port or 8000 @@ -1131,21 +1161,10 @@ def run( try: self.is_running = True - if workers == 1: - if auto_reload and os.name != "posix": - # This condition must be removed after implementing - # auto reloader for other operating systems. - raise NotImplementedError - - if ( - auto_reload - and os.environ.get("SANIC_SERVER_RUNNING") != "true" - ): - reloader_helpers.watchdog(2) - else: - serve(**server_settings) + if auto_reload and self.is_first_process: + reloader_helpers.watchdog(2) else: - serve_multiple(server_settings, workers) + serve(**server_settings, workers=workers) except BaseException: error_logger.exception( "Experienced exception while trying to serve" @@ -1153,7 +1172,6 @@ def run( raise finally: self.is_running = False - logger.info("Server Stopped") def stop(self): """This kills the Sanic""" @@ -1312,6 +1330,7 @@ def _helper( raise ValueError("SSLContext or certificate and key required.") context = create_default_context(purpose=Purpose.CLIENT_AUTH) context.load_cert_chain(cert, keyfile=key) + context.set_alpn_protocols(["h2", "http/1.1"]) ssl = context if stop_event is not None: if debug: @@ -1342,7 +1361,7 @@ def _helper( "app": self, "signal": Signal(), "debug": debug, - "request_handler": self.handle_request, + "request_handler": self.handle_request_trio, "error_handler": self.error_handler, "request_timeout": self.config.REQUEST_TIMEOUT, "response_timeout": self.config.RESPONSE_TIMEOUT, @@ -1381,10 +1400,7 @@ def _helper( if self.configure_logging and debug: logger.setLevel(logging.DEBUG) - if ( - self.config.LOGO - and os.environ.get("SANIC_SERVER_RUNNING") != "true" - ): + if self.config.LOGO and self.is_first_process: logger.debug( self.config.LOGO if isinstance(self.config.LOGO, str) @@ -1395,7 +1411,7 @@ def _helper( server_settings["run_async"] = True # Serve - if host and port and os.environ.get("SANIC_SERVER_RUNNING") != "true": + if host and port and self.is_first_process: proto = "http" if ssl is not None: proto = "https" diff --git a/sanic/cookies.py b/sanic/cookies.py index 19907945e1..ed672fba1a 100644 --- a/sanic/cookies.py +++ b/sanic/cookies.py @@ -130,6 +130,10 @@ def encode(self, encoding): :return: Cookie encoded in a codec of choosing. :except: UnicodeEncodeError """ + return str(self).encode(encoding) + + def __str__(self): + """Format as a Set-Cookie header value.""" output = ["%s=%s" % (self.key, _quote(self.value))] for key, value in self.items(): if key == "max-age": @@ -147,4 +151,4 @@ def encode(self, encoding): else: output.append("%s=%s" % (self._keys[key], value)) - return "; ".join(output).encode(encoding) + return "; ".join(output) diff --git a/sanic/exceptions.py b/sanic/exceptions.py index 2c4ab2c02e..728235133e 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -175,6 +175,11 @@ class ServiceUnavailable(SanicException): pass +@add_status_code(505) +class VersionNotSupported(SanicException): + pass + + class URLBuildError(ServerError): pass diff --git a/sanic/headers.py b/sanic/headers.py index 142ab27bf1..265bbf401d 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -1,9 +1,12 @@ import re -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from urllib.parse import unquote +from sanic.helpers import STATUS_CODES + +HeaderIterable = Iterable[Tuple[str, Any]] # Values convertible to str Options = Dict[str, Union[int, str]] # key=value fields in various headers OptionsIterable = Iterable[Tuple[str, str]] # May contain duplicate keys @@ -170,3 +173,32 @@ def parse_host(host: str) -> Tuple[Optional[str], Optional[int]]: return None, None host, port = m.groups() return host.lower(), int(port) if port is not None else None + + +def format_http1(headers: HeaderIterable) -> bytes: + """Convert a headers iterable into HTTP/1 header format. + + - Outputs UTF-8 bytes where each header line ends with \\r\\n. + - Values are converted into strings if necessary. + """ + return "".join(f"{name}: {val}\r\n" for name, val in headers).encode() + + +def format_http1_response( + status: int, headers: HeaderIterable, body: Optional[bytes] = None +) -> bytes: + """Format a full HTTP/1.1 response. + + - If `body` is included, content-length must be specified in headers. + """ + if body is None: + body = b"" + headers = format_http1(headers) + if status == 200: + return b"HTTP/1.1 200 OK\r\n%b\r\n%b" % (headers, body) + return b"HTTP/1.1 %d %b\r\n%b\r\n%b" % ( + status, + STATUS_CODES.get(status, b"UNKNOWN"), + headers, + body, + ) diff --git a/sanic/protocol.py b/sanic/protocol.py new file mode 100644 index 0000000000..66d92ba202 --- /dev/null +++ b/sanic/protocol.py @@ -0,0 +1,148 @@ +from sanic.headers import format_http1, format_http1_response +from sanic.helpers import has_message_body, remove_entity_headers + + +# FIXME: Put somewhere before response: +if False: + # Middleware has a chance to replace or modify the response + response = await self.app._run_response_middleware( + request, response + ) + +class H1Stream: + __slots__ = ("stream", "length", "pos", "set_timeout", "response_state", "status", "headers", "bytes_left") + + def __init__(self, headers, stream, need_continue): + self.length = int(headers.get("content-length", "0")) + assert self.length >= 0 + self.pos = None if need_continue else 0 + self.stream = stream + self.status = self.bytes_left = None + self.response_state = 0 + + async def aclose(self): + # Finish sending a response (if no error) + if self.response_state < 2: + await self.send(end_stream=True) + # Response fully sent, request fully read? + if self.pos != self.length or self.response_state != 2: + await self.stream.aclose() # If not, must disconnect :( + + # Request methods + + def dont_continue(self): + """Prevent a pending 100 Continue response being sent, and avoid + receiving the request body. Does not by itself send a 417 response.""" + if self.pos is None: + self.pos = self.length = 0 + + async def trigger_continue(self): + if self.pos is None: + self.pos = 0 + await self.stream.send_all(b"HTTP/1.1 100 Continue\r\n\r\n") + + async def __aiter__(self): + while True: + data = await self.read() + if not data: + return + yield data + + async def read(self): + await self.trigger_continue() + if self.pos == self.length: + return None + buf = await self.stream.receive_some() + if len(buf) > self.length: + self.stream.push_back(buf[self.length :]) + buf = buf[: self.length] + self.pos += len(buf) + return buf + + # Response methods + + def respond(self, status, headers): + if self.response_state > 0: + self.response_state = 3 # FAIL mode + raise RuntimeError("Response already started") + self.status = status + self.headers = headers + return self + + async def send(self, data=None, data_bytes=None, end_stream=False): + """Send any pending response headers and the given data as body. + :param data: str-convertible data to be written + :param data_bytes: bytes-ish data to be written (used if data is None) + :end_stream: whether to close the stream after this block + """ + data = self.data_to_send(data, data_bytes, end_stream) + if data is None: + return + # Check if the request expects a 100-continue first + if self.pos is None: + if self.status == 417: + self.dont_continue() + else: + await self.trigger_continue() + # Send response + await self.stream.send_all(data) + + def data_to_send(self, data, data_bytes, end_stream): + """Format output data bytes for given body data. + Headers are prepended to the first output block and then cleared. + :param data: str-convertible data to be written + :param data_bytes: bytes-ish data to be written (used if data is None) + :return: bytes to send, or None if there is nothing to send + """ + data = data_bytes if data is None else f"{data}".encode() + size = len(data) if data is not None else 0 + + # Headers not yet sent? + if self.response_state == 0: + status, headers = self.status, self.headers + if status in (304, 412): + headers = remove_entity_headers(headers) + if not has_message_body(status): + # Header-only response status + assert ( + size == 0 and end_stream + ), f"A {status} response may only have headers, no body." + assert "content-length" not in self.headers + assert "transfer-encoding" not in self.headers + elif end_stream: + # Non-streaming response (all in one block) + headers["content-length"] = size + elif "content-length" in headers: + # Streaming response with size known in advance + self.bytes_left = int(headers["content-length"]) - size + assert self.bytes_left >= 0 + else: + # Length not known, use chunked encoding + headers["transfer-encoding"] = "chunked" + data = b"%x\r\n%b\r\n" % (size, data) if size else None + self.bytes_left = ... + self.status = self.headers = None + self.response_state = 2 if end_stream else 1 + return format_http1_response(status, headers.items(), data) + + if self.response_state == 2: + if size: + raise RuntimeError("Cannot send data to a closed stream") + return + + self.response_state = 2 if end_stream else 1 + + # Chunked encoding + if self.bytes_left is ...: + if end_stream: + self.bytes_left = None + if size: + return b"%x\r\n%b\r\n0\r\n\r\n" % (size, data) + return b"0\r\n\r\n" + return b"%x\r\n%b\r\n" % (size, data) if size else None + + # Normal encoding + if isinstance(self.bytes_left, int): + self.bytes_left -= size + assert self.bytes_left >= 0 + return data if size else None diff --git a/sanic/reloader_helpers.py b/sanic/reloader_helpers.py index b58391f64b..eac192a31f 100644 --- a/sanic/reloader_helpers.py +++ b/sanic/reloader_helpers.py @@ -3,7 +3,6 @@ import subprocess import sys -from multiprocessing import Process from time import sleep @@ -56,81 +55,21 @@ def restart_with_reloader(): args = _get_args_for_reloading() new_environ = os.environ.copy() new_environ["SANIC_SERVER_RUNNING"] = "true" - cmd = " ".join(args) - worker_process = Process( - target=subprocess.call, - args=(cmd,), - kwargs={"cwd": cwd, "shell": True, "env": new_environ}, - ) - worker_process.start() - return worker_process - - -def kill_process_children_unix(pid): - """Find and kill child processes of a process (maximum two level). - - :param pid: PID of parent process (process ID) - :return: Nothing - """ - root_process_path = "/proc/{pid}/task/{pid}/children".format(pid=pid) - if not os.path.isfile(root_process_path): - return - with open(root_process_path) as children_list_file: - children_list_pid = children_list_file.read().split() - - for child_pid in children_list_pid: - children_proc_path = "/proc/%s/task/%s/children" % ( - child_pid, - child_pid, - ) - if not os.path.isfile(children_proc_path): - continue - with open(children_proc_path) as children_list_file_2: - children_list_pid_2 = children_list_file_2.read().split() - for _pid in children_list_pid_2: - try: - os.kill(int(_pid), signal.SIGTERM) - except ProcessLookupError: - continue + return subprocess.Popen(args, cwd=cwd, env=new_environ) + +def join(worker_process): + try: + # Graceful + worker_process.terminate() + worker_process.wait(2) + except subprocess.TimeoutExpired: + # Not so graceful try: - os.kill(int(child_pid), signal.SIGTERM) - except ProcessLookupError: - continue - - -def kill_process_children_osx(pid): - """Find and kill child processes of a process. - - :param pid: PID of parent process (process ID) - :return: Nothing - """ - subprocess.run(["pkill", "-P", str(pid)]) - - -def kill_process_children(pid): - """Find and kill child processes of a process. - - :param pid: PID of parent process (process ID) - :return: Nothing - """ - if sys.platform == "darwin": - kill_process_children_osx(pid) - elif sys.platform == "linux": - kill_process_children_unix(pid) - else: - pass # should signal error here - - -def kill_program_completly(proc): - """Kill worker and it's child processes and exit. - - :param proc: worker process (process ID) - :return: Nothing - """ - kill_process_children(proc.pid) - proc.terminate() - os._exit(0) - + worker_process.terminate() + worker_process.wait(1) + except subprocess.TimeoutExpired: + worker_process.kill() + worker_process.wait() def watchdog(sleep_interval): """Watch project files, restart worker process if a change happened. @@ -138,30 +77,41 @@ def watchdog(sleep_interval): :param sleep_interval: interval in second. :return: Nothing """ - mtimes = {} - worker_process = restart_with_reloader() - signal.signal( - signal.SIGTERM, lambda *args: kill_program_completly(worker_process) - ) - signal.signal( - signal.SIGINT, lambda *args: kill_program_completly(worker_process) - ) - while True: - for filename in _iter_module_files(): - try: - mtime = os.stat(filename).st_mtime - except OSError: - continue - - old_time = mtimes.get(filename) - if old_time is None: - mtimes[filename] = mtime - continue - elif mtime > old_time: - kill_process_children(worker_process.pid) + try: + mtimes = {} + worker_process = restart_with_reloader() + quit = False + def terminate(sig, frame): + nonlocal quit + quit = True + worker_process.terminate() + signal.signal(signal.SIGTERM, terminate) + signal.signal(signal.SIGINT, terminate) + while not quit: + for i in range(10): + sleep(sleep_interval / 10) + if quit or worker_process.poll(): + return + + need_reload = False + + for filename in _iter_module_files(): + try: + mtime = os.stat(filename).st_mtime + except OSError: + continue + + old_time = mtimes.get(filename) + if old_time is None: + mtimes[filename] = mtime + elif mtime > old_time: + mtimes[filename] = mtime + need_reload = True + + if need_reload: worker_process.terminate() + sleep(0.1) worker_process = restart_with_reloader() - mtimes[filename] = mtime - break - sleep(sleep_interval) + finally: + join(worker_process) diff --git a/sanic/request.py b/sanic/request.py index 246eb351ef..e40a6b1f9e 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -9,7 +9,8 @@ from httptools import parse_url # type: ignore -from sanic.exceptions import InvalidUsage +from sanic.compat import Header +from sanic.exceptions import HeaderExpectationFailed, InvalidUsage from sanic.headers import ( parse_content_header, parse_forwarded, @@ -161,6 +162,24 @@ def body_push(self, data): def body_finish(self): self.body = b"".join(self.body) + async def receive_body(self): + if self.stream.length and not self.stream.pos: + max_size = self.app.config.REQUEST_MAX_SIZE + body = [] + if self.stream.length > max_size: + raise HeaderExpectationFailed("Request body is too large.") + async for data in self.stream: + if self.stream.pos > max_size: + raise HeaderExpectationFailed("Request body is too large.") + body.append(data) + self.body = b"".join(body) + + def respond(self, status=200, headers=None, content_type="text/html"): + headers = Header(headers or {}) + if "content-type" not in headers: + headers["content-type"] = content_type + return self.stream.respond(status, headers) + @property def json(self): if self.parsed_json is None: diff --git a/sanic/response.py b/sanic/response.py index 91fb25f4da..14fe78a227 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -7,7 +7,8 @@ from sanic.compat import Header from sanic.cookies import CookieJar -from sanic.helpers import STATUS_CODES, has_message_body, remove_entity_headers +from sanic.headers import format_http1, format_http1_response +from sanic.helpers import has_message_body, remove_entity_headers try: @@ -30,20 +31,7 @@ def _encode_body(self, data): return str(data).encode() def _parse_headers(self): - headers = b"" - for name, value in self.headers.items(): - try: - headers += b"%b: %b\r\n" % ( - name.encode(), - value.encode("utf-8"), - ) - except AttributeError: - headers += b"%b: %b\r\n" % ( - str(name).encode(), - str(value).encode("utf-8"), - ) - - return headers + return format_http1(self.headers.items()) @property def cookies(self): @@ -116,33 +104,17 @@ async def stream( def get_headers( self, version="1.1", keep_alive=False, keep_alive_timeout=None ): - # This is all returned in a kind-of funky way - # We tried to make this as fast as possible in pure python - timeout_header = b"" + if "Content-Type" not in self.headers: + self.headers["Content-Type"] = self.content_type + if keep_alive and keep_alive_timeout is not None: - timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout + self.headers["Keep-Alive"] = keep_alive_timeout if self.chunked and version == "1.1": self.headers["Transfer-Encoding"] = "chunked" self.headers.pop("Content-Length", None) - self.headers["Content-Type"] = self.headers.get( - "Content-Type", self.content_type - ) - headers = self._parse_headers() - - if self.status == 200: - status = b"OK" - else: - status = STATUS_CODES.get(self.status) - - return (b"HTTP/%b %d %b\r\n" b"%b" b"%b\r\n") % ( - version.encode(), - self.status, - status, - timeout_header, - headers, - ) + return format_http1_response(self.status, self.headers.items()) class HTTPResponse(BaseHTTPResponse): @@ -168,11 +140,8 @@ def __init__( self._cookies = None def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None): - # This is all returned in a kind-of funky way - # We tried to make this as fast as possible in pure python - timeout_header = b"" - if keep_alive and keep_alive_timeout is not None: - timeout_header = b"Keep-Alive: %d\r\n" % keep_alive_timeout + if "Content-Type" not in self.headers: + self.headers["Content-Type"] = self.content_type body = b"" if has_message_body(self.status): @@ -188,24 +157,13 @@ def output(self, version="1.1", keep_alive=False, keep_alive_timeout=None): if self.status in (304, 412): self.headers = remove_entity_headers(self.headers) - headers = self._parse_headers() + if keep_alive and keep_alive_timeout is not None: + self.headers["Connection"] = "keep-alive" + self.headers["Keep-Alive"] = keep_alive_timeout + elif not keep_alive: + self.headers["Connection"] = "close" - if self.status == 200: - status = b"OK" - else: - status = STATUS_CODES.get(self.status, b"UNKNOWN RESPONSE") - - return ( - b"HTTP/%b %d %b\r\n" b"Connection: %b\r\n" b"%b" b"%b\r\n" b"%b" - ) % ( - version.encode(), - self.status, - status, - b"keep-alive" if keep_alive else b"close", - timeout_header, - headers, - body, - ) + return format_http1_response(self.status, self.headers.items(), body) @property def cookies(self): diff --git a/sanic/server.py b/sanic/server.py index 41af81c0fc..f02a72f5ca 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -1,990 +1 @@ -import asyncio -import os -import traceback - -from collections import deque -from functools import partial -from inspect import isawaitable -from multiprocessing import Process -from signal import SIG_IGN, SIGINT, SIGTERM, Signals -from signal import signal as signal_func -from socket import SO_REUSEADDR, SOL_SOCKET, socket -from time import time - -from httptools import HttpRequestParser # type: ignore -from httptools.parser.errors import HttpParserError # type: ignore - -from sanic.compat import Header -from sanic.exceptions import ( - HeaderExpectationFailed, - InvalidUsage, - PayloadTooLarge, - RequestTimeout, - ServerError, - ServiceUnavailable, -) -from sanic.log import access_logger, logger -from sanic.request import EXPECT_HEADER, Request, StreamBuffer -from sanic.response import HTTPResponse - - -try: - import uvloop # type: ignore - - if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy): - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -except ImportError: - pass - - -class Signal: - stopped = False - - -class HttpProtocol(asyncio.Protocol): - """ - This class provides a basic HTTP implementation of the sanic framework. - """ - - __slots__ = ( - # app - "app", - # event loop, connection - "loop", - "transport", - "connections", - "signal", - # request params - "parser", - "request", - "url", - "headers", - # request config - "request_handler", - "request_timeout", - "response_timeout", - "keep_alive_timeout", - "request_max_size", - "request_buffer_queue_size", - "request_class", - "is_request_stream", - "router", - "error_handler", - # enable or disable access log purpose - "access_log", - # connection management - "_total_request_size", - "_request_timeout_handler", - "_response_timeout_handler", - "_keep_alive_timeout_handler", - "_last_request_time", - "_last_response_time", - "_is_stream_handler", - "_not_paused", - "_request_handler_task", - "_request_stream_task", - "_keep_alive", - "_header_fragment", - "state", - "_debug", - ) - - def __init__( - self, - *, - loop, - app, - request_handler, - error_handler, - signal=Signal(), - connections=None, - request_timeout=60, - response_timeout=60, - keep_alive_timeout=5, - request_max_size=None, - request_buffer_queue_size=100, - request_class=None, - access_log=True, - keep_alive=True, - is_request_stream=False, - router=None, - state=None, - debug=False, - **kwargs - ): - self.loop = loop - self.app = app - self.transport = None - self.request = None - self.parser = None - self.url = None - self.headers = None - self.router = router - self.signal = signal - self.access_log = access_log - self.connections = connections if connections is not None else set() - self.request_handler = request_handler - self.error_handler = error_handler - self.request_timeout = request_timeout - self.request_buffer_queue_size = request_buffer_queue_size - self.response_timeout = response_timeout - self.keep_alive_timeout = keep_alive_timeout - self.request_max_size = request_max_size - self.request_class = request_class or Request - self.is_request_stream = is_request_stream - self._is_stream_handler = False - self._not_paused = asyncio.Event(loop=loop) - self._total_request_size = 0 - self._request_timeout_handler = None - self._response_timeout_handler = None - self._keep_alive_timeout_handler = None - self._last_request_time = None - self._last_response_time = None - self._request_handler_task = None - self._request_stream_task = None - self._keep_alive = keep_alive - self._header_fragment = b"" - self.state = state if state else {} - if "requests_count" not in self.state: - self.state["requests_count"] = 0 - self._debug = debug - self._not_paused.set() - self._body_chunks = deque() - - @property - def keep_alive(self): - """ - Check if the connection needs to be kept alive based on the params - attached to the `_keep_alive` attribute, :attr:`Signal.stopped` - and :func:`HttpProtocol.parser.should_keep_alive` - - :return: ``True`` if connection is to be kept alive ``False`` else - """ - return ( - self._keep_alive - and not self.signal.stopped - and self.parser.should_keep_alive() - ) - - # -------------------------------------------- # - # Connection - # -------------------------------------------- # - - def connection_made(self, transport): - self.connections.add(self) - self._request_timeout_handler = self.loop.call_later( - self.request_timeout, self.request_timeout_callback - ) - self.transport = transport - self._last_request_time = time() - - def connection_lost(self, exc): - self.connections.discard(self) - if self._request_handler_task: - self._request_handler_task.cancel() - if self._request_stream_task: - self._request_stream_task.cancel() - if self._request_timeout_handler: - self._request_timeout_handler.cancel() - if self._response_timeout_handler: - self._response_timeout_handler.cancel() - if self._keep_alive_timeout_handler: - self._keep_alive_timeout_handler.cancel() - - def pause_writing(self): - self._not_paused.clear() - - def resume_writing(self): - self._not_paused.set() - - def request_timeout_callback(self): - # See the docstring in the RequestTimeout exception, to see - # exactly what this timeout is checking for. - # Check if elapsed time since request initiated exceeds our - # configured maximum request timeout value - time_elapsed = time() - self._last_request_time - if time_elapsed < self.request_timeout: - time_left = self.request_timeout - time_elapsed - self._request_timeout_handler = self.loop.call_later( - time_left, self.request_timeout_callback - ) - else: - if self._request_stream_task: - self._request_stream_task.cancel() - if self._request_handler_task: - self._request_handler_task.cancel() - self.write_error(RequestTimeout("Request Timeout")) - - def response_timeout_callback(self): - # Check if elapsed time since response was initiated exceeds our - # configured maximum request timeout value - time_elapsed = time() - self._last_request_time - if time_elapsed < self.response_timeout: - time_left = self.response_timeout - time_elapsed - self._response_timeout_handler = self.loop.call_later( - time_left, self.response_timeout_callback - ) - else: - if self._request_stream_task: - self._request_stream_task.cancel() - if self._request_handler_task: - self._request_handler_task.cancel() - self.write_error(ServiceUnavailable("Response Timeout")) - - def keep_alive_timeout_callback(self): - """ - Check if elapsed time since last response exceeds our configured - maximum keep alive timeout value and if so, close the transport - pipe and let the response writer handle the error. - - :return: None - """ - time_elapsed = time() - self._last_response_time - if time_elapsed < self.keep_alive_timeout: - time_left = self.keep_alive_timeout - time_elapsed - self._keep_alive_timeout_handler = self.loop.call_later( - time_left, self.keep_alive_timeout_callback - ) - else: - logger.debug("KeepAlive Timeout. Closing connection.") - self.transport.close() - self.transport = None - - # -------------------------------------------- # - # Parsing - # -------------------------------------------- # - - def data_received(self, data): - # Check for the request itself getting too large and exceeding - # memory limits - self._total_request_size += len(data) - if self._total_request_size > self.request_max_size: - self.write_error(PayloadTooLarge("Payload Too Large")) - - # Create parser if this is the first time we're receiving data - if self.parser is None: - assert self.request is None - self.headers = [] - self.parser = HttpRequestParser(self) - - # requests count - self.state["requests_count"] = self.state["requests_count"] + 1 - - # Parse request chunk or close connection - try: - self.parser.feed_data(data) - except HttpParserError: - message = "Bad Request" - if self._debug: - message += "\n" + traceback.format_exc() - self.write_error(InvalidUsage(message)) - - def on_url(self, url): - if not self.url: - self.url = url - else: - self.url += url - - def on_header(self, name, value): - self._header_fragment += name - - if value is not None: - if ( - self._header_fragment == b"Content-Length" - and int(value) > self.request_max_size - ): - self.write_error(PayloadTooLarge("Payload Too Large")) - try: - value = value.decode() - except UnicodeDecodeError: - value = value.decode("latin_1") - self.headers.append( - (self._header_fragment.decode().casefold(), value) - ) - - self._header_fragment = b"" - - def on_headers_complete(self): - self.request = self.request_class( - url_bytes=self.url, - headers=Header(self.headers), - version=self.parser.get_http_version(), - method=self.parser.get_method().decode(), - transport=self.transport, - app=self.app, - ) - # Remove any existing KeepAlive handler here, - # It will be recreated if required on the new request. - if self._keep_alive_timeout_handler: - self._keep_alive_timeout_handler.cancel() - self._keep_alive_timeout_handler = None - - if self.request.headers.get(EXPECT_HEADER): - self.expect_handler() - - if self.is_request_stream: - self._is_stream_handler = self.router.is_stream_handler( - self.request - ) - if self._is_stream_handler: - self.request.stream = StreamBuffer( - self.request_buffer_queue_size - ) - self.execute_request_handler() - - def expect_handler(self): - """ - Handler for Expect Header. - """ - expect = self.request.headers.get(EXPECT_HEADER) - if self.request.version == "1.1": - if expect.lower() == "100-continue": - self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") - else: - self.write_error( - HeaderExpectationFailed( - "Unknown Expect: {expect}".format(expect=expect) - ) - ) - - def on_body(self, body): - if self.is_request_stream and self._is_stream_handler: - # body chunks can be put into asyncio.Queue out of order if - # multiple tasks put concurrently and the queue is full in python - # 3.7. so we should not create more than one task putting into the - # queue simultaneously. - self._body_chunks.append(body) - if ( - not self._request_stream_task - or self._request_stream_task.done() - ): - self._request_stream_task = self.loop.create_task( - self.stream_append() - ) - else: - self.request.body_push(body) - - async def stream_append(self): - while self._body_chunks: - body = self._body_chunks.popleft() - if self.request.stream.is_full(): - self.transport.pause_reading() - await self.request.stream.put(body) - self.transport.resume_reading() - else: - await self.request.stream.put(body) - - def on_message_complete(self): - # Entire request (headers and whole body) is received. - # We can cancel and remove the request timeout handler now. - if self._request_timeout_handler: - self._request_timeout_handler.cancel() - self._request_timeout_handler = None - if self.is_request_stream and self._is_stream_handler: - self._body_chunks.append(None) - if ( - not self._request_stream_task - or self._request_stream_task.done() - ): - self._request_stream_task = self.loop.create_task( - self.stream_append() - ) - return - self.request.body_finish() - self.execute_request_handler() - - def execute_request_handler(self): - """ - Invoke the request handler defined by the - :func:`sanic.app.Sanic.handle_request` method - - :return: None - """ - self._response_timeout_handler = self.loop.call_later( - self.response_timeout, self.response_timeout_callback - ) - self._last_request_time = time() - self._request_handler_task = self.loop.create_task( - self.request_handler( - self.request, self.write_response, self.stream_response - ) - ) - - # -------------------------------------------- # - # Responding - # -------------------------------------------- # - def log_response(self, response): - """ - Helper method provided to enable the logging of responses in case if - the :attr:`HttpProtocol.access_log` is enabled. - - :param response: Response generated for the current request - - :type response: :class:`sanic.response.HTTPResponse` or - :class:`sanic.response.StreamingHTTPResponse` - - :return: None - """ - if self.access_log: - extra = {"status": getattr(response, "status", 0)} - - if isinstance(response, HTTPResponse): - extra["byte"] = len(response.body) - else: - extra["byte"] = -1 - - extra["host"] = "UNKNOWN" - if self.request is not None: - if self.request.ip: - extra["host"] = "{0}:{1}".format( - self.request.ip, self.request.port - ) - - extra["request"] = "{0} {1}".format( - self.request.method, self.request.url - ) - else: - extra["request"] = "nil" - - access_logger.info("", extra=extra) - - def write_response(self, response): - """ - Writes response content synchronously to the transport. - """ - if self._response_timeout_handler: - self._response_timeout_handler.cancel() - self._response_timeout_handler = None - try: - keep_alive = self.keep_alive - self.transport.write( - response.output( - self.request.version, keep_alive, self.keep_alive_timeout - ) - ) - self.log_response(response) - except AttributeError: - logger.error( - "Invalid response object for url %s, " - "Expected Type: HTTPResponse, Actual Type: %s", - self.url, - type(response), - ) - self.write_error(ServerError("Invalid response type")) - except RuntimeError: - if self._debug: - logger.error( - "Connection lost before response written @ %s", - self.request.ip, - ) - keep_alive = False - except Exception as e: - self.bail_out( - "Writing response failed, connection closed {}".format(repr(e)) - ) - finally: - if not keep_alive: - self.transport.close() - self.transport = None - else: - self._keep_alive_timeout_handler = self.loop.call_later( - self.keep_alive_timeout, self.keep_alive_timeout_callback - ) - self._last_response_time = time() - self.cleanup() - - async def drain(self): - await self._not_paused.wait() - - async def push_data(self, data): - self.transport.write(data) - - async def stream_response(self, response): - """ - Streams a response to the client asynchronously. Attaches - the transport to the response so the response consumer can - write to the response as needed. - """ - if self._response_timeout_handler: - self._response_timeout_handler.cancel() - self._response_timeout_handler = None - - try: - keep_alive = self.keep_alive - response.protocol = self - await response.stream( - self.request.version, keep_alive, self.keep_alive_timeout - ) - self.log_response(response) - except AttributeError: - logger.error( - "Invalid response object for url %s, " - "Expected Type: HTTPResponse, Actual Type: %s", - self.url, - type(response), - ) - self.write_error(ServerError("Invalid response type")) - except RuntimeError: - if self._debug: - logger.error( - "Connection lost before response written @ %s", - self.request.ip, - ) - keep_alive = False - except Exception as e: - self.bail_out( - "Writing response failed, connection closed {}".format(repr(e)) - ) - finally: - if not keep_alive: - self.transport.close() - self.transport = None - else: - self._keep_alive_timeout_handler = self.loop.call_later( - self.keep_alive_timeout, self.keep_alive_timeout_callback - ) - self._last_response_time = time() - self.cleanup() - - def write_error(self, exception): - # An error _is_ a response. - # Don't throw a response timeout, when a response _is_ given. - if self._response_timeout_handler: - self._response_timeout_handler.cancel() - self._response_timeout_handler = None - response = None - try: - response = self.error_handler.response(self.request, exception) - version = self.request.version if self.request else "1.1" - self.transport.write(response.output(version)) - except RuntimeError: - if self._debug: - logger.error( - "Connection lost before error written @ %s", - self.request.ip if self.request else "Unknown", - ) - except Exception as e: - self.bail_out( - "Writing error failed, connection closed {}".format(repr(e)), - from_error=True, - ) - finally: - if self.parser and ( - self.keep_alive or getattr(response, "status", 0) == 408 - ): - self.log_response(response) - try: - self.transport.close() - except AttributeError: - logger.debug("Connection lost before server could close it.") - - def bail_out(self, message, from_error=False): - """ - In case if the transport pipes are closed and the sanic app encounters - an error while writing data to the transport pipe, we log the error - with proper details. - - :param message: Error message to display - :param from_error: If the bail out was invoked while handling an - exception scenario. - - :type message: str - :type from_error: bool - - :return: None - """ - if from_error or self.transport is None or self.transport.is_closing(): - logger.error( - "Transport closed @ %s and exception " - "experienced during error handling", - ( - self.transport.get_extra_info("peername") - if self.transport is not None - else "N/A" - ), - ) - logger.debug("Exception:", exc_info=True) - else: - self.write_error(ServerError(message)) - logger.error(message) - - def cleanup(self): - """This is called when KeepAlive feature is used, - it resets the connection in order for it to be able - to handle receiving another request on the same connection.""" - self.parser = None - self.request = None - self.url = None - self.headers = None - self._request_handler_task = None - self._request_stream_task = None - self._total_request_size = 0 - self._is_stream_handler = False - - def close_if_idle(self): - """Close the connection if a request is not being sent or received - - :return: boolean - True if closed, false if staying open - """ - if not self.parser: - self.transport.close() - return True - return False - - def close(self): - """ - Force close the connection. - """ - if self.transport is not None: - self.transport.close() - self.transport = None - - -def trigger_events(events, loop): - """Trigger event callbacks (functions or async) - - :param events: one or more sync or async functions to execute - :param loop: event loop - """ - for event in events: - result = event(loop) - if isawaitable(result): - loop.run_until_complete(result) - - -class AsyncioServer: - """ - Wraps an asyncio server with functionality that might be useful to - a user who needs to manage the server lifecycle manually. - """ - - __slots__ = ( - "loop", - "serve_coro", - "_after_start", - "_before_stop", - "_after_stop", - "server", - "connections", - ) - - def __init__( - self, - loop, - serve_coro, - connections, - after_start, - before_stop, - after_stop, - ): - # Note, Sanic already called "before_server_start" events - # before this helper was even created. So we don't need it here. - self.loop = loop - self.serve_coro = serve_coro - self._after_start = after_start - self._before_stop = before_stop - self._after_stop = after_stop - self.server = None - self.connections = connections - - def after_start(self): - """Trigger "after_server_start" events""" - trigger_events(self._after_start, self.loop) - - def before_stop(self): - """Trigger "before_server_stop" events""" - trigger_events(self._before_stop, self.loop) - - def after_stop(self): - """Trigger "after_server_stop" events""" - trigger_events(self._after_stop, self.loop) - - def is_serving(self): - if self.server: - return self.server.is_serving() - return False - - def wait_closed(self): - if self.server: - return self.server.wait_closed() - - def close(self): - if self.server: - self.server.close() - coro = self.wait_closed() - task = asyncio.ensure_future(coro, loop=self.loop) - return task - - def __await__(self): - """Starts the asyncio server, returns AsyncServerCoro""" - task = asyncio.ensure_future(self.serve_coro) - while not task.done(): - yield - self.server = task.result() - return self - - -def serve( - host, - port, - app, - request_handler, - error_handler, - before_start=None, - after_start=None, - before_stop=None, - after_stop=None, - debug=False, - request_timeout=60, - response_timeout=60, - keep_alive_timeout=5, - ssl=None, - sock=None, - request_max_size=None, - request_buffer_queue_size=100, - reuse_port=False, - loop=None, - protocol=HttpProtocol, - backlog=100, - register_sys_signals=True, - run_multiple=False, - run_async=False, - connections=None, - signal=Signal(), - request_class=None, - access_log=True, - keep_alive=True, - is_request_stream=False, - router=None, - websocket_max_size=None, - websocket_max_queue=None, - websocket_read_limit=2 ** 16, - websocket_write_limit=2 ** 16, - state=None, - graceful_shutdown_timeout=15.0, - asyncio_server_kwargs=None, -): - """Start asynchronous HTTP Server on an individual process. - - :param host: Address to host on - :param port: Port to host on - :param request_handler: Sanic request handler with middleware - :param error_handler: Sanic error handler with middleware - :param before_start: function to be executed before the server starts - listening. Takes arguments `app` instance and `loop` - :param after_start: function to be executed after the server starts - listening. Takes arguments `app` instance and `loop` - :param before_stop: function to be executed when a stop signal is - received before it is respected. Takes arguments - `app` instance and `loop` - :param after_stop: function to be executed when a stop signal is - received after it is respected. Takes arguments - `app` instance and `loop` - :param debug: enables debug output (slows server) - :param request_timeout: time in seconds - :param response_timeout: time in seconds - :param keep_alive_timeout: time in seconds - :param ssl: SSLContext - :param sock: Socket for the server to accept connections from - :param request_max_size: size in bytes, `None` for no limit - :param reuse_port: `True` for multiple workers - :param loop: asyncio compatible event loop - :param protocol: subclass of asyncio protocol class - :param run_async: bool: Do not create a new event loop for the server, - and return an AsyncServer object rather than running it - :param request_class: Request class to use - :param access_log: disable/enable access log - :param websocket_max_size: enforces the maximum size for - incoming messages in bytes. - :param websocket_max_queue: sets the maximum length of the queue - that holds incoming messages. - :param websocket_read_limit: sets the high-water limit of the buffer for - incoming bytes, the low-water limit is half - the high-water limit. - :param websocket_write_limit: sets the high-water limit of the buffer for - outgoing bytes, the low-water limit is a - quarter of the high-water limit. - :param is_request_stream: disable/enable Request.stream - :param request_buffer_queue_size: streaming request buffer queue size - :param router: Router object - :param graceful_shutdown_timeout: How long take to Force close non-idle - connection - :param asyncio_server_kwargs: key-value args for asyncio/uvloop - create_server method - :return: Nothing - """ - if not run_async: - # create new event_loop after fork - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - if debug: - loop.set_debug(debug) - - app.asgi = False - - connections = connections if connections is not None else set() - server = partial( - protocol, - loop=loop, - connections=connections, - signal=signal, - app=app, - request_handler=request_handler, - error_handler=error_handler, - request_timeout=request_timeout, - response_timeout=response_timeout, - keep_alive_timeout=keep_alive_timeout, - request_max_size=request_max_size, - request_buffer_queue_size=request_buffer_queue_size, - request_class=request_class, - access_log=access_log, - keep_alive=keep_alive, - is_request_stream=is_request_stream, - router=router, - websocket_max_size=websocket_max_size, - websocket_max_queue=websocket_max_queue, - websocket_read_limit=websocket_read_limit, - websocket_write_limit=websocket_write_limit, - state=state, - debug=debug, - ) - asyncio_server_kwargs = ( - asyncio_server_kwargs if asyncio_server_kwargs else {} - ) - server_coroutine = loop.create_server( - server, - host, - port, - ssl=ssl, - reuse_port=reuse_port, - sock=sock, - backlog=backlog, - **asyncio_server_kwargs - ) - - if run_async: - return AsyncioServer( - loop, - server_coroutine, - connections, - after_start, - before_stop, - after_stop, - ) - - trigger_events(before_start, loop) - - try: - http_server = loop.run_until_complete(server_coroutine) - except BaseException: - logger.exception("Unable to start server") - return - - trigger_events(after_start, loop) - - # Ignore SIGINT when run_multiple - if run_multiple: - signal_func(SIGINT, SIG_IGN) - - # Register signals for graceful termination - if register_sys_signals: - _singals = (SIGTERM,) if run_multiple else (SIGINT, SIGTERM) - for _signal in _singals: - try: - loop.add_signal_handler(_signal, loop.stop) - except NotImplementedError: - logger.warning( - "Sanic tried to use loop.add_signal_handler " - "but it is not implemented on this platform." - ) - pid = os.getpid() - try: - logger.info("Starting worker [%s]", pid) - loop.run_forever() - finally: - logger.info("Stopping worker [%s]", pid) - - # Run the on_stop function if provided - trigger_events(before_stop, loop) - - # Wait for event loop to finish and all connections to drain - http_server.close() - loop.run_until_complete(http_server.wait_closed()) - - # Complete all tasks on the loop - signal.stopped = True - for connection in connections: - connection.close_if_idle() - - # Gracefully shutdown timeout. - # We should provide graceful_shutdown_timeout, - # instead of letting connection hangs forever. - # Let's roughly calcucate time. - start_shutdown = 0 - while connections and (start_shutdown < graceful_shutdown_timeout): - loop.run_until_complete(asyncio.sleep(0.1)) - start_shutdown = start_shutdown + 0.1 - - # Force close non-idle connection after waiting for - # graceful_shutdown_timeout - coros = [] - for conn in connections: - if hasattr(conn, "websocket") and conn.websocket: - coros.append(conn.websocket.close_connection()) - else: - conn.close() - - _shutdown = asyncio.gather(*coros, loop=loop) - loop.run_until_complete(_shutdown) - - trigger_events(after_stop, loop) - - loop.close() - - -def serve_multiple(server_settings, workers): - """Start multiple server processes simultaneously. Stop on interrupt - and terminate signals, and drain connections when complete. - - :param server_settings: kw arguments to be passed to the serve function - :param workers: number of workers to launch - :param stop_event: if provided, is used as a stop signal - :return: - """ - server_settings["reuse_port"] = True - server_settings["run_multiple"] = True - - # Handling when custom socket is not provided. - if server_settings.get("sock") is None: - sock = socket() - sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) - sock.bind((server_settings["host"], server_settings["port"])) - sock.set_inheritable(True) - server_settings["sock"] = sock - server_settings["host"] = None - server_settings["port"] = None - - processes = [] - - def sig_handler(signal, frame): - logger.info("Received signal %s. Shutting down.", Signals(signal).name) - for process in processes: - os.kill(process.pid, SIGTERM) - - signal_func(SIGINT, lambda s, f: sig_handler(s, f)) - signal_func(SIGTERM, lambda s, f: sig_handler(s, f)) - - for _ in range(workers): - process = Process(target=serve, kwargs=server_settings) - process.daemon = True - process.start() - processes.append(process) - - for process in processes: - process.join() - - # the above processes will block this until they're stopped - for process in processes: - process.terminate() - server_settings.get("sock").close() +from sanic.server_trio import * diff --git a/sanic/server_asyncio.py b/sanic/server_asyncio.py new file mode 100644 index 0000000000..e1d7a3d22f --- /dev/null +++ b/sanic/server_asyncio.py @@ -0,0 +1,889 @@ +import asyncio +import os +import traceback + +from functools import partial +from inspect import isawaitable +from multiprocessing import Process +from signal import SIG_IGN, SIGINT, SIGTERM, Signals +from signal import signal as signal_func +from socket import SO_REUSEADDR, SOL_SOCKET, socket +from time import time + +from httptools import HttpRequestParser +from httptools.parser.errors import HttpParserError + +from sanic.compat import Header +from sanic.exceptions import ( + HeaderExpectationFailed, + InvalidUsage, + PayloadTooLarge, + RequestTimeout, + ServerError, + ServiceUnavailable, +) +from sanic.log import access_logger, logger +from sanic.request import EXPECT_HEADER, Request, StreamBuffer +from sanic.response import HTTPResponse + + +try: + import uvloop + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +except ImportError: + pass + + +class Signal: + stopped = False + + +class HttpProtocol(asyncio.Protocol): + """ + This class provides a basic HTTP implementation of the sanic framework. + """ + + __slots__ = ( + # app + "app", + # event loop, connection + "loop", + "transport", + "connections", + "signal", + # request params + "parser", + "request", + "url", + "headers", + # request config + "request_handler", + "request_timeout", + "response_timeout", + "keep_alive_timeout", + "request_max_size", + "request_buffer_queue_size", + "request_class", + "is_request_stream", + "router", + "error_handler", + # enable or disable access log purpose + "access_log", + # connection management + "_total_request_size", + "_request_timeout_handler", + "_response_timeout_handler", + "_keep_alive_timeout_handler", + "_last_request_time", + "_last_response_time", + "_is_stream_handler", + "_not_paused", + "_request_handler_task", + "_request_stream_task", + "_keep_alive", + "_header_fragment", + "state", + "_debug", + ) + + def __init__( + self, + *, + loop, + app, + request_handler, + error_handler, + signal=Signal(), + connections=None, + request_timeout=60, + response_timeout=60, + keep_alive_timeout=5, + request_max_size=None, + request_buffer_queue_size=100, + request_class=None, + access_log=True, + keep_alive=True, + is_request_stream=False, + router=None, + state=None, + debug=False, + **kwargs + ): + self.loop = loop + self.app = app + self.transport = None + self.request = None + self.parser = None + self.url = None + self.headers = None + self.router = router + self.signal = signal + self.access_log = access_log + self.connections = connections if connections is not None else set() + self.request_handler = request_handler + self.error_handler = error_handler + self.request_timeout = request_timeout + self.request_buffer_queue_size = request_buffer_queue_size + self.response_timeout = response_timeout + self.keep_alive_timeout = keep_alive_timeout + self.request_max_size = request_max_size + self.request_class = request_class or Request + self.is_request_stream = is_request_stream + self._is_stream_handler = False + self._not_paused = asyncio.Event(loop=loop) + self._total_request_size = 0 + self._request_timeout_handler = None + self._response_timeout_handler = None + self._keep_alive_timeout_handler = None + self._last_request_time = None + self._last_response_time = None + self._request_handler_task = None + self._request_stream_task = None + self._keep_alive = keep_alive + self._header_fragment = b"" + self.state = state if state else {} + if "requests_count" not in self.state: + self.state["requests_count"] = 0 + self._debug = debug + self._not_paused.set() + + @property + def keep_alive(self): + """ + Check if the connection needs to be kept alive based on the params + attached to the `_keep_alive` attribute, :attr:`Signal.stopped` + and :func:`HttpProtocol.parser.should_keep_alive` + + :return: ``True`` if connection is to be kept alive ``False`` else + """ + return ( + self._keep_alive + and not self.signal.stopped + and self.parser.should_keep_alive() + ) + + # -------------------------------------------- # + # Connection + # -------------------------------------------- # + + def connection_made(self, transport): + self.connections.add(self) + self._request_timeout_handler = self.loop.call_later( + self.request_timeout, self.request_timeout_callback + ) + self.transport = transport + self._last_request_time = time() + + def connection_lost(self, exc): + self.connections.discard(self) + if self._request_handler_task: + self._request_handler_task.cancel() + if self._request_stream_task: + self._request_stream_task.cancel() + if self._request_timeout_handler: + self._request_timeout_handler.cancel() + if self._response_timeout_handler: + self._response_timeout_handler.cancel() + if self._keep_alive_timeout_handler: + self._keep_alive_timeout_handler.cancel() + + def pause_writing(self): + self._not_paused.clear() + + def resume_writing(self): + self._not_paused.set() + + def request_timeout_callback(self): + # See the docstring in the RequestTimeout exception, to see + # exactly what this timeout is checking for. + # Check if elapsed time since request initiated exceeds our + # configured maximum request timeout value + time_elapsed = time() - self._last_request_time + if time_elapsed < self.request_timeout: + time_left = self.request_timeout - time_elapsed + self._request_timeout_handler = self.loop.call_later( + time_left, self.request_timeout_callback + ) + else: + if self._request_stream_task: + self._request_stream_task.cancel() + if self._request_handler_task: + self._request_handler_task.cancel() + self.write_error(RequestTimeout("Request Timeout")) + + def response_timeout_callback(self): + # Check if elapsed time since response was initiated exceeds our + # configured maximum request timeout value + time_elapsed = time() - self._last_request_time + if time_elapsed < self.response_timeout: + time_left = self.response_timeout - time_elapsed + self._response_timeout_handler = self.loop.call_later( + time_left, self.response_timeout_callback + ) + else: + if self._request_stream_task: + self._request_stream_task.cancel() + if self._request_handler_task: + self._request_handler_task.cancel() + self.write_error(ServiceUnavailable("Response Timeout")) + + def keep_alive_timeout_callback(self): + """ + Check if elapsed time since last response exceeds our configured + maximum keep alive timeout value and if so, close the transport + pipe and let the response writer handle the error. + + :return: None + """ + time_elapsed = time() - self._last_response_time + if time_elapsed < self.keep_alive_timeout: + time_left = self.keep_alive_timeout - time_elapsed + self._keep_alive_timeout_handler = self.loop.call_later( + time_left, self.keep_alive_timeout_callback + ) + else: + logger.debug("KeepAlive Timeout. Closing connection.") + self.transport.close() + self.transport = None + + # -------------------------------------------- # + # Parsing + # -------------------------------------------- # + + def data_received(self, data): + # Check for the request itself getting too large and exceeding + # memory limits + self._total_request_size += len(data) + if self._total_request_size > self.request_max_size: + self.write_error(PayloadTooLarge("Payload Too Large")) + + # Create parser if this is the first time we're receiving data + if self.parser is None: + assert self.request is None + self.headers = [] + self.parser = HttpRequestParser(self) + + # requests count + self.state["requests_count"] = self.state["requests_count"] + 1 + + # Parse request chunk or close connection + try: + self.parser.feed_data(data) + except HttpParserError: + message = "Bad Request" + if self._debug: + message += "\n" + traceback.format_exc() + self.write_error(InvalidUsage(message)) + + def on_url(self, url): + if not self.url: + self.url = url + else: + self.url += url + + def on_header(self, name, value): + self._header_fragment += name + + if value is not None: + if ( + self._header_fragment == b"Content-Length" + and int(value) > self.request_max_size + ): + self.write_error(PayloadTooLarge("Payload Too Large")) + try: + value = value.decode() + except UnicodeDecodeError: + value = value.decode("latin_1") + self.headers.append( + (self._header_fragment.decode().casefold(), value) + ) + + self._header_fragment = b"" + + def on_headers_complete(self): + self.request = self.request_class( + url_bytes=self.url, + headers=Header(self.headers), + version=self.parser.get_http_version(), + method=self.parser.get_method().decode(), + transport=self.transport, + app=self.app, + ) + # Remove any existing KeepAlive handler here, + # It will be recreated if required on the new request. + if self._keep_alive_timeout_handler: + self._keep_alive_timeout_handler.cancel() + self._keep_alive_timeout_handler = None + + if self.request.headers.get(EXPECT_HEADER): + self.expect_handler() + + if self.is_request_stream: + self._is_stream_handler = self.router.is_stream_handler( + self.request + ) + if self._is_stream_handler: + self.request.stream = StreamBuffer( + self.request_buffer_queue_size + ) + self.execute_request_handler() + + def expect_handler(self): + """ + Handler for Expect Header. + """ + expect = self.request.headers.get(EXPECT_HEADER) + if self.request.version == "1.1": + if expect.lower() == "100-continue": + self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") + else: + self.write_error( + HeaderExpectationFailed( + "Unknown Expect: {expect}".format(expect=expect) + ) + ) + + def on_body(self, body): + if self.is_request_stream and self._is_stream_handler: + self._request_stream_task = self.loop.create_task( + self.body_append(body) + ) + else: + self.request.body_push(body) + + async def body_append(self, body): + if self.request.stream.is_full(): + self.transport.pause_reading() + await self.request.stream.put(body) + self.transport.resume_reading() + else: + await self.request.stream.put(body) + + def on_message_complete(self): + # Entire request (headers and whole body) is received. + # We can cancel and remove the request timeout handler now. + if self._request_timeout_handler: + self._request_timeout_handler.cancel() + self._request_timeout_handler = None + if self.is_request_stream and self._is_stream_handler: + self._request_stream_task = self.loop.create_task( + self.request.stream.put(None) + ) + return + self.request.body_finish() + self.execute_request_handler() + + def execute_request_handler(self): + """ + Invoke the request handler defined by the + :func:`sanic.app.Sanic.handle_request` method + + :return: None + """ + self._response_timeout_handler = self.loop.call_later( + self.response_timeout, self.response_timeout_callback + ) + self._last_request_time = time() + self._request_handler_task = self.loop.create_task( + self.request_handler( + self.request, self.write_response, self.stream_response + ) + ) + + # -------------------------------------------- # + # Responding + # -------------------------------------------- # + def log_response(self, response): + """ + Helper method provided to enable the logging of responses in case if + the :attr:`HttpProtocol.access_log` is enabled. + + :param response: Response generated for the current request + + :type response: :class:`sanic.response.HTTPResponse` or + :class:`sanic.response.StreamingHTTPResponse` + + :return: None + """ + if self.access_log: + extra = {"status": getattr(response, "status", 0)} + + if isinstance(response, HTTPResponse): + extra["byte"] = len(response.body) + else: + extra["byte"] = -1 + + extra["host"] = "UNKNOWN" + if self.request is not None: + if self.request.ip: + extra["host"] = "{0}:{1}".format( + self.request.ip, self.request.port + ) + + extra["request"] = "{0} {1}".format( + self.request.method, self.request.url + ) + else: + extra["request"] = "nil" + + access_logger.info("", extra=extra) + + def write_response(self, response): + """ + Writes response content synchronously to the transport. + """ + if self._response_timeout_handler: + self._response_timeout_handler.cancel() + self._response_timeout_handler = None + try: + keep_alive = self.keep_alive + self.transport.write( + response.output( + self.request.version, keep_alive, self.keep_alive_timeout + ) + ) + self.log_response(response) + except AttributeError: + logger.error( + "Invalid response object for url %s, " + "Expected Type: HTTPResponse, Actual Type: %s", + self.url, + type(response), + ) + self.write_error(ServerError("Invalid response type")) + except RuntimeError: + if self._debug: + logger.error( + "Connection lost before response written @ %s", + self.request.ip, + ) + keep_alive = False + except Exception as e: + self.bail_out( + "Writing response failed, connection closed {}".format(repr(e)) + ) + finally: + if not keep_alive: + self.transport.close() + self.transport = None + else: + self._keep_alive_timeout_handler = self.loop.call_later( + self.keep_alive_timeout, self.keep_alive_timeout_callback + ) + self._last_response_time = time() + self.cleanup() + + async def drain(self): + await self._not_paused.wait() + + async def push_data(self, data): + self.transport.write(data) + + async def stream_response(self, response): + """ + Streams a response to the client asynchronously. Attaches + the transport to the response so the response consumer can + write to the response as needed. + """ + if self._response_timeout_handler: + self._response_timeout_handler.cancel() + self._response_timeout_handler = None + + try: + keep_alive = self.keep_alive + response.protocol = self + await response.stream( + self.request.version, keep_alive, self.keep_alive_timeout + ) + self.log_response(response) + except AttributeError: + logger.error( + "Invalid response object for url %s, " + "Expected Type: HTTPResponse, Actual Type: %s", + self.url, + type(response), + ) + self.write_error(ServerError("Invalid response type")) + except RuntimeError: + if self._debug: + logger.error( + "Connection lost before response written @ %s", + self.request.ip, + ) + keep_alive = False + except Exception as e: + self.bail_out( + "Writing response failed, connection closed {}".format(repr(e)) + ) + finally: + if not keep_alive: + self.transport.close() + self.transport = None + else: + self._keep_alive_timeout_handler = self.loop.call_later( + self.keep_alive_timeout, self.keep_alive_timeout_callback + ) + self._last_response_time = time() + self.cleanup() + + def write_error(self, exception): + # An error _is_ a response. + # Don't throw a response timeout, when a response _is_ given. + if self._response_timeout_handler: + self._response_timeout_handler.cancel() + self._response_timeout_handler = None + response = None + try: + response = self.error_handler.response(self.request, exception) + version = self.request.version if self.request else "1.1" + self.transport.write(response.output(version)) + except RuntimeError: + if self._debug: + logger.error( + "Connection lost before error written @ %s", + self.request.ip if self.request else "Unknown", + ) + except Exception as e: + self.bail_out( + "Writing error failed, connection closed {}".format(repr(e)), + from_error=True, + ) + finally: + if self.parser and ( + self.keep_alive or getattr(response, "status", 0) == 408 + ): + self.log_response(response) + try: + self.transport.close() + except AttributeError: + logger.debug("Connection lost before server could close it.") + + def bail_out(self, message, from_error=False): + """ + In case if the transport pipes are closed and the sanic app encounters + an error while writing data to the transport pipe, we log the error + with proper details. + + :param message: Error message to display + :param from_error: If the bail out was invoked while handling an + exception scenario. + + :type message: str + :type from_error: bool + + :return: None + """ + if from_error or self.transport is None or self.transport.is_closing(): + logger.error( + "Transport closed @ %s and exception " + "experienced during error handling", + ( + self.transport.get_extra_info("peername") + if self.transport is not None + else "N/A" + ), + ) + logger.debug("Exception:", exc_info=True) + else: + self.write_error(ServerError(message)) + logger.error(message) + + def cleanup(self): + """This is called when KeepAlive feature is used, + it resets the connection in order for it to be able + to handle receiving another request on the same connection.""" + self.parser = None + self.request = None + self.url = None + self.headers = None + self._request_handler_task = None + self._request_stream_task = None + self._total_request_size = 0 + self._is_stream_handler = False + + def close_if_idle(self): + """Close the connection if a request is not being sent or received + + :return: boolean - True if closed, false if staying open + """ + if not self.parser: + self.transport.close() + return True + return False + + def close(self): + """ + Force close the connection. + """ + if self.transport is not None: + self.transport.close() + self.transport = None + + +def trigger_events(events, loop): + """Trigger event callbacks (functions or async) + + :param events: one or more sync or async functions to execute + :param loop: event loop + """ + for event in events: + result = event(loop) + if isawaitable(result): + loop.run_until_complete(result) + + +def serve( + host, + port, + app, + request_handler, + error_handler, + before_start=None, + after_start=None, + before_stop=None, + after_stop=None, + debug=False, + request_timeout=60, + response_timeout=60, + keep_alive_timeout=5, + ssl=None, + sock=None, + request_max_size=None, + request_buffer_queue_size=100, + reuse_port=False, + loop=None, + protocol=HttpProtocol, + backlog=100, + register_sys_signals=True, + run_multiple=False, + run_async=False, + connections=None, + signal=Signal(), + request_class=None, + access_log=True, + keep_alive=True, + is_request_stream=False, + router=None, + websocket_max_size=None, + websocket_max_queue=None, + websocket_read_limit=2 ** 16, + websocket_write_limit=2 ** 16, + state=None, + graceful_shutdown_timeout=15.0, + asyncio_server_kwargs=None, +): + """Start asynchronous HTTP Server on an individual process. + + :param host: Address to host on + :param port: Port to host on + :param request_handler: Sanic request handler with middleware + :param error_handler: Sanic error handler with middleware + :param before_start: function to be executed before the server starts + listening. Takes arguments `app` instance and `loop` + :param after_start: function to be executed after the server starts + listening. Takes arguments `app` instance and `loop` + :param before_stop: function to be executed when a stop signal is + received before it is respected. Takes arguments + `app` instance and `loop` + :param after_stop: function to be executed when a stop signal is + received after it is respected. Takes arguments + `app` instance and `loop` + :param debug: enables debug output (slows server) + :param request_timeout: time in seconds + :param response_timeout: time in seconds + :param keep_alive_timeout: time in seconds + :param ssl: SSLContext + :param sock: Socket for the server to accept connections from + :param request_max_size: size in bytes, `None` for no limit + :param reuse_port: `True` for multiple workers + :param loop: asyncio compatible event loop + :param protocol: subclass of asyncio protocol class + :param request_class: Request class to use + :param access_log: disable/enable access log + :param websocket_max_size: enforces the maximum size for + incoming messages in bytes. + :param websocket_max_queue: sets the maximum length of the queue + that holds incoming messages. + :param websocket_read_limit: sets the high-water limit of the buffer for + incoming bytes, the low-water limit is half + the high-water limit. + :param websocket_write_limit: sets the high-water limit of the buffer for + outgoing bytes, the low-water limit is a + quarter of the high-water limit. + :param is_request_stream: disable/enable Request.stream + :param request_buffer_queue_size: streaming request buffer queue size + :param router: Router object + :param graceful_shutdown_timeout: How long take to Force close non-idle + connection + :param asyncio_server_kwargs: key-value args for asyncio/uvloop + create_server method + :return: Nothing + """ + if not run_async: + # create new event_loop after fork + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + if debug: + loop.set_debug(debug) + + app.asgi = False + + connections = connections if connections is not None else set() + server = partial( + protocol, + loop=loop, + connections=connections, + signal=signal, + app=app, + request_handler=request_handler, + error_handler=error_handler, + request_timeout=request_timeout, + response_timeout=response_timeout, + keep_alive_timeout=keep_alive_timeout, + request_max_size=request_max_size, + request_class=request_class, + access_log=access_log, + keep_alive=keep_alive, + is_request_stream=is_request_stream, + router=router, + websocket_max_size=websocket_max_size, + websocket_max_queue=websocket_max_queue, + websocket_read_limit=websocket_read_limit, + websocket_write_limit=websocket_write_limit, + state=state, + debug=debug, + ) + asyncio_server_kwargs = ( + asyncio_server_kwargs if asyncio_server_kwargs else {} + ) + server_coroutine = loop.create_server( + server, + host, + port, + ssl=ssl, + reuse_port=reuse_port, + sock=sock, + backlog=backlog, + **asyncio_server_kwargs + ) + + if run_async: + return server_coroutine + + trigger_events(before_start, loop) + + try: + http_server = loop.run_until_complete(server_coroutine) + except BaseException: + logger.exception("Unable to start server") + return + + trigger_events(after_start, loop) + + # Ignore SIGINT when run_multiple + if run_multiple: + signal_func(SIGINT, SIG_IGN) + + # Register signals for graceful termination + if register_sys_signals: + _singals = (SIGTERM,) if run_multiple else (SIGINT, SIGTERM) + for _signal in _singals: + try: + loop.add_signal_handler(_signal, loop.stop) + except NotImplementedError: + logger.warning( + "Sanic tried to use loop.add_signal_handler " + "but it is not implemented on this platform." + ) + pid = os.getpid() + try: + logger.info("Starting worker [%s]", pid) + loop.run_forever() + finally: + logger.info("Stopping worker [%s]", pid) + + # Run the on_stop function if provided + trigger_events(before_stop, loop) + + # Wait for event loop to finish and all connections to drain + http_server.close() + loop.run_until_complete(http_server.wait_closed()) + + # Complete all tasks on the loop + signal.stopped = True + for connection in connections: + connection.close_if_idle() + + # Gracefully shutdown timeout. + # We should provide graceful_shutdown_timeout, + # instead of letting connection hangs forever. + # Let's roughly calcucate time. + start_shutdown = 0 + while connections and (start_shutdown < graceful_shutdown_timeout): + loop.run_until_complete(asyncio.sleep(0.1)) + start_shutdown = start_shutdown + 0.1 + + # Force close non-idle connection after waiting for + # graceful_shutdown_timeout + coros = [] + for conn in connections: + if hasattr(conn, "websocket") and conn.websocket: + coros.append(conn.websocket.close_connection()) + else: + conn.close() + + _shutdown = asyncio.gather(*coros, loop=loop) + loop.run_until_complete(_shutdown) + + trigger_events(after_stop, loop) + + loop.close() + + +def serve_multiple(server_settings, workers): + """Start multiple server processes simultaneously. Stop on interrupt + and terminate signals, and drain connections when complete. + + :param server_settings: kw arguments to be passed to the serve function + :param workers: number of workers to launch + :param stop_event: if provided, is used as a stop signal + :return: + """ + server_settings["reuse_port"] = True + server_settings["run_multiple"] = True + + # Handling when custom socket is not provided. + if server_settings.get("sock") is None: + sock = socket() + sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) + sock.bind((server_settings["host"], server_settings["port"])) + sock.set_inheritable(True) + server_settings["sock"] = sock + server_settings["host"] = None + server_settings["port"] = None + + processes = [] + + def sig_handler(signal, frame): + logger.info("Received signal %s. Shutting down.", Signals(signal).name) + for process in processes: + os.kill(process.pid, SIGTERM) + + signal_func(SIGINT, lambda s, f: sig_handler(s, f)) + signal_func(SIGTERM, lambda s, f: sig_handler(s, f)) + + for _ in range(workers): + process = Process(target=serve, kwargs=server_settings) + process.daemon = True + process.start() + processes.append(process) + + for process in processes: + process.join() + + # the above processes will block this until they're stopped + for process in processes: + process.terminate() + server_settings.get("sock").close() diff --git a/sanic/server_trio.py b/sanic/server_trio.py new file mode 100644 index 0000000000..8902dcc45e --- /dev/null +++ b/sanic/server_trio.py @@ -0,0 +1,541 @@ +import os +import socket +import stat +import sys +import time +import traceback + +from enum import Enum +from functools import partial +from inspect import isawaitable +from ipaddress import ip_address +import multiprocessing as mp +from signal import SIG_IGN, SIGINT, SIGTERM, Signals +from signal import signal as signal_func +from time import sleep as time_sleep +from time import time + +import trio + +from h2.config import H2Configuration +from h2.connection import H2Connection +from h2.events import ConnectionTerminated, DataReceived, RequestReceived +from httptools.parser.errors import HttpParserError + +from sanic.compat import Header +from sanic.exceptions import ( + HeaderExpectationFailed, + InvalidUsage, + PayloadTooLarge, + RequestTimeout, + SanicException, + ServerError, + ServiceUnavailable, + VersionNotSupported, +) +from sanic.log import access_logger, logger +from sanic.protocol import H1Stream +from sanic.request import EXPECT_HEADER, Request, StreamBuffer +from sanic.response import HTTPResponse + + +try: + from signal import SIGHUP +except: + SIGHUP = SIGTERM + + + +class Signal: + stopped = False + +SSL, H2 = Enum("Protocol", "SSL H2") + +h2config = H2Configuration( + client_side=False, + header_encoding="utf-8", + validate_outbound_headers=False, + normalize_outbound_headers=False, + validate_inbound_headers=False, + normalize_inbound_headers=False, + logger=None, # logger +) + + +idle_connections = set() +quit = trio.Event() + +def trigger_graceful_exit(): + """Signals all running connections to terminate smoothly.""" + # Disallow new requests + quit.set() + # Promptly terminate idle connections + for c in idle_connections: + c.cancel() + +def parse_h1_request(data: bytes) -> dict: + try: + data = data.decode() + except UnicodeDecodeError: + data = data.decode("ISO-8859-1") + req, *hlines = data.split("\r\n") + method, path, version = req.split(" ", 2) + if version != "HTTP/1.1": + raise VersionNotSupported(f"Expected 'HTTP/1.1', got '{version}'") + headers = {":method": method, ":path": path} + for name, val in (h.split(": ", 1) for h in hlines): + name = name.lower() + headers[name] = f"{headers[name]}, {val}" if name in headers else val + return headers + + +def push_back(stream, data): + assert data, "Ensure that data is not empty before calling this function." + stream_type = type(stream) + + class PushbackStream(stream_type): + async def receive_some(self, max_bytes=None): + if max_bytes and max_bytes < len(data): + ret = data[:max_bytes] + del data[:max_bytes] + return ret + self.__class__ = stream_type + return data + + stream.__class__ = PushbackStream + + +class HttpProtocol: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + self.request_class = self.request_class or Request + self.stream = None + self.servername = None # User-visible server hostname, no port! + + async def ssl_init(self): + def servername_callback(sock, req_hostname, cb_context): + self.servername = req_hostname + + self.ssl.sni_callback = servername_callback + self.stream = trio.SSLStream( + self.stream, self.ssl, server_side=True, https_compatible=True + ) + await self.stream.do_handshake() + self.alpn = self.stream.selected_alpn_protocol() + + async def sniff_protocol(self): + req = await self.receive_request() + if req is SSL: + if not self.ssl: + raise RuntimeError("Only plain HTTP supported (not SSL).") + await self.ssl_init() + if not self.alpn or self.alpn == "http/1.1": + return self.http1() + if self.alpn == "h2": + return self.http2() + raise RuntimeError(f"Unknown ALPN {self.alpn}") + # HTTP2 (not Upgrade) + if req is H2: + return self.http2() + if req is None: + return None + # HTTP1 but might be Upgrade to websocket or h2c + headers = parse_h1_request(req) + upgrade = headers.get("upgrade") + if upgrade == "h2c": + return self.http2(settings_header=headers["http2-settings"]) + if upgrade == "websocket": + self.websocket = True + return self.http1(headers=headers) + + def set_timeout(self, timeout: str): + self.nursery.cancel_scope.deadline = trio.current_time() + getattr( + self, f"{timeout}_timeout" + ) + + async def run(self, stream): + assert not self.stream + self.stream = stream + self.stream.push_back = partial(push_back, stream) + async with stream, trio.open_nursery() as self.nursery: + try: + self.set_timeout("request") + protocol_coroutine = await self.sniff_protocol() + if not protocol_coroutine: + return + await protocol_coroutine + self.nursery.cancel_scope.cancel() # Terminate all connections + except trio.BrokenResourceError: + pass # Connection reset by peer + except Exception: + logger.exception("Error in server") + finally: + idle_connections.discard(self.nursery.cancel_scope) + + async def receive_request(self): + idle_connections.add(self.nursery.cancel_scope) + buffer = None + prevpos = 0 + async for data in self.stream: + idle_connections.discard(self.nursery.cancel_scope) + if buffer is None: + buffer = data + else: + # Headers normally come in one packet, so this may be slower + buffer += data + buflen = len(buffer) + if buffer[0] == 0x16: + self.stream.push_back(buffer) + return SSL + if buflen > self.request_max_size: + raise RuntimeError("Request larger than request_max_size") + pos = buffer.find(b"\r\n\r\n", prevpos) + if pos > 0: + req = buffer[:pos] + if req == b"PRI * HTTP/2.0": + self.stream.push_back(buffer) + return H2 + pos += 4 # Skip header and its trailing \r\n\r\n + if buflen > pos: + self.stream.push_back(buffer[pos:]) + return req + prevpos = buflen - 3 # \r\n\r\n may cross packet boundary + if buffer: + raise RuntimeError(f"Peer disconnected after {buffer!r:.200}") + + async def http1(self, headers=None): + _response = None + while not quit.is_set(): + # Process request + if headers is None: + req = await self.receive_request() + if not isinstance(req, (bytes, bytearray)): + return + headers = parse_h1_request(req) + + request = self.request_class( + url_bytes=headers[":path"].encode(), + headers=Header(headers), + version="1.1", + method=headers[":method"], + transport=None, + app=self.app, + ) + need_continue = headers.get("expect", "").lower() == "100-continue" + + if "chunked" in headers.get("transfer-encoding", "").lower(): + raise RuntimeError("Chunked requests not supported") # FIXME + + request.stream = H1Stream(headers, self.stream, need_continue) + headers = None + # Timeout between consecutive requests, reset *here* to minimize + # chances of timeouting request handlers. + self.set_timeout("request") + try: + await self.request_handler(request) + except trio.BrokenResourceError: + logger.info(f"Client disconnected during {request.method} {request.path}") + return + except Exception as e: + r = self.app.error_handler.default(request, e) + try: + await request.stream.respond(r.status, r.headers).send( + data_bytes=r.body + ) + except RuntimeError: + pass # If we cannot send to client anymore + raise + finally: + await request.stream.aclose() + + async def h2_sender(self): + async for _ in self.can_send: + await self.stream.send_all(self.conn.data_to_send()) + + async def http2(self, settings_header=None): + self.conn = H2Connection(config=h2config) + if settings_header: # Upgrade from HTTP 1.1 + self.conn.initiate_upgrade_connection(settings_header) + await self.stream.send_all( + b"HTTP/1.1 101 Switching Protocols\r\n" + b"Connection: Upgrade\r\n" + b"Upgrade: h2c\r\n" + b"\r\n" + self.conn.data_to_send() + ) + else: # straight into HTTP/2 mode + self.conn.initiate_connection() + await self.stream.send_all(self.conn.data_to_send()) + # A trigger mechanism that ensures promptly sending data from self.conn + # to stream; size must be > 0 to avoid data left unsent in buffer + # when a stream is canceled while awaiting on send_some. + self.send_some, self.can_send = trio.open_memory_channel(1) + self.nursery.start_soon(self.h2_sender) + idle_connections.add(self.nursery.cancel_scope) + self.requests = {} + async for data in self.stream: + for event in self.conn.receive_data(data): + # print("-*-", event) + if isinstance(event, RequestReceived): + self.nursery.start_soon( + self.h2request, event.stream_id, event.headers + ) + # idle_connections.discard(self.nursery.cancel_scope) + if isinstance(event, ConnectionTerminated): + return + await self.send_some.send(...) + + async def h2request(self, stream_id, headers): + hdrs = {} + for name, value in headers: + old = hdrs.get(name) + hdrs[name] = value if old is None else f"{old}, {value}" + # Process response + request = self.request_class( + url_bytes=hdrs.get(":path", "").encode(), + headers=Header(headers), + version="h2", + method=hdrs[":method"], + transport=None, + app=self.app, + ) + + async def write_response(response): + headers = ( + (":status", f"{response.status}"), + ("content-length", f"{len(response.body)}"), + ("content-type", response.content_type), + *response.headers, + ) + self.conn.send_headers(stream_id, headers) + self.conn.send_data(stream_id, response.body, end_stream=True) + await self.send_some.send(...) + + with trio.fail_after(self.response_timeout): + await self.request_handler(request, write_response, None) + + async def websocket(self): + logger.info("Websocket requested, not yet implemented") + + +async def trigger_events(events): + """Trigger event callbacks (functions or async) + + :param events: one or more sync or async functions to execute + :param loop: event loop + """ + for event in events: + result = event() + if isawaitable(result): + await result + + +def serve( + host, + port, + app, + request_handler, + error_handler, + before_start=None, + after_start=None, + before_stop=None, + after_stop=None, + debug=False, + request_timeout=60, + response_timeout=60, + keep_alive_timeout=5, + ssl=None, + sock=None, + request_max_size=None, + request_buffer_queue_size=100, + protocol=HttpProtocol, + backlog=100, + register_sys_signals=True, + run_multiple=False, + run_async=False, + connections=None, + signal=Signal(), + request_class=None, + access_log=True, + keep_alive=True, + is_request_stream=False, + router=None, + websocket_max_size=None, + websocket_max_queue=None, + websocket_read_limit=2 ** 16, + websocket_write_limit=2 ** 16, + state=None, + graceful_shutdown_timeout=15.0, + asyncio_server_kwargs=None, + workers=1, + loop=None, +): + proto = partial( + protocol, + connections=connections, + signal=signal, + app=app, + ssl=ssl, + request_handler=request_handler, + error_handler=error_handler, + request_timeout=request_timeout, + response_timeout=response_timeout, + keep_alive_timeout=keep_alive_timeout, + request_max_size=request_max_size, + request_class=request_class, + access_log=access_log, + keep_alive=keep_alive, + is_request_stream=is_request_stream, + router=router, + websocket_max_size=websocket_max_size, + websocket_max_queue=websocket_max_queue, + websocket_read_limit=websocket_read_limit, + websocket_write_limit=websocket_write_limit, + state=state, + debug=debug, + ) + + app.asgi = False + assert not ( + run_async or run_multiple or asyncio_server_kwargs or loop + ), "Not implemented" + + acceptor = partial( + runaccept, + before_start=before_start, + after_start=after_start, + before_stop=before_stop, + after_stop=after_stop, + proto=proto, + graceful_shutdown_timeout=graceful_shutdown_timeout, + ) + + server = partial( + runserver, + acceptor=acceptor, + host=host, + port=port, + sock=sock, + backlog=backlog, + workers=workers, + ) + return server() if run_async else trio.run(server) + + +async def runserver(acceptor, host, port, sock, backlog, workers): + if host and host.startswith("unix:"): + open_listeners = partial( + # Not Implemented: open_unix_listeners, path=host[5:], backlog=backlog + ) + else: + open_listeners = partial( + trio.open_tcp_listeners, + host=host, + port=port or 8000, + backlog=backlog, + ) + try: + listeners = await open_listeners() + except Exception: + logger.exception("Unable to start server") + return + for l in listeners: + l.socket.set_inheritable(True) + processes = [] + + try: + if workers: + # Spawn method is consistent across platforms and, unlike the fork + # method, can be used from within async functions (such as this one). + mp.set_start_method("spawn") + with trio.open_signal_receiver(SIGINT, SIGTERM, SIGHUP) as sigiter: + logger.info(f"Server starting, {workers} worker processes") + while True: + while len(processes) < workers: + p = mp.Process( + target=trio.run, args=(acceptor, listeners) + ) + p.daemon = True + p.start() + processes.append(p) + logger.info("Worker [%s]", p.pid) + # Wait for signals and periodically check processes + with trio.move_on_after(0.1): + s = await sigiter.__anext__() + logger.info(f"Server received {Signals(s).name}") + for p in processes: + p.terminate() + if s in (SIGTERM, SIGINT): + break + processes = [p for p in processes if p.is_alive()] + else: # workers=0 + logger.info("Server and worker started") + await acceptor(listeners) + finally: + with trio.CancelScope() as cs: + cs.shield = True + # Close listeners and wait for workers to terminate + for l in listeners: + await l.aclose() + for p in processes: + p.join() + logger.info("Server stopped") + +async def sighandler(scopes, task_status=trio.TASK_STATUS_IGNORED): + with trio.open_signal_receiver(SIGINT, SIGTERM, SIGHUP) as sigiter: + t = None + task_status.started() + async for s in sigiter: + # Ignore spuriously repeated signals + if t is not None and trio.current_time() - t < 0.5: + logger.debug(f"Ignored {Signals(s).name}") + continue + logger.info(f"Received {Signals(s).name}") + if not scopes: + raise trio.Cancelled("Signaled too many times") + scopes.pop().cancel() + t = trio.current_time() + +async def runbench(stream): + async for d in stream: + if d[-4:] == b"\r\n\r\n": + await stream.send_all(b"HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\ncontent-length: 13\r\n\r\nHello World!\n") + +async def runaccept( + listeners, + before_start, + after_start, + before_stop, + after_stop, + proto, + graceful_shutdown_timeout, +): + try: + async with trio.open_nursery() as signal_nursery: + cancel_scopes = [signal_nursery.cancel_scope] + sigscope = await signal_nursery.start(sighandler, cancel_scopes) + async with trio.open_nursery() as main_nursery: + cancel_scopes.append(main_nursery.cancel_scope) + await trigger_events(before_start) + # Accept connections until a signal is received + async with trio.open_nursery() as acceptor: + cancel_scopes.append(acceptor.cancel_scope) + acceptor.start_soon( + partial( + trio.serve_listeners, + handler=lambda stream: proto().run(stream), + listeners=listeners, + handler_nursery=main_nursery, + ) + ) + await trigger_events(after_start) + # No longer accepting new connections. Attempt graceful exit. + main_nursery.cancel_scope.deadline = ( + trio.current_time() + graceful_shutdown_timeout + ) + trigger_graceful_exit() + await trigger_events(before_stop) + await trigger_events(after_stop) + signal_nursery.cancel_scope.cancel() # Exit signal handler + logger.info(f"Worker finished gracefully") + except BaseException: + logger.exception(f"Worker terminating")