Skip to content

Commit

Permalink
Add input validation checks for the Client Requests
Browse files Browse the repository at this point in the history
Signed-off-by: camille-bouvy-frequenz <camille.bouvy@frequenz.com>
  • Loading branch information
camille-bouvy-frequenz committed Aug 29, 2024
1 parent 56228ae commit 3c2976b
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 39 deletions.
112 changes: 78 additions & 34 deletions src/frequenz/client/electricity_trading/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Module to define the client class."""

import logging
from datetime import datetime
from datetime import datetime, timezone
from decimal import Decimal, InvalidOperation
from typing import Awaitable, cast

Expand Down Expand Up @@ -146,6 +146,8 @@ async def stream_gridpool_orders( # pylint: disable=too-many-arguments
Raises:
grpc.RpcError: If an error occurs while streaming the orders.
"""
self.validate_params(delivery_period=delivery_period)

gridpool_order_filter = GridpoolOrderFilter(
order_states=order_states,
side=market_side,
Expand Down Expand Up @@ -202,6 +204,8 @@ async def stream_gridpool_trades( # pylint: disable=too-many-arguments
Raises:
grpc.RpcError: If an error occurs while streaming gridpool trades.
"""
self.validate_params(delivery_period=delivery_period)

gridpool_trade_filter = GridpoolTradeFilter(
trade_states=trade_states,
trade_ids=trade_ids,
Expand Down Expand Up @@ -254,6 +258,8 @@ async def stream_public_trades(
Raises:
grpc.RpcError: If an error occurs while streaming public trades.
"""
self.validate_params(delivery_period=delivery_period)

public_trade_filter = PublicTradeFilter(
states=states,
delivery_period=delivery_period,
Expand All @@ -280,6 +286,60 @@ async def stream_public_trades(
raise e
return self._public_trades_streams[public_trade_filter].new_receiver()

def validate_params( # pylint: disable=too-many-arguments
self,
price: Price | None | _Sentinel = NO_VALUE,
quantity: Energy | None | _Sentinel = NO_VALUE,
stop_price: Price | None | _Sentinel = NO_VALUE,
peak_price_delta: Price | None | _Sentinel = NO_VALUE,
display_quantity: Energy | None | _Sentinel = NO_VALUE,
delivery_period: DeliveryPeriod | None = None,
valid_until: datetime | None | _Sentinel = NO_VALUE,
) -> None:
"""
Validate the parameters of an order.
This method ensures the following:
- Price and quantity values have the correct number of decimal places and are positive.
- The delivery_start and valid_until values are in the future.
Args:
price: The price of the order.
quantity: The quantity of the order.
stop_price: The stop price of the order.
peak_price_delta: The peak price delta of the order.
display_quantity: The display quantity of the order.
delivery_period: The delivery period of the order.
valid_until: The valid until of the order.
Raises:
ValueError: If the parameters are invalid.
"""
if not isinstance(price, _Sentinel) and price is not None:
validate_decimal_places(price.amount, PRECISION_DECIMAL_PRICE, "price")
if not isinstance(quantity, _Sentinel) and quantity is not None:
validate_decimal_places(
quantity.mwh, PRECISION_DECIMAL_QUANTITY, "quantity"
)
if not isinstance(stop_price, _Sentinel) and stop_price is not None:
validate_decimal_places(
stop_price.amount, PRECISION_DECIMAL_PRICE, "stop price"
)
if not isinstance(peak_price_delta, _Sentinel) and peak_price_delta is not None:
validate_decimal_places(
peak_price_delta.amount, PRECISION_DECIMAL_PRICE, "peak price delta"
)
if not isinstance(display_quantity, _Sentinel) and display_quantity is not None:
validate_decimal_places(
display_quantity.mwh, PRECISION_DECIMAL_QUANTITY, "display quantity"
)
if not isinstance(delivery_period, _Sentinel) and delivery_period is not None:
if delivery_period.start < datetime.now(timezone.utc):
raise ValueError("delivery_period must be in the future")
if not isinstance(valid_until, _Sentinel) and valid_until is not None:
if valid_until < datetime.now(timezone.utc):
raise ValueError("valid_until must be in the future")

async def create_gridpool_order( # pylint: disable=too-many-arguments, too-many-locals
self,
gridpool_id: int,
Expand Down Expand Up @@ -322,21 +382,15 @@ async def create_gridpool_order( # pylint: disable=too-many-arguments, too-many
Raises:
grpc.RpcError: An error occurred while creating the order.
"""
validate_decimal_places(price.amount, PRECISION_DECIMAL_PRICE, "price")
validate_decimal_places(quantity.mwh, PRECISION_DECIMAL_QUANTITY, "quantity")
if stop_price is not None:
validate_decimal_places(
stop_price.amount, PRECISION_DECIMAL_PRICE, "stop price"
)
if peak_price_delta is not None:
validate_decimal_places(
peak_price_delta.amount, PRECISION_DECIMAL_PRICE, "peak price delta"
)
if display_quantity is not None:
validate_decimal_places(
display_quantity.mwh, PRECISION_DECIMAL_QUANTITY, "display quantity"
)

self.validate_params(
price=price,
quantity=quantity,
stop_price=stop_price,
peak_price_delta=peak_price_delta,
display_quantity=display_quantity,
delivery_period=delivery_period,
valid_until=valid_until,
)
order = Order(
delivery_area=delivery_area,
delivery_period=delivery_period,
Expand Down Expand Up @@ -412,24 +466,14 @@ async def update_gridpool_order( # pylint: disable=too-many-arguments, too-many
Raises:
ValueError: If no fields to update are provided.
"""
if not isinstance(price, _Sentinel) and price is not None:
validate_decimal_places(price.amount, PRECISION_DECIMAL_PRICE, "price")
if not isinstance(quantity, _Sentinel) and quantity is not None:
validate_decimal_places(
quantity.mwh, PRECISION_DECIMAL_QUANTITY, "quantity"
)
if not isinstance(stop_price, _Sentinel) and stop_price is not None:
validate_decimal_places(
stop_price.amount, PRECISION_DECIMAL_PRICE, "stop price"
)
if not isinstance(peak_price_delta, _Sentinel) and peak_price_delta is not None:
validate_decimal_places(
peak_price_delta.amount, PRECISION_DECIMAL_PRICE, "peak price delta"
)
if not isinstance(display_quantity, _Sentinel) and display_quantity is not None:
validate_decimal_places(
display_quantity.mwh, PRECISION_DECIMAL_QUANTITY, "display quantity"
)
self.validate_params(
price=price,
quantity=quantity,
stop_price=stop_price,
peak_price_delta=peak_price_delta,
display_quantity=display_quantity,
valid_until=valid_until,
)

params = {
"price": price,
Expand Down
15 changes: 10 additions & 5 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""Tests for the methods in the client."""
import asyncio
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from decimal import Decimal
from unittest.mock import AsyncMock

Expand Down Expand Up @@ -34,6 +34,10 @@
TradeState,
)

# Set the year of the delivery starts to one year from now to make sure
# we never get an expired delivery start
NEXT_YEAR = datetime.now().year + 1


@pytest.fixture
def set_up() -> Generator[Any, Any, Any]:
Expand All @@ -48,18 +52,19 @@ def set_up() -> Generator[Any, Any, Any]:
asyncio.set_event_loop(loop)

# Set up the parameters for the orders
delivery_start = datetime(NEXT_YEAR, 1, 1, 12, 0, tzinfo=timezone.utc)
gridpool_id = 123
delivery_area = DeliveryArea(code="DE", code_type=EnergyMarketCodeType.EUROPE_EIC)
delivery_period = DeliveryPeriod(
start=datetime.fromisoformat("2023-01-01T00:00:00+00:00"),
start=delivery_start,
duration=timedelta(minutes=15),
)
order_type = OrderType.LIMIT
side = MarketSide.BUY
price = Price(amount=Decimal("50"), currency=Currency.EUR)
quantity = Energy(mwh=Decimal("0.1"))
order_execution_option = OrderExecutionOption.AON
valid_until = datetime.fromisoformat("2023-01-01T00:00:00+00:00")
valid_until = delivery_start + timedelta(hours=3)

yield {
"client": _,
Expand Down Expand Up @@ -103,8 +108,8 @@ def set_up_order_detail_response(
),
open_quantity=Energy(mwh=Decimal("5.00")),
filled_quantity=Energy(mwh=Decimal("0.00")),
create_time=datetime.fromisoformat("2024-01-03T12:00:00+00:00"),
modification_time=datetime.fromisoformat("2024-01-03T12:00:00+00:00"),
create_time=set_up["delivery_period"].start - timedelta(hours=2),
modification_time=set_up["delivery_period"].start - timedelta(hours=1),
).to_pb()


Expand Down

0 comments on commit 3c2976b

Please sign in to comment.