Skip to content

Commit

Permalink
wip precompiled schema
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Sep 26, 2023
1 parent 5f14973 commit 827a905
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 11 deletions.
21 changes: 16 additions & 5 deletions generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core
return type_dict_schema(obj, definitions)
elif obj == Any or obj == type:
return {'type': 'any'}
if isinstance(obj, type) and issubclass(obj, core_schema.Protocol):
elif isinstance(obj, type) and issubclass(obj, core_schema.Protocol):
return {'type': 'callable'}
# elif isinstance(obj, ForwardRef):

origin = get_origin(obj)
assert origin is not None, f'origin cannot be None, obj={obj}, you probably need to fix generate_self_schema.py'
Expand Down Expand Up @@ -151,6 +152,9 @@ def type_dict_schema( # noqa: C901
else:
field_type = eval_forward_ref(field_type)

if fr_arg == 'SchemaValidator' or fr_arg == 'SchemaSerializer':
schema = {'type': 'is-instance', 'cls': Ident(fr_arg), 'cls_repr': f'pydantic_core.{fr_arg}'}

if schema is None:
if get_origin(field_type) == core_schema.Required:
required = True
Expand Down Expand Up @@ -202,7 +206,7 @@ def main() -> None:
definitions: dict[str, core_schema.CoreSchema] = {}

choices = {}
for s in schema_union.__args__:
for s in get_args(schema_union):
type_ = s.__annotations__['type']
m = re.search(r"Literal\['(.+?)']", type_.__forward_arg__)
assert m, f'Unknown schema type: {type_}'
Expand All @@ -217,9 +221,9 @@ def main() -> None:
*definitions.values(),
],
)
python_code = (
f'# this file is auto-generated by generate_self_schema.py, DO NOT edit manually\nself_schema = {schema}\n'
)
python_code = f"""# this file is auto-generated by generate_self_schema.py, DO NOT edit manually
from pydantic_core import SchemaValidator, SchemaSerializer
self_schema = {schema}\n"""
try:
from black import Mode, TargetVersion, format_file_contents
except ImportError:
Expand All @@ -236,5 +240,12 @@ def main() -> None:
print(f'Self schema definition written to {SAVE_PATH}')


class Ident(str):
"""Format a literal as a Ident in the output"""

def __repr__(self) -> str:
return str(self)


if __name__ == '__main__':
main()
56 changes: 52 additions & 4 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,17 @@
from typing import Literal

if TYPE_CHECKING:
from pydantic_core import PydanticUndefined
from pydantic_core import PydanticUndefined, SchemaSerializer, SchemaValidator
else:
# The initial build of pydantic_core requires PydanticUndefined to generate
# The initial build of pydantic_core requires some Rust structures to generate
# the core schema; so we need to conditionally skip it. mypy doesn't like
# this at all, hence the TYPE_CHECKING branch above.
try:
from pydantic_core import PydanticUndefined
from pydantic_core import PydanticUndefined, SchemaSerializer, SchemaValidator
except ImportError:
PydanticUndefined = object()
PydanticUndefined = 'PydanticUndefined'
SchemaValidator = 'SchemaValidator'
SchemaSerializer = 'SchemaSerializer'


