From 6472887b3ad3e865e494b46efb66a71e87ceb1cf Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:00:12 -0500 Subject: [PATCH] WIP: Simplify shared union serializer logic (#1538) Co-authored-by: David Hewitt --- src/serializers/type_serializers/union.rs | 261 +++++++--------------- 1 file changed, 84 insertions(+), 177 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 35a5c8bc7..40a72e9dd 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -8,8 +8,8 @@ use std::borrow::Cow; use crate::build_tools::py_schema_err; use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD}; use crate::definitions::DefinitionsBuilder; +use crate::serializers::PydanticSerializationUnexpectedValue; use crate::tools::{truncate_safe_repr, SchemaDict}; -use crate::PydanticSerializationUnexpectedValue; use super::{ infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerCheck, @@ -70,22 +70,23 @@ impl UnionSerializer { impl_py_gc_traverse!(UnionSerializer { choices }); -fn to_python( - value: &Bound<'_, PyAny>, - include: Option<&Bound<'_, PyAny>>, - exclude: Option<&Bound<'_, PyAny>>, +fn union_serialize( + // if this returns `Ok(Some(v))`, we picked a union variant to serialize, + // Or `Ok(None)` if we couldn't find a suitable variant to serialize + // Finally, `Err(err)` if we encountered errors while trying to serialize + mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult, extra: &Extra, choices: &[CombinedSerializer], retry_with_lax_check: bool, -) -> PyResult { +) -> PyResult> { // try the serializers in left to right order with error_on fallback=true let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); for comb_serializer in choices { - match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return Ok(v), + match selector(comb_serializer, &new_extra) { + Ok(v) => return Ok(Some(v)), Err(err) => errors.push(err), } } @@ -94,8 +95,8 @@ fn to_python( if extra.check != SerCheck::Strict && retry_with_lax_check { new_extra.check = SerCheck::Lax; for comb_serializer in choices { - if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) { - return Ok(v); + if let Ok(v) = selector(comb_serializer, &new_extra) { + return Ok(Some(v)); } } } @@ -113,94 +114,45 @@ fn to_python( return Err(PydanticSerializationUnexpectedValue::new_err(Some(message))); } - infer_to_python(value, include, exclude, extra) + Ok(None) } -fn json_key<'a>( - key: &'a Bound<'_, PyAny>, +fn tagged_union_serialize( + discriminator_value: Option>, + lookup: &HashMap, + // if this returns `Ok(v)`, we picked a union variant to serialize, where + // `S` is intermediate state which can be passed on to the finalizer + mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult, extra: &Extra, choices: &[CombinedSerializer], retry_with_lax_check: bool, -) -> PyResult> { +) -> PyResult> { let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; - let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); - - for comb_serializer in choices { - match comb_serializer.json_key(key, &new_extra) { - Ok(v) => return Ok(v), - Err(err) => errors.push(err), - } - } - // If extra.check is SerCheck::Strict, we're in a nested union - if extra.check != SerCheck::Strict && retry_with_lax_check { - new_extra.check = SerCheck::Lax; - for comb_serializer in choices { - if let Ok(v) = comb_serializer.json_key(key, &new_extra) { - return Ok(v); - } - } - } - - // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings - if extra.check == SerCheck::None { - for err in &errors { - extra.warnings.custom_warning(err.to_string()); - } - } - // Otherwise, if we've encountered errors, return them to the parent union, which should take - // care of the formatting for us - else if !errors.is_empty() { - let message = errors.iter().map(ToString::to_string).collect::>().join("\n"); - return Err(PydanticSerializationUnexpectedValue::new_err(Some(message))); - } - infer_json_key(key, extra) -} - -#[allow(clippy::too_many_arguments)] -fn serde_serialize( - value: &Bound<'_, PyAny>, - serializer: S, - include: Option<&Bound<'_, PyAny>>, - exclude: Option<&Bound<'_, PyAny>>, - extra: &Extra, - choices: &[CombinedSerializer], - retry_with_lax_check: bool, -) -> Result { - let py = value.py(); - let mut new_extra = extra.clone(); - new_extra.check = SerCheck::Strict; - let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); - - for comb_serializer in choices { - match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), - Err(err) => errors.push(err), - } - } - - // If extra.check is SerCheck::Strict, we're in a nested union - if extra.check != SerCheck::Strict && retry_with_lax_check { - new_extra.check = SerCheck::Lax; - for comb_serializer in choices { - if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) { - return infer_serialize(v.bind(py), serializer, None, None, extra); + if let Some(tag) = discriminator_value { + let tag_str = tag.to_string(); + if let Some(&serializer_index) = lookup.get(&tag_str) { + let selected_serializer = &choices[serializer_index]; + + match selector(selected_serializer, &new_extra) { + Ok(v) => return Ok(Some(v)), + Err(_) => { + if retry_with_lax_check { + new_extra.check = SerCheck::Lax; + if let Ok(v) = selector(selected_serializer, &new_extra) { + return Ok(Some(v)); + } + } + } } } } - // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings - if extra.check == SerCheck::None { - for err in &errors { - extra.warnings.custom_warning(err.to_string()); - } - } else { - // NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors - // will have to be returned here - } - - infer_serialize(value, serializer, include, exclude, extra) + // if we haven't returned at this point, we should fallback to the union serializer + // which preserves the historical expectation that we do our best with serialization + // even if that means we resort to inference + union_serialize(selector, extra, choices, retry_with_lax_check) } impl TypeSerializer for UnionSerializer { @@ -211,18 +163,23 @@ impl TypeSerializer for UnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> PyResult { - to_python( - value, - include, - exclude, + union_serialize( + |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), extra, &self.choices, self.retry_with_lax_check(), - ) + )? + .map_or_else(|| infer_to_python(value, include, exclude, extra), Ok) } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { - json_key(key, extra, &self.choices, self.retry_with_lax_check()) + union_serialize( + |comb_serializer, new_extra| comb_serializer.json_key(key, new_extra), + extra, + &self.choices, + self.retry_with_lax_check(), + )? + .map_or_else(|| infer_json_key(key, extra), Ok) } fn serde_serialize( @@ -233,15 +190,16 @@ impl TypeSerializer for UnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> Result { - serde_serialize( - value, - serializer, - include, - exclude, + match union_serialize( + |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), extra, &self.choices, self.retry_with_lax_check(), - ) + ) { + Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra), + Ok(None) => infer_serialize(value, serializer, include, exclude, extra), + Err(err) => Err(serde::ser::Error::custom(err.to_string())), + } } fn get_name(&self) -> &str { @@ -309,62 +267,29 @@ impl TypeSerializer for TaggedUnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> PyResult { - let mut new_extra = extra.clone(); - new_extra.check = SerCheck::Strict; - - if let Some(tag) = self.get_discriminator_value(value, extra) { - let tag_str = tag.to_string(); - if let Some(&serializer_index) = self.lookup.get(&tag_str) { - let serializer = &self.choices[serializer_index]; - - match serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return Ok(v), - Err(_) => { - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - if let Ok(v) = serializer.to_python(value, include, exclude, &new_extra) { - return Ok(v); - } - } - } - } - } - } - - to_python( - value, - include, - exclude, + tagged_union_serialize( + self.get_discriminator_value(value, extra), + &self.lookup, + |comb_serializer: &CombinedSerializer, new_extra: &Extra| { + comb_serializer.to_python(value, include, exclude, new_extra) + }, extra, &self.choices, self.retry_with_lax_check(), - ) + )? + .map_or_else(|| infer_to_python(value, include, exclude, extra), Ok) } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { - let mut new_extra = extra.clone(); - new_extra.check = SerCheck::Strict; - - if let Some(tag) = self.get_discriminator_value(key, extra) { - let tag_str = tag.to_string(); - if let Some(&serializer_index) = self.lookup.get(&tag_str) { - let serializer = &self.choices[serializer_index]; - - match serializer.json_key(key, &new_extra) { - Ok(v) => return Ok(v), - Err(_) => { - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - if let Ok(v) = serializer.json_key(key, &new_extra) { - return Ok(v); - } - } - } - } - } - } - - json_key(key, extra, &self.choices, self.retry_with_lax_check()) + tagged_union_serialize( + self.get_discriminator_value(key, extra), + &self.lookup, + |comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra), + extra, + &self.choices, + self.retry_with_lax_check(), + )? + .map_or_else(|| infer_json_key(key, extra), Ok) } fn serde_serialize( @@ -375,38 +300,20 @@ impl TypeSerializer for TaggedUnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> Result { - let py = value.py(); - let mut new_extra = extra.clone(); - new_extra.check = SerCheck::Strict; - - if let Some(tag) = self.get_discriminator_value(value, extra) { - let tag_str = tag.to_string(); - if let Some(&serializer_index) = self.lookup.get(&tag_str) { - let selected_serializer = &self.choices[serializer_index]; - - match selected_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), - Err(_) => { - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - if let Ok(v) = selected_serializer.to_python(value, include, exclude, &new_extra) { - return infer_serialize(v.bind(py), serializer, None, None, extra); - } - } - } - } - } - } - - serde_serialize( - value, - serializer, - include, - exclude, + match tagged_union_serialize( + None, + &self.lookup, + |comb_serializer: &CombinedSerializer, new_extra: &Extra| { + comb_serializer.to_python(value, include, exclude, new_extra) + }, extra, &self.choices, self.retry_with_lax_check(), - ) + ) { + Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra), + Ok(None) => infer_serialize(value, serializer, include, exclude, extra), + Err(err) => Err(serde::ser::Error::custom(err.to_string())), + } } fn get_name(&self) -> &str {