Skip to content

Commit

Permalink
Improve functionality and consistency of Sanic.event() (#2827)
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
  • Loading branch information
talljosh and ahopkins authored Dec 7, 2023
1 parent 4499d2c commit 5f0787b
Show file tree
Hide file tree
Showing 5 changed files with 369 additions and 81 deletions.
25 changes: 19 additions & 6 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from asyncio.futures import Future
from collections import defaultdict, deque
from contextlib import contextmanager, suppress
from enum import Enum
from functools import partial, wraps
from inspect import isawaitable
from os import environ
Expand Down Expand Up @@ -714,7 +715,12 @@ async def handle_registration(request):
)

async def event(
self, event: str, timeout: Optional[Union[int, float]] = None
self,
event: Union[str, Enum],
timeout: Optional[Union[int, float]] = None,
*,
condition: Optional[Dict[str, Any]] = None,
exclusive: bool = True,
) -> None:
"""Wait for a specific event to be triggered.
Expand All @@ -737,13 +743,18 @@ async def event(
timeout (Optional[Union[int, float]]): An optional timeout value
in seconds. If provided, the wait will be terminated if the
timeout is reached. Defaults to `None`, meaning no timeout.
condition: If provided, method will only return when the signal
is dispatched with the given condition.
exclusive: When true (default), the signal can only be dispatched
when the condition has been met. When ``False``, the signal can
be dispatched either with or without it.
Raises:
NotFound: If the event is not found and auto-registration of
events is not enabled.
Returns:
None
The context dict of the dispatched signal.
Examples:
```python
Expand All @@ -759,16 +770,18 @@ async def after_server_start(app, loop):
```
"""

signal = self.signal_router.name_index.get(event)
if not signal:
waiter = self.signal_router.get_waiter(event, condition, exclusive)
if not waiter:
if self.config.EVENT_AUTOREGISTER:
self.signal_router.reset()
self.add_signal(None, event)
signal = self.signal_router.name_index[event]
waiter = self.signal_router.get_waiter(
event, condition, exclusive
)
self.signal_router.finalize()
else:
raise NotFound("Could not find signal %s" % event)
return await wait_for(signal.ctx.event.wait(), timeout=timeout)
return await wait_for(waiter.wait(), timeout=timeout)

