From aeaac5631c2bda7ae487910888d4dc82532fc6b8 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Mon, 25 Sep 2023 23:03:48 +0100 Subject: [PATCH] skip recursion checks for non-recursive definitions --- src/definitions.rs | 96 +++++++++++++----- src/serializers/extra.rs | 4 +- src/serializers/mod.rs | 4 +- .../type_serializers/definitions.rs | 24 ++--- src/validators/definitions.rs | 98 ++++++++----------- src/validators/generator.rs | 3 +- src/validators/mod.rs | 8 +- 7 files changed, 137 insertions(+), 100 deletions(-) diff --git a/src/definitions.rs b/src/definitions.rs index 0d01fd2ae..fa0885468 100644 --- a/src/definitions.rs +++ b/src/definitions.rs @@ -5,11 +5,11 @@ /// and then get a definition from a reference using an integer id (just for performance of not using a HashMap) use std::collections::hash_map::Entry; -use pyo3::prelude::*; +use pyo3::{prelude::*, PyTraverseError, PyVisit}; use ahash::AHashMap; -use crate::build_tools::py_schema_err; +use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse}; // An integer id for the reference pub type ReferenceId = usize; @@ -24,23 +24,44 @@ pub type ReferenceId = usize; /// They get indexed by a ReferenceId, which are integer identifiers /// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer} /// gets build. -pub type Definitions = [T]; +pub type Definitions = [Definition]; #[derive(Clone, Debug)] -struct Definition { - pub id: ReferenceId, - pub value: Option, +struct DefinitionSlot { + id: ReferenceId, + value: Option, + recursive: bool, +} + +#[derive(Clone)] +pub struct Definition { + pub value: T, + pub recursive: bool, +} + +impl PyGcTraverse for Definition { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + self.value.py_gc_traverse(visit) + } +} + +impl std::fmt::Debug for Definition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.value.fmt(f) + } } #[derive(Clone, Debug)] pub struct DefinitionsBuilder { - definitions: AHashMap>, + definitions: AHashMap>, + in_flight_definitions: AHashMap, } impl DefinitionsBuilder { pub fn new() -> Self { Self { definitions: AHashMap::new(), + in_flight_definitions: AHashMap::new(), } } @@ -51,34 +72,53 @@ impl DefinitionsBuilder { // We either need a String copy or two hashmap lookups // Neither is better than the other // We opted for the easier outward facing API - match self.definitions.entry(reference.to_string()) { + let id = match self.definitions.entry(reference.to_string()) { Entry::Occupied(entry) => entry.get().id, Entry::Vacant(entry) => { - entry.insert(Definition { + entry.insert(DefinitionSlot { id: next_id, value: None, + recursive: false, }); next_id } + }; + // If this definition is currently being built, then it's recursive + if let Some(recursive) = self.in_flight_definitions.get_mut(&id) { + *recursive = true; } + id } - /// Add a definition, returning the ReferenceId that maps to it - pub fn add_definition(&mut self, reference: String, value: T) -> PyResult { + /// Add a definition + pub fn build_definition( + &mut self, + reference: String, + constructor: impl FnOnce(&mut Self) -> PyResult, + ) -> PyResult<()> { let next_id = self.definitions.len(); - match self.definitions.entry(reference.clone()) { - Entry::Occupied(mut entry) => match entry.get_mut().value.replace(value) { - Some(_) => py_schema_err!("Duplicate ref: `{}`", reference), - None => Ok(entry.get().id), - }, - Entry::Vacant(entry) => { - entry.insert(Definition { - id: next_id, - value: Some(value), - }); - Ok(next_id) + let id = match self.definitions.entry(reference.clone()) { + Entry::Occupied(entry) => { + let entry = entry.into_mut(); + if entry.value.is_some() { + return py_schema_err!("Duplicate ref: `{}`", reference); + } + entry } + Entry::Vacant(entry) => entry.insert(DefinitionSlot { + id: next_id, + value: None, + recursive: false, + }), } + .id; + self.in_flight_definitions.insert(id, false); + let value = constructor(self)?; + // can unwrap because the entry was just built above + let slot = self.definitions.get_mut(&reference).unwrap(); + slot.value = Some(value); + slot.recursive = self.in_flight_definitions.remove(&id).unwrap(); + Ok(()) } /// Retrieve an item definition using a ReferenceId @@ -99,13 +139,19 @@ impl DefinitionsBuilder { } /// Consume this Definitions into a vector of items, indexed by each items ReferenceId - pub fn finish(self) -> PyResult> { + pub fn finish(self) -> PyResult>> { // We need to create a vec of defs according to the order in their ids - let mut defs: Vec<(usize, T)> = Vec::new(); + let mut defs: Vec<(usize, Definition)> = Vec::new(); for (reference, def) in self.definitions { match def.value { None => return py_schema_err!("Definitions error: definition {} was never filled", reference), - Some(v) => defs.push((def.id, v)), + Some(value) => defs.push(( + def.id, + Definition { + value, + recursive: def.recursive, + }, + )), } } defs.sort_by_key(|(id, _)| *id); diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index 9972a82c4..32e6f8d87 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -11,7 +11,7 @@ use super::config::SerializationConfig; use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER}; use super::ob_type::ObTypeLookup; use super::shared::CombinedSerializer; -use crate::definitions::Definitions; +use crate::definitions::{Definition, Definitions}; use crate::recursion_guard::RecursionGuard; /// this is ugly, would be much better if extra could be stored in `SerializationState` @@ -156,7 +156,7 @@ impl SerCheck { #[cfg_attr(debug_assertions, derive(Debug))] pub(crate) struct ExtraOwned { mode: SerMode, - definitions: Vec, + definitions: Vec>, warnings: CollectWarnings, by_alias: bool, exclude_unset: bool, diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 6dbc076fe..420c58be9 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -5,7 +5,7 @@ use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict}; use pyo3::{PyTraverseError, PyVisit}; -use crate::definitions::DefinitionsBuilder; +use crate::definitions::{Definition, DefinitionsBuilder}; use crate::py_gc::PyGcTraverse; use config::SerializationConfig; @@ -30,7 +30,7 @@ mod type_serializers; #[derive(Debug)] pub struct SchemaSerializer { serializer: CombinedSerializer, - definitions: Vec, + definitions: Vec>, expected_json_size: AtomicUsize, config: SerializationConfig, } diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index 4614bbc56..0a93a3e60 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -27,11 +27,11 @@ impl BuildSerializer for DefinitionsSerializerBuilder { let schema_definitions: &PyList = schema.get_as_req(intern!(py, "definitions"))?; for schema_definition in schema_definitions { - let reference = schema_definition - .extract::<&PyDict>()? - .get_as_req::(intern!(py, "ref"))?; - let serializer = CombinedSerializer::build(schema_definition.downcast()?, config, definitions)?; - definitions.add_definition(reference, serializer)?; + let schema = schema_definition.downcast::()?; + let reference = schema.get_as_req::(intern!(py, "ref"))?; + definitions.build_definition(reference, |definitions| { + CombinedSerializer::build(schema, config, definitions) + })?; } let inner_schema: &PyDict = schema.get_as_req(intern!(py, "schema"))?; @@ -69,8 +69,8 @@ impl TypeSerializer for DefinitionRefSerializer { extra: &Extra, ) -> PyResult { let value_id = extra.rec_guard.add(value, self.serializer_id)?; - let comb_serializer = extra.definitions.get(self.serializer_id).unwrap(); - let r = comb_serializer.to_python(value, include, exclude, extra); + let definition = extra.definitions.get(self.serializer_id).unwrap(); + let r = definition.value.to_python(value, include, exclude, extra); extra.rec_guard.pop(value_id, self.serializer_id); r } @@ -88,8 +88,10 @@ impl TypeSerializer for DefinitionRefSerializer { extra: &Extra, ) -> Result { let value_id = extra.rec_guard.add(value, self.serializer_id).map_err(py_err_se_err)?; - let comb_serializer = extra.definitions.get(self.serializer_id).unwrap(); - let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra); + let definition = extra.definitions.get(self.serializer_id).unwrap(); + let r = definition + .value + .serde_serialize(value, serializer, include, exclude, extra); extra.rec_guard.pop(value_id, self.serializer_id); r } @@ -99,7 +101,7 @@ impl TypeSerializer for DefinitionRefSerializer { } fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { - let comb_serializer = definitions.get(self.serializer_id).unwrap(); - comb_serializer.retry_with_lax_check(definitions) + let definition = definitions.get(self.serializer_id).unwrap(); + definition.value.retry_with_lax_check(definitions) } } diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 3a35fce4c..6a26cb3bc 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -28,8 +28,9 @@ impl BuildValidator for DefinitionsValidatorBuilder { let reference = schema_definition .extract::<&PyDict>()? .get_as_req::(intern!(py, "ref"))?; - let validator = build_validator(schema_definition, config, definitions)?; - definitions.add_definition(reference, validator)?; + definitions.build_definition(reference, |definitions| { + build_validator(schema_definition, config, definitions) + })?; } let inner_schema: &PyAny = schema.get_as_req(intern!(py, "schema"))?; @@ -82,22 +83,25 @@ impl Validator for DefinitionRefValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - if let Some(id) = input.identity() { - if state.recursion_guard.contains_or_insert(id, self.validator_id) { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) - } else { - if state.recursion_guard.incr_depth() { - return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); - } - let output = validate(self.validator_id, py, input, state); - state.recursion_guard.remove(id, self.validator_id); - state.recursion_guard.decr_depth(); - output + let definition = state.definitions.get(self.validator_id).unwrap(); + if definition.recursive { + if let Some(id) = input.identity() { + return if state.recursion_guard.contains_or_insert(id, self.validator_id) { + // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` + Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) + } else { + if state.recursion_guard.incr_depth() { + return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); + } + let output = definition.value.validate(py, input, state); + state.recursion_guard.remove(id, self.validator_id); + state.recursion_guard.decr_depth(); + output + }; } - } else { - validate(self.validator_id, py, input, state) - } + }; + + definition.value.validate(py, input, state) } fn validate_assignment<'data>( @@ -108,22 +112,29 @@ impl Validator for DefinitionRefValidator { field_value: &'data PyAny, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - if let Some(id) = obj.identity() { - if state.recursion_guard.contains_or_insert(id, self.validator_id) { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) - } else { - if state.recursion_guard.incr_depth() { - return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); - } - let output = validate_assignment(self.validator_id, py, obj, field_name, field_value, state); - state.recursion_guard.remove(id, self.validator_id); - state.recursion_guard.decr_depth(); - output + let definition = state.definitions.get(self.validator_id).unwrap(); + if definition.recursive { + if let Some(id) = obj.identity() { + return if state.recursion_guard.contains_or_insert(id, self.validator_id) { + // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` + Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) + } else { + if state.recursion_guard.incr_depth() { + return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); + } + let output = definition + .value + .validate_assignment(py, obj, field_name, field_value, state); + state.recursion_guard.remove(id, self.validator_id); + state.recursion_guard.decr_depth(); + output + }; } - } else { - validate_assignment(self.validator_id, py, obj, field_name, field_value, state) - } + }; + + definition + .value + .validate_assignment(py, obj, field_name, field_value, state) } fn different_strict_behavior( @@ -151,26 +162,3 @@ impl Validator for DefinitionRefValidator { Ok(()) } } - -fn validate<'data>( - validator_id: usize, - py: Python<'data>, - input: &'data impl Input<'data>, - state: &mut ValidationState, -) -> ValResult<'data, PyObject> { - let validator = state.definitions.get(validator_id).unwrap(); - validator.validate(py, input, state) -} - -#[allow(clippy::too_many_arguments)] -fn validate_assignment<'data>( - validator_id: usize, - py: Python<'data>, - obj: &'data PyAny, - field_name: &'data str, - field_value: &'data PyAny, - state: &mut ValidationState, -) -> ValResult<'data, PyObject> { - let validator = state.definitions.get(validator_id).unwrap(); - validator.validate_assignment(py, obj, field_name, field_value, state) -} diff --git a/src/validators/generator.rs b/src/validators/generator.rs index bf6d009e1..f69d1ec0b 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -3,6 +3,7 @@ use std::fmt; use pyo3::prelude::*; use pyo3::types::PyDict; +use crate::definitions::Definition; use crate::errors::{ErrorType, LocItem, ValError, ValResult}; use crate::input::{GenericIterator, Input}; use crate::recursion_guard::RecursionGuard; @@ -223,7 +224,7 @@ impl ValidatorIterator { pub struct InternalValidator { name: String, validator: CombinedValidator, - definitions: Vec, + definitions: Vec>, // TODO, do we need data? data: Option>, strict: Option, diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 4ee677663..aa584a8e3 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -9,7 +9,7 @@ use pyo3::types::{PyAny, PyDict, PyTuple, PyType}; use pyo3::{intern, PyTraverseError, PyVisit}; use crate::build_tools::{py_schema_err, py_schema_error_type, SchemaError}; -use crate::definitions::DefinitionsBuilder; +use crate::definitions::{Definition, DefinitionsBuilder}; use crate::errors::{LocItem, ValError, ValResult, ValidationError}; use crate::input::{Input, InputType, StringMapping}; use crate::py_gc::PyGcTraverse; @@ -101,7 +101,7 @@ impl PySome { #[derive(Debug, Clone)] pub struct SchemaValidator { validator: CombinedValidator, - definitions: Vec, + definitions: Vec>, schema: PyObject, #[pyo3(get)] title: PyObject, @@ -119,7 +119,7 @@ impl SchemaValidator { validator.complete(&definitions_builder)?; let mut definitions = definitions_builder.clone().finish()?; for val in &mut definitions { - val.complete(&definitions_builder)?; + val.value.complete(&definitions_builder)?; } let config_title = match config { Some(c) => c.get_item("title"), @@ -395,7 +395,7 @@ impl<'py> SelfValidator<'py> { validator.complete(&definitions_builder)?; let mut definitions = definitions_builder.clone().finish()?; for val in &mut definitions { - val.complete(&definitions_builder)?; + val.value.complete(&definitions_builder)?; } Ok(SchemaValidator { validator,