Skip to content

Commit

Permalink
feat: apply same logic to serde_serialize and add non-regression test
Browse files Browse the repository at this point in the history
Signed-off-by: Luka Peschke <mail@lukapeschke.com>
  • Loading branch information
lukapeschke committed Nov 7, 2024
1 parent 5c38506 commit 5506e1f
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
13 changes: 10 additions & 3 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ fn serde_serialize<S: serde::ser::Serializer>(
}
}

if retry_with_lax_check {
// If extra.check is SerCheck::Strict, we're in a nested union
if extra.check != SerCheck::Strict && retry_with_lax_check {
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
Expand All @@ -189,8 +190,14 @@ fn serde_serialize<S: serde::ser::Serializer>(
}
}

for err in &errors {
extra.warnings.custom_warning(err.to_string());
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
if extra.check == SerCheck::None {
for err in &errors {
extra.warnings.custom_warning(err.to_string());
}
} else {
// NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors
// will have to be returned here
}

infer_serialize(value, serializer, include, exclude, extra)
Expand Down
29 changes: 28 additions & 1 deletion tests/serializers/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ def test_union_of_unions_of_models_with_tagged_union_invalid_variant(
],
)
def test_union_of_unions_of_models_with_tagged_union_json_key_serialization(
input: bool | int | float | str, expected: bytes
input: dict[bool | int | float | str, str], expected: bytes
) -> None:
s = SchemaSerializer(
core_schema.dict_schema(
Expand All @@ -973,3 +973,30 @@ def test_union_of_unions_of_models_with_tagged_union_json_key_serialization(
)

assert s.to_json(input, warnings='error') == expected


@pytest.mark.parametrize(
'input,expected',
[
({'key': True}, b'{"key":true}'),
({'key': 1}, b'{"key":1}'),
({'key': 2.3}, b'{"key":2.3}'),
({'key': 'a'}, b'{"key":"a"}'),
],
)
def test_union_of_unions_of_models_with_tagged_union_json_serialization(
input: dict[str, bool | int | float | str], expected: bytes
) -> None:
s = SchemaSerializer(
core_schema.dict_schema(
keys_schema=core_schema.str_schema(),
values_schema=core_schema.union_schema(
[
core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]),
core_schema.union_schema([core_schema.float_schema(), core_schema.str_schema()]),
]
),
)
)

assert s.to_json(input, warnings='error') == expected

0 comments on commit 5506e1f

Please sign in to comment.