Skip to content

Commit

Permalink
Ensure recursion guard is always used as a stack (pydantic#1166)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt authored Jan 19, 2024
1 parent 4da7192 commit 7a5f8e6
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 97 deletions.
63 changes: 57 additions & 6 deletions src/recursion_guard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,59 @@ type RecursionKey = (

/// This is used to avoid cyclic references in input data causing recursive validation and a nasty segmentation fault.
/// It's used in `validators/definition` to detect when a reference is reused within itself.
pub(crate) struct RecursionGuard<'a, S: ContainsRecursionState> {
state: &'a mut S,
obj_id: usize,
node_id: usize,
}

pub(crate) enum RecursionError {
/// Cyclic reference detected
Cyclic,
/// Recursion limit exceeded
Depth,
}

impl<S: ContainsRecursionState> RecursionGuard<'_, S> {
/// Creates a recursion guard for the given object and node id.
///
/// When dropped, this will release the recursion for the given object and node id.
pub fn new(state: &'_ mut S, obj_id: usize, node_id: usize) -> Result<RecursionGuard<'_, S>, RecursionError> {
state.access_recursion_state(|state| {
if !state.insert(obj_id, node_id) {
return Err(RecursionError::Cyclic);
}
if state.incr_depth() {
return Err(RecursionError::Depth);
}
Ok(())
})?;
Ok(RecursionGuard { state, obj_id, node_id })
}

/// Retrieves the underlying state for further use.
pub fn state(&mut self) -> &mut S {
self.state
}
}

impl<S: ContainsRecursionState> Drop for RecursionGuard<'_, S> {
fn drop(&mut self) {
self.state.access_recursion_state(|state| {
state.decr_depth();
state.remove(self.obj_id, self.node_id);
});
}
}

/// This trait is used to retrieve the recursion state from some other type
pub(crate) trait ContainsRecursionState {
fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R;
}

