Skip to content

Commit

Permalink
fix: Correct the logic for tagged field defaults
Browse files Browse the repository at this point in the history
This fixes an issue where the logic for tagged fields was not in sync
with upstream. This is in the realm of undocumented features of the
Kafka protocol, so there is no accurate documentation to reference. The
closest to documentation that exists, is [this section in a source code
README][messages-readme].

[messages-readme]:https://github.com/apache/kafka/tree/4419677e06b77ed9103c179d713450b1eb7881a1/clients/src/main/resources/common/message#deserializing-messages

There was a [long discussion] about this behavior on the Kafka dev
mailing list.

[long discussion]: https://www.mail-archive.com/dev@kafka.apache.org/msg144437.html

This fixes #215.
  • Loading branch information
aiven-anton committed Oct 29, 2024
1 parent a4a78b1 commit 8be3f71
Show file tree
Hide file tree
Showing 6 changed files with 327 additions and 32 deletions.
99 changes: 99 additions & 0 deletions src/kio/serial/_implicit_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import annotations

from collections.abc import Mapping
from dataclasses import MISSING
from dataclasses import Field
from dataclasses import fields
from datetime import timedelta
from types import MappingProxyType
from typing import Final
from typing import TypeVar
from typing import assert_never
from uuid import UUID

from kio.serial._introspect import EntityField
from kio.serial._introspect import EntityTupleField
from kio.serial._introspect import PrimitiveField
from kio.serial._introspect import PrimitiveTupleField
from kio.serial._introspect import classify_field
from kio.serial._introspect import is_optional
from kio.serial.readers import tz_aware_from_i64
from kio.static.constants import uuid_zero
from kio.static.primitive import Records
from kio.static.primitive import TZAware
from kio.static.primitive import f64
from kio.static.primitive import i8
from kio.static.primitive import i16
from kio.static.primitive import i32
from kio.static.primitive import i32Timedelta
from kio.static.primitive import i64
from kio.static.primitive import i64Timedelta
from kio.static.primitive import u8
from kio.static.primitive import u16
from kio.static.primitive import u32
from kio.static.primitive import u64
from kio.static.protocol import Entity

T = TypeVar("T")

primitive_implicit_defaults: Final[Mapping[type, object]] = MappingProxyType(
{
u8: u8(0),
u16: u16(0),
u32: u32(0),
u64: u64(0),
i8: i8(0),
i16: i16(0),
i32: i32(0),
i64: i64(0),
f64: f64(0.0),
i32Timedelta: i32Timedelta.parse(timedelta(0)),
i64Timedelta: i64Timedelta.parse(timedelta(0)),
TZAware: tz_aware_from_i64(i64(0)),
UUID: uuid_zero,
str: "",
bytes: b"",
}
)


def get_implicit_default(field_type: type[T]) -> T:
# Records fields have null as implicit default, supporting this requires changing
# code generation to always expect null for a tagged records field. As of writing
# there are no tagged records fields, or other occurrences where we would need such
# implicit default, so this can be safely deferred.
if issubclass(field_type, Records):
raise NotImplementedError("Tagged record fields are not supported")

try:
# mypy has no way of typing a mapping as T -> T on a per-item level.
return primitive_implicit_defaults[field_type] # type: ignore[return-value]
except KeyError:
return primitive_implicit_defaults[field_type.__bases__[0]] # type: ignore[return-value]


U = TypeVar("U", bound=Entity)


def get_tagged_field_default(field: Field[U]) -> U:
if field.default is not MISSING:
return field.default

if is_optional(field):
raise TypeError("Optional fields should have None as explicit default")

field_class = classify_field(field)

if isinstance(field_class, PrimitiveField):
return get_implicit_default(field_class.type_)
elif isinstance(field_class, EntityField):
return field.type(
**{
nested_field.name: get_tagged_field_default(nested_field)
for nested_field in fields(field.type)
}
)
elif isinstance(field_class, PrimitiveTupleField | EntityTupleField):
raise TypeError("Tuple fields should have the empty tuple as explicit default")

assert_never(field_class)
16 changes: 13 additions & 3 deletions src/kio/serial/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from kio.static.protocol import Entity

from . import readers
from ._implicit_defaults import get_tagged_field_default
from ._introspect import EntityField
from ._introspect import EntityTupleField
from ._introspect import PrimitiveField
Expand Down Expand Up @@ -165,7 +166,11 @@ def entity_reader(
is_tagged_field=tag is not None,
)
if tag is not None:
tagged_field_readers[tag] = field, field_reader
tagged_field_readers[tag] = (
field,
field_reader,
get_tagged_field_default(field),
)
else:
field_readers[field] = field_reader

Expand All @@ -185,12 +190,17 @@ def read_entity(buffer: IO[bytes]) -> E:
return entity_type(**kwargs)

# Read tagged fields.
tagged_field_values = {}
num_tagged_fields = readers.read_unsigned_varint(buffer)
for _ in range(num_tagged_fields):
field_tag = readers.read_unsigned_varint(buffer)
readers.read_unsigned_varint(buffer) # field length
field, field_reader = tagged_field_readers[field_tag]
kwargs[field.name] = field_reader(buffer)
field, field_reader, _ = tagged_field_readers[field_tag]
tagged_field_values[field.name] = field_reader(buffer)

# Resolve tagged field implicit defaults.
for field, _, implicit_default in tagged_field_readers.values():
kwargs[field.name] = tagged_field_values.get(field.name, implicit_default)

return entity_type(**kwargs)

Expand Down
18 changes: 15 additions & 3 deletions src/kio/serial/_serialize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io

