Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve functionality and consistency of Sanic.event() #2827

Merged
merged 5 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from asyncio.futures import Future
from collections import defaultdict, deque
from contextlib import contextmanager, suppress
from enum import Enum
from functools import partial, wraps
from inspect import isawaitable
from os import environ
Expand Down Expand Up @@ -714,7 +715,12 @@ async def handle_registration(request):
)

async def event(
self, event: str, timeout: Optional[Union[int, float]] = None
self,
event: Union[str, Enum],
timeout: Optional[Union[int, float]] = None,
*,
condition: Optional[Dict[str, Any]] = None,
exclusive: bool = True,
) -> None:
"""Wait for a specific event to be triggered.

Expand All @@ -737,13 +743,18 @@ async def event(
timeout (Optional[Union[int, float]]): An optional timeout value
in seconds. If provided, the wait will be terminated if the
timeout is reached. Defaults to `None`, meaning no timeout.
condition: If provided, method will only return when the signal
is dispatched with the given condition.
exclusive: When true (default), the signal can only be dispatched
when the condition has been met. When ``False``, the signal can
be dispatched either with or without it.

Raises:
NotFound: If the event is not found and auto-registration of
events is not enabled.

Returns:
None
The context dict of the dispatched signal.

Examples:
```python
Expand All @@ -759,16 +770,18 @@ async def after_server_start(app, loop):
```
"""

signal = self.signal_router.name_index.get(event)
if not signal:
waiter = self.signal_router.get_waiter(event, condition, exclusive)
if not waiter:
if self.config.EVENT_AUTOREGISTER:
self.signal_router.reset()
self.add_signal(None, event)
signal = self.signal_router.name_index[event]
waiter = self.signal_router.get_waiter(
event, condition, exclusive
)
self.signal_router.finalize()
else:
raise NotFound("Could not find signal %s" % event)
return await wait_for(signal.ctx.event.wait(), timeout=timeout)
return await wait_for(waiter.wait(), timeout=timeout)

