Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

caching strings from JSON #1240

Merged
merged 13 commits into from
Mar 27, 2024
2 changes: 2 additions & 0 deletions .mypy-stubtest-allowlist
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
# TODO: don't want to expose this staticmethod, requires https://github.com/PyO3/pyo3/issues/2384
pydantic_core._pydantic_core.PydanticUndefinedType.new
# As per #1240, from_json has custom logic to coverage the `cache_strings` kwarg
pydantic_core._pydantic_core.from_json
15 changes: 2 additions & 13 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ base64 = "0.21.7"
num-bigint = "0.4.4"
python3-dll-a = "0.2.7"
uuid = "1.7.0"
jiter = { version = "0.1.0", features = ["python"] }
jiter = { version = "0.1.1", features = ["python"] }

[lib]
name = "_pydantic_core"
Expand Down
10 changes: 8 additions & 2 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,12 @@ def to_json(
JSON bytes.
"""

def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True, cache_strings: bool = True) -> Any:
def from_json(
data: str | bytes | bytearray,
*,
allow_inf_nan: bool = True,
cache_strings: bool | Literal['all', 'keys', 'none'] = True,
) -> Any:
"""
Deserialize JSON data to a Python object.

Expand All @@ -400,7 +405,8 @@ def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True, cach
data: The JSON data to deserialize.
allow_inf_nan: Whether to allow `Infinity`, `-Infinity` and `NaN` values as `json.loads()` does by default.
cache_strings: Whether to cache strings to avoid constructing new Python objects,
this should have a significant impact on performance while increasing memory usage slightly.
this should have a significant impact on performance while increasing memory usage slightly,
`all/True` means cache all strings, `keys` means cache only dict keys, `none/False` means no caching.

Raises:
ValueError: If deserialization fails.
Expand Down
18 changes: 9 additions & 9 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::cmp::Ordering;
use std::ops::Rem;
use std::str::FromStr;

use jiter::{JsonArray, JsonValue};
use jiter::{JsonArray, JsonValue, StringCacheMode};
use num_bigint::BigInt;

use pyo3::exceptions::PyTypeError;
Expand Down Expand Up @@ -435,9 +435,15 @@ impl<'a> EitherString<'a> {
}
}

pub fn as_py_string(&'a self, py: Python<'a>) -> Bound<'a, PyString> {
pub fn as_py_string(&'a self, py: Python<'a>, cache_str: StringCacheMode) -> Bound<'a, PyString> {
match self {
Self::Cow(cow) => PyString::new_bound(py, cow),
Self::Cow(cow) => {
if matches!(cache_str, StringCacheMode::All) {
jiter::cached_py_string(py, cow.as_ref())
} else {
PyString::new_bound(py, cow.as_ref())
}
}
Self::Py(py_string) => py_string.clone(),
}
}
Expand All @@ -461,12 +467,6 @@ impl<'a> From<Bound<'a, PyString>> for EitherString<'a> {
}
}

impl<'a> IntoPy<PyObject> for EitherString<'a> {
fn into_py(self, py: Python<'_>) -> PyObject {
self.as_py_string(py).into_py(py)
}
}

