Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip recursion checks for non-recursive definitions #989

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 71 additions & 25 deletions src/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
/// and then get a definition from a reference using an integer id (just for performance of not using a HashMap)
use std::collections::hash_map::Entry;

use pyo3::prelude::*;
use pyo3::{prelude::*, PyTraverseError, PyVisit};

use ahash::AHashMap;

use crate::build_tools::py_schema_err;
use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse};

// An integer id for the reference
pub type ReferenceId = usize;
Expand All @@ -24,23 +24,44 @@ pub type ReferenceId = usize;
/// They get indexed by a ReferenceId, which are integer identifiers
/// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer}
/// gets build.
pub type Definitions<T> = [T];
pub type Definitions<T> = [Definition<T>];

#[derive(Clone, Debug)]
struct Definition<T> {
pub id: ReferenceId,
pub value: Option<T>,
struct DefinitionSlot<T> {
id: ReferenceId,
value: Option<T>,
recursive: bool,
}

#[derive(Clone)]
pub struct Definition<T> {
pub value: T,
pub recursive: bool,
}

impl<T: PyGcTraverse> PyGcTraverse for Definition<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
self.value.py_gc_traverse(visit)
}
}

impl<T: std::fmt::Debug> std::fmt::Debug for Definition<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.value.fmt(f)
}
}

#[derive(Clone, Debug)]
pub struct DefinitionsBuilder<T> {
definitions: AHashMap<String, Definition<T>>,
definitions: AHashMap<String, DefinitionSlot<T>>,
in_flight_definitions: AHashMap<ReferenceId, bool>,
}

