diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 24544a971..ac674ef34 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -9,6 +9,7 @@ use crate::build_tools::py_schema_err; use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD}; use crate::definitions::DefinitionsBuilder; use crate::tools::{truncate_safe_repr, SchemaDict}; +use crate::PydanticSerializationUnexpectedValue; use super::{ infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerCheck, @@ -89,7 +90,8 @@ fn to_python( } } - if retry_with_lax_check { + // If extra.check is SerCheck::Strict, we're in a nested union + if extra.check != SerCheck::Strict && retry_with_lax_check { new_extra.check = SerCheck::Lax; for comb_serializer in choices { if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) { @@ -98,8 +100,17 @@ fn to_python( } } - for err in &errors { - extra.warnings.custom_warning(err.to_string()); + // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings + if extra.check == SerCheck::None { + for err in &errors { + extra.warnings.custom_warning(err.to_string()); + } + } + // Otherwise, if we've encountered errors, return them to the parent union, which should take + // care of the formatting for us + else if !errors.is_empty() { + let message = errors.iter().map(ToString::to_string).collect::>().join("\n"); + return Err(PydanticSerializationUnexpectedValue::new_err(Some(message))); } infer_to_python(value, include, exclude, extra) @@ -122,7 +133,8 @@ fn json_key<'a>( } } - if retry_with_lax_check { + // If extra.check is SerCheck::Strict, we're in a nested union + if extra.check != SerCheck::Strict && retry_with_lax_check { new_extra.check = SerCheck::Lax; for comb_serializer in choices { if let Ok(v) = comb_serializer.json_key(key, &new_extra) { @@ -131,10 +143,18 @@ fn json_key<'a>( } } - for err in &errors { - extra.warnings.custom_warning(err.to_string()); + // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings + if extra.check == SerCheck::None { + for err in &errors { + extra.warnings.custom_warning(err.to_string()); + } + } + // Otherwise, if we've encountered errors, return them to the parent union, which should take + // care of the formatting for us + else if !errors.is_empty() { + let message = errors.iter().map(ToString::to_string).collect::>().join("\n"); + return Err(PydanticSerializationUnexpectedValue::new_err(Some(message))); } - infer_json_key(key, extra) } @@ -160,7 +180,8 @@ fn serde_serialize( } } - if retry_with_lax_check { + // If extra.check is SerCheck::Strict, we're in a nested union + if extra.check != SerCheck::Strict && retry_with_lax_check { new_extra.check = SerCheck::Lax; for comb_serializer in choices { if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) { @@ -169,8 +190,14 @@ fn serde_serialize( } } - for err in &errors { - extra.warnings.custom_warning(err.to_string()); + // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings + if extra.check == SerCheck::None { + for err in &errors { + extra.warnings.custom_warning(err.to_string()); + } + } else { + // NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors + // will have to be returned here } infer_serialize(value, serializer, include, exclude, extra) diff --git a/tests/serializers/test_union.py b/tests/serializers/test_union.py index 8b6d6f128..66ec3b9b8 100644 --- a/tests/serializers/test_union.py +++ b/tests/serializers/test_union.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import json import uuid @@ -778,3 +780,223 @@ class ModelB: model_b = ModelB(field=1) assert s.to_python(model_a) == {'field': 1, 'TAG': 'a'} assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'} + + +class ModelDog: + def __init__(self, type_: Literal['dog']) -> None: + self.type_ = 'dog' + + +class ModelCat: + def __init__(self, type_: Literal['cat']) -> None: + self.type_ = 'cat' + + +class ModelAlien: + def __init__(self, type_: Literal['alien']) -> None: + self.type_ = 'alien' + + +@pytest.fixture +def model_a_b_union_schema() -> core_schema.UnionSchema: + return core_schema.union_schema( + [ + core_schema.model_schema( + cls=ModelA, + schema=core_schema.model_fields_schema( + fields={ + 'a': core_schema.model_field(core_schema.str_schema()), + 'b': core_schema.model_field(core_schema.str_schema()), + }, + ), + ), + core_schema.model_schema( + cls=ModelB, + schema=core_schema.model_fields_schema( + fields={ + 'c': core_schema.model_field(core_schema.str_schema()), + 'd': core_schema.model_field(core_schema.str_schema()), + }, + ), + ), + ] + ) + + +@pytest.fixture +def union_of_unions_schema(model_a_b_union_schema: core_schema.UnionSchema) -> core_schema.UnionSchema: + return core_schema.union_schema( + [ + model_a_b_union_schema, + core_schema.union_schema( + [ + core_schema.model_schema( + cls=ModelCat, + schema=core_schema.model_fields_schema( + fields={ + 'type_': core_schema.model_field(core_schema.literal_schema(['cat'])), + }, + ), + ), + core_schema.model_schema( + cls=ModelDog, + schema=core_schema.model_fields_schema( + fields={ + 'type_': core_schema.model_field(core_schema.literal_schema(['dog'])), + }, + ), + ), + ] + ), + ] + ) + + +@pytest.mark.parametrize( + 'input,expected', + [ + (ModelA(a='a', b='b'), {'a': 'a', 'b': 'b'}), + (ModelB(c='c', d='d'), {'c': 'c', 'd': 'd'}), + (ModelCat(type_='cat'), {'type_': 'cat'}), + (ModelDog(type_='dog'), {'type_': 'dog'}), + ], +) +def test_union_of_unions_of_models(union_of_unions_schema: core_schema.UnionSchema, input: Any, expected: Any) -> None: + s = SchemaSerializer(union_of_unions_schema) + assert s.to_python(input, warnings='error') == expected + + +def test_union_of_unions_of_models_invalid_variant(union_of_unions_schema: core_schema.UnionSchema) -> None: + s = SchemaSerializer(union_of_unions_schema) + # All warnings should be available + messages = [ + 'Expected `ModelA` but got `ModelAlien`', + 'Expected `ModelB` but got `ModelAlien`', + 'Expected `ModelCat` but got `ModelAlien`', + 'Expected `ModelDog` but got `ModelAlien`', + ] + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + s.to_python(ModelAlien(type_='alien')) + for m in messages: + assert m in str(w[0].message) + + +@pytest.fixture +def tagged_union_of_unions_schema(model_a_b_union_schema: core_schema.UnionSchema) -> core_schema.UnionSchema: + return core_schema.union_schema( + [ + model_a_b_union_schema, + core_schema.tagged_union_schema( + discriminator='type_', + choices={ + 'cat': core_schema.model_schema( + cls=ModelCat, + schema=core_schema.model_fields_schema( + fields={ + 'type_': core_schema.model_field(core_schema.literal_schema(['cat'])), + }, + ), + ), + 'dog': core_schema.model_schema( + cls=ModelDog, + schema=core_schema.model_fields_schema( + fields={ + 'type_': core_schema.model_field(core_schema.literal_schema(['dog'])), + }, + ), + ), + }, + ), + ] + ) + + +@pytest.mark.parametrize( + 'input,expected', + [ + (ModelA(a='a', b='b'), {'a': 'a', 'b': 'b'}), + (ModelB(c='c', d='d'), {'c': 'c', 'd': 'd'}), + (ModelCat(type_='cat'), {'type_': 'cat'}), + (ModelDog(type_='dog'), {'type_': 'dog'}), + ], +) +def test_union_of_unions_of_models_with_tagged_union( + tagged_union_of_unions_schema: core_schema.UnionSchema, input: Any, expected: Any +) -> None: + s = SchemaSerializer(tagged_union_of_unions_schema) + assert s.to_python(input, warnings='error') == expected + + +def test_union_of_unions_of_models_with_tagged_union_invalid_variant( + tagged_union_of_unions_schema: core_schema.UnionSchema, +) -> None: + s = SchemaSerializer(tagged_union_of_unions_schema) + # All warnings should be available + messages = [ + 'Expected `ModelA` but got `ModelAlien`', + 'Expected `ModelB` but got `ModelAlien`', + 'Expected `ModelCat` but got `ModelAlien`', + 'Expected `ModelDog` but got `ModelAlien`', + ] + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + s.to_python(ModelAlien(type_='alien')) + for m in messages: + assert m in str(w[0].message) + + +@pytest.mark.parametrize( + 'input,expected', + [ + ({True: '1'}, b'{"true":"1"}'), + ({1: '1'}, b'{"1":"1"}'), + ({2.3: '1'}, b'{"2.3":"1"}'), + ({'a': 'b'}, b'{"a":"b"}'), + ], +) +def test_union_of_unions_of_models_with_tagged_union_json_key_serialization( + input: dict[bool | int | float | str, str], expected: bytes +) -> None: + s = SchemaSerializer( + core_schema.dict_schema( + keys_schema=core_schema.union_schema( + [ + core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]), + core_schema.union_schema([core_schema.float_schema(), core_schema.str_schema()]), + ] + ), + values_schema=core_schema.str_schema(), + ) + ) + + assert s.to_json(input, warnings='error') == expected + + +@pytest.mark.parametrize( + 'input,expected', + [ + ({'key': True}, b'{"key":true}'), + ({'key': 1}, b'{"key":1}'), + ({'key': 2.3}, b'{"key":2.3}'), + ({'key': 'a'}, b'{"key":"a"}'), + ], +) +def test_union_of_unions_of_models_with_tagged_union_json_serialization( + input: dict[str, bool | int | float | str], expected: bytes +) -> None: + s = SchemaSerializer( + core_schema.dict_schema( + keys_schema=core_schema.str_schema(), + values_schema=core_schema.union_schema( + [ + core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]), + core_schema.union_schema([core_schema.float_schema(), core_schema.str_schema()]), + ] + ), + ) + ) + + assert s.to_json(input, warnings='error') == expected