diff --git a/tests/utils.py b/tests/utils.py index 47dd21353..8a6f83bc5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,3 @@ -import asyncio import os from contextlib import asynccontextmanager, contextmanager from pathlib import Path @@ -8,14 +7,12 @@ @asynccontextmanager async def run_server(config: Config, sockets=None): - server = Server(config=config) - cancel_handle = asyncio.ensure_future(server.serve(sockets=sockets)) - await asyncio.sleep(0.1) + server = Server(config=config, sockets=sockets) + await server.start_serving() try: yield server finally: await server.shutdown() - cancel_handle.cancel() @contextmanager diff --git a/uvicorn/server.py b/uvicorn/server.py index 494e2b658..052a42549 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -46,7 +46,9 @@ def __init__(self) -> None: class Server: - def __init__(self, config: Config) -> None: + def __init__( + self, config: Config, *, sockets: Optional[List[socket.socket]] = None + ) -> None: self.config = config self.server_state = ServerState() @@ -55,11 +57,19 @@ def __init__(self, config: Config) -> None: self.force_exit = False self.last_notified = 0.0 + self._sockets: Optional[List[socket.socket]] = sockets + + self._main_task: Optional[asyncio.Task] = None + + #: Created on demand and set once immediately after startup has + #: completed and the server has started listening for requests. + self._startup_event: Optional[asyncio.Event] = None + def run(self, sockets: Optional[List[socket.socket]] = None) -> None: self.config.setup_event_loop() return asyncio.run(self.serve(sockets=sockets)) - async def serve(self, sockets: Optional[List[socket.socket]] = None) -> None: + async def _main(self) -> None: process_id = os.getpid() config = self.config @@ -68,22 +78,62 @@ async def serve(self, sockets: Optional[List[socket.socket]] = None) -> None: self.lifespan = config.lifespan_class(config) - self.install_signal_handlers() - message = "Started server process [%d]" color_message = "Started server process [" + click.style("%d", fg="cyan") + "]" logger.info(message, process_id, extra={"color_message": color_message}) - await self.startup(sockets=sockets) + await self.startup(sockets=self._sockets) if self.should_exit: return await self.main_loop() - await self.shutdown(sockets=sockets) + await self._shutdown(sockets=self._sockets) message = "Finished server process [%d]" color_message = "Finished server process [" + click.style("%d", fg="cyan") + "]" logger.info(message, process_id, extra={"color_message": color_message}) + async def start_serving(self) -> None: + """ + Starts the server running in a background task and blocks until startup + is either fully complete, or fails. + + Idempotent. Can be called multiple times without creating multiple + instances. + """ + if self._startup_event is None: + # We defer creating the startup event until start serving is called + # because there is no guarantee that the constructor will be called + # with an active loop. + self._startup_event = asyncio.Event() + + if self._main_task is None: + self._main_task = asyncio.create_task(self._main()) + + # If the main task exits before the startup event is set it means that + # startup has failed. + await asyncio.wait( + [asyncio.create_task(self._startup_event.wait()), self._main_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + async def serve(self, sockets: Optional[List[socket.socket]] = None) -> None: + if self._main_task is not None: + raise RuntimeError("cannot call serve on running server") + + if sockets is not None: + if self._sockets is not None: + raise RuntimeError("cannot override already provided sockets list") + self._sockets = sockets + + self.install_signal_handlers() + try: + await self.start_serving() + await self.wait_closed() + + except asyncio.CancelledError: + self.close() + await self.wait_closed() + async def startup(self, sockets: list = None) -> None: await self.lifespan.startup() if self.lifespan.should_exit: @@ -171,6 +221,9 @@ def _share_socket(sock: socket.SocketType) -> socket.SocketType: self.started = True + assert self._startup_event is not None + self._startup_event.set() + def _log_started_message(self, listeners: Sequence[socket.SocketType]) -> None: config = self.config @@ -249,7 +302,7 @@ async def on_tick(self, counter: int) -> bool: return self.server_state.total_requests >= self.config.limit_max_requests return False - async def shutdown(self, sockets: Optional[List[socket.socket]] = None) -> None: + async def _shutdown(self, sockets: Optional[List[socket.socket]] = None) -> None: logger.info("Shutting down") # Stop accepting new connections. @@ -283,6 +336,27 @@ async def shutdown(self, sockets: Optional[List[socket.socket]] = None) -> None: if not self.force_exit: await self.lifespan.shutdown() + def close(self, *, force_exit: bool = False) -> None: + """ + Asks the server, asynchronously, to initiate shutdown. + It should be safe to call this from a request handler. + """ + self.should_exit = True + if force_exit and not self.force_exit: + self.force_exit = True + + async def wait_closed(self) -> None: + """ + Blocks until the server is completely shutdown. + """ + if self._main_task is None: + raise RuntimeError("Server hasn't been started") + await self._main_task + + async def shutdown(self) -> None: + self.close() + await self.wait_closed() + def install_signal_handlers(self) -> None: if threading.current_thread() is not threading.main_thread(): # Signals can only be listened to from the main thread.