Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Correct the logic for tagged field defaults #216

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions codegen/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def test_{entity_snake_case}_roundtrip(instance: {entity_type}) -> None:
"""

test_code_java = """\
{xfail}
@pytest.mark.java
@given(instance=from_type({entity_type}))
def test_{entity_snake_case}_java(instance: {entity_type}, java_tester: JavaTester) -> None:
Expand Down Expand Up @@ -89,20 +88,10 @@ def main() -> None:
)

if entity_type.__type__ is not EntityType.nested:
xfail = (
""
if entity_type.__name__ not in "UpdateRaftVoterResponse"
else (
"@pytest.mark.xfail("
'reason="https://github.com/Aiven-Open/kio/issues/215"'
")"
)
)
module_code[module_path].append(
test_code_java.format(
entity_type=entity_type.__name__,
entity_snake_case=to_snake_case(entity_type.__name__),
xfail=xfail,
)
)

Expand Down
99 changes: 99 additions & 0 deletions src/kio/serial/_implicit_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import annotations

from collections.abc import Mapping
from dataclasses import MISSING
from dataclasses import Field
from dataclasses import fields
from datetime import timedelta
from types import MappingProxyType
from typing import Final
from typing import TypeVar
from typing import assert_never
from uuid import UUID

from kio.serial._introspect import EntityField
from kio.serial._introspect import EntityTupleField
from kio.serial._introspect import PrimitiveField
from kio.serial._introspect import PrimitiveTupleField
from kio.serial._introspect import classify_field
from kio.serial._introspect import is_optional
from kio.serial.readers import tz_aware_from_i64
from kio.static.constants import uuid_zero
from kio.static.primitive import Records
from kio.static.primitive import TZAware
from kio.static.primitive import f64
from kio.static.primitive import i8
from kio.static.primitive import i16
from kio.static.primitive import i32
from kio.static.primitive import i32Timedelta
from kio.static.primitive import i64
from kio.static.primitive import i64Timedelta
from kio.static.primitive import u8
from kio.static.primitive import u16
from kio.static.primitive import u32
from kio.static.primitive import u64
from kio.static.protocol import Entity

T = TypeVar("T")

primitive_implicit_defaults: Final[Mapping[type, object]] = MappingProxyType(
{
u8: u8(0),
u16: u16(0),
u32: u32(0),
u64: u64(0),
i8: i8(0),
i16: i16(0),
i32: i32(0),
i64: i64(0),
f64: f64(0.0),
i32Timedelta: i32Timedelta.parse(timedelta(0)),
i64Timedelta: i64Timedelta.parse(timedelta(0)),
TZAware: tz_aware_from_i64(i64(0)),
UUID: uuid_zero,
str: "",
bytes: b"",
}
)


def get_implicit_default(field_type: type[T]) -> T:
# Records fields have null as implicit default, supporting this requires changing
# code generation to always expect null for a tagged records field. As of writing
# there are no tagged records fields, or other occurrences where we would need such
# implicit default, so this can be safely deferred.
if issubclass(field_type, Records):
raise NotImplementedError("Tagged record fields are not supported")

try:
# mypy has no way of typing a mapping as T -> T on a per-item level.
return primitive_implicit_defaults[field_type] # type: ignore[return-value]
except KeyError:
return primitive_implicit_defaults[field_type.__bases__[0]] # type: ignore[return-value]


U = TypeVar("U", bound=Entity)


def get_tagged_field_default(field: Field[U]) -> U:
if field.default is not MISSING:
return field.default

if is_optional(field):
raise TypeError("Optional fields should have None as explicit default")

field_class = classify_field(field)

if isinstance(field_class, PrimitiveField):
return get_implicit_default(field_class.type_)
elif isinstance(field_class, EntityField):
return field.type(
**{
nested_field.name: get_tagged_field_default(nested_field)
for nested_field in fields(field.type)
}
)
elif isinstance(field_class, PrimitiveTupleField | EntityTupleField):
raise TypeError("Tuple fields should have the empty tuple as explicit default")