pub fn py_string_str<'a>(py_str: &'a Bound<'_, PyString>) -> ValResult<&'a str> {
py_str.to_str().map_err(|_| {
ValError::new_custom_input(
Expand Down
18 changes: 15 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ extern crate core;

use std::sync::OnceLock;

use jiter::StringCacheMode;
use pyo3::exceptions::PyTypeError;
use pyo3::{prelude::*, sync::GILOnceCell};

Expand Down Expand Up @@ -38,19 +39,30 @@ pub use validators::{validate_core_schema, PySome, SchemaValidator};

use crate::input::Input;

#[pyfunction(signature = (data, *, allow_inf_nan=true, cache_strings=true))]
#[derive(FromPyObject)]
pub enum CacheStringsArg {
Bool(bool),
Literal(StringCacheMode),
}

#[pyfunction(signature = (data, *, allow_inf_nan=true, cache_strings=CacheStringsArg::Bool(true)))]
pub fn from_json<'py>(
py: Python<'py>,
data: &Bound<'_, PyAny>,
allow_inf_nan: bool,
cache_strings: bool,
cache_strings: CacheStringsArg,
) -> PyResult<Bound<'py, PyAny>> {
let v_match = data
.validate_bytes(false)
.map_err(|_| PyTypeError::new_err("Expected bytes, bytearray or str"))?;
let json_either_bytes = v_match.into_inner();
let json_bytes = json_either_bytes.as_slice();
jiter::python_parse(py, json_bytes, allow_inf_nan, cache_strings).map_err(|e| jiter::map_json_error(json_bytes, &e))
samuelcolvin marked this conversation as resolved.
Show resolved Hide resolved
let cache_mode = match cache_strings {
CacheStringsArg::Bool(b) => b.into(),
CacheStringsArg::Literal(mode) => mode,
};
jiter::python_parse(py, json_bytes, allow_inf_nan, cache_mode, false)
.map_err(|e| jiter::map_json_error(json_bytes, &e))
}

pub fn get_pydantic_core_version() -> &'static str {
Expand Down
9 changes: 6 additions & 3 deletions src/validators/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ impl BuildValidator for ArgumentsValidator {
for (arg_index, arg) in arguments_schema.iter().enumerate() {
let arg = arg.downcast::<PyDict>()?;

let name: String = arg.get_as_req(intern!(py, "name"))?;
let py_name: Bound<PyString> = arg.get_as_req(intern!(py, "name"))?;
let name = py_name.to_string();
let mode = arg.get_as::<Bound<'_, PyString>>(intern!(py, "mode"))?;
let mode = mode
.as_ref()
Expand All @@ -77,7 +78,7 @@ impl BuildValidator for ArgumentsValidator {
}
None => Some(LookupKey::from_string(py, &name)),
};
kwarg_key = Some(PyString::new_bound(py, &name).into());
kwarg_key = Some(py_name.into_py(py));
}

let schema = arg.get_as_req(intern!(py, "schema"))?;
Expand Down Expand Up @@ -274,7 +275,9 @@ impl Validator for ArgumentsValidator {
if !used_kwargs.contains(either_str.as_cow()?.as_ref()) {
match self.var_kwargs_validator {
Some(ref validator) => match validator.validate(py, value.borrow_input(), state) {
Ok(value) => output_kwargs.set_item(either_str.as_py_string(py), value)?,
Ok(value) => {
output_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
errors.push(err.with_outer_location(raw_key.clone()));
Expand Down
10 changes: 7 additions & 3 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,10 @@ impl Validator for DataclassArgsValidator {
if let Some(ref validator) = self.extras_validator {
match validator.validate(py, value.borrow_input(), state) {
Ok(value) => {
output_dict.set_item(either_str.as_py_string(py), value)?;
output_dict.set_item(
either_str.as_py_string(py, state.cache_str()),
value,
)?;
}
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
Expand All @@ -312,7 +315,8 @@ impl Validator for DataclassArgsValidator {
Err(err) => return Err(err),
}
} else {
output_dict.set_item(either_str.as_py_string(py), value)?;
output_dict
.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
}
}
Expand Down Expand Up @@ -455,7 +459,7 @@ impl BuildValidator for DataclassValidator {
let validator = build_validator(&sub_schema, config, definitions)?;

let post_init = if schema.get_as::<bool>(intern!(py, "post_init"))?.unwrap_or(false) {
Some(PyString::new_bound(py, "__post_init__").into())
Some(intern!(py, "__post_init__").into_py(py))
} else {
None
};
Expand Down
4 changes: 4 additions & 0 deletions src/validators/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ pub struct InternalValidator {
validation_mode: InputType,
hide_input_in_errors: bool,
validation_error_cause: bool,
cache_str: jiter::StringCacheMode,
}

impl fmt::Debug for InternalValidator {
Expand Down Expand Up @@ -250,6 +251,7 @@ impl InternalValidator {
validation_mode: extra.input_type,
hide_input_in_errors,
validation_error_cause,
cache_str: extra.cache_str,
}
}

Expand All @@ -268,6 +270,7 @@ impl InternalValidator {
from_attributes: self.from_attributes,
context: self.context.as_ref().map(|data| data.bind(py)),
self_instance: self.self_instance.as_ref().map(|data| data.bind(py)),
cache_str: self.cache_str,
};
let mut state = ValidationState::new(extra, &mut self.recursion_guard);
state.exactness = self.exactness;
Expand Down Expand Up @@ -302,6 +305,7 @@ impl InternalValidator {
from_attributes: self.from_attributes,
context: self.context.as_ref().map(|data| data.bind(py)),
self_instance: self.self_instance.as_ref().map(|data| data.bind(py)),
cache_str: self.cache_str,
};
let mut state = ValidationState::new(extra, &mut self.recursion_guard);
state.exactness = self.exactness;
Expand Down
4 changes: 2 additions & 2 deletions src/validators/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ impl Validator for JsonValidator {
validator.validate(py, &json_value, &mut json_state)
}
None => {
let obj =
jiter::python_parse(py, json_bytes, true, true).map_err(|e| map_json_err(input, e, json_bytes))?;
let obj = jiter::python_parse(py, json_bytes, true, true.into(), false)
samuelcolvin marked this conversation as resolved.
Show resolved Hide resolved
.map_err(|e| map_json_err(input, e, json_bytes))?;
Ok(obj.unbind())
}
}
Expand Down
32 changes: 29 additions & 3 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fmt::Debug;

use enum_dispatch::enum_dispatch;
use jiter::StringCacheMode;

use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
Expand Down Expand Up @@ -110,6 +111,7 @@ pub struct SchemaValidator {
title: PyObject,
hide_input_in_errors: bool,
validation_error_cause: bool,
cache_str: StringCacheMode,
}

#[pymethods]
Expand All @@ -135,6 +137,9 @@ impl SchemaValidator {
};
let hide_input_in_errors: bool = config.get_as(intern!(py, "hide_input_in_errors"))?.unwrap_or(false);
let validation_error_cause: bool = config.get_as(intern!(py, "validation_error_cause"))?.unwrap_or(false);
let cache_str: StringCacheMode = config
.get_as(intern!(py, "cache_strings"))?
.unwrap_or(StringCacheMode::All);
Ok(Self {
validator,
definitions,
Expand All @@ -143,6 +148,7 @@ impl SchemaValidator {
title,
hide_input_in_errors,
validation_error_cause,
cache_str,
})
}

Expand Down Expand Up @@ -262,6 +268,7 @@ impl SchemaValidator {
from_attributes,
context,
self_instance: None,
cache_str: self.cache_str,
};

let guard = &mut RecursionState::default();
Expand All @@ -285,6 +292,7 @@ impl SchemaValidator {
from_attributes: None,
context,
self_instance: None,
cache_str: self.cache_str,
};
let recursion_guard = &mut RecursionState::default();
let mut state = ValidationState::new(extra, recursion_guard);
Expand All @@ -300,10 +308,15 @@ impl SchemaValidator {

pub fn __repr__(&self, py: Python) -> String {
format!(
"SchemaValidator(title={:?}, validator={:#?}, definitions={:#?})",
"SchemaValidator(title={:?}, validator={:#?}, definitions={:#?}, cache_strings={})",
self.title.extract::<&str>(py).unwrap(),
self.validator,
self.definitions,
match self.cache_str {
StringCacheMode::All => "True",
StringCacheMode::Keys => "'keys'",
StringCacheMode::None => "False",
}
)
}

Expand Down Expand Up @@ -331,7 +344,14 @@ impl SchemaValidator {
) -> ValResult<PyObject> {
let mut recursion_guard = RecursionState::default();
let mut state = ValidationState::new(
Extra::new(strict, from_attributes, context, self_instance, input_type),
Extra::new(
strict,
from_attributes,
context,
self_instance,
input_type,
self.cache_str,
),
&mut recursion_guard,
);
self.validator.validate(py, input, &mut state)
Expand Down Expand Up @@ -384,7 +404,7 @@ impl<'py> SelfValidator<'py> {
let py = schema.py();
let mut recursion_guard = RecursionState::default();
let mut state = ValidationState::new(
Extra::new(strict, None, None, None, InputType::Python),
Extra::new(strict, None, None, None, InputType::Python, true.into()),
&mut recursion_guard,
);
match self.validator.validator.validate(py, schema, &mut state) {
Expand Down Expand Up @@ -414,6 +434,7 @@ impl<'py> SelfValidator<'py> {
title: "Self Schema".into_py(py),
hide_input_in_errors: false,
validation_error_cause: false,
cache_str: true.into(),
})
}
}
Expand Down Expand Up @@ -577,6 +598,8 @@ pub struct Extra<'a, 'py> {
pub context: Option<&'a Bound<'py, PyAny>>,
/// This is an instance of the model or dataclass being validated, when validation is performed from `__init__`
self_instance: Option<&'a Bound<'py, PyAny>>,
/// Whether to use a cache of short strings to accelerate python string construction
cache_str: StringCacheMode,
}

impl<'a, 'py> Extra<'a, 'py> {
Expand All @@ -586,6 +609,7 @@ impl<'a, 'py> Extra<'a, 'py> {
context: Option<&'a Bound<'py, PyAny>>,
self_instance: Option<&'a Bound<'py, PyAny>>,
input_type: InputType,
cache_str: StringCacheMode,
) -> Self {
Extra {
input_type,
Expand All @@ -594,6 +618,7 @@ impl<'a, 'py> Extra<'a, 'py> {
from_attributes,
context,
self_instance,
cache_str,
}
}
}
Expand All @@ -607,6 +632,7 @@ impl Extra<'_, '_> {
from_attributes: self.from_attributes,
context: self.context,
self_instance: self.self_instance,
cache_str: self.cache_str,
}
}
}
Expand Down
Loading
Loading