Skip to content

Commit

Permalink
Fix warning for tuple of wrong size in union (#1174)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt authored Feb 1, 2024
1 parent 758bc51 commit dcaf63e
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 126 deletions.
244 changes: 118 additions & 126 deletions src/serializers/type_serializers/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ use std::iter;
use serde::ser::SerializeSeq;

use crate::definitions::DefinitionsBuilder;
use crate::serializers::extra::SerCheck;
use crate::serializers::type_serializers::any::AnySerializer;
use crate::tools::SchemaDict;
use crate::PydanticSerializationUnexpectedValue;

use super::{
infer_json_key, infer_serialize, infer_to_python, py_err_se_err, BuildSerializer, CombinedSerializer, Extra,
Expand Down Expand Up @@ -70,52 +72,14 @@ impl TypeSerializer for TupleSerializer {
let py = value.py();

let n_items = py_tuple.len();
let mut py_tuple_iter = py_tuple.iter();
let mut items = Vec::with_capacity(n_items);

macro_rules! use_serializers {
($serializers_iter:expr) => {
for (index, serializer) in $serializers_iter.enumerate() {
let element = match py_tuple_iter.next() {
Some(value) => value,
None => break,
};
let op_next = self
.filter
.index_filter(index, include, exclude, Some(n_items))?;
if let Some((next_include, next_exclude)) = op_next {
items.push(serializer.to_python(element, next_include, next_exclude, extra)?);
}
}
};
}

if let Some(variadic_item_index) = self.variadic_item_index {
// Need `saturating_sub` to handle items with too few elements without panicking
let n_variadic_items = (n_items + 1).saturating_sub(self.serializers.len());
let serializers_iter = self.serializers[..variadic_item_index]
.iter()
.chain(iter::repeat(&self.serializers[variadic_item_index]).take(n_variadic_items))
.chain(self.serializers[variadic_item_index + 1..].iter());
use_serializers!(serializers_iter);
} else {
use_serializers!(self.serializers.iter());
let mut warned = false;
for (i, element) in py_tuple_iter.enumerate() {
if !warned {
extra
.warnings
.custom_warning("Unexpected extra items present in tuple".to_string());
warned = true;
}
let op_next =
self.filter
.index_filter(i + self.serializers.len(), include, exclude, Some(n_items))?;
if let Some((next_include, next_exclude)) = op_next {
items.push(AnySerializer.to_python(element, next_include, next_exclude, extra)?);
}
}
};
self.for_each_tuple_item_and_serializer(py_tuple, include, exclude, extra, |entry| {
entry
.serializer
.to_python(entry.item, entry.include, entry.exclude, extra)
.map(|item| items.push(item))
})??;

match extra.mode {
SerMode::Json => Ok(PyList::new(py, items).into_py(py)),
Expand All @@ -132,35 +96,14 @@ impl TypeSerializer for TupleSerializer {
fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult<Cow<'py, str>> {
match key.downcast::<PyTuple>() {
Ok(py_tuple) => {
let mut py_tuple_iter = py_tuple.iter();

let mut key_builder = KeyBuilder::new();

let n_items = py_tuple.len();

macro_rules! use_serializers {
($serializers_iter:expr) => {
for serializer in $serializers_iter {
let element = match py_tuple_iter.next() {
Some(value) => value,
None => break,
};
key_builder.push(&serializer.json_key(element, extra)?);
}
};
}

if let Some(variadic_item_index) = self.variadic_item_index {
// Need `saturating_sub` to handle items with too few elements without panicking
let n_variadic_items = (n_items + 1).saturating_sub(self.serializers.len());
let serializers_iter = self.serializers[..variadic_item_index]
.iter()
.chain(iter::repeat(&self.serializers[variadic_item_index]).take(n_variadic_items))
.chain(self.serializers[variadic_item_index + 1..].iter());
use_serializers!(serializers_iter);
} else {
use_serializers!(self.serializers.iter());
};
self.for_each_tuple_item_and_serializer(py_tuple, None, None, extra, |entry| {
entry
.serializer
.json_key(entry.item, extra)
.map(|key| key_builder.push(&key))
})??;

Ok(Cow::Owned(key_builder.finish()))
}
Expand All @@ -184,63 +127,18 @@ impl TypeSerializer for TupleSerializer {
let py_tuple: &PyTuple = py_tuple.downcast().map_err(py_err_se_err)?;

let n_items = py_tuple.len();
let mut py_tuple_iter = py_tuple.iter();
let mut seq = serializer.serialize_seq(Some(n_items))?;

macro_rules! use_serializers {
($serializers_iter:expr) => {
for (index, serializer) in $serializers_iter.enumerate() {
let element = match py_tuple_iter.next() {
Some(value) => value,
None => break,
};
let op_next = self
.filter
.index_filter(index, include, exclude, Some(n_items))
.map_err(py_err_se_err)?;
if let Some((next_include, next_exclude)) = op_next {
let item_serialize =
PydanticSerializer::new(element, serializer, next_include, next_exclude, extra);
seq.serialize_element(&item_serialize)?;
}
}
};
}

if let Some(variadic_item_index) = self.variadic_item_index {
// Need `saturating_sub` to handle items with too few elements without panicking
let n_variadic_items = (n_items + 1).saturating_sub(self.serializers.len());
let serializers_iter = self.serializers[..variadic_item_index]
.iter()
.chain(iter::repeat(&self.serializers[variadic_item_index]).take(n_variadic_items))
.chain(self.serializers[variadic_item_index + 1..].iter());
use_serializers!(serializers_iter);
} else {
use_serializers!(self.serializers.iter());
let mut warned = false;
for (i, element) in py_tuple_iter.enumerate() {
if !warned {
extra
.warnings
.custom_warning("Unexpected extra items present in tuple".to_string());
warned = true;
}
let op_next = self
.filter
.index_filter(i + self.serializers.len(), include, exclude, Some(n_items))
.map_err(py_err_se_err)?;
if let Some((next_include, next_exclude)) = op_next {
let item_serialize = PydanticSerializer::new(
element,
&CombinedSerializer::Any(AnySerializer),
next_include,
next_exclude,
extra,
);
seq.serialize_element(&item_serialize)?;
}
}
};
self.for_each_tuple_item_and_serializer(py_tuple, include, exclude, extra, |entry| {
seq.serialize_element(&PydanticSerializer::new(
entry.item,
entry.serializer,
entry.include,
entry.exclude,
extra,
))
})
.map_err(py_err_se_err)??;

seq.end()
}
Expand All @@ -254,6 +152,100 @@ impl TypeSerializer for TupleSerializer {
fn get_name(&self) -> &str {
&self.name
}

fn retry_with_lax_check(&self) -> bool {
true
}
}

struct TupleSerializerEntry<'a, 'py> {
item: &'py PyAny,
include: Option<&'py PyAny>,
exclude: Option<&'py PyAny>,
serializer: &'a CombinedSerializer,
}

impl TupleSerializer {
/// Try to serialize each item in the tuple with the corresponding serializer.
///
/// If the tuple doesn't match the length of the serializer, in strict mode, an error is returned.
///
/// The error type E is the type of the error returned by the closure, which is why there are two
/// levels of `Result`.
fn for_each_tuple_item_and_serializer<E>(
&self,
tuple: &PyTuple,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
mut f: impl for<'a, 'py> FnMut(TupleSerializerEntry<'a, 'py>) -> Result<(), E>,
) -> PyResult<Result<(), E>> {
let n_items = tuple.len();
let mut py_tuple_iter = tuple.iter();

macro_rules! use_serializers {
($serializers_iter:expr) => {
for (index, serializer) in $serializers_iter.enumerate() {
let element = match py_tuple_iter.next() {
Some(value) => value,
None => break,
};
let op_next = self.filter.index_filter(index, include, exclude, Some(n_items))?;
if let Some((next_include, next_exclude)) = op_next {
if let Err(e) = f(TupleSerializerEntry {
item: element,
include: next_include,
exclude: next_exclude,
serializer,
}) {
return Ok(Err(e));
};
}
}
};
}

if let Some(variadic_item_index) = self.variadic_item_index {
// Need `saturating_sub` to handle items with too few elements without panicking
let n_variadic_items = (n_items + 1).saturating_sub(self.serializers.len());
let serializers_iter = self.serializers[..variadic_item_index]
.iter()
.chain(iter::repeat(&self.serializers[variadic_item_index]).take(n_variadic_items))
.chain(self.serializers[variadic_item_index + 1..].iter());
use_serializers!(serializers_iter);
} else if extra.check == SerCheck::Strict && n_items != self.serializers.len() {
return Err(PydanticSerializationUnexpectedValue::new_err(Some(format!(
"Expected {} items, but got {}",
self.serializers.len(),
n_items
))));
} else {
use_serializers!(self.serializers.iter());
let mut warned = false;
for (i, element) in py_tuple_iter.enumerate() {
if !warned {
extra
.warnings
.custom_warning("Unexpected extra items present in tuple".to_string());
warned = true;
}
let op_next = self
.filter
.index_filter(i + self.serializers.len(), include, exclude, Some(n_items))?;
if let Some((next_include, next_exclude)) = op_next {
if let Err(e) = f(TupleSerializerEntry {
item: element,
include: next_include,
exclude: next_exclude,
serializer: &CombinedSerializer::Any(AnySerializer),
}) {
return Ok(Err(e));
};
}
}
};
Ok(Ok(()))
}
}

pub(crate) struct KeyBuilder {
Expand Down
26 changes: 26 additions & 0 deletions tests/serializers/test_list_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,29 @@ def test_tuple_pos_dict_key():
assert s.to_python({(1, 'a', 2): 1}, mode='json') == {'1,a,2': 1}
assert s.to_json({(1, 'a'): 1}) == b'{"1,a":1}'
assert s.to_json({(1, 'a', 2): 1}) == b'{"1,a,2":1}'


def test_tuple_wrong_size_union():
# See https://github.com/pydantic/pydantic/issues/8677

f = core_schema.float_schema()
s = SchemaSerializer(
core_schema.union_schema([core_schema.tuple_schema([f, f]), core_schema.tuple_schema([f, f, f])])
)
assert s.to_python((1.0, 2.0)) == (1.0, 2.0)
assert s.to_python((1.0, 2.0, 3.0)) == (1.0, 2.0, 3.0)

with pytest.warns(UserWarning, match='Unexpected extra items present in tuple'):
s.to_python((1.0, 2.0, 3.0, 4.0))

assert s.to_python((1.0, 2.0), mode='json') == [1.0, 2.0]
assert s.to_python((1.0, 2.0, 3.0), mode='json') == [1.0, 2.0, 3.0]

with pytest.warns(UserWarning, match='Unexpected extra items present in tuple'):
s.to_python((1.0, 2.0, 3.0, 4.0), mode='json')

assert s.to_json((1.0, 2.0)) == b'[1.0,2.0]'
assert s.to_json((1.0, 2.0, 3.0)) == b'[1.0,2.0,3.0]'

with pytest.warns(UserWarning, match='Unexpected extra items present in tuple'):
s.to_json((1.0, 2.0, 3.0, 4.0))

0 comments on commit dcaf63e

Please sign in to comment.