Skip to content

Commit

Permalink
Add support in FakeService for stream()
Browse files Browse the repository at this point in the history
Also adds a test for it.

Signed-off-by: Mathias L. Baumann <mathias.baumann@frequenz.com>
  • Loading branch information
Marenz committed Sep 24, 2024
1 parent 8d012f6 commit 8e6719b
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 10 deletions.
1 change: 1 addition & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
## New Features

* Added support for duration=None when creating a dispatch.
* The `FakeService` now supports the `stream()` method.

## Bug Fixes

Expand Down
91 changes: 82 additions & 9 deletions src/frequenz/client/dispatch/test/_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
Useful for testing.
"""
import dataclasses
import logging
from dataclasses import dataclass, replace
from datetime import datetime, timezone
from typing import AsyncIterator

import grpc
import grpc.aio
Expand All @@ -24,16 +26,19 @@
GetMicrogridDispatchResponse,
ListMicrogridDispatchesRequest,
ListMicrogridDispatchesResponse,
StreamMicrogridDispatchesRequest,
StreamMicrogridDispatchesResponse,
UpdateMicrogridDispatchRequest,
UpdateMicrogridDispatchResponse,
)
from frequenz.channels import Broadcast
from google.protobuf.empty_pb2 import Empty

# pylint: enable=no-name-in-module
from frequenz.client.base.conversion import to_datetime as _to_dt

from .._internal_types import DispatchCreateRequest
from ..types import Dispatch
from ..types import Dispatch, DispatchEvent, Event

ALL_KEY = "all"
"""Key that has access to all resources in the FakeService."""
Expand All @@ -46,12 +51,25 @@
class FakeService:
"""Dispatch mock service for testing."""

_stream_channel: Broadcast[tuple[int, DispatchEvent]]
"""Channel for dispatch events."""

_stream_sender: Sender[tuple[int, DispatchEvent]]
"""Sender for dispatch events."""

dispatches: dict[int, list[Dispatch]] = dataclasses.field(default_factory=dict)
"""List of dispatches per microgrid."""

_last_id: int = 0
"""Last used dispatch id."""

def __init__(self) -> None:
"""Initialize the service."""
super().__init__()
self._stream_channel = Broadcast(name="dispatch-stream")
self._stream_sender = self._stream_channel.new_sender()
self.dispatches = {}

def _check_access(self, metadata: grpc.aio.Metadata) -> None:
"""Check if the access key is valid.
Expand Down Expand Up @@ -120,6 +138,35 @@ async def ListMicrogridDispatches(
),
)

async def StreamMicrogridDispatches(
self, request: StreamMicrogridDispatchesRequest, metadata: grpc.aio.Metadata
) -> AsyncIterator[StreamMicrogridDispatchesResponse]:
"""Stream microgrid dispatches changes.
Args:
request: The request.
metadata: The metadata.
Returns:
An async generator for dispatch changes.
Yields:
An event for each dispatch change.
"""
self._check_access(metadata)

receiver = self._stream_channel.new_receiver()

async for message in receiver:
if message[0] == request.microgrid_id:
logging.info("Sending event %s", message[1])

response = StreamMicrogridDispatchesResponse(
event=message[1].event.value,
dispatch=message[1].dispatch.to_protobuf(),
)
yield response

