From 72b1c8e7261a08be044ab11b1aeeb950394904e0 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Mon, 4 Nov 2024 16:44:00 +0000 Subject: [PATCH 1/2] add test for wrap serializer in function --- tests/serializers/test_union.py | 139 +++++++++++++++++++++----------- 1 file changed, 93 insertions(+), 46 deletions(-) diff --git a/tests/serializers/test_union.py b/tests/serializers/test_union.py index 8b6d6f128..e1295a085 100644 --- a/tests/serializers/test_union.py +++ b/tests/serializers/test_union.py @@ -60,53 +60,36 @@ def __init__(self, c, d): @pytest.fixture(scope='module') def model_serializer() -> SchemaSerializer: return SchemaSerializer( - { - 'type': 'union', - 'choices': [ - { - 'type': 'model', - 'cls': ModelA, - 'schema': { - 'type': 'model-fields', - 'fields': { - 'a': {'type': 'model-field', 'schema': {'type': 'bytes'}}, - 'b': { - 'type': 'model-field', - 'schema': { - 'type': 'float', - 'serialization': { - 'type': 'format', - 'formatting_string': '0.1f', - 'when_used': 'unless-none', - }, - }, - }, - }, - }, - }, - { - 'type': 'model', - 'cls': ModelB, - 'schema': { - 'type': 'model-fields', - 'fields': { - 'c': {'type': 'model-field', 'schema': {'type': 'bytes'}}, - 'd': { - 'type': 'model-field', - 'schema': { - 'type': 'float', - 'serialization': { - 'type': 'format', - 'formatting_string': '0.2f', - 'when_used': 'unless-none', - }, - }, - }, - }, - }, - }, + core_schema.union_schema( + [ + core_schema.model_schema( + ModelA, + core_schema.model_fields_schema( + { + 'a': core_schema.model_field(core_schema.bytes_schema()), + 'b': core_schema.model_field( + core_schema.float_schema( + serialization=core_schema.format_ser_schema('0.1f', when_used='unless-none') + ) + ), + } + ), + ), + core_schema.model_schema( + ModelB, + core_schema.model_fields_schema( + { + 'c': core_schema.model_field(core_schema.bytes_schema()), + 'd': core_schema.model_field( + core_schema.float_schema( + serialization=core_schema.format_ser_schema('0.2f', when_used='unless-none') + ) + ), + } + ), + ), ], - } + ) ) @@ -778,3 +761,67 @@ class ModelB: model_b = ModelB(field=1) assert s.to_python(model_a) == {'field': 1, 'TAG': 'a'} assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'} + + +def test_union_model_wrap_serializer(): + def wrap_serializer(value, handler): + return handler(value) + + class Data: + pass + + class ModelA: + a: Data + + class ModelB: + a: Data + + model_serializer = SchemaSerializer( + core_schema.union_schema( + [ + core_schema.model_schema( + ModelA, + core_schema.model_fields_schema( + { + 'a': core_schema.model_field( + core_schema.model_schema( + Data, + core_schema.model_fields_schema({}), + ) + ), + }, + ), + serialization=core_schema.wrap_serializer_function_ser_schema(wrap_serializer), + ), + core_schema.model_schema( + ModelB, + core_schema.model_fields_schema( + { + 'a': core_schema.model_field( + core_schema.model_schema( + Data, + core_schema.model_fields_schema({}), + ) + ), + }, + ), + serialization=core_schema.wrap_serializer_function_ser_schema(wrap_serializer), + ), + ], + ) + ) + + input_value = ModelA() + input_value.a = Data() + + assert model_serializer.to_python(input_value) == {'a': {}} + assert model_serializer.to_python(input_value, mode='json') == {'a': {}} + assert model_serializer.to_json(input_value) == b'{"a":{}}' + + # add some additional attribute, should be ignored & not break serialization + + input_value.a._a = 'foo' + + assert model_serializer.to_python(input_value) == {'a': {}} + assert model_serializer.to_python(input_value, mode='json') == {'a': {}} + assert model_serializer.to_json(input_value) == b'{"a":{}}' From e375b85c31aba6c460f0d2417e7c7096243a2ae7 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Mon, 4 Nov 2024 17:00:51 +0000 Subject: [PATCH 2/2] fix union wrap combo --- src/serializers/extra.rs | 4 ++++ src/serializers/fields.rs | 23 ++++++++++---------- src/serializers/type_serializers/function.rs | 15 +++++++++++++ 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index ba47445b5..28aeea133 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -198,6 +198,10 @@ impl<'a> Extra<'a> { pub fn serialize_infer<'py>(&'py self, value: &'py Bound<'py, PyAny>) -> super::infer::SerializeInfer<'py> { super::infer::SerializeInfer::new(value, None, None, self) } + + pub(crate) fn model_type_name(&self) -> Option> { + self.model.and_then(|model| model.get_type().name().ok()) + } } #[derive(Clone, Copy, PartialEq, Eq)] diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index 6cd76c36b..4498d8fa7 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -200,7 +200,15 @@ impl GeneralFieldsSerializer { }; output_dict.set_item(key, value)?; } else if field_extra.check == SerCheck::Strict { - return Err(PydanticSerializationUnexpectedValue::new_err(None)); + let type_name = field_extra.model_type_name(); + return Err(PydanticSerializationUnexpectedValue::new_err(Some(format!( + "Unexpected field `{key}`{for_type_name}", + for_type_name = if let Some(type_name) = type_name { + format!(" for type `{type_name}`") + } else { + String::new() + }, + )))); } } } @@ -212,22 +220,15 @@ impl GeneralFieldsSerializer { && self.required_fields > used_req_fields { let required_fields = self.required_fields; - let type_name = match extra.model { - Some(model) => model - .get_type() - .qualname() - .ok() - .unwrap_or_else(|| PyString::new_bound(py, "")) - .to_string(), - None => "".to_string(), - }; + let type_name = extra.model_type_name(); let field_value = match extra.model { Some(model) => truncate_safe_repr(model, Some(100)), None => "".to_string(), }; Err(PydanticSerializationUnexpectedValue::new_err(Some(format!( - "Expected {required_fields} fields but got {used_req_fields} for type `{type_name}` with value `{field_value}` - serialized value may not be as expected." + "Expected {required_fields} fields but got {used_req_fields}{for_type_name} with value `{field_value}` - serialized value may not be as expected.", + for_type_name = if let Some(type_name) = type_name { format!(" for type `{type_name}`") } else { String::new() }, )))) } else { Ok(output_dict) diff --git a/src/serializers/type_serializers/function.rs b/src/serializers/type_serializers/function.rs index d0fea665f..f01b0c556 100644 --- a/src/serializers/type_serializers/function.rs +++ b/src/serializers/type_serializers/function.rs @@ -179,6 +179,13 @@ impl FunctionPlainSerializer { .expect("fallback_serializer unexpectedly none") .as_ref() } + + fn retry_with_lax_check(&self) -> bool { + self.fallback_serializer + .as_ref() + .map_or(false, |f| f.retry_with_lax_check()) + || self.return_serializer.retry_with_lax_check() + } } fn on_error(py: Python, err: PyErr, function_name: &str, extra: &Extra) -> PyResult<()> { @@ -271,6 +278,10 @@ macro_rules! function_type_serializer { fn get_name(&self) -> &str { &self.name } + + fn retry_with_lax_check(&self) -> bool { + self.retry_with_lax_check() + } } }; } @@ -409,6 +420,10 @@ impl FunctionWrapSerializer { fn get_fallback_serializer(&self) -> &CombinedSerializer { self.serializer.as_ref() } + + fn retry_with_lax_check(&self) -> bool { + self.serializer.retry_with_lax_check() || self.return_serializer.retry_with_lax_check() + } } impl_py_gc_traverse!(FunctionWrapSerializer {