Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify shared union serializer logic #1538

Merged
merged 9 commits into from
Nov 13, 2024
261 changes: 84 additions & 177 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use std::borrow::Cow;
use crate::build_tools::py_schema_err;
use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD};
use crate::definitions::DefinitionsBuilder;
use crate::serializers::PydanticSerializationUnexpectedValue;
use crate::tools::{truncate_safe_repr, SchemaDict};
use crate::PydanticSerializationUnexpectedValue;

use super::{
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerCheck,
Expand Down Expand Up @@ -70,22 +70,23 @@ impl UnionSerializer {

impl_py_gc_traverse!(UnionSerializer { choices });

fn to_python(
value: &Bound<'_, PyAny>,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
fn union_serialize<S>(
// if this returns `Ok(Some(v))`, we picked a union variant to serialize,
// Or `Ok(None)` if we couldn't find a suitable variant to serialize
// Finally, `Err(err)` if we encountered errors while trying to serialize
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
extra: &Extra,
choices: &[CombinedSerializer],
retry_with_lax_check: bool,
) -> PyResult<PyObject> {
) -> PyResult<Option<S>> {
// try the serializers in left to right order with error_on fallback=true
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();

for comb_serializer in choices {
match comb_serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return Ok(v),
match selector(comb_serializer, &new_extra) {
Ok(v) => return Ok(Some(v)),
Err(err) => errors.push(err),
}
}
Expand All @@ -94,8 +95,8 @@ fn to_python(
if extra.check != SerCheck::Strict && retry_with_lax_check {
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
return Ok(v);
if let Ok(v) = selector(comb_serializer, &new_extra) {
return Ok(Some(v));
}
}
}
Expand All @@ -113,94 +114,45 @@ fn to_python(
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
}

infer_to_python(value, include, exclude, extra)
Ok(None)
}

fn json_key<'a>(
key: &'a Bound<'_, PyAny>,
fn tagged_union_serialize<S>(
discriminator_value: Option<Py<PyAny>>,
lookup: &HashMap<String, usize>,
// if this returns `Ok(v)`, we picked a union variant to serialize, where
// `S` is intermediate state which can be passed on to the finalizer
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
extra: &Extra,
choices: &[CombinedSerializer],
retry_with_lax_check: bool,
) -> PyResult<Cow<'a, str>> {
) -> PyResult<Option<S>> {
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();

for comb_serializer in choices {
match comb_serializer.json_key(key, &new_extra) {
Ok(v) => return Ok(v),
Err(err) => errors.push(err),
}
}

// If extra.check is SerCheck::Strict, we're in a nested union
if extra.check != SerCheck::Strict && retry_with_lax_check {
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = comb_serializer.json_key(key, &new_extra) {
return Ok(v);
}
}
}

// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
if extra.check == SerCheck::None {
for err in &errors {
extra.warnings.custom_warning(err.to_string());
}
}
// Otherwise, if we've encountered errors, return them to the parent union, which should take
// care of the formatting for us
else if !errors.is_empty() {
let message = errors.iter().map(ToString::to_string).collect::<Vec<_>>().join("\n");
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
}
infer_json_key(key, extra)
}

#[allow(clippy::too_many_arguments)]
fn serde_serialize<S: serde::ser::Serializer>(
value: &Bound<'_, PyAny>,
serializer: S,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
choices: &[CombinedSerializer],
retry_with_lax_check: bool,
) -> Result<S::Ok, S::Error> {
let py = value.py();
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();

for comb_serializer in choices {
match comb_serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
Err(err) => errors.push(err),
}
}

// If extra.check is SerCheck::Strict, we're in a nested union
if extra.check != SerCheck::Strict && retry_with_lax_check {
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = comb_serializer.to_python(value, include, exclude, &new_extra) {
return infer_serialize(v.bind(py), serializer, None, None, extra);
if let Some(tag) = discriminator_value {
let tag_str = tag.to_string();
if let Some(&serializer_index) = lookup.get(&tag_str) {
let selected_serializer = &choices[serializer_index];

match selector(selected_serializer, &new_extra) {
Ok(v) => return Ok(Some(v)),
Err(_) => {
if retry_with_lax_check {
new_extra.check = SerCheck::Lax;
if let Ok(v) = selector(selected_serializer, &new_extra) {
return Ok(Some(v));
}
}
}
}
}
sydney-runkle marked this conversation as resolved.
Show resolved Hide resolved
}

// If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
if extra.check == SerCheck::None {
for err in &errors {
extra.warnings.custom_warning(err.to_string());
}
} else {
// NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors
// will have to be returned here
}

infer_serialize(value, serializer, include, exclude, extra)
// if we haven't returned at this point, we should fallback to the union serializer
// which preserves the historical expectation that we do our best with serialization
// even if that means we resort to inference
union_serialize(selector, extra, choices, retry_with_lax_check)
}

impl TypeSerializer for UnionSerializer {
Expand All @@ -211,18 +163,23 @@ impl TypeSerializer for UnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> PyResult<PyObject> {
to_python(
value,
include,
exclude,
union_serialize(
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
extra,
&self.choices,
self.retry_with_lax_check(),
)
)?
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok)
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
json_key(key, extra, &self.choices, self.retry_with_lax_check())
union_serialize(
|comb_serializer, new_extra| comb_serializer.json_key(key, new_extra),
extra,
&self.choices,
self.retry_with_lax_check(),
)?
.map_or_else(|| infer_json_key(key, extra), Ok)
}

fn serde_serialize<S: serde::ser::Serializer>(
Expand All @@ -233,15 +190,16 @@ impl TypeSerializer for UnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
serde_serialize(
value,
serializer,
include,
exclude,
match union_serialize(
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
extra,
&self.choices,
self.retry_with_lax_check(),
)
) {
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra),
Ok(None) => infer_serialize(value, serializer, include, exclude, extra),
Err(err) => Err(serde::ser::Error::custom(err.to_string())),
}
}

fn get_name(&self) -> &str {
Expand Down Expand Up @@ -309,62 +267,29 @@ impl TypeSerializer for TaggedUnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> PyResult<PyObject> {
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

if let Some(tag) = self.get_discriminator_value(value, extra) {
let tag_str = tag.to_string();
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
let serializer = &self.choices[serializer_index];

match serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return Ok(v),
Err(_) => {
if self.retry_with_lax_check() {
new_extra.check = SerCheck::Lax;
if let Ok(v) = serializer.to_python(value, include, exclude, &new_extra) {
return Ok(v);
}
}
}
}
}
}

to_python(
value,
include,
exclude,
tagged_union_serialize(
self.get_discriminator_value(value, extra),
&self.lookup,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
comb_serializer.to_python(value, include, exclude, new_extra)
},
extra,
&self.choices,
self.retry_with_lax_check(),
)
)?
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok)
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

