Skip to content

Commit

Permalink
validators never need to be cloned
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Sep 26, 2023
1 parent cddcc71 commit 5f14973
Show file tree
Hide file tree
Showing 25 changed files with 69 additions and 58 deletions.
8 changes: 8 additions & 0 deletions src/py_gc.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use ahash::AHashMap;
use enum_dispatch::enum_dispatch;
use pyo3::{AsPyPointer, Py, PyTraverseError, PyVisit};
Expand Down Expand Up @@ -35,6 +37,12 @@ impl<T: PyGcTraverse> PyGcTraverse for AHashMap<String, T> {
}
}

impl<T: PyGcTraverse> PyGcTraverse for Arc<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
T::py_gc_traverse(self, visit)
}
}

impl<T: PyGcTraverse> PyGcTraverse for Box<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
T::py_gc_traverse(self, visit)
Expand Down
4 changes: 2 additions & 2 deletions src/validators/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::tools::SchemaDict;
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, Clone)]
#[derive(Debug)]
struct Parameter {
positional: bool,
name: String,
Expand All @@ -24,7 +24,7 @@ struct Parameter {
validator: CombinedValidator,
}

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct ArgumentsValidator {
parameters: Vec<Parameter>,
positional_params_count: usize,
Expand Down
2 changes: 1 addition & 1 deletion src/validators/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::tools::SchemaDict;
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct CallValidator {
function: PyObject,
arguments_validator: Box<CombinedValidator>,
Expand Down
2 changes: 1 addition & 1 deletion src/validators/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::tools::SchemaDict;
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct ChainValidator {
steps: Vec<CombinedValidator>,
name: String,
Expand Down
2 changes: 1 addition & 1 deletion src/validators/custom_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl CustomError {
}
}

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct CustomErrorValidator {
validator: Box<CombinedValidator>,
custom_error: CustomError,
Expand Down
6 changes: 3 additions & 3 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use super::{
build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator,
};

#[derive(Debug, Clone)]
#[derive(Debug)]
struct Field {
kw_only: bool,
name: String,
Expand All @@ -30,7 +30,7 @@ struct Field {
frozen: bool,
}

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct DataclassArgsValidator {
fields: Vec<Field>,
positional_count: usize,
Expand Down Expand Up @@ -441,7 +441,7 @@ impl Validator for DataclassArgsValidator {
}
}

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct DataclassValidator {
strict: bool,
validator: Box<CombinedValidator>,
Expand Down
2 changes: 1 addition & 1 deletion src/validators/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use super::any::AnyValidator;
use super::list::length_check;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct DictValidator {
strict: bool,
key_validator: Box<CombinedValidator>,
Expand Down
2 changes: 1 addition & 1 deletion src/validators/frozenset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::set::set_build;
use super::validation_state::ValidationState;
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct FrozenSetValidator {
strict: bool,
item_validator: Box<CombinedValidator>,
Expand Down
20 changes: 11 additions & 9 deletions src/validators/function.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use pyo3::exceptions::{PyAssertionError, PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyDict, PyString};
Expand Down Expand Up @@ -130,7 +132,7 @@ macro_rules! impl_validator {
};
}

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct FunctionBeforeValidator {
validator: Box<CombinedValidator>,
func: PyObject,
Expand Down Expand Up @@ -163,7 +165,7 @@ impl FunctionBeforeValidator {

impl_validator!(FunctionBeforeValidator);

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct FunctionAfterValidator {
validator: Box<CombinedValidator>,
func: PyObject,
Expand Down Expand Up @@ -264,9 +266,9 @@ impl Validator for FunctionPlainValidator {
}
}

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct FunctionWrapValidator {
validator: Box<CombinedValidator>,
validator: Arc<CombinedValidator>,
func: PyObject,
config: PyObject,
name: String,
Expand All @@ -290,7 +292,7 @@ impl BuildValidator for FunctionWrapValidator {
let hide_input_in_errors: bool = config.get_as(intern!(py, "hide_input_in_errors"))?.unwrap_or(false);
let validation_error_cause: bool = config.get_as(intern!(py, "validation_error_cause"))?.unwrap_or(false);
Ok(Self {
validator: Box::new(validator),
validator: Arc::new(validator),
func: function_info.function.clone(),
config: match config {
Some(c) => c.into(),
Expand Down Expand Up @@ -341,7 +343,7 @@ impl Validator for FunctionWrapValidator {
validator: InternalValidator::new(
py,
"ValidatorCallable",
&self.validator,
self.validator.clone(),
state,
self.hide_input_in_errors,
self.validation_error_cause,
Expand All @@ -367,7 +369,7 @@ impl Validator for FunctionWrapValidator {
validator: InternalValidator::new(
py,
"AssignmentValidatorCallable",
&self.validator,
self.validator.clone(),
state,
self.hide_input_in_errors,
self.validation_error_cause,
Expand Down Expand Up @@ -396,7 +398,7 @@ impl Validator for FunctionWrapValidator {
}

#[pyclass(module = "pydantic_core._pydantic_core")]
#[derive(Debug, Clone)]
#[derive(Debug)]
struct ValidatorCallable {
validator: InternalValidator,
}
Expand Down Expand Up @@ -428,7 +430,7 @@ impl ValidatorCallable {
}

#[pyclass(module = "pydantic_core._pydantic_core")]
#[derive(Debug, Clone)]
#[derive(Debug)]
struct AssignmentValidatorCallable {
updated_field_name: String,
updated_field_value: Py<PyAny>,
Expand Down
18 changes: 9 additions & 9 deletions src/validators/generator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::fmt;
use std::sync::Arc;

use pyo3::prelude::*;
use pyo3::types::PyDict;
Expand All @@ -14,7 +15,7 @@ use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, InputT

#[derive(Debug, Clone)]
pub struct GeneratorValidator {
item_validator: Option<Box<CombinedValidator>>,
item_validator: Option<Arc<CombinedValidator>>,
min_length: Option<usize>,
max_length: Option<usize>,
name: String,
Expand All @@ -30,7 +31,7 @@ impl BuildValidator for GeneratorValidator {
config: Option<&PyDict>,
definitions: &mut DefinitionsBuilder<CombinedValidator>,
) -> PyResult<CombinedValidator> {
let item_validator = get_items_schema(schema, config, definitions)?;
let item_validator = get_items_schema(schema, config, definitions)?.map(Arc::new);
let name = match item_validator {
Some(ref v) => format!("{}[{}]", Self::EXPECTED_TYPE, v.get_name()),
None => format!("{}[any]", Self::EXPECTED_TYPE),
Expand Down Expand Up @@ -67,7 +68,7 @@ impl Validator for GeneratorValidator {
InternalValidator::new(
py,
"ValidatorIterator",
v,
v.clone(),
state,
self.hide_input_in_errors,
self.validation_error_cause,
Expand Down Expand Up @@ -106,7 +107,7 @@ impl Validator for GeneratorValidator {
}

#[pyclass(module = "pydantic_core._pydantic_core")]
#[derive(Debug, Clone)]
#[derive(Debug)]
struct ValidatorIterator {
iterator: GenericIterator,
validator: Option<InternalValidator>,
Expand Down Expand Up @@ -213,12 +214,11 @@ impl ValidatorIterator {
}
}

/// Cloneable validator wrapper for use in generators in functions, this can be passed back to python
/// Owned validator wrapper for use in generators in functions, this can be passed back to python
/// mid-validation
#[derive(Clone)]
pub struct InternalValidator {
name: String,
validator: CombinedValidator,
validator: Arc<CombinedValidator>,
// TODO, do we need data?
data: Option<Py<PyDict>>,
strict: Option<bool>,
Expand All @@ -241,15 +241,15 @@ impl InternalValidator {
pub fn new(
py: Python,
name: &str,
validator: &CombinedValidator,
validator: Arc<CombinedValidator>,
state: &ValidationState,
hide_input_in_errors: bool,
validation_error_cause: bool,
) -> Self {
let extra = state.extra();
Self {
name: name.to_string(),
validator: validator.clone(),
validator,
data: extra.data.map(|d| d.into_py(py)),
strict: extra.strict,
from_attributes: extra.from_attributes,
Expand Down
2 changes: 1 addition & 1 deletion src/validators/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::tools::SchemaDict;
use super::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct JsonValidator {
validator: Option<Box<CombinedValidator>>,
name: String,
Expand Down
2 changes: 1 addition & 1 deletion src/validators/json_or_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use super::InputType;
use super::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, Validator};

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct JsonOrPython {
json: Box<CombinedValidator>,
python: Box<CombinedValidator>,
Expand Down
2 changes: 1 addition & 1 deletion src/validators/lax_or_strict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::tools::SchemaDict;
use super::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct LaxOrStrictValidator {
strict: bool,
lax_validator: Box<CombinedValidator>,
Expand Down
8 changes: 4 additions & 4 deletions src/validators/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::tools::SchemaDict;

use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct ListValidator {
strict: bool,
item_validator: Option<Box<CombinedValidator>>,
Expand All @@ -22,13 +22,13 @@ pub fn get_items_schema(
schema: &PyDict,
config: Option<&PyDict>,
definitions: &mut DefinitionsBuilder<CombinedValidator>,
) -> PyResult<Option<Box<CombinedValidator>>> {
) -> PyResult<Option<CombinedValidator>> {
match schema.get_item(pyo3::intern!(schema.py(), "items_schema")) {
Some(d) => {
let validator = build_validator(d, config, definitions)?;
match validator {
CombinedValidator::Any(_) => Ok(None),
_ => Ok(Some(Box::new(validator))),
_ => Ok(Some(validator)),
}
}
None => Ok(None),
Expand Down Expand Up @@ -100,7 +100,7 @@ impl BuildValidator for ListValidator {
definitions: &mut DefinitionsBuilder<CombinedValidator>,
) -> PyResult<CombinedValidator> {
let py = schema.py();
let item_validator = get_items_schema(schema, config, definitions)?;
let item_validator = get_items_schema(schema, config, definitions)?.map(Box::new);
Ok(Self {
strict: crate::build_tools::is_strict(schema, config)?,
item_validator,
Expand Down
6 changes: 3 additions & 3 deletions src/validators/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct BoolLiteral {
}

#[derive(Debug, Clone)]
pub struct LiteralLookup<T: Clone + Debug> {
pub struct LiteralLookup<T: Debug> {
// Specialized lookups for ints, bools and strings because they
// (1) are easy to convert between Rust and Python
// (2) hashing them in Rust is very fast
Expand All @@ -35,7 +35,7 @@ pub struct LiteralLookup<T: Clone + Debug> {
pub values: Vec<T>,
}

impl<T: Clone + Debug> LiteralLookup<T> {
impl<T: Debug> LiteralLookup<T> {
pub fn new<'py>(py: Python<'py>, expected: impl Iterator<Item = (&'py PyAny, T)>) -> PyResult<Self> {
let mut expected_int = AHashMap::new();
let mut expected_str: AHashMap<String, usize> = AHashMap::new();
Expand Down Expand Up @@ -135,7 +135,7 @@ impl<T: Clone + Debug> LiteralLookup<T> {
}
}

impl<T: PyGcTraverse + Clone + Debug> PyGcTraverse for LiteralLookup<T> {
impl<T: PyGcTraverse + Debug> PyGcTraverse for LiteralLookup<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
self.expected_py.py_gc_traverse(visit)?;
self.values.py_gc_traverse(visit)?;
Expand Down
13 changes: 7 additions & 6 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ impl PySome {
}

#[pyclass(module = "pydantic_core._pydantic_core")]
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct SchemaValidator {
validator: CombinedValidator,
definitions: Definitions<CombinedValidator>,
Expand Down Expand Up @@ -141,9 +141,10 @@ impl SchemaValidator {
})
}

pub fn __reduce__(&self, py: Python) -> PyResult<PyObject> {
let args = (self.schema.as_ref(py),);
let cls = Py::new(py, self.clone())?.getattr(py, "__class__")?;
pub fn __reduce__(slf: &PyCell<Self>) -> PyResult<PyObject> {
let py = slf.py();
let args = (slf.try_borrow()?.schema.to_object(py),);
let cls = slf.getattr("__class__")?;
Ok((cls, args).into_py(py))
}

Expand Down Expand Up @@ -598,7 +599,7 @@ impl<'a> Extra<'a> {
}
}

#[derive(Debug, Clone)]
#[derive(Debug)]
#[enum_dispatch(PyGcTraverse)]
pub enum CombinedValidator {
// typed dict e.g. heterogeneous dicts or simply a model
Expand Down Expand Up @@ -694,7 +695,7 @@ pub enum CombinedValidator {
/// This trait must be implemented by all validators, it allows various validators to be accessed consistently,
/// validators defined in `build_validator` also need `EXPECTED_TYPE` as a const, but that can't be part of the trait
#[enum_dispatch(CombinedValidator)]
pub trait Validator: Send + Sync + Clone + Debug {
pub trait Validator: Send + Sync + Debug {
/// Do the actual validation for this schema/type
fn validate<'data>(
&self,
Expand Down
Loading

0 comments on commit 5f14973

Please sign in to comment.