Skip to content

Commit

Permalink
Switch to val_json_bytes config key
Browse files Browse the repository at this point in the history
  • Loading branch information
josh-newman committed Jun 11, 2024
1 parent c1c84e3 commit 6e7fc01
Show file tree
Hide file tree
Showing 13 changed files with 68 additions and 32 deletions.
2 changes: 2 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class CoreConfig(TypedDict, total=False):
ser_json_bytes: The serialization option for `bytes` values. Default is 'utf8'.
ser_json_inf_nan: The serialization option for infinity and NaN values
in float fields. Default is 'null'.
val_json_bytes: The validation option for `bytes` values, complementing ser_json_bytes. Default is 'utf8'.
hide_input_in_errors: Whether to hide input data from `ValidationError` representation.
validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError.
Requires exceptiongroup backport pre Python 3.11.
Expand Down Expand Up @@ -107,6 +108,7 @@ class CoreConfig(TypedDict, total=False):
ser_json_timedelta: Literal['iso8601', 'float'] # default: 'iso8601'
ser_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8'
ser_json_inf_nan: Literal['null', 'constants', 'strings'] # default: 'null'
val_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8'
# used to hide input data from ValidationError repr
hide_input_in_errors: bool
validation_error_cause: bool # default: False
Expand Down
4 changes: 2 additions & 2 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use pyo3::{intern, prelude::*};

use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::lookup_key::{LookupKey, LookupPath};
use crate::serializers::config::BytesMode;
use crate::tools::py_err;
use crate::validators::config::ValBytesMode;

use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherInt, EitherString};
Expand Down Expand Up @@ -72,7 +72,7 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {

fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValMatch<EitherString<'_>>;

fn validate_bytes<'a>(&'a self, strict: bool, mode: BytesMode) -> ValMatch<EitherBytes<'a, 'py>>;
fn validate_bytes<'a>(&'a self, strict: bool, mode: ValBytesMode) -> ValMatch<EitherBytes<'a, 'py>>;

fn validate_bool(&self, strict: bool) -> ValMatch<bool>;

Expand Down
6 changes: 3 additions & 3 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use strum::EnumMessage;

use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::lookup_key::{LookupKey, LookupPath};
use crate::serializers::config::BytesMode;
use crate::validators::config::ValBytesMode;
use crate::validators::decimal::create_decimal;