def report_exception(
self, handler: Callable[[Sanic, Exception], Coroutine[Any, Any, None]]
Expand Down
39 changes: 31 additions & 8 deletions sanic/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,33 +512,56 @@ async def dispatch(self, *args, **kwargs):
condition = kwargs.pop("condition", {})
condition.update({"__blueprint__": self.name})
kwargs["condition"] = condition
await asyncio.gather(
return await asyncio.gather(
*[app.dispatch(*args, **kwargs) for app in self.apps]
)

def event(self, event: str, timeout: Optional[Union[int, float]] = None):
def event(
self,
event: str,
timeout: Optional[Union[int, float]] = None,
*,
condition: Optional[Dict[str, Any]] = None,
):
"""Wait for a signal event to be dispatched.
Args:
event (str): Name of the signal event.
timeout (Optional[Union[int, float]]): Timeout for the event to be
dispatched.
condition: If provided, method will only return when the signal
is dispatched with the given condition.
Returns:
Awaitable: Awaitable for the event to be dispatched.
"""
events = set()
if condition is None:
condition = {}
condition.update({"__blueprint__": self.name})

waiters = []
for app in self.apps:
signal = app.signal_router.name_index.get(event)
if not signal:
waiter = app.signal_router.get_waiter(
event, condition, exclusive=False
)
if not waiter:
raise NotFound("Could not find signal %s" % event)
events.add(signal.ctx.event)
waiters.append(waiter)

return self._event(waiters, timeout)

return asyncio.wait(
[asyncio.create_task(event.wait()) for event in events],
async def _event(self, waiters, timeout):
done, pending = await asyncio.wait(
[asyncio.create_task(waiter.wait()) for waiter in waiters],
return_when=asyncio.FIRST_COMPLETED,
timeout=timeout,
)
for task in pending:
task.cancel()
if not done:
raise TimeoutError()
(finished_task,) = done
return finished_task.result()

@staticmethod
def _extract_value(*values):
Expand Down
4 changes: 2 additions & 2 deletions sanic/mixins/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def decorator(handler: SignalHandler):
def add_signal(
self,
handler: Optional[Callable[..., Any]],
event: str,
event: Union[str, Enum],
condition: Optional[Dict[str, Any]] = None,
exclusive: bool = True,
) -> Callable[..., Any]:
Expand All @@ -88,7 +88,7 @@ def add_signal(
"""
if not handler:

async def noop():
async def noop(**context):
...

handler = noop
Expand Down
98 changes: 80 additions & 18 deletions sanic/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import asyncio

from collections import deque
from dataclasses import dataclass
from enum import Enum
from inspect import isawaitable
from typing import Any, Dict, List, Optional, Tuple, Union, cast
Expand Down Expand Up @@ -86,6 +88,39 @@ class Signal(Route):
"""A `Route` that is used to dispatch signals to handlers"""


@dataclass
class SignalWaiter:
"""A record representing a future waiting for a signal"""

signal: Signal
event_definition: str
trigger: str = ""
requirements: Optional[Dict[str, str]] = None
exclusive: bool = True

future: Optional[asyncio.Future] = None

async def wait(self):
"""Block until the signal is next dispatched.
Return the context of the signal dispatch, if any.
"""
loop = asyncio.get_running_loop()
self.future = loop.create_future()
self.signal.ctx.waiters.append(self)
try:
return await self.future
finally:
self.signal.ctx.waiters.remove(self)

def matches(self, event, condition):
return (
(condition is None and not self.exclusive)
or (condition is None and not self.requirements)
or condition == self.requirements
) and (self.trigger or event == self.event_definition)


class SignalGroup(RouteGroup):
"""A `RouteGroup` that is used to dispatch signals to handlers"""

Expand All @@ -104,7 +139,7 @@ def __init__(self) -> None:
self.ctx.loop = None

@staticmethod
def format_event(event: str) -> str:
def format_event(event: Union[str, Enum]) -> str:
"""Ensure event strings in proper format
Args:
Expand All @@ -113,13 +148,15 @@ def format_event(event: str) -> str:
Returns:
str: formatted event string
"""
if isinstance(event, Enum):
event = str(event.value)
if "." not in event:
event = GENERIC_SIGNAL_FORMAT % event
return event

def get( # type: ignore
self,
event: str,
event: Union[str, Enum],
condition: Optional[Dict[str, str]] = None,
):
"""Get the handlers for a signal
Expand Down Expand Up @@ -186,18 +223,20 @@ async def _dispatch(
error_logger.warning(str(e))
return None

events = [signal.ctx.event for signal in group]
for signal_event in events:
signal_event.set()
if context:
params.update(context)
params.pop("__trigger__", None)

signals = group.routes
if not reverse:
signals = signals[::-1]
try:
for signal in signals:
params.pop("__trigger__", None)
for waiter in signal.ctx.waiters:
if waiter.matches(event, condition):
waiter.future.set_result(dict(params))

for signal in signals:
requirements = signal.extra.requirements
if (
(condition is None and signal.ctx.exclusive is False)
Expand All @@ -223,13 +262,10 @@ async def _dispatch(
)
setattr(e, "__dispatched__", True)
raise e
finally:
for signal_event in events:
signal_event.clear()

async def dispatch(
self,
event: str,
event: Union[str, Enum],
*,
context: Optional[Dict[str, Any]] = None,
condition: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -271,15 +307,29 @@ async def dispatch(
await asyncio.sleep(0)
return task

def add( # type: ignore
def get_waiter(
self,
handler: SignalHandler,
event: str,
event: Union[str, Enum],
condition: Optional[Dict[str, Any]] = None,
exclusive: bool = True,
) -> Signal:
event = self.format_event(event)
event_definition = event
):
event_definition = self.format_event(event)
name, trigger, _ = self._get_event_parts(event_definition)
signal = cast(Signal, self.name_index.get(name))
if not signal:
return None

if event_definition.endswith(".*") and not trigger:
trigger = "*"
return SignalWaiter(
signal=signal,
event_definition=event_definition,
trigger=trigger,
requirements=condition,
exclusive=bool(exclusive),
)

def _get_event_parts(self, event: str) -> Tuple[str, str, str]:
parts = self._build_event_parts(event)
if parts[2].startswith("<"):
name = ".".join([*parts[:-1], "*"])
Expand All @@ -291,8 +341,20 @@ def add( # type: ignore
if not trigger:
event = ".".join([*parts[:2], "<__trigger__>"])

return name, trigger, event

def add( # type: ignore
self,
handler: SignalHandler,
event: Union[str, Enum],
condition: Optional[Dict[str, Any]] = None,
exclusive: bool = True,
) -> Signal:
event_definition = self.format_event(event)
name, trigger, event_string = self._get_event_parts(event_definition)

signal = super().add(
event,
event_string,
handler,
name=name,
append=True,
Expand Down Expand Up @@ -326,7 +388,7 @@ def finalize(self, do_compile: bool = True, do_optimize: bool = False):
raise RuntimeError("Cannot finalize signals outside of event loop")

for signal in self.routes:
signal.ctx.event = asyncio.Event()
signal.ctx.waiters = deque()

return super().finalize(do_compile=do_compile, do_optimize=do_optimize)

Expand Down
Loading

0 comments on commit 5f0787b

Please sign in to comment.