def report_exception(
self, handler: Callable[[Sanic, Exception], Coroutine[Any, Any, None]]
Expand Down
39 changes: 31 additions & 8 deletions sanic/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,33 +512,56 @@
condition = kwargs.pop("condition", {})
condition.update({"__blueprint__": self.name})
kwargs["condition"] = condition
await asyncio.gather(
return await asyncio.gather(
*[app.dispatch(*args, **kwargs) for app in self.apps]
)

def event(self, event: str, timeout: Optional[Union[int, float]] = None):
def event(
self,
event: str,
timeout: Optional[Union[int, float]] = None,
*,
condition: Optional[Dict[str, Any]] = None,
):
"""Wait for a signal event to be dispatched.

Args:
event (str): Name of the signal event.
timeout (Optional[Union[int, float]]): Timeout for the event to be
dispatched.
condition: If provided, method will only return when the signal
is dispatched with the given condition.

Returns:
Awaitable: Awaitable for the event to be dispatched.
"""
events = set()
if condition is None:
condition = {}
condition.update({"__blueprint__": self.name})

waiters = []
for app in self.apps:
signal = app.signal_router.name_index.get(event)
if not signal:
waiter = app.signal_router.get_waiter(
event, condition, exclusive=False
)
if not waiter:
raise NotFound("Could not find signal %s" % event)
events.add(signal.ctx.event)
waiters.append(waiter)

return self._event(waiters, timeout)

return asyncio.wait(
[asyncio.create_task(event.wait()) for event in events],
async def _event(self, waiters, timeout):
done, pending = await asyncio.wait(
[asyncio.create_task(waiter.wait()) for waiter in waiters],
return_when=asyncio.FIRST_COMPLETED,
timeout=timeout,
)
for task in pending:
task.cancel()

Check warning on line 560 in sanic/blueprints.py

View check run for this annotation

Codecov / codecov/patch

sanic/blueprints.py#L560

Added line #L560 was not covered by tests
if not done:
raise TimeoutError()

Check warning on line 562 in sanic/blueprints.py

View check run for this annotation

Codecov / codecov/patch

sanic/blueprints.py#L562

Added line #L562 was not covered by tests
(finished_task,) = done
return finished_task.result()

@staticmethod
def _extract_value(*values):
Expand Down
4 changes: 2 additions & 2 deletions sanic/mixins/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def decorator(handler: SignalHandler):
def add_signal(
self,
handler: Optional[Callable[..., Any]],
event: str,
event: Union[str, Enum],
condition: Optional[Dict[str, Any]] = None,
exclusive: bool = True,
) -> Callable[..., Any]:
Expand All @@ -88,7 +88,7 @@ def add_signal(
"""
if not handler:

async def noop():
async def noop(**context):
...

handler = noop
Expand Down
98 changes: 80 additions & 18 deletions sanic/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import asyncio

from collections import deque
from dataclasses import dataclass
from enum import Enum
from inspect import isawaitable
from typing import Any, Dict, List, Optional, Tuple, Union, cast
Expand Down Expand Up @@ -86,6 +88,39 @@ class Signal(Route):
"""A `Route` that is used to dispatch signals to handlers"""


@dataclass
class SignalWaiter:
"""A record representing a future waiting for a signal"""

signal: Signal
event_definition: str
trigger: str = ""
requirements: Optional[Dict[str, str]] = None
exclusive: bool = True

future: Optional[asyncio.Future] = None

async def wait(self):
"""Block until the signal is next dispatched.

Return the context of the signal dispatch, if any.
"""
loop = asyncio.get_running_loop()
self.future = loop.create_future()
self.signal.ctx.waiters.append(self)
try:
return await self.future
finally:
self.signal.ctx.waiters.remove(self)

def matches(self, event, condition):
return (
(condition is None and not self.exclusive)
or (condition is None and not self.requirements)
or condition == self.requirements
) and (self.trigger or event == self.event_definition)


class SignalGroup(RouteGroup):
"""A `RouteGroup` that is used to dispatch signals to handlers"""

Expand All @@ -104,7 +139,7 @@ def __init__(self) -> None:
self.ctx.loop = None

@staticmethod
def format_event(event: str) -> str:
def format_event(event: Union[str, Enum]) -> str:
"""Ensure event strings in proper format

Args:
Expand All @@ -113,13 +148,15 @@ def format_event(event: str) -> str:
Returns:
str: formatted event string
"""
if isinstance(event, Enum):
event = str(event.value)
if "." not in event:
event = GENERIC_SIGNAL_FORMAT % event
return event

def get( # type: ignore
self,
event: str,
event: Union[str, Enum],
condition: Optional[Dict[str, str]] = None,
):
"""Get the handlers for a signal
Expand Down Expand Up @@ -186,18 +223,20 @@ async def _dispatch(
error_logger.warning(str(e))
return None

events = [signal.ctx.event for signal in group]
for signal_event in events:
signal_event.set()
if context:
params.update(context)
params.pop("__trigger__", None)

signals = group.routes
if not reverse:
signals = signals[::-1]
try:
for signal in signals:
params.pop("__trigger__", None)
for waiter in signal.ctx.waiters:
if waiter.matches(event, condition):
waiter.future.set_result(dict(params))

for signal in signals:
requirements = signal.extra.requirements
if (
(condition is None and signal.ctx.exclusive is False)
Expand All @@ -223,13 +262,10 @@ async def _dispatch(
)
setattr(e, "__dispatched__", True)
raise e
finally:
for signal_event in events:
signal_event.clear()

async def dispatch(
self,
event: str,
event: Union[str, Enum],
*,
context: Optional[Dict[str, Any]] = None,
condition: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -271,15 +307,29 @@ async def dispatch(
await asyncio.sleep(0)
return task

def add( # type: ignore
def get_waiter(
self,
handler: SignalHandler,
event: str,
event: Union[str, Enum],
condition: Optional[Dict[str, Any]] = None,
exclusive: bool = True,
) -> Signal:
event = self.format_event(event)
event_definition = event
):
event_definition = self.format_event(event)
name, trigger, _ = self._get_event_parts(event_definition)
signal = cast(Signal, self.name_index.get(name))
if not signal:
return None

if event_definition.endswith(".*") and not trigger:
trigger = "*"
return SignalWaiter(
signal=signal,
event_definition=event_definition,
trigger=trigger,
requirements=condition,
exclusive=bool(exclusive),
)

def _get_event_parts(self, event: str) -> Tuple[str, str, str]:
parts = self._build_event_parts(event)
if parts[2].startswith("<"):
name = ".".join([*parts[:-1], "*"])
Expand All @@ -291,8 +341,20 @@ def add( # type: ignore
if not trigger:
event = ".".join([*parts[:2], "<__trigger__>"])

return name, trigger, event

def add( # type: ignore
self,
handler: SignalHandler,
event: Union[str, Enum],
condition: Optional[Dict[str, Any]] = None,
exclusive: bool = True,
) -> Signal:
event_definition = self.format_event(event)
name, trigger, event_string = self._get_event_parts(event_definition)

signal = super().add(
event,
event_string,
handler,
name=name,
append=True,
Expand Down Expand Up @@ -326,7 +388,7 @@ def finalize(self, do_compile: bool = True, do_optimize: bool = False):
raise RuntimeError("Cannot finalize signals outside of event loop")

for signal in self.routes:
signal.ctx.event = asyncio.Event()
signal.ctx.waiters = deque()

return super().finalize(do_compile=do_compile, do_optimize=do_optimize)

Expand Down
Loading
Loading