From 4d0a0a357063950e2ec0b8cb34b7c846136f7dc9 Mon Sep 17 00:00:00 2001 From: Josh Bartlett Date: Sat, 23 Sep 2023 22:59:44 +1000 Subject: [PATCH 1/4] Update app.event() to return context dict (#2088), and only fire for the specific requested event (#2826). Also added support for passing conditions= and exclusive= arguments to app.event(). --- sanic/app.py | 23 +++- sanic/blueprints.py | 37 ++++-- sanic/mixins/signals.py | 2 +- sanic/signals.py | 92 ++++++++++--- tests/test_signals.py | 282 ++++++++++++++++++++++++++++++++++------ 5 files changed, 364 insertions(+), 72 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 452954cadb..367aa962ee 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -17,6 +17,7 @@ from asyncio.futures import Future from collections import defaultdict, deque from contextlib import contextmanager, suppress +from enum import Enum from functools import partial, wraps from inspect import isawaitable from os import environ @@ -663,7 +664,12 @@ async def handle_registration(request): ) async def event( - self, event: str, timeout: Optional[Union[int, float]] = None + self, + event: Union[str, Enum], + timeout: Optional[Union[int, float]] = None, + *, + condition: Optional[Dict[str, Any]] = None, + exclusive: bool = True, ) -> None: """Wait for a specific event to be triggered. @@ -686,13 +692,18 @@ async def event( timeout (Optional[Union[int, float]]): An optional timeout value in seconds. If provided, the wait will be terminated if the timeout is reached. Defaults to `None`, meaning no timeout. + condition: If provided, method will only return when the signal + is dispatched with the given condition. + exclusive: When true (default), the signal can only be dispatched + when the condition has been met. When ``False``, the signal can + be dispatched either with or without it. Raises: NotFound: If the event is not found and auto-registration of events is not enabled. Returns: - None + The context dict of the dispatched signal. Examples: ```python @@ -708,16 +719,16 @@ async def after_server_start(app, loop): ``` """ - signal = self.signal_router.name_index.get(event) - if not signal: + waiter = self.signal_router.get_waiter(event, condition, exclusive) + if not waiter: if self.config.EVENT_AUTOREGISTER: self.signal_router.reset() self.add_signal(None, event) - signal = self.signal_router.name_index[event] + waiter = self.signal_router.get_waiter(event, condition, exclusive) self.signal_router.finalize() else: raise NotFound("Could not find signal %s" % event) - return await wait_for(signal.ctx.event.wait(), timeout=timeout) + return await wait_for(waiter.wait(), timeout=timeout) def report_exception( self, handler: Callable[[Sanic, Exception], Coroutine[Any, Any, None]] diff --git a/sanic/blueprints.py b/sanic/blueprints.py index 28199e9ecc..4aabd2db55 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -512,33 +512,54 @@ async def dispatch(self, *args, **kwargs): condition = kwargs.pop("condition", {}) condition.update({"__blueprint__": self.name}) kwargs["condition"] = condition - await asyncio.gather( + return await asyncio.gather( *[app.dispatch(*args, **kwargs) for app in self.apps] ) - def event(self, event: str, timeout: Optional[Union[int, float]] = None): + def event( + self, + event: str, + timeout: Optional[Union[int, float]] = None, + *, + condition: Optional[Dict[str, Any]] = None, + ): """Wait for a signal event to be dispatched. Args: event (str): Name of the signal event. timeout (Optional[Union[int, float]]): Timeout for the event to be dispatched. + condition: If provided, method will only return when the signal + is dispatched with the given condition. Returns: Awaitable: Awaitable for the event to be dispatched. """ - events = set() + if condition is None: + condition = {} + condition.update({"__blueprint__": self.name}) + + waiters = [] for app in self.apps: - signal = app.signal_router.name_index.get(event) - if not signal: + waiter = app.signal_router.get_waiter(event, condition, exclusive=False) + if not waiter: raise NotFound("Could not find signal %s" % event) - events.add(signal.ctx.event) + waiters.append(waiter) + + return self._event(waiters, timeout) - return asyncio.wait( - [asyncio.create_task(event.wait()) for event in events], + async def _event(self, waiters, timeout): + done, pending = await asyncio.wait( + [asyncio.create_task(waiter.wait()) for waiter in waiters], return_when=asyncio.FIRST_COMPLETED, timeout=timeout, ) + for task in pending: + task.cancel() + if not done: + raise TimeoutError() + finished_task, = done + return finished_task.result() @staticmethod def _extract_value(*values): diff --git a/sanic/mixins/signals.py b/sanic/mixins/signals.py index f1419dc4c9..7b23d73a67 100644 --- a/sanic/mixins/signals.py +++ b/sanic/mixins/signals.py @@ -88,7 +88,7 @@ def add_signal( """ if not handler: - async def noop(): + async def noop(**context): ... handler = noop diff --git a/sanic/signals.py b/sanic/signals.py index fe252c12be..513c43e049 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -2,6 +2,8 @@ import asyncio +from collections import deque +from dataclasses import dataclass from enum import Enum from inspect import isawaitable from typing import Any, Dict, List, Optional, Tuple, Union, cast @@ -76,6 +78,38 @@ class Signal(Route): """A `Route` that is used to dispatch signals to handlers""" +@dataclass +class SignalWaiter: + """A record representing a future waiting for a signal""" + + signal: Signal + event_definition: str + trigger: str = "" + requirements: Optional[Dict[str, str]] = None + exclusive: bool = True + + future: Optional[asyncio.Future] = None + + async def wait(self): + """Block until the signal is next dispatched. + + Return the context of the signal dispatch, if any. + """ + loop = asyncio.get_running_loop() + self.future = loop.create_future() + self.signal.ctx.waiters.append(self) + try: + return await self.future + finally: + self.signal.ctx.waiters.remove(self) + + def matches(self, event, condition): + return ((condition is None and not self.exclusive) + or (condition is None and not self.requirements) + or condition == self.requirements + ) and (self.trigger or event == self.event_definition) + + class SignalGroup(RouteGroup): """A `RouteGroup` that is used to dispatch signals to handlers""" @@ -160,18 +194,20 @@ async def _dispatch( error_logger.warning(str(e)) return None - events = [signal.ctx.event for signal in group] - for signal_event in events: - signal_event.set() if context: params.update(context) + params.pop("__trigger__", None) signals = group.routes if not reverse: signals = signals[::-1] try: for signal in signals: - params.pop("__trigger__", None) + for waiter in signal.ctx.waiters: + if waiter.matches(event, condition): + waiter.future.set_result(dict(params)) + + for signal in signals: requirements = signal.extra.requirements if ( (condition is None and signal.ctx.exclusive is False) @@ -197,9 +233,6 @@ async def _dispatch( ) setattr(e, "__dispatched__", True) raise e - finally: - for signal_event in events: - signal_event.clear() async def dispatch( self, @@ -244,14 +277,29 @@ async def dispatch( await asyncio.sleep(0) return task - def add( # type: ignore - self, - handler: SignalHandler, - event: str, - condition: Optional[Dict[str, Any]] = None, - exclusive: bool = True, - ) -> Signal: - event_definition = event + def get_waiter( + self, + event: Union[str, Enum], + condition: Optional[Dict[str, Any]], + exclusive: bool, + ): + event_definition = str(event.value) if isinstance(event, Enum) else event + name, trigger, _ = self._get_event_parts(event_definition) + signal = cast(Signal, self.name_index.get(name)) + if not signal: + return None + + if event_definition.endswith(".*") and not trigger: + trigger = "*" + return SignalWaiter( + signal=signal, + event_definition=event_definition, + trigger=trigger, + requirements=condition, + exclusive=bool(exclusive), + ) + + def _get_event_parts(self, event): parts = self._build_event_parts(event) if parts[2].startswith("<"): name = ".".join([*parts[:-1], "*"]) @@ -263,6 +311,18 @@ def add( # type: ignore if not trigger: event = ".".join([*parts[:2], "<__trigger__>"]) + return name, trigger, event + + def add( # type: ignore + self, + handler: SignalHandler, + event: str, + condition: Optional[Dict[str, Any]] = None, + exclusive: bool = True, + ) -> Signal: + event_definition = event + name, trigger, event = self._get_event_parts(event) + signal = super().add( event, handler, @@ -298,7 +358,7 @@ def finalize(self, do_compile: bool = True, do_optimize: bool = False): raise RuntimeError("Cannot finalize signals outside of event loop") for signal in self.routes: - signal.ctx.event = asyncio.Event() + signal.ctx.waiters = deque() return super().finalize(do_compile=do_compile, do_optimize=do_optimize) diff --git a/tests/test_signals.py b/tests/test_signals.py index 8a564a98bc..b451b9cef0 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -78,6 +78,49 @@ def handler(): ... +@pytest.mark.asyncio +async def test_dispatch_signal_triggers_event(app): + + @app.signal("foo.bar.baz") + def sync_signal(*args): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.baz")) + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) + + assert event_task.done() + event_task.result() # Will raise if there was an exception + + +@pytest.mark.asyncio +async def test_dispatch_signal_triggers_correct_event(app): + # Check for https://github.com/sanic-org/sanic/issues/2826 + + @app.signal("foo.bar.baz") + def sync_signal(*args): + pass + + @app.signal("foo.bar.spam") + def sync_signal(*args): + pass + + app.signal_router.finalize() + + baz_task = asyncio.create_task(app.event("foo.bar.baz")) + spam_task = asyncio.create_task(app.event("foo.bar.spam")) + + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) + + assert baz_task.done() + assert not spam_task.done() + baz_task.result() + spam_task.cancel() + + @pytest.mark.asyncio async def test_dispatch_signal_with_enum_event(app): counter = 0 @@ -97,6 +140,26 @@ def sync_signal(*_): assert counter == 1 +@pytest.mark.asyncio +async def test_dispatch_signal_with_enum_event_to_event(app): + + class FooEnum(Enum): + FOO_BAR_BAZ = "foo.bar.baz" + + @app.signal(FooEnum.FOO_BAR_BAZ) + def sync_signal(*args): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event(FooEnum.FOO_BAR_BAZ)) + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) + + assert event_task.done() + event_task.result() # Will raise if there was an exception + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_multiple_handlers(app): counter = 0 @@ -121,22 +184,45 @@ async def async_signal(*_): @pytest.mark.asyncio -async def test_dispatch_signal_triggers_triggers_event(app): - counter = 0 +async def test_dispatch_signal_triggers_multiple_events(app): @app.signal("foo.bar.baz") - def sync_signal(*args): - nonlocal app - nonlocal counter - group, *_ = app.signal_router.get("foo.bar.baz") - for signal in group: - counter += signal.ctx.event.is_set() + def sync_signal(*_): + pass app.signal_router.finalize() + event_task1 = asyncio.create_task(app.event("foo.bar.baz")) + event_task2 = asyncio.create_task(app.event("foo.bar.baz")) + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) - assert counter == 1 + assert event_task1.done() + assert event_task2.done() + event_task1.result() # Will raise if there was an exception + event_task2.result() # Will raise if there was an exception + + +@pytest.mark.asyncio +async def test_dispatch_signal_with_multiple_handlers_triggers_event_once(app): + + @app.signal("foo.bar.baz") + def sync_signal(*_): + pass + + @app.signal("foo.bar.baz") + async def async_signal(*_): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.baz")) + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) + + assert event_task.done() + event_task.result() # Will raise if there was an exception @pytest.mark.asyncio @@ -155,6 +241,40 @@ def sync_signal(baz): assert counter == 9 +@pytest.mark.asyncio +async def test_dispatch_signal_triggers_parameterized_dynamic_route_event(app): + + @app.signal("foo.bar.") + def sync_signal(baz): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.")) + await app.dispatch("foo.bar.9") + await asyncio.sleep(0) + + assert event_task.done() + event_task.result() # Will raise if there was an exception + + +@pytest.mark.asyncio +async def test_dispatch_signal_triggers_starred_dynamic_route_event(app): + + @app.signal("foo.bar.") + def sync_signal(baz): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.*")) + await app.dispatch("foo.bar.9") + await asyncio.sleep(0) + + assert event_task.done() + event_task.result() # Will raise if there was an exception + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_with_requirements(app): counter = 0 @@ -172,6 +292,26 @@ def sync_signal(*_): assert counter == 1 +@pytest.mark.asyncio +async def test_dispatch_signal_to_event_with_requirements(app): + + @app.signal("foo.bar.baz") + def sync_signal(*_): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.baz", condition={"one": "two"})) + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) + assert not event_task.done() + + await app.dispatch("foo.bar.baz", condition={"one": "two"}) + await asyncio.sleep(0) + assert event_task.done() + event_task.result() # Will raise if there was an exception + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_with_requirements_exclusive(app): counter = 0 @@ -189,6 +329,28 @@ def sync_signal(*_): assert counter == 2 +@pytest.mark.asyncio +async def test_dispatch_signal_to_event_with_requirements_exclusive(app): + + @app.signal("foo.bar.baz") + def sync_signal(*_): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.baz", condition={"one": "two"}, exclusive=False)) + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) + assert event_task.done() + event_task.result() # Will raise if there was an exception + + event_task = asyncio.create_task(app.event("foo.bar.baz", condition={"one": "two"}, exclusive=False)) + await app.dispatch("foo.bar.baz", condition={"one": "two"}) + await asyncio.sleep(0) + assert event_task.done() + event_task.result() # Will raise if there was an exception + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_with_context(app): counter = 0 @@ -204,6 +366,22 @@ def sync_signal(amount): assert counter == 9 +@pytest.mark.asyncio +async def test_dispatch_signal_to_event_with_context(app): + + @app.signal("foo.bar.baz") + def sync_signal(**context): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.baz")) + await app.dispatch("foo.bar.baz", context={"amount": 9}) + await asyncio.sleep(0) + assert event_task.done() + assert event_task.result()['amount'] == 9 + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_with_context_fail(app): counter = 0 @@ -219,6 +397,22 @@ def sync_signal(amount): await app.dispatch("foo.bar.baz", {"amount": 9}) +@pytest.mark.asyncio +async def test_dispatch_signal_to_dynamic_route_event(app): + + @app.signal("foo.bar.") + def sync_signal(**context): + pass + + app.signal_router.finalize() + + event_task = asyncio.create_task(app.event("foo.bar.")) + await app.dispatch("foo.bar.baz") + await asyncio.sleep(0) + assert event_task.done() + assert event_task.result()['something'] == "baz" + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_on_bp(app): bp = Blueprint("bp") @@ -267,61 +461,67 @@ def bp_signal(): @pytest.mark.asyncio -async def test_dispatch_signal_triggers_event(app): - app_counter = 0 +async def test_dispatch_signal_triggers_event_on_bp(app): + bp = Blueprint("bp") @app.signal("foo.bar.baz") def app_signal(): ... - async def do_wait(): - nonlocal app_counter - await app.event("foo.bar.baz") - app_counter += 1 + @bp.signal("foo.bar.baz") + def bp_signal(): + ... + app.blueprint(bp) app.signal_router.finalize() + app_task = asyncio.create_task(app.event("foo.bar.baz")) + bp_task = asyncio.create_task(bp.event("foo.bar.baz")) + await asyncio.sleep(0) await app.dispatch("foo.bar.baz") - waiter = app.event("foo.bar.baz") - assert isawaitable(waiter) - fut = asyncio.ensure_future(do_wait()) - await app.dispatch("foo.bar.baz") - await fut + # Allow a few event loop iterations for tasks to finish + for _ in range(5): + await asyncio.sleep(0) - assert app_counter == 1 + assert app_task.done() + assert bp_task.done() + app_task.result() + bp_task.result() + + app_task = asyncio.create_task(app.event("foo.bar.baz")) + bp_task = asyncio.create_task(bp.event("foo.bar.baz")) + await asyncio.sleep(0) + await bp.dispatch("foo.bar.baz") + + # Allow a few event loop iterations for tasks to finish + for _ in range(5): + await asyncio.sleep(0) + + assert bp_task.done() + assert not app_task.done() + bp_task.result() + app_task.cancel() @pytest.mark.asyncio -async def test_dispatch_signal_triggers_event_on_bp(app): +async def test_dispatch_signal_triggers_event_on_bp_with_context(app): bp = Blueprint("bp") - bp_counter = 0 @bp.signal("foo.bar.baz") def bp_signal(): ... - async def do_wait(): - nonlocal bp_counter - await bp.event("foo.bar.baz") - bp_counter += 1 - app.blueprint(bp) app.signal_router.finalize() - signal_group, *_ = app.signal_router.get( - "foo.bar.baz", condition={"blueprint": "bp"} - ) - await bp.dispatch("foo.bar.baz") - waiter = bp.event("foo.bar.baz") - assert isawaitable(waiter) - - fut = do_wait() - for signal in signal_group: - signal.ctx.event.set() - await asyncio.gather(fut) - - assert bp_counter == 1 + event_task = asyncio.create_task(bp.event("foo.bar.baz")) + await asyncio.sleep(0) + await app.dispatch("foo.bar.baz", context={"amount": 9}) + for _ in range(5): + await asyncio.sleep(0) + assert event_task.done() + assert event_task.result()['amount'] == 9 def test_bad_finalize(app): From 5fb7eaaeabdf5d23ee3eb2530d1062ec27d6928e Mon Sep 17 00:00:00 2001 From: Josh Bartlett Date: Mon, 25 Sep 2023 11:02:04 +1000 Subject: [PATCH 2/4] Prettified --- sanic/app.py | 4 +++- sanic/blueprints.py | 16 ++++++++------- sanic/signals.py | 19 +++++++++-------- tests/test_signals.py | 48 ++++++++++++++++++++----------------------- 4 files changed, 45 insertions(+), 42 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 367aa962ee..6a170d878f 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -724,7 +724,9 @@ async def after_server_start(app, loop): if self.config.EVENT_AUTOREGISTER: self.signal_router.reset() self.add_signal(None, event) - waiter = self.signal_router.get_waiter(event, condition, exclusive) + waiter = self.signal_router.get_waiter( + event, condition, exclusive + ) self.signal_router.finalize() else: raise NotFound("Could not find signal %s" % event) diff --git a/sanic/blueprints.py b/sanic/blueprints.py index 4aabd2db55..bf0635b3b7 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -517,11 +517,11 @@ async def dispatch(self, *args, **kwargs): ) def event( - self, - event: str, - timeout: Optional[Union[int, float]] = None, - *, - condition: Optional[Dict[str, Any]] = None, + self, + event: str, + timeout: Optional[Union[int, float]] = None, + *, + condition: Optional[Dict[str, Any]] = None, ): """Wait for a signal event to be dispatched. @@ -541,7 +541,9 @@ def event( waiters = [] for app in self.apps: - waiter = app.signal_router.get_waiter(event, condition, exclusive=False) + waiter = app.signal_router.get_waiter( + event, condition, exclusive=False + ) if not waiter: raise NotFound("Could not find signal %s" % event) waiters.append(waiter) @@ -558,7 +560,7 @@ async def _event(self, waiters, timeout): task.cancel() if not done: raise TimeoutError() - finished_task, = done + (finished_task,) = done return finished_task.result() @staticmethod diff --git a/sanic/signals.py b/sanic/signals.py index 513c43e049..a247da6319 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -104,9 +104,10 @@ async def wait(self): self.signal.ctx.waiters.remove(self) def matches(self, event, condition): - return ((condition is None and not self.exclusive) - or (condition is None and not self.requirements) - or condition == self.requirements + return ( + (condition is None and not self.exclusive) + or (condition is None and not self.requirements) + or condition == self.requirements ) and (self.trigger or event == self.event_definition) @@ -278,12 +279,14 @@ async def dispatch( return task def get_waiter( - self, - event: Union[str, Enum], - condition: Optional[Dict[str, Any]], - exclusive: bool, + self, + event: Union[str, Enum], + condition: Optional[Dict[str, Any]], + exclusive: bool, ): - event_definition = str(event.value) if isinstance(event, Enum) else event + event_definition = ( + str(event.value) if isinstance(event, Enum) else event + ) name, trigger, _ = self._get_event_parts(event_definition) signal = cast(Signal, self.name_index.get(name)) if not signal: diff --git a/tests/test_signals.py b/tests/test_signals.py index b451b9cef0..6073ff9bc4 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -80,7 +80,6 @@ def handler(): @pytest.mark.asyncio async def test_dispatch_signal_triggers_event(app): - @app.signal("foo.bar.baz") def sync_signal(*args): pass @@ -92,7 +91,7 @@ def sync_signal(*args): await asyncio.sleep(0) assert event_task.done() - event_task.result() # Will raise if there was an exception + event_task.result() # Will raise if there was an exception @pytest.mark.asyncio @@ -142,7 +141,6 @@ def sync_signal(*_): @pytest.mark.asyncio async def test_dispatch_signal_with_enum_event_to_event(app): - class FooEnum(Enum): FOO_BAR_BAZ = "foo.bar.baz" @@ -157,7 +155,7 @@ def sync_signal(*args): await asyncio.sleep(0) assert event_task.done() - event_task.result() # Will raise if there was an exception + event_task.result() # Will raise if there was an exception @pytest.mark.asyncio @@ -185,7 +183,6 @@ async def async_signal(*_): @pytest.mark.asyncio async def test_dispatch_signal_triggers_multiple_events(app): - @app.signal("foo.bar.baz") def sync_signal(*_): pass @@ -200,13 +197,12 @@ def sync_signal(*_): assert event_task1.done() assert event_task2.done() - event_task1.result() # Will raise if there was an exception - event_task2.result() # Will raise if there was an exception + event_task1.result() # Will raise if there was an exception + event_task2.result() # Will raise if there was an exception @pytest.mark.asyncio async def test_dispatch_signal_with_multiple_handlers_triggers_event_once(app): - @app.signal("foo.bar.baz") def sync_signal(*_): pass @@ -222,7 +218,7 @@ async def async_signal(*_): await asyncio.sleep(0) assert event_task.done() - event_task.result() # Will raise if there was an exception + event_task.result() # Will raise if there was an exception @pytest.mark.asyncio @@ -243,7 +239,6 @@ def sync_signal(baz): @pytest.mark.asyncio async def test_dispatch_signal_triggers_parameterized_dynamic_route_event(app): - @app.signal("foo.bar.") def sync_signal(baz): pass @@ -255,12 +250,11 @@ def sync_signal(baz): await asyncio.sleep(0) assert event_task.done() - event_task.result() # Will raise if there was an exception + event_task.result() # Will raise if there was an exception @pytest.mark.asyncio async def test_dispatch_signal_triggers_starred_dynamic_route_event(app): - @app.signal("foo.bar.") def sync_signal(baz): pass @@ -272,7 +266,7 @@ def sync_signal(baz): await asyncio.sleep(0) assert event_task.done() - event_task.result() # Will raise if there was an exception + event_task.result() # Will raise if there was an exception @pytest.mark.asyncio @@ -294,14 +288,15 @@ def sync_signal(*_): @pytest.mark.asyncio async def test_dispatch_signal_to_event_with_requirements(app): - @app.signal("foo.bar.baz") def sync_signal(*_): pass app.signal_router.finalize() - event_task = asyncio.create_task(app.event("foo.bar.baz", condition={"one": "two"})) + event_task = asyncio.create_task( + app.event("foo.bar.baz", condition={"one": "two"}) + ) await app.dispatch("foo.bar.baz") await asyncio.sleep(0) assert not event_task.done() @@ -309,7 +304,7 @@ def sync_signal(*_): await app.dispatch("foo.bar.baz", condition={"one": "two"}) await asyncio.sleep(0) assert event_task.done() - event_task.result() # Will raise if there was an exception + event_task.result() # Will raise if there was an exception @pytest.mark.asyncio @@ -331,24 +326,27 @@ def sync_signal(*_): @pytest.mark.asyncio async def test_dispatch_signal_to_event_with_requirements_exclusive(app): - @app.signal("foo.bar.baz") def sync_signal(*_): pass app.signal_router.finalize() - event_task = asyncio.create_task(app.event("foo.bar.baz", condition={"one": "two"}, exclusive=False)) + event_task = asyncio.create_task( + app.event("foo.bar.baz", condition={"one": "two"}, exclusive=False) + ) await app.dispatch("foo.bar.baz") await asyncio.sleep(0) assert event_task.done() - event_task.result() # Will raise if there was an exception + event_task.result() # Will raise if there was an exception - event_task = asyncio.create_task(app.event("foo.bar.baz", condition={"one": "two"}, exclusive=False)) + event_task = asyncio.create_task( + app.event("foo.bar.baz", condition={"one": "two"}, exclusive=False) + ) await app.dispatch("foo.bar.baz", condition={"one": "two"}) await asyncio.sleep(0) assert event_task.done() - event_task.result() # Will raise if there was an exception + event_task.result() # Will raise if there was an exception @pytest.mark.asyncio @@ -368,7 +366,6 @@ def sync_signal(amount): @pytest.mark.asyncio async def test_dispatch_signal_to_event_with_context(app): - @app.signal("foo.bar.baz") def sync_signal(**context): pass @@ -379,7 +376,7 @@ def sync_signal(**context): await app.dispatch("foo.bar.baz", context={"amount": 9}) await asyncio.sleep(0) assert event_task.done() - assert event_task.result()['amount'] == 9 + assert event_task.result()["amount"] == 9 @pytest.mark.asyncio @@ -399,7 +396,6 @@ def sync_signal(amount): @pytest.mark.asyncio async def test_dispatch_signal_to_dynamic_route_event(app): - @app.signal("foo.bar.") def sync_signal(**context): pass @@ -410,7 +406,7 @@ def sync_signal(**context): await app.dispatch("foo.bar.baz") await asyncio.sleep(0) assert event_task.done() - assert event_task.result()['something'] == "baz" + assert event_task.result()["something"] == "baz" @pytest.mark.asyncio @@ -521,7 +517,7 @@ def bp_signal(): for _ in range(5): await asyncio.sleep(0) assert event_task.done() - assert event_task.result()['amount'] == 9 + assert event_task.result()["amount"] == 9 def test_bad_finalize(app): From c35497349eff5302e6a3d427eef028851dac08e3 Mon Sep 17 00:00:00 2001 From: Josh Bartlett Date: Mon, 25 Sep 2023 11:28:45 +1000 Subject: [PATCH 3/4] Fixed type checking error --- sanic/mixins/signals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sanic/mixins/signals.py b/sanic/mixins/signals.py index 7b23d73a67..b6be461108 100644 --- a/sanic/mixins/signals.py +++ b/sanic/mixins/signals.py @@ -65,7 +65,7 @@ def decorator(handler: SignalHandler): def add_signal( self, handler: Optional[Callable[..., Any]], - event: str, + event: Union[str, Enum], condition: Optional[Dict[str, Any]] = None, exclusive: bool = True, ) -> Callable[..., Any]: From 9469f36ac1a373ab4406082a7a88e24ba2aa63d7 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Thu, 7 Dec 2023 11:49:21 +0200 Subject: [PATCH 4/4] make pretty --- sanic/signals.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sanic/signals.py b/sanic/signals.py index 1786ca129b..799d1e9143 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -394,7 +394,11 @@ def finalize(self, do_compile: bool = True, do_optimize: bool = False): def _build_event_parts(self, event: str) -> Tuple[str, str, str]: parts = path_to_parts(event, self.delimiter) - if len(parts) != 3 or parts[0].startswith("<") or parts[1].startswith("<"): + if ( + len(parts) != 3 + or parts[0].startswith("<") + or parts[1].startswith("<") + ): raise InvalidSignal("Invalid signal event: %s" % event) if ( @@ -402,7 +406,9 @@ def _build_event_parts(self, event: str) -> Tuple[str, str, str]: and event not in RESERVED_NAMESPACES[parts[0]] and not (parts[2].startswith("<") and parts[2].endswith(">")) ): - raise InvalidSignal("Cannot declare reserved signal event: %s" % event) + raise InvalidSignal( + "Cannot declare reserved signal event: %s" % event + ) return parts def _clean_trigger(self, trigger: str) -> str: