Skip to content

Commit

Permalink
support newest jiter behaviour (#1092)
Browse files Browse the repository at this point in the history
Co-authored-by: David Hewitt <mail@davidhewitt.dev>
  • Loading branch information
samuelcolvin and davidhewitt authored Nov 27, 2023
1 parent 5b63e7a commit c7daf16
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 113 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ base64 = "0.21.5"
num-bigint = "0.4.4"
python3-dll-a = "0.2.7"
uuid = "1.5.0"
jiter = {version = "0.0.4", features = ["python"]}
#jiter = {path = "../jiter", features = ["python"]}
jiter = {version = "0.0.5", features = ["python"]}

[lib]
name = "_pydantic_core"
Expand Down
4 changes: 3 additions & 1 deletion python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def to_json(
JSON bytes.
"""

def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True) -> Any:
def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True, cache_strings: bool = True) -> Any:
"""
Deserialize JSON data to a Python object.
Expand All @@ -394,6 +394,8 @@ def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True) -> A
Arguments:
data: The JSON data to deserialize.
allow_inf_nan: Whether to allow `Infinity`, `-Infinity` and `NaN` values as `json.loads()` does by default.
cache_strings: Whether to cache strings to avoid constructing new Python objects,
this should have a significant impact on performance while increasing memory usage slightly.
Raises:
ValueError: If deserialization fails.
Expand Down
4 changes: 0 additions & 4 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ use pyo3::exceptions::PyValueError;
use pyo3::types::{PyDict, PyType};
use pyo3::{intern, prelude::*};

use jiter::JsonValue;

use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, ValError, ValResult};
use crate::tools::py_err;
use crate::{PyMultiHostUrl, PyUrl};
Expand Down Expand Up @@ -89,8 +87,6 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized {

fn validate_dataclass_args(&'a self, dataclass_name: &str) -> ValResult<GenericArguments<'a>>;

fn parse_json(&'a self) -> ValResult<JsonValue>;

fn validate_str(
&'a self,
strict: bool,
Expand Down
13 changes: 1 addition & 12 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use super::datetime::{
float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime,
};
use super::return_enums::ValidationMatch;
use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int};
use super::shared::{float_as_int, int_as_bool, str_as_bool, str_as_float, str_as_int};
use super::{
BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable,
GenericIterator, GenericMapping, Input, JsonArgs,
Expand Down Expand Up @@ -84,13 +84,6 @@ impl<'a> Input<'a> for JsonValue {
}
}

fn parse_json(&'a self) -> ValResult<JsonValue> {
match self {
JsonValue::Str(s) => JsonValue::parse(s.as_bytes(), true).map_err(|e| map_json_err(self, e)),
_ => Err(ValError::new(ErrorTypeDefaults::JsonType, self)),
}
}

fn exact_str(&'a self) -> ValResult<EitherString<'a>> {
match self {
JsonValue::Str(s) => Ok(s.as_str().into()),
Expand Down Expand Up @@ -367,10 +360,6 @@ impl<'a> Input<'a> for String {
))
}

fn parse_json(&'a self) -> ValResult<JsonValue> {
JsonValue::parse(self.as_bytes(), true).map_err(|e| map_json_err(self, e))
}

fn validate_str(
&'a self,
_strict: bool,
Expand Down
20 changes: 1 addition & 19 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use pyo3::types::{
use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues};
use pyo3::{intern, PyTypeInfo};

use jiter::JsonValue;
use speedate::MicrosecondsPrecisionOverflowBehavior;

use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
Expand All @@ -26,8 +25,7 @@ use super::datetime::{
};
use super::return_enums::ValidationMatch;
use super::shared::{
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, map_json_err, str_as_bool, str_as_float,
str_as_int,
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, str_as_bool, str_as_float, str_as_int,
};
use super::{
py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments,
Expand Down Expand Up @@ -195,22 +193,6 @@ impl<'a> Input<'a> for PyAny {
}
}

fn parse_json(&'a self) -> ValResult<JsonValue> {
let bytes = if let Ok(py_bytes) = self.downcast::<PyBytes>() {
py_bytes.as_bytes()
} else if let Ok(py_str) = self.downcast::<PyString>() {
let str = py_string_str(py_str)?;
str.as_bytes()
} else if let Ok(py_byte_array) = self.downcast::<PyByteArray>() {
// Safety: from_slice does not run arbitrary Python code and the GIL is held so the
// bytes array will not be mutated while `JsonValue::parse` is reading it
unsafe { py_byte_array.as_bytes() }
} else {
return Err(ValError::new(ErrorTypeDefaults::JsonType, self));
};
JsonValue::parse(bytes, true).map_err(|e| map_json_err(self, e))
}

fn validate_str(
&'a self,
strict: bool,
Expand Down
13 changes: 1 addition & 12 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyString};

use jiter::JsonValue;
use speedate::MicrosecondsPrecisionOverflowBehavior;

use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
Expand All @@ -12,7 +11,7 @@ use crate::validators::decimal::create_decimal;
use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime,
};
use super::shared::{map_json_err, str_as_bool, str_as_float};
use super::shared::{str_as_bool, str_as_float};
use super::{
BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable,
GenericIterator, GenericMapping, Input, ValidationMatch,
Expand Down Expand Up @@ -86,16 +85,6 @@ impl<'a> Input<'a> for StringMapping<'a> {
}
}

fn parse_json(&'a self) -> ValResult<JsonValue> {
match self {
Self::String(s) => {
let str = py_string_str(s)?;
JsonValue::parse(str.as_bytes(), true).map_err(|e| map_json_err(self, e))
}
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::JsonType, self)),
}
}

fn validate_str(
&'a self,
_strict: bool,
Expand Down
13 changes: 1 addition & 12 deletions src/input/shared.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use pyo3::sync::GILOnceCell;
use pyo3::{intern, Py, PyAny, Python, ToPyObject};

use jiter::JsonValueError;
use num_bigint::BigInt;

use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult};
use crate::errors::{ErrorTypeDefaults, ValError, ValResult};

use super::{EitherFloat, EitherInt, Input};
static ENUM_META_OBJECT: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
Expand All @@ -20,16 +19,6 @@ pub fn get_enum_meta_object(py: Python) -> Py<PyAny> {
.clone()
}

pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: JsonValueError) -> ValError {
ValError::new(
ErrorType::JsonInvalid {
error: error.to_string(),
context: None,
},
input,
)
}

