From 0929f6222ad079112ff02aee9bb81432f25fb037 Mon Sep 17 00:00:00 2001 From: Michael Huang Date: Sun, 15 Oct 2023 05:31:08 -0400 Subject: [PATCH 01/12] feat: Serialize computed field without a model --- src/serializers/computed_fields.rs | 31 ++++++++++-- tests/serializers/test_typed_dict.py | 71 +++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 5 deletions(-) diff --git a/src/serializers/computed_fields.rs b/src/serializers/computed_fields.rs index 8a1f041ae..8714dde2c 100644 --- a/src/serializers/computed_fields.rs +++ b/src/serializers/computed_fields.rs @@ -52,9 +52,10 @@ impl ComputedFields { // Do not serialize computed fields return Ok(()); } - for computed_fields in &self.0 { - computed_fields.to_python(model, output_dict, filter, include, exclude, extra)?; + for computed_field in &self.0 { + computed_field.to_python(model, output_dict, filter, include, exclude, extra)?; } + Ok(()) } @@ -102,6 +103,7 @@ struct ComputedField { serializer: CombinedSerializer, alias: String, alias_py: Py, + has_ser_func: bool, } impl ComputedField { @@ -123,6 +125,7 @@ impl ComputedField { serializer, alias: alias_py.extract()?, alias_py: alias_py.into_py(py), + has_ser_func: has_ser_function(return_schema), }) } @@ -139,7 +142,11 @@ impl ComputedField { let property_name_py = self.property_name_py.as_ref(py); if let Some((next_include, next_exclude)) = filter.key_filter(property_name_py, include, exclude)? { - let next_value = model.getattr(property_name_py)?; + let next_value = if self.has_ser_func { + model + } else { + model.getattr(property_name_py)? + }; let value = self .serializer @@ -179,7 +186,11 @@ impl<'py> Serialize for ComputedFieldSerializer<'py> { fn serialize(&self, serializer: S) -> Result { let py = self.model.py(); let property_name_py = self.computed_field.property_name_py.as_ref(py); - let next_value = self.model.getattr(property_name_py).map_err(py_err_se_err)?; + let next_value = if self.computed_field.has_ser_func { + self.model + } else { + self.model.getattr(property_name_py).map_err(py_err_se_err)? + }; let s = PydanticSerializer::new( next_value, &self.computed_field.serializer, @@ -190,3 +201,15 @@ impl<'py> Serialize for ComputedFieldSerializer<'py> { s.serialize(serializer) } } + +fn has_ser_function(schema: &PyDict) -> bool { + let py = schema.py(); + let ser_schema = schema + .get_as::<&PyDict>(intern!(py, "serialization")) + .unwrap_or_default(); + + match ser_schema { + Some(s) => s.contains(intern!(py, "function")).unwrap_or_default(), + None => false, + } +} diff --git a/tests/serializers/test_typed_dict.py b/tests/serializers/test_typed_dict.py index df507a248..e9a220f64 100644 --- a/tests/serializers/test_typed_dict.py +++ b/tests/serializers/test_typed_dict.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict +from typing import Any, Dict, List import pytest from dirty_equals import IsStrictDict @@ -333,3 +333,72 @@ def test_extra_custom_serializer(): m = {'extra': 'extra'} assert s.to_python(m) == {'extra': 'extra bam!'} + + +def test_computed_fields_with_plain_serializer_function(): + def ser_x(v: dict): + two = v['0'] + v['1'] + 1 + return two + + schema = core_schema.typed_dict_schema( + { + '0': core_schema.typed_dict_field(core_schema.int_schema()), + '1': core_schema.typed_dict_field(core_schema.int_schema()), + }, + computed_fields=[ + core_schema.computed_field( + '2', core_schema.int_schema(serialization=core_schema.plain_serializer_function_ser_schema(ser_x)) + ) + ], + ) + s = SchemaSerializer(schema) + value = {'0': 0, '1': 1} + assert s.to_python(value) == {'0': 0, '1': 1, '2': 2} + + def ser_foo(_v: dict): + return 'bar' + + schema = core_schema.typed_dict_schema( + {}, + computed_fields=[ + core_schema.computed_field( + 'foo', core_schema.str_schema(serialization=core_schema.plain_serializer_function_ser_schema(ser_foo)) + ) + ], + ) + s = SchemaSerializer(schema) + assert s.to_python({}) == {'foo': 'bar'} + + +def test_computed_fields_with_warpped_serializer_function(): + def ser_to_upper(string_arr: List[str]) -> List[str]: + return [s.upper() for s in string_arr] + + def ser_columns(v: dict, serializer: core_schema.SerializerFunctionWrapHandler, _) -> str: + column_keys = serializer([key for key in v.keys()]) + return column_keys + + schema = core_schema.typed_dict_schema( + { + 'one': core_schema.typed_dict_field(core_schema.int_schema()), + 'two': core_schema.typed_dict_field(core_schema.int_schema()), + 'three': core_schema.typed_dict_field(core_schema.int_schema()), + }, + computed_fields=[ + core_schema.computed_field( + 'columns', + core_schema.int_schema( + serialization=core_schema.wrap_serializer_function_ser_schema( + ser_columns, + info_arg=True, + schema=core_schema.list_schema( + serialization=core_schema.plain_serializer_function_ser_schema(ser_to_upper) + ), + ) + ), + ) + ], + ) + s = SchemaSerializer(schema) + value = {'one': 1, 'two': 2, 'three': 3} + assert s.to_python(value) == {'one': 1, 'two': 2, 'three': 3, 'columns': ['ONE', 'TWO', 'THREE']} From a1bba7481420a70cc644ed4d3da3d2829f7252ae Mon Sep 17 00:00:00 2001 From: Michael Huang Date: Sun, 15 Oct 2023 06:14:34 -0400 Subject: [PATCH 02/12] Check if model is a dict --- src/serializers/computed_fields.rs | 36 ++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/src/serializers/computed_fields.rs b/src/serializers/computed_fields.rs index 8714dde2c..7c06a7da9 100644 --- a/src/serializers/computed_fields.rs +++ b/src/serializers/computed_fields.rs @@ -12,6 +12,8 @@ use crate::serializers::shared::{BuildSerializer, CombinedSerializer, PydanticSe use crate::tools::SchemaDict; use super::errors::py_err_se_err; +use super::errors::PydanticSerializationError; +use super::ob_type::{IsType, ObType}; use super::Extra; #[derive(Debug, Clone)] @@ -142,10 +144,17 @@ impl ComputedField { let property_name_py = self.property_name_py.as_ref(py); if let Some((next_include, next_exclude)) = filter.key_filter(property_name_py, include, exclude)? { - let next_value = if self.has_ser_func { - model - } else { - model.getattr(property_name_py)? + let next_value = match extra.ob_type_lookup.is_type(model, ObType::Dict) { + IsType::Exact => { + if !self.has_ser_func { + return Err(PydanticSerializationError::new_err(format!( + "no serialization function found for {}", + self.property_name + ))); + } + model + } + _ => model.getattr(property_name_py)?, }; let value = self @@ -186,11 +195,20 @@ impl<'py> Serialize for ComputedFieldSerializer<'py> { fn serialize(&self, serializer: S) -> Result { let py = self.model.py(); let property_name_py = self.computed_field.property_name_py.as_ref(py); - let next_value = if self.computed_field.has_ser_func { - self.model - } else { - self.model.getattr(property_name_py).map_err(py_err_se_err)? - }; + let next_value = match self.extra.ob_type_lookup.is_type(self.model, ObType::Dict) { + IsType::Exact => { + if self.computed_field.has_ser_func { + Ok(self.model) + } else { + Err(PydanticSerializationError::new_err(format!( + "no serialization function found for {}", + self.computed_field.property_name + ))) + } + } + _ => self.model.getattr(property_name_py), + } + .map_err(py_err_se_err)?; let s = PydanticSerializer::new( next_value, &self.computed_field.serializer, From 4d830d7fddded7531364ba77bde6e7117189ca35 Mon Sep 17 00:00:00 2001 From: Michael Huang Date: Sun, 15 Oct 2023 06:46:01 -0400 Subject: [PATCH 03/12] clean up --- src/serializers/computed_fields.rs | 55 ++++++++++++++---------------- 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/src/serializers/computed_fields.rs b/src/serializers/computed_fields.rs index 7c06a7da9..40b2960b3 100644 --- a/src/serializers/computed_fields.rs +++ b/src/serializers/computed_fields.rs @@ -13,7 +13,7 @@ use crate::tools::SchemaDict; use super::errors::py_err_se_err; use super::errors::PydanticSerializationError; -use super::ob_type::{IsType, ObType}; +use super::ob_type::{ObType, ObTypeLookup}; use super::Extra; #[derive(Debug, Clone)] @@ -144,19 +144,7 @@ impl ComputedField { let property_name_py = self.property_name_py.as_ref(py); if let Some((next_include, next_exclude)) = filter.key_filter(property_name_py, include, exclude)? { - let next_value = match extra.ob_type_lookup.is_type(model, ObType::Dict) { - IsType::Exact => { - if !self.has_ser_func { - return Err(PydanticSerializationError::new_err(format!( - "no serialization function found for {}", - self.property_name - ))); - } - model - } - _ => model.getattr(property_name_py)?, - }; - + let next_value = get_next_value(self, model, extra.ob_type_lookup)?; let value = self .serializer .to_python(next_value, next_include, next_exclude, extra)?; @@ -193,22 +181,8 @@ impl_py_gc_traverse!(ComputedFieldSerializer<'_> { computed_field }); impl<'py> Serialize for ComputedFieldSerializer<'py> { fn serialize(&self, serializer: S) -> Result { - let py = self.model.py(); - let property_name_py = self.computed_field.property_name_py.as_ref(py); - let next_value = match self.extra.ob_type_lookup.is_type(self.model, ObType::Dict) { - IsType::Exact => { - if self.computed_field.has_ser_func { - Ok(self.model) - } else { - Err(PydanticSerializationError::new_err(format!( - "no serialization function found for {}", - self.computed_field.property_name - ))) - } - } - _ => self.model.getattr(property_name_py), - } - .map_err(py_err_se_err)?; + let next_value = + get_next_value(self.computed_field, self.model, self.extra.ob_type_lookup).map_err(py_err_se_err)?; let s = PydanticSerializer::new( next_value, &self.computed_field.serializer, @@ -231,3 +205,24 @@ fn has_ser_function(schema: &PyDict) -> bool { None => false, } } + +fn get_next_value<'a>( + field: &'a ComputedField, + input_value: &'a PyAny, + ob_type_lookup: &'a ObTypeLookup, +) -> PyResult<&'a PyAny> { + let next_value = match ob_type_lookup.get_type(input_value) { + ObType::Unknown | ObType::Dataclass => input_value.getattr(field.property_name_py.as_ref(input_value.py())), + _ => { + if field.has_ser_func { + Ok(input_value) + } else { + Err(PydanticSerializationError::new_err(format!( + "no serialization function found for {}", + field.property_name + ))) + } + } + }; + next_value +} From 264e6acd6240f0b777ac8dd634b62b05baf798ab Mon Sep 17 00:00:00 2001 From: Michael Huang Date: Sun, 15 Oct 2023 07:16:03 -0400 Subject: [PATCH 04/12] Fix --- src/serializers/computed_fields.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/serializers/computed_fields.rs b/src/serializers/computed_fields.rs index 40b2960b3..e1021d839 100644 --- a/src/serializers/computed_fields.rs +++ b/src/serializers/computed_fields.rs @@ -211,8 +211,10 @@ fn get_next_value<'a>( input_value: &'a PyAny, ob_type_lookup: &'a ObTypeLookup, ) -> PyResult<&'a PyAny> { - let next_value = match ob_type_lookup.get_type(input_value) { - ObType::Unknown | ObType::Dataclass => input_value.getattr(field.property_name_py.as_ref(input_value.py())), + match ob_type_lookup.get_type(input_value) { + ObType::Dataclass | ObType::PydanticSerializable | ObType::Unknown => { + input_value.getattr(field.property_name_py.as_ref(input_value.py())) + } _ => { if field.has_ser_func { Ok(input_value) @@ -223,6 +225,5 @@ fn get_next_value<'a>( ))) } } - }; - next_value + } } From 747dd3e3e114b3ec5bc6799a90b83b04c34d37ef Mon Sep 17 00:00:00 2001 From: Michael Huang Date: Sun, 15 Oct 2023 07:56:25 -0400 Subject: [PATCH 05/12] Fallback behavior for ObType::Unknown --- src/serializers/computed_fields.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/serializers/computed_fields.rs b/src/serializers/computed_fields.rs index e1021d839..2e1805994 100644 --- a/src/serializers/computed_fields.rs +++ b/src/serializers/computed_fields.rs @@ -211,14 +211,16 @@ fn get_next_value<'a>( input_value: &'a PyAny, ob_type_lookup: &'a ObTypeLookup, ) -> PyResult<&'a PyAny> { + let property_name = field.property_name_py.as_ref(input_value.py()); match ob_type_lookup.get_type(input_value) { - ObType::Dataclass | ObType::PydanticSerializable | ObType::Unknown => { - input_value.getattr(field.property_name_py.as_ref(input_value.py())) - } + ObType::Dataclass | ObType::PydanticSerializable => input_value.getattr(property_name), _ => { if field.has_ser_func { Ok(input_value) } else { + if input_value.hasattr(property_name).unwrap_or_default() { + return input_value.getattr(property_name); + } Err(PydanticSerializationError::new_err(format!( "no serialization function found for {}", field.property_name From d9951880be0b98359ecc59228fc0f838224e89e8 Mon Sep 17 00:00:00 2001 From: Michael Huang Date: Sun, 15 Oct 2023 08:58:34 -0400 Subject: [PATCH 06/12] Update tests --- tests/serializers/test_model.py | 6 +++--- tests/serializers/test_typed_dict.py | 3 +++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/serializers/test_model.py b/tests/serializers/test_model.py index 32ecd3c1a..cc01c92e0 100644 --- a/tests/serializers/test_model.py +++ b/tests/serializers/test_model.py @@ -690,12 +690,12 @@ class Model: ), ) ) - with pytest.raises(AttributeError, match="^'Model' object has no attribute 'area'$"): + with pytest.raises(PydanticSerializationError, match="^No serialization function found for 'area'$"): s.to_python(Model(3)) - with pytest.raises(AttributeError, match="^'Model' object has no attribute 'area'$"): + with pytest.raises(PydanticSerializationError, match="^No serialization function found for 'area'$"): s.to_python(Model(3), mode='json') - e = "^Error serializing to JSON: AttributeError: 'Model' object has no attribute 'area'$" + e = "^Error serializing to JSON: PydanticSerializationError: No serialization function found for 'area'$" with pytest.raises(PydanticSerializationError, match=e): s.to_json(Model(3)) diff --git a/tests/serializers/test_typed_dict.py b/tests/serializers/test_typed_dict.py index e9a220f64..c89cdc240 100644 --- a/tests/serializers/test_typed_dict.py +++ b/tests/serializers/test_typed_dict.py @@ -354,6 +354,7 @@ def ser_x(v: dict): s = SchemaSerializer(schema) value = {'0': 0, '1': 1} assert s.to_python(value) == {'0': 0, '1': 1, '2': 2} + assert s.to_json(value) == b'{"0":0,"1":1,"2":2}' def ser_foo(_v: dict): return 'bar' @@ -368,6 +369,7 @@ def ser_foo(_v: dict): ) s = SchemaSerializer(schema) assert s.to_python({}) == {'foo': 'bar'} + assert s.to_json({}) == b'{"foo":"bar"}' def test_computed_fields_with_warpped_serializer_function(): @@ -402,3 +404,4 @@ def ser_columns(v: dict, serializer: core_schema.SerializerFunctionWrapHandler, s = SchemaSerializer(schema) value = {'one': 1, 'two': 2, 'three': 3} assert s.to_python(value) == {'one': 1, 'two': 2, 'three': 3, 'columns': ['ONE', 'TWO', 'THREE']} + assert s.to_json(value) == b'{"one":1,"two":2,"three":3,"columns":["ONE","TWO","THREE"]}' From fd5521cb384c7af2daa7637d929f2529a50c514c Mon Sep 17 00:00:00 2001 From: Michael Huang Date: Sun, 15 Oct 2023 08:59:33 -0400 Subject: [PATCH 07/12] Refactor get_next_value --- src/serializers/computed_fields.rs | 51 +++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/src/serializers/computed_fields.rs b/src/serializers/computed_fields.rs index 2e1805994..da1b9b057 100644 --- a/src/serializers/computed_fields.rs +++ b/src/serializers/computed_fields.rs @@ -1,3 +1,4 @@ +use pyo3::exceptions::PyAttributeError; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString}; use pyo3::{intern, PyTraverseError, PyVisit}; @@ -211,20 +212,46 @@ fn get_next_value<'a>( input_value: &'a PyAny, ob_type_lookup: &'a ObTypeLookup, ) -> PyResult<&'a PyAny> { + // Backwards compatiability. + let legacy_result = match ob_type_lookup.get_type(input_value) { + ObType::Dataclass | ObType::PydanticSerializable | ObType::Unknown => { + py_get_attrs(input_value, field.property_name_py.as_ref(input_value.py())) + } + _ => Ok(None), + }; + let legacy_next_value = match legacy_result { + Ok(res) => res, + Err(err) => return Err(err), + }; + if legacy_next_value.is_some() { + return Ok(legacy_next_value.unwrap()); + } + + // Default behavior: If custom serialization function provided, compute value based on input. + if field.has_ser_func { + return Ok(input_value); + } + + // Fallback behavior: Check if computed field is a property of input object. let property_name = field.property_name_py.as_ref(input_value.py()); - match ob_type_lookup.get_type(input_value) { - ObType::Dataclass | ObType::PydanticSerializable => input_value.getattr(property_name), - _ => { - if field.has_ser_func { - Ok(input_value) + if input_value.hasattr(property_name).unwrap_or_default() { + return input_value.getattr(property_name); + } + + Err(PydanticSerializationError::new_err(format!( + "No serialization function found for '{}'", + property_name + ))) +} + +fn py_get_attrs<'a>(obj: &'a PyAny, attr_name: &PyString) -> PyResult> { + match obj.getattr(attr_name) { + Ok(attr) => Ok(Some(attr)), + Err(err) => { + if err.get_type(obj.py()).is_subclass_of::()? { + Ok(None) } else { - if input_value.hasattr(property_name).unwrap_or_default() { - return input_value.getattr(property_name); - } - Err(PydanticSerializationError::new_err(format!( - "no serialization function found for {}", - field.property_name - ))) + Err(err) } } } From 6dc985f996bb8c90286a37d090296daecc139b89 Mon Sep 17 00:00:00 2001 From: Michael Huang Date: Sun, 15 Oct 2023 09:11:02 -0400 Subject: [PATCH 08/12] Lint --- src/serializers/computed_fields.rs | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/serializers/computed_fields.rs b/src/serializers/computed_fields.rs index da1b9b057..2c68f2d35 100644 --- a/src/serializers/computed_fields.rs +++ b/src/serializers/computed_fields.rs @@ -214,33 +214,34 @@ fn get_next_value<'a>( ) -> PyResult<&'a PyAny> { // Backwards compatiability. let legacy_result = match ob_type_lookup.get_type(input_value) { - ObType::Dataclass | ObType::PydanticSerializable | ObType::Unknown => { + ObType::Dataclass | ObType::PydanticSerializable => { py_get_attrs(input_value, field.property_name_py.as_ref(input_value.py())) } _ => Ok(None), }; - let legacy_next_value = match legacy_result { - Ok(res) => res, + match legacy_result { + Ok(opt) => { + if let Some(legacy_next_value) = opt { + return Ok(legacy_next_value); + } + } Err(err) => return Err(err), }; - if legacy_next_value.is_some() { - return Ok(legacy_next_value.unwrap()); - } // Default behavior: If custom serialization function provided, compute value based on input. if field.has_ser_func { return Ok(input_value); } - - // Fallback behavior: Check if computed field is a property of input object. - let property_name = field.property_name_py.as_ref(input_value.py()); - if input_value.hasattr(property_name).unwrap_or_default() { - return input_value.getattr(property_name); + // Fallback behavior: Check if computed field is a property of input object + // (i.e. in some cases input_value can be ObType::Unknown) + if let Ok(Some(next_value_from_input)) = py_get_attrs(input_value, field.property_name_py.as_ref(input_value.py())) + { + return Ok(next_value_from_input); } Err(PydanticSerializationError::new_err(format!( "No serialization function found for '{}'", - property_name + field.property_name ))) } From d2932153186e94378bbc9d122149bee69d8796a7 Mon Sep 17 00:00:00 2001 From: Michael Huang Date: Wed, 25 Oct 2023 13:54:30 -0400 Subject: [PATCH 09/12] Address ser_schema comment --- src/serializers/computed_fields.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/serializers/computed_fields.rs b/src/serializers/computed_fields.rs index 2c68f2d35..5885dda90 100644 --- a/src/serializers/computed_fields.rs +++ b/src/serializers/computed_fields.rs @@ -200,11 +200,7 @@ fn has_ser_function(schema: &PyDict) -> bool { let ser_schema = schema .get_as::<&PyDict>(intern!(py, "serialization")) .unwrap_or_default(); - - match ser_schema { - Some(s) => s.contains(intern!(py, "function")).unwrap_or_default(), - None => false, - } + ser_schema.is_some_and(|s| s.contains(intern!(py, "function")).unwrap_or_default()) } fn get_next_value<'a>( From 147e917f556a9487dc2ec4ea4bc8dfcdbe7b6c63 Mon Sep 17 00:00:00 2001 From: Michael Huang Date: Wed, 25 Oct 2023 14:02:31 -0400 Subject: [PATCH 10/12] Add test to compute on TypedDict model --- tests/serializers/test_typed_dict.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/serializers/test_typed_dict.py b/tests/serializers/test_typed_dict.py index c89cdc240..a13fac751 100644 --- a/tests/serializers/test_typed_dict.py +++ b/tests/serializers/test_typed_dict.py @@ -405,3 +405,23 @@ def ser_columns(v: dict, serializer: core_schema.SerializerFunctionWrapHandler, value = {'one': 1, 'two': 2, 'three': 3} assert s.to_python(value) == {'one': 1, 'two': 2, 'three': 3, 'columns': ['ONE', 'TWO', 'THREE']} assert s.to_json(value) == b'{"one":1,"two":2,"three":3,"columns":["ONE","TWO","THREE"]}' + + +def test_computed_fields_with_typed_dict_model(): + class Model(TypedDict): + x: int + + def ser_y(v: Any) -> str: + return f'{v["x"]}.00' + + s = SchemaSerializer( + core_schema.typed_dict_schema( + {'x': core_schema.typed_dict_field(core_schema.int_schema())}, + computed_fields=[ + core_schema.computed_field( + 'y', core_schema.str_schema(serialization=core_schema.plain_serializer_function_ser_schema(ser_y)) + ) + ], + ) + ) + assert s.to_python(Model(x=1000)) == {'x': 1000, 'y': '1000.00'} From bdc1326bd022d180be519040c2eff8d0698e4b64 Mon Sep 17 00:00:00 2001 From: Michael Huang Date: Wed, 25 Oct 2023 14:23:15 -0400 Subject: [PATCH 11/12] Address comment on error message --- src/serializers/computed_fields.rs | 34 +++++++++++++--------------- tests/serializers/test_model.py | 6 ++--- tests/serializers/test_typed_dict.py | 16 ++++++++++++- 3 files changed, 34 insertions(+), 22 deletions(-) diff --git a/src/serializers/computed_fields.rs b/src/serializers/computed_fields.rs index 5885dda90..7938d9e50 100644 --- a/src/serializers/computed_fields.rs +++ b/src/serializers/computed_fields.rs @@ -208,10 +208,22 @@ fn get_next_value<'a>( input_value: &'a PyAny, ob_type_lookup: &'a ObTypeLookup, ) -> PyResult<&'a PyAny> { + let py = input_value.py(); + let mut legacy_attr_error: Option = None; // Backwards compatiability. let legacy_result = match ob_type_lookup.get_type(input_value) { ObType::Dataclass | ObType::PydanticSerializable => { - py_get_attrs(input_value, field.property_name_py.as_ref(input_value.py())) + match input_value.getattr(field.property_name_py.as_ref(py)) { + Ok(attr) => Ok(Some(attr)), + Err(err) => { + if err.get_type(py).is_subclass_of::()? { + legacy_attr_error = Some(err); + Ok(None) + } else { + Err(err) + } + } + } } _ => Ok(None), }; @@ -230,26 +242,12 @@ fn get_next_value<'a>( } // Fallback behavior: Check if computed field is a property of input object // (i.e. in some cases input_value can be ObType::Unknown) - if let Ok(Some(next_value_from_input)) = py_get_attrs(input_value, field.property_name_py.as_ref(input_value.py())) - { + if let Ok(next_value_from_input) = input_value.getattr(field.property_name_py.as_ref(py)) { return Ok(next_value_from_input); } - Err(PydanticSerializationError::new_err(format!( + Err(legacy_attr_error.unwrap_or(PydanticSerializationError::new_err(format!( "No serialization function found for '{}'", field.property_name - ))) -} - -fn py_get_attrs<'a>(obj: &'a PyAny, attr_name: &PyString) -> PyResult> { - match obj.getattr(attr_name) { - Ok(attr) => Ok(Some(attr)), - Err(err) => { - if err.get_type(obj.py()).is_subclass_of::()? { - Ok(None) - } else { - Err(err) - } - } - } + )))) } diff --git a/tests/serializers/test_model.py b/tests/serializers/test_model.py index cc01c92e0..32ecd3c1a 100644 --- a/tests/serializers/test_model.py +++ b/tests/serializers/test_model.py @@ -690,12 +690,12 @@ class Model: ), ) ) - with pytest.raises(PydanticSerializationError, match="^No serialization function found for 'area'$"): + with pytest.raises(AttributeError, match="^'Model' object has no attribute 'area'$"): s.to_python(Model(3)) - with pytest.raises(PydanticSerializationError, match="^No serialization function found for 'area'$"): + with pytest.raises(AttributeError, match="^'Model' object has no attribute 'area'$"): s.to_python(Model(3), mode='json') - e = "^Error serializing to JSON: PydanticSerializationError: No serialization function found for 'area'$" + e = "^Error serializing to JSON: AttributeError: 'Model' object has no attribute 'area'$" with pytest.raises(PydanticSerializationError, match=e): s.to_json(Model(3)) diff --git a/tests/serializers/test_typed_dict.py b/tests/serializers/test_typed_dict.py index a13fac751..e9f1fc1d9 100644 --- a/tests/serializers/test_typed_dict.py +++ b/tests/serializers/test_typed_dict.py @@ -5,7 +5,7 @@ from dirty_equals import IsStrictDict from typing_extensions import TypedDict -from pydantic_core import SchemaSerializer, core_schema +from pydantic_core import PydanticSerializationError, SchemaSerializer, core_schema @pytest.mark.parametrize('extra_behavior_kw', [{}, {'extra_behavior': 'ignore'}, {'extra_behavior': None}]) @@ -425,3 +425,17 @@ def ser_y(v: Any) -> str: ) ) assert s.to_python(Model(x=1000)) == {'x': 1000, 'y': '1000.00'} + + +def test_computed_fields_without_ser_function(): + class Model(TypedDict): + x: int + + s = SchemaSerializer( + core_schema.typed_dict_schema( + {'x': core_schema.typed_dict_field(core_schema.int_schema())}, + computed_fields=[core_schema.computed_field('y', core_schema.str_schema())], + ) + ) + with pytest.raises(PydanticSerializationError, match="^No serialization function found for 'y'$"): + s.to_python(Model(x=1000)) From 144120fb30c4c715e88f74d0a71b010474230a16 Mon Sep 17 00:00:00 2001 From: Michael Huang Date: Wed, 25 Oct 2023 14:29:45 -0400 Subject: [PATCH 12/12] Small clean up --- src/serializers/computed_fields.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/serializers/computed_fields.rs b/src/serializers/computed_fields.rs index 7938d9e50..e8384128e 100644 --- a/src/serializers/computed_fields.rs +++ b/src/serializers/computed_fields.rs @@ -209,8 +209,8 @@ fn get_next_value<'a>( ob_type_lookup: &'a ObTypeLookup, ) -> PyResult<&'a PyAny> { let py = input_value.py(); - let mut legacy_attr_error: Option = None; // Backwards compatiability. + let mut legacy_attr_error: Option = None; let legacy_result = match ob_type_lookup.get_type(input_value) { ObType::Dataclass | ObType::PydanticSerializable => { match input_value.getattr(field.property_name_py.as_ref(py)) {