ExtraBehavior = Literal['allow', 'forbid', 'ignore']
Expand Down Expand Up @@ -3605,6 +3607,50 @@ def definition_reference_schema(
return _dict_not_none(type='definition-ref', schema_ref=schema_ref, metadata=metadata, serialization=serialization)


class PrecompiledSchema(TypedDict, total=False):
type: Required[Literal['precompiled']]
schema: CoreSchema
validator: SchemaValidator
serializer: SchemaSerializer
ref: str
metadata: Any


def precompiled_schema(
schema: CoreSchema,
validator: SchemaValidator,
serializer: SchemaSerializer,
ref: str | None = None,
metadata: Any = None,
) -> PrecompiledSchema:
"""
Returns a schema that points to a schema stored in "definitions", this is useful for nested recursive
models and also when you want to define validators separately from the main schema, e.g.:
```py
from pydantic_core import SchemaValidator, core_schema
schema_definition = core_schema.definition_reference_schema('list-schema')
schema = core_schema.definitions_schema(
schema=schema_definition,
definitions=[
core_schema.list_schema(items_schema=schema_definition, ref='list-schema'),
],
)
v = SchemaValidator(schema)
assert v.validate_python([()]) == [[]]
```
Args:
schema_ref: The schema ref to use for the definition reference schema
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
"""
return _dict_not_none(
type='precompiled', schema=schema, validator=validator, serializer=serializer, ref=ref, metadata=metadata
)


MYPY = False
# See https://github.com/python/mypy/issues/14034 for details, in summary mypy is extremely slow to process this
# union which kills performance not just for pydantic, but even for code using pydantic
Expand Down Expand Up @@ -3658,6 +3704,7 @@ def definition_reference_schema(
DefinitionsSchema,
DefinitionReferenceSchema,
UuidSchema,
PrecompiledSchema,
]
elif False:
CoreSchema: TypeAlias = Mapping[str, Any]
Expand Down Expand Up @@ -3713,6 +3760,7 @@ def definition_reference_schema(
'definitions',
'definition-ref',
'uuid',
'precompiled',
]

CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field']
Expand Down
5 changes: 4 additions & 1 deletion src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ mod ob_type;
mod shared;
mod type_serializers;

#[pyclass(module = "pydantic_core._pydantic_core")]
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
#[derive(Debug)]
pub struct SchemaSerializer {
serializer: CombinedSerializer,
schema: PyObject,
definitions: Definitions<CombinedSerializer>,
expected_json_size: AtomicUsize,
config: SerializationConfig,
Expand Down Expand Up @@ -77,6 +78,7 @@ impl SchemaSerializer {
let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?;
Ok(Self {
serializer,
schema: schema.into(),
definitions: definitions_builder.finish()?,
expected_json_size: AtomicUsize::new(1024),
config: SerializationConfig::from_config(config)?,
Expand Down Expand Up @@ -183,6 +185,7 @@ impl SchemaSerializer {

fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.serializer.py_gc_traverse(&visit)?;
visit.call(&self.schema)?;
self.definitions.py_gc_traverse(&visit)?;
Ok(())
}
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ combined_serializer! {
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
TuplePositional: super::type_serializers::tuple::TuplePositionalSerializer;
TupleVariable: super::type_serializers::tuple::TupleVariableSerializer;
Precompiled: super::type_serializers::precompiled::PrecompiledSerializer;
}
}

Expand Down Expand Up @@ -250,6 +251,7 @@ impl PyGcTraverse for CombinedSerializer {
CombinedSerializer::TuplePositional(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::TupleVariable(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Precompiled(inner) => inner.py_gc_traverse(visit),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/serializers/type_serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub mod literal;
pub mod model;
pub mod nullable;
pub mod other;
pub mod precompiled;
pub mod set_frozenset;
pub mod simple;
pub mod string;
Expand Down
80 changes: 80 additions & 0 deletions src/serializers/type_serializers/precompiled.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use std::borrow::Cow;

use pyo3::types::PyDict;
use pyo3::{intern, prelude::*};

use crate::build_tools::py_schema_err;
use crate::definitions::DefinitionsBuilder;
use crate::serializers::shared::TypeSerializer;
use crate::serializers::Extra;
use crate::tools::SchemaDict;
use crate::SchemaSerializer;

use super::{BuildSerializer, CombinedSerializer};

#[derive(Debug, Clone)]
pub struct PrecompiledSerializer {
serializer: Py<SchemaSerializer>,
}

impl BuildSerializer for PrecompiledSerializer {
const EXPECTED_TYPE: &'static str = "precompiled";

fn build(
schema: &PyDict,
_config: Option<&PyDict>,
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
let py = schema.py();
let sub_schema: &PyAny = schema.get_as_req(intern!(py, "schema"))?;
let serializer: PyRef<SchemaSerializer> = schema.get_as_req(intern!(py, "serializer"))?;

// TODO DEBUG THIS LATER
// if !serializer.schema.is(sub_schema) {
// return py_schema_err!("precompiled schema mismatch");
// }

Ok(CombinedSerializer::Precompiled(PrecompiledSerializer {
serializer: serializer.into(),
}))
}
}

impl_py_gc_traverse!(PrecompiledSerializer { serializer });

impl TypeSerializer for PrecompiledSerializer {
fn to_python(
&self,
value: &PyAny,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<PyObject> {
self.serializer
.get()
.serializer
.to_python(value, include, exclude, extra)
}

fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult<Cow<'py, str>> {
self.serializer.get().serializer.json_key(key, extra)
}

fn serde_serialize<S: serde::ser::Serializer>(
&self,
value: &PyAny,
serializer: S,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
self.serializer
.get()
.serializer
.serde_serialize(value, serializer, include, exclude, extra)
}

fn get_name(&self) -> &str {
self.serializer.get().serializer.get_name()
}
}
7 changes: 6 additions & 1 deletion src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ mod model;
mod model_fields;
mod none;
mod nullable;
mod precompiled;
mod set;
mod string;
mod time;
Expand Down Expand Up @@ -97,7 +98,7 @@ impl PySome {
}
}

#[pyclass(module = "pydantic_core._pydantic_core")]
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
#[derive(Debug)]
pub struct SchemaValidator {
validator: CombinedValidator,
Expand Down Expand Up @@ -542,6 +543,8 @@ pub fn build_validator<'a>(
// recursive (self-referencing) models
definitions::DefinitionRefValidator,
definitions::DefinitionsValidatorBuilder,
// precompiled models
precompiled::PrecompiledValidator,
)
}

Expand Down Expand Up @@ -690,6 +693,8 @@ pub enum CombinedValidator {
DefinitionRef(definitions::DefinitionRefValidator),
// input dependent
JsonOrPython(json_or_python::JsonOrPython),
// reusing a sub-schema
Precompiled(precompiled::PrecompiledValidator),
}

/// This trait must be implemented by all validators, it allows various validators to be accessed consistently,
Expand Down
65 changes: 65 additions & 0 deletions src/validators/precompiled.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use pyo3::types::PyDict;
use pyo3::{intern, prelude::*};

use crate::build_tools::py_schema_err;
use crate::definitions::DefinitionsBuilder;
use crate::errors::ValResult;
use crate::input::Input;
use crate::tools::SchemaDict;
use crate::SchemaValidator;

use super::{BuildValidator, CombinedValidator, ValidationState, Validator};

#[derive(Debug)]
pub struct PrecompiledValidator {
validator: Py<SchemaValidator>,
}

impl BuildValidator for PrecompiledValidator {
const EXPECTED_TYPE: &'static str = "precompiled";

fn build(
schema: &PyDict,
_config: Option<&PyDict>,
_definitions: &mut DefinitionsBuilder<CombinedValidator>,
) -> PyResult<CombinedValidator> {
let py = schema.py();
let sub_schema: &PyAny = schema.get_as_req(intern!(py, "schema"))?;
let validator: PyRef<SchemaValidator> = schema.get_as_req(intern!(py, "validator"))?;

// TODO DEBUG THIS LATER
// if !validator.schema.is(sub_schema) {
// return py_schema_err!("precompiled schema mismatch");
// }

Ok(CombinedValidator::Precompiled(PrecompiledValidator {
validator: validator.into(),
}))
}
}

impl_py_gc_traverse!(PrecompiledValidator { validator });

impl Validator for PrecompiledValidator {
fn validate<'data>(
&self,
py: Python<'data>,
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
self.validator.get().validator.validate(py, input, state)
}

fn different_strict_behavior(&self, ultra_strict: bool) -> bool {
self.validator.get().validator.different_strict_behavior(ultra_strict)
}

fn get_name(&self) -> &str {
self.validator.get().validator.get_name()
}

fn complete(&self) -> PyResult<()> {
// No need to complete a precompiled validator
Ok(())
}
}
Loading

0 comments on commit 827a905

Please sign in to comment.