From 1f18da208721d993d344934adc4bcc5bbfadb17f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 6 Nov 2023 11:32:53 +0000 Subject: [PATCH] jiter (#974) --- Cargo.lock | 85 +++++++++ Cargo.toml | 2 + python/pydantic_core/__init__.py | 2 + python/pydantic_core/_pydantic_core.pyi | 18 ++ src/errors/line_error.rs | 6 +- src/errors/location.rs | 28 +-- src/errors/mod.rs | 2 +- src/input/input_abstract.rs | 12 +- src/input/input_json.rs | 167 +++++++++--------- src/input/input_python.rs | 59 ++++--- src/input/input_string.rs | 13 +- src/input/mod.rs | 2 - src/input/parse_json.rs | 222 ------------------------ src/input/return_enums.rs | 16 +- src/input/shared.rs | 12 +- src/lazy_index_map.rs | 63 ------- src/lib.rs | 17 +- src/lookup_key.rs | 18 +- src/validators/arguments.rs | 2 +- src/validators/dataclass.rs | 2 +- src/validators/dict.rs | 2 +- src/validators/function.rs | 22 +-- src/validators/model_fields.rs | 2 +- src/validators/typed_dict.rs | 2 +- src/validators/union.rs | 4 +- tests/test_json.py | 9 +- tests/validators/test_decimal.py | 7 +- tests/validators/test_float.py | 37 +++- tests/validators/test_function.py | 15 +- tests/validators/test_int.py | 26 ++- 30 files changed, 385 insertions(+), 489 deletions(-) delete mode 100644 src/input/parse_json.rs delete mode 100644 src/lazy_index_map.rs diff --git a/Cargo.lock b/Cargo.lock index 14804774d..f16b93697 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -136,6 +136,84 @@ version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62b02a5381cc465bd3041d84623d0fa3b66738b52b8e2fc3bab8ad63ab032f4a" +[[package]] +name = "jiter" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b27d419c535bf7b50ad355278b1159cbf0cc8d507ea003d625b17bf0375720b8" +dependencies = [ + "ahash", + "lexical-core", + "num-bigint", + "num-traits", + "pyo3", + "smallvec", +] + +[[package]] +name = "lexical-core" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" +dependencies = [ + "lexical-parse-float", + "lexical-parse-integer", + "lexical-util", + "lexical-write-float", + "lexical-write-integer", +] + +[[package]] +name = "lexical-parse-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" +dependencies = [ + "lexical-parse-integer", + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-parse-integer" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" +dependencies = [ + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-util" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +dependencies = [ + "static_assertions", +] + +[[package]] +name = "lexical-write-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" +dependencies = [ + "lexical-util", + "lexical-write-integer", + "static_assertions", +] + +[[package]] +name = "lexical-write-integer" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +dependencies = [ + "lexical-util", + "static_assertions", +] + [[package]] name = "libc" version = "0.2.147" @@ -249,6 +327,7 @@ dependencies = [ "base64", "enum_dispatch", "idna", + "jiter", "num-bigint", "pyo3", "pyo3-build-config", @@ -450,6 +529,12 @@ dependencies = [ "strum_macros", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "strum" version = "0.25.0" diff --git a/Cargo.toml b/Cargo.toml index cdd1b7056..211fdfe01 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,8 @@ base64 = "0.21.5" num-bigint = "0.4.4" python3-dll-a = "0.2.7" uuid = "1.5.0" +jiter = {version = "0.0.4", features = ["python"]} +#jiter = {path = "../jiter", features = ["python"]} [lib] name = "_pydantic_core" diff --git a/python/pydantic_core/__init__.py b/python/pydantic_core/__init__.py index a46a77b7d..5b2655c91 100644 --- a/python/pydantic_core/__init__.py +++ b/python/pydantic_core/__init__.py @@ -22,6 +22,7 @@ Url, ValidationError, __version__, + from_json, to_json, to_jsonable_python, validate_core_schema, @@ -63,6 +64,7 @@ 'PydanticSerializationUnexpectedValue', 'TzInfo', 'to_json', + 'from_json', 'to_jsonable_python', 'validate_core_schema', ] diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index 8ed3092a9..f28b7a12a 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -41,6 +41,7 @@ __all__ = [ 'PydanticUndefinedType', 'Some', 'to_json', + 'from_json', 'to_jsonable_python', 'list_all_errors', 'TzInfo', @@ -384,6 +385,23 @@ def to_json( JSON bytes. """ +def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True) -> Any: + """ + Deserialize JSON data to a Python object. + + This is effectively a faster version of [`json.loads()`][json.loads]. + + Arguments: + data: The JSON data to deserialize. + allow_inf_nan: Whether to allow `Infinity`, `-Infinity` and `NaN` values as `json.loads()` does by default. + + Raises: + ValueError: If deserialization fails. + + Returns: + The deserialized Python object. + """ + def to_jsonable_python( value: Any, *, diff --git a/src/errors/line_error.rs b/src/errors/line_error.rs index e5d3c7bac..3ee4c7894 100644 --- a/src/errors/line_error.rs +++ b/src/errors/line_error.rs @@ -2,7 +2,9 @@ use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; use pyo3::PyDowncastError; -use crate::input::{Input, JsonInput}; +use jiter::JsonValue; + +use crate::input::Input; use super::location::{LocItem, Location}; use super::types::ErrorType; @@ -147,7 +149,7 @@ impl<'a> ValLineError<'a> { #[derive(Clone)] pub enum InputValue<'a> { PyAny(&'a PyAny), - JsonInput(JsonInput), + JsonInput(JsonValue), String(&'a str), } diff --git a/src/errors/location.rs b/src/errors/location.rs index e5c32d5e2..8acc2a039 100644 --- a/src/errors/location.rs +++ b/src/errors/location.rs @@ -3,12 +3,11 @@ use pyo3::once_cell::GILOnceCell; use std::fmt; use pyo3::prelude::*; -use pyo3::types::{PyList, PyString, PyTuple}; +use pyo3::types::{PyList, PyTuple}; use serde::ser::SerializeSeq; use serde::{Serialize, Serializer}; use crate::lookup_key::{LookupPath, PathItem}; -use crate::tools::extract_i64; /// Used to store individual items of the error location, e.g. a string for key/field names /// or a number for array indices. @@ -35,6 +34,12 @@ impl fmt::Display for LocItem { } } +// TODO rename to ToLocItem +pub trait AsLocItem { + // TODO rename to to_loc_item + fn as_loc_item(&self) -> LocItem; +} + impl From for LocItem { fn from(s: String) -> Self { Self::S(s) @@ -82,21 +87,6 @@ impl ToPyObject for LocItem { } } -impl TryFrom<&PyAny> for LocItem { - type Error = PyErr; - - fn try_from(loc_item: &PyAny) -> PyResult { - if let Ok(py_str) = loc_item.downcast::() { - let str = py_str.to_str()?.to_string(); - Ok(Self::S(str)) - } else if let Ok(int) = extract_i64(loc_item) { - Ok(Self::I(int)) - } else { - Err(PyTypeError::new_err("Item in a location must be a string or int")) - } - } -} - impl Serialize for LocItem { fn serialize(&self, serializer: S) -> Result where @@ -211,9 +201,9 @@ impl TryFrom> for Location { fn try_from(location: Option<&PyAny>) -> PyResult { if let Some(location) = location { let mut loc_vec: Vec = if let Ok(tuple) = location.downcast::() { - tuple.iter().map(LocItem::try_from).collect::>()? + tuple.iter().map(AsLocItem::as_loc_item).collect() } else if let Ok(list) = location.downcast::() { - list.iter().map(LocItem::try_from).collect::>()? + list.iter().map(AsLocItem::as_loc_item).collect() } else { return Err(PyTypeError::new_err( "Location must be a list or tuple of strings and ints", diff --git a/src/errors/mod.rs b/src/errors/mod.rs index 6a253197f..bfc5b4329 100644 --- a/src/errors/mod.rs +++ b/src/errors/mod.rs @@ -7,7 +7,7 @@ mod validation_exception; mod value_exception; pub use self::line_error::{InputValue, ValError, ValLineError, ValResult}; -pub use self::location::LocItem; +pub use self::location::{AsLocItem, LocItem}; pub use self::types::{list_all_errors, ErrorType, ErrorTypeDefaults, Number}; pub use self::validation_exception::ValidationError; pub use self::value_exception::{PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault}; diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 655ba24b9..52551ef42 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -4,13 +4,15 @@ use pyo3::exceptions::PyValueError; use pyo3::types::{PyDict, PyType}; use pyo3::{intern, prelude::*}; -use crate::errors::{InputValue, LocItem, ValResult}; +use jiter::JsonValue; + +use crate::errors::{AsLocItem, InputValue, ValResult}; use crate::tools::py_err; use crate::{PyMultiHostUrl, PyUrl}; use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta}; use super::return_enums::{EitherBytes, EitherInt, EitherString}; -use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping, JsonInput}; +use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping}; #[derive(Debug, Clone, Copy)] pub enum InputType { @@ -46,9 +48,7 @@ impl TryFrom<&str> for InputType { /// the convention is to either implement: /// * `strict_*` & `lax_*` if they have different behavior /// * or, `validate_*` and `strict_*` to just call `validate_*` if the behavior for strict and lax is the same -pub trait Input<'a>: fmt::Debug + ToPyObject { - fn as_loc_item(&self) -> LocItem; - +pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem { fn as_error_value(&'a self) -> InputValue<'a>; fn identity(&self) -> Option { @@ -89,7 +89,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { fn validate_dataclass_args(&'a self, dataclass_name: &str) -> ValResult<'a, GenericArguments<'a>>; - fn parse_json(&'a self) -> ValResult<'a, JsonInput>; + fn parse_json(&'a self) -> ValResult<'a, JsonValue>; fn validate_str(&'a self, strict: bool, coerce_numbers_to_str: bool) -> ValResult> { if strict { diff --git a/src/input/input_json.rs b/src/input/input_json.rs index e375f5755..ac552621d 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -1,46 +1,48 @@ use std::borrow::Cow; +use jiter::{JsonArray, JsonValue}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; use speedate::MicrosecondsPrecisionOverflowBehavior; use strum::EnumMessage; -use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; +use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::validators::decimal::create_decimal; use super::datetime::{ bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime, }; -use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int, string_to_vec}; +use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int}; use super::{ BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, - GenericIterator, GenericMapping, Input, JsonArgs, JsonArray, JsonInput, + GenericIterator, GenericMapping, Input, JsonArgs, }; -impl<'a> Input<'a> for JsonInput { - /// This is required by since JSON object keys are always strings, I don't think it can be called - #[cfg_attr(has_coverage_attribute, coverage(off))] +/// This is required but since JSON object keys are always strings, I don't think it can be called +impl AsLocItem for JsonValue { fn as_loc_item(&self) -> LocItem { match self { - JsonInput::Int(i) => (*i).into(), - JsonInput::String(s) => s.as_str().into(), + JsonValue::Int(i) => (*i).into(), + JsonValue::Str(s) => s.as_str().into(), v => format!("{v:?}").into(), } } +} +impl<'a> Input<'a> for JsonValue { fn as_error_value(&'a self) -> InputValue<'a> { - // cloning JsonInput is cheap due to use of Arc + // cloning JsonValue is cheap due to use of Arc InputValue::JsonInput(self.clone()) } fn is_none(&self) -> bool { - matches!(self, JsonInput::Null) + matches!(self, JsonValue::Null) } fn as_kwargs(&'a self, py: Python<'a>) -> Option<&'a PyDict> { match self { - JsonInput::Object(object) => { + JsonValue::Object(object) => { let dict = PyDict::new(py); for (k, v) in object.iter() { dict.set_item(k, v.to_object(py)).unwrap(); @@ -53,15 +55,15 @@ impl<'a> Input<'a> for JsonInput { fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> { match self { - JsonInput::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()), - JsonInput::Array(array) => Ok(JsonArgs::new(Some(array), None).into()), + JsonValue::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()), + JsonValue::Array(array) => Ok(JsonArgs::new(Some(array), None).into()), _ => Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)), } } fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult<'a, GenericArguments<'a>> { match self { - JsonInput::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()), + JsonValue::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()), _ => { let class_name = class_name.to_string(); Err(ValError::new( @@ -75,33 +77,32 @@ impl<'a> Input<'a> for JsonInput { } } - fn parse_json(&'a self) -> ValResult<'a, JsonInput> { + fn parse_json(&'a self) -> ValResult<'a, JsonValue> { match self { - JsonInput::String(s) => serde_json::from_str(s.as_str()).map_err(|e| map_json_err(self, e)), + JsonValue::Str(s) => JsonValue::parse(s.as_bytes(), true).map_err(|e| map_json_err(self, e)), _ => Err(ValError::new(ErrorTypeDefaults::JsonType, self)), } } fn strict_str(&'a self) -> ValResult> { match self { - JsonInput::String(s) => Ok(s.as_str().into()), + JsonValue::Str(s) => Ok(s.as_str().into()), _ => Err(ValError::new(ErrorTypeDefaults::StringType, self)), } } fn lax_str(&'a self, coerce_numbers_to_str: bool) -> ValResult> { match self { - JsonInput::String(s) => Ok(s.as_str().into()), - JsonInput::BigInt(v) if coerce_numbers_to_str => Ok(v.to_string().into()), - JsonInput::Float(v) if coerce_numbers_to_str => Ok(v.to_string().into()), - JsonInput::Int(v) if coerce_numbers_to_str => Ok(v.to_string().into()), - JsonInput::Uint(v) if coerce_numbers_to_str => Ok(v.to_string().into()), + JsonValue::Str(s) => Ok(s.as_str().into()), + JsonValue::Int(i) if coerce_numbers_to_str => Ok(i.to_string().into()), + JsonValue::BigInt(b) if coerce_numbers_to_str => Ok(b.to_string().into()), + JsonValue::Float(f) if coerce_numbers_to_str => Ok(f.to_string().into()), _ => Err(ValError::new(ErrorTypeDefaults::StringType, self)), } } fn validate_bytes(&'a self, _strict: bool) -> ValResult> { match self { - JsonInput::String(s) => Ok(s.as_bytes().into()), + JsonValue::Str(s) => Ok(s.as_bytes().into()), _ => Err(ValError::new(ErrorTypeDefaults::BytesType, self)), } } @@ -112,16 +113,16 @@ impl<'a> Input<'a> for JsonInput { fn strict_bool(&self) -> ValResult { match self { - JsonInput::Bool(b) => Ok(*b), + JsonValue::Bool(b) => Ok(*b), _ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)), } } fn lax_bool(&self) -> ValResult { match self { - JsonInput::Bool(b) => Ok(*b), - JsonInput::String(s) => str_as_bool(self, s), - JsonInput::Int(int) => int_as_bool(self, *int), - JsonInput::Float(float) => match float_as_int(self, *float) { + JsonValue::Bool(b) => Ok(*b), + JsonValue::Str(s) => str_as_bool(self, s), + JsonValue::Int(int) => int_as_bool(self, *int), + JsonValue::Float(float) => match float_as_int(self, *float) { Ok(int) => int .as_bool() .ok_or_else(|| ValError::new(ErrorTypeDefaults::BoolParsing, self)), @@ -133,60 +134,56 @@ impl<'a> Input<'a> for JsonInput { fn strict_int(&'a self) -> ValResult> { match self { - JsonInput::Int(i) => Ok(EitherInt::I64(*i)), - JsonInput::Uint(u) => Ok(EitherInt::U64(*u)), - JsonInput::BigInt(b) => Ok(EitherInt::BigInt(b.clone())), + JsonValue::Int(i) => Ok(EitherInt::I64(*i)), + JsonValue::BigInt(b) => Ok(EitherInt::BigInt(b.clone())), _ => Err(ValError::new(ErrorTypeDefaults::IntType, self)), } } fn lax_int(&'a self) -> ValResult> { match self { - JsonInput::Bool(b) => match *b { + JsonValue::Bool(b) => match *b { true => Ok(EitherInt::I64(1)), false => Ok(EitherInt::I64(0)), }, - JsonInput::Int(i) => Ok(EitherInt::I64(*i)), - JsonInput::Uint(u) => Ok(EitherInt::U64(*u)), - JsonInput::BigInt(b) => Ok(EitherInt::BigInt(b.clone())), - JsonInput::Float(f) => float_as_int(self, *f), - JsonInput::String(str) => str_as_int(self, str), + JsonValue::Int(i) => Ok(EitherInt::I64(*i)), + JsonValue::BigInt(b) => Ok(EitherInt::BigInt(b.clone())), + JsonValue::Float(f) => float_as_int(self, *f), + JsonValue::Str(str) => str_as_int(self, str), _ => Err(ValError::new(ErrorTypeDefaults::IntType, self)), } } fn ultra_strict_float(&'a self) -> ValResult> { match self { - JsonInput::Float(f) => Ok(EitherFloat::F64(*f)), + JsonValue::Float(f) => Ok(EitherFloat::F64(*f)), _ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), } } fn strict_float(&'a self) -> ValResult> { match self { - JsonInput::Float(f) => Ok(EitherFloat::F64(*f)), - JsonInput::Int(i) => Ok(EitherFloat::F64(*i as f64)), - JsonInput::Uint(u) => Ok(EitherFloat::F64(*u as f64)), + JsonValue::Float(f) => Ok(EitherFloat::F64(*f)), + JsonValue::Int(i) => Ok(EitherFloat::F64(*i as f64)), _ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), } } fn lax_float(&'a self) -> ValResult> { match self { - JsonInput::Bool(b) => match *b { + JsonValue::Bool(b) => match *b { true => Ok(EitherFloat::F64(1.0)), false => Ok(EitherFloat::F64(0.0)), }, - JsonInput::Float(f) => Ok(EitherFloat::F64(*f)), - JsonInput::Int(i) => Ok(EitherFloat::F64(*i as f64)), - JsonInput::Uint(u) => Ok(EitherFloat::F64(*u as f64)), - JsonInput::String(str) => str_as_float(self, str), + JsonValue::Float(f) => Ok(EitherFloat::F64(*f)), + JsonValue::Int(i) => Ok(EitherFloat::F64(*i as f64)), + JsonValue::Str(str) => str_as_float(self, str), _ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), } } fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> { match self { - JsonInput::Float(f) => create_decimal(PyString::new(py, &f.to_string()), self, py), + JsonValue::Float(f) => create_decimal(PyString::new(py, &f.to_string()), self, py), - JsonInput::String(..) | JsonInput::Int(..) | JsonInput::Uint(..) | JsonInput::BigInt(..) => { + JsonValue::Str(..) | JsonValue::Int(..) | JsonValue::BigInt(..) => { create_decimal(self.to_object(py).into_ref(py), self, py) } _ => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)), @@ -195,7 +192,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_dict(&'a self, _strict: bool) -> ValResult> { match self { - JsonInput::Object(dict) => Ok(dict.into()), + JsonValue::Object(dict) => Ok(dict.into()), _ => Err(ValError::new(ErrorTypeDefaults::DictType, self)), } } @@ -206,7 +203,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_list(&'a self, _strict: bool) -> ValResult> { match self { - JsonInput::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), _ => Err(ValError::new(ErrorTypeDefaults::ListType, self)), } } @@ -218,7 +215,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_tuple(&'a self, _strict: bool) -> ValResult> { // just as in set's case, List has to be allowed match self { - JsonInput::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), _ => Err(ValError::new(ErrorTypeDefaults::TupleType, self)), } } @@ -230,7 +227,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_set(&'a self, _strict: bool) -> ValResult> { // we allow a list here since otherwise it would be impossible to create a set from JSON match self { - JsonInput::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), _ => Err(ValError::new(ErrorTypeDefaults::SetType, self)), } } @@ -242,7 +239,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_frozenset(&'a self, _strict: bool) -> ValResult> { // we allow a list here since otherwise it would be impossible to create a frozenset from JSON match self { - JsonInput::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), _ => Err(ValError::new(ErrorTypeDefaults::FrozenSetType, self)), } } @@ -253,20 +250,20 @@ impl<'a> Input<'a> for JsonInput { fn extract_generic_iterable(&self) -> ValResult { match self { - JsonInput::Array(a) => Ok(GenericIterable::JsonArray(a)), - JsonInput::String(s) => Ok(GenericIterable::JsonString(s)), - JsonInput::Object(object) => Ok(GenericIterable::JsonObject(object)), + JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Str(s) => Ok(GenericIterable::JsonString(s)), + JsonValue::Object(object) => Ok(GenericIterable::JsonObject(object)), _ => Err(ValError::new(ErrorTypeDefaults::IterableType, self)), } } fn validate_iter(&self) -> ValResult { match self { - JsonInput::Array(a) => Ok(a.clone().into()), - JsonInput::String(s) => Ok(string_to_vec(s).into()), - JsonInput::Object(object) => { + JsonValue::Array(a) => Ok(a.clone().into()), + JsonValue::Str(s) => Ok(string_to_vec(s).into()), + JsonValue::Object(object) => { // return keys iterator to match python's behavior - let keys: JsonArray = JsonArray::new(object.keys().map(|k| JsonInput::String(k.clone())).collect()); + let keys: JsonArray = JsonArray::new(object.keys().map(|k| JsonValue::Str(k.clone())).collect()); Ok(keys.into()) } _ => Err(ValError::new(ErrorTypeDefaults::IterableType, self)), @@ -275,7 +272,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_date(&self, _strict: bool) -> ValResult { match self { - JsonInput::String(v) => bytes_as_date(self, v.as_bytes()), + JsonValue::Str(v) => bytes_as_date(self, v.as_bytes()), _ => Err(ValError::new(ErrorTypeDefaults::DateType, self)), } } @@ -291,16 +288,16 @@ impl<'a> Input<'a> for JsonInput { microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, ) -> ValResult { match self { - JsonInput::String(v) => bytes_as_time(self, v.as_bytes(), microseconds_overflow_behavior), + JsonValue::Str(v) => bytes_as_time(self, v.as_bytes(), microseconds_overflow_behavior), _ => Err(ValError::new(ErrorTypeDefaults::TimeType, self)), } } fn lax_time(&self, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior) -> ValResult { match self { - JsonInput::String(v) => bytes_as_time(self, v.as_bytes(), microseconds_overflow_behavior), - JsonInput::Int(v) => int_as_time(self, *v, 0), - JsonInput::Float(v) => float_as_time(self, *v), - JsonInput::BigInt(_) => Err(ValError::new( + JsonValue::Str(v) => bytes_as_time(self, v.as_bytes(), microseconds_overflow_behavior), + JsonValue::Int(v) => int_as_time(self, *v, 0), + JsonValue::Float(v) => float_as_time(self, *v), + JsonValue::BigInt(_) => Err(ValError::new( ErrorType::TimeParsing { error: Cow::Borrowed( speedate::ParseError::TimeTooLarge @@ -320,7 +317,7 @@ impl<'a> Input<'a> for JsonInput { microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, ) -> ValResult { match self { - JsonInput::String(v) => bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior), + JsonValue::Str(v) => bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior), _ => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)), } } @@ -329,9 +326,9 @@ impl<'a> Input<'a> for JsonInput { microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, ) -> ValResult { match self { - JsonInput::String(v) => bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior), - JsonInput::Int(v) => int_as_datetime(self, *v, 0), - JsonInput::Float(v) => float_as_datetime(self, *v), + JsonValue::Str(v) => bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior), + JsonValue::Int(v) => int_as_datetime(self, *v, 0), + JsonValue::Float(v) => float_as_datetime(self, *v), _ => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)), } } @@ -341,7 +338,7 @@ impl<'a> Input<'a> for JsonInput { microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, ) -> ValResult { match self { - JsonInput::String(v) => bytes_as_timedelta(self, v.as_bytes(), microseconds_overflow_behavior), + JsonValue::Str(v) => bytes_as_timedelta(self, v.as_bytes(), microseconds_overflow_behavior), _ => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)), } } @@ -350,29 +347,31 @@ impl<'a> Input<'a> for JsonInput { microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, ) -> ValResult { match self { - JsonInput::String(v) => bytes_as_timedelta(self, v.as_bytes(), microseconds_overflow_behavior), - JsonInput::Int(v) => Ok(int_as_duration(self, *v)?.into()), - JsonInput::Float(v) => Ok(float_as_duration(self, *v)?.into()), + JsonValue::Str(v) => bytes_as_timedelta(self, v.as_bytes(), microseconds_overflow_behavior), + JsonValue::Int(v) => Ok(int_as_duration(self, *v)?.into()), + JsonValue::Float(v) => Ok(float_as_duration(self, *v)?.into()), _ => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)), } } } -impl BorrowInput for &'_ JsonInput { - type Input<'a> = JsonInput where Self: 'a; +impl BorrowInput for &'_ JsonValue { + type Input<'a> = JsonValue where Self: 'a; fn borrow_input(&self) -> &Self::Input<'_> { self } } -/// TODO: it would be good to get JsonInput and StringMapping string variants to go through this -/// implementation -/// Required for Dict keys so the string can behave like an Input -impl<'a> Input<'a> for String { +impl AsLocItem for String { fn as_loc_item(&self) -> LocItem { self.to_string().into() } +} +/// TODO: it would be good to get JsonInput and StringMapping string variants to go through this +/// implementation +/// Required for JSON Object keys so the string can behave like an Input +impl<'a> Input<'a> for String { fn as_error_value(&'a self) -> InputValue<'a> { InputValue::String(self) } @@ -398,8 +397,8 @@ impl<'a> Input<'a> for String { )) } - fn parse_json(&'a self) -> ValResult<'a, JsonInput> { - serde_json::from_str(self.as_str()).map_err(|e| map_json_err(self, e)) + fn parse_json(&'a self) -> ValResult<'a, JsonValue> { + JsonValue::parse(self.as_bytes(), true).map_err(|e| map_json_err(self, e)) } fn strict_str(&'a self) -> ValResult> { @@ -504,3 +503,7 @@ impl BorrowInput for String { self } } + +fn string_to_vec(s: &str) -> JsonArray { + JsonArray::new(s.chars().map(|c| JsonValue::Str(c.to_string())).collect()) +} diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 33d7ca296..de59ebce0 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -9,9 +9,11 @@ use pyo3::types::{ #[cfg(not(PyPy))] use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues}; use pyo3::{intern, PyTypeInfo}; + +use jiter::JsonValue; use speedate::MicrosecondsPrecisionOverflowBehavior; -use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; +use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::tools::{extract_i64, safe_repr}; use crate::validators::decimal::{create_decimal, get_decimal_type}; use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl}; @@ -27,7 +29,7 @@ use super::shared::{ }; use super::{ py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, - GenericIterable, GenericIterator, GenericMapping, Input, JsonInput, PyArgs, + GenericIterable, GenericIterator, GenericMapping, Input, PyArgs, }; #[cfg(not(PyPy))] @@ -90,7 +92,7 @@ macro_rules! extract_dict_items { }; } -impl<'a> Input<'a> for PyAny { +impl AsLocItem for PyAny { fn as_loc_item(&self) -> LocItem { if let Ok(py_str) = self.downcast::() { py_str.to_string_lossy().as_ref().into() @@ -100,7 +102,9 @@ impl<'a> Input<'a> for PyAny { safe_repr(self).to_string().into() } } +} +impl<'a> Input<'a> for PyAny { fn as_error_value(&'a self) -> InputValue<'a> { InputValue::PyAny(self) } @@ -183,19 +187,20 @@ impl<'a> Input<'a> for PyAny { } } - fn parse_json(&'a self) -> ValResult<'a, JsonInput> { - if let Ok(py_bytes) = self.downcast::() { - serde_json::from_slice(py_bytes.as_bytes()).map_err(|e| map_json_err(self, e)) + fn parse_json(&'a self) -> ValResult<'a, JsonValue> { + let bytes = if let Ok(py_bytes) = self.downcast::() { + py_bytes.as_bytes() } else if let Ok(py_str) = self.downcast::() { let str = py_string_str(py_str)?; - serde_json::from_str(str).map_err(|e| map_json_err(self, e)) + str.as_bytes() } else if let Ok(py_byte_array) = self.downcast::() { // Safety: from_slice does not run arbitrary Python code and the GIL is held so the - // bytes array will not be mutated while from_slice is reading it - serde_json::from_slice(unsafe { py_byte_array.as_bytes() }).map_err(|e| map_json_err(self, e)) + // bytes array will not be mutated while `JsonValue::parse` is reading it + unsafe { py_byte_array.as_bytes() } } else { - Err(ValError::new(ErrorTypeDefaults::JsonType, self)) - } + return Err(ValError::new(ErrorTypeDefaults::JsonType, self)); + }; + JsonValue::parse(bytes, true).map_err(|e| map_json_err(self, e)) } fn strict_str(&'a self) -> ValResult> { @@ -210,22 +215,6 @@ impl<'a> Input<'a> for PyAny { } } - fn exact_str(&'a self) -> ValResult> { - if let Ok(py_str) = PyString::try_from_exact(self) { - Ok(EitherString::Py(py_str)) - } else { - Err(ValError::new(ErrorTypeDefaults::IntType, self)) - } - } - - fn exact_int(&'a self) -> ValResult> { - if PyInt::is_exact_type_of(self) { - Ok(EitherInt::Py(self)) - } else { - Err(ValError::new(ErrorTypeDefaults::IntType, self)) - } - } - fn lax_str(&'a self, coerce_numbers_to_str: bool) -> ValResult> { if let Ok(py_str) = ::try_from_exact(self) { Ok(py_str.into()) @@ -352,6 +341,22 @@ impl<'a> Input<'a> for PyAny { } } + fn exact_int(&'a self) -> ValResult> { + if PyInt::is_exact_type_of(self) { + Ok(EitherInt::Py(self)) + } else { + Err(ValError::new(ErrorTypeDefaults::IntType, self)) + } + } + + fn exact_str(&'a self) -> ValResult> { + if let Ok(py_str) = PyString::try_from_exact(self) { + Ok(EitherString::Py(py_str)) + } else { + Err(ValError::new(ErrorTypeDefaults::IntType, self)) + } + } + fn ultra_strict_float(&'a self) -> ValResult> { if self.is_instance_of::() { Err(ValError::new(ErrorTypeDefaults::FloatType, self)) diff --git a/src/input/input_string.rs b/src/input/input_string.rs index 72a32d897..b84908edf 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -1,9 +1,10 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; +use jiter::JsonValue; use speedate::MicrosecondsPrecisionOverflowBehavior; -use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; +use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::input::py_string_str; use crate::tools::safe_repr; use crate::validators::decimal::create_decimal; @@ -14,7 +15,7 @@ use super::datetime::{ use super::shared::{map_json_err, str_as_bool, str_as_float}; use super::{ BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, - GenericIterator, GenericMapping, Input, JsonInput, + GenericIterator, GenericMapping, Input, }; #[derive(Debug)] @@ -52,14 +53,16 @@ impl<'py> StringMapping<'py> { } } -impl<'a> Input<'a> for StringMapping<'a> { +impl AsLocItem for StringMapping<'_> { fn as_loc_item(&self) -> LocItem { match self { Self::String(s) => s.to_string_lossy().as_ref().into(), Self::Mapping(d) => safe_repr(d).to_string().into(), } } +} +impl<'a> Input<'a> for StringMapping<'a> { fn as_error_value(&'a self) -> InputValue<'a> { match self { Self::String(s) => s.as_error_value(), @@ -83,11 +86,11 @@ impl<'a> Input<'a> for StringMapping<'a> { } } - fn parse_json(&'a self) -> ValResult<'a, JsonInput> { + fn parse_json(&'a self) -> ValResult<'a, JsonValue> { match self { Self::String(s) => { let str = py_string_str(s)?; - serde_json::from_str(str).map_err(|e| map_json_err(self, e)) + JsonValue::parse(str.as_bytes(), true).map_err(|e| map_json_err(self, e)) } Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::JsonType, self)), } diff --git a/src/input/mod.rs b/src/input/mod.rs index 22d774a8c..13c835f83 100644 --- a/src/input/mod.rs +++ b/src/input/mod.rs @@ -7,7 +7,6 @@ mod input_abstract; mod input_json; mod input_python; mod input_string; -mod parse_json; mod return_enums; mod shared; @@ -18,7 +17,6 @@ pub(crate) use datetime::{ }; pub(crate) use input_abstract::{BorrowInput, Input, InputType}; pub(crate) use input_string::StringMapping; -pub(crate) use parse_json::{JsonArray, JsonInput, JsonObject}; pub(crate) use return_enums::{ py_string_str, AttributesGenericIterator, DictGenericIterator, EitherBytes, EitherFloat, EitherInt, EitherString, GenericArguments, GenericIterable, GenericIterator, GenericMapping, Int, JsonArgs, JsonObjectGenericIterator, diff --git a/src/input/parse_json.rs b/src/input/parse_json.rs deleted file mode 100644 index 20a107669..000000000 --- a/src/input/parse_json.rs +++ /dev/null @@ -1,222 +0,0 @@ -use std::fmt; -use std::sync::Arc; - -use num_bigint::BigInt; -use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList}; -use serde::de::{Deserialize, DeserializeSeed, Error as SerdeError, MapAccess, SeqAccess, Visitor}; -use smallvec::SmallVec; - -use crate::lazy_index_map::LazyIndexMap; - -/// similar to serde `Value` but with int and float split -#[derive(Clone, Debug)] -pub enum JsonInput { - Null, - Bool(bool), - Int(i64), - BigInt(BigInt), - Uint(u64), - Float(f64), - String(String), - Array(JsonArray), - Object(JsonObject), -} -pub type JsonArray = Arc>; -pub type JsonObject = Arc>; - -impl ToPyObject for JsonInput { - fn to_object(&self, py: Python<'_>) -> PyObject { - match self { - Self::Null => py.None(), - Self::Bool(b) => b.into_py(py), - Self::Int(i) => i.into_py(py), - Self::BigInt(b) => b.to_object(py), - Self::Uint(i) => i.into_py(py), - Self::Float(f) => f.into_py(py), - Self::String(s) => s.into_py(py), - Self::Array(v) => PyList::new(py, v.iter().map(|v| v.to_object(py))).into_py(py), - Self::Object(o) => { - let dict = PyDict::new(py); - for (k, v) in o.iter() { - dict.set_item(k, v.to_object(py)).unwrap(); - } - dict.into_py(py) - } - } - } -} - -impl<'de> Deserialize<'de> for JsonInput { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - struct JsonVisitor; - - impl<'de> Visitor<'de> for JsonVisitor { - type Value = JsonInput; - - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("any valid JSON value") - } - - fn visit_bool(self, value: bool) -> Result { - Ok(JsonInput::Bool(value)) - } - - fn visit_i64(self, value: i64) -> Result { - Ok(JsonInput::Int(value)) - } - - fn visit_u64(self, value: u64) -> Result { - match i64::try_from(value) { - Ok(i) => Ok(JsonInput::Int(i)), - Err(_) => Ok(JsonInput::Uint(value)), - } - } - - fn visit_f64(self, value: f64) -> Result { - Ok(JsonInput::Float(value)) - } - - fn visit_str(self, value: &str) -> Result - where - E: SerdeError, - { - Ok(JsonInput::String(value.to_string())) - } - - fn visit_string(self, value: String) -> Result { - Ok(JsonInput::String(value)) - } - - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn visit_none(self) -> Result { - unreachable!() - } - - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn visit_some(self, _: D) -> Result - where - D: serde::Deserializer<'de>, - { - unreachable!() - } - - fn visit_unit(self) -> Result { - Ok(JsonInput::Null) - } - - fn visit_seq(self, mut visitor: V) -> Result - where - V: SeqAccess<'de>, - { - let mut vec = SmallVec::new(); - - while let Some(elem) = visitor.next_element()? { - vec.push(elem); - } - - Ok(JsonInput::Array(JsonArray::new(vec))) - } - - fn visit_map(self, mut visitor: V) -> Result - where - V: MapAccess<'de>, - { - const SERDE_JSON_NUMBER: &str = "$serde_json::private::Number"; - match visitor.next_key_seed(KeyDeserializer)? { - Some(first_key) => { - let mut values = LazyIndexMap::new(); - let first_value = visitor.next_value()?; - - // serde_json will parse arbitrary precision numbers into a map - // structure with a "number" key and a String value - 'try_number: { - if first_key == SERDE_JSON_NUMBER { - // Just in case someone tries to actually store that key in a real map, - // keep parsing and continue as a map if so - - if let Some((key, value)) = visitor.next_entry::()? { - // Important to preserve order of the keys - values.insert(first_key, first_value); - values.insert(key, value); - break 'try_number; - } - - if let JsonInput::String(s) = &first_value { - // Normalize the string to either an int or float - let normalized = if s.chars().any(|c| c == '.' || c == 'E' || c == 'e') { - JsonInput::Float( - s.parse() - .map_err(|e| V::Error::custom(format!("expected a float: {e}")))?, - ) - } else if let Ok(i) = s.parse::() { - JsonInput::Int(i) - } else if let Ok(big) = s.parse::() { - JsonInput::BigInt(big) - } else { - // Failed to normalize, just throw it in the map and continue - values.insert(first_key, first_value); - break 'try_number; - }; - - return Ok(normalized); - }; - } else { - values.insert(first_key, first_value); - } - } - - while let Some((key, value)) = visitor.next_entry()? { - values.insert(key, value); - } - Ok(JsonInput::Object(Arc::new(values))) - } - None => Ok(JsonInput::Object(Arc::new(LazyIndexMap::new()))), - } - } - } - - deserializer.deserialize_any(JsonVisitor) - } -} - -struct KeyDeserializer; - -impl<'de> DeserializeSeed<'de> for KeyDeserializer { - type Value = String; - - fn deserialize(self, deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_str(self) - } -} - -impl<'de> Visitor<'de> for KeyDeserializer { - type Value = String; - - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string key") - } - - fn visit_str(self, s: &str) -> Result - where - E: serde::de::Error, - { - Ok(s.to_string()) - } - - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn visit_string(self, _: String) -> Result - where - E: serde::de::Error, - { - unreachable!() - } -} diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index daa9f39fe..412842a13 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -4,6 +4,7 @@ use std::ops::Rem; use std::slice::Iter as SliceIter; use std::str::FromStr; +use jiter::{JsonArray, JsonObject, JsonValue}; use num_bigint::BigInt; use pyo3::exceptions::PyTypeError; @@ -26,7 +27,6 @@ use crate::tools::py_err; use crate::validators::{CombinedValidator, ValidationState, Validator}; use super::input_string::StringMapping; -use super::parse_json::{JsonArray, JsonInput, JsonObject}; use super::{py_error_on_minusone, Input}; /// Container for all the collections (sized iterable containers) types, which @@ -50,7 +50,7 @@ pub enum GenericIterable<'a> { PyByteArray(&'a PyByteArray), Sequence(&'a PySequence), Iterator(&'a PyIterator), - JsonArray(&'a [JsonInput]), + JsonArray(&'a [JsonValue]), JsonObject(&'a JsonObject), JsonString(&'a String), } @@ -573,7 +573,7 @@ impl<'py> Iterator for AttributesGenericIterator<'py> { } pub struct JsonObjectGenericIterator<'py> { - object_iter: SliceIter<'py, (String, JsonInput)>, + object_iter: SliceIter<'py, (String, JsonValue)>, } impl<'py> JsonObjectGenericIterator<'py> { @@ -585,7 +585,7 @@ impl<'py> JsonObjectGenericIterator<'py> { } impl<'py> Iterator for JsonObjectGenericIterator<'py> { - type Item = ValResult<'py, (&'py String, &'py JsonInput)>; + type Item = ValResult<'py, (&'py String, &'py JsonValue)>; fn next(&mut self) -> Option { self.object_iter.next().map(|(key, value)| Ok((key, value))) @@ -653,7 +653,7 @@ pub struct GenericJsonIterator { } impl GenericJsonIterator { - pub fn next(&mut self, _py: Python) -> PyResult> { + pub fn next(&mut self, _py: Python) -> PyResult> { if self.index < self.array.len() { // panic here is impossible due to bounds check above; compiler should be // able to optimize it away even @@ -667,7 +667,7 @@ impl GenericJsonIterator { } pub fn input_as_error_value<'py>(&self, _py: Python<'py>) -> InputValue<'py> { - InputValue::JsonInput(JsonInput::Array(self.array.clone())) + InputValue::JsonInput(JsonValue::Array(self.array.clone())) } pub fn index(&self) -> usize { @@ -689,12 +689,12 @@ impl<'a> PyArgs<'a> { #[cfg_attr(debug_assertions, derive(Debug))] pub struct JsonArgs<'a> { - pub args: Option<&'a [JsonInput]>, + pub args: Option<&'a [JsonValue]>, pub kwargs: Option<&'a JsonObject>, } impl<'a> JsonArgs<'a> { - pub fn new(args: Option<&'a [JsonInput]>, kwargs: Option<&'a JsonObject>) -> Self { + pub fn new(args: Option<&'a [JsonValue]>, kwargs: Option<&'a JsonObject>) -> Self { Self { args, kwargs } } } diff --git a/src/input/shared.rs b/src/input/shared.rs index 105da4bcc..718210098 100644 --- a/src/input/shared.rs +++ b/src/input/shared.rs @@ -1,12 +1,12 @@ -use num_bigint::BigInt; use pyo3::sync::GILOnceCell; use pyo3::{intern, Py, PyAny, Python, ToPyObject}; +use jiter::JsonValueError; +use num_bigint::BigInt; + use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; -use super::parse_json::{JsonArray, JsonInput}; use super::{EitherFloat, EitherInt, Input}; - static ENUM_META_OBJECT: GILOnceCell> = GILOnceCell::new(); pub fn get_enum_meta_object(py: Python) -> Py { @@ -20,7 +20,7 @@ pub fn get_enum_meta_object(py: Python) -> Py { .clone() } -pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: serde_json::Error) -> ValError<'a> { +pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: JsonValueError) -> ValError<'a> { ValError::new( ErrorType::JsonInvalid { error: error.to_string(), @@ -164,7 +164,3 @@ pub fn decimal_as_int<'a>(py: Python, input: &'a impl Input<'a>, decimal: &'a Py } Ok(EitherInt::Py(numerator)) } - -pub fn string_to_vec(s: &str) -> JsonArray { - JsonArray::new(s.chars().map(|c| JsonInput::String(c.to_string())).collect()) -} diff --git a/src/lazy_index_map.rs b/src/lazy_index_map.rs deleted file mode 100644 index c5621f877..000000000 --- a/src/lazy_index_map.rs +++ /dev/null @@ -1,63 +0,0 @@ -use std::borrow::Borrow; -use std::cmp::{Eq, PartialEq}; -use std::fmt::Debug; -use std::hash::Hash; -use std::slice::Iter as SliceIter; -use std::sync::OnceLock; - -use ahash::AHashMap; -use smallvec::SmallVec; - -#[derive(Debug, Clone, Default)] -pub struct LazyIndexMap { - vec: SmallVec<[(K, V); 8]>, - map: OnceLock>, -} - -/// Like [IndexMap](https://docs.rs/indexmap/latest/indexmap/) but only builds the lookup map when it's needed. -impl LazyIndexMap -where - K: Clone + Debug + Eq + Hash, - V: Debug, -{ - pub fn new() -> Self { - Self { - vec: SmallVec::new(), - map: OnceLock::new(), - } - } - - pub fn insert(&mut self, key: K, value: V) { - if let Some(map) = self.map.get_mut() { - map.insert(key.clone(), self.vec.len()); - } - self.vec.push((key, value)); - } - - pub fn len(&self) -> usize { - self.vec.len() - } - - pub fn get(&self, key: &Q) -> Option<&V> - where - K: Borrow + PartialEq, - Q: Hash + Eq, - { - let map = self.map.get_or_init(|| { - self.vec - .iter() - .enumerate() - .map(|(index, (key, _))| (key.clone(), index)) - .collect() - }); - map.get(key).map(|&i| &self.vec[i].1) - } - - pub fn keys(&self) -> impl Iterator { - self.vec.iter().map(|(k, _)| k) - } - - pub fn iter(&self) -> SliceIter<'_, (K, V)> { - self.vec.iter() - } -} diff --git a/src/lib.rs b/src/lib.rs index b241cdb8a..f969c0657 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,8 @@ extern crate core; use std::sync::OnceLock; +use pyo3::exceptions::PyTypeError; +use pyo3::types::{PyByteArray, PyBytes, PyString}; use pyo3::{prelude::*, sync::GILOnceCell}; // parse this first to get access to the contained macro @@ -15,7 +17,6 @@ mod build_tools; mod definitions; mod errors; mod input; -mod lazy_index_map; mod lookup_key; mod recursion_guard; mod serializers; @@ -36,6 +37,19 @@ pub use serializers::{ }; pub use validators::{validate_core_schema, PySome, SchemaValidator}; +#[pyfunction(signature = (data, *, allow_inf_nan=true))] +pub fn from_json(py: Python, data: &PyAny, allow_inf_nan: bool) -> PyResult { + if let Ok(py_bytes) = data.downcast::() { + jiter::python_parse(py, py_bytes.as_bytes(), allow_inf_nan) + } else if let Ok(py_str) = data.downcast::() { + jiter::python_parse(py, py_str.to_str()?.as_bytes(), allow_inf_nan) + } else if let Ok(py_byte_array) = data.downcast::() { + jiter::python_parse(py, &py_byte_array.to_vec(), allow_inf_nan) + } else { + Err(PyTypeError::new_err("Expected bytes, bytearray or str")) + } +} + pub fn get_pydantic_core_version() -> &'static str { static PYDANTIC_CORE_VERSION: OnceLock = OnceLock::new(); @@ -95,6 +109,7 @@ fn _pydantic_core(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(to_json, m)?)?; + m.add_function(wrap_pyfunction!(from_json, m)?)?; m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?; m.add_function(wrap_pyfunction!(list_all_errors, m)?)?; m.add_function(wrap_pyfunction!(validate_core_schema, m)?)?; diff --git a/src/lookup_key.rs b/src/lookup_key.rs index bb7d7e3d7..f833c00af 100644 --- a/src/lookup_key.rs +++ b/src/lookup_key.rs @@ -5,9 +5,11 @@ use pyo3::exceptions::{PyAttributeError, PyTypeError}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyMapping, PyString}; +use jiter::{JsonObject, JsonValue}; + use crate::build_tools::py_schema_err; use crate::errors::{py_err_string, ErrorType, ValError, ValLineError, ValResult}; -use crate::input::{Input, JsonInput, JsonObject, StringMapping}; +use crate::input::{Input, StringMapping}; use crate::tools::{extract_i64, py_err}; /// Used for getting items from python dicts, python objects, or JSON objects, in different ways @@ -264,7 +266,7 @@ impl LookupKey { pub fn json_get<'data, 's>( &'s self, dict: &'data JsonObject, - ) -> ValResult<'data, Option<(&'s LookupPath, &'data JsonInput)>> { + ) -> ValResult<'data, Option<(&'s LookupPath, &'data JsonValue)>> { match self { Self::Simple { key, path, .. } => match dict.get(key) { Some(value) => Ok(Some((path, value))), @@ -289,13 +291,13 @@ impl LookupKey { // first step is different from the rest as we already know dict is JsonObject // because of above checks, we know that path should have at least one element, hence unwrap - let v: &JsonInput = match path_iter.next().unwrap().json_obj_get(dict) { + let v: &JsonValue = match path_iter.next().unwrap().json_obj_get(dict) { Some(v) => v, None => continue, }; // similar to above - // iterate over the path and plug each value into the JsonInput from the last step, starting with v + // iterate over the path and plug each value into the JsonValue from the last step, starting with v // from the first step, this could just be a loop but should be somewhat faster with a functional design if let Some(v) = path_iter.try_fold(v, |d, loc| loc.json_get(d)) { // Successfully found an item, return it @@ -481,10 +483,10 @@ impl PathItem { } } - pub fn json_get<'a>(&self, any_json: &'a JsonInput) -> Option<&'a JsonInput> { + pub fn json_get<'a>(&self, any_json: &'a JsonValue) -> Option<&'a JsonValue> { match any_json { - JsonInput::Object(v_obj) => self.json_obj_get(v_obj), - JsonInput::Array(v_array) => match self { + JsonValue::Object(v_obj) => self.json_obj_get(v_obj), + JsonValue::Array(v_array) => match self { Self::Pos(index) => v_array.get(*index), Self::Neg(index) => { if let Some(index) = v_array.len().checked_sub(*index) { @@ -499,7 +501,7 @@ impl PathItem { } } - pub fn json_obj_get<'a>(&self, json_obj: &'a JsonObject) -> Option<&'a JsonInput> { + pub fn json_obj_get<'a>(&self, json_obj: &'a JsonObject) -> Option<&'a JsonValue> { match self { Self::S(key, _) => json_obj.get(key), _ => None, diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index 7f406ba16..748b13338 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -6,7 +6,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::schema_or_config_same; -use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{GenericArguments, Input}; use crate::lookup_key::LookupKey; diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index b18faea2c..d93441ce0 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -7,7 +7,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior}; -use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{BorrowInput, GenericArguments, Input}; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; diff --git a/src/validators/dict.rs b/src/validators/dict.rs index 5026afba3..c7df345ed 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -3,7 +3,7 @@ use pyo3::prelude::*; use pyo3::types::PyDict; use crate::build_tools::is_strict; -use crate::errors::{ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ValError, ValLineError, ValResult}; use crate::input::BorrowInput; use crate::input::{ DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, MappingGenericIterator, diff --git a/src/validators/function.rs b/src/validators/function.rs index adb143696..66bbafbb9 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -1,16 +1,16 @@ use std::sync::Arc; -use pyo3::exceptions::{PyAssertionError, PyTypeError, PyValueError}; +use pyo3::exceptions::{PyAssertionError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyAny, PyDict, PyString}; use pyo3::{intern, PyTraverseError, PyVisit}; use crate::errors::{ - ErrorType, LocItem, PydanticCustomError, PydanticKnownError, PydanticOmit, ValError, ValResult, ValidationError, + AsLocItem, ErrorType, PydanticCustomError, PydanticKnownError, PydanticOmit, ValError, ValResult, ValidationError, }; use crate::input::Input; use crate::py_gc::PyGcTraverse; -use crate::tools::{function_name, py_err, safe_repr, SchemaDict}; +use crate::tools::{function_name, safe_repr, SchemaDict}; use crate::PydanticUseDefault; use super::generator::InternalValidator; @@ -406,13 +406,7 @@ struct ValidatorCallable { #[pymethods] impl ValidatorCallable { fn __call__(&mut self, py: Python, input_value: &PyAny, outer_location: Option<&PyAny>) -> PyResult { - let outer_location = match outer_location { - Some(ol) => match LocItem::try_from(ol) { - Ok(ol) => Some(ol), - Err(_) => return py_err!(PyTypeError; "outer_location must be a str or int"), - }, - None => None, - }; + let outer_location = outer_location.map(AsLocItem::as_loc_item); self.validator.validate(py, input_value, outer_location) } @@ -440,13 +434,7 @@ struct AssignmentValidatorCallable { #[pymethods] impl AssignmentValidatorCallable { fn __call__(&mut self, py: Python, input_value: &PyAny, outer_location: Option<&PyAny>) -> PyResult { - let outer_location = match outer_location { - Some(ol) => match LocItem::try_from(ol) { - Ok(ol) => Some(ol), - Err(_) => return py_err!(PyTypeError; "outer_location must be a str or int"), - }, - None => None, - }; + let outer_location = outer_location.map(AsLocItem::as_loc_item); self.validator.validate_assignment( py, input_value, diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index b79145c97..17ec81670 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -7,7 +7,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior}; -use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{ AttributesGenericIterator, BorrowInput, DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, MappingGenericIterator, StringMappingGenericIterator, diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index dab492da7..5839959e7 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -8,7 +8,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config, schema_or_config_same, ExtraBehavior}; -use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{ AttributesGenericIterator, BorrowInput, DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, MappingGenericIterator, StringMappingGenericIterator, diff --git a/src/validators/union.rs b/src/validators/union.rs index a8bd29d7d..837114408 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -9,7 +9,7 @@ use smallvec::SmallVec; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config}; -use crate::errors::{ErrorType, LocItem, ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ErrorType, ValError, ValLineError, ValResult}; use crate::input::{GenericMapping, Input}; use crate::lookup_key::LookupKey; use crate::py_gc::PyGcTraverse; @@ -566,7 +566,7 @@ impl TaggedUnionValidator { if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag) { return match validator.validate(py, input, state) { Ok(res) => Ok(res), - Err(err) => Err(err.with_outer_location(LocItem::try_from(tag.to_object(py).into_ref(py))?)), + Err(err) => Err(err.with_outer_location(tag.as_loc_item())), }; } match self.custom_error { diff --git a/tests/test_json.py b/tests/test_json.py index 9bba05c14..4ef8a1d40 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -4,7 +4,7 @@ from typing import List import pytest -from dirty_equals import IsList +from dirty_equals import IsFloatNan, IsList import pydantic_core from pydantic_core import ( @@ -358,3 +358,10 @@ def test_bad_repr(): to_json(b) assert to_json(b, serialize_unknown=True) == b'""' + + +def test_inf_nan_allow(): + v = SchemaValidator(core_schema.float_schema(allow_inf_nan=True)) + assert v.validate_json('Infinity') == float('inf') + assert v.validate_json('-Infinity') == float('-inf') + assert v.validate_json('NaN') == IsFloatNan() diff --git a/tests/validators/test_decimal.py b/tests/validators/test_decimal.py index 43b3d19b9..b9fabeaed 100644 --- a/tests/validators/test_decimal.py +++ b/tests/validators/test_decimal.py @@ -140,7 +140,8 @@ def test_decimal_strict_json(input_value, expected): {'ge': 0}, -0.1, Err( - 'Input should be greater than or equal to 0 [type=greater_than_equal, input_value=-0.1, input_type=float]' + 'Input should be greater than or equal to 0 ' + '[type=greater_than_equal, input_value=-0.1, input_type=float]' ), ), ({'gt': 0}, 0.1, Decimal('0.1')), @@ -150,14 +151,14 @@ def test_decimal_strict_json(input_value, expected): ({'le': 0}, 0.1, Err('Input should be less than or equal to 0')), ({'lt': 0, 'allow_inf_nan': True}, float('nan'), Err('Input should be less than 0')), ({'gt': 0, 'allow_inf_nan': True}, float('inf'), Decimal('inf')), + ({'allow_inf_nan': True}, float('-inf'), Decimal('-inf')), + ({'allow_inf_nan': True}, float('nan'), FunctionCheck(math.isnan)), ({'lt': 0}, 0, Err('Input should be less than 0')), ({'lt': 0.123456}, 1, Err('Input should be less than 0.123456')), ], ) def test_decimal_kwargs(py_and_json: PyAndJson, kwargs: Dict[str, Any], input_value, expected): v = py_and_json({'type': 'decimal', **kwargs}) - if v.validator_type == 'json' and isinstance(input_value, float) and not math.isfinite(input_value): - expected = Err('Invalid JSON') if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_test(input_value) diff --git a/tests/validators/test_float.py b/tests/validators/test_float.py index b18181fbb..35b04c3f9 100644 --- a/tests/validators/test_float.py +++ b/tests/validators/test_float.py @@ -4,9 +4,9 @@ from typing import Any, Dict import pytest -from dirty_equals import FunctionCheck, IsStr +from dirty_equals import FunctionCheck, IsFloatNan, IsStr -from pydantic_core import SchemaValidator, ValidationError +from pydantic_core import SchemaValidator, ValidationError, core_schema from ..conftest import Err, PyAndJson, plain_repr @@ -92,8 +92,6 @@ def test_float_strict(py_and_json: PyAndJson, input_value, expected): ) def test_float_kwargs(py_and_json: PyAndJson, kwargs: Dict[str, Any], input_value, expected): v = py_and_json({'type': 'float', **kwargs}) - if v.validator_type == 'json' and isinstance(input_value, float) and not math.isfinite(input_value): - expected = Err('Invalid JSON') if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_test(input_value) @@ -376,3 +374,34 @@ def test_string_with_underscores() -> None: v.validate_python(edge_case) with pytest.raises(ValidationError): v.validate_json(f'"{edge_case}"') + + +def test_allow_inf_nan_true_json() -> None: + v = SchemaValidator(core_schema.float_schema()) + + assert v.validate_json('123') == 123 + assert v.validate_json('NaN') == IsFloatNan() + assert v.validate_json('Infinity') == float('inf') + assert v.validate_json('-Infinity') == float('-inf') + + +def test_allow_inf_nan_false_json() -> None: + v = SchemaValidator(core_schema.float_schema(), core_schema.CoreConfig(allow_inf_nan=False)) + + assert v.validate_json('123') == 123 + with pytest.raises(ValidationError) as exc_info1: + v.validate_json('NaN') + # insert_assert(exc_info.value.errors()) + assert exc_info1.value.errors(include_url=False) == [ + {'type': 'finite_number', 'loc': (), 'msg': 'Input should be a finite number', 'input': IsFloatNan()} + ] + with pytest.raises(ValidationError) as exc_info2: + v.validate_json('Infinity') + assert exc_info2.value.errors(include_url=False) == [ + {'type': 'finite_number', 'loc': (), 'msg': 'Input should be a finite number', 'input': float('inf')} + ] + with pytest.raises(ValidationError) as exc_info3: + v.validate_json('-Infinity') + assert exc_info3.value.errors(include_url=False) == [ + {'type': 'finite_number', 'loc': (), 'msg': 'Input should be a finite number', 'input': float('-inf')} + ] diff --git a/tests/validators/test_function.py b/tests/validators/test_function.py index 9f94ceb1b..e5ccba1e3 100644 --- a/tests/validators/test_function.py +++ b/tests/validators/test_function.py @@ -289,8 +289,19 @@ def f(input_value, validator, info): v = SchemaValidator(core_schema.with_info_wrap_validator_function(f, core_schema.int_schema())) - with pytest.raises(TypeError, match='^outer_location must be a str or int$'): - v.validate_python(4) + assert v.validate_python(4) == 6 + + with pytest.raises(ValidationError) as exc_info: + v.validate_python('wrong') + # insert_assert(exc_info.value.errors(include_url=False)) + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'int_parsing', + 'loc': ("('4',)",), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'wrong', + } + ] def test_function_after(): diff --git a/tests/validators/test_int.py b/tests/validators/test_int.py index dedc2bd93..61acab7fb 100644 --- a/tests/validators/test_int.py +++ b/tests/validators/test_int.py @@ -6,7 +6,7 @@ import pytest from dirty_equals import IsStr -from pydantic_core import SchemaValidator, ValidationError +from pydantic_core import SchemaValidator, ValidationError, core_schema from ..conftest import Err, PyAndJson, plain_repr @@ -472,3 +472,27 @@ class PlainEnum(Enum): v_lax = v.validate_python(PlainEnum.ONE) assert v_lax == 1 assert type(v_lax) == int + + +def test_allow_inf_nan_true_json() -> None: + v = SchemaValidator(core_schema.int_schema(), core_schema.CoreConfig(allow_inf_nan=True)) + + assert v.validate_json('123') == 123 + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('NaN') + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('Infinity') + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('-Infinity') + + +def test_allow_inf_nan_false_json() -> None: + v = SchemaValidator(core_schema.int_schema(), core_schema.CoreConfig(allow_inf_nan=False)) + + assert v.validate_json('123') == 123 + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('NaN') + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('Infinity') + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('-Infinity')