Skip to content

Commit

Permalink
[melobot] Refactor plugin working mechanism & add deprecated removing…
Browse files Browse the repository at this point in the history
… flags
  • Loading branch information
aicorein committed Dec 10, 2024
1 parent d1d67f0 commit 4508081
Show file tree
Hide file tree
Showing 10 changed files with 283 additions and 111 deletions.
4 changes: 3 additions & 1 deletion src/melobot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from .di import Depends
from .handle import Flow, FlowStore, node, rewind, stop
from .log import GenericLogger, Logger, LogLevel, get_logger
from .plugin import AsyncShare, Plugin, PluginLifeSpan, SyncShare

# REMOVE: 3.0.0 (Plugin)
from .plugin import AsyncShare, Plugin, PluginLifeSpan, PluginPlanner, SyncShare
from .session import Rule, SessionStore, enter_session, suspend
from .typ import HandleLevel, LogicMode
21 changes: 14 additions & 7 deletions src/melobot/bot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..exceptions import BotError
from ..io.base import AbstractInSource, AbstractIOSource, AbstractOutSource
from ..log.base import GenericLogger, Logger, NullLogger
from ..plugin.base import Plugin, PluginLifeSpan
from ..plugin.base import Plugin, PluginLifeSpan, PluginPlanner
from ..plugin.ipc import AsyncShare, IPCManager, SyncShare
from ..plugin.load import PluginLoader
from ..protocols.base import ProtocolStack
Expand Down Expand Up @@ -241,11 +241,13 @@ def _core_init(self) -> None:
self.logger.debug(f"当前异步事件循环策略:{asyncio.get_event_loop_policy()}")

def load_plugin(
self, plugin: ModuleType | str | PathLike[str] | Plugin, load_depth: int = 1
self,
plugin: ModuleType | str | PathLike[str] | PluginPlanner,
load_depth: int = 1,
) -> Bot:
"""加载插件
:param plugin: 可以被加载为插件的对象(插件目录对应的模块,插件的目录路径,可直接 import 包名称,插件对象
:param plugin: 可以被加载为插件的对象(插件目录对应的模块,插件的目录路径,可直接 import 包名称,插件管理器对象
:param load_depth:
插件加载时的相对引用深度,默认值 1 只支持向上引用到插件目录一级。
增加为 2 可以引用到插件目录的父目录一级,依此类推。
Expand All @@ -261,20 +263,25 @@ def load_plugin(
return self

self._plugins[p.name] = p
self._dispatcher.internal_add(*p.handlers)

if self._hook_bus.get_evoke_time(BotLifeSpan.STARTED) != -1:
asyncio.create_task(self._dispatcher.add(*p.handlers))
else:
self._dispatcher.add_nowait(*p.handlers)

for share in p.shares:
self.ipc_manager.add(p.name, share)
for func in p.funcs:
self.ipc_manager.add_func(p.name, func)
self.logger.info(f"成功加载插件:{p.name}")

if self._hook_bus.get_evoke_time(BotLifeSpan.STARTED) != -1:
asyncio.create_task(p._hook_bus.emit(PluginLifeSpan.INITED))
asyncio.create_task(p.hook_bus.emit(PluginLifeSpan.INITED))
return self

def load_plugins(
self,
plugins: Iterable[ModuleType | str | PathLike[str] | Plugin],
plugins: Iterable[ModuleType | str | PathLike[str] | PluginPlanner],
load_depth: int = 1,
) -> None:
"""与 :func:`load_plugin` 行为类似,但是参数变为可迭代对象
Expand Down Expand Up @@ -335,7 +342,7 @@ async def core_run(self) -> None:
await self._hook_bus.emit(BotLifeSpan.RELOADED)

for p in self._plugins.values():
await p._hook_bus.emit(PluginLifeSpan.INITED)
await p.hook_bus.emit(PluginLifeSpan.INITED)

timed_task = asyncio.create_task(self._dispatcher.timed_gc())
self._tasks.append(timed_task)
Expand Down
30 changes: 19 additions & 11 deletions src/melobot/bot/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..adapter.model import Event
from ..handle.base import EventHandler
from ..typ import HandleLevel
from ..typ import AsyncCallable, HandleLevel
from ..utils import RWContext

KeyT = TypeVar("KeyT", bound=float, default=float)
Expand Down Expand Up @@ -52,19 +52,23 @@ def setdefault(self, key: KeyT, default: ValT) -> ValT:
class Dispatcher:
def __init__(self) -> None:
self.handlers: _KeyOrderDict[HandleLevel, set[EventHandler]] = _KeyOrderDict()
self.broadcast_ctrl = RWContext()
self.dispatch_ctrl = RWContext()
self.gc_interval = 5

