diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 553f0b33c..db46f06fe 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -1,7 +1,7 @@ use std::fmt; use pyo3::exceptions::PyValueError; -use pyo3::types::{PyDict, PyType}; +use pyo3::types::{PyDict, PyList, PyType}; use pyo3::{intern, prelude::*}; use crate::errors::{ErrorTypeDefaults, InputValue, ValError, ValResult}; @@ -42,6 +42,8 @@ impl TryFrom<&str> for InputType { } } +pub type ValMatch = ValResult>; + /// all types have three methods: `validate_*`, `strict_*`, `lax_*` /// the convention is to either implement: /// * `strict_*` & `lax_*` if they have different behavior @@ -87,13 +89,13 @@ pub trait Input<'py>: fmt::Debug + ToPyObject { fn validate_dataclass_args<'a>(&'a self, dataclass_name: &str) -> ValResult>; - fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValResult>>; + fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValMatch>; - fn validate_bytes<'a>(&'a self, strict: bool) -> ValResult>>; + fn validate_bytes<'a>(&'a self, strict: bool) -> ValMatch>; - fn validate_bool(&self, strict: bool) -> ValResult>; + fn validate_bool(&self, strict: bool) -> ValMatch; - fn validate_int(&self, strict: bool) -> ValResult>>; + fn validate_int(&self, strict: bool) -> ValMatch>; fn exact_int(&self) -> ValResult> { self.validate_int(true).and_then(|val_match| { @@ -113,7 +115,7 @@ pub trait Input<'py>: fmt::Debug + ToPyObject { }) } - fn validate_float(&self, strict: bool) -> ValResult>>; + fn validate_float(&self, strict: bool) -> ValMatch>; fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValResult> { if strict { @@ -145,18 +147,11 @@ pub trait Input<'py>: fmt::Debug + ToPyObject { self.validate_dict(strict) } - fn validate_list<'a>(&'a self, strict: bool) -> ValResult> { - if strict { - self.strict_list() - } else { - self.lax_list() - } - } - fn strict_list<'a>(&'a self) -> ValResult>; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_list<'a>(&'a self) -> ValResult> { - self.strict_list() - } + type List<'a>: Iterable<'py> + AsPyList<'py> + where + Self: 'a; + + fn validate_list(&self, strict: bool) -> ValMatch>; fn validate_tuple<'a>(&'a self, strict: bool) -> ValResult> { if strict { @@ -201,25 +196,25 @@ pub trait Input<'py>: fmt::Debug + ToPyObject { fn validate_iter(&self) -> ValResult; - fn validate_date(&self, strict: bool) -> ValResult>>; + fn validate_date(&self, strict: bool) -> ValMatch>; fn validate_time( &self, strict: bool, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult>>; + ) -> ValMatch>; fn validate_datetime( &self, strict: bool, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult>>; + ) -> ValMatch>; fn validate_timedelta( &self, strict: bool, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult>>; + ) -> ValMatch>; } /// The problem to solve here is that iterating collections often returns owned @@ -238,3 +233,42 @@ impl<'py, T: Input<'py> + ?Sized> BorrowInput<'py> for &'_ T { self } } + +pub enum Never {} + +// Pairs with Iterable below +pub trait ConsumeIterator { + type Output; + fn consume_iterator(self, iterator: impl Iterator) -> Self::Output; +} + +// This slightly awkward trait is used to define types which can be iterable. This formulation +// arises because the Python enums have several different underlying iterator types, and we want to +// be able to dispatch over each of them without overhead. +pub trait Iterable<'py> { + type Input: BorrowInput<'py>; + fn len(&self) -> Option; + fn iterate(self, consumer: impl ConsumeIterator, Output = R>) -> ValResult; +} + +// Necessary for inputs which don't support certain types, e.g. String -> list +impl<'py> Iterable<'py> for Never { + type Input = Bound<'py, PyAny>; // Doesn't really matter what this is + fn len(&self) -> Option { + unreachable!() + } + fn iterate(self, _consumer: impl ConsumeIterator, Output = R>) -> ValResult { + unreachable!() + } +} + +// Optimization pathway for inputs which are already python lists +pub trait AsPyList<'py>: Iterable<'py> { + fn as_py_list(&self) -> Option<&Bound<'py, PyList>>; +} + +impl<'py> AsPyList<'py> for Never { + fn as_py_list(&self) -> Option<&Bound<'py, PyList>> { + unreachable!() + } +} diff --git a/src/input/input_json.rs b/src/input/input_json.rs index fc24b9375..1bd26f259 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -2,7 +2,8 @@ use std::borrow::Cow; use jiter::{JsonArray, JsonValue}; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyString}; +use pyo3::types::{PyDict, PyList, PyString}; +use smallvec::SmallVec; use speedate::MicrosecondsPrecisionOverflowBehavior; use strum::EnumMessage; @@ -13,6 +14,7 @@ use super::datetime::{ bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime, }; +use super::input_abstract::{AsPyList, ConsumeIterator, Iterable, Never, ValMatch}; use super::return_enums::ValidationMatch; use super::shared::{float_as_int, int_as_bool, str_as_bool, str_as_float, str_as_int}; use super::{ @@ -37,7 +39,7 @@ impl From> for LocItem { } } -impl<'py> Input<'py> for JsonValue<'_> { +impl<'py, 'data> Input<'py> for JsonValue<'data> { fn as_error_value(&self) -> InputValue { // cloning JsonValue is cheap due to use of Arc InputValue::Json(self.clone().into_static()) @@ -172,16 +174,14 @@ impl<'py> Input<'py> for JsonValue<'_> { self.validate_dict(false) } - fn validate_list<'a>(&'a self, _strict: bool) -> ValResult> { + type List<'a> = &'a JsonArray<'data> where Self: 'a; + + fn validate_list(&self, _strict: bool) -> ValMatch<&JsonArray<'data>> { match self { - JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Array(a) => Ok(ValidationMatch::strict(a)), _ => Err(ValError::new(ErrorTypeDefaults::ListType, self)), } } - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn strict_list<'a>(&'a self) -> ValResult> { - self.validate_list(false) - } fn validate_tuple<'a>(&'a self, _strict: bool) -> ValResult> { // just as in set's case, List has to be allowed @@ -375,8 +375,9 @@ impl<'py> Input<'py> for str { Err(ValError::new(ErrorTypeDefaults::DictType, self)) } - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn strict_list<'a>(&'a self) -> ValResult> { + type List<'a> = Never; + + fn validate_list(&self, _strict: bool) -> ValMatch { Err(ValError::new(ErrorTypeDefaults::ListType, self)) } @@ -449,3 +450,20 @@ impl BorrowInput<'_> for String { fn string_to_vec(s: &str) -> JsonArray { JsonArray::new(s.chars().map(|c| JsonValue::Str(c.to_string().into())).collect()) } + +impl<'a, 'data> Iterable<'_> for &'a JsonArray<'data> { + type Input = &'a JsonValue<'data>; + + fn len(&self) -> Option { + Some(SmallVec::len(self)) + } + fn iterate(self, consumer: impl ConsumeIterator, Output = R>) -> ValResult { + Ok(consumer.consume_iterator(self.iter().map(Ok))) + } +} + +impl<'py> AsPyList<'py> for &'_ JsonArray<'_> { + fn as_py_list(&self) -> Option<&Bound<'py, PyList>> { + None + } +} diff --git a/src/input/input_python.rs b/src/input/input_python.rs index edc9aec04..1c192c290 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -23,6 +23,7 @@ use super::datetime::{ float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime, }; +use super::input_abstract::ValMatch; use super::return_enums::ValidationMatch; use super::shared::{ decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, str_as_bool, str_as_float, str_as_int, @@ -461,24 +462,25 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { } } - fn strict_list<'a>(&'a self) -> ValResult> { - match self.lax_list()? { - GenericIterable::List(iter) => Ok(GenericIterable::List(iter)), - _ => Err(ValError::new(ErrorTypeDefaults::ListType, self)), - } - } + type List<'a> = GenericIterable<'a, 'py> where Self: 'a; - fn lax_list<'a>(&'a self) -> ValResult> { - match self - .extract_generic_iterable() - .map_err(|_| ValError::new(ErrorTypeDefaults::ListType, self))? - { - GenericIterable::PyString(_) - | GenericIterable::Bytes(_) - | GenericIterable::Dict(_) - | GenericIterable::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::ListType, self)), - other => Ok(other), + fn validate_list<'a>(&'a self, strict: bool) -> ValMatch> { + if let Ok(list) = self.downcast::() { + return Ok(ValidationMatch::exact(GenericIterable::List(list))); + } else if !strict { + match self.extract_generic_iterable() { + Ok( + GenericIterable::PyString(_) + | GenericIterable::Bytes(_) + | GenericIterable::Dict(_) + | GenericIterable::Mapping(_), + ) + | Err(_) => {} + Ok(other) => return Ok(ValidationMatch::lax(other)), + } } + + Err(ValError::new(ErrorTypeDefaults::ListType, self)) } fn strict_tuple<'a>(&'a self) -> ValResult> { diff --git a/src/input/input_string.rs b/src/input/input_string.rs index 486dff5ff..76c1a9f39 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -11,6 +11,7 @@ use crate::validators::decimal::create_decimal; use super::datetime::{ bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime, }; +use super::input_abstract::{Never, ValMatch}; use super::shared::{str_as_bool, str_as_float, str_as_int}; use super::{ BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, @@ -138,7 +139,9 @@ impl<'py> Input<'py> for StringMapping<'py> { } } - fn strict_list<'a>(&'a self) -> ValResult> { + type List<'a> = Never where Self: 'a; + + fn validate_list(&self, _strict: bool) -> ValMatch { Err(ValError::new(ErrorTypeDefaults::ListType, self)) } diff --git a/src/input/mod.rs b/src/input/mod.rs index d7ca0a5bf..1d5047005 100644 --- a/src/input/mod.rs +++ b/src/input/mod.rs @@ -15,12 +15,13 @@ pub(crate) use datetime::{ duration_as_pytimedelta, pydate_as_date, pydatetime_as_datetime, pytime_as_time, EitherDate, EitherDateTime, EitherTime, EitherTimedelta, }; -pub(crate) use input_abstract::{BorrowInput, Input, InputType}; +pub(crate) use input_abstract::{AsPyList, BorrowInput, ConsumeIterator, Input, InputType, Iterable}; pub(crate) use input_string::StringMapping; pub(crate) use return_enums::{ - py_string_str, AttributesGenericIterator, DictGenericIterator, EitherBytes, EitherFloat, EitherInt, EitherString, - GenericArguments, GenericIterable, GenericIterator, GenericMapping, Int, JsonArgs, JsonObjectGenericIterator, - MappingGenericIterator, PyArgs, StringMappingGenericIterator, ValidationMatch, + no_validator_iter_to_vec, py_string_str, validate_iter_to_vec, AttributesGenericIterator, DictGenericIterator, + EitherBytes, EitherFloat, EitherInt, EitherString, GenericArguments, GenericIterable, GenericIterator, + GenericMapping, Int, JsonArgs, JsonObjectGenericIterator, MappingGenericIterator, MaxLengthCheck, PyArgs, + StringMappingGenericIterator, ValidationMatch, }; // Defined here as it's not exported by pyo3 diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index fc9e4d279..235d075a6 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -26,6 +26,7 @@ use crate::errors::{ use crate::tools::{extract_i64, py_err}; use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator}; +use super::input_abstract::{AsPyList, ConsumeIterator, Iterable}; use super::input_string::StringMapping; use super::{py_error_on_minusone, BorrowInput, Input}; @@ -121,9 +122,31 @@ impl<'py> GenericIterable<'_, 'py> { GenericIterable::JsonString(s) => Ok(Box::new(PyString::new_bound(py, s).iter()?)), } } + + pub fn as_sequence_iterator_py(&self) -> PyResult>> + '_>> { + match self { + GenericIterable::List(iter) => Ok(Box::new(iter.iter().map(Ok))), + GenericIterable::Tuple(iter) => Ok(Box::new(iter.iter().map(Ok))), + GenericIterable::Set(iter) => Ok(Box::new(iter.iter().map(Ok))), + GenericIterable::FrozenSet(iter) => Ok(Box::new(iter.iter().map(Ok))), + // Note that this iterates over only the keys, just like doing iter({}) in Python + GenericIterable::Dict(iter) => Ok(Box::new(iter.iter().map(|(k, _)| Ok(k)))), + GenericIterable::DictKeys(iter) => Ok(Box::new(iter.iter()?)), + GenericIterable::DictValues(iter) => Ok(Box::new(iter.iter()?)), + GenericIterable::DictItems(iter) => Ok(Box::new(iter.iter()?)), + // Note that this iterates over only the keys, just like doing iter({}) in Python + GenericIterable::Mapping(iter) => Ok(Box::new(iter.keys()?.iter()?)), + GenericIterable::PyString(iter) => Ok(Box::new(iter.iter()?)), + GenericIterable::Bytes(iter) => Ok(Box::new(iter.iter()?)), + GenericIterable::PyByteArray(iter) => Ok(Box::new(iter.iter()?)), + GenericIterable::Sequence(iter) => Ok(Box::new(iter.iter()?)), + GenericIterable::Iterator(iter) => Ok(Box::new(iter.iter()?)), + _ => unreachable!(), + } + } } -struct MaxLengthCheck<'a, INPUT: ?Sized> { +pub struct MaxLengthCheck<'a, INPUT: ?Sized> { current_length: usize, max_length: Option, field_type: &'a str, @@ -132,7 +155,12 @@ struct MaxLengthCheck<'a, INPUT: ?Sized> { } impl<'a, INPUT: ?Sized> MaxLengthCheck<'a, INPUT> { - fn new(max_length: Option, field_type: &'a str, input: &'a INPUT, actual_length: Option) -> Self { + pub(crate) fn new( + max_length: Option, + field_type: &'a str, + input: &'a INPUT, + actual_length: Option, + ) -> Self { Self { current_length: 0, max_length, @@ -177,7 +205,7 @@ macro_rules! any_next_error { } #[allow(clippy::too_many_arguments)] -fn validate_iter_to_vec<'py>( +pub(crate) fn validate_iter_to_vec<'py>( py: Python<'py>, iter: impl Iterator>>, capacity: usize, @@ -289,7 +317,7 @@ fn validate_iter_to_set<'py>( } } -fn no_validator_iter_to_vec<'py>( +pub(crate) fn no_validator_iter_to_vec<'py>( py: Python<'py>, input: &(impl Input<'py> + ?Sized), iter: impl Iterator>>, @@ -428,6 +456,43 @@ impl<'py> GenericIterable<'_, 'py> { } } +impl<'py> Iterable<'py> for GenericIterable<'_, 'py> { + type Input = Bound<'py, PyAny>; + fn len(&self) -> Option { + self.generic_len() + } + fn iterate(self, consumer: impl ConsumeIterator, Output = R>) -> ValResult { + match self { + GenericIterable::List(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))), + GenericIterable::Tuple(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))), + GenericIterable::Set(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))), + GenericIterable::FrozenSet(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))), + // Note that this iterates over only the keys, just like doing iter({}) in Python + GenericIterable::Dict(iter) => Ok(consumer.consume_iterator(iter.iter().map(|(k, _)| Ok(k)))), + GenericIterable::DictKeys(iter) => Ok(consumer.consume_iterator(iter.iter()?)), + GenericIterable::DictValues(iter) => Ok(consumer.consume_iterator(iter.iter()?)), + GenericIterable::DictItems(iter) => Ok(consumer.consume_iterator(iter.iter()?)), + // Note that this iterates over only the keys, just like doing iter({}) in Python + GenericIterable::Mapping(iter) => Ok(consumer.consume_iterator(iter.keys()?.iter()?)), + GenericIterable::PyString(iter) => Ok(consumer.consume_iterator(iter.iter()?)), + GenericIterable::Bytes(iter) => Ok(consumer.consume_iterator(iter.iter()?)), + GenericIterable::PyByteArray(iter) => Ok(consumer.consume_iterator(iter.iter()?)), + GenericIterable::Sequence(iter) => Ok(consumer.consume_iterator(iter.iter()?)), + GenericIterable::Iterator(iter) => Ok(consumer.consume_iterator(iter.iter()?)), + _ => unreachable!(), + } + } +} + +impl<'py> AsPyList<'py> for GenericIterable<'_, 'py> { + fn as_py_list(&self) -> Option<&Bound<'py, PyList>> { + match self { + GenericIterable::List(iter) => Some(iter), + _ => None, + } + } +} + #[cfg_attr(debug_assertions, derive(Debug))] pub enum GenericMapping<'a, 'py> { PyDict(&'a Bound<'py, PyDict>), diff --git a/src/validators/list.rs b/src/validators/list.rs index 6409d370b..4d56e1160 100644 --- a/src/validators/list.rs +++ b/src/validators/list.rs @@ -4,9 +4,11 @@ use pyo3::prelude::*; use pyo3::types::PyDict; use crate::errors::ValResult; -use crate::input::{GenericIterable, Input}; +use crate::input::{ + no_validator_iter_to_vec, validate_iter_to_vec, AsPyList, BorrowInput, ConsumeIterator, Input, Iterable, + MaxLengthCheck, +}; use crate::tools::SchemaDict; -use crate::validators::Exactness; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; @@ -122,24 +124,34 @@ impl Validator for ListValidator { input: &(impl Input<'py> + ?Sized), state: &mut ValidationState<'_, 'py>, ) -> ValResult { - let seq = input.validate_list(state.strict_or(self.strict))?; - let exactness = match &seq { - GenericIterable::List(_) | GenericIterable::JsonArray(_) => Exactness::Exact, - GenericIterable::Tuple(_) => Exactness::Strict, - _ => Exactness::Lax, - }; - state.floor_exactness(exactness); + let seq = input.validate_list(state.strict_or(self.strict))?.unpack(state); + let actual_length = seq.len(); let output = match self.item_validator { - Some(ref v) => seq.validate_to_vec(py, input, self.max_length, "List", v, state)?, - None => match seq { - GenericIterable::List(list) => { - length_check!(input, "List", self.min_length, self.max_length, list); - let list_copy = list.get_slice(0, usize::MAX); + Some(ref v) => seq.iterate(ValidateToVec { + py, + input, + actual_length, + max_length: self.max_length, + field_type: "List", + item_validator: v, + state, + })??, + None => { + if let Some(py_list) = seq.as_py_list() { + length_check!(input, "List", self.min_length, self.max_length, py_list); + let list_copy = py_list.get_slice(0, usize::MAX); return Ok(list_copy.into_py(py)); } - _ => seq.to_vec(py, input, "List", self.max_length)?, - }, + + seq.iterate(ToVec { + py, + input, + actual_length, + max_length: self.max_length, + field_type: "List", + })?? + } }; min_length_check!(input, "List", self.min_length, output); Ok(output.into_py(py)) @@ -164,3 +176,54 @@ impl Validator for ListValidator { } } } + +struct ValidateToVec<'a, 's, 'py, I: Input<'py> + ?Sized> { + py: Python<'py>, + input: &'a I, + actual_length: Option, + max_length: Option, + field_type: &'static str, + item_validator: &'a CombinedValidator, + state: &'a mut ValidationState<'s, 'py>, +} + +// pretty arbitrary default capacity when creating vecs from iteration +const DEFAULT_CAPACITY: usize = 10; + +impl<'py, T, I: Input<'py> + ?Sized> ConsumeIterator> for ValidateToVec<'_, '_, 'py, I> +where + T: BorrowInput<'py>, +{ + type Output = ValResult>; + fn consume_iterator(self, iterator: impl Iterator>) -> ValResult> { + let capacity = self.actual_length.unwrap_or(DEFAULT_CAPACITY); + let max_length_check = MaxLengthCheck::new(self.max_length, self.field_type, self.input, self.actual_length); + validate_iter_to_vec( + self.py, + iterator, + capacity, + max_length_check, + self.item_validator, + self.state, + ) + } +} + +struct ToVec<'a, 'py, I: Input<'py> + ?Sized> { + py: Python<'py>, + input: &'a I, + actual_length: Option, + max_length: Option, + field_type: &'static str, +} + +impl<'py, T, I: Input<'py> + ?Sized> ConsumeIterator> for ToVec<'_, 'py, I> +where + T: BorrowInput<'py>, +{ + type Output = ValResult>; + fn consume_iterator(self, iterator: impl Iterator>) -> ValResult> { + let max_length_check = MaxLengthCheck::new(self.max_length, self.field_type, self.input, self.actual_length); + no_validator_iter_to_vec(self.py, self.input, iterator, max_length_check) + } +}