Skip to content

Commit

Permalink
fix(union_serializer): do not raise warnings in nested unions
Browse files Browse the repository at this point in the history
In case unions of unions are used, this will bubble-up the errors rather
than warning immediately. If no solution is found among all serializers
by the top-level union, it will warn as before.

Signed-off-by: Luka Peschke <mail@lukapeschke.com>
  • Loading branch information
lukapeschke committed Nov 6, 2024
1 parent 4cb82bf commit 043fce1
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 7 deletions.
34 changes: 27 additions & 7 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::build_tools::py_schema_err;
use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD};
use crate::definitions::DefinitionsBuilder;
use crate::tools::{truncate_safe_repr, SchemaDict};
use crate::PydanticSerializationUnexpectedValue;

use super::{
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerCheck,
Expand Down Expand Up @@ -89,7 +90,8 @@ fn to_python(
}
}

if retry_with_lax_check {
// If extra.check is SerCheck::Strict, we're in a nested union
if extra.check != SerCheck::Strict && retry_with_lax_check {
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
Expand All @@ -98,8 +100,17 @@ fn to_python(
}
}

for err in &errors {
extra.warnings.custom_warning(err.to_string());
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
if extra.check == SerCheck::None {
for err in &errors {
extra.warnings.custom_warning(err.to_string());
}
}
// Otherwise, if we've encountered errors, return them to the parent union, which should take
// care of the formatting for us
else if !errors.is_empty() {
let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n");
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
}

infer_to_python(value, include, exclude, extra)
Expand All @@ -122,7 +133,8 @@ fn json_key<'a>(
}
}

