Skip to content

Commit

Permalink
continue refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle committed Nov 12, 2024
1 parent f3a304a commit ee9dd3d
Showing 1 changed file with 38 additions and 87 deletions.
125 changes: 38 additions & 87 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,23 @@ impl UnionSerializer {

impl_py_gc_traverse!(UnionSerializer { choices });

fn union_serialize<S, R>(
// if this returns `Ok(v)`, we picked a union variant to serialize, where
// `S` is intermediate state which can be passed on to the finalizer
fn union_serialize<S>(
// if this returns `Ok(Some(v))`, we picked a union variant to serialize,
// Or `Ok(None)` if we couldn't find a suitable variant to serialize
// Finally, `Err(err)` if we encountered errors while trying to serialize
mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult<S>,
// if called with `Some(v)`, we have intermediate state to finish
// if `None`, we need to just go to fallback
finalizer: impl FnOnce(Option<S>) -> R,
extra: &Extra,
choices: &[CombinedSerializer],
retry_with_lax_check: bool,
) -> PyResult<R> {
) -> PyResult<Option<S>> {
// try the serializers in left to right order with error_on fallback=true
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;
let mut errors: SmallVec<[PyErr; SMALL_UNION_THRESHOLD]> = SmallVec::new();

for comb_serializer in choices {
match selector(comb_serializer, &new_extra) {
Ok(v) => return Ok(finalizer(Some(v))),
Ok(v) => return Ok(Some(v)),
Err(err) => errors.push(err),
}
}
Expand All @@ -98,7 +96,7 @@ fn union_serialize<S, R>(
new_extra.check = SerCheck::Lax;
for comb_serializer in choices {
if let Ok(v) = selector(comb_serializer, &new_extra) {
return Ok(finalizer(Some(v)));
return Ok(Some(v));
}
}
}
Expand All @@ -116,7 +114,7 @@ fn union_serialize<S, R>(
return Err(PydanticSerializationUnexpectedValue::new_err(Some(message)));
}

Ok(finalizer(None))
Ok(None)
}

fn tagged_union_serialize<S>(
Expand All @@ -128,7 +126,7 @@ fn tagged_union_serialize<S>(
extra: &Extra,
choices: &[CombinedSerializer],
retry_with_lax_check: bool,
) -> Option<S> {
) -> PyResult<Option<S>> {
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

Expand All @@ -138,20 +136,23 @@ fn tagged_union_serialize<S>(
let selected_serializer = &choices[serializer_index];

match selector(selected_serializer, &new_extra) {
Ok(v) => return Some(v),
Ok(v) => return Ok(Some(v)),
Err(_) => {
if retry_with_lax_check {
new_extra.check = SerCheck::Lax;
if let Ok(v) = selector(selected_serializer, &new_extra) {
return Some(v);
return Ok(Some(v));
}
}
}
}
}
}

None
// if we haven't returned at this point, we should fallback to the union serializer
// which preserves the historical expectation that we do our best with serialization
// even if that means we resort to inference
union_serialize(selector, extra, choices, retry_with_lax_check)
}

impl TypeSerializer for UnionSerializer {
Expand All @@ -164,21 +165,21 @@ impl TypeSerializer for UnionSerializer {
) -> PyResult<PyObject> {
union_serialize(
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
|v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok),
extra,
&self.choices,
self.retry_with_lax_check(),
)?
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok)
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
union_serialize(
|comb_serializer, new_extra| comb_serializer.json_key(key, new_extra),
|v| v.map_or_else(|| infer_json_key(key, extra), Ok),
extra,
&self.choices,
self.retry_with_lax_check(),
)?
.map_or_else(|| infer_json_key(key, extra), Ok)
}

fn serde_serialize<S: serde::ser::Serializer>(
Expand All @@ -189,22 +190,16 @@ impl TypeSerializer for UnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
union_serialize(
match union_serialize(
|comb_serializer, new_extra| comb_serializer.to_python(value, include, exclude, new_extra),
|v| {
infer_serialize(
v.as_ref().map_or(value, |v| v.bind(value.py())),
serializer,
None,
None,
extra,
)
},
extra,
&self.choices,
self.retry_with_lax_check(),
)
.map_err(|err| serde::ser::Error::custom(err.to_string()))?
) {
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra),
Ok(None) => infer_serialize(value, serializer, include, exclude, extra),
Err(err) => Err(serde::ser::Error::custom(err.to_string())),
}
}

fn get_name(&self) -> &str {
Expand Down Expand Up @@ -272,56 +267,29 @@ impl TypeSerializer for TaggedUnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> PyResult<PyObject> {
let to_python_selector = |comb_serializer: &CombinedSerializer, new_extra: &Extra| {
comb_serializer.to_python(value, include, exclude, new_extra)
};

tagged_union_serialize(
self.get_discriminator_value(value, extra),
&self.lookup,
to_python_selector,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
comb_serializer.to_python(value, include, exclude, new_extra)
},
extra,
&self.choices,
self.retry_with_lax_check(),
)
.map_or_else(
|| {
union_serialize(
to_python_selector,
|v| v.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok),
extra,
&self.choices,
self.retry_with_lax_check(),
)?
},
Ok,
)
)?
.map_or_else(|| infer_to_python(value, include, exclude, extra), Ok)
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
let json_key_selector =
|comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra);

tagged_union_serialize(
self.get_discriminator_value(key, extra),
&self.lookup,
json_key_selector,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra),
extra,
&self.choices,
self.retry_with_lax_check(),
)
.map_or_else(
|| {
union_serialize(
json_key_selector,
|v| v.map_or_else(|| infer_json_key(key, extra), Ok),
extra,
&self.choices,
self.retry_with_lax_check(),
)?
},
Ok,
)
)?
.map_or_else(|| infer_json_key(key, extra), Ok)
}

fn serde_serialize<S: serde::ser::Serializer>(
Expand All @@ -332,37 +300,20 @@ impl TypeSerializer for TaggedUnionSerializer {
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
let serde_selector = |comb_serializer: &CombinedSerializer, new_extra: &Extra| {
comb_serializer.to_python(value, include, exclude, new_extra)
};

if let Some(v) = tagged_union_serialize(
match tagged_union_serialize(
None,
&self.lookup,
serde_selector,
|comb_serializer: &CombinedSerializer, new_extra: &Extra| {
comb_serializer.to_python(value, include, exclude, new_extra)
},
extra,
&self.choices,
self.retry_with_lax_check(),
) {
return infer_serialize(v.bind(value.py()), serializer, None, None, extra);
Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra),
Ok(None) => infer_serialize(value, serializer, include, exclude, extra),
Err(err) => Err(serde::ser::Error::custom(err.to_string())),
}

union_serialize(
serde_selector,
|v| {
infer_serialize(
v.as_ref().map_or(value, |v| v.bind(value.py())),
serializer,
None,
None,
extra,
)
},
extra,
&self.choices,
self.retry_with_lax_check(),
)
.map_err(|err| serde::ser::Error::custom(err.to_string()))?
}

fn get_name(&self) -> &str {
Expand Down

0 comments on commit ee9dd3d

Please sign in to comment.