assert_never(field_class)
16 changes: 13 additions & 3 deletions src/kio/serial/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from kio.static.protocol import Entity

from . import readers
from ._implicit_defaults import get_tagged_field_default
from ._introspect import EntityField
from ._introspect import EntityTupleField
from ._introspect import PrimitiveField
Expand Down Expand Up @@ -165,7 +166,11 @@ def entity_reader(
is_tagged_field=tag is not None,
)
if tag is not None:
tagged_field_readers[tag] = field, field_reader
tagged_field_readers[tag] = (
field,
field_reader,
get_tagged_field_default(field),
)
else:
field_readers[field] = field_reader

Expand All @@ -185,12 +190,17 @@ def read_entity(buffer: IO[bytes]) -> E:
return entity_type(**kwargs)

# Read tagged fields.
tagged_field_values = {}
num_tagged_fields = readers.read_unsigned_varint(buffer)
for _ in range(num_tagged_fields):
field_tag = readers.read_unsigned_varint(buffer)
readers.read_unsigned_varint(buffer) # field length
field, field_reader = tagged_field_readers[field_tag]
kwargs[field.name] = field_reader(buffer)
field, field_reader, _ = tagged_field_readers[field_tag]
tagged_field_values[field.name] = field_reader(buffer)

# Resolve tagged field implicit defaults.
for field, _, implicit_default in tagged_field_readers.values():
kwargs[field.name] = tagged_field_values.get(field.name, implicit_default)

return entity_type(**kwargs)

Expand Down
18 changes: 15 additions & 3 deletions src/kio/serial/_serialize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io

from dataclasses import MISSING
from dataclasses import Field
from dataclasses import fields
from typing import Literal
Expand All @@ -11,6 +12,7 @@
from kio.static.protocol import Entity

from . import writers
from ._implicit_defaults import get_tagged_field_default
from ._introspect import EntityField
from ._introspect import EntityTupleField
from ._introspect import PrimitiveField
Expand Down Expand Up @@ -183,7 +185,11 @@ def entity_writer(entity_type: type[E], nullable: bool = False) -> Writer[E | No
is_tag=tag is not None,
)
if tag is not None:
tagged_field_writers[tag] = field, field_writer
tagged_field_writers[tag] = (
field,
field_writer,
get_tagged_field_default(field),
)
else:
field_writers[field] = field_writer

Expand Down Expand Up @@ -212,11 +218,17 @@ def write_entity(buffer: Writable, entity: E) -> None:
num_tagged_fields = 0
with io.BytesIO() as tag_buffer:
# Serialize tagged fields. Note that order is important to fulfill spec.
for tag, (field, field_writer) in tagged_field_writers.items():
for tag, (
field,
field_writer,
implicit_default,
) in tagged_field_writers.items():
field_value = getattr(entity, field.name)

# Skip default-valued fields.
if field_value == field.default:
if field_value == field.default or (
field.default == MISSING and field_value == implicit_default
):
continue

# Write the tag to the buffer and increase counter.
Expand Down
6 changes: 3 additions & 3 deletions src/kio/serial/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def read_timedelta_i64(buffer: IO[bytes]) -> i64Timedelta:
return datetime.timedelta(milliseconds=read_int64(buffer)) # type: ignore[return-value]


def _tz_aware_from_i64(timestamp: i64) -> TZAware:
def tz_aware_from_i64(timestamp: i64) -> TZAware:
dt = datetime.datetime.fromtimestamp(timestamp / 1000, datetime.UTC)
try:
return TZAware.truncate(dt)
Expand All @@ -213,11 +213,11 @@ def _tz_aware_from_i64(timestamp: i64) -> TZAware:


