diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 72cbc34d4..ac674ef34 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -180,7 +180,8 @@ fn serde_serialize( } } - if retry_with_lax_check { + // If extra.check is SerCheck::Strict, we're in a nested union + if extra.check != SerCheck::Strict && retry_with_lax_check { new_extra.check = SerCheck::Lax; for comb_serializer in choices { if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) { @@ -189,8 +190,14 @@ fn serde_serialize( } } - for err in &errors { - extra.warnings.custom_warning(err.to_string()); + // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings + if extra.check == SerCheck::None { + for err in &errors { + extra.warnings.custom_warning(err.to_string()); + } + } else { + // NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors + // will have to be returned here } infer_serialize(value, serializer, include, exclude, extra) diff --git a/tests/serializers/test_union.py b/tests/serializers/test_union.py index 20ac832a7..66ec3b9b8 100644 --- a/tests/serializers/test_union.py +++ b/tests/serializers/test_union.py @@ -958,7 +958,7 @@ def test_union_of_unions_of_models_with_tagged_union_invalid_variant( ], ) def test_union_of_unions_of_models_with_tagged_union_json_key_serialization( - input: bool | int | float | str, expected: bytes + input: dict[bool | int | float | str, str], expected: bytes ) -> None: s = SchemaSerializer( core_schema.dict_schema( @@ -973,3 +973,30 @@ def test_union_of_unions_of_models_with_tagged_union_json_key_serialization( ) assert s.to_json(input, warnings='error') == expected + + +@pytest.mark.parametrize( + 'input,expected', + [ + ({'key': True}, b'{"key":true}'), + ({'key': 1}, b'{"key":1}'), + ({'key': 2.3}, b'{"key":2.3}'), + ({'key': 'a'}, b'{"key":"a"}'), + ], +) +def test_union_of_unions_of_models_with_tagged_union_json_serialization( + input: dict[str, bool | int | float | str], expected: bytes +) -> None: + s = SchemaSerializer( + core_schema.dict_schema( + keys_schema=core_schema.str_schema(), + values_schema=core_schema.union_schema( + [ + core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]), + core_schema.union_schema([core_schema.float_schema(), core_schema.str_schema()]), + ] + ), + ) + ) + + assert s.to_json(input, warnings='error') == expected