if retry_with_lax_check {
// If extra.check is SerCheck::Strict, we're in a nested union
if extra.check != SerCheck::Strict && retry_with_lax_check {
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = comb_serializer.json_key(key, &new_extra) {
Expand All @@ -131,10 +143,18 @@ fn json_key<'a>(
}
}

for err in &errors {
extra.warnings.custom_warning(err.to_string());
// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
if extra.check == SerCheck::None {
for err in &errors {
extra.warnings.custom_warning(err.to_string());
}
}
// Otherwise, if we've encountered errors, return them to the parent union, which should take
// care of the formatting for us
else if !errors.is_empty() {
let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n");
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
}

infer_json_key(key, extra)
}

Expand Down
225 changes: 225 additions & 0 deletions tests/serializers/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,228 @@ class ModelB:
model_b = ModelB(field=1)
assert s.to_python(model_a) == {'field': 1, 'TAG': 'a'}
assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'}


class ModelDog:
def __init__(self, type_: Literal['dog']) -> None:
self.type_ = 'dog'


class ModelCat:
def __init__(self, type_: Literal['cat']) -> None:
self.type_ = 'cat'


class ModelAlien:
def __init__(self, type_: Literal['alien']) -> None:
self.type_ = 'alien'


@pytest.fixture
def model_a_b_union_schema() -> core_schema.UnionSchema:
return core_schema.union_schema(
[
core_schema.model_schema(
cls=ModelA,
schema=core_schema.model_fields_schema(
fields={
'a': core_schema.model_field(core_schema.str_schema()),
'b': core_schema.model_field(core_schema.str_schema()),
},
),
),
core_schema.model_schema(
cls=ModelB,
schema=core_schema.model_fields_schema(
fields={
'c': core_schema.model_field(core_schema.str_schema()),
'd': core_schema.model_field(core_schema.str_schema()),
},
),
),
]
)


@pytest.fixture
def union_of_unions_schema(model_a_b_union_schema: core_schema.UnionSchema) -> core_schema.UnionSchema:
return core_schema.union_schema(
[
model_a_b_union_schema,
core_schema.union_schema(
[
core_schema.model_schema(
cls=ModelCat,
schema=core_schema.model_fields_schema(
fields={
'type_': core_schema.model_field(core_schema.literal_schema(['cat'])),
},
),
),
core_schema.model_schema(
cls=ModelDog,
schema=core_schema.model_fields_schema(
fields={
'type_': core_schema.model_field(core_schema.literal_schema(['dog'])),
},
),
),
]
),
]
)


@pytest.mark.parametrize(
'input,expected',
[
(ModelA(a='a', b='b'), {'a': 'a', 'b': 'b'}),
(ModelB(c='c', d='d'), {'c': 'c', 'd': 'd'}),
(ModelCat(type_='cat'), {'type_': 'cat'}),
(ModelDog(type_='dog'), {'type_': 'dog'}),
],
)
def test_union_of_unions_of_models(union_of_unions_schema: core_schema.UnionSchema, input: Any, expected: Any) -> None:
s = SchemaSerializer(union_of_unions_schema)
assert s.to_python(input, warnings='error') == expected


def test_union_of_unions_of_models_invalid_variant(union_of_unions_schema: core_schema.UnionSchema) -> None:
s = SchemaSerializer(union_of_unions_schema)
# All warnings should be available
messages = [
'Expected `ModelA` but got `ModelAlien`',
'Expected `ModelB` but got `ModelAlien`',
'Expected `ModelCat` but got `ModelAlien`',
'Expected `ModelDog` but got `ModelAlien`',
]

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
s.to_python(ModelAlien(type_='alien'))
for m in messages:
assert m in str(w[0].message)


@pytest.fixture
def tagged_union_of_unions_schema(model_a_b_union_schema: core_schema.UnionSchema) -> core_schema.UnionSchema:
return core_schema.union_schema(
[
model_a_b_union_schema,
core_schema.tagged_union_schema(
discriminator='type_',
choices={
'cat': core_schema.model_schema(
cls=ModelCat,
schema=core_schema.model_fields_schema(
fields={
'type_': core_schema.model_field(core_schema.literal_schema(['cat'])),
},
),
),
'dog': core_schema.model_schema(
cls=ModelDog,
schema=core_schema.model_fields_schema(
fields={
'type_': core_schema.model_field(core_schema.literal_schema(['dog'])),
},
),
),
},
),
]
)


@pytest.mark.parametrize(
'input,expected',
[
(ModelA(a='a', b='b'), {'a': 'a', 'b': 'b'}),
(ModelB(c='c', d='d'), {'c': 'c', 'd': 'd'}),
(ModelCat(type_='cat'), {'type_': 'cat'}),
(ModelDog(type_='dog'), {'type_': 'dog'}),
],
)
def test_union_of_unions_of_models_with_tagged_union(
tagged_union_of_unions_schema: core_schema.UnionSchema, input: Any, expected: Any
) -> None:
s = SchemaSerializer(tagged_union_of_unions_schema)
assert s.to_python(input, warnings='error') == expected


def test_union_of_unions_of_models_with_tagged_union_invalid_variant(
tagged_union_of_unions_schema: core_schema.UnionSchema,
) -> None:
s = SchemaSerializer(tagged_union_of_unions_schema)
# All warnings should be available
messages = [
'Expected `ModelA` but got `ModelAlien`',
'Expected `ModelB` but got `ModelAlien`',
'Expected `ModelCat` but got `ModelAlien`',
'Expected `ModelDog` but got `ModelAlien`',
]

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
s.to_python(ModelAlien(type_='alien'))
for m in messages:
assert m in str(w[0].message)


@dataclasses.dataclass(frozen=True)
class DataClassA:
a: str


@dataclasses.dataclass(frozen=True)
class DataClassB:
b: str


@pytest.mark.parametrize(
'input,expected',
[
({True: '1'}, b'{"true":"1"}'),
({1: '1'}, b'{"1":"1"}'),
({2.3: '1'}, b'{"2.3":"1"}'),
({'a': 'b'}, b'{"a":"b"}'),
],
)
def test_union_of_unions_of_models_with_tagged_union_json_key_serialization(
input: bool | int | float | str, expected: bytes
) -> None:
s = SchemaSerializer(
core_schema.dict_schema(
keys_schema=core_schema.union_schema(
[
core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]),
core_schema.union_schema([core_schema.float_schema(), core_schema.str_schema()]),
]
),
values_schema=core_schema.str_schema(),
)
)

assert s.to_json(input, warnings='error') == expected


def test_union_of_unions_of_models_with_tagged_union_json_serialization_invalid_variant(
tagged_union_of_unions_schema: core_schema.UnionSchema,
) -> None:
s = SchemaSerializer(
core_schema.dict_schema(keys_schema=tagged_union_of_unions_schema, values_schema=core_schema.str_schema())
)

# All warnings should be available
messages = [
'Expected `ModelA` but got `ModelAlien`',
'Expected `ModelB` but got `ModelAlien`',
'Expected `ModelCat` but got `ModelAlien`',
'Expected `ModelDog` but got `ModelAlien`',
]

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
s.to_python({ModelAlien(type_='alien'): 'coucou'})
for m in messages:
assert m in str(w[0].message)

0 comments on commit 043fce1

Please sign in to comment.