Skip to content

Commit

Permalink
🐛 fix: add Lifespan._on_ready() for forward adapter startup
Browse files Browse the repository at this point in the history
fix #2475
  • Loading branch information
ProgramRipper committed Dec 5, 2023
1 parent 28ad682 commit a47c35a
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 60 deletions.
12 changes: 0 additions & 12 deletions nonebot/drivers/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import HTTPServerSetup, WebSocketServerSetup

from ._lifespan import LIFESPAN_FUNC, Lifespan

try:
import uvicorn
from fastapi.responses import Response
Expand Down Expand Up @@ -97,8 +95,6 @@ def __init__(self, env: Env, config: NoneBotConfig):

self.fastapi_config: Config = Config(**config.dict())

self._lifespan = Lifespan()

self._server_app = FastAPI(
lifespan=self._lifespan_manager,
openapi_url=self.fastapi_config.fastapi_openapi_url,
Expand Down Expand Up @@ -155,14 +151,6 @@ async def _handle(websocket: WebSocket) -> None:
name=setup.name,
)

@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_startup(func)

@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_shutdown(func)

@contextlib.asynccontextmanager
async def _lifespan_manager(self, app: FastAPI):
await self._lifespan.startup()
Expand Down
14 changes: 0 additions & 14 deletions nonebot/drivers/none.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from nonebot.config import Env, Config
from nonebot.drivers import Driver as BaseDriver

from ._lifespan import LIFESPAN_FUNC, Lifespan

HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
Expand All @@ -35,8 +33,6 @@ class Driver(BaseDriver):
def __init__(self, env: Env, config: Config):
super().__init__(env, config)

self._lifespan = Lifespan()

self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False

Expand All @@ -52,16 +48,6 @@ def logger(self):
"""none driver 使用的 logger"""
return logger

@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)

@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个停止时执行的函数"""
return self._lifespan.on_shutdown(func)

@override
def run(self, *args, **kwargs):
"""启动 none driver"""
Expand Down
27 changes: 3 additions & 24 deletions nonebot/drivers/quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,7 @@
import asyncio
from functools import wraps
from typing_extensions import override
from typing import (
Any,
Dict,
List,
Tuple,
Union,
TypeVar,
Callable,
Optional,
Coroutine,
cast,
)
from typing import Any, Dict, List, Tuple, Union, Optional, cast

from pydantic import BaseSettings

Expand Down Expand Up @@ -57,8 +46,6 @@
"Install with pip: `pip install nonebot2[quart]`"
) from e

_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])


def catch_closed(func):
@wraps(func)
Expand Down Expand Up @@ -102,6 +89,8 @@ def __init__(self, env: Env, config: NoneBotConfig):
self._server_app = Quart(
self.__class__.__qualname__, **self.quart_config.quart_extra
)
self._server_app.before_serving(self._lifespan.startup)
self._server_app.after_serving(self._lifespan.shutdown)

@property
@override
Expand Down Expand Up @@ -150,16 +139,6 @@ async def _handle() -> None:
view_func=_handle,
)

@override
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.before_serving(func) # type: ignore

@override
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.after_serving(func) # type: ignore

@override
def run(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
class Lifespan:
def __init__(self) -> None:
self._startup_funcs: List[LIFESPAN_FUNC] = []
self._ready_funcs: List[LIFESPAN_FUNC] = []
self._shutdown_funcs: List[LIFESPAN_FUNC] = []

def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
Expand All @@ -21,6 +22,10 @@ def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._shutdown_funcs.append(func)
return func

def _on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._ready_funcs.append(func)
return func

Check warning on line 27 in nonebot/internal/driver/_lifespan.py

View check run for this annotation

Codecov / codecov/patch

nonebot/internal/driver/_lifespan.py#L26-L27

Added lines #L26 - L27 were not covered by tests

@staticmethod
async def _run_lifespan_func(
funcs: List[LIFESPAN_FUNC],
Expand All @@ -35,6 +40,8 @@ async def startup(self) -> None:
if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs)

await self._run_lifespan_func(self._ready_funcs)

async def shutdown(self) -> None:
if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs)
Expand Down
18 changes: 9 additions & 9 deletions nonebot/internal/driver/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
from typing_extensions import TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
from typing import TYPE_CHECKING, Any, Set, Dict, Type, AsyncGenerator

from nonebot.log import logger
from nonebot.config import Env, Config
Expand All @@ -16,6 +16,7 @@
T_BotDisconnectionHook,
)

from ._lifespan import LIFESPAN_FUNC, Lifespan
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup

if TYPE_CHECKING:
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self, env: Env, config: Config):
"""全局配置对象"""
self._bots: Dict[str, "Bot"] = {}
self._bot_tasks: Set[asyncio.Task] = set()
self._lifespan = Lifespan()

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -100,15 +102,13 @@ def run(self, *args, **kwargs):

self.on_shutdown(self._cleanup)

@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
"""注册一个在驱动器启动时执行的函数"""
raise NotImplementedError
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)

Check warning on line 107 in nonebot/internal/driver/abstract.py

View check run for this annotation

Codecov / codecov/patch

nonebot/internal/driver/abstract.py#L107

Added line #L107 was not covered by tests

@abc.abstractmethod
def on_shutdown(self, func: Callable) -> Callable:
"""注册一个在驱动器停止时执行的函数"""
raise NotImplementedError
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个停止时执行的函数"""
return self._lifespan.on_shutdown(func)

Check warning on line 111 in nonebot/internal/driver/abstract.py

View check run for this annotation

Codecov / codecov/patch

nonebot/internal/driver/abstract.py#L111

Added line #L111 was not covered by tests

@classmethod
def on_bot_connect(cls, func: T_BotConnectionHook) -> T_BotConnectionHook:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from nonebot.params import Depends
from nonebot.dependencies import Dependent
from nonebot.exception import WebSocketClosed
from nonebot.drivers._lifespan import Lifespan
from nonebot.internal.driver._lifespan import Lifespan
from nonebot.drivers import (
URL,
Driver,
Expand Down

0 comments on commit a47c35a

Please sign in to comment.