pub fn str_as_bool<'a>(input: &'a impl Input<'a>, str: &str) -> ValResult<bool> {
if str == "0"
|| str.eq_ignore_ascii_case("f")
Expand Down
22 changes: 10 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ extern crate core;
use std::sync::OnceLock;

use pyo3::exceptions::PyTypeError;
use pyo3::types::{PyByteArray, PyBytes, PyString};
use pyo3::{prelude::*, sync::GILOnceCell};

// parse this first to get access to the contained macro
Expand Down Expand Up @@ -37,17 +36,16 @@ pub use serializers::{
};
pub use validators::{validate_core_schema, PySome, SchemaValidator};

#[pyfunction(signature = (data, *, allow_inf_nan=true))]
pub fn from_json(py: Python, data: &PyAny, allow_inf_nan: bool) -> PyResult<PyObject> {
if let Ok(py_bytes) = data.downcast::<PyBytes>() {
jiter::python_parse(py, py_bytes.as_bytes(), allow_inf_nan)
} else if let Ok(py_str) = data.downcast::<PyString>() {
jiter::python_parse(py, py_str.to_str()?.as_bytes(), allow_inf_nan)
} else if let Ok(py_byte_array) = data.downcast::<PyByteArray>() {
jiter::python_parse(py, &py_byte_array.to_vec(), allow_inf_nan)
} else {
Err(PyTypeError::new_err("Expected bytes, bytearray or str"))
}
use crate::input::Input;

#[pyfunction(signature = (data, *, allow_inf_nan=true, cache_strings=true))]
pub fn from_json(py: Python, data: &PyAny, allow_inf_nan: bool, cache_strings: bool) -> PyResult<PyObject> {
let v_match = data
.validate_bytes(false)
.map_err(|_| PyTypeError::new_err("Expected bytes, bytearray or str"))?;
let json_either_bytes = v_match.into_inner();
let json_bytes = json_either_bytes.as_slice();
jiter::python_parse(py, json_bytes, allow_inf_nan, cache_strings).map_err(|e| jiter::map_json_error(json_bytes, &e))
}

pub fn get_pydantic_core_version() -> &'static str {
Expand Down
53 changes: 45 additions & 8 deletions src/validators/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::PyDict;

use crate::errors::ValResult;
use crate::input::Input;
use jiter::JsonValue;

use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult};
use crate::input::{EitherBytes, Input, ValidationMatch};
use crate::tools::SchemaDict;

