diff --git a/src/serializers/computed_fields.rs b/src/serializers/computed_fields.rs index 8a1f041ae..e8384128e 100644 --- a/src/serializers/computed_fields.rs +++ b/src/serializers/computed_fields.rs @@ -1,3 +1,4 @@ +use pyo3::exceptions::PyAttributeError; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString}; use pyo3::{intern, PyTraverseError, PyVisit}; @@ -12,6 +13,8 @@ use crate::serializers::shared::{BuildSerializer, CombinedSerializer, PydanticSe use crate::tools::SchemaDict; use super::errors::py_err_se_err; +use super::errors::PydanticSerializationError; +use super::ob_type::{ObType, ObTypeLookup}; use super::Extra; #[derive(Debug, Clone)] @@ -52,9 +55,10 @@ impl ComputedFields { // Do not serialize computed fields return Ok(()); } - for computed_fields in &self.0 { - computed_fields.to_python(model, output_dict, filter, include, exclude, extra)?; + for computed_field in &self.0 { + computed_field.to_python(model, output_dict, filter, include, exclude, extra)?; } + Ok(()) } @@ -102,6 +106,7 @@ struct ComputedField { serializer: CombinedSerializer, alias: String, alias_py: Py, + has_ser_func: bool, } impl ComputedField { @@ -123,6 +128,7 @@ impl ComputedField { serializer, alias: alias_py.extract()?, alias_py: alias_py.into_py(py), + has_ser_func: has_ser_function(return_schema), }) } @@ -139,8 +145,7 @@ impl ComputedField { let property_name_py = self.property_name_py.as_ref(py); if let Some((next_include, next_exclude)) = filter.key_filter(property_name_py, include, exclude)? { - let next_value = model.getattr(property_name_py)?; - + let next_value = get_next_value(self, model, extra.ob_type_lookup)?; let value = self .serializer .to_python(next_value, next_include, next_exclude, extra)?; @@ -177,9 +182,8 @@ impl_py_gc_traverse!(ComputedFieldSerializer<'_> { computed_field }); impl<'py> Serialize for ComputedFieldSerializer<'py> { fn serialize(&self, serializer: S) -> Result { - let py = self.model.py(); - let property_name_py = self.computed_field.property_name_py.as_ref(py); - let next_value = self.model.getattr(property_name_py).map_err(py_err_se_err)?; + let next_value = + get_next_value(self.computed_field, self.model, self.extra.ob_type_lookup).map_err(py_err_se_err)?; let s = PydanticSerializer::new( next_value, &self.computed_field.serializer, @@ -190,3 +194,60 @@ impl<'py> Serialize for ComputedFieldSerializer<'py> { s.serialize(serializer) } } + +fn has_ser_function(schema: &PyDict) -> bool { + let py = schema.py(); + let ser_schema = schema + .get_as::<&PyDict>(intern!(py, "serialization")) + .unwrap_or_default(); + ser_schema.is_some_and(|s| s.contains(intern!(py, "function")).unwrap_or_default()) +} + +fn get_next_value<'a>( + field: &'a ComputedField, + input_value: &'a PyAny, + ob_type_lookup: &'a ObTypeLookup, +) -> PyResult<&'a PyAny> { + let py = input_value.py(); + // Backwards compatiability. + let mut legacy_attr_error: Option = None; + let legacy_result = match ob_type_lookup.get_type(input_value) { + ObType::Dataclass | ObType::PydanticSerializable => { + match input_value.getattr(field.property_name_py.as_ref(py)) { + Ok(attr) => Ok(Some(attr)), + Err(err) => { + if err.get_type(py).is_subclass_of::()? { + legacy_attr_error = Some(err); + Ok(None) + } else { + Err(err) + } + } + } + } + _ => Ok(None), + }; + match legacy_result { + Ok(opt) => { + if let Some(legacy_next_value) = opt { + return Ok(legacy_next_value); + } + } + Err(err) => return Err(err), + }; + + // Default behavior: If custom serialization function provided, compute value based on input. + if field.has_ser_func { + return Ok(input_value); + } + // Fallback behavior: Check if computed field is a property of input object + // (i.e. in some cases input_value can be ObType::Unknown) + if let Ok(next_value_from_input) = input_value.getattr(field.property_name_py.as_ref(py)) { + return Ok(next_value_from_input); + } + + Err(legacy_attr_error.unwrap_or(PydanticSerializationError::new_err(format!( + "No serialization function found for '{}'", + field.property_name + )))) +} diff --git a/tests/serializers/test_typed_dict.py b/tests/serializers/test_typed_dict.py index df507a248..e9f1fc1d9 100644 --- a/tests/serializers/test_typed_dict.py +++ b/tests/serializers/test_typed_dict.py @@ -1,11 +1,11 @@ import json -from typing import Any, Dict +from typing import Any, Dict, List import pytest from dirty_equals import IsStrictDict from typing_extensions import TypedDict -from pydantic_core import SchemaSerializer, core_schema +from pydantic_core import PydanticSerializationError, SchemaSerializer, core_schema @pytest.mark.parametrize('extra_behavior_kw', [{}, {'extra_behavior': 'ignore'}, {'extra_behavior': None}]) @@ -333,3 +333,109 @@ def test_extra_custom_serializer(): m = {'extra': 'extra'} assert s.to_python(m) == {'extra': 'extra bam!'} + + +def test_computed_fields_with_plain_serializer_function(): + def ser_x(v: dict): + two = v['0'] + v['1'] + 1 + return two + + schema = core_schema.typed_dict_schema( + { + '0': core_schema.typed_dict_field(core_schema.int_schema()), + '1': core_schema.typed_dict_field(core_schema.int_schema()), + }, + computed_fields=[ + core_schema.computed_field( + '2', core_schema.int_schema(serialization=core_schema.plain_serializer_function_ser_schema(ser_x)) + ) + ], + ) + s = SchemaSerializer(schema) + value = {'0': 0, '1': 1} + assert s.to_python(value) == {'0': 0, '1': 1, '2': 2} + assert s.to_json(value) == b'{"0":0,"1":1,"2":2}' + + def ser_foo(_v: dict): + return 'bar' + + schema = core_schema.typed_dict_schema( + {}, + computed_fields=[ + core_schema.computed_field( + 'foo', core_schema.str_schema(serialization=core_schema.plain_serializer_function_ser_schema(ser_foo)) + ) + ], + ) + s = SchemaSerializer(schema) + assert s.to_python({}) == {'foo': 'bar'} + assert s.to_json({}) == b'{"foo":"bar"}' + + +def test_computed_fields_with_warpped_serializer_function(): + def ser_to_upper(string_arr: List[str]) -> List[str]: + return [s.upper() for s in string_arr] + + def ser_columns(v: dict, serializer: core_schema.SerializerFunctionWrapHandler, _) -> str: + column_keys = serializer([key for key in v.keys()]) + return column_keys + + schema = core_schema.typed_dict_schema( + { + 'one': core_schema.typed_dict_field(core_schema.int_schema()), + 'two': core_schema.typed_dict_field(core_schema.int_schema()), + 'three': core_schema.typed_dict_field(core_schema.int_schema()), + }, + computed_fields=[ + core_schema.computed_field( + 'columns', + core_schema.int_schema( + serialization=core_schema.wrap_serializer_function_ser_schema( + ser_columns, + info_arg=True, + schema=core_schema.list_schema( + serialization=core_schema.plain_serializer_function_ser_schema(ser_to_upper) + ), + ) + ), + ) + ], + ) + s = SchemaSerializer(schema) + value = {'one': 1, 'two': 2, 'three': 3} + assert s.to_python(value) == {'one': 1, 'two': 2, 'three': 3, 'columns': ['ONE', 'TWO', 'THREE']} + assert s.to_json(value) == b'{"one":1,"two":2,"three":3,"columns":["ONE","TWO","THREE"]}' + + +def test_computed_fields_with_typed_dict_model(): + class Model(TypedDict): + x: int + + def ser_y(v: Any) -> str: + return f'{v["x"]}.00' + + s = SchemaSerializer( + core_schema.typed_dict_schema( + {'x': core_schema.typed_dict_field(core_schema.int_schema())}, + computed_fields=[ + core_schema.computed_field( + 'y', core_schema.str_schema(serialization=core_schema.plain_serializer_function_ser_schema(ser_y)) + ) + ], + ) + ) + assert s.to_python(Model(x=1000)) == {'x': 1000, 'y': '1000.00'} + + +def test_computed_fields_without_ser_function(): + class Model(TypedDict): + x: int + + s = SchemaSerializer( + core_schema.typed_dict_schema( + {'x': core_schema.typed_dict_field(core_schema.int_schema())}, + computed_fields=[core_schema.computed_field('y', core_schema.str_schema())], + ) + ) + with pytest.raises(PydanticSerializationError, match="^No serialization function found for 'y'$"): + s.to_python(Model(x=1000))