From 827a905855eff686a22e8ed5df70d1c95510b5d4 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Tue, 26 Sep 2023 20:08:02 +0100 Subject: [PATCH] wip precompiled schema --- generate_self_schema.py | 21 +++-- python/pydantic_core/core_schema.py | 56 ++++++++++++- src/serializers/mod.rs | 5 +- src/serializers/shared.rs | 2 + src/serializers/type_serializers/mod.rs | 1 + .../type_serializers/precompiled.rs | 80 +++++++++++++++++++ src/validators/mod.rs | 7 +- src/validators/precompiled.rs | 65 +++++++++++++++ tests/test_schema_functions.py | 9 +++ 9 files changed, 235 insertions(+), 11 deletions(-) create mode 100644 src/serializers/type_serializers/precompiled.rs create mode 100644 src/validators/precompiled.rs diff --git a/generate_self_schema.py b/generate_self_schema.py index 2c190bbad..83167d038 100644 --- a/generate_self_schema.py +++ b/generate_self_schema.py @@ -54,8 +54,9 @@ def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core return type_dict_schema(obj, definitions) elif obj == Any or obj == type: return {'type': 'any'} - if isinstance(obj, type) and issubclass(obj, core_schema.Protocol): + elif isinstance(obj, type) and issubclass(obj, core_schema.Protocol): return {'type': 'callable'} + # elif isinstance(obj, ForwardRef): origin = get_origin(obj) assert origin is not None, f'origin cannot be None, obj={obj}, you probably need to fix generate_self_schema.py' @@ -151,6 +152,9 @@ def type_dict_schema( # noqa: C901 else: field_type = eval_forward_ref(field_type) + if fr_arg == 'SchemaValidator' or fr_arg == 'SchemaSerializer': + schema = {'type': 'is-instance', 'cls': Ident(fr_arg), 'cls_repr': f'pydantic_core.{fr_arg}'} + if schema is None: if get_origin(field_type) == core_schema.Required: required = True @@ -202,7 +206,7 @@ def main() -> None: definitions: dict[str, core_schema.CoreSchema] = {} choices = {} - for s in schema_union.__args__: + for s in get_args(schema_union): type_ = s.__annotations__['type'] m = re.search(r"Literal\['(.+?)']", type_.__forward_arg__) assert m, f'Unknown schema type: {type_}' @@ -217,9 +221,9 @@ def main() -> None: *definitions.values(), ], ) - python_code = ( - f'# this file is auto-generated by generate_self_schema.py, DO NOT edit manually\nself_schema = {schema}\n' - ) + python_code = f"""# this file is auto-generated by generate_self_schema.py, DO NOT edit manually +from pydantic_core import SchemaValidator, SchemaSerializer +self_schema = {schema}\n""" try: from black import Mode, TargetVersion, format_file_contents except ImportError: @@ -236,5 +240,12 @@ def main() -> None: print(f'Self schema definition written to {SAVE_PATH}') +class Ident(str): + """Format a literal as a Ident in the output""" + + def __repr__(self) -> str: + return str(self) + + if __name__ == '__main__': main() diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 2d7061ffd..fd541e243 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -30,15 +30,17 @@ from typing import Literal if TYPE_CHECKING: - from pydantic_core import PydanticUndefined + from pydantic_core import PydanticUndefined, SchemaSerializer, SchemaValidator else: - # The initial build of pydantic_core requires PydanticUndefined to generate + # The initial build of pydantic_core requires some Rust structures to generate # the core schema; so we need to conditionally skip it. mypy doesn't like # this at all, hence the TYPE_CHECKING branch above. try: - from pydantic_core import PydanticUndefined + from pydantic_core import PydanticUndefined, SchemaSerializer, SchemaValidator except ImportError: - PydanticUndefined = object() + PydanticUndefined = 'PydanticUndefined' + SchemaValidator = 'SchemaValidator' + SchemaSerializer = 'SchemaSerializer' ExtraBehavior = Literal['allow', 'forbid', 'ignore'] @@ -3605,6 +3607,50 @@ def definition_reference_schema( return _dict_not_none(type='definition-ref', schema_ref=schema_ref, metadata=metadata, serialization=serialization) +class PrecompiledSchema(TypedDict, total=False): + type: Required[Literal['precompiled']] + schema: CoreSchema + validator: SchemaValidator + serializer: SchemaSerializer + ref: str + metadata: Any + + +def precompiled_schema( + schema: CoreSchema, + validator: SchemaValidator, + serializer: SchemaSerializer, + ref: str | None = None, + metadata: Any = None, +) -> PrecompiledSchema: + """ + Returns a schema that points to a schema stored in "definitions", this is useful for nested recursive + models and also when you want to define validators separately from the main schema, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema_definition = core_schema.definition_reference_schema('list-schema') + schema = core_schema.definitions_schema( + schema=schema_definition, + definitions=[ + core_schema.list_schema(items_schema=schema_definition, ref='list-schema'), + ], + ) + v = SchemaValidator(schema) + assert v.validate_python([()]) == [[]] + ``` + + Args: + schema_ref: The schema ref to use for the definition reference schema + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='precompiled', schema=schema, validator=validator, serializer=serializer, ref=ref, metadata=metadata + ) + + MYPY = False # See https://github.com/python/mypy/issues/14034 for details, in summary mypy is extremely slow to process this # union which kills performance not just for pydantic, but even for code using pydantic @@ -3658,6 +3704,7 @@ def definition_reference_schema( DefinitionsSchema, DefinitionReferenceSchema, UuidSchema, + PrecompiledSchema, ] elif False: CoreSchema: TypeAlias = Mapping[str, Any] @@ -3713,6 +3760,7 @@ def definition_reference_schema( 'definitions', 'definition-ref', 'uuid', + 'precompiled', ] CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field'] diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 72028346b..c1f828a30 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -26,10 +26,11 @@ mod ob_type; mod shared; mod type_serializers; -#[pyclass(module = "pydantic_core._pydantic_core")] +#[pyclass(module = "pydantic_core._pydantic_core", frozen)] #[derive(Debug)] pub struct SchemaSerializer { serializer: CombinedSerializer, + schema: PyObject, definitions: Definitions, expected_json_size: AtomicUsize, config: SerializationConfig, @@ -77,6 +78,7 @@ impl SchemaSerializer { let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?; Ok(Self { serializer, + schema: schema.into(), definitions: definitions_builder.finish()?, expected_json_size: AtomicUsize::new(1024), config: SerializationConfig::from_config(config)?, @@ -183,6 +185,7 @@ impl SchemaSerializer { fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { self.serializer.py_gc_traverse(&visit)?; + visit.call(&self.schema)?; self.definitions.py_gc_traverse(&visit)?; Ok(()) } diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 7c24ff6db..96a9e7fd5 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -141,6 +141,7 @@ combined_serializer! { Recursive: super::type_serializers::definitions::DefinitionRefSerializer; TuplePositional: super::type_serializers::tuple::TuplePositionalSerializer; TupleVariable: super::type_serializers::tuple::TupleVariableSerializer; + Precompiled: super::type_serializers::precompiled::PrecompiledSerializer; } } @@ -250,6 +251,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::TuplePositional(inner) => inner.py_gc_traverse(visit), CombinedSerializer::TupleVariable(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::Precompiled(inner) => inner.py_gc_traverse(visit), } } } diff --git a/src/serializers/type_serializers/mod.rs b/src/serializers/type_serializers/mod.rs index b942b5b86..b97762a36 100644 --- a/src/serializers/type_serializers/mod.rs +++ b/src/serializers/type_serializers/mod.rs @@ -15,6 +15,7 @@ pub mod literal; pub mod model; pub mod nullable; pub mod other; +pub mod precompiled; pub mod set_frozenset; pub mod simple; pub mod string; diff --git a/src/serializers/type_serializers/precompiled.rs b/src/serializers/type_serializers/precompiled.rs new file mode 100644 index 000000000..e38f45a49 --- /dev/null +++ b/src/serializers/type_serializers/precompiled.rs @@ -0,0 +1,80 @@ +use std::borrow::Cow; + +use pyo3::types::PyDict; +use pyo3::{intern, prelude::*}; + +use crate::build_tools::py_schema_err; +use crate::definitions::DefinitionsBuilder; +use crate::serializers::shared::TypeSerializer; +use crate::serializers::Extra; +use crate::tools::SchemaDict; +use crate::SchemaSerializer; + +use super::{BuildSerializer, CombinedSerializer}; + +#[derive(Debug, Clone)] +pub struct PrecompiledSerializer { + serializer: Py, +} + +impl BuildSerializer for PrecompiledSerializer { + const EXPECTED_TYPE: &'static str = "precompiled"; + + fn build( + schema: &PyDict, + _config: Option<&PyDict>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let py = schema.py(); + let sub_schema: &PyAny = schema.get_as_req(intern!(py, "schema"))?; + let serializer: PyRef = schema.get_as_req(intern!(py, "serializer"))?; + + // TODO DEBUG THIS LATER + // if !serializer.schema.is(sub_schema) { + // return py_schema_err!("precompiled schema mismatch"); + // } + + Ok(CombinedSerializer::Precompiled(PrecompiledSerializer { + serializer: serializer.into(), + })) + } +} + +impl_py_gc_traverse!(PrecompiledSerializer { serializer }); + +impl TypeSerializer for PrecompiledSerializer { + fn to_python( + &self, + value: &PyAny, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> PyResult { + self.serializer + .get() + .serializer + .to_python(value, include, exclude, extra) + } + + fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { + self.serializer.get().serializer.json_key(key, extra) + } + + fn serde_serialize( + &self, + value: &PyAny, + serializer: S, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> Result { + self.serializer + .get() + .serializer + .serde_serialize(value, serializer, include, exclude, extra) + } + + fn get_name(&self) -> &str { + self.serializer.get().serializer.get_name() + } +} diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 42aad2001..bb6f16229 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -46,6 +46,7 @@ mod model; mod model_fields; mod none; mod nullable; +mod precompiled; mod set; mod string; mod time; @@ -97,7 +98,7 @@ impl PySome { } } -#[pyclass(module = "pydantic_core._pydantic_core")] +#[pyclass(module = "pydantic_core._pydantic_core", frozen)] #[derive(Debug)] pub struct SchemaValidator { validator: CombinedValidator, @@ -542,6 +543,8 @@ pub fn build_validator<'a>( // recursive (self-referencing) models definitions::DefinitionRefValidator, definitions::DefinitionsValidatorBuilder, + // precompiled models + precompiled::PrecompiledValidator, ) } @@ -690,6 +693,8 @@ pub enum CombinedValidator { DefinitionRef(definitions::DefinitionRefValidator), // input dependent JsonOrPython(json_or_python::JsonOrPython), + // reusing a sub-schema + Precompiled(precompiled::PrecompiledValidator), } /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, diff --git a/src/validators/precompiled.rs b/src/validators/precompiled.rs new file mode 100644 index 000000000..17e52fd25 --- /dev/null +++ b/src/validators/precompiled.rs @@ -0,0 +1,65 @@ +use pyo3::types::PyDict; +use pyo3::{intern, prelude::*}; + +use crate::build_tools::py_schema_err; +use crate::definitions::DefinitionsBuilder; +use crate::errors::ValResult; +use crate::input::Input; +use crate::tools::SchemaDict; +use crate::SchemaValidator; + +use super::{BuildValidator, CombinedValidator, ValidationState, Validator}; + +#[derive(Debug)] +pub struct PrecompiledValidator { + validator: Py, +} + +impl BuildValidator for PrecompiledValidator { + const EXPECTED_TYPE: &'static str = "precompiled"; + + fn build( + schema: &PyDict, + _config: Option<&PyDict>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let py = schema.py(); + let sub_schema: &PyAny = schema.get_as_req(intern!(py, "schema"))?; + let validator: PyRef = schema.get_as_req(intern!(py, "validator"))?; + + // TODO DEBUG THIS LATER + // if !validator.schema.is(sub_schema) { + // return py_schema_err!("precompiled schema mismatch"); + // } + + Ok(CombinedValidator::Precompiled(PrecompiledValidator { + validator: validator.into(), + })) + } +} + +impl_py_gc_traverse!(PrecompiledValidator { validator }); + +impl Validator for PrecompiledValidator { + fn validate<'data>( + &self, + py: Python<'data>, + input: &'data impl Input<'data>, + state: &mut ValidationState, + ) -> ValResult<'data, PyObject> { + self.validator.get().validator.validate(py, input, state) + } + + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { + self.validator.get().validator.different_strict_behavior(ultra_strict) + } + + fn get_name(&self) -> &str { + self.validator.get().validator.get_name() + } + + fn complete(&self) -> PyResult<()> { + // No need to complete a precompiled validator + Ok(()) + } +} diff --git a/tests/test_schema_functions.py b/tests/test_schema_functions.py index d4b53cbe4..b8dcd9fe3 100644 --- a/tests/test_schema_functions.py +++ b/tests/test_schema_functions.py @@ -39,6 +39,10 @@ def args(*args, **kwargs): return args, kwargs +INT_SCHEMA = core_schema.int_schema() +INT_VALIDATOR = SchemaValidator(INT_SCHEMA) +INT_SERIALIZER = SchemaSerializer(INT_SCHEMA) + all_schema_functions = [ (core_schema.any_schema, args(), {'type': 'any'}), (core_schema.any_schema, args(metadata=['foot', 'spa']), {'type': 'any', 'metadata': ['foot', 'spa']}), @@ -289,6 +293,11 @@ def args(*args, **kwargs): (core_schema.uuid_schema, args(), {'type': 'uuid'}), (core_schema.decimal_schema, args(), {'type': 'decimal'}), (core_schema.decimal_schema, args(multiple_of=5, gt=1.2), {'type': 'decimal', 'multiple_of': 5, 'gt': 1.2}), + ( + core_schema.precompiled_schema, + args(schema=INT_SCHEMA, validator=INT_VALIDATOR, serializer=INT_SERIALIZER), + {'type': 'precompiled', 'schema': INT_SCHEMA, 'validator': INT_VALIDATOR, 'serializer': INT_SERIALIZER}, + ), ]