use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};
Expand Down Expand Up @@ -50,17 +52,52 @@ impl Validator for JsonValidator {
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<PyObject> {
let json_value = input.parse_json()?;
let v_match = validate_json_bytes(input)?;
let json_either_bytes = v_match.unpack(state);
let json_bytes = json_either_bytes.as_slice();
match self.validator {
Some(ref validator) => match validator.validate(py, &json_value, state) {
Ok(v) => Ok(v),
Err(err) => Err(err),
},
None => Ok(json_value.to_object(py)),
Some(ref validator) => {
let json_value = JsonValue::parse(json_bytes, true).map_err(|e| map_json_err(input, e, json_bytes))?;
validator.validate(py, &json_value, state)
}
None => {
let obj =
jiter::python_parse(py, json_bytes, true, true).map_err(|e| map_json_err(input, e, json_bytes))?;
Ok(obj)
}
}
}

fn get_name(&self) -> &str {
&self.name
}
}

pub fn validate_json_bytes<'data>(input: &'data impl Input<'data>) -> ValResult<ValidationMatch<EitherBytes<'data>>> {
match input.validate_bytes(false) {
Ok(v_match) => Ok(v_match),
Err(ValError::LineErrors(e)) => Err(ValError::LineErrors(
e.into_iter().map(map_bytes_error).collect::<Vec<_>>(),
)),
Err(e) => Err(e),
}
}

fn map_bytes_error(line_error: ValLineError) -> ValLineError {
match line_error.error_type {
ErrorType::BytesType { .. } => {
ValLineError::new_custom_input(ErrorTypeDefaults::JsonType, line_error.input_value)
}
_ => line_error,
}
}

pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: jiter::JsonError, json_bytes: &[u8]) -> ValError {
ValError::new(
ErrorType::JsonInvalid {
error: error.description(json_bytes),
context: None,
},
input,
)
}
51 changes: 29 additions & 22 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ impl SchemaValidator {
from_attributes,
context,
self_instance,
&mut RecursionGuard::default(),
)
.map_err(|e| self.prepare_validation_err(py, e, InputType::Python))
}
Expand All @@ -194,7 +193,6 @@ impl SchemaValidator {
from_attributes,
context,
self_instance,
&mut RecursionGuard::default(),
) {
Ok(_) => Ok(true),
Err(ValError::InternalErr(err)) => Err(err),
Expand All @@ -213,22 +211,18 @@ impl SchemaValidator {
context: Option<&PyAny>,
self_instance: Option<&PyAny>,
) -> PyResult<PyObject> {
let recursion_guard = &mut RecursionGuard::default();
match input.parse_json() {
Ok(input) => self
._validate(
py,
&input,
InputType::Json,
strict,
None,
context,
self_instance,
recursion_guard,
)
.map_err(|e| self.prepare_validation_err(py, e, InputType::Json)),
Err(err) => Err(self.prepare_validation_err(py, err, InputType::Json)),
}
let r = match json::validate_json_bytes(input) {
Ok(v_match) => self._validate_json(
py,
input,
v_match.into_inner().as_slice(),
strict,
context,
self_instance,
),
Err(err) => Err(err),
};
r.map_err(|e| self.prepare_validation_err(py, e, InputType::Json))
}

#[pyo3(signature = (input, *, strict=None, context=None))]
Expand All @@ -242,8 +236,7 @@ impl SchemaValidator {
let t = InputType::String;
let string_mapping = StringMapping::new_value(input).map_err(|e| self.prepare_validation_err(py, e, t))?;

let recursion_guard = &mut RecursionGuard::default();
match self._validate(py, &string_mapping, t, strict, None, context, None, recursion_guard) {
match self._validate(py, &string_mapping, t, strict, None, context, None) {
Ok(r) => Ok(r),
Err(e) => Err(self.prepare_validation_err(py, e, t)),
}
Expand Down Expand Up @@ -329,18 +322,32 @@ impl SchemaValidator {
from_attributes: Option<bool>,
context: Option<&'data PyAny>,
self_instance: Option<&PyAny>,
recursion_guard: &'data mut RecursionGuard,
) -> ValResult<PyObject>
where
's: 'data,
{
let mut recursion_guard = RecursionGuard::default();
let mut state = ValidationState::new(
Extra::new(strict, from_attributes, context, self_instance, input_type),
recursion_guard,
&mut recursion_guard,
);
self.validator.validate(py, input, &mut state)
}

fn _validate_json(
&self,
py: Python,
input: &PyAny,
json_data: &[u8],
strict: Option<bool>,
context: Option<&PyAny>,
self_instance: Option<&PyAny>,
) -> ValResult<PyObject> {
let json_value =
jiter::JsonValue::parse(json_data, true).map_err(|e| json::map_json_err(input, e, json_data))?;
self._validate(py, &json_value, InputType::Json, strict, None, context, self_instance)
}

fn prepare_validation_err(&self, py: Python, error: ValError, input_type: InputType) -> PyErr {
ValidationError::from_val_error(
py,
Expand Down
Loading

0 comments on commit c7daf16

Please sign in to comment.