from dataclasses import MISSING
from dataclasses import Field
from dataclasses import fields
from typing import Literal
Expand All @@ -11,6 +12,7 @@
from kio.static.protocol import Entity

from . import writers
from ._implicit_defaults import get_tagged_field_default
from ._introspect import EntityField
from ._introspect import EntityTupleField
from ._introspect import PrimitiveField
Expand Down Expand Up @@ -183,7 +185,11 @@ def entity_writer(entity_type: type[E], nullable: bool = False) -> Writer[E | No
is_tag=tag is not None,
)
if tag is not None:
tagged_field_writers[tag] = field, field_writer
tagged_field_writers[tag] = (
field,
field_writer,
get_tagged_field_default(field),
)
else:
field_writers[field] = field_writer

Expand Down Expand Up @@ -212,11 +218,17 @@ def write_entity(buffer: Writable, entity: E) -> None:
num_tagged_fields = 0
with io.BytesIO() as tag_buffer:
# Serialize tagged fields. Note that order is important to fulfill spec.
for tag, (field, field_writer) in tagged_field_writers.items():
for tag, (
field,
field_writer,
implicit_default,
) in tagged_field_writers.items():
field_value = getattr(entity, field.name)

# Skip default-valued fields.
if field_value == field.default:
if field_value == field.default or (
field.default == MISSING and field_value == implicit_default
):
continue

# Write the tag to the buffer and increase counter.
Expand Down
6 changes: 3 additions & 3 deletions src/kio/serial/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def read_timedelta_i64(buffer: IO[bytes]) -> i64Timedelta:
return datetime.timedelta(milliseconds=read_int64(buffer)) # type: ignore[return-value]


def _tz_aware_from_i64(timestamp: i64) -> TZAware:
def tz_aware_from_i64(timestamp: i64) -> TZAware:
dt = datetime.datetime.fromtimestamp(timestamp / 1000, datetime.UTC)
try:
return TZAware.truncate(dt)
Expand All @@ -213,11 +213,11 @@ def _tz_aware_from_i64(timestamp: i64) -> TZAware:


def read_datetime_i64(buffer: IO[bytes]) -> TZAware:
return _tz_aware_from_i64(read_int64(buffer))
return tz_aware_from_i64(read_int64(buffer))


def read_nullable_datetime_i64(buffer: IO[bytes]) -> TZAware | None:
timestamp = read_int64(buffer)
if timestamp == -1:
return None
return _tz_aware_from_i64(timestamp)
return tz_aware_from_i64(timestamp)
118 changes: 118 additions & 0 deletions tests/serial/test_implicit_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from dataclasses import dataclass
from dataclasses import fields
from datetime import timedelta
from uuid import UUID

import pytest

from kio.schema.types import BrokerId
from kio.schema.types import GroupId
from kio.schema.types import ProducerId
from kio.schema.types import TopicName
from kio.schema.types import TransactionalId
from kio.serial._implicit_defaults import get_implicit_default
from kio.serial._implicit_defaults import get_tagged_field_default
from kio.serial.readers import tz_aware_from_i64
from kio.static.primitive import Records
from kio.static.primitive import TZAware
from kio.static.primitive import f64
from kio.static.primitive import i8
from kio.static.primitive import i16
from kio.static.primitive import i32
from kio.static.primitive import i32Timedelta
from kio.static.primitive import i64
from kio.static.primitive import i64Timedelta
from kio.static.primitive import u8
from kio.static.primitive import u16
from kio.static.primitive import u32
from kio.static.primitive import u64


class TestGetImplicitDefault:
@pytest.mark.parametrize(
("annotation", "expected"),
(
(u8, 0),
(u16, 0),
(u32, 0),
(u64, 0),
(i8, 0),
(i16, 0),
(i32, 0),
(i64, 0),
(f64, 0.0),
(i32Timedelta, timedelta(0)),
(i64Timedelta, timedelta(0)),
(TZAware, tz_aware_from_i64(i64(0))),
(UUID, UUID(int=0)),
(str, ""),
(bytes, b""),
(BrokerId, 0),
(GroupId, ""),
(ProducerId, 0),
(TopicName, ""),
(TransactionalId, ""),
),
)
def test_returns_expected_value(self, annotation: type, expected: object) -> None:
assert get_implicit_default(annotation) == expected

def test_raises_not_implemented_error_for_records(self) -> None:
with pytest.raises(NotImplementedError):
get_implicit_default(Records)


class TestGetTaggedFieldDefault:
def test_raises_type_error_for_optional_field(self) -> None:
@dataclass
class A:
a: u8 | None

[field] = fields(A)

with pytest.raises(
TypeError,
match=r"Optional fields should have None as explicit default",
):
get_tagged_field_default(field)

def test_raises_type_error_for_tuple_field(self) -> None:
@dataclass
class A:
a: tuple[u8, ...]

[field] = fields(A)

with pytest.raises(
TypeError,
match=r"Tuple fields should have the empty tuple as explicit default",
):
get_tagged_field_default(field)

def test_can_get_default_for_primitive_field(self) -> None:
@dataclass
class A:
a: u8

[field] = fields(A)
assert get_tagged_field_default(field) == 0

def test_can_get_default_for_entity_field(self) -> None:
@dataclass
class A:
a: u8

@dataclass
class B:
b: A

[field] = fields(B)
assert get_tagged_field_default(field) == A(a=u8(0))

def test_returns_explicit_default_if_defined(self) -> None:
@dataclass
class A:
a: u8 = u8(1)

[field] = fields(A)
assert get_tagged_field_default(field) == u8(1)
Loading

0 comments on commit 8be3f71

Please sign in to comment.