From d5811d4328992353bc473b60cbc3a0576a8cfd7b Mon Sep 17 00:00:00 2001 From: Bas Schoenmaeckers Date: Tue, 10 Sep 2024 17:12:15 +0200 Subject: [PATCH] Do not require padding when decoding base64 bytes --- src/validators/config.rs | 18 ++++++++++++++---- tests/test_json.py | 6 ++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/validators/config.rs b/src/validators/config.rs index a43a7643a..a14104628 100644 --- a/src/validators/config.rs +++ b/src/validators/config.rs @@ -1,8 +1,9 @@ use std::borrow::Cow; use std::str::FromStr; -use base64::engine::general_purpose::{STANDARD, URL_SAFE}; -use base64::{DecodeError, Engine}; +use base64::engine::general_purpose::GeneralPurpose; +use base64::engine::{DecodePaddingMode, GeneralPurposeConfig}; +use base64::{alphabet, DecodeError, Engine}; use pyo3::types::{PyDict, PyString}; use pyo3::{intern, prelude::*}; @@ -11,6 +12,15 @@ use crate::input::EitherBytes; use crate::serializers::BytesMode; use crate::tools::SchemaDict; +const URL_SAFE_OPTIONAL_PADDING: GeneralPurpose = GeneralPurpose::new( + &alphabet::URL_SAFE, + GeneralPurposeConfig::new().with_decode_padding_mode(DecodePaddingMode::Indifferent), +); +const STANDARD_OPTIONAL_PADDING: GeneralPurpose = GeneralPurpose::new( + &alphabet::STANDARD, + GeneralPurposeConfig::new().with_decode_padding_mode(DecodePaddingMode::Indifferent), +); + #[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] pub struct ValBytesMode { pub ser: BytesMode, @@ -29,10 +39,10 @@ impl ValBytesMode { pub fn deserialize_string<'py>(self, s: &str) -> Result, ErrorType> { match self.ser { BytesMode::Utf8 => Ok(EitherBytes::Cow(Cow::Borrowed(s.as_bytes()))), - BytesMode::Base64 => URL_SAFE + BytesMode::Base64 => URL_SAFE_OPTIONAL_PADDING .decode(s) .or_else(|err| match err { - DecodeError::InvalidByte(_, b'/' | b'+') => STANDARD.decode(s), + DecodeError::InvalidByte(_, b'/' | b'+') => STANDARD_OPTIONAL_PADDING.decode(s), _ => Err(err), }) .map(EitherBytes::from) diff --git a/tests/test_json.py b/tests/test_json.py index f22850704..49175410d 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -401,6 +401,12 @@ def test_json_bytes_base64_round_trip(): assert v.validate_json(b'{"key":' + encoded_url + b'}') == {'key': data} +def test_json_bytes_base64_no_padding(): + v = SchemaValidator({'type': 'bytes'}, {'val_json_bytes': 'base64'}) + base_64_without_padding = "bm8tcGFkZGluZw" + assert v.validate_json(json.dumps(base_64_without_padding)) == b"no-padding" + + def test_json_bytes_base64_invalid(): v = SchemaValidator({'type': 'bytes'}, {'val_json_bytes': 'base64'}) wrong_input = 'wrong!'