Skip to content

Commit

Permalink
WIP: Simplify shared union serializer logic (#1538)
Browse files Browse the repository at this point in the history
Co-authored-by: David Hewitt <mail@davidhewitt.dev>
  • Loading branch information
sydney-runkle and davidhewitt authored Nov 13, 2024
1 parent 061711f commit 6472887
Showing 1 changed file with 84 additions and 177 deletions.
261 changes: 84 additions & 177 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use std::borrow::Cow;
use crate::build_tools::py_schema_err;
use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD};
use crate::definitions::DefinitionsBuilder;
use crate::serializers::PydanticSerializationUnexpectedValue;
use crate::tools::{truncate_safe_repr, SchemaDict};
use crate::PydanticSerializationUnexpectedValue;

use super::{
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerCheck,
Expand Down Expand Up @@ -70,22 +70,23 @@ impl UnionSerializer {

impl_py_gc_traverse!(UnionSerializer { choices });

fn to_python(
value: &Bound<'_, PyAny>,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
fn union_serialize<S>(
// if this returns `Ok(Some(v))`, we picked a union variant to serialize,
// Or `Ok(None)` if we couldn't find a suitable variant to serialize
// Finally, `Err(err)` if we encountered errors while trying to serialize
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
extra: &Extra,
choices: &[CombinedSerializer],
retry_with_lax_check: bool,
) -> PyResult<PyObject> {
) -> PyResult<Option<S>> {
// try the serializers in left to right order with error_on fallback=true
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();

for comb_serializer in choices {
match comb_serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return Ok(v),
match selector(comb_serializer, &new_extra) {
Ok(v) => return Ok(Some(v)),
Err(err) => errors.push(err),
}
}
Expand All @@ -94,8 +95,8 @@ fn to_python(
if extra.check != SerCheck::Strict && retry_with_lax_check {
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
return Ok(v);
if let Ok(v) = selector(comb_serializer, &new_extra) {
return Ok(Some(v));
}
}
}
Expand All @@ -113,94 +114,45 @@ fn to_python(
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
}

infer_to_python(value, include, exclude, extra)
Ok(None)
}

fn json_key<'a>(
key: &'a Bound<'_, PyAny>,
fn tagged_union_serialize<S>(
discriminator_value: Option<Py<PyAny>>,
lookup: &HashMap<String, usize>,
// if this returns `Ok(v)`, we picked a union variant to serialize, where
// `S` is intermediate state which can be passed on to the finalizer
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
extra: &Extra,
choices: &[CombinedSerializer],
retry_with_lax_check: bool,
) -> PyResult<Cow<'a, str>> {
) -> PyResult<Option<S>> {
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();

for comb_serializer in choices {
match comb_serializer.json_key(key, &new_extra) {
Ok(v) => return Ok(v),
Err(err) => errors.push(err),
}
}

// If extra.check is SerCheck::Strict, we're in a nested union
if extra.check != SerCheck::Strict && retry_with_lax_check {
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = comb_serializer.json_key(key, &new_extra) {
return Ok(v);
}
}
}

// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
if extra.check == SerCheck::None {
for err in &errors {
extra.warnings.custom_warning(err.to_string());
}
}
// Otherwise, if we've encountered errors, return them to the parent union, which should take
// care of the formatting for us
else if !errors.is_empty() {
let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n");
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
}
infer_json_key(key, extra)
}

#[allow(clippy::too_many_arguments)]
fn serde_serialize<S: serde::ser::Serializer>(
value: &Bound<'_, PyAny>,
serializer: S,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
choices: &[CombinedSerializer],
retry_with_lax_check: bool,
) -> Result<S::Ok, S::Error> {
let py = value.py();
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();

for comb_serializer in choices {
match comb_serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
Err(err) => errors.push(err),
}
}

// If extra.check is SerCheck::Strict, we're in a nested union
if extra.check != SerCheck::Strict && retry_with_lax_check {
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
return infer_serialize(v.bind(py), serializer, None, None, extra);
if let Some(tag) = discriminator_value {
let tag_str = tag.to_string();
if let Some(&serializer_index) = lookup.get(&tag_str) {
let selected_serializer = &choices[serializer_index];

match selector(selected_serializer, &new_extra) {
Ok(v) => return Ok(Some(v)),
Err(_) => {
if retry_with_lax_check {
new_extra.check = SerCheck::Lax;
if let Ok(v) = selector(selected_serializer, &new_extra) {
return Ok(Some(v));
}
}
}
}
}
}

// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
if extra.check == SerCheck::None {
for err in &errors {
extra.warnings.custom_warning(err.to_string());
}
} else {
// NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors
// will have to be returned here
}

infer_serialize(value, serializer, include, exclude, extra)
// if we haven't returned at this point, we should fallback to the union serializer
// which preserves the historical expectation that we do our best with serialization
// even if that means we resort to inference
union_serialize(selector, extra, choices, retry_with_lax_check)
}

impl TypeSerializer for UnionSerializer {
Expand All @@ -211,18 +163,23 @@ impl TypeSerializer for UnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> PyResult<PyObject> {
to_python(
value,
include,
exclude,
union_serialize(
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
extra,
&self.choices,
self.retry_with_lax_check(),
)
)?
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok)
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
json_key(key, extra, &self.choices, self.retry_with_lax_check())
union_serialize(
|comb_serializer, new_extra| comb_serializer.json_key(key, new_extra),
extra,
&self.choices,
self.retry_with_lax_check(),
)?
.map_or_else(|| infer_json_key(key, extra), Ok)
}

fn serde_serialize<S: serde::ser::Serializer>(
Expand All @@ -233,15 +190,16 @@ impl TypeSerializer for UnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
serde_serialize(
value,
serializer,
include,
exclude,
match union_serialize(
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
extra,
&self.choices,
self.retry_with_lax_check(),
)
) {
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra),
Ok(None) => infer_serialize(value, serializer, include, exclude, extra),
Err(err) => Err(serde::ser::Error::custom(err.to_string())),
}
}

