Skip to content

Commit

Permalink
Implement KIP-893 nullable entity fields
Browse files Browse the repository at this point in the history
This commit implements support for the nullable entity field introduced
in `ConsumerGroupHeartbeatResponse.assignment`. This kind of field was
formalized upstream in KIP-893 but is not mentioned in any official
documentation of the protocol.

The KIP also mentions tagged fields. I left this out for now as I don't
understand why another mechanism for omitting a tagged field value is
needed, and AFAIK no entity uses this yet.
  • Loading branch information
aiven-anton committed Nov 21, 2023
1 parent da82e8c commit 04455db
Show file tree
Hide file tree
Showing 16 changed files with 306 additions and 43 deletions.
23 changes: 14 additions & 9 deletions codegen/generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@


def format_default(
type_: Primitive,
type_: Primitive | EntityType | CommonStructType,
default: str | int | float | bool,
optional: bool,
custom_type: CustomTypeDef | None,
Expand All @@ -99,6 +99,7 @@ def format_default(
| Primitive.uint32
| Primitive.uint64
), str(default):
assert not isinstance(type_, EntityType | CommonStructType)
if custom_type_open:
return "".join(
(
Expand Down Expand Up @@ -133,7 +134,7 @@ def format_default(
return "None"

raise NotImplementedError(
f"Failed parsing default for {type_.value=} field: {default=!r}"
f"Failed parsing default for {type_=} field: {default=!r}"
)


Expand Down Expand Up @@ -166,8 +167,6 @@ def format_dataclass_field(
if isinstance(field_type, PrimitiveArrayType):
field_kwargs["default"] = "()"
elif default is not None:
assert not isinstance(field_type, EntityType)
assert not isinstance(field_type, CommonStructType)
field_kwargs["default"] = format_default(
field_type, default, optional, custom_type
)
Expand Down Expand Up @@ -355,21 +354,27 @@ def generate_entity_array_field(
return f" {to_snake_case(field.name)}: tuple[{field.type}, ...]{field_call}\n"


def entity_annotation(field: EntityField | CommonStructField, optional: bool) -> str:
return f"{field.type} | None" if optional else str(field.type)


def generate_entity_field(
field: EntityField | CommonStructField,
version: int,
) -> str:
optional = (
field.nullableVersions.matches(version) if field.nullableVersions else False
)
field_call = format_dataclass_field(
field_type=field.type,
default=None,
optional=(
field.nullableVersions.matches(version) if field.nullableVersions else False
),
default=field.default,
optional=optional,
custom_type=None,
tag=field.get_tag(version),
ignorable=field.ignorable,
)
return f" {to_snake_case(field.name)}: {field.type}{field_call}\n"
annotation = entity_annotation(field, optional)
return f" {to_snake_case(field.name)}: {annotation}{field_call}\n"


def generate_common_struct_array_field(
Expand Down
1 change: 0 additions & 1 deletion codegen/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def main() -> None:
"FetchResponse", # Records
"FetchSnapshotResponse", # Records
"FetchRequest", # Should not output tagged field if its value equals to default (presumably)
"ConsumerGroupHeartbeatResponse", # Nullable `assignment` field
}:
module_code[module_path].append(
test_code_java.format(
Expand Down
2 changes: 2 additions & 0 deletions codegen/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ class CommonStructArrayField(_BaseField):

class CommonStructField(_BaseField):
type: CommonStructType
default: Literal["null"] | None = None


class EntityArrayField(_BaseField):
Expand All @@ -382,6 +383,7 @@ class EntityArrayField(_BaseField):
class EntityField(_BaseField):
type: EntityType
fields: tuple[Field, ...]
default: Literal["null"] | None = None


class _BaseSchema(BaseModel):
Expand Down
18 changes: 14 additions & 4 deletions src/kio/_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Work-around for broken support for cache decorators in
# https://github.com/python/typeshed/issues/6347
# https://stackoverflow.com/a/73517689
from typing import TYPE_CHECKING
from typing import Protocol
from typing import TypeVar

__all__ = ("cache",)
__all__ = ("cache", "DataclassInstance")


# Work-around for broken support for cache decorators in
# https://github.com/python/typeshed/issues/6347
# https://stackoverflow.com/a/73517689
if TYPE_CHECKING:
_C = TypeVar("_C")

Expand All @@ -14,3 +16,11 @@ def cache(c: _C) -> _C:

else:
from functools import cache


if TYPE_CHECKING:
from _typeshed import DataclassInstance
else:

class DataclassInstance(Protocol):
...
2 changes: 1 addition & 1 deletion src/kio/schema/consumer_group_heartbeat/v0/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ class ConsumerGroupHeartbeatResponse(ApiMessage):
"""True if the member should compute the assignment for the group."""
heartbeat_interval: i32Timedelta = field(metadata={"kafka_type": "timedelta_i32"})
"""The heartbeat interval in milliseconds."""
assignment: Assignment
assignment: Assignment | None = field(default=None)
"""null if not provided; the assignment otherwise."""
21 changes: 21 additions & 0 deletions src/kio/serial/_introspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,27 @@ class FieldKind(enum.Enum):
def classify_field(field: Field[T]) -> tuple[FieldKind, type[T]]:
type_origin = get_origin(field.type)

if type_origin is UnionType:
try:
a, b = get_args(field.type)
except ValueError:
raise SchemaError(
f"Field {field.name} has unsupported union type: {field.type}"
) from None

if a is NoneType:
inner_type = b
elif b is NoneType:
inner_type = a
else:
raise SchemaError("Only union with None is supported")

return (
(FieldKind.entity, inner_type)
if is_dataclass(inner_type)
else (FieldKind.primitive, inner_type)
)

if type_origin is not tuple:
return (
(FieldKind.entity, field.type) # type: ignore[return-value]
Expand Down
39 changes: 36 additions & 3 deletions src/kio/serial/_parse.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from dataclasses import Field
from dataclasses import fields
from typing import IO
from typing import Literal
from typing import TypeVar
from typing import assert_never
from typing import overload

from kio._utils import cache
from kio.static.protocol import Entity
Expand All @@ -13,6 +15,8 @@
from ._introspect import get_field_tag
from ._introspect import get_schema_field_type
from ._introspect import is_optional
from ._shared import NullableEntityMarker
from .readers import read_int8


def get_reader(
Expand Down Expand Up @@ -114,7 +118,11 @@ def get_field_reader(
)
)
case FieldKind.entity:
return entity_reader(field_type) # type: ignore[type-var]
return ( # type: ignore[no-any-return]
entity_reader(field_type, nullable=True) # type: ignore[call-overload]
if is_optional(field)
else entity_reader(field_type) # type: ignore[type-var]
)
case FieldKind.entity_tuple:
return array_reader( # type: ignore[return-value]
entity_reader(field_type) # type: ignore[type-var]
Expand All @@ -126,8 +134,24 @@ def get_field_reader(
E = TypeVar("E", bound=Entity)


@cache
@overload
def entity_reader(entity_type: type[E]) -> readers.Reader[E]:
...


@overload
def entity_reader(
entity_type: type[E],
nullable: Literal[True],
) -> readers.Reader[E | None]:
...


@cache
def entity_reader(
entity_type: type[E],
nullable: bool = False,
) -> readers.Reader[E | None]:
field_readers = {}
tagged_field_readers = {}
is_request_header = entity_type.__name__ == "RequestHeader"
Expand Down Expand Up @@ -170,4 +194,13 @@ def read_entity(buffer: IO[bytes]) -> E:

return entity_type(**kwargs)

return read_entity
if not nullable:
return read_entity

# This is undocumented behavior, formalized in KIP-893.
# https://cwiki.apache.org/confluence/display/KAFKA/KIP-893%3A+The+Kafka+protocol+should+support+nullable+structs
def read_nullable_entity(buffer: IO[bytes]) -> E | None:
marker = NullableEntityMarker(read_int8(buffer))
return None if marker is NullableEntityMarker.null else read_entity(buffer)

return read_nullable_entity
37 changes: 34 additions & 3 deletions src/kio/serial/_serialize.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import io
from dataclasses import Field
from dataclasses import fields
from typing import Literal
from typing import TypeVar
from typing import assert_never
from typing import overload

from kio._utils import cache
from kio.static.protocol import Entity
Expand All @@ -13,10 +15,12 @@
from ._introspect import get_field_tag
from ._introspect import get_schema_field_type
from ._introspect import is_optional
from ._shared import NullableEntityMarker
from .writers import Writable
from .writers import Writer
from .writers import compact_array_writer
from .writers import legacy_array_writer
from .writers import write_int8
from .writers import write_tagged_field
from .writers import write_unsigned_varint

Expand Down Expand Up @@ -128,7 +132,11 @@ def get_field_writer(
)
)
case FieldKind.entity:
return entity_writer(field_type) # type: ignore[type-var]
return ( # type: ignore[no-any-return]
entity_writer(field_type, nullable=True) # type: ignore[call-overload]
if optional
else entity_writer(field_type) # type: ignore[type-var]
)
case FieldKind.entity_tuple:
return array_writer( # type: ignore[return-value]
entity_writer(field_type) # type: ignore[type-var]
Expand All @@ -140,8 +148,31 @@ def get_field_writer(
E = TypeVar("E", bound=Entity)


@cache
def _wrap_nullable(write_entity: Writer[E]) -> Writer[E | None]:
# This is undocumented behavior, formalized in KIP-893.
# https://cwiki.apache.org/confluence/display/KAFKA/KIP-893%3A+The+Kafka+protocol+should+support+nullable+structs
def write_nullable(buffer: Writable, entity: E | None) -> None:
if entity is None:
write_int8(buffer, NullableEntityMarker.null.value)
return
write_int8(buffer, NullableEntityMarker.not_null.value)
write_entity(buffer, entity)

return write_nullable


@overload
def entity_writer(entity_type: type[E]) -> Writer[E]:
...


@overload
def entity_writer(entity_type: type[E], nullable: Literal[True]) -> Writer[E | None]:
...


@cache
def entity_writer(entity_type: type[E], nullable: bool = False) -> Writer[E | None]:
field_writers = {}
tagged_field_writers = {}
is_request_header = entity_type.__name__ == "RequestHeader"
Expand Down Expand Up @@ -204,4 +235,4 @@ def write_entity(buffer: Writable, entity: E) -> None:
write_unsigned_varint(buffer, num_tagged_fields)
buffer.write(tag_buffer.getvalue())

return write_entity
return _wrap_nullable(write_entity) if nullable else write_entity # type: ignore[return-value]
8 changes: 8 additions & 0 deletions src/kio/serial/_shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import enum

from kio.static.primitive import i8


class NullableEntityMarker(enum.Enum):
null = i8(-1)
not_null = i8(1)
11 changes: 2 additions & 9 deletions src/kio/static/protocol.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
from typing import TYPE_CHECKING
from typing import ClassVar
from typing import Protocol

from .primitive import i16

if TYPE_CHECKING:
from _typeshed import DataclassInstance
else:

class DataclassInstance(Protocol):
...
from kio._utils import DataclassInstance

from .primitive import i16

__all__ = ("ApiMessage", "Entity", "Payload")

Expand Down
32 changes: 31 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@
from subprocess import PIPE
from subprocess import Popen
from subprocess import TimeoutExpired
from types import NoneType
from types import UnionType
from typing import Any
from typing import get_args
from typing import get_origin
from uuid import UUID

import pytest
import pytest_asyncio
from hypothesis import settings

from kio._utils import DataclassInstance
from kio.serial import entity_writer
from kio.static.protocol import Entity

Expand Down Expand Up @@ -90,11 +95,36 @@ async def stream_writer(
return async_buffers[1]


def is_nullable_entity_field(field: dataclasses.Field) -> bool:
if get_origin(field.type) is not UnionType:
return False
try:
a, b = get_args(field.type)
except ValueError:
return False
return (a is NoneType and dataclasses.is_dataclass(b)) or (
b is NoneType and dataclasses.is_dataclass(a)
)


def map_nullable_entity_fields(obj: DataclassInstance) -> dict[str, bool]:
"""Return map of KIP-893 nullable entity fields."""
return {
field.name: is_nullable_entity_field(field) for field in dataclasses.fields(obj)
}


class JavaTester:
class _Encoder(JSONEncoder):
def default(self, o: Any) -> Any:
if dataclasses.is_dataclass(o):
return self._replace_tzaware_nulls(dataclasses.asdict(o))
return self._replace_tzaware_nulls(
{
k: v
for k, v in dataclasses.asdict(o).items()
if (v is not None or not map_nullable_entity_fields(o)[k])
}
)
if isinstance(o, timedelta):
return round(o.total_seconds() * 1000)
if isinstance(o, datetime):
Expand Down
Loading

0 comments on commit 04455db

Please sign in to comment.