Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix wrap serializer breaking union serialization in presence of extra fields #1530

Merged
merged 3 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ impl<'a> Extra<'a> {
pub fn serialize_infer<'py>(&'py self, value: &'py Bound<'py, PyAny>) -> super::infer::SerializeInfer<'py> {
super::infer::SerializeInfer::new(value, None, None, self)
}

pub(crate) fn model_type_name(&self) -> Option<Bound<'a, PyString>> {
self.model.and_then(|model| model.get_type().name().ok())
}
}

#[derive(Clone, Copy, PartialEq, Eq)]
Expand Down
23 changes: 12 additions & 11 deletions src/serializers/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,15 @@ impl GeneralFieldsSerializer {
};
output_dict.set_item(key, value)?;
} else if field_extra.check == SerCheck::Strict {
return Err(PydanticSerializationUnexpectedValue::new_err(None));
let type_name = field_extra.model_type_name();
return Err(PydanticSerializationUnexpectedValue::new_err(Some(format!(
"Unexpected field `{key}`{for_type_name}",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a drive-by improvement to the warning which was necessary for me to debug.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminds me of #1483

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice improvement though, thanks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we go ahead and fix #1483 while we're at it here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last comment here - we should do a quick search through the codebase to see if other (similar) warnings can be updated with this structure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I will punt on both of these and leave them for a later refactoring of serialization warnings.

for_type_name = if let Some(type_name) = type_name {
format!(" for type `{type_name}`")
} else {
String::new()
},
))));
}
}
}
Expand All @@ -212,22 +220,15 @@ impl GeneralFieldsSerializer {
&& self.required_fields > used_req_fields
{
let required_fields = self.required_fields;
let type_name = match extra.model {
Some(model) => model
.get_type()
.qualname()
.ok()
.unwrap_or_else(|| PyString::new_bound(py, "<unknown python object>"))
.to_string(),
None => "<unknown python object>".to_string(),
};
let type_name = extra.model_type_name();
let field_value = match extra.model {
Some(model) => truncate_safe_repr(model, Some(100)),
None => "<unknown python object>".to_string(),
};

Err(PydanticSerializationUnexpectedValue::new_err(Some(format!(
"Expected {required_fields} fields but got {used_req_fields} for type `{type_name}` with value `{field_value}` - serialized value may not be as expected."
"Expected {required_fields} fields but got {used_req_fields}{for_type_name} with value `{field_value}` - serialized value may not be as expected.",
for_type_name = if let Some(type_name) = type_name { format!(" for type `{type_name}`") } else { String::new() },
))))
} else {
Ok(output_dict)
Expand Down
15 changes: 15 additions & 0 deletions src/serializers/type_serializers/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ impl FunctionPlainSerializer {
.expect("fallback_serializer unexpectedly none")
.as_ref()
}

fn retry_with_lax_check(&self) -> bool {
self.fallback_serializer
.as_ref()
.map_or(false, |f| f.retry_with_lax_check())
|| self.return_serializer.retry_with_lax_check()
}
}

fn on_error(py: Python, err: PyErr, function_name: &str, extra: &Extra) -> PyResult<()> {
Expand Down Expand Up @@ -271,6 +278,10 @@ macro_rules! function_type_serializer {
fn get_name(&self) -> &str {
&self.name
}

fn retry_with_lax_check(&self) -> bool {
self.retry_with_lax_check()
}
}
};
}
Expand Down Expand Up @@ -409,6 +420,10 @@ impl FunctionWrapSerializer {
fn get_fallback_serializer(&self) -> &CombinedSerializer {
self.serializer.as_ref()
}

fn retry_with_lax_check(&self) -> bool {
self.serializer.retry_with_lax_check() || self.return_serializer.retry_with_lax_check()
}
}

impl_py_gc_traverse!(FunctionWrapSerializer {
Expand Down
139 changes: 93 additions & 46 deletions tests/serializers/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,53 +60,36 @@ def __init__(self, c, d):
@pytest.fixture(scope='module')
def model_serializer() -> SchemaSerializer:
return SchemaSerializer(
{
'type': 'union',
'choices': [
{
'type': 'model',
'cls': ModelA,
'schema': {
'type': 'model-fields',
'fields': {
'a': {'type': 'model-field', 'schema': {'type': 'bytes'}},
'b': {
'type': 'model-field',
'schema': {
'type': 'float',
'serialization': {
'type': 'format',
'formatting_string': '0.1f',
'when_used': 'unless-none',
},
},
},
},
},
},
{
'type': 'model',
'cls': ModelB,
'schema': {
'type': 'model-fields',
'fields': {
'c': {'type': 'model-field', 'schema': {'type': 'bytes'}},
'd': {
'type': 'model-field',
'schema': {
'type': 'float',
'serialization': {
'type': 'format',
'formatting_string': '0.2f',
'when_used': 'unless-none',
},
},
},
},
},
},
core_schema.union_schema(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much cleaner, thanks.

[
core_schema.model_schema(
ModelA,
core_schema.model_fields_schema(
{
'a': core_schema.model_field(core_schema.bytes_schema()),
'b': core_schema.model_field(
core_schema.float_schema(
serialization=core_schema.format_ser_schema('0.1f', when_used='unless-none')
)
),
}
),
),
core_schema.model_schema(
ModelB,
core_schema.model_fields_schema(
{
'c': core_schema.model_field(core_schema.bytes_schema()),
'd': core_schema.model_field(
core_schema.float_schema(
serialization=core_schema.format_ser_schema('0.2f', when_used='unless-none')
)
),
}
),
),
],
}
)
)


Expand Down Expand Up @@ -778,3 +761,67 @@ class ModelB:
model_b = ModelB(field=1)
assert s.to_python(model_a) == {'field': 1, 'TAG': 'a'}
assert s.to_python(model_b) == {'field': 1, 'TAG': 'b'}


def test_union_model_wrap_serializer():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test fails on main.

def wrap_serializer(value, handler):
return handler(value)

class Data:
pass

class ModelA:
a: Data

class ModelB:
a: Data

model_serializer = SchemaSerializer(
core_schema.union_schema(
[
core_schema.model_schema(
ModelA,
core_schema.model_fields_schema(
{
'a': core_schema.model_field(
core_schema.model_schema(
Data,
core_schema.model_fields_schema({}),
)
),
},
),
serialization=core_schema.wrap_serializer_function_ser_schema(wrap_serializer),
),
core_schema.model_schema(
ModelB,
core_schema.model_fields_schema(
{
'a': core_schema.model_field(
core_schema.model_schema(
Data,
core_schema.model_fields_schema({}),
)
),
},
),
serialization=core_schema.wrap_serializer_function_ser_schema(wrap_serializer),
),
],
)
)

input_value = ModelA()
input_value.a = Data()

assert model_serializer.to_python(input_value) == {'a': {}}
assert model_serializer.to_python(input_value, mode='json') == {'a': {}}
assert model_serializer.to_json(input_value) == b'{"a":{}}'

# add some additional attribute, should be ignored & not break serialization

input_value.a._a = 'foo'

assert model_serializer.to_python(input_value) == {'a': {}}
assert model_serializer.to_python(input_value, mode='json') == {'a': {}}
assert model_serializer.to_json(input_value) == b'{"a":{}}'
Loading