/// State for the RecursionGuard. Can also be used directly to increase / decrease depth.
#[derive(Debug, Clone, Default)]
pub struct RecursionGuard {
pub struct RecursionState {
ids: RecursionStack,
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
// use one number for all validators
Expand All @@ -31,11 +82,11 @@ pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(wi
255
};

impl RecursionGuard {
impl RecursionState {
// insert a new value
// * return `false` if the stack already had it in it
// * return `true` if the stack didn't have it in it and it was inserted
pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
fn insert(&mut self, obj_id: usize, node_id: usize) -> bool {
self.ids.insert((obj_id, node_id))
}

Expand Down Expand Up @@ -68,7 +119,7 @@ impl RecursionGuard {
self.depth = self.depth.saturating_sub(1);
}

pub fn remove(&mut self, obj_id: usize, node_id: usize) {
fn remove(&mut self, obj_id: usize, node_id: usize) {
self.ids.remove(&(obj_id, node_id));
}
}
Expand Down Expand Up @@ -98,7 +149,7 @@ impl RecursionStack {
// insert a new value
// * return `false` if the stack already had it in it
// * return `true` if the stack didn't have it in it and it was inserted
pub fn insert(&mut self, v: RecursionKey) -> bool {
fn insert(&mut self, v: RecursionKey) -> bool {
match self {
Self::Array { data, len } => {
if *len < ARRAY_SIZE {
Expand Down Expand Up @@ -129,7 +180,7 @@ impl RecursionStack {
}
}

pub fn remove(&mut self, v: &RecursionKey) {
fn remove(&mut self, v: &RecursionKey) {
match self {
Self::Array { data, len } => {
*len = len.checked_sub(1).expect("remove from empty recursion guard");
Expand Down
56 changes: 29 additions & 27 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,23 @@ use serde::ser::Error;
use super::config::SerializationConfig;
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
use super::ob_type::ObTypeLookup;
use crate::recursion_guard::ContainsRecursionState;
use crate::recursion_guard::RecursionError;
use crate::recursion_guard::RecursionGuard;
use crate::recursion_guard::RecursionState;

/// this is ugly, would be much better if extra could be stored in `SerializationState`
/// then `SerializationState` got a `serialize_infer` method, but I couldn't get it to work
pub(crate) struct SerializationState {
warnings: CollectWarnings,
rec_guard: SerRecursionGuard,
rec_guard: SerRecursionState,
config: SerializationConfig,
}

impl SerializationState {
pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult<Self> {
let warnings = CollectWarnings::new(false);
let rec_guard = SerRecursionGuard::default();
let rec_guard = SerRecursionState::default();
let config = SerializationConfig::from_args(timedelta_mode, bytes_mode, inf_nan_mode)?;
Ok(Self {
warnings,
Expand Down Expand Up @@ -77,7 +80,7 @@ pub(crate) struct Extra<'a> {
pub exclude_none: bool,
pub round_trip: bool,
pub config: &'a SerializationConfig,
pub rec_guard: &'a SerRecursionGuard,
pub rec_guard: &'a SerRecursionState,
// the next two are used for union logic
pub check: SerCheck,
// data representing the current model field
Expand All @@ -101,7 +104,7 @@ impl<'a> Extra<'a> {
exclude_none: bool,
round_trip: bool,
config: &'a SerializationConfig,
rec_guard: &'a SerRecursionGuard,
rec_guard: &'a SerRecursionState,
serialize_unknown: bool,
fallback: Option<&'a PyAny>,
) -> Self {
Expand All @@ -124,6 +127,22 @@ impl<'a> Extra<'a> {
}
}

pub fn recursion_guard<'x, 'y>(
// TODO: this double reference is a bit if a hack, but it's necessary because the recursion
// guard is not passed around with &mut reference
//
// See how validation has &mut ValidationState passed around; we should aim to refactor
// to match that.
self: &'x mut &'y Self,
value: &PyAny,
def_ref_id: usize,
) -> PyResult<RecursionGuard<'x, &'y Self>> {
RecursionGuard::new(self, value.as_ptr() as usize, def_ref_id).map_err(|e| match e {
RecursionError::Depth => PyValueError::new_err("Circular reference detected (depth exceeded)"),
RecursionError::Cyclic => PyValueError::new_err("Circular reference detected (id repeated)"),
})
}

pub fn serialize_infer<'py>(&'py self, value: &'py PyAny) -> super::infer::SerializeInfer<'py> {
super::infer::SerializeInfer::new(value, None, None, self)
}
Expand Down Expand Up @@ -157,7 +176,7 @@ pub(crate) struct ExtraOwned {
exclude_none: bool,
round_trip: bool,
config: SerializationConfig,
rec_guard: SerRecursionGuard,
rec_guard: SerRecursionState,
check: SerCheck,
model: Option<PyObject>,
field_name: Option<String>,
Expand Down Expand Up @@ -340,29 +359,12 @@ impl CollectWarnings {

#[derive(Default, Clone)]
#[cfg_attr(debug_assertions, derive(Debug))]
pub struct SerRecursionGuard {
guard: RefCell<RecursionGuard>,
pub struct SerRecursionState {
guard: RefCell<RecursionState>,
}

impl SerRecursionGuard {
pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
let id = value.as_ptr() as usize;
let mut guard = self.guard.borrow_mut();

if guard.insert(id, def_ref_id) {
if guard.incr_depth() {
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
} else {
Ok(id)
}
} else {
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
}
}

pub fn pop(&self, id: usize, def_ref_id: usize) {
let mut guard = self.guard.borrow_mut();
guard.decr_depth();
guard.remove(id, def_ref_id);
impl ContainsRecursionState for &'_ Extra<'_> {
fn access_recursion_state<R>(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R {
f(&mut self.rec_guard.guard.borrow_mut())
}
}
29 changes: 15 additions & 14 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,22 @@ pub(crate) fn infer_to_python_known(
value: &PyAny,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
mut extra: &Extra,
) -> PyResult<PyObject> {
let py = value.py();
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID) {
Ok(id) => id,

let mode = extra.mode;
let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID) {
Ok(v) => v,
Err(e) => {
return match extra.mode {
return match mode {
SerMode::Json => Err(e),
// if recursion is detected by we're serializing to python, we just return the value
_ => Ok(value.into_py(py)),
};
}
};
let extra = guard.state();

macro_rules! serialize_seq {
($t:ty) => {
Expand Down Expand Up @@ -220,7 +223,6 @@ pub(crate) fn infer_to_python_known(
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
let next_result = infer_to_python(next_value, include, exclude, extra);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
return next_result;
} else if extra.serialize_unknown {
serialize_unknown(value).into_py(py)
Expand Down Expand Up @@ -267,15 +269,13 @@ pub(crate) fn infer_to_python_known(
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
let next_result = infer_to_python(next_value, include, exclude, extra);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
return next_result;
}
value.into_py(py)
}
_ => value.into_py(py),
},
};
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
Ok(value)
}

Expand Down Expand Up @@ -332,18 +332,21 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
serializer: S,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
mut extra: &Extra,
) -> Result<S::Ok, S::Error> {
let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) {
let extra_serialize_unknown = extra.serialize_unknown;
let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID) {
Ok(v) => v,
Err(e) => {
return if extra.serialize_unknown {
return if extra_serialize_unknown {
serializer.serialize_str("...")
} else {
Err(e)
}
Err(py_err_se_err(e))
};
}
};
let extra = guard.state();

macro_rules! serialize {
($t:ty) => {
match value.extract::<$t>() {
Expand Down Expand Up @@ -506,7 +509,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
let next_result = infer_serialize(next_value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
return next_result;
} else if extra.serialize_unknown {
serializer.serialize_str(&serialize_unknown(value))
Expand All @@ -520,7 +522,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
}
}
};
extra.rec_guard.pop(value_id, INFER_DEF_REF_ID);
ser_result
}

Expand Down
8 changes: 4 additions & 4 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::py_gc::PyGcTraverse;

use config::SerializationConfig;
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
use extra::{CollectWarnings, SerRecursionGuard};
use extra::{CollectWarnings, SerRecursionState};
pub(crate) use extra::{Extra, SerMode, SerializationState};
pub use shared::CombinedSerializer;
use shared::{to_json_bytes, BuildSerializer, TypeSerializer};
Expand Down Expand Up @@ -52,7 +52,7 @@ impl SchemaSerializer {
exclude_defaults: bool,
exclude_none: bool,
round_trip: bool,
rec_guard: &'a SerRecursionGuard,
rec_guard: &'a SerRecursionState,
serialize_unknown: bool,
fallback: Option<&'a PyAny>,
) -> Extra<'b> {
Expand Down Expand Up @@ -113,7 +113,7 @@ impl SchemaSerializer {
) -> PyResult<PyObject> {
let mode: SerMode = mode.into();
let warnings = CollectWarnings::new(warnings);
let rec_guard = SerRecursionGuard::default();
let rec_guard = SerRecursionState::default();
let extra = self.build_extra(
py,
&mode,
Expand Down Expand Up @@ -152,7 +152,7 @@ impl SchemaSerializer {
fallback: Option<&PyAny>,
) -> PyResult<PyObject> {
let warnings = CollectWarnings::new(warnings);
let rec_guard = SerRecursionGuard::default();
let rec_guard = SerRecursionState::default();
let extra = self.build_extra(
py,
&SerMode::Json,
Expand Down
19 changes: 7 additions & 12 deletions src/serializers/type_serializers/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,12 @@ impl TypeSerializer for DefinitionRefSerializer {
value: &PyAny,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
mut extra: &Extra,
) -> PyResult<PyObject> {
self.definition.read(|comb_serializer| {
let comb_serializer = comb_serializer.unwrap();
let value_id = extra.rec_guard.add(value, self.definition.id())?;
let r = comb_serializer.to_python(value, include, exclude, extra);
extra.rec_guard.pop(value_id, self.definition.id());
r
let mut guard = extra.recursion_guard(value, self.definition.id())?;
comb_serializer.to_python(value, include, exclude, guard.state())
})
}

Expand All @@ -87,17 +85,14 @@ impl TypeSerializer for DefinitionRefSerializer {
serializer: S,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
mut extra: &Extra,
) -> Result<S::Ok, S::Error> {
self.definition.read(|comb_serializer| {
let comb_serializer = comb_serializer.unwrap();
let value_id = extra
.rec_guard
.add(value, self.definition.id())
let mut guard = extra
.recursion_guard(value, self.definition.id())
.map_err(py_err_se_err)?;
let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id, self.definition.id());
r
comb_serializer.serde_serialize(value, serializer, include, exclude, guard.state())
})
}

Expand Down
Loading

0 comments on commit 7a5f8e6

Please sign in to comment.