Skip to content

Commit

Permalink
skip recursion checks for non-recursive definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Sep 25, 2023
1 parent e610984 commit aeaac56
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 100 deletions.
96 changes: 71 additions & 25 deletions src/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
/// and then get a definition from a reference using an integer id (just for performance of not using a HashMap)
use std::collections::hash_map::Entry;

use pyo3::prelude::*;
use pyo3::{prelude::*, PyTraverseError, PyVisit};

use ahash::AHashMap;

use crate::build_tools::py_schema_err;
use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse};

// An integer id for the reference
pub type ReferenceId = usize;
Expand All @@ -24,23 +24,44 @@ pub type ReferenceId = usize;
/// They get indexed by a ReferenceId, which are integer identifiers
/// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer}
/// gets build.
pub type Definitions<T> = [T];
pub type Definitions<T> = [Definition<T>];

#[derive(Clone, Debug)]
struct Definition<T> {
pub id: ReferenceId,
pub value: Option<T>,
struct DefinitionSlot<T> {
id: ReferenceId,
value: Option<T>,
recursive: bool,
}

#[derive(Clone)]
pub struct Definition<T> {
pub value: T,
pub recursive: bool,
}

impl<T: PyGcTraverse> PyGcTraverse for Definition<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
self.value.py_gc_traverse(visit)
}
}

impl<T: std::fmt::Debug> std::fmt::Debug for Definition<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.value.fmt(f)
}
}

#[derive(Clone, Debug)]
pub struct DefinitionsBuilder<T> {
definitions: AHashMap<String, Definition<T>>,
definitions: AHashMap<String, DefinitionSlot<T>>,
in_flight_definitions: AHashMap<ReferenceId, bool>,
}

