From 165de2d193f6cd638b10a560499bb6cd9c81e74b Mon Sep 17 00:00:00 2001 From: Anton Agestam Date: Thu, 14 Dec 2023 16:33:52 +0100 Subject: [PATCH] feature: Refine exposed protocol types By refining the exposed protocol types to include unions types for the different kinds of headers, we allow typing at call-sites to work smoother. This also enables using exhaustiveness checking in a place where it wasn't possible before, when matching the type of an entity's header schema. See the replaced NotImplementedError in the diff for reference. --- src/kio/static/protocol.py | 66 ++++++++++++++++++++++++++++++++++++-- tests/test_integration.py | 18 ++++++----- 2 files changed, 74 insertions(+), 10 deletions(-) diff --git a/src/kio/static/protocol.py b/src/kio/static/protocol.py index 52554af6..57ab7fba 100644 --- a/src/kio/static/protocol.py +++ b/src/kio/static/protocol.py @@ -1,12 +1,46 @@ +from __future__ import annotations + from typing import ClassVar from typing import Protocol +from typing import TypeAlias + +import kio.schema.request_header.v0 +import kio.schema.request_header.v1 +import kio.schema.request_header.v2 +import kio.schema.response_header.v0 +import kio.schema.response_header.v1 from kio._utils import DataclassInstance from .constants import EntityType from .primitive import i16 -__all__ = ("Entity", "Payload") +__all__ = ( + "Entity", + "Payload", + "ResponsePayload", + "RequestPayload", + "RequestHeader", + "ResponseHeader", + "Header", + "HeaderV0RequestPayload", + "HeaderV1RequestPayload", + "HeaderV2RequestPayload", + "HeaderV0ResponsePayload", + "HeaderV1ResponsePayload", +) + + +RequestHeader: TypeAlias = ( + kio.schema.request_header.v0.RequestHeader + | kio.schema.request_header.v1.RequestHeader + | kio.schema.request_header.v2.RequestHeader +) +ResponseHeader: TypeAlias = ( + kio.schema.response_header.v0.ResponseHeader + | kio.schema.response_header.v1.ResponseHeader +) +Header: TypeAlias = RequestHeader | ResponseHeader class Entity(DataclassInstance, Protocol): @@ -51,8 +85,36 @@ class Payload(DataclassInstance, Protocol): # required to accept ANY subtype of Entity as set-type. This is quirky but sound # from static type checking perspective. @property - def __header_schema__(self) -> type[Entity]: + def __header_schema__(self) -> type[Header]: """ The header entity type that should be used when sending or receiving this payload. """ + + +class HeaderV0RequestPayload(Payload, Protocol): + __header_schema__: ClassVar[type[kio.schema.request_header.v0.RequestHeader]] + + +class HeaderV1RequestPayload(Payload, Protocol): + __header_schema__: ClassVar[type[kio.schema.request_header.v1.RequestHeader]] + + +class HeaderV2RequestPayload(Payload, Protocol): + __header_schema__: ClassVar[type[kio.schema.request_header.v2.RequestHeader]] + + +RequestPayload: TypeAlias = ( + HeaderV0RequestPayload | HeaderV1RequestPayload | HeaderV2RequestPayload +) + + +class HeaderV0ResponsePayload(Payload, Protocol): + __header_schema__: ClassVar[type[kio.schema.response_header.v0.ResponseHeader]] + + +class HeaderV1ResponsePayload(Payload, Protocol): + __header_schema__: ClassVar[type[kio.schema.response_header.v1.ResponseHeader]] + + +ResponsePayload: TypeAlias = HeaderV0ResponsePayload | HeaderV1ResponsePayload diff --git a/tests/test_integration.py b/tests/test_integration.py index a90e6a64..941cc876 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -9,6 +9,7 @@ from typing import Any from typing import Final from typing import TypeVar +from typing import assert_never from typing import assert_type from unittest import mock @@ -59,8 +60,9 @@ from kio.static.primitive import i32 from kio.static.primitive import i32Timedelta from kio.static.primitive import i64 -from kio.static.protocol import Entity -from kio.static.protocol import Payload +from kio.static.protocol import RequestHeader +from kio.static.protocol import RequestPayload +from kio.static.protocol import ResponsePayload from . import fixtures @@ -71,12 +73,12 @@ def write_request_header( buffer: Writable, - payload: Payload, + payload: RequestPayload, correlation_id: i32, client_id: str | None, ) -> None: header_schema = payload.__header_schema__ - header: Entity + header: RequestHeader if issubclass(header_schema, kio.schema.request_header.v0.header.RequestHeader): header = header_schema( @@ -98,14 +100,14 @@ def write_request_header( client_id=client_id, ) else: - raise NotImplementedError(f"Unknown request header schema: {header_schema}") + assert_never(header_schema) entity_writer(header_schema)(buffer, header) async def send( stream: StreamWriter, - payload: Payload, + payload: RequestPayload, correlation_id: i32, ) -> None: write_request = entity_writer(type(payload)) @@ -142,7 +144,7 @@ async def read_response_bytes(stream: StreamReader) -> io.BytesIO: return io.BytesIO(await stream.read(response_length)) -R = TypeVar("R", bound=Payload) +R = TypeVar("R", bound=ResponsePayload) def parse_response( @@ -162,7 +164,7 @@ def parse_response( async def make_request( - request: Payload, + request: RequestPayload, response_type: type[R], ) -> R: correlation_id = i32(secrets.randbelow(i32.__high__ + 1))