def read_datetime_i64(buffer: IO[bytes]) -> TZAware:
return _tz_aware_from_i64(read_int64(buffer))
return tz_aware_from_i64(read_int64(buffer))


def read_nullable_datetime_i64(buffer: IO[bytes]) -> TZAware | None:
timestamp = read_int64(buffer)
if timestamp == -1:
return None
return _tz_aware_from_i64(timestamp)
return tz_aware_from_i64(timestamp)
1 change: 0 additions & 1 deletion tests/generated/test_update_raft_voter_v0_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def test_update_raft_voter_response_roundtrip(
assert instance == result


@pytest.mark.xfail(reason="https://github.com/Aiven-Open/kio/issues/215")
@pytest.mark.java
@given(instance=from_type(UpdateRaftVoterResponse))
def test_update_raft_voter_response_java(
Expand Down
118 changes: 118 additions & 0 deletions tests/serial/test_implicit_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from dataclasses import dataclass
from dataclasses import fields
from datetime import timedelta
from uuid import UUID

import pytest

from kio.schema.types import BrokerId
from kio.schema.types import GroupId
from kio.schema.types import ProducerId
from kio.schema.types import TopicName
from kio.schema.types import TransactionalId
from kio.serial._implicit_defaults import get_implicit_default
from kio.serial._implicit_defaults import get_tagged_field_default
from kio.serial.readers import tz_aware_from_i64
from kio.static.primitive import Records
from kio.static.primitive import TZAware
from kio.static.primitive import f64
from kio.static.primitive import i8
from kio.static.primitive import i16
from kio.static.primitive import i32
from kio.static.primitive import i32Timedelta
from kio.static.primitive import i64
from kio.static.primitive import i64Timedelta
from kio.static.primitive import u8
from kio.static.primitive import u16
from kio.static.primitive import u32
from kio.static.primitive import u64


class TestGetImplicitDefault:
@pytest.mark.parametrize(
("annotation", "expected"),
(
(u8, 0),
(u16, 0),
(u32, 0),
(u64, 0),
(i8, 0),
(i16, 0),
(i32, 0),
(i64, 0),
(f64, 0.0),
(i32Timedelta, timedelta(0)),
(i64Timedelta, timedelta(0)),
(TZAware, tz_aware_from_i64(i64(0))),
(UUID, UUID(int=0)),
(str, ""),
(bytes, b""),
(BrokerId, 0),
(GroupId, ""),
(ProducerId, 0),
(TopicName, ""),
(TransactionalId, ""),
),
)
def test_returns_expected_value(self, annotation: type, expected: object) -> None:
assert get_implicit_default(annotation) == expected

def test_raises_not_implemented_error_for_records(self) -> None:
with pytest.raises(NotImplementedError):
get_implicit_default(Records)


class TestGetTaggedFieldDefault:
def test_raises_type_error_for_optional_field(self) -> None:
@dataclass
class A:
a: u8 | None

[field] = fields(A)

with pytest.raises(
TypeError,
match=r"Optional fields should have None as explicit default",
):
get_tagged_field_default(field)

def test_raises_type_error_for_tuple_field(self) -> None:
@dataclass
class A:
a: tuple[u8, ...]

[field] = fields(A)

with pytest.raises(
TypeError,
match=r"Tuple fields should have the empty tuple as explicit default",
):
get_tagged_field_default(field)

def test_can_get_default_for_primitive_field(self) -> None:
@dataclass
class A:
a: u8

[field] = fields(A)
assert get_tagged_field_default(field) == 0

def test_can_get_default_for_entity_field(self) -> None:
@dataclass
class A:
a: u8

@dataclass
class B:
b: A

[field] = fields(B)
assert get_tagged_field_default(field) == A(a=u8(0))

def test_returns_explicit_default_if_defined(self) -> None:
@dataclass
class A:
a: u8 = u8(1)

[field] = fields(A)
assert get_tagged_field_default(field) == u8(1)
Loading