if let Some(tag) = self.get_discriminator_value(key, extra) {
let tag_str = tag.to_string();
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
let serializer = &self.choices[serializer_index];

match serializer.json_key(key, &new_extra) {
Ok(v) => return Ok(v),
Err(_) => {
if self.retry_with_lax_check() {
new_extra.check = SerCheck::Lax;
if let Ok(v) = serializer.json_key(key, &new_extra) {
return Ok(v);
}
}
}
}
}
}

json_key(key, extra, &self.choices, self.retry_with_lax_check())
tagged_union_serialize(
self.get_discriminator_value(key, extra),
&self.lookup,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra),
extra,
&self.choices,
self.retry_with_lax_check(),
)?
.map_or_else(|| infer_json_key(key, extra), Ok)
}

fn serde_serialize<S: serde::ser::Serializer>(
Expand All @@ -375,38 +300,20 @@ impl TypeSerializer for TaggedUnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
let py = value.py();
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

if let Some(tag) = self.get_discriminator_value(value, extra) {
let tag_str = tag.to_string();
if let Some(&serializer_index) = self.lookup.get(&tag_str) {
let selected_serializer = &self.choices[serializer_index];

match selected_serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return infer_serialize(v.bind(py), serializer, None, None, extra),
Err(_) => {
if self.retry_with_lax_check() {
new_extra.check = SerCheck::Lax;
if let Ok(v) = selected_serializer.to_python(value, include, exclude, &new_extra) {
return infer_serialize(v.bind(py), serializer, None, None, extra);
}
}
}
}
}
}

serde_serialize(
value,
serializer,
include,
exclude,
match tagged_union_serialize(
None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sydney-runkle this is the source of the perf regression; accidentally switched off tagged union serialization optimization in the JSON case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😬 oops. Great find!

&self.lookup,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
comb_serializer.to_python(value, include, exclude, new_extra)
},
extra,
&self.choices,
self.retry_with_lax_check(),
)
) {
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra),
Ok(None) => infer_serialize(value, serializer, include, exclude, extra),
Err(err) => Err(serde::ser::Error::custom(err.to_string())),
}
}

fn get_name(&self) -> &str {
Expand Down
Loading