Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate bytes based on ser_json_bytes #1308

Merged
merged 14 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ num-bigint = "0.4.6"
python3-dll-a = "0.2.10"
uuid = "1.9.1"
jiter = { version = "0.5", features = ["python"] }
hex = "0.4.3"

[lib]
name = "_pydantic_core"
Expand Down
8 changes: 4 additions & 4 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def to_json(
exclude_none: bool = False,
round_trip: bool = False,
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
bytes_mode: Literal['utf8', 'base64'] = 'utf8',
bytes_mode: Literal['utf8', 'base64', 'hex'] = 'utf8',
inf_nan_mode: Literal['null', 'constants', 'strings'] = 'constants',
serialize_unknown: bool = False,
fallback: Callable[[Any], Any] | None = None,
Expand All @@ -373,7 +373,7 @@ def to_json(
exclude_none: Whether to exclude fields that have a value of `None`.
round_trip: Whether to enable serialization and validation round-trip support.
timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`.
bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`.
bytes_mode: How to serialize `bytes` objects, either `'utf8'`, `'base64'`, or `'hex'`.
inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'`, `'constants'`, or `'strings'`.
serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails
`"<Unserializable {value_type} object>"` will be used.
Expand Down Expand Up @@ -427,7 +427,7 @@ def to_jsonable_python(
exclude_none: bool = False,
round_trip: bool = False,
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
bytes_mode: Literal['utf8', 'base64'] = 'utf8',
bytes_mode: Literal['utf8', 'base64', 'hex'] = 'utf8',
inf_nan_mode: Literal['null', 'constants', 'strings'] = 'constants',
serialize_unknown: bool = False,
fallback: Callable[[Any], Any] | None = None,
Expand All @@ -448,7 +448,7 @@ def to_jsonable_python(
exclude_none: Whether to exclude fields that have a value of `None`.
round_trip: Whether to enable serialization and validation round-trip support.
timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`.
bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`.
bytes_mode: How to serialize `bytes` objects, either `'utf8'`, `'base64'`, or `'hex'`.
inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'`, `'constants'`, or `'strings'`.
serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails
`"<Unserializable {value_type} object>"` will be used.
Expand Down
3 changes: 3 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class CoreConfig(TypedDict, total=False):
ser_json_bytes: The serialization option for `bytes` values. Default is 'utf8'.
ser_json_inf_nan: The serialization option for infinity and NaN values
in float fields. Default is 'null'.
val_json_bytes: The validation option for `bytes` values, complementing ser_json_bytes. Default is 'utf8'.
hide_input_in_errors: Whether to hide input data from `ValidationError` representation.
validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError.
Requires exceptiongroup backport pre Python 3.11.
Expand Down Expand Up @@ -107,6 +108,7 @@ class CoreConfig(TypedDict, total=False):
ser_json_timedelta: Literal['iso8601', 'float'] # default: 'iso8601'
ser_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8'
ser_json_inf_nan: Literal['null', 'constants', 'strings'] # default: 'null'
val_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8'
# used to hide input data from ValidationError repr
hide_input_in_errors: bool
validation_error_cause: bool # default: False
Expand Down Expand Up @@ -3904,6 +3906,7 @@ def definition_reference_schema(
'bytes_type',
'bytes_too_short',
'bytes_too_long',
'bytes_invalid_encoding',
'value_error',
'assertion_error',
'literal_error',
Expand Down
10 changes: 10 additions & 0 deletions src/errors/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ error_types! {
BytesTooLong {
max_length: {ctx_type: usize, ctx_fn: field_from_context},
},
BytesInvalidEncoding {
encoding: {ctx_type: String, ctx_fn: field_from_context},
encoding_error: {ctx_type: String, ctx_fn: field_from_context},
},
// ---------------------
// python errors from functions
ValueError {
Expand Down Expand Up @@ -515,6 +519,7 @@ impl ErrorType {
Self::BytesType {..} => "Input should be a valid bytes",
Self::BytesTooShort {..} => "Data should have at least {min_length} byte{expected_plural}",
Self::BytesTooLong {..} => "Data should have at most {max_length} byte{expected_plural}",
Self::BytesInvalidEncoding { .. } => "Data should be valid {encoding}: {encoding_error}",
Self::ValueError {..} => "Value error, {error}",
Self::AssertionError {..} => "Assertion failed, {error}",
Self::CustomError {..} => "", // custom errors are handled separately
Expand Down Expand Up @@ -664,6 +669,11 @@ impl ErrorType {
let expected_plural = plural_s(*max_length);
to_string_render!(tmpl, max_length, expected_plural)
}
Self::BytesInvalidEncoding {
encoding,
encoding_error,
..
} => render!(tmpl, encoding, encoding_error),
Self::ValueError { error, .. } => {
let error = &error
.as_ref()
Expand Down
3 changes: 2 additions & 1 deletion src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use pyo3::{intern, prelude::*};
use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::py_err;
use crate::validators::ValBytesMode;

use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherInt, EitherString};
Expand Down Expand Up @@ -71,7 +72,7 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {

fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValMatch<EitherString<'_>>;

fn validate_bytes<'a>(&'a self, strict: bool) -> ValMatch<EitherBytes<'a, 'py>>;
fn validate_bytes<'a>(&'a self, strict: bool, mode: ValBytesMode) -> ValMatch<EitherBytes<'a, 'py>>;

fn validate_bool(&self, strict: bool) -> ValMatch<bool>;

Expand Down
23 changes: 19 additions & 4 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use strum::EnumMessage;
use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::lookup_key::{LookupKey, LookupPath};
use crate::validators::decimal::create_decimal;
use crate::validators::ValBytesMode;

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
Expand Down Expand Up @@ -106,9 +107,16 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
}
}

fn validate_bytes<'a>(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
fn validate_bytes<'a>(
&'a self,
_strict: bool,
mode: ValBytesMode,
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
match self {
JsonValue::Str(s) => Ok(ValidationMatch::strict(s.as_bytes().into())),
JsonValue::Str(s) => match mode.deserialize_string(s) {
Ok(b) => Ok(ValidationMatch::strict(b)),
Err(e) => Err(ValError::new(e, self)),
},
_ => Err(ValError::new(ErrorTypeDefaults::BytesType, self)),
}
}
Expand Down Expand Up @@ -342,8 +350,15 @@ impl<'py> Input<'py> for str {
Ok(ValidationMatch::strict(self.into()))
}

fn validate_bytes<'a>(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
Ok(ValidationMatch::strict(self.as_bytes().into()))
fn validate_bytes<'a>(
&'a self,
_strict: bool,
mode: ValBytesMode,
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
match mode.deserialize_string(self) {
Ok(b) => Ok(ValidationMatch::strict(b)),
Err(e) => Err(ValError::new(e, self)),
}
}

fn validate_bool(&self, _strict: bool) -> ValResult<ValidationMatch<bool>> {
Expand Down
12 changes: 10 additions & 2 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError,
use crate::tools::{extract_i64, safe_repr};
use crate::validators::decimal::{create_decimal, get_decimal_type};
use crate::validators::Exactness;
use crate::validators::ValBytesMode;
use crate::ArgsKwargs;

use super::datetime::{
Expand Down Expand Up @@ -174,7 +175,11 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
Err(ValError::new(ErrorTypeDefaults::StringType, self))
}

fn validate_bytes<'a>(&'a self, strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
fn validate_bytes<'a>(
&'a self,
strict: bool,
mode: ValBytesMode,
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
if let Ok(py_bytes) = self.downcast_exact::<PyBytes>() {
return Ok(ValidationMatch::exact(py_bytes.into()));
} else if let Ok(py_bytes) = self.downcast::<PyBytes>() {
Expand All @@ -185,7 +190,10 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
if !strict {
return if let Ok(py_str) = self.downcast::<PyString>() {
let str = py_string_str(py_str)?;
Ok(str.as_bytes().into())
match mode.deserialize_string(str) {
Ok(b) => Ok(b),
Err(e) => Err(ValError::new(e, self)),
}
} else if let Ok(py_byte_array) = self.downcast::<PyByteArray>() {
Ok(py_byte_array.to_vec().into())
} else {
Expand Down
12 changes: 10 additions & 2 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::input::py_string_str;
use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::safe_repr;
use crate::validators::decimal::create_decimal;
use crate::validators::ValBytesMode;

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime,
Expand Down Expand Up @@ -105,9 +106,16 @@ impl<'py> Input<'py> for StringMapping<'py> {
}
}

fn validate_bytes<'a>(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
fn validate_bytes<'a>(
&'a self,
_strict: bool,
mode: ValBytesMode,
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
match self {
Self::String(s) => py_string_str(s).map(|b| ValidationMatch::strict(b.as_bytes().into())),
Self::String(s) => py_string_str(s).and_then(|b| match mode.deserialize_string(b) {
Ok(b) => Ok(ValidationMatch::strict(b)),
Err(e) => Err(ValError::new(e, self)),
}),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BytesType, self)),
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use std::sync::OnceLock;
use jiter::{map_json_error, PartialMode, PythonParse, StringCacheMode};
use pyo3::exceptions::PyTypeError;
use pyo3::{prelude::*, sync::GILOnceCell};
use serializers::BytesMode;
use validators::ValBytesMode;

// parse this first to get access to the contained macro
#[macro_use]
Expand Down Expand Up @@ -55,7 +57,7 @@ pub fn from_json<'py>(
allow_partial: bool,
) -> PyResult<Bound<'py, PyAny>> {
let v_match = data
.validate_bytes(false)
.validate_bytes(false, ValBytesMode { ser: BytesMode::Utf8 })
.map_err(|_| PyTypeError::new_err("Expected bytes, bytearray or str"))?;
let json_either_bytes = v_match.into_inner();
let json_bytes = json_either_bytes.as_slice();
Expand Down
6 changes: 2 additions & 4 deletions src/serializers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub trait FromConfig {
macro_rules! serialization_mode {
($name:ident, $config_key:expr, $($variant:ident => $value:expr),* $(,)?) => {
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum $name {
pub enum $name {
#[default]
$($variant,)*
}
Expand Down Expand Up @@ -183,9 +183,7 @@ impl BytesMode {
Err(e) => Err(Error::custom(e.to_string())),
},
Self::Base64 => serializer.serialize_str(&base64::engine::general_purpose::URL_SAFE.encode(bytes)),
Self::Hex => {
serializer.serialize_str(&bytes.iter().fold(String::new(), |acc, b| acc + &format!("{b:02x}")))
}
Self::Hex => serializer.serialize_str(hex::encode(bytes).as_str()),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use pyo3::{PyTraverseError, PyVisit};
use crate::definitions::{Definitions, DefinitionsBuilder};
use crate::py_gc::PyGcTraverse;

pub(crate) use config::BytesMode;
use config::SerializationConfig;
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
use extra::{CollectWarnings, SerRecursionState, WarningsMode};
Expand Down
11 changes: 9 additions & 2 deletions src/validators/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ use crate::input::Input;

use crate::tools::SchemaDict;

use super::config::ValBytesMode;
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};

#[derive(Debug, Clone)]
pub struct BytesValidator {
strict: bool,
bytes_mode: ValBytesMode,
}

impl BuildValidator for BytesValidator {
Expand All @@ -31,6 +33,7 @@ impl BuildValidator for BytesValidator {
} else {
Ok(Self {
strict: is_strict(schema, config)?,
bytes_mode: ValBytesMode::from_config(config)?,
}
.into())
}
Expand All @@ -47,7 +50,7 @@ impl Validator for BytesValidator {
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
input
.validate_bytes(state.strict_or(self.strict))
.validate_bytes(state.strict_or(self.strict), self.bytes_mode)
.map(|m| m.unpack(state).into_py(py))
}

Expand All @@ -59,6 +62,7 @@ impl Validator for BytesValidator {
#[derive(Debug, Clone)]
pub struct BytesConstrainedValidator {
strict: bool,
bytes_mode: ValBytesMode,
max_length: Option<usize>,
min_length: Option<usize>,
}
Expand All @@ -72,7 +76,9 @@ impl Validator for BytesConstrainedValidator {
input: &(impl Input<'py> + ?Sized),
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
let either_bytes = input.validate_bytes(state.strict_or(self.strict))?.unpack(state);
let either_bytes = input
.validate_bytes(state.strict_or(self.strict), self.bytes_mode)?
.unpack(state);
let len = either_bytes.len()?;

if let Some(min_length) = self.min_length {
Expand Down Expand Up @@ -110,6 +116,7 @@ impl BytesConstrainedValidator {
let py = schema.py();
Ok(Self {
strict: is_strict(schema, config)?,
bytes_mode: ValBytesMode::from_config(config)?,
min_length: schema.get_as(intern!(py, "min_length"))?,
max_length: schema.get_as(intern!(py, "max_length"))?,
}
Expand Down
49 changes: 49 additions & 0 deletions src/validators/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use std::borrow::Cow;
use std::str::FromStr;

use base64::Engine;
use pyo3::types::{PyDict, PyString};
use pyo3::{intern, prelude::*};

use crate::errors::ErrorType;
use crate::input::EitherBytes;
use crate::serializers::BytesMode;
use crate::tools::SchemaDict;

#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub struct ValBytesMode {
pub ser: BytesMode,
}
davidhewitt marked this conversation as resolved.
Show resolved Hide resolved

impl ValBytesMode {
pub fn from_config(config: Option<&Bound<'_, PyDict>>) -> PyResult<Self> {
let Some(config_dict) = config else {
return Ok(Self::default());
};
let raw_mode = config_dict.get_as::<Bound<'_, PyString>>(intern!(config_dict.py(), "val_json_bytes"))?;
let ser_mode = raw_mode.map_or_else(|| Ok(BytesMode::default()), |raw| BytesMode::from_str(&raw.to_cow()?))?;
Ok(Self { ser: ser_mode })
}

pub fn deserialize_string<'py>(self, s: &str) -> Result<EitherBytes<'_, 'py>, ErrorType> {
match self.ser {
BytesMode::Utf8 => Ok(EitherBytes::Cow(Cow::Borrowed(s.as_bytes()))),
BytesMode::Base64 => match base64::engine::general_purpose::URL_SAFE.decode(s) {
Ok(bytes) => Ok(EitherBytes::from(bytes)),
Err(err) => Err(ErrorType::BytesInvalidEncoding {
encoding: "base64".to_string(),
encoding_error: err.to_string(),
context: None,
}),
},
BytesMode::Hex => match hex::decode(s) {
Ok(vec) => Ok(EitherBytes::from(vec)),
Err(err) => Err(ErrorType::BytesInvalidEncoding {
encoding: "hex".to_string(),
encoding_error: err.to_string(),
context: None,
}),
},
}
}
}
Loading
Loading