Skip to content

Commit

Permalink
support trailing-strings with allow_partial (#1539)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Nov 12, 2024
1 parent a3f13c7 commit be03c66
Show file tree
Hide file tree
Showing 21 changed files with 159 additions and 106 deletions.
5 changes: 4 additions & 1 deletion .mypy-stubtest-allowlist
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# TODO: don't want to expose this staticmethod, requires https://github.com/PyO3/pyo3/issues/2384
pydantic_core._pydantic_core.PydanticUndefinedType.new
# As per #1240, from_json has custom logic to coverage the `cache_strings` kwarg
# See #1540 for discussion
pydantic_core._pydantic_core.from_json
pydantic_core._pydantic_core.SchemaValidator.validate_python
pydantic_core._pydantic_core.SchemaValidator.validate_json
pydantic_core._pydantic_core.SchemaValidator.validate_strings
# the `warnings` kwarg for SchemaSerializer functions has custom logic
pydantic_core._pydantic_core.SchemaSerializer.to_python
pydantic_core._pydantic_core.SchemaSerializer.to_json
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ base64 = "0.22.1"
num-bigint = "0.4.6"
python3-dll-a = "0.2.10"
uuid = "1.11.0"
jiter = { version = "0.7", features = ["python"] }
jiter = { version = "0.7.1", features = ["python"] }
hex = "0.4.3"

[lib]
Expand Down
102 changes: 51 additions & 51 deletions benches/main.rs

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,6 @@ require_change_file = false
[tool.pyright]
include = ['pydantic_core', 'tests/test_typing.py']
reportUnnecessaryTypeIgnoreComment = true

[tool.inline-snapshot.shortcuts]
fix = ["create", "fix"]
10 changes: 7 additions & 3 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class SchemaValidator:
from_attributes: bool | None = None,
context: Any | None = None,
self_instance: Any | None = None,
allow_partial: bool = False,
allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
) -> Any:
"""
Validate a Python object against the schema and return the validated object.
Expand All @@ -113,6 +113,7 @@ class SchemaValidator:
validation from the `__init__` method of a model.
allow_partial: Whether to allow partial validation; if `True` errors in the last element of sequences
and mappings are ignored.
`'trailing-strings'` means any final unfinished JSON string is included in the result.
Raises:
ValidationError: If validation fails.
Expand Down Expand Up @@ -146,7 +147,7 @@ class SchemaValidator:
strict: bool | None = None,
context: Any | None = None,
self_instance: Any | None = None,
allow_partial: bool = False,
allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
) -> Any:
"""
Validate JSON data directly against the schema and return the validated Python object.
Expand All @@ -166,6 +167,7 @@ class SchemaValidator:
self_instance: An instance of a model set attributes on from validation.
allow_partial: Whether to allow partial validation; if `True` incomplete JSON will be parsed successfully
and errors in the last element of sequences and mappings are ignored.
`'trailing-strings'` means any final unfinished JSON string is included in the result.
Raises:
ValidationError: If validation fails or if the JSON data is invalid.
Expand All @@ -180,7 +182,7 @@ class SchemaValidator:
*,
strict: bool | None = None,
context: Any | None = None,
allow_partial: bool = False,
allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
) -> Any:
"""
Validate a string against the schema and return the validated Python object.
Expand All @@ -196,6 +198,7 @@ class SchemaValidator:
[`info.context`][pydantic_core.core_schema.ValidationInfo.context].
allow_partial: Whether to allow partial validation; if `True` errors in the last element of sequences
and mappings are ignored.
`'trailing-strings'` means any final unfinished JSON string is included in the result.
Raises:
ValidationError: If validation fails or if the JSON data is invalid.
Expand Down Expand Up @@ -433,6 +436,7 @@ def from_json(
`all/True` means cache all strings, `keys` means cache only dict keys, `none/False` means no caching.
allow_partial: Whether to allow partial deserialization, if `True` JSON data is returned if the end of the
input is reached before the full object is deserialized, e.g. `["aa", "bb", "c` would return `['aa', 'bb']`.
`'trailing-strings'` means any final unfinished JSON string is included in the result.
Raises:
ValueError: If deserialization fails.
Expand Down
15 changes: 12 additions & 3 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::cmp::Ordering;
use std::ops::Rem;
use std::str::FromStr;

use jiter::{JsonArray, JsonValue, StringCacheMode};
use jiter::{JsonArray, JsonValue, PartialMode, StringCacheMode};
use num_bigint::BigInt;

use pyo3::exceptions::PyTypeError;
Expand Down Expand Up @@ -128,9 +128,13 @@ pub(crate) fn validate_iter_to_vec<'py>(
) -> ValResult<Vec<PyObject>> {
let mut output: Vec<PyObject> = Vec::with_capacity(capacity);
let mut errors: Vec<ValLineError> = Vec::new();
let allow_partial = state.allow_partial;

for (index, is_last_partial, item_result) in state.enumerate_last_partial(iter) {
state.allow_partial = is_last_partial;
state.allow_partial = match is_last_partial {
true => allow_partial,
false => PartialMode::Off,
};
let item = item_result.map_err(|e| any_next_error!(py, e, max_length_check.input, index))?;
match validator.validate(py, item.borrow_input(), state) {
Ok(item) => {
Expand Down Expand Up @@ -202,8 +206,13 @@ pub(crate) fn validate_iter_to_set<'py>(
) -> ValResult<()> {
let mut errors: Vec<ValLineError> = Vec::new();

let allow_partial = state.allow_partial;

for (index, is_last_partial, item_result) in state.enumerate_last_partial(iter) {
state.allow_partial = is_last_partial;
state.allow_partial = match is_last_partial {
true => allow_partial,
false => PartialMode::Off,
};
let item = item_result.map_err(|e| any_next_error!(py, e, input, index))?;
match validator.validate(py, item.borrow_input(), state) {
Ok(item) => {
Expand Down
4 changes: 2 additions & 2 deletions src/url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl PyUrl {
pub fn py_new(py: Python, url: &Bound<'_, PyAny>) -> PyResult<Self> {
let schema_obj = SCHEMA_DEFINITION_URL
.get_or_init(py, || build_schema_validator(py, "url"))
.validate_python(py, url, None, None, None, None, false)?;
.validate_python(py, url, None, None, None, None, false.into())?;
schema_obj.extract(py)
}

Expand Down Expand Up @@ -225,7 +225,7 @@ impl PyMultiHostUrl {
pub fn py_new(py: Python, url: &Bound<'_, PyAny>) -> PyResult<Self> {
let schema_obj = SCHEMA_DEFINITION_MULTI_HOST_URL
.get_or_init(py, || build_schema_validator(py, "multi-host-url"))
.validate_python(py, url, None, None, None, None, false)?;
.validate_python(py, url, None, None, None, None, false.into())?;
schema_obj.extract(py)
}

Expand Down
3 changes: 1 addition & 2 deletions src/validators/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use crate::build_tools::{schema_or_config_same, ExtraBehavior};
use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult};
use crate::input::{Arguments, BorrowInput, Input, KeywordArgs, PositionalArgs, ValidationMatch};
use crate::lookup_key::LookupKey;

use crate::tools::SchemaDict;

use super::validation_state::ValidationState;
Expand Down Expand Up @@ -189,7 +188,7 @@ impl Validator for ArgumentsValidator {
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
// this validator does not yet support partial validation, disable it to avoid incorrect results
state.allow_partial = false;
state.allow_partial = false.into();

let args = input.validate_args()?;

Expand Down
2 changes: 1 addition & 1 deletion src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ impl Validator for DataclassArgsValidator {
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
// this validator does not yet support partial validation, disable it to avoid incorrect results
state.allow_partial = false;
state.allow_partial = false.into();

let args = input.validate_dataclass_args(&self.dataclass_name)?;

Expand Down
2 changes: 1 addition & 1 deletion src/validators/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl Validator for DefinitionRefValidator {
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
// this validator does not yet support partial validation, disable it to avoid incorrect results
state.allow_partial = false;
state.allow_partial = false.into();

self.definition.read(|validator| {
let validator = validator.unwrap();
Expand Down
8 changes: 6 additions & 2 deletions src/validators/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ where
fn consume_iterator(self, iterator: impl Iterator<Item = ValResult<(Key, Value)>>) -> ValResult<PyObject> {
let output = PyDict::new_bound(self.py);
let mut errors: Vec<ValLineError> = Vec::new();
let allow_partial = self.state.allow_partial;

for (_, is_last_partial, item_result) in self.state.enumerate_last_partial(iterator) {
self.state.allow_partial = false;
self.state.allow_partial = false.into();
let (key, value) = item_result?;
let output_key = match self.key_validator.validate(self.py, key.borrow_input(), self.state) {
Ok(value) => Some(value),
Expand All @@ -125,7 +126,10 @@ where
Err(ValError::Omit) => continue,
Err(err) => return Err(err),
};
self.state.allow_partial = is_last_partial;
self.state.allow_partial = match is_last_partial {
true => allow_partial,
false => false.into(),
};
let output_value = match self.value_validator.validate(self.py, value.borrow_input(), self.state) {
Ok(value) => value,
Err(ValError::LineErrors(line_errors)) => {
Expand Down
6 changes: 3 additions & 3 deletions src/validators/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl Validator for GeneratorValidator {
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
// this validator does not yet support partial validation, disable it to avoid incorrect results
state.allow_partial = false;
state.allow_partial = false.into();

let iterator = input.validate_iter()?.into_static();
let validator = self.item_validator.as_ref().map(|v| {
Expand Down Expand Up @@ -282,7 +282,7 @@ impl InternalValidator {
self_instance: self.self_instance.as_ref().map(|data| data.bind(py)),
cache_str: self.cache_str,
};
let mut state = ValidationState::new(extra, &mut self.recursion_guard, false);
let mut state = ValidationState::new(extra, &mut self.recursion_guard, false.into());
state.exactness = self.exactness;
let result = self
.validator
Expand Down Expand Up @@ -317,7 +317,7 @@ impl InternalValidator {
self_instance: self.self_instance.as_ref().map(|data| data.bind(py)),
cache_str: self.cache_str,
};
let mut state = ValidationState::new(extra, &mut self.recursion_guard, false);
let mut state = ValidationState::new(extra, &mut self.recursion_guard, false.into());
state.exactness = self.exactness;
let result = self.validator.validate(py, input, &mut state).map_err(|e| {
ValidationError::from_val_error(
Expand Down
8 changes: 2 additions & 6 deletions src/validators/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::PyDict;

use jiter::{FloatMode, JsonValue, PartialMode, PythonParse};
use jiter::{FloatMode, JsonValue, PythonParse};

use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult};
use crate::input::{EitherBytes, Input, InputType, ValidationMatch};
Expand Down Expand Up @@ -70,11 +70,7 @@ impl Validator for JsonValidator {
let parse_builder = PythonParse {
allow_inf_nan: true,
cache_mode: state.cache_str(),
partial_mode: if state.allow_partial {
PartialMode::TrailingStrings
} else {
PartialMode::Off
},
partial_mode: state.allow_partial,
catch_duplicate_keys: false,
float_mode: FloatMode::Float,
};
Expand Down
26 changes: 13 additions & 13 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt::Debug;

use enum_dispatch::enum_dispatch;
use jiter::StringCacheMode;
use jiter::{PartialMode, StringCacheMode};

use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
Expand Down Expand Up @@ -165,7 +165,7 @@ impl SchemaValidator {
}

#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (input, *, strict=None, from_attributes=None, context=None, self_instance=None, allow_partial=false))]
#[pyo3(signature = (input, *, strict=None, from_attributes=None, context=None, self_instance=None, allow_partial=PartialMode::Off))]
pub fn validate_python(
&self,
py: Python,
Expand All @@ -174,7 +174,7 @@ impl SchemaValidator {
from_attributes: Option<bool>,
context: Option<&Bound<'_, PyAny>>,
self_instance: Option<&Bound<'_, PyAny>>,
allow_partial: bool,
allow_partial: PartialMode,
) -> PyResult<PyObject> {
self._validate(
py,
Expand Down Expand Up @@ -207,7 +207,7 @@ impl SchemaValidator {
from_attributes,
context,
self_instance,
false,
false.into(),
) {
Ok(_) => Ok(true),
Err(ValError::InternalErr(err)) => Err(err),
Expand All @@ -217,15 +217,15 @@ impl SchemaValidator {
}
}

#[pyo3(signature = (input, *, strict=None, context=None, self_instance=None, allow_partial=false))]
#[pyo3(signature = (input, *, strict=None, context=None, self_instance=None, allow_partial=PartialMode::Off))]
pub fn validate_json(
&self,
py: Python,
input: &Bound<'_, PyAny>,
strict: Option<bool>,
context: Option<&Bound<'_, PyAny>>,
self_instance: Option<&Bound<'_, PyAny>>,
allow_partial: bool,
allow_partial: PartialMode,
) -> PyResult<PyObject> {
let r = match json::validate_json_bytes(input) {
Ok(v_match) => self._validate_json(
Expand All @@ -242,14 +242,14 @@ impl SchemaValidator {
r.map_err(|e| self.prepare_validation_err(py, e, InputType::Json))
}

#[pyo3(signature = (input, *, strict=None, context=None, allow_partial=false))]
#[pyo3(signature = (input, *, strict=None, context=None, allow_partial=PartialMode::Off))]
pub fn validate_strings(
&self,
py: Python,
input: Bound<'_, PyAny>,
strict: Option<bool>,
context: Option<&Bound<'_, PyAny>>,
allow_partial: bool,
allow_partial: PartialMode,
) -> PyResult<PyObject> {
let t = InputType::String;
let string_mapping = StringMapping::new_value(input).map_err(|e| self.prepare_validation_err(py, e, t))?;
Expand Down Expand Up @@ -283,7 +283,7 @@ impl SchemaValidator {
};

let guard = &mut RecursionState::default();
let mut state = ValidationState::new(extra, guard, false);
let mut state = ValidationState::new(extra, guard, false.into());
self.validator
.validate_assignment(py, &obj, field_name, &field_value, &mut state)
.map_err(|e| self.prepare_validation_err(py, e, InputType::Python))
Expand All @@ -306,7 +306,7 @@ impl SchemaValidator {
cache_str: self.cache_str,
};
let recursion_guard = &mut RecursionState::default();
let mut state = ValidationState::new(extra, recursion_guard, false);
let mut state = ValidationState::new(extra, recursion_guard, false.into());
let r = self.validator.default_value(py, None::<i64>, &mut state);
match r {
Ok(maybe_default) => match maybe_default {
Expand Down Expand Up @@ -352,7 +352,7 @@ impl SchemaValidator {
from_attributes: Option<bool>,
context: Option<&Bound<'py, PyAny>>,
self_instance: Option<&Bound<'py, PyAny>>,
allow_partial: bool,
allow_partial: PartialMode,
) -> ValResult<PyObject> {
let mut recursion_guard = RecursionState::default();
let mut state = ValidationState::new(
Expand All @@ -379,7 +379,7 @@ impl SchemaValidator {
strict: Option<bool>,
context: Option<&Bound<'_, PyAny>>,
self_instance: Option<&Bound<'_, PyAny>>,
allow_partial: bool,
allow_partial: PartialMode,
) -> ValResult<PyObject> {
let json_value = jiter::JsonValue::parse_with_config(json_data, true, allow_partial)
.map_err(|e| json::map_json_err(input, e, json_data))?;
Expand Down Expand Up @@ -430,7 +430,7 @@ impl<'py> SelfValidator<'py> {
let mut state = ValidationState::new(
Extra::new(strict, None, None, None, InputType::Python, true.into()),
&mut recursion_guard,
false,
false.into(),
);
match self.validator.validator.validate(py, schema, &mut state) {
Ok(schema_obj) => Ok(schema_obj.into_bound(py)),
Expand Down
2 changes: 1 addition & 1 deletion src/validators/model_fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl Validator for ModelFieldsValidator {
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
// this validator does not yet support partial validation, disable it to avoid incorrect results
state.allow_partial = false;
state.allow_partial = false.into();

let strict = state.strict_or(self.strict);
let from_attributes = state.extra().from_attributes.unwrap_or(self.from_attributes);
Expand Down
Loading

0 comments on commit be03c66

Please sign in to comment.