-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: Correct the logic for tagged field defaults
This fixes an issue where the logic for tagged fields was not in sync with upstream. This is in the realm of undocumented features of the Kafka protocol, so there is no accurate documentation to reference. The closest to documentation that exists, is [this section in a source code README][messages-readme]. [messages-readme]:https://github.com/apache/kafka/tree/4419677e06b77ed9103c179d713450b1eb7881a1/clients/src/main/resources/common/message#deserializing-messages There was a [long discussion] about this behavior on the Kafka dev mailing list. [long discussion]: https://www.mail-archive.com/dev@kafka.apache.org/msg144437.html This fixes #215.
- Loading branch information
1 parent
a4a78b1
commit 8be3f71
Showing
6 changed files
with
327 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from __future__ import annotations | ||
|
||
from collections.abc import Mapping | ||
from dataclasses import MISSING | ||
from dataclasses import Field | ||
from dataclasses import fields | ||
from datetime import timedelta | ||
from types import MappingProxyType | ||
from typing import Final | ||
from typing import TypeVar | ||
from typing import assert_never | ||
from uuid import UUID | ||
|
||
from kio.serial._introspect import EntityField | ||
from kio.serial._introspect import EntityTupleField | ||
from kio.serial._introspect import PrimitiveField | ||
from kio.serial._introspect import PrimitiveTupleField | ||
from kio.serial._introspect import classify_field | ||
from kio.serial._introspect import is_optional | ||
from kio.serial.readers import tz_aware_from_i64 | ||
from kio.static.constants import uuid_zero | ||
from kio.static.primitive import Records | ||
from kio.static.primitive import TZAware | ||
from kio.static.primitive import f64 | ||
from kio.static.primitive import i8 | ||
from kio.static.primitive import i16 | ||
from kio.static.primitive import i32 | ||
from kio.static.primitive import i32Timedelta | ||
from kio.static.primitive import i64 | ||
from kio.static.primitive import i64Timedelta | ||
from kio.static.primitive import u8 | ||
from kio.static.primitive import u16 | ||
from kio.static.primitive import u32 | ||
from kio.static.primitive import u64 | ||
from kio.static.protocol import Entity | ||
|
||
T = TypeVar("T") | ||
|
||
primitive_implicit_defaults: Final[Mapping[type, object]] = MappingProxyType( | ||
{ | ||
u8: u8(0), | ||
u16: u16(0), | ||
u32: u32(0), | ||
u64: u64(0), | ||
i8: i8(0), | ||
i16: i16(0), | ||
i32: i32(0), | ||
i64: i64(0), | ||
f64: f64(0.0), | ||
i32Timedelta: i32Timedelta.parse(timedelta(0)), | ||
i64Timedelta: i64Timedelta.parse(timedelta(0)), | ||
TZAware: tz_aware_from_i64(i64(0)), | ||
UUID: uuid_zero, | ||
str: "", | ||
bytes: b"", | ||
} | ||
) | ||
|
||
|
||
def get_implicit_default(field_type: type[T]) -> T: | ||
# Records fields have null as implicit default, supporting this requires changing | ||
# code generation to always expect null for a tagged records field. As of writing | ||
# there are no tagged records fields, or other occurrences where we would need such | ||
# implicit default, so this can be safely deferred. | ||
if issubclass(field_type, Records): | ||
raise NotImplementedError("Tagged record fields are not supported") | ||
|
||
try: | ||
# mypy has no way of typing a mapping as T -> T on a per-item level. | ||
return primitive_implicit_defaults[field_type] # type: ignore[return-value] | ||
except KeyError: | ||
return primitive_implicit_defaults[field_type.__bases__[0]] # type: ignore[return-value] | ||
|
||
|
||
U = TypeVar("U", bound=Entity) | ||
|
||
|
||
def get_tagged_field_default(field: Field[U]) -> U: | ||
if field.default is not MISSING: | ||
return field.default | ||
|
||
if is_optional(field): | ||
raise TypeError("Optional fields should have None as explicit default") | ||
|
||
field_class = classify_field(field) | ||
|
||
if isinstance(field_class, PrimitiveField): | ||
return get_implicit_default(field_class.type_) | ||
elif isinstance(field_class, EntityField): | ||
return field.type( | ||
**{ | ||
nested_field.name: get_tagged_field_default(nested_field) | ||
for nested_field in fields(field.type) | ||
} | ||
) | ||
elif isinstance(field_class, PrimitiveTupleField | EntityTupleField): | ||
raise TypeError("Tuple fields should have the empty tuple as explicit default") | ||
|
||
assert_never(field_class) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from dataclasses import dataclass | ||
from dataclasses import fields | ||
from datetime import timedelta | ||
from uuid import UUID | ||
|
||
import pytest | ||
|
||
from kio.schema.types import BrokerId | ||
from kio.schema.types import GroupId | ||
from kio.schema.types import ProducerId | ||
from kio.schema.types import TopicName | ||
from kio.schema.types import TransactionalId | ||
from kio.serial._implicit_defaults import get_implicit_default | ||
from kio.serial._implicit_defaults import get_tagged_field_default | ||
from kio.serial.readers import tz_aware_from_i64 | ||
from kio.static.primitive import Records | ||
from kio.static.primitive import TZAware | ||
from kio.static.primitive import f64 | ||
from kio.static.primitive import i8 | ||
from kio.static.primitive import i16 | ||
from kio.static.primitive import i32 | ||
from kio.static.primitive import i32Timedelta | ||
from kio.static.primitive import i64 | ||
from kio.static.primitive import i64Timedelta | ||
from kio.static.primitive import u8 | ||
from kio.static.primitive import u16 | ||
from kio.static.primitive import u32 | ||
from kio.static.primitive import u64 | ||
|
||
|
||
class TestGetImplicitDefault: | ||
@pytest.mark.parametrize( | ||
("annotation", "expected"), | ||
( | ||
(u8, 0), | ||
(u16, 0), | ||
(u32, 0), | ||
(u64, 0), | ||
(i8, 0), | ||
(i16, 0), | ||
(i32, 0), | ||
(i64, 0), | ||
(f64, 0.0), | ||
(i32Timedelta, timedelta(0)), | ||
(i64Timedelta, timedelta(0)), | ||
(TZAware, tz_aware_from_i64(i64(0))), | ||
(UUID, UUID(int=0)), | ||
(str, ""), | ||
(bytes, b""), | ||
(BrokerId, 0), | ||
(GroupId, ""), | ||
(ProducerId, 0), | ||
(TopicName, ""), | ||
(TransactionalId, ""), | ||
), | ||
) | ||
def test_returns_expected_value(self, annotation: type, expected: object) -> None: | ||
assert get_implicit_default(annotation) == expected | ||
|
||
def test_raises_not_implemented_error_for_records(self) -> None: | ||
with pytest.raises(NotImplementedError): | ||
get_implicit_default(Records) | ||
|
||
|
||
class TestGetTaggedFieldDefault: | ||
def test_raises_type_error_for_optional_field(self) -> None: | ||
@dataclass | ||
class A: | ||
a: u8 | None | ||
|
||
[field] = fields(A) | ||
|
||
with pytest.raises( | ||
TypeError, | ||
match=r"Optional fields should have None as explicit default", | ||
): | ||
get_tagged_field_default(field) | ||
|
||
def test_raises_type_error_for_tuple_field(self) -> None: | ||
@dataclass | ||
class A: | ||
a: tuple[u8, ...] | ||
|
||
[field] = fields(A) | ||
|
||
with pytest.raises( | ||
TypeError, | ||
match=r"Tuple fields should have the empty tuple as explicit default", | ||
): | ||
get_tagged_field_default(field) | ||
|
||
def test_can_get_default_for_primitive_field(self) -> None: | ||
@dataclass | ||
class A: | ||
a: u8 | ||
|
||
[field] = fields(A) | ||
assert get_tagged_field_default(field) == 0 | ||
|
||
def test_can_get_default_for_entity_field(self) -> None: | ||
@dataclass | ||
class A: | ||
a: u8 | ||
|
||
@dataclass | ||
class B: | ||
b: A | ||
|
||
[field] = fields(B) | ||
assert get_tagged_field_default(field) == A(a=u8(0)) | ||
|
||
def test_returns_explicit_default_if_defined(self) -> None: | ||
@dataclass | ||
class A: | ||
a: u8 = u8(1) | ||
|
||
[field] = fields(A) | ||
assert get_tagged_field_default(field) == u8(1) |
Oops, something went wrong.