From 7dc970f9976ef59eefe6fe2a0a2c30559850e37a Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 25 May 2023 23:27:23 +0100 Subject: [PATCH] fix for model subclass serialization in unions (#632) --- src/serializers/fields.rs | 10 ++- src/serializers/type_serializers/model.rs | 5 +- tests/serializers/test_union.py | 101 ++++++++++++++++++++-- 3 files changed, 105 insertions(+), 11 deletions(-) diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index 4c3a08816c..dd73e843c3 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -6,6 +6,7 @@ use pyo3::types::{PyDict, PyString}; use ahash::AHashMap; use serde::ser::SerializeMap; +use crate::serializers::extra::SerCheck; use crate::PydanticSerializationUnexpectedValue; use super::computed_fields::ComputedFields; @@ -75,7 +76,7 @@ fn exclude_default(value: &PyAny, extra: &Extra, serializer: &CombinedSerializer Ok(false) } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] pub(super) enum FieldsMode { // typeddict with no extra items SimpleDict, @@ -196,10 +197,10 @@ impl TypeSerializer for GeneralFieldsSerializer { if field.required { used_req_fields += 1; } - } else if matches!(self.mode, FieldsMode::TypedDictAllow) { + } else if self.mode == FieldsMode::TypedDictAllow { let value = infer_to_python(value, next_include, next_exclude, &extra)?; output_dict.set_item(key, value)?; - } else if extra.check.enabled() { + } else if extra.check == SerCheck::Strict { return Err(PydanticSerializationUnexpectedValue::new_err(None)); } } @@ -281,11 +282,12 @@ impl TypeSerializer for GeneralFieldsSerializer { map.serialize_entry(&output_key, &s)?; } } - } else if matches!(self.mode, FieldsMode::TypedDictAllow) { + } else if self.mode == FieldsMode::TypedDictAllow { let output_key = infer_json_key(key, &extra).map_err(py_err_se_err)?; let s = SerializeInfer::new(value, next_include, next_exclude, &extra); map.serialize_entry(&output_key, &s)? } + // no error case here since unions (which need the error case) use `to_python(..., mode='json')` } } // this is used to include `__pydantic_extra__` in serialization on models diff --git a/src/serializers/type_serializers/model.rs b/src/serializers/type_serializers/model.rs index 6107ff2b2e..6c92d8d356 100644 --- a/src/serializers/type_serializers/model.rs +++ b/src/serializers/type_serializers/model.rs @@ -109,9 +109,10 @@ fn has_extra(schema: &PyDict, config: Option<&PyDict>) -> PyResult { impl ModelSerializer { fn allow_value(&self, value: &PyAny, extra: &Extra) -> PyResult { + let class = self.class.as_ref(value.py()); match extra.check { - SerCheck::Strict => Ok(value.get_type().is(self.class.as_ref(value.py()))), - SerCheck::Lax => value.is_instance(self.class.as_ref(value.py())), + SerCheck::Strict => Ok(value.get_type().is(class)), + SerCheck::Lax => value.is_instance(class), SerCheck::None => value.hasattr(intern!(value.py(), "__dict__")), } } diff --git a/tests/serializers/test_union.py b/tests/serializers/test_union.py index aeaf3c9262..7177522472 100644 --- a/tests/serializers/test_union.py +++ b/tests/serializers/test_union.py @@ -1,3 +1,4 @@ +import dataclasses import json import re @@ -179,9 +180,6 @@ def test_typed_dict_literal(): def test_typed_dict_missing(): - """ - TODO, needs tests for each case - """ s = SchemaSerializer( core_schema.union_schema( [ @@ -189,7 +187,9 @@ def test_typed_dict_missing(): core_schema.typed_dict_schema( dict( foo=core_schema.typed_dict_field( - core_schema.int_schema(serialization=core_schema.format_ser_schema('04d')) + core_schema.int_schema( + serialization=core_schema.format_ser_schema('04d', when_used='always') + ) ), bar=core_schema.typed_dict_field(core_schema.int_schema()), ) @@ -201,7 +201,8 @@ def test_typed_dict_missing(): assert s.to_python(dict(foo=1)) == {'foo': 1} assert s.to_python(dict(foo=1), mode='json') == {'foo': 1} assert s.to_json(dict(foo=1)) == b'{"foo":1}' - assert s.to_python(dict(foo=1, bar=2)) == {'foo': 1, 'bar': 2} + + assert s.to_python(dict(foo=1, bar=2)) == {'foo': '0001', 'bar': 2} assert s.to_python(dict(foo=1, bar=2), mode='json') == {'foo': '0001', 'bar': 2} assert s.to_json(dict(foo=1, bar=2)) == b'{"foo":"0001","bar":2}' @@ -269,3 +270,93 @@ def test_typed_dict_different_fields(): assert s.to_python(dict(spam=1, ham=2)) == {'spam': 1, 'ham': 2} assert s.to_python(dict(spam=1, ham=2), mode='json') == {'spam': 1, 'ham': '0002'} assert s.to_json(dict(spam=1, ham=2)) == b'{"spam":1,"ham":"0002"}' + + +def test_dataclass_union(): + @dataclasses.dataclass + class BaseUser: + name: str + + @dataclasses.dataclass + class User(BaseUser): + surname: str + + @dataclasses.dataclass + class DBUser(User): + password_hash: str + + @dataclasses.dataclass + class Item: + name: str + price: float + + user_schema = core_schema.dataclass_schema( + User, + core_schema.dataclass_args_schema( + 'User', + [ + core_schema.dataclass_field(name='name', schema=core_schema.str_schema()), + core_schema.dataclass_field(name='surname', schema=core_schema.str_schema()), + ], + ), + ['name', 'surname'], + ) + item_schema = core_schema.dataclass_schema( + Item, + core_schema.dataclass_args_schema( + 'Item', + [ + core_schema.dataclass_field(name='name', schema=core_schema.str_schema()), + core_schema.dataclass_field(name='price', schema=core_schema.float_schema()), + ], + ), + ['name', 'price'], + ) + s = SchemaSerializer(core_schema.union_schema([user_schema, item_schema])) + assert s.to_python(User(name='foo', surname='bar')) == {'name': 'foo', 'surname': 'bar'} + assert s.to_python(DBUser(name='foo', surname='bar', password_hash='x')) == {'name': 'foo', 'surname': 'bar'} + assert s.to_json(DBUser(name='foo', surname='bar', password_hash='x')) == b'{"name":"foo","surname":"bar"}' + + +def test_model_union(): + class BaseUser: + def __init__(self, name: str): + self.name = name + + class User(BaseUser): + def __init__(self, name: str, surname: str): + super().__init__(name) + self.surname = surname + + class DBUser(User): + def __init__(self, name: str, surname: str, password_hash: str): + super().__init__(name, surname) + self.password_hash = password_hash + + class Item: + def __init__(self, name: str, price: float): + self.name = name + self.price = price + + user_schema = core_schema.model_schema( + User, + core_schema.model_fields_schema( + { + 'name': core_schema.model_field(schema=core_schema.str_schema()), + 'surname': core_schema.model_field(schema=core_schema.str_schema()), + } + ), + ) + item_schema = core_schema.model_schema( + Item, + core_schema.model_fields_schema( + { + 'name': core_schema.model_field(schema=core_schema.str_schema()), + 'price': core_schema.model_field(schema=core_schema.float_schema()), + } + ), + ) + s = SchemaSerializer(core_schema.union_schema([user_schema, item_schema])) + assert s.to_python(User(name='foo', surname='bar')) == {'name': 'foo', 'surname': 'bar'} + assert s.to_python(DBUser(name='foo', surname='bar', password_hash='x')) == {'name': 'foo', 'surname': 'bar'} + assert s.to_json(DBUser(name='foo', surname='bar', password_hash='x')) == b'{"name":"foo","surname":"bar"}'