def internal_add(self, *handlers: EventHandler) -> None:
def add_nowait(self, *handlers: EventHandler) -> None:
for h in handlers:
self.handlers.setdefault(h.flow.priority, set()).add(h)
h.flow.on_priority_reset(
lambda new_prior, h=h: self._reset_hook(h, new_prior)
)

async def add(self, *handlers: EventHandler) -> None:
async with self.broadcast_ctrl.write():
self.internal_add(*handlers)
async def add(
self, *handlers: EventHandler, callback: AsyncCallable[[], None] | None = None
) -> None:
async with self.dispatch_ctrl.write():
self.add_nowait(*handlers)
if callback is not None:
await callback()

async def _remove(self, *handlers: EventHandler) -> None:
for h in handlers:
Expand All @@ -74,15 +78,19 @@ async def _remove(self, *handlers: EventHandler) -> None:
if len(h_set) == 0:
self.handlers.pop(h.flow.priority)

async def expire(self, *handlers: EventHandler) -> None:
async with self.broadcast_ctrl.write():
async def remove(
self, *handlers: EventHandler, callback: AsyncCallable[[], None] | None = None
) -> None:
async with self.dispatch_ctrl.write():
await self._remove(*handlers)
if callback is not None:
await callback()

async def _reset_hook(self, handler: EventHandler, new_prior: HandleLevel) -> None:
if handler.flow.priority == new_prior:
return

async with self.broadcast_ctrl.write():
async with self.dispatch_ctrl.write():
old_prior = handler.flow.priority
if old_prior == new_prior:
return
Expand All @@ -93,7 +101,7 @@ async def _reset_hook(self, handler: EventHandler, new_prior: HandleLevel) -> No
self.handlers.setdefault(new_prior, set()).add(handler)

async def broadcast(self, event: Event) -> None:
async with self.broadcast_ctrl.read():
async with self.dispatch_ctrl.read():
for h_set in self.handlers.values():
tasks = tuple(asyncio.create_task(h.handle(event)) for h in h_set)
await asyncio.wait(tasks)
Expand All @@ -103,7 +111,7 @@ async def broadcast(self, event: Event) -> None:
async def timed_gc(self) -> None:
while True:
await asyncio.sleep(self.gc_interval)
async with self.broadcast_ctrl.write():
async with self.dispatch_ctrl.write():
hs = tuple(
h for h_set in self.handlers.values() for h in h_set if h.invalid
)
Expand Down
1 change: 1 addition & 0 deletions src/melobot/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def unfold(self, obj: T) -> Generator[T, None, None]:
finally:
self.remove(token)

# REMOVE: 3.0.0
@deprecated("将于 melobot v3.0.0 移除,使用 unfold 方法代替")
@contextmanager
def in_ctx(self, obj: T) -> Generator[None, None, None]:
Expand Down
6 changes: 3 additions & 3 deletions src/melobot/handle/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, plugin: "Plugin", flow: Flow) -> None:
self._temp = flow.temp
self.invalid = False

async def _handle_event(self, event: Event) -> None:
async def _handle(self, event: Event) -> None:
try:
await self.flow.run(event)
except Exception:
Expand All @@ -39,12 +39,12 @@ async def handle(self, event: Event) -> None:
async with self._handle_ctrl.read():
if self.invalid:
return
return await self._handle_event(event)
return await self._handle(event)

async with self._handle_ctrl.write():
if self.invalid:
return
await self._handle_event(event)
await self._handle(event)
self.invalid = True
return

Expand Down
4 changes: 3 additions & 1 deletion src/melobot/plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .base import Plugin, PluginLifeSpan
# REMOVE: 3.0.0
from .base import LegacyPlugin as Plugin
from .base import PluginLifeSpan, PluginPlanner
from .ipc import AsyncShare, SyncShare
13 changes: 9 additions & 4 deletions src/melobot/plugin/__init__.template.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# pylint: disable=invalid-name
from os import listdir as _VAR1
from pathlib import Path as _VAR2
from typing import Any as _VAR3

from melobot.plugin.load import plugin_get_attr as _VAR4
from melobot import get_bot as _VAR4

_VAR5 = _VAR2(__file__).parent
_VAR6 = set(fname.split(".")[0] for fname in _VAR1(_VAR5))
_VAR7 = _VAR5.parts[-1]


def __getattr__(name: str) -> _VAR3:
if name in _VAR6 or name.startswith("_"):
def __getattr__(_VAR8: str) -> _VAR3:
if _VAR8 in _VAR6 or _VAR8.startswith("_"):
raise AttributeError
return _VAR4(_VAR5.parts[-1], name)
_VAR9 = _VAR4().get_share(_VAR7, _VAR8)
if _VAR9.static:
return _VAR9.get()
return _VAR9
Loading

0 comments on commit 4508081

Please sign in to comment.