use super::datetime::{
Expand Down Expand Up @@ -110,7 +110,7 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
fn validate_bytes<'a>(
&'a self,
_strict: bool,
mode: BytesMode,
mode: ValBytesMode,
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
match self {
JsonValue::Str(s) => match mode.deserialize_string(s) {
Expand Down Expand Up @@ -353,7 +353,7 @@ impl<'py> Input<'py> for str {
fn validate_bytes<'a>(
&'a self,
_strict: bool,
mode: BytesMode,
mode: ValBytesMode,
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
match mode.deserialize_string(self) {
Ok(b) => Ok(ValidationMatch::strict(b)),
Expand Down
8 changes: 6 additions & 2 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use pyo3::PyTypeCheck;
use speedate::MicrosecondsPrecisionOverflowBehavior;

use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::serializers::config::BytesMode;
use crate::tools::{extract_i64, safe_repr};
use crate::validators::config::ValBytesMode;
use crate::validators::decimal::{create_decimal, get_decimal_type};
use crate::validators::Exactness;
use crate::ArgsKwargs;
Expand Down Expand Up @@ -175,7 +175,11 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
Err(ValError::new(ErrorTypeDefaults::StringType, self))
}

fn validate_bytes<'a>(&'a self, strict: bool, mode: BytesMode) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
fn validate_bytes<'a>(
&'a self,
strict: bool,
mode: ValBytesMode,
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
if let Ok(py_bytes) = self.downcast_exact::<PyBytes>() {
return Ok(ValidationMatch::exact(py_bytes.into()));
} else if let Ok(py_bytes) = self.downcast::<PyBytes>() {
Expand Down
4 changes: 2 additions & 2 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use speedate::MicrosecondsPrecisionOverflowBehavior;
use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::input::py_string_str;
use crate::lookup_key::{LookupKey, LookupPath};
use crate::serializers::config::BytesMode;
use crate::tools::safe_repr;
use crate::validators::config::ValBytesMode;
use crate::validators::decimal::create_decimal;

use super::datetime::{
Expand Down Expand Up @@ -109,7 +109,7 @@ impl<'py> Input<'py> for StringMapping<'py> {
fn validate_bytes<'a>(
&'a self,
_strict: bool,
mode: BytesMode,
mode: ValBytesMode,
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
match self {
Self::String(s) => py_string_str(s).and_then(|b| match mode.deserialize_string(b) {
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use jiter::{map_json_error, PartialMode, PythonParse, StringCacheMode};
use pyo3::exceptions::PyTypeError;
use pyo3::{prelude::*, sync::GILOnceCell};
use serializers::config::BytesMode;
use validators::config::ValBytesMode;

// parse this first to get access to the contained macro
#[macro_use]
Expand Down Expand Up @@ -56,7 +57,7 @@ pub fn from_json<'py>(
allow_partial: bool,
) -> PyResult<Bound<'py, PyAny>> {
let v_match = data
.validate_bytes(false, BytesMode::Utf8)
.validate_bytes(false, ValBytesMode { ser: BytesMode::Utf8 })
.map_err(|_| PyTypeError::new_err("Expected bytes, bytearray or str"))?;
let json_either_bytes = v_match.into_inner();
let json_bytes = json_either_bytes.as_slice();
Expand Down
14 changes: 1 addition & 13 deletions src/serializers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ use std::borrow::Cow;
use std::str::{from_utf8, FromStr, Utf8Error};

use base64::Engine;
use pyo3::exceptions::PyValueError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDelta, PyDict, PyString};

use serde::ser::Error;

use crate::build_tools::py_schema_err;
use crate::input::{EitherBytes, EitherTimedelta};
use crate::input::EitherTimedelta;
use crate::tools::SchemaDict;

use super::errors::py_err_se_err;
Expand Down Expand Up @@ -189,17 +188,6 @@ impl BytesMode {
}
}
}

pub fn deserialize_string<'a, 'py>(&self, s: &'a str) -> PyResult<EitherBytes<'a, 'py>> {
match self {
Self::Utf8 => Ok(EitherBytes::Cow(Cow::Borrowed(s.as_bytes()))),
Self::Base64 => match base64::engine::general_purpose::URL_SAFE.decode(s) {
Ok(bytes) => Ok(EitherBytes::from(bytes)),
Err(err) => Err(PyValueError::new_err(format!("Base64 decode error: {err}"))),
},
Self::Hex => Err(PyValueError::new_err("Hex deserialization is not supported")),
}
}
}

pub fn utf8_py_error(py: Python, err: Utf8Error, data: &[u8]) -> PyErr {
Expand Down
10 changes: 5 additions & 5 deletions src/validators/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ use crate::build_tools::is_strict;
use crate::errors::{ErrorType, ValError, ValResult};
use crate::input::Input;

use crate::serializers::config::{BytesMode, FromConfig};
use crate::tools::SchemaDict;

use super::config::ValBytesMode;
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};

#[derive(Debug, Clone)]
pub struct BytesValidator {
strict: bool,
bytes_mode: BytesMode,
bytes_mode: ValBytesMode,
}

impl BuildValidator for BytesValidator {
Expand All @@ -33,7 +33,7 @@ impl BuildValidator for BytesValidator {
} else {
Ok(Self {
strict: is_strict(schema, config)?,
bytes_mode: BytesMode::from_config(config)?,
bytes_mode: ValBytesMode::from_config(config)?,
}
.into())
}
Expand Down Expand Up @@ -62,7 +62,7 @@ impl Validator for BytesValidator {
#[derive(Debug, Clone)]
pub struct BytesConstrainedValidator {
strict: bool,
bytes_mode: BytesMode,
bytes_mode: ValBytesMode,
max_length: Option<usize>,
min_length: Option<usize>,
}
Expand Down Expand Up @@ -116,7 +116,7 @@ impl BytesConstrainedValidator {
let py = schema.py();
Ok(Self {
strict: is_strict(schema, config)?,
bytes_mode: BytesMode::from_config(config)?,
bytes_mode: ValBytesMode::from_config(config)?,
min_length: schema.get_as(intern!(py, "min_length"))?,
max_length: schema.get_as(intern!(py, "max_length"))?,
}
Expand Down
38 changes: 38 additions & 0 deletions src/validators/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use std::borrow::Cow;
use std::str::FromStr;

use base64::Engine;
use pyo3::exceptions::PyValueError;
use pyo3::types::{PyDict, PyString};
use pyo3::{intern, prelude::*};

use crate::input::EitherBytes;
use crate::serializers::config::BytesMode;
use crate::tools::SchemaDict;

#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub struct ValBytesMode {
pub ser: BytesMode,
}

impl ValBytesMode {
pub fn from_config(config: Option<&Bound<'_, PyDict>>) -> PyResult<Self> {
let Some(config_dict) = config else {
return Ok(Self::default());
};
let raw_mode = config_dict.get_as::<Bound<'_, PyString>>(intern!(config_dict.py(), "val_json_bytes"))?;
let ser_mode = raw_mode.map_or_else(|| Ok(BytesMode::default()), |raw| BytesMode::from_str(&raw.to_cow()?))?;
Ok(Self { ser: ser_mode })
}

pub fn deserialize_string<'a, 'py>(&self, s: &'a str) -> PyResult<EitherBytes<'a, 'py>> {
match self.ser {
BytesMode::Utf8 => Ok(EitherBytes::Cow(Cow::Borrowed(s.as_bytes()))),
BytesMode::Base64 => match base64::engine::general_purpose::URL_SAFE.decode(s) {
Ok(bytes) => Ok(EitherBytes::from(bytes)),
Err(err) => Err(PyValueError::new_err(format!("Base64 decode error: {err}"))),
},
BytesMode::Hex => Err(PyValueError::new_err("Hex deserialization is not supported")),
}
}
}
3 changes: 2 additions & 1 deletion src/validators/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::input::{EitherBytes, Input, InputType, ValidationMatch};
use crate::serializers::config::BytesMode;
use crate::tools::SchemaDict;

use super::config::ValBytesMode;
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};

#[derive(Debug)]
Expand Down Expand Up @@ -88,7 +89,7 @@ impl Validator for JsonValidator {
pub fn validate_json_bytes<'a, 'py>(
input: &'a (impl Input<'py> + ?Sized),
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
match input.validate_bytes(false, BytesMode::Utf8) {
match input.validate_bytes(false, ValBytesMode { ser: BytesMode::Utf8 }) {
Ok(v_match) => Ok(v_match),
Err(ValError::LineErrors(e)) => Err(ValError::LineErrors(
e.into_iter().map(map_bytes_error).collect::<Vec<_>>(),
Expand Down
1 change: 1 addition & 0 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod bytes;
mod call;
mod callable;
mod chain;
pub(crate) mod config;
mod custom_error;
mod dataclass;
mod date;
Expand Down
3 changes: 2 additions & 1 deletion src/validators/uuid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::input::ValidationMatch;
use crate::serializers::config::BytesMode;
use crate::tools::SchemaDict;

use super::config::ValBytesMode;
use super::model::create_class;
use super::model::force_setattr;
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Exactness, ValidationState, Validator};
Expand Down Expand Up @@ -170,7 +171,7 @@ impl UuidValidator {
}
None => {
let either_bytes = input
.validate_bytes(true, BytesMode::Utf8)
.validate_bytes(true, ValBytesMode { ser: BytesMode::Utf8 })
.map_err(|_| ValError::new(ErrorTypeDefaults::UuidType, input))?
.into_inner();
let bytes_slice = either_bytes.as_slice();
Expand Down
4 changes: 2 additions & 2 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def test_json_bytes_base64_round_trip():
encoded = b'"aGVsbG8="'
assert to_json(data, bytes_mode='base64') == encoded

v = SchemaValidator({'type': 'bytes'}, {'ser_json_bytes': 'base64'})
v = SchemaValidator({'type': 'bytes'}, {'val_json_bytes': 'base64'})
assert v.validate_json(encoded) == data

with pytest.raises(ValueError):
Expand All @@ -392,6 +392,6 @@ def test_json_bytes_base64_round_trip():
assert to_json({'key': data}, bytes_mode='base64') == b'{"key":"aGVsbG8="}'
v = SchemaValidator(
{'type': 'dict', 'keys_schema': {'type': 'str'}, 'values_schema': {'type': 'bytes'}},
{'ser_json_bytes': 'base64'},
{'val_json_bytes': 'base64'},
)
assert v.validate_json('{"key":"aGVsbG8="}') == {'key': data}

0 comments on commit 6e7fc01

Please sign in to comment.