fn get_name(&self) -> &str {
Expand Down Expand Up @@ -309,62 +267,29 @@ impl TypeSerializer for TaggedUnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> PyResult<PyObject> {
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

if let Some(tag) = self.get_discriminator_value(value, extra) {
let tag_str = tag.to_string();
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
let serializer = &self.choices[serializer_index];

match serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return Ok(v),
Err(_) => {
if self.retry_with_lax_check() {
new_extra.check = SerCheck::Lax;
if let Ok(v) = serializer.to_python(value, include, exclude, &new_extra) {
return Ok(v);
}
}
}
}
}
}

to_python(
value,
include,
exclude,
tagged_union_serialize(
self.get_discriminator_value(value, extra),
&self.lookup,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
comb_serializer.to_python(value, include, exclude, new_extra)
},
extra,
&self.choices,
self.retry_with_lax_check(),
)
)?
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok)
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

if let Some(tag) = self.get_discriminator_value(key, extra) {
let tag_str = tag.to_string();
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
let serializer = &self.choices[serializer_index];

match serializer.json_key(key, &new_extra) {
Ok(v) => return Ok(v),
Err(_) => {
if self.retry_with_lax_check() {
new_extra.check = SerCheck::Lax;
if let Ok(v) = serializer.json_key(key, &new_extra) {
return Ok(v);
}
}
}
}
}
}

json_key(key, extra, &self.choices, self.retry_with_lax_check())
tagged_union_serialize(
self.get_discriminator_value(key, extra),
&self.lookup,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra),
extra,
&self.choices,
self.retry_with_lax_check(),
)?
.map_or_else(|| infer_json_key(key, extra), Ok)
}

fn serde_serialize<S: serde::ser::Serializer>(
Expand All @@ -375,38 +300,20 @@ impl TypeSerializer for TaggedUnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
let py = value.py();
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

if let Some(tag) = self.get_discriminator_value(value, extra) {
let tag_str = tag.to_string();
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
let selected_serializer = &self.choices[serializer_index];

match selected_serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
Err(_) => {
if self.retry_with_lax_check() {
new_extra.check = SerCheck::Lax;
if let Ok(v) = selected_serializer.to_python(value, include, exclude, &new_extra) {
return infer_serialize(v.bind(py), serializer, None, None, extra);
}
}
}
}
}
}

serde_serialize(
value,
serializer,
include,
exclude,
match tagged_union_serialize(
None,
&self.lookup,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
comb_serializer.to_python(value, include, exclude, new_extra)
},
extra,
&self.choices,
self.retry_with_lax_check(),
)
) {
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra),
Ok(None) => infer_serialize(value, serializer, include, exclude, extra),
Err(err) => Err(serde::ser::Error::custom(err.to_string())),
}
}

fn get_name(&self) -> &str {
Expand Down

0 comments on commit 6472887

Please sign in to comment.