Skip to content

Commit

Permalink
fix for model subclass serialization in unions (pydantic#632)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored May 25, 2023
1 parent 726ef5f commit 7dc970f
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 11 deletions.
10 changes: 6 additions & 4 deletions src/serializers/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use pyo3::types::{PyDict, PyString};
use ahash::AHashMap;
use serde::ser::SerializeMap;

use crate::serializers::extra::SerCheck;
use crate::PydanticSerializationUnexpectedValue;

use super::computed_fields::ComputedFields;
Expand Down Expand Up @@ -75,7 +76,7 @@ fn exclude_default(value: &PyAny, extra: &Extra, serializer: &CombinedSerializer
Ok(false)
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Eq, PartialEq)]
pub(super) enum FieldsMode {
// typeddict with no extra items
SimpleDict,
Expand Down Expand Up @@ -196,10 +197,10 @@ impl TypeSerializer for GeneralFieldsSerializer {
if field.required {
used_req_fields += 1;
}
} else if matches!(self.mode, FieldsMode::TypedDictAllow) {
} else if self.mode == FieldsMode::TypedDictAllow {
let value = infer_to_python(value, next_include, next_exclude, &extra)?;
output_dict.set_item(key, value)?;
} else if extra.check.enabled() {
} else if extra.check == SerCheck::Strict {
return Err(PydanticSerializationUnexpectedValue::new_err(None));
}
}
Expand Down Expand Up @@ -281,11 +282,12 @@ impl TypeSerializer for GeneralFieldsSerializer {
map.serialize_entry(&output_key, &s)?;
}
}
} else if matches!(self.mode, FieldsMode::TypedDictAllow) {
} else if self.mode == FieldsMode::TypedDictAllow {
let output_key = infer_json_key(key, &extra).map_err(py_err_se_err)?;
let s = SerializeInfer::new(value, next_include, next_exclude, &extra);
map.serialize_entry(&output_key, &s)?
}
// no error case here since unions (which need the error case) use `to_python(..., mode='json')`
}
}
// this is used to include `__pydantic_extra__` in serialization on models
Expand Down
5 changes: 3 additions & 2 deletions src/serializers/type_serializers/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ fn has_extra(schema: &PyDict, config: Option<&PyDict>) -> PyResult<bool> {

impl ModelSerializer {
fn allow_value(&self, value: &PyAny, extra: &Extra) -> PyResult<bool> {
let class = self.class.as_ref(value.py());
match extra.check {
SerCheck::Strict => Ok(value.get_type().is(self.class.as_ref(value.py()))),
SerCheck::Lax => value.is_instance(self.class.as_ref(value.py())),
SerCheck::Strict => Ok(value.get_type().is(class)),
SerCheck::Lax => value.is_instance(class),
SerCheck::None => value.hasattr(intern!(value.py(), "__dict__")),
}
}
Expand Down
101 changes: 96 additions & 5 deletions tests/serializers/test_union.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import json
import re

Expand Down Expand Up @@ -179,17 +180,16 @@ def test_typed_dict_literal():


def test_typed_dict_missing():
"""
TODO, needs tests for each case
"""
s = SchemaSerializer(
core_schema.union_schema(
[
core_schema.typed_dict_schema(dict(foo=core_schema.typed_dict_field(core_schema.int_schema()))),
core_schema.typed_dict_schema(
dict(
foo=core_schema.typed_dict_field(
core_schema.int_schema(serialization=core_schema.format_ser_schema('04d'))
core_schema.int_schema(
serialization=core_schema.format_ser_schema('04d', when_used='always')
)
),
bar=core_schema.typed_dict_field(core_schema.int_schema()),
)
Expand All @@ -201,7 +201,8 @@ def test_typed_dict_missing():
assert s.to_python(dict(foo=1)) == {'foo': 1}
assert s.to_python(dict(foo=1), mode='json') == {'foo': 1}
assert s.to_json(dict(foo=1)) == b'{"foo":1}'
assert s.to_python(dict(foo=1, bar=2)) == {'foo': 1, 'bar': 2}

assert s.to_python(dict(foo=1, bar=2)) == {'foo': '0001', 'bar': 2}
assert s.to_python(dict(foo=1, bar=2), mode='json') == {'foo': '0001', 'bar': 2}
assert s.to_json(dict(foo=1, bar=2)) == b'{"foo":"0001","bar":2}'

Expand Down Expand Up @@ -269,3 +270,93 @@ def test_typed_dict_different_fields():
assert s.to_python(dict(spam=1, ham=2)) == {'spam': 1, 'ham': 2}
assert s.to_python(dict(spam=1, ham=2), mode='json') == {'spam': 1, 'ham': '0002'}
assert s.to_json(dict(spam=1, ham=2)) == b'{"spam":1,"ham":"0002"}'


def test_dataclass_union():
@dataclasses.dataclass
class BaseUser:
name: str

@dataclasses.dataclass
class User(BaseUser):
surname: str

@dataclasses.dataclass
class DBUser(User):
password_hash: str

@dataclasses.dataclass
class Item:
name: str
price: float

user_schema = core_schema.dataclass_schema(
User,
core_schema.dataclass_args_schema(
'User',
[
core_schema.dataclass_field(name='name', schema=core_schema.str_schema()),
core_schema.dataclass_field(name='surname', schema=core_schema.str_schema()),
],
),
['name', 'surname'],
)
item_schema = core_schema.dataclass_schema(
Item,
core_schema.dataclass_args_schema(
'Item',
[
core_schema.dataclass_field(name='name', schema=core_schema.str_schema()),
core_schema.dataclass_field(name='price', schema=core_schema.float_schema()),
],
),
['name', 'price'],
)
s = SchemaSerializer(core_schema.union_schema([user_schema, item_schema]))
assert s.to_python(User(name='foo', surname='bar')) == {'name': 'foo', 'surname': 'bar'}
assert s.to_python(DBUser(name='foo', surname='bar', password_hash='x')) == {'name': 'foo', 'surname': 'bar'}
assert s.to_json(DBUser(name='foo', surname='bar', password_hash='x')) == b'{"name":"foo","surname":"bar"}'


def test_model_union():
class BaseUser:
def __init__(self, name: str):
self.name = name

class User(BaseUser):
def __init__(self, name: str, surname: str):
super().__init__(name)
self.surname = surname

class DBUser(User):
def __init__(self, name: str, surname: str, password_hash: str):
super().__init__(name, surname)
self.password_hash = password_hash

class Item:
def __init__(self, name: str, price: float):
self.name = name
self.price = price

user_schema = core_schema.model_schema(
User,
core_schema.model_fields_schema(
{
'name': core_schema.model_field(schema=core_schema.str_schema()),
'surname': core_schema.model_field(schema=core_schema.str_schema()),
}
),
)
item_schema = core_schema.model_schema(
Item,
core_schema.model_fields_schema(
{
'name': core_schema.model_field(schema=core_schema.str_schema()),
'price': core_schema.model_field(schema=core_schema.float_schema()),
}
),
)
s = SchemaSerializer(core_schema.union_schema([user_schema, item_schema]))
assert s.to_python(User(name='foo', surname='bar')) == {'name': 'foo', 'surname': 'bar'}
assert s.to_python(DBUser(name='foo', surname='bar', password_hash='x')) == {'name': 'foo', 'surname': 'bar'}
assert s.to_json(DBUser(name='foo', surname='bar', password_hash='x')) == b'{"name":"foo","surname":"bar"}'

0 comments on commit 7dc970f

Please sign in to comment.