# pylint: disable=too-many-branches
@staticmethod
def _filter_dispatch(
Expand Down Expand Up @@ -179,6 +226,13 @@ async def CreateMicrogridDispatch(
# implicitly create the list if it doesn't exist
self.dispatches.setdefault(request.microgrid_id, []).append(new_dispatch)

await self._stream_sender.send(
(
request.microgrid_id,
DispatchEvent(dispatch=new_dispatch, event=Event.CREATED),
)
)

return CreateMicrogridDispatchResponse(dispatch=new_dispatch.to_protobuf())

async def UpdateMicrogridDispatch(
Expand Down Expand Up @@ -241,9 +295,9 @@ async def UpdateMicrogridDispatch(
| "bymonthdays"
| "bymonths"
):
getattr(pb_dispatch.data.recurrence, split_path[1])[:] = (
getattr(request.update.recurrence, split_path[1])[:]
)
getattr(pb_dispatch.data.recurrence, split_path[1])[
:
] = getattr(request.update.recurrence, split_path[1])[:]

dispatch = Dispatch.from_protobuf(pb_dispatch)
dispatch = replace(
Expand All @@ -253,6 +307,13 @@ async def UpdateMicrogridDispatch(

grid_dispatches[index] = dispatch

await self._stream_sender.send(
(
request.microgrid_id,
DispatchEvent(dispatch=dispatch, event=Event.UPDATED),
)
)

return UpdateMicrogridDispatchResponse(dispatch=dispatch.to_protobuf())

async def GetMicrogridDispatch(
Expand Down Expand Up @@ -285,19 +346,31 @@ async def DeleteMicrogridDispatch(
"""Delete a given dispatch."""
self._check_access(metadata)
grid_dispatches = self.dispatches.get(request.microgrid_id, [])
num_dispatches = len(grid_dispatches)
self.dispatches[request.microgrid_id] = [
d for d in grid_dispatches if d.id != request.dispatch_id
]

if len(self.dispatches[request.microgrid_id]) == num_dispatches:
dispatch_to_delete = next(
(d for d in grid_dispatches if d.id == request.dispatch_id), None
)

if dispatch_to_delete is None:
error = grpc.RpcError()
# pylint: disable=protected-access
error._code = grpc.StatusCode.NOT_FOUND # type: ignore
error._details = "Dispatch not found" # type: ignore
# pylint: enable=protected-access
raise error

grid_dispatches.remove(dispatch_to_delete)

await self._stream_sender.send(
(
request.microgrid_id,
DispatchEvent(
dispatch=dispatch_to_delete,
event=Event.DELETED,
),
)
)

return Empty()

# pylint: enable=invalid-name
Expand Down
50 changes: 49 additions & 1 deletion tests/test_dispatch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Tests for the frequenz.client.dispatch package."""

import asyncio
import random
from dataclasses import replace
from datetime import timedelta
Expand All @@ -13,7 +14,7 @@
from frequenz.client.dispatch.test.client import FakeClient, to_create_params
from frequenz.client.dispatch.test.fixtures import client, generator, sample
from frequenz.client.dispatch.test.generator import DispatchGenerator
from frequenz.client.dispatch.types import Dispatch
from frequenz.client.dispatch.types import Dispatch, Event

# Ignore flake8 error in the rest of the file to use the same fixture names
# flake8: noqa[811]
Expand Down Expand Up @@ -261,3 +262,50 @@ async def test_delete_dispatch_fail(client: FakeClient) -> None:
"""Test deleting a non-existent dispatch."""
with raises(grpc.RpcError):
await client.delete(microgrid_id=1, dispatch_id=1)


async def test_dispatch_stream(client: FakeClient, sample: Dispatch) -> None:
"""Test dispatching a stream of dispatches."""
microgrid_id = random.randint(1, 100)
dispatches = [sample, sample, sample]

stream = client.stream(microgrid_id)

async def expect(dispatch: Dispatch, event: Event) -> None:
message = await stream.receive()
assert message.dispatch == dispatch
assert message.event == event

# Give stream some time to start
await asyncio.sleep(0.1)

# Add a new dispatch
dispatches[0] = await client.create(**to_create_params(microgrid_id, dispatches[0]))
# Expect the first dispatch event
await expect(dispatches[0], Event.CREATED)

# Add a new dispatch
dispatches[1] = await client.create(**to_create_params(microgrid_id, dispatches[1]))
# Expect the second dispatch
await expect(dispatches[1], Event.CREATED)

# Add a new dispatch
dispatches[2] = await client.create(**to_create_params(microgrid_id, dispatches[2]))
# Expect the third dispatch
await expect(dispatches[2], Event.CREATED)

# Update the first dispatch
dispatches[0] = await client.update(
microgrid_id=microgrid_id,
dispatch_id=dispatches[0].id,
new_fields={"start_time": dispatches[0].start_time + timedelta(minutes=1)},
)

# Expect the first dispatch update
await expect(dispatches[0], Event.UPDATED)

# Delete the first dispatch
await client.delete(microgrid_id=microgrid_id, dispatch_id=dispatches[0].id)

# Expect the first dispatch deletion
await expect(dispatches[0], Event.DELETED)

0 comments on commit 8e6719b

Please sign in to comment.