impl<T: Clone + std::fmt::Debug> DefinitionsBuilder<T> {
pub fn new() -> Self {
Self {
definitions: AHashMap::new(),
in_flight_definitions: AHashMap::new(),
}
}

Expand All @@ -51,34 +72,53 @@ impl<T: Clone + std::fmt::Debug> DefinitionsBuilder<T> {
// We either need a String copy or two hashmap lookups
// Neither is better than the other
// We opted for the easier outward facing API
match self.definitions.entry(reference.to_string()) {
let id = match self.definitions.entry(reference.to_string()) {
Entry::Occupied(entry) => entry.get().id,
Entry::Vacant(entry) => {
entry.insert(Definition {
entry.insert(DefinitionSlot {
id: next_id,
value: None,
recursive: false,
});
next_id
}
};
// If this definition is currently being built, then it's recursive
if let Some(recursive) = self.in_flight_definitions.get_mut(&id) {
*recursive = true;
}
id
}

/// Add a definition, returning the ReferenceId that maps to it
pub fn add_definition(&mut self, reference: String, value: T) -> PyResult<ReferenceId> {
/// Add a definition
pub fn build_definition(
&mut self,
reference: String,
constructor: impl FnOnce(&mut Self) -> PyResult<T>,
) -> PyResult<()> {
let next_id = self.definitions.len();
match self.definitions.entry(reference.clone()) {
Entry::Occupied(mut entry) => match entry.get_mut().value.replace(value) {
Some(_) => py_schema_err!("Duplicate ref: `{}`", reference),
None => Ok(entry.get().id),
},
Entry::Vacant(entry) => {
entry.insert(Definition {
id: next_id,
value: Some(value),
});
Ok(next_id)
let id = match self.definitions.entry(reference.clone()) {
Entry::Occupied(entry) => {
let entry = entry.into_mut();
if entry.value.is_some() {
return py_schema_err!("Duplicate ref: `{}`", reference);
}
entry
}
Entry::Vacant(entry) => entry.insert(DefinitionSlot {
id: next_id,
value: None,
recursive: false,
}),
}
.id;
self.in_flight_definitions.insert(id, false);
let value = constructor(self)?;
// can unwrap because the entry was just built above
let slot = self.definitions.get_mut(&reference).unwrap();
slot.value = Some(value);
slot.recursive = self.in_flight_definitions.remove(&id).unwrap();
Ok(())
}

/// Retrieve an item definition using a ReferenceId
Expand All @@ -99,13 +139,19 @@ impl<T: Clone + std::fmt::Debug> DefinitionsBuilder<T> {
}

/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
pub fn finish(self) -> PyResult<Vec<T>> {
pub fn finish(self) -> PyResult<Vec<Definition<T>>> {
// We need to create a vec of defs according to the order in their ids
let mut defs: Vec<(usize, T)> = Vec::new();
let mut defs: Vec<(usize, Definition<T>)> = Vec::new();
for (reference, def) in self.definitions {
match def.value {
None => return py_schema_err!("Definitions error: definition {} was never filled", reference),
Some(v) => defs.push((def.id, v)),
Some(value) => defs.push((
def.id,
Definition {
value,
recursive: def.recursive,
},
)),
}
}
defs.sort_by_key(|(id, _)| *id);
Expand Down
4 changes: 2 additions & 2 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use super::config::SerializationConfig;
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
use super::ob_type::ObTypeLookup;
use super::shared::CombinedSerializer;
use crate::definitions::Definitions;
use crate::definitions::{Definition, Definitions};
use crate::recursion_guard::RecursionGuard;

/// this is ugly, would be much better if extra could be stored in `SerializationState`
Expand Down Expand Up @@ -156,7 +156,7 @@ impl SerCheck {
#[cfg_attr(debug_assertions, derive(Debug))]
pub(crate) struct ExtraOwned {
mode: SerMode,
definitions: Vec<CombinedSerializer>,
definitions: Vec<Definition<CombinedSerializer>>,
warnings: CollectWarnings,
by_alias: bool,
exclude_unset: bool,
Expand Down
4 changes: 2 additions & 2 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict};
use pyo3::{PyTraverseError, PyVisit};

use crate::definitions::DefinitionsBuilder;
use crate::definitions::{Definition, DefinitionsBuilder};
use crate::py_gc::PyGcTraverse;

use config::SerializationConfig;
Expand All @@ -30,7 +30,7 @@ mod type_serializers;
#[derive(Debug)]
pub struct SchemaSerializer {
serializer: CombinedSerializer,
definitions: Vec<CombinedSerializer>,
definitions: Vec<Definition<CombinedSerializer>>,
expected_json_size: AtomicUsize,
config: SerializationConfig,
}
Expand Down
24 changes: 13 additions & 11 deletions src/serializers/type_serializers/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ impl BuildSerializer for DefinitionsSerializerBuilder {
let schema_definitions: &PyList = schema.get_as_req(intern!(py, "definitions"))?;

for schema_definition in schema_definitions {
let reference = schema_definition
.extract::<&PyDict>()?
.get_as_req::<String>(intern!(py, "ref"))?;
let serializer = CombinedSerializer::build(schema_definition.downcast()?, config, definitions)?;
definitions.add_definition(reference, serializer)?;
let schema = schema_definition.downcast::<PyDict>()?;
let reference = schema.get_as_req::<String>(intern!(py, "ref"))?;
definitions.build_definition(reference, |definitions| {
CombinedSerializer::build(schema, config, definitions)
})?;
}

let inner_schema: &PyDict = schema.get_as_req(intern!(py, "schema"))?;
Expand Down Expand Up @@ -69,8 +69,8 @@ impl TypeSerializer for DefinitionRefSerializer {
extra: &Extra,
) -> PyResult<PyObject> {
let value_id = extra.rec_guard.add(value, self.serializer_id)?;
let comb_serializer = extra.definitions.get(self.serializer_id).unwrap();
let r = comb_serializer.to_python(value, include, exclude, extra);
let definition = extra.definitions.get(self.serializer_id).unwrap();
let r = definition.value.to_python(value, include, exclude, extra);
extra.rec_guard.pop(value_id, self.serializer_id);
r
}
Expand All @@ -88,8 +88,10 @@ impl TypeSerializer for DefinitionRefSerializer {
extra: &Extra,
) -> Result<S::Ok, S::Error> {
let value_id = extra.rec_guard.add(value, self.serializer_id).map_err(py_err_se_err)?;
let comb_serializer = extra.definitions.get(self.serializer_id).unwrap();
let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra);
let definition = extra.definitions.get(self.serializer_id).unwrap();
let r = definition
.value
.serde_serialize(value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id, self.serializer_id);
r
}
Expand All @@ -99,7 +101,7 @@ impl TypeSerializer for DefinitionRefSerializer {
}

fn retry_with_lax_check(&self, definitions: &Definitions<CombinedSerializer>) -> bool {
let comb_serializer = definitions.get(self.serializer_id).unwrap();
comb_serializer.retry_with_lax_check(definitions)
let definition = definitions.get(self.serializer_id).unwrap();
definition.value.retry_with_lax_check(definitions)
}
}
98 changes: 43 additions & 55 deletions src/validators/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ impl BuildValidator for DefinitionsValidatorBuilder {
let reference = schema_definition
.extract::<&PyDict>()?
.get_as_req::<String>(intern!(py, "ref"))?;
let validator = build_validator(schema_definition, config, definitions)?;
definitions.add_definition(reference, validator)?;
definitions.build_definition(reference, |definitions| {
build_validator(schema_definition, config, definitions)
})?;
}

let inner_schema: &PyAny = schema.get_as_req(intern!(py, "schema"))?;
Expand Down Expand Up @@ -82,22 +83,25 @@ impl Validator for DefinitionRefValidator {
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
if let Some(id) = input.identity() {
if state.recursion_guard.contains_or_insert(id, self.validator_id) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
} else {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
}
let output = validate(self.validator_id, py, input, state);
state.recursion_guard.remove(id, self.validator_id);
state.recursion_guard.decr_depth();
output
let definition = state.definitions.get(self.validator_id).unwrap();
if definition.recursive {
if let Some(id) = input.identity() {
return if state.recursion_guard.contains_or_insert(id, self.validator_id) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
} else {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
}
let output = definition.value.validate(py, input, state);
state.recursion_guard.remove(id, self.validator_id);
state.recursion_guard.decr_depth();
output
};
}
} else {
validate(self.validator_id, py, input, state)
}
};

definition.value.validate(py, input, state)
}

fn validate_assignment<'data>(
Expand All @@ -108,22 +112,29 @@ impl Validator for DefinitionRefValidator {
field_value: &'data PyAny,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
if let Some(id) = obj.identity() {
if state.recursion_guard.contains_or_insert(id, self.validator_id) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
} else {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
}
let output = validate_assignment(self.validator_id, py, obj, field_name, field_value, state);
state.recursion_guard.remove(id, self.validator_id);
state.recursion_guard.decr_depth();
output
let definition = state.definitions.get(self.validator_id).unwrap();
if definition.recursive {
if let Some(id) = obj.identity() {
return if state.recursion_guard.contains_or_insert(id, self.validator_id) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
} else {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
}
let output = definition
.value
.validate_assignment(py, obj, field_name, field_value, state);
state.recursion_guard.remove(id, self.validator_id);
state.recursion_guard.decr_depth();
output
};
}
} else {
validate_assignment(self.validator_id, py, obj, field_name, field_value, state)
}
};

definition
.value
.validate_assignment(py, obj, field_name, field_value, state)
}

fn different_strict_behavior(
Expand Down Expand Up @@ -151,26 +162,3 @@ impl Validator for DefinitionRefValidator {
Ok(())
}
}

fn validate<'data>(
validator_id: usize,
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
let validator = state.definitions.get(validator_id).unwrap();
validator.validate(py, input, state)
}

#[allow(clippy::too_many_arguments)]
fn validate_assignment<'data>(
validator_id: usize,
py: Python<'data>,
obj: &'data PyAny,
field_name: &'data str,
field_value: &'data PyAny,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
let validator = state.definitions.get(validator_id).unwrap();
validator.validate_assignment(py, obj, field_name, field_value, state)
}
3 changes: 2 additions & 1 deletion src/validators/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::fmt;
use pyo3::prelude::*;
use pyo3::types::PyDict;

use crate::definitions::Definition;
use crate::errors::{ErrorType, LocItem, ValError, ValResult};
use crate::input::{GenericIterator, Input};
use crate::recursion_guard::RecursionGuard;
Expand Down Expand Up @@ -223,7 +224,7 @@ impl ValidatorIterator {
pub struct InternalValidator {
name: String,
validator: CombinedValidator,
definitions: Vec<CombinedValidator>,
definitions: Vec<Definition<CombinedValidator>>,
// TODO, do we need data?
data: Option<Py<PyDict>>,
strict: Option<bool>,
Expand Down
Loading

0 comments on commit aeaac56

Please sign in to comment.