Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

caching strings from JSON #1240

Merged
merged 13 commits into from
Mar 27, 2024
2 changes: 2 additions & 0 deletions .mypy-stubtest-allowlist
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
# TODO: don't want to expose this staticmethod, requires https://github.com/PyO3/pyo3/issues/2384
pydantic_core._pydantic_core.PydanticUndefinedType.new
# As per #1240, from_json has custom logic to coverage the `cache_strings` kwarg
pydantic_core._pydantic_core.from_json
15 changes: 2 additions & 13 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ base64 = "0.21.7"
num-bigint = "0.4.4"
python3-dll-a = "0.2.7"
uuid = "1.7.0"
jiter = { version = "0.1.0", features = ["python"] }
jiter = { version = "0.1.1", features = ["python"] }

[lib]
name = "_pydantic_core"
Expand Down
15 changes: 12 additions & 3 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -390,17 +390,26 @@ def to_json(
JSON bytes.
"""

def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True, cache_strings: bool = True) -> Any:
def from_json(
data: str | bytes | bytearray,
*,
allow_inf_nan: bool = True,
cache_strings: bool | Literal['all', 'keys', 'none'] = True,
allow_partial: bool = False,
) -> Any:
"""
Deserialize JSON data to a Python object.

This is effectively a faster version of `json.loads()`.
This is effectively a faster version of `json.loads()`, with some extra functionality.

Arguments:
data: The JSON data to deserialize.
allow_inf_nan: Whether to allow `Infinity`, `-Infinity` and `NaN` values as `json.loads()` does by default.
cache_strings: Whether to cache strings to avoid constructing new Python objects,
this should have a significant impact on performance while increasing memory usage slightly.
this should have a significant impact on performance while increasing memory usage slightly,
`all/True` means cache all strings, `keys` means cache only dict keys, `none/False` means no caching.
allow_partial: Whether to allow partial deserialization, if `True` JSON data is returned if the end of the
input is reached before the full object is deserialized, e.g. `["aa", "bb", "c` would return `['aa', 'bb']`.

Raises:
ValueError: If deserialization fails.
Expand Down
3 changes: 3 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class CoreConfig(TypedDict, total=False):
Requires exceptiongroup backport pre Python 3.11.
coerce_numbers_to_str: Whether to enable coercion of any `Number` type to `str` (not applicable in `strict` mode).
regex_engine: The regex engine to use for regex pattern validation. Default is 'rust-regex'. See `StringSchema`.
cache_strings: Whether to cache strings. Default is `True`, `True` or `'all'` is required to cache strings
during general validation since validators don't know if they're in a key or a value.
"""

title: str
Expand Down Expand Up @@ -110,6 +112,7 @@ class CoreConfig(TypedDict, total=False):
validation_error_cause: bool # default: False
coerce_numbers_to_str: bool # default: False
regex_engine: Literal['rust-regex', 'python-re'] # default: 'rust-regex'
cache_strings: Union[bool, Literal['all', 'keys', 'none']] # default: 'True'


IncExCall: TypeAlias = 'set[int | str] | dict[int | str, IncExCall] | None'
Expand Down
18 changes: 9 additions & 9 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::cmp::Ordering;
use std::ops::Rem;
use std::str::FromStr;

use jiter::{JsonArray, JsonValue};
use jiter::{JsonArray, JsonValue, StringCacheMode};
use num_bigint::BigInt;

use pyo3::exceptions::PyTypeError;
Expand Down Expand Up @@ -435,9 +435,15 @@ impl<'a> EitherString<'a> {
}
}

pub fn as_py_string(&'a self, py: Python<'a>) -> Bound<'a, PyString> {
pub fn as_py_string(&'a self, py: Python<'a>, cache_str: StringCacheMode) -> Bound<'a, PyString> {
match self {
Self::Cow(cow) => PyString::new_bound(py, cow),
Self::Cow(cow) => {
if matches!(cache_str, StringCacheMode::All) {
jiter::cached_py_string(py, cow.as_ref())
} else {
PyString::new_bound(py, cow.as_ref())
}
}
Self::Py(py_string) => py_string.clone(),
}
}
Expand All @@ -461,12 +467,6 @@ impl<'a> From<Bound<'a, PyString>> for EitherString<'a> {
}
}

impl<'a> IntoPy<PyObject> for EitherString<'a> {
fn into_py(self, py: Python<'_>) -> PyObject {
self.as_py_string(py).into_py(py)
}
}

pub fn py_string_str<'a>(py_str: &'a Bound<'_, PyString>) -> ValResult<&'a str> {
py_str.to_str().map_err(|_| {
ValError::new_custom_input(
Expand Down
19 changes: 16 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ extern crate core;

use std::sync::OnceLock;

use jiter::StringCacheMode;
use pyo3::exceptions::PyTypeError;
use pyo3::{prelude::*, sync::GILOnceCell};

Expand Down Expand Up @@ -38,19 +39,31 @@ pub use validators::{validate_core_schema, PySome, SchemaValidator};

use crate::input::Input;

#[pyfunction(signature = (data, *, allow_inf_nan=true, cache_strings=true))]
#[derive(FromPyObject)]
pub enum CacheStringsArg {
Bool(bool),
Literal(StringCacheMode),
}

#[pyfunction(signature = (data, *, allow_inf_nan=true, cache_strings=CacheStringsArg::Bool(true), allow_partial=false))]
pub fn from_json<'py>(
py: Python<'py>,
data: &Bound<'_, PyAny>,
allow_inf_nan: bool,
cache_strings: bool,
cache_strings: CacheStringsArg,
allow_partial: bool,
) -> PyResult<Bound<'py, PyAny>> {
let v_match = data
.validate_bytes(false)
.map_err(|_| PyTypeError::new_err("Expected bytes, bytearray or str"))?;
let json_either_bytes = v_match.into_inner();
let json_bytes = json_either_bytes.as_slice();
jiter::python_parse(py, json_bytes, allow_inf_nan, cache_strings).map_err(|e| jiter::map_json_error(json_bytes, &e))
samuelcolvin marked this conversation as resolved.
Show resolved Hide resolved
let cache_mode = match cache_strings {
CacheStringsArg::Bool(b) => b.into(),
CacheStringsArg::Literal(mode) => mode,
};
jiter::python_parse(py, json_bytes, allow_inf_nan, cache_mode, allow_partial)
.map_err(|e| jiter::map_json_error(json_bytes, &e))
}

pub fn get_pydantic_core_version() -> &'static str {
Expand Down
9 changes: 6 additions & 3 deletions src/validators/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ impl BuildValidator for ArgumentsValidator {
for (arg_index, arg) in arguments_schema.iter().enumerate() {
let arg = arg.downcast::<PyDict>()?;

let name: String = arg.get_as_req(intern!(py, "name"))?;
let py_name: Bound<PyString> = arg.get_as_req(intern!(py, "name"))?;
let name = py_name.to_string();
let mode = arg.get_as::<Bound<'_, PyString>>(intern!(py, "mode"))?;
let mode = mode
.as_ref()
Expand All @@ -77,7 +78,7 @@ impl BuildValidator for ArgumentsValidator {
}
None => Some(LookupKey::from_string(py, &name)),
};
kwarg_key = Some(PyString::new_bound(py, &name).into());
kwarg_key = Some(py_name.into_py(py));
}

let schema = arg.get_as_req(intern!(py, "schema"))?;
Expand Down Expand Up @@ -274,7 +275,9 @@ impl Validator for ArgumentsValidator {
if !used_kwargs.contains(either_str.as_cow()?.as_ref()) {
match self.var_kwargs_validator {
Some(ref validator) => match validator.validate(py, value.borrow_input(), state) {
Ok(value) => output_kwargs.set_item(either_str.as_py_string(py), value)?,
Ok(value) => {
output_kwargs.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
errors.push(err.with_outer_location(raw_key.clone()));
Expand Down
10 changes: 7 additions & 3 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,10 @@ impl Validator for DataclassArgsValidator {
if let Some(ref validator) = self.extras_validator {
match validator.validate(py, value.borrow_input(), state) {
Ok(value) => {
output_dict.set_item(either_str.as_py_string(py), value)?;
output_dict.set_item(
either_str.as_py_string(py, state.cache_str()),
value,
)?;
}
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
Expand All @@ -312,7 +315,8 @@ impl Validator for DataclassArgsValidator {
Err(err) => return Err(err),
}
} else {
output_dict.set_item(either_str.as_py_string(py), value)?;
output_dict
.set_item(either_str.as_py_string(py, state.cache_str()), value)?;
}
}
}
Expand Down Expand Up @@ -455,7 +459,7 @@ impl BuildValidator for DataclassValidator {
let validator = build_validator(&sub_schema, config, definitions)?;

let post_init = if schema.get_as::<bool>(intern!(py, "post_init"))?.unwrap_or(false) {
Some(PyString::new_bound(py, "__post_init__").into())
Some(intern!(py, "__post_init__").into_py(py))
} else {
None
};
Expand Down
4 changes: 4 additions & 0 deletions src/validators/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ pub struct InternalValidator {
validation_mode: InputType,
hide_input_in_errors: bool,
validation_error_cause: bool,
cache_str: jiter::StringCacheMode,
}

impl fmt::Debug for InternalValidator {
Expand Down Expand Up @@ -250,6 +251,7 @@ impl InternalValidator {
validation_mode: extra.input_type,
hide_input_in_errors,
validation_error_cause,
cache_str: extra.cache_str,
}
}

Expand All @@ -268,6 +270,7 @@ impl InternalValidator {
from_attributes: self.from_attributes,
context: self.context.as_ref().map(|data| data.bind(py)),
self_instance: self.self_instance.as_ref().map(|data| data.bind(py)),
cache_str: self.cache_str,
};
let mut state = ValidationState::new(extra, &mut self.recursion_guard);
state.exactness = self.exactness;
Expand Down Expand Up @@ -302,6 +305,7 @@ impl InternalValidator {
from_attributes: self.from_attributes,
context: self.context.as_ref().map(|data| data.bind(py)),
self_instance: self.self_instance.as_ref().map(|data| data.bind(py)),
cache_str: self.cache_str,
};
let mut state = ValidationState::new(extra, &mut self.recursion_guard);
state.exactness = self.exactness;
Expand Down
4 changes: 2 additions & 2 deletions src/validators/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ impl Validator for JsonValidator {
validator.validate(py, &json_value, &mut json_state)
}
None => {
let obj =
jiter::python_parse(py, json_bytes, true, true).map_err(|e| map_json_err(input, e, json_bytes))?;
let obj = jiter::python_parse(py, json_bytes, true, state.cache_str(), false)
.map_err(|e| map_json_err(input, e, json_bytes))?;
Ok(obj.unbind())
}
}
Expand Down
Loading
Loading