impl<T: Clone + std::fmt::Debug> DefinitionsBuilder<T> {
pub fn new() -> Self {
Self {
definitions: AHashMap::new(),
in_flight_definitions: AHashMap::new(),
}
}

Expand All @@ -51,34 +72,53 @@ impl<T: Clone + std::fmt::Debug> DefinitionsBuilder<T> {
// We either need a String copy or two hashmap lookups
// Neither is better than the other
// We opted for the easier outward facing API
match self.definitions.entry(reference.to_string()) {
let id = match self.definitions.entry(reference.to_string()) {
Entry::Occupied(entry) => entry.get().id,
Entry::Vacant(entry) => {
entry.insert(Definition {
entry.insert(DefinitionSlot {
id: next_id,
value: None,
recursive: false,
});
next_id
}
};
// If this definition is currently being built, then it's recursive
if let Some(recursive) = self.in_flight_definitions.get_mut(&id) {
*recursive = true;
}
id
}

/// Add a definition, returning the ReferenceId that maps to it
pub fn add_definition(&mut self, reference: String, value: T) -> PyResult<ReferenceId> {
/// Add a definition
pub fn build_definition(
&mut self,
reference: String,
constructor: impl FnOnce(&mut Self) -> PyResult<T>,
) -> PyResult<()> {
let next_id = self.definitions.len();
match self.definitions.entry(reference.clone()) {
Entry::Occupied(mut entry) => match entry.get_mut().value.replace(value) {
Some(_) => py_schema_err!("Duplicate ref: `{}`", reference),
None => Ok(entry.get().id),
},
Entry::Vacant(entry) => {
entry.insert(Definition {
id: next_id,
value: Some(value),
});
Ok(next_id)
let id = match self.definitions.entry(reference.clone()) {
Entry::Occupied(entry) => {
let entry = entry.into_mut();
if entry.value.is_some() {
return py_schema_err!("Duplicate ref: `{}`", reference);
}
entry
}
Entry::Vacant(entry) => entry.insert(DefinitionSlot {
id: next_id,
value: None,
recursive: false,
}),
}
.id;
self.in_flight_definitions.insert(id, false);
let value = constructor(self)?;
// can unwrap because the entry was just built above
let slot = self.definitions.get_mut(&reference).unwrap();
slot.value = Some(value);
slot.recursive = self.in_flight_definitions.remove(&id).unwrap();
Ok(())
}

/// Retrieve an item definition using a ReferenceId
Expand All @@ -99,13 +139,19 @@ impl<T: Clone + std::fmt::Debug> DefinitionsBuilder<T> {
}

/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
pub fn finish(self) -> PyResult<Vec<T>> {
pub fn finish(self) -> PyResult<Vec<Definition<T>>> {
// We need to create a vec of defs according to the order in their ids
let mut defs: Vec<(usize, T)> = Vec::new();
let mut defs: Vec<(usize, Definition<T>)> = Vec::new();
for (reference, def) in self.definitions {
match def.value {
None => return py_schema_err!("Definitions error: definition {} was never filled", reference),
Some(v) => defs.push((def.id, v)),
Some(value) => defs.push((
def.id,
Definition {
value,
recursive: def.recursive,
},
)),
}
}
defs.sort_by_key(|(id, _)| *id);
Expand Down
4 changes: 2 additions & 2 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use super::config::SerializationConfig;
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
use super::ob_type::ObTypeLookup;
use super::shared::CombinedSerializer;
use crate::definitions::Definitions;
use crate::definitions::{Definition, Definitions};
use crate::recursion_guard::RecursionGuard;

/// this is ugly, would be much better if extra could be stored in `SerializationState`
Expand Down Expand Up @@ -156,7 +156,7 @@ impl SerCheck {
#[cfg_attr(debug_assertions, derive(Debug))]
pub(crate) struct ExtraOwned {
mode: SerMode,
definitions: Vec<CombinedSerializer>,
definitions: Vec<Definition<CombinedSerializer>>,
warnings: CollectWarnings,
by_alias: bool,
exclude_unset: bool,
Expand Down
4 changes: 2 additions & 2 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict};
use pyo3::{PyTraverseError, PyVisit};

use crate::definitions::DefinitionsBuilder;
use crate::definitions::{Definition, DefinitionsBuilder};
use crate::py_gc::PyGcTraverse;

use config::SerializationConfig;
Expand All @@ -30,7 +30,7 @@ mod type_serializers;
#[derive(Debug)]
pub struct SchemaSerializer {
serializer: CombinedSerializer,
definitions: Vec<CombinedSerializer>,
definitions: Vec<Definition<CombinedSerializer>>,
expected_json_size: AtomicUsize,
config: SerializationConfig,
}
Expand Down
24 changes: 13 additions & 11 deletions src/serializers/type_serializers/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ impl BuildSerializer for DefinitionsSerializerBuilder {
let schema_definitions: &PyList = schema.get_as_req(intern!(py, "definitions"))?;

for schema_definition in schema_definitions {
let reference = schema_definition
.extract::<&PyDict>()?
.get_as_req::<String>(intern!(py, "ref"))?;
let serializer = CombinedSerializer::build(schema_definition.downcast()?, config, definitions)?;
definitions.add_definition(reference, serializer)?;
let schema = schema_definition.downcast::<PyDict>()?;
let reference = schema.get_as_req::<String>(intern!(py, "ref"))?;
definitions.build_definition(reference, |definitions| {
CombinedSerializer::build(schema, config, definitions)
})?;
}

let inner_schema: &PyDict = schema.get_as_req(intern!(py, "schema"))?;
Expand Down Expand Up @@ -69,8 +69,8 @@ impl TypeSerializer for DefinitionRefSerializer {
extra: &Extra,
) -> PyResult<PyObject> {
let value_id = extra.rec_guard.add(value, self.serializer_id)?;
let comb_serializer = extra.definitions.get(self.serializer_id).unwrap();
let r = comb_serializer.to_python(value, include, exclude, extra);
let definition = extra.definitions.get(self.serializer_id).unwrap();
let r = definition.value.to_python(value, include, exclude, extra);
extra.rec_guard.pop(value_id, self.serializer_id);
Comment on lines 71 to 74
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar question about serializer recursion.

r
}
Expand All @@ -88,8 +88,10 @@ impl TypeSerializer for DefinitionRefSerializer {
extra: &Extra,
) -> Result<S::Ok, S::Error> {
let value_id = extra.rec_guard.add(value, self.serializer_id).map_err(py_err_se_err)?;
let comb_serializer = extra.definitions.get(self.serializer_id).unwrap();
let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra);
let definition = extra.definitions.get(self.serializer_id).unwrap();
let r = definition
.value
.serde_serialize(value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id, self.serializer_id);
Comment on lines 90 to 95
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can / should skip recursion on serializing too?

r
}
Expand All @@ -99,7 +101,7 @@ impl TypeSerializer for DefinitionRefSerializer {
}

fn retry_with_lax_check(&self, definitions: &Definitions<CombinedSerializer>) -> bool {
let comb_serializer = definitions.get(self.serializer_id).unwrap();
comb_serializer.retry_with_lax_check(definitions)
let definition = definitions.get(self.serializer_id).unwrap();
definition.value.retry_with_lax_check(definitions)
}
}
98 changes: 43 additions & 55 deletions src/validators/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ impl BuildValidator for DefinitionsValidatorBuilder {
let reference = schema_definition
.extract::<&PyDict>()?
.get_as_req::<String>(intern!(py, "ref"))?;
let validator = build_validator(schema_definition, config, definitions)?;
definitions.add_definition(reference, validator)?;
definitions.build_definition(reference, |definitions| {
build_validator(schema_definition, config, definitions)
})?;
}

let inner_schema: &PyAny = schema.get_as_req(intern!(py, "schema"))?;
Expand Down Expand Up @@ -82,22 +83,25 @@ impl Validator for DefinitionRefValidator {
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
if let Some(id) = input.identity() {
if state.recursion_guard.contains_or_insert(id, self.validator_id) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
} else {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
}
let output = validate(self.validator_id, py, input, state);
state.recursion_guard.remove(id, self.validator_id);
state.recursion_guard.decr_depth();
output
let definition = state.definitions.get(self.validator_id).unwrap();
if definition.recursive {
Copy link
Contributor Author

@davidhewitt davidhewitt Sep 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I comment out this block I get a segfault in the test suite which gives me confidence the recursion detection is working correctly.

if let Some(id) = input.identity() {
return if state.recursion_guard.contains_or_insert(id, self.validator_id) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
} else {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
}
let output = definition.value.validate(py, input, state);
state.recursion_guard.remove(id, self.validator_id);
state.recursion_guard.decr_depth();
output
};
}
} else {
validate(self.validator_id, py, input, state)
}
};

definition.value.validate(py, input, state)
}

fn validate_assignment<'data>(
Expand All @@ -108,22 +112,29 @@ impl Validator for DefinitionRefValidator {
field_value: &'data PyAny,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
if let Some(id) = obj.identity() {
if state.recursion_guard.contains_or_insert(id, self.validator_id) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
} else {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
}
let output = validate_assignment(self.validator_id, py, obj, field_name, field_value, state);
state.recursion_guard.remove(id, self.validator_id);
state.recursion_guard.decr_depth();
output
let definition = state.definitions.get(self.validator_id).unwrap();
if definition.recursive {
if let Some(id) = obj.identity() {
return if state.recursion_guard.contains_or_insert(id, self.validator_id) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
} else {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
}
let output = definition
.value
.validate_assignment(py, obj, field_name, field_value, state);
state.recursion_guard.remove(id, self.validator_id);
state.recursion_guard.decr_depth();
output
};
}
} else {
validate_assignment(self.validator_id, py, obj, field_name, field_value, state)
}
};

definition
.value
.validate_assignment(py, obj, field_name, field_value, state)
}

fn different_strict_behavior(
Expand Down Expand Up @@ -151,26 +162,3 @@ impl Validator for DefinitionRefValidator {
Ok(())
}
}

fn validate<'data>(
validator_id: usize,
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
let validator = state.definitions.get(validator_id).unwrap();
validator.validate(py, input, state)
}

#[allow(clippy::too_many_arguments)]
fn validate_assignment<'data>(
validator_id: usize,
py: Python<'data>,
obj: &'data PyAny,
field_name: &'data str,
field_value: &'data PyAny,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
let validator = state.definitions.get(validator_id).unwrap();
validator.validate_assignment(py, obj, field_name, field_value, state)
}
3 changes: 2 additions & 1 deletion src/validators/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::fmt;
use pyo3::prelude::*;
use pyo3::types::PyDict;

use crate::definitions::Definition;
use crate::errors::{ErrorType, LocItem, ValError, ValResult};
use crate::input::{GenericIterator, Input};
use crate::recursion_guard::RecursionGuard;
Expand Down Expand Up @@ -223,7 +224,7 @@ impl ValidatorIterator {
pub struct InternalValidator {
name: String,
validator: CombinedValidator,
definitions: Vec<CombinedValidator>,
definitions: Vec<Definition<CombinedValidator>>,
// TODO, do we need data?
data: Option<Py<PyDict>>,
strict: Option<bool>,
Expand Down
Loading