From f0379045903f702c8177d7aa06de17643d804c89 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Mon, 11 Nov 2024 15:24:49 +0000 Subject: [PATCH 1/8] wip: simplify unions --- src/serializers/type_serializers/union.rs | 145 ++++++++-------------- 1 file changed, 55 insertions(+), 90 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index ac674ef34..d39fe83fc 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -70,50 +70,60 @@ impl UnionSerializer { impl_py_gc_traverse!(UnionSerializer { choices }); -fn to_python( - value: &Bound<'_, PyAny>, - include: Option<&Bound<'_, PyAny>>, - exclude: Option<&Bound<'_, PyAny>>, +fn union_serialize( + // if this returns `Ok(v)`, we picked a union variant to serialize, where + // `S` is intermediate state which can be passed on to the finalizer + mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult, + // if called with `Some(v)`, we have intermediate state to finish + // if `None`, we need to just go to fallback + finalizer: impl FnOnce(Option) -> R, extra: &Extra, choices: &[CombinedSerializer], retry_with_lax_check: bool, -) -> PyResult { +) -> R { // try the serializers in left to right order with error_on fallback=true let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); for comb_serializer in choices { - match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return Ok(v), + match selector(comb_serializer, &new_extra) { + Ok(v) => return finalizer(Some(v)), Err(err) => errors.push(err), } } - // If extra.check is SerCheck::Strict, we're in a nested union - if extra.check != SerCheck::Strict && retry_with_lax_check { + if retry_with_lax_check { new_extra.check = SerCheck::Lax; for comb_serializer in choices { - if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) { - return Ok(v); + if let Ok(v) = selector(comb_serializer, &new_extra) { + return finalizer(Some(v)); } } } - // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings - if extra.check == SerCheck::None { - for err in &errors { - extra.warnings.custom_warning(err.to_string()); - } - } - // Otherwise, if we've encountered errors, return them to the parent union, which should take - // care of the formatting for us - else if !errors.is_empty() { - let message = errors.iter().map(ToString::to_string).collect::>().join("\n"); - return Err(PydanticSerializationUnexpectedValue::new_err(Some(message))); + for err in &errors { + extra.warnings.custom_warning(err.to_string()); } - infer_to_python(value, include, exclude, extra) + finalizer(None) +} + +fn to_python( + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + choices: &[CombinedSerializer], + retry_with_lax_check: bool, +) -> PyResult { + union_serialize( + |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), + |v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok), + extra, + choices, + retry_with_lax_check, + ) } fn json_key<'a>( @@ -122,40 +132,13 @@ fn json_key<'a>( choices: &[CombinedSerializer], retry_with_lax_check: bool, ) -> PyResult> { - let mut new_extra = extra.clone(); - new_extra.check = SerCheck::Strict; - let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); - - for comb_serializer in choices { - match comb_serializer.json_key(key, &new_extra) { - Ok(v) => return Ok(v), - Err(err) => errors.push(err), - } - } - - // If extra.check is SerCheck::Strict, we're in a nested union - if extra.check != SerCheck::Strict && retry_with_lax_check { - new_extra.check = SerCheck::Lax; - for comb_serializer in choices { - if let Ok(v) = comb_serializer.json_key(key, &new_extra) { - return Ok(v); - } - } - } - - // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings - if extra.check == SerCheck::None { - for err in &errors { - extra.warnings.custom_warning(err.to_string()); - } - } - // Otherwise, if we've encountered errors, return them to the parent union, which should take - // care of the formatting for us - else if !errors.is_empty() { - let message = errors.iter().map(ToString::to_string).collect::>().join("\n"); - return Err(PydanticSerializationUnexpectedValue::new_err(Some(message))); - } - infer_json_key(key, extra) + union_serialize( + |comb_serializer, new_extra| comb_serializer.json_key(key, new_extra), + |v| v.map_or_else(|| infer_json_key(key, extra), Ok), + extra, + choices, + retry_with_lax_check, + ) } #[allow(clippy::too_many_arguments)] @@ -168,39 +151,21 @@ fn serde_serialize( choices: &[CombinedSerializer], retry_with_lax_check: bool, ) -> Result { - let py = value.py(); - let mut new_extra = extra.clone(); - new_extra.check = SerCheck::Strict; - let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new(); - - for comb_serializer in choices { - match comb_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), - Err(err) => errors.push(err), - } - } - - // If extra.check is SerCheck::Strict, we're in a nested union - if extra.check != SerCheck::Strict && retry_with_lax_check { - new_extra.check = SerCheck::Lax; - for comb_serializer in choices { - if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) { - return infer_serialize(v.bind(py), serializer, None, None, extra); - } - } - } - - // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings - if extra.check == SerCheck::None { - for err in &errors { - extra.warnings.custom_warning(err.to_string()); - } - } else { - // NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors - // will have to be returned here - } - - infer_serialize(value, serializer, include, exclude, extra) + union_serialize( + |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), + |v| { + infer_serialize( + v.as_ref().map_or(value, |v| v.bind(value.py())), + serializer, + None, + None, + extra, + ) + }, + extra, + choices, + retry_with_lax_check, + ) } impl TypeSerializer for UnionSerializer { From a5fbc2a17432ac90a339e528dd0454ac4e2385ad Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 11 Nov 2024 11:09:52 -0500 Subject: [PATCH 2/8] remove middleman functions --- src/serializers/type_serializers/union.rs | 139 +++++++++------------- 1 file changed, 56 insertions(+), 83 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 6ca702792..d3fb3682c 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -9,7 +9,6 @@ use crate::build_tools::py_schema_err; use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD}; use crate::definitions::DefinitionsBuilder; use crate::tools::{truncate_safe_repr, SchemaDict}; -use crate::PydanticSerializationUnexpectedValue; use super::{ infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerCheck, @@ -93,7 +92,8 @@ fn union_serialize( } } - if retry_with_lax_check { + // If extra.check is SerCheck::Strict, we're in a nested union + if extra.check != SerCheck::Strict && retry_with_lax_check { new_extra.check = SerCheck::Lax; for comb_serializer in choices { if let Ok(v) = selector(comb_serializer, &new_extra) { @@ -102,72 +102,23 @@ fn union_serialize( } } - for err in &errors { - extra.warnings.custom_warning(err.to_string()); + // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings + if extra.check == SerCheck::None { + for err in &errors { + extra.warnings.custom_warning(err.to_string()); + } } + // Otherwise, if we've encountered errors, return them to the parent union, which should take + // care of the formatting for us + // TODO: change up return type to support this + // else if !errors.is_empty() { + // let message = errors.iter().map(ToString::to_string).collect::>().join("\n"); + // return Err(PydanticSerializationUnexpectedValue::new_err(Some(message))); + // } finalizer(None) } -fn to_python( - value: &Bound<'_, PyAny>, - include: Option<&Bound<'_, PyAny>>, - exclude: Option<&Bound<'_, PyAny>>, - extra: &Extra, - choices: &[CombinedSerializer], - retry_with_lax_check: bool, -) -> PyResult { - union_serialize( - |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), - |v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok), - extra, - choices, - retry_with_lax_check, - ) -} - -fn json_key<'a>( - key: &'a Bound<'_, PyAny>, - extra: &Extra, - choices: &[CombinedSerializer], - retry_with_lax_check: bool, -) -> PyResult> { - union_serialize( - |comb_serializer, new_extra| comb_serializer.json_key(key, new_extra), - |v| v.map_or_else(|| infer_json_key(key, extra), Ok), - extra, - choices, - retry_with_lax_check, - ) -} - -#[allow(clippy::too_many_arguments)] -fn serde_serialize( - value: &Bound<'_, PyAny>, - serializer: S, - include: Option<&Bound<'_, PyAny>>, - exclude: Option<&Bound<'_, PyAny>>, - extra: &Extra, - choices: &[CombinedSerializer], - retry_with_lax_check: bool, -) -> Result { - union_serialize( - |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), - |v| { - infer_serialize( - v.as_ref().map_or(value, |v| v.bind(value.py())), - serializer, - None, - None, - extra, - ) - }, - extra, - choices, - retry_with_lax_check, - ) -} - impl TypeSerializer for UnionSerializer { fn to_python( &self, @@ -176,10 +127,9 @@ impl TypeSerializer for UnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> PyResult { - to_python( - value, - include, - exclude, + union_serialize( + |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), + |v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok), extra, &self.choices, self.retry_with_lax_check(), @@ -187,7 +137,13 @@ impl TypeSerializer for UnionSerializer { } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { - json_key(key, extra, &self.choices, self.retry_with_lax_check()) + union_serialize( + |comb_serializer, new_extra| comb_serializer.json_key(key, new_extra), + |v| v.map_or_else(|| infer_json_key(key, extra), Ok), + extra, + &self.choices, + self.retry_with_lax_check(), + ) } fn serde_serialize( @@ -198,11 +154,17 @@ impl TypeSerializer for UnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> Result { - serde_serialize( - value, - serializer, - include, - exclude, + union_serialize( + |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), + |v| { + infer_serialize( + v.as_ref().map_or(value, |v| v.bind(value.py())), + serializer, + None, + None, + extra, + ) + }, extra, &self.choices, self.retry_with_lax_check(), @@ -296,10 +258,9 @@ impl TypeSerializer for TaggedUnionSerializer { } } - to_python( - value, - include, - exclude, + union_serialize( + |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), + |v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok), extra, &self.choices, self.retry_with_lax_check(), @@ -329,7 +290,13 @@ impl TypeSerializer for TaggedUnionSerializer { } } - json_key(key, extra, &self.choices, self.retry_with_lax_check()) + union_serialize( + |comb_serializer, new_extra| comb_serializer.json_key(key, new_extra), + |v| v.map_or_else(|| infer_json_key(key, extra), Ok), + extra, + &self.choices, + self.retry_with_lax_check(), + ) } fn serde_serialize( @@ -363,11 +330,17 @@ impl TypeSerializer for TaggedUnionSerializer { } } - serde_serialize( - value, - serializer, - include, - exclude, + union_serialize( + |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), + |v| { + infer_serialize( + v.as_ref().map_or(value, |v| v.bind(value.py())), + serializer, + None, + None, + extra, + ) + }, extra, &self.choices, self.retry_with_lax_check(), From 742a3e9316eb21e955beabd72f38e7886bcd1ed8 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 11 Nov 2024 11:17:47 -0500 Subject: [PATCH 3/8] add support for the raised ser unexpected err --- src/serializers/type_serializers/union.rs | 24 ++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index d3fb3682c..2cb5780c0 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -8,6 +8,7 @@ use std::borrow::Cow; use crate::build_tools::py_schema_err; use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD}; use crate::definitions::DefinitionsBuilder; +use crate::serializers::PydanticSerializationUnexpectedValue; use crate::tools::{truncate_safe_repr, SchemaDict}; use super::{ @@ -79,7 +80,7 @@ fn union_serialize( extra: &Extra, choices: &[CombinedSerializer], retry_with_lax_check: bool, -) -> R { +) -> PyResult { // try the serializers in left to right order with error_on fallback=true let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; @@ -87,7 +88,7 @@ fn union_serialize( for comb_serializer in choices { match selector(comb_serializer, &new_extra) { - Ok(v) => return finalizer(Some(v)), + Ok(v) => return Ok(finalizer(Some(v))), Err(err) => errors.push(err), } } @@ -97,7 +98,7 @@ fn union_serialize( new_extra.check = SerCheck::Lax; for comb_serializer in choices { if let Ok(v) = selector(comb_serializer, &new_extra) { - return finalizer(Some(v)); + return Ok(finalizer(Some(v))); } } } @@ -110,13 +111,12 @@ fn union_serialize( } // Otherwise, if we've encountered errors, return them to the parent union, which should take // care of the formatting for us - // TODO: change up return type to support this - // else if !errors.is_empty() { - // let message = errors.iter().map(ToString::to_string).collect::>().join("\n"); - // return Err(PydanticSerializationUnexpectedValue::new_err(Some(message))); - // } + else if !errors.is_empty() { + let message = errors.iter().map(ToString::to_string).collect::>().join("\n"); + return Err(PydanticSerializationUnexpectedValue::new_err(Some(message))); + } - finalizer(None) + Ok(finalizer(None)) } impl TypeSerializer for UnionSerializer { @@ -134,6 +134,7 @@ impl TypeSerializer for UnionSerializer { &self.choices, self.retry_with_lax_check(), ) + .unwrap() } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { @@ -144,6 +145,7 @@ impl TypeSerializer for UnionSerializer { &self.choices, self.retry_with_lax_check(), ) + .unwrap() } fn serde_serialize( @@ -169,6 +171,7 @@ impl TypeSerializer for UnionSerializer { &self.choices, self.retry_with_lax_check(), ) + .unwrap() } fn get_name(&self) -> &str { @@ -265,6 +268,7 @@ impl TypeSerializer for TaggedUnionSerializer { &self.choices, self.retry_with_lax_check(), ) + .unwrap() } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { @@ -297,6 +301,7 @@ impl TypeSerializer for TaggedUnionSerializer { &self.choices, self.retry_with_lax_check(), ) + .unwrap() } fn serde_serialize( @@ -345,6 +350,7 @@ impl TypeSerializer for TaggedUnionSerializer { &self.choices, self.retry_with_lax_check(), ) + .unwrap() } fn get_name(&self) -> &str { From 92f5bbcde1c7fa08be2afb28e6b15ffd9b6c6326 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 11 Nov 2024 12:16:38 -0500 Subject: [PATCH 4/8] fix tests --- src/serializers/type_serializers/union.rs | 38 ++++++++++++++--------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 2cb5780c0..385d8dae1 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -133,8 +133,7 @@ impl TypeSerializer for UnionSerializer { extra, &self.choices, self.retry_with_lax_check(), - ) - .unwrap() + )? } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { @@ -144,8 +143,7 @@ impl TypeSerializer for UnionSerializer { extra, &self.choices, self.retry_with_lax_check(), - ) - .unwrap() + )? } fn serde_serialize( @@ -156,7 +154,7 @@ impl TypeSerializer for UnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> Result { - union_serialize( + match union_serialize( |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), |v| { infer_serialize( @@ -170,8 +168,15 @@ impl TypeSerializer for UnionSerializer { extra, &self.choices, self.retry_with_lax_check(), - ) - .unwrap() + ) { + Ok(v) => v, + // TODO: we don't expect to hit this branch, but if we do, we should change the return + // type of this function to return a PyResult... + Err(err) => { + let message = err.to_string(); + Err(serde::ser::Error::custom(message)) + } + } } fn get_name(&self) -> &str { @@ -267,8 +272,7 @@ impl TypeSerializer for TaggedUnionSerializer { extra, &self.choices, self.retry_with_lax_check(), - ) - .unwrap() + )? } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { @@ -300,8 +304,7 @@ impl TypeSerializer for TaggedUnionSerializer { extra, &self.choices, self.retry_with_lax_check(), - ) - .unwrap() + )? } fn serde_serialize( @@ -335,7 +338,7 @@ impl TypeSerializer for TaggedUnionSerializer { } } - union_serialize( + match union_serialize( |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), |v| { infer_serialize( @@ -349,8 +352,15 @@ impl TypeSerializer for TaggedUnionSerializer { extra, &self.choices, self.retry_with_lax_check(), - ) - .unwrap() + ) { + Ok(v) => v, + // TODO: we don't expect to hit this branch, but if we do, we should change the return + // type of this function to return a PyResult... + Err(err) => { + let message = err.to_string(); + Err(serde::ser::Error::custom(message)) + } + } } fn get_name(&self) -> &str { From 57c3c6f064b8afbceb4260a43d420afc274237a4 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 11 Nov 2024 12:20:10 -0500 Subject: [PATCH 5/8] cleaner error handling --- src/serializers/type_serializers/union.rs | 26 ++++++----------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 385d8dae1..3694878a7 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -154,7 +154,7 @@ impl TypeSerializer for UnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> Result { - match union_serialize( + union_serialize( |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), |v| { infer_serialize( @@ -168,15 +168,8 @@ impl TypeSerializer for UnionSerializer { extra, &self.choices, self.retry_with_lax_check(), - ) { - Ok(v) => v, - // TODO: we don't expect to hit this branch, but if we do, we should change the return - // type of this function to return a PyResult... - Err(err) => { - let message = err.to_string(); - Err(serde::ser::Error::custom(message)) - } - } + ) + .map_err(|err| serde::ser::Error::custom(err.to_string()))? } fn get_name(&self) -> &str { @@ -338,7 +331,7 @@ impl TypeSerializer for TaggedUnionSerializer { } } - match union_serialize( + union_serialize( |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), |v| { infer_serialize( @@ -352,15 +345,8 @@ impl TypeSerializer for TaggedUnionSerializer { extra, &self.choices, self.retry_with_lax_check(), - ) { - Ok(v) => v, - // TODO: we don't expect to hit this branch, but if we do, we should change the return - // type of this function to return a PyResult... - Err(err) => { - let message = err.to_string(); - Err(serde::ser::Error::custom(message)) - } - } + ) + .map_err(|err| serde::ser::Error::custom(err.to_string()))? } fn get_name(&self) -> &str { From 7ea0d645870fd894152216c70e49775aa6a27f05 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 11 Nov 2024 15:06:58 -0500 Subject: [PATCH 6/8] checkpoint - using tagged_union_serializer too --- src/serializers/type_serializers/union.rs | 187 ++++++++++++---------- 1 file changed, 103 insertions(+), 84 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 3694878a7..47cf1aa1b 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -4,6 +4,7 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyTuple}; use smallvec::SmallVec; use std::borrow::Cow; +use std::sync::Arc; use crate::build_tools::py_schema_err; use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD}; @@ -119,6 +120,41 @@ fn union_serialize( Ok(finalizer(None)) } +fn tagged_union_serialize( + discriminator_value: Option>, + lookup: &HashMap, + // if this returns `Ok(v)`, we picked a union variant to serialize, where + // `S` is intermediate state which can be passed on to the finalizer + mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult, + extra: &Extra, + choices: &Vec, + retry_with_lax_check: bool, +) -> Option { + let mut new_extra = extra.clone(); + new_extra.check = SerCheck::Strict; + + if let Some(tag) = discriminator_value { + let tag_str = tag.to_string(); + if let Some(&serializer_index) = lookup.get(&tag_str) { + let selected_serializer = &choices[serializer_index]; + + match selector(&selected_serializer, &new_extra) { + Ok(v) => return Some(v), + Err(_) => { + if retry_with_lax_check { + new_extra.check = SerCheck::Lax; + if let Ok(v) = selector(&selected_serializer, &new_extra) { + return Some(v); + } + } + } + } + } + } + + None +} + impl TypeSerializer for UnionSerializer { fn to_python( &self, @@ -237,67 +273,56 @@ impl TypeSerializer for TaggedUnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> PyResult { - let mut new_extra = extra.clone(); - new_extra.check = SerCheck::Strict; - - if let Some(tag) = self.get_discriminator_value(value, extra) { - let tag_str = tag.to_string(); - if let Some(&serializer_index) = self.lookup.get(&tag_str) { - let serializer = &self.choices[serializer_index]; - - match serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return Ok(v), - Err(_) => { - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - if let Ok(v) = serializer.to_python(value, include, exclude, &new_extra) { - return Ok(v); - } - } - } - } - } - } + let to_python_selector = |comb_serializer: &CombinedSerializer, new_extra: &Extra| { + comb_serializer.to_python(value, include, exclude, new_extra) + }; - union_serialize( - |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), - |v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok), + tagged_union_serialize( + self.get_discriminator_value(value, extra), + &self.lookup, + to_python_selector, extra, &self.choices, self.retry_with_lax_check(), - )? + ) + .map_or_else( + || { + union_serialize( + to_python_selector, + |v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok), + extra, + &self.choices, + self.retry_with_lax_check(), + )? + }, + Ok, + ) } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { - let mut new_extra = extra.clone(); - new_extra.check = SerCheck::Strict; - - if let Some(tag) = self.get_discriminator_value(key, extra) { - let tag_str = tag.to_string(); - if let Some(&serializer_index) = self.lookup.get(&tag_str) { - let serializer = &self.choices[serializer_index]; - - match serializer.json_key(key, &new_extra) { - Ok(v) => return Ok(v), - Err(_) => { - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - if let Ok(v) = serializer.json_key(key, &new_extra) { - return Ok(v); - } - } - } - } - } - } + let json_key_selector = + |comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra); - union_serialize( - |comb_serializer, new_extra| comb_serializer.json_key(key, new_extra), - |v| v.map_or_else(|| infer_json_key(key, extra), Ok), + tagged_union_serialize( + self.get_discriminator_value(key, extra), + &self.lookup, + json_key_selector, extra, &self.choices, self.retry_with_lax_check(), - )? + ) + .map_or_else( + || { + union_serialize( + json_key_selector, + |v| v.map_or_else(|| infer_json_key(key, extra), Ok), + extra, + &self.choices, + self.retry_with_lax_check(), + )? + }, + Ok, + ) } fn serde_serialize( @@ -308,45 +333,39 @@ impl TypeSerializer for TaggedUnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> Result { - let py = value.py(); - let mut new_extra = extra.clone(); - new_extra.check = SerCheck::Strict; - - if let Some(tag) = self.get_discriminator_value(value, extra) { - let tag_str = tag.to_string(); - if let Some(&serializer_index) = self.lookup.get(&tag_str) { - let selected_serializer = &self.choices[serializer_index]; - - match selected_serializer.to_python(value, include, exclude, &new_extra) { - Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra), - Err(_) => { - if self.retry_with_lax_check() { - new_extra.check = SerCheck::Lax; - if let Ok(v) = selected_serializer.to_python(value, include, exclude, &new_extra) { - return infer_serialize(v.bind(py), serializer, None, None, extra); - } - } - } - } - } - } + let serde_selector = |comb_serializer: &CombinedSerializer, new_extra: &Extra| { + comb_serializer.to_python(value, include, exclude, new_extra) + }; - union_serialize( - |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), - |v| { - infer_serialize( - v.as_ref().map_or(value, |v| v.bind(value.py())), - serializer, - None, - None, - extra, - ) - }, + tagged_union_serialize( + None, + &self.lookup, + serde_selector, extra, &self.choices, self.retry_with_lax_check(), ) - .map_err(|err| serde::ser::Error::custom(err.to_string()))? + .map_or_else( + || { + union_serialize( + serde_selector, + |v| { + infer_serialize( + v.as_ref().map_or(value, |v| v.bind(value.py())), + serializer, + None, + None, + extra, + ) + }, + extra, + &self.choices, + self.retry_with_lax_check(), + ) + .map_err(|err| serde::ser::Error::custom(err.to_string()))? + }, + |v| infer_serialize(v.bind(value.py()), serializer, None, None, extra), + ) } fn get_name(&self) -> &str { From f3a304a4b7a1e6bb2d867f09e25a8a5b406c056d Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 11 Nov 2024 15:14:00 -0500 Subject: [PATCH 7/8] figure out issue with borrowed serializer --- src/serializers/type_serializers/union.rs | 43 +++++++++++------------ 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 47cf1aa1b..77ffb23d2 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -4,7 +4,6 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyTuple}; use smallvec::SmallVec; use std::borrow::Cow; -use std::sync::Arc; use crate::build_tools::py_schema_err; use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD}; @@ -127,7 +126,7 @@ fn tagged_union_serialize( // `S` is intermediate state which can be passed on to the finalizer mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult, extra: &Extra, - choices: &Vec, + choices: &[CombinedSerializer], retry_with_lax_check: bool, ) -> Option { let mut new_extra = extra.clone(); @@ -138,12 +137,12 @@ fn tagged_union_serialize( if let Some(&serializer_index) = lookup.get(&tag_str) { let selected_serializer = &choices[serializer_index]; - match selector(&selected_serializer, &new_extra) { + match selector(selected_serializer, &new_extra) { Ok(v) => return Some(v), Err(_) => { if retry_with_lax_check { new_extra.check = SerCheck::Lax; - if let Ok(v) = selector(&selected_serializer, &new_extra) { + if let Ok(v) = selector(selected_serializer, &new_extra) { return Some(v); } } @@ -337,35 +336,33 @@ impl TypeSerializer for TaggedUnionSerializer { comb_serializer.to_python(value, include, exclude, new_extra) }; - tagged_union_serialize( + if let Some(v) = tagged_union_serialize( None, &self.lookup, serde_selector, extra, &self.choices, self.retry_with_lax_check(), - ) - .map_or_else( - || { - union_serialize( - serde_selector, - |v| { - infer_serialize( - v.as_ref().map_or(value, |v| v.bind(value.py())), - serializer, - None, - None, - extra, - ) - }, + ) { + return infer_serialize(v.bind(value.py()), serializer, None, None, extra); + } + + union_serialize( + serde_selector, + |v| { + infer_serialize( + v.as_ref().map_or(value, |v| v.bind(value.py())), + serializer, + None, + None, extra, - &self.choices, - self.retry_with_lax_check(), ) - .map_err(|err| serde::ser::Error::custom(err.to_string()))? }, - |v| infer_serialize(v.bind(value.py()), serializer, None, None, extra), + extra, + &self.choices, + self.retry_with_lax_check(), ) + .map_err(|err| serde::ser::Error::custom(err.to_string()))? } fn get_name(&self) -> &str { From ee9dd3d1fca65afe941e933300ddf48e4bd6168e Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Tue, 12 Nov 2024 13:08:30 -0500 Subject: [PATCH 8/8] continue refactor --- src/serializers/type_serializers/union.rs | 125 +++++++--------------- 1 file changed, 38 insertions(+), 87 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 77ffb23d2..40a72e9dd 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -70,17 +70,15 @@ impl UnionSerializer { impl_py_gc_traverse!(UnionSerializer { choices }); -fn union_serialize( - // if this returns `Ok(v)`, we picked a union variant to serialize, where - // `S` is intermediate state which can be passed on to the finalizer +fn union_serialize( + // if this returns `Ok(Some(v))`, we picked a union variant to serialize, + // Or `Ok(None)` if we couldn't find a suitable variant to serialize + // Finally, `Err(err)` if we encountered errors while trying to serialize mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult, - // if called with `Some(v)`, we have intermediate state to finish - // if `None`, we need to just go to fallback - finalizer: impl FnOnce(Option) -> R, extra: &Extra, choices: &[CombinedSerializer], retry_with_lax_check: bool, -) -> PyResult { +) -> PyResult> { // try the serializers in left to right order with error_on fallback=true let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; @@ -88,7 +86,7 @@ fn union_serialize( for comb_serializer in choices { match selector(comb_serializer, &new_extra) { - Ok(v) => return Ok(finalizer(Some(v))), + Ok(v) => return Ok(Some(v)), Err(err) => errors.push(err), } } @@ -98,7 +96,7 @@ fn union_serialize( new_extra.check = SerCheck::Lax; for comb_serializer in choices { if let Ok(v) = selector(comb_serializer, &new_extra) { - return Ok(finalizer(Some(v))); + return Ok(Some(v)); } } } @@ -116,7 +114,7 @@ fn union_serialize( return Err(PydanticSerializationUnexpectedValue::new_err(Some(message))); } - Ok(finalizer(None)) + Ok(None) } fn tagged_union_serialize( @@ -128,7 +126,7 @@ fn tagged_union_serialize( extra: &Extra, choices: &[CombinedSerializer], retry_with_lax_check: bool, -) -> Option { +) -> PyResult> { let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; @@ -138,12 +136,12 @@ fn tagged_union_serialize( let selected_serializer = &choices[serializer_index]; match selector(selected_serializer, &new_extra) { - Ok(v) => return Some(v), + Ok(v) => return Ok(Some(v)), Err(_) => { if retry_with_lax_check { new_extra.check = SerCheck::Lax; if let Ok(v) = selector(selected_serializer, &new_extra) { - return Some(v); + return Ok(Some(v)); } } } @@ -151,7 +149,10 @@ fn tagged_union_serialize( } } - None + // if we haven't returned at this point, we should fallback to the union serializer + // which preserves the historical expectation that we do our best with serialization + // even if that means we resort to inference + union_serialize(selector, extra, choices, retry_with_lax_check) } impl TypeSerializer for UnionSerializer { @@ -164,21 +165,21 @@ impl TypeSerializer for UnionSerializer { ) -> PyResult { union_serialize( |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), - |v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok), extra, &self.choices, self.retry_with_lax_check(), )? + .map_or_else(|| infer_to_python(value, include, exclude, extra), Ok) } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { union_serialize( |comb_serializer, new_extra| comb_serializer.json_key(key, new_extra), - |v| v.map_or_else(|| infer_json_key(key, extra), Ok), extra, &self.choices, self.retry_with_lax_check(), )? + .map_or_else(|| infer_json_key(key, extra), Ok) } fn serde_serialize( @@ -189,22 +190,16 @@ impl TypeSerializer for UnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> Result { - union_serialize( + match union_serialize( |comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra), - |v| { - infer_serialize( - v.as_ref().map_or(value, |v| v.bind(value.py())), - serializer, - None, - None, - extra, - ) - }, extra, &self.choices, self.retry_with_lax_check(), - ) - .map_err(|err| serde::ser::Error::custom(err.to_string()))? + ) { + Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra), + Ok(None) => infer_serialize(value, serializer, include, exclude, extra), + Err(err) => Err(serde::ser::Error::custom(err.to_string())), + } } fn get_name(&self) -> &str { @@ -272,56 +267,29 @@ impl TypeSerializer for TaggedUnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> PyResult { - let to_python_selector = |comb_serializer: &CombinedSerializer, new_extra: &Extra| { - comb_serializer.to_python(value, include, exclude, new_extra) - }; - tagged_union_serialize( self.get_discriminator_value(value, extra), &self.lookup, - to_python_selector, + |comb_serializer: &CombinedSerializer, new_extra: &Extra| { + comb_serializer.to_python(value, include, exclude, new_extra) + }, extra, &self.choices, self.retry_with_lax_check(), - ) - .map_or_else( - || { - union_serialize( - to_python_selector, - |v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok), - extra, - &self.choices, - self.retry_with_lax_check(), - )? - }, - Ok, - ) + )? + .map_or_else(|| infer_to_python(value, include, exclude, extra), Ok) } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { - let json_key_selector = - |comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra); - tagged_union_serialize( self.get_discriminator_value(key, extra), &self.lookup, - json_key_selector, + |comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra), extra, &self.choices, self.retry_with_lax_check(), - ) - .map_or_else( - || { - union_serialize( - json_key_selector, - |v| v.map_or_else(|| infer_json_key(key, extra), Ok), - extra, - &self.choices, - self.retry_with_lax_check(), - )? - }, - Ok, - ) + )? + .map_or_else(|| infer_json_key(key, extra), Ok) } fn serde_serialize( @@ -332,37 +300,20 @@ impl TypeSerializer for TaggedUnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> Result { - let serde_selector = |comb_serializer: &CombinedSerializer, new_extra: &Extra| { - comb_serializer.to_python(value, include, exclude, new_extra) - }; - - if let Some(v) = tagged_union_serialize( + match tagged_union_serialize( None, &self.lookup, - serde_selector, + |comb_serializer: &CombinedSerializer, new_extra: &Extra| { + comb_serializer.to_python(value, include, exclude, new_extra) + }, extra, &self.choices, self.retry_with_lax_check(), ) { - return infer_serialize(v.bind(value.py()), serializer, None, None, extra); + Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra), + Ok(None) => infer_serialize(value, serializer, include, exclude, extra), + Err(err) => Err(serde::ser::Error::custom(err.to_string())), } - - union_serialize( - serde_selector, - |v| { - infer_serialize( - v.as_ref().map_or(value, |v| v.bind(value.py())), - serializer, - None, - None, - extra, - ) - }, - extra, - &self.choices, - self.retry_with_lax_check(), - ) - .map_err(|err| serde::ser::Error::custom(err.to_string()))? } fn get_name(&self) -> &str {