Skip to content

Commit

Permalink
Add method to get a callback before connecting to the stream
Browse files Browse the repository at this point in the history
Signed-off-by: Mathias L. Baumann <mathias.baumann@frequenz.com>
  • Loading branch information
Marenz committed Sep 25, 2024
1 parent aec339d commit ba15b6e
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 17 deletions.
98 changes: 81 additions & 17 deletions src/frequenz/client/dispatch/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import datetime, timedelta
from importlib.resources import files
from pathlib import Path
from typing import Any, AsyncIterator, Awaitable, Iterator, cast
from typing import Any, AsyncIterator, Awaitable, Callable, Iterator, cast

# pylint: disable=no-name-in-module
from frequenz.api.common.v1.pagination.pagination_params_pb2 import PaginationParams
Expand Down Expand Up @@ -48,12 +48,77 @@
DEFAULT_DISPATCH_PORT = 50051


class _DispatchBroadcaster(
GrpcStreamBroadcaster[StreamMicrogridDispatchesResponse, DispatchEvent]
):
"""A broadcaster for dispatch events.
Offers pre-connect callbacks to allow for setup before (re-)connecting.
"""

def __init__(
self, client: "Client", request: StreamMicrogridDispatchesRequest
) -> None:
"""Initialize the broadcaster.
Args:
client: The client instance.
request: The stream request.
"""
super().__init__(
stream_name="StreamMicrogridDispatches",
stream_method=self._dispatch_stream_method,
transform=DispatchEvent.from_protobuf,
)
self._pre_connect_cbs: list[Callable[[], bool]] = []
self._request = request
self._client = client

def add_pre_connect_cb(self, cb: Callable[[], bool]) -> None:
"""Add a pre-connect callback.
The callback will be called before connecting to the stream. If the callback
returns False, it will be removed from the list of callbacks.
Args:
cb: The callback to add.
"""
self._pre_connect_cbs.append(cb)

def remove_pre_connect_cb(self, cb: Callable[[], bool]) -> None:
"""Remove a pre-connect callback.
Args:
cb: The callback to remove.
"""
self._pre_connect_cbs.remove(cb)

def _dispatch_stream_method(
self,
) -> AsyncIterator[StreamMicrogridDispatchesResponse]:
# Collect callbacks that return False when called
to_remove = []
for cb in self._pre_connect_cbs:
if not cb(): # Call the callback function and check its return value
to_remove.append(cb) # Mark for removal if it returns False

# Remove callbacks that returned False
for cb in to_remove:
self.remove_pre_connect_cb(cb)

return cast(
AsyncIterator[StreamMicrogridDispatchesResponse],
self._client.stub.StreamMicrogridDispatches(
self._request,
metadata=self._client._metadata, # pylint: disable=protected-access
),
)


class Client(BaseApiClient[dispatch_pb2_grpc.MicrogridDispatchServiceStub]):
"""Dispatch API client."""

streams: dict[
int, GrpcStreamBroadcaster[StreamMicrogridDispatchesResponse, DispatchEvent]
] = {}
streams: dict[int, _DispatchBroadcaster] = {}
"""A dictionary of streamers, keyed by microgrid_id."""

def __init__(
Expand Down Expand Up @@ -180,7 +245,9 @@ def to_interval(
else:
break

def stream(self, microgrid_id: int) -> channels.Receiver[DispatchEvent]:
def stream(
self, microgrid_id: int, pre_connect_cb: Callable[[], bool] | None = None
) -> channels.Receiver[DispatchEvent]:
"""Receive a stream of dispatch events.
This function returns a receiver channel that can be used to receive
Expand All @@ -197,31 +264,28 @@ def stream(self, microgrid_id: int) -> channels.Receiver[DispatchEvent]:
Args:
microgrid_id: The microgrid_id to receive dispatches for.
pre_connect_cb: An optional callback to be called before connecting.
If the callback returns False, it will be removed from the list of
callbacks.
Returns:
A receiver channel to receive the stream of dispatch events.
"""
return self._get_stream(microgrid_id).new_receiver()
return self._get_stream(microgrid_id, pre_connect_cb).new_receiver()

def _get_stream(
self, microgrid_id: int
self, microgrid_id: int, pre_connect_cb: Callable[[], bool] | None = None
) -> GrpcStreamBroadcaster[StreamMicrogridDispatchesResponse, DispatchEvent]:
"""Get an instance to the streaming helper."""
broadcaster = self.streams.get(microgrid_id)
if broadcaster is None:
request = StreamMicrogridDispatchesRequest(microgrid_id=microgrid_id)
broadcaster = GrpcStreamBroadcaster(
stream_name="StreamMicrogridDispatches",
stream_method=lambda: cast(
AsyncIterator[StreamMicrogridDispatchesResponse],
self.stub.StreamMicrogridDispatches(
request, metadata=self._metadata
),
),
transform=DispatchEvent.from_protobuf,
)
broadcaster = _DispatchBroadcaster(self, request)
self.streams[microgrid_id] = broadcaster

if pre_connect_cb:
broadcaster.add_pre_connect_cb(pre_connect_cb)

return broadcaster

async def create(
Expand Down
25 changes: 25 additions & 0 deletions tests/test_dispatch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,28 @@ async def expect(dispatch: Dispatch, event: Event) -> None:

# Expect the first dispatch deletion
await expect(dispatches[0], Event.DELETED)


async def test_dispatch_stream_pre_connect(
client: FakeClient, sample: Dispatch
) -> None:
"""Test dispatching a stream of dispatches without connecting to the stream."""
microgrid_id = random.randint(1, 100)
dispatches = [sample, sample, sample]

pre_connect_called = False

def pre_connect() -> bool:
nonlocal pre_connect_called
pre_connect_called = True
return True

stream = client.stream(microgrid_id, pre_connect)

await asyncio.sleep(0.1)

dispatches[0] = await client.create(**to_create_params(microgrid_id, dispatches[0]))

msg = await stream.receive()

assert pre_connect_called

0 comments on commit